Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
......@@ -42,7 +42,7 @@ class AllReduce(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
......
......@@ -22,9 +22,7 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
......@@ -291,10 +289,19 @@ class BasicLinear(BasicOperation):
# Quantize if needed
if self._with_quantized_weight:
quantizer = self.get_quantizer("forward", 1)
if quantizer is None:
raise RuntimeError(
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized "
"within fp8_model_init, but the forward pass was not "
"performed within fp8_autocast."
)
quantizer.set_usage(
rowwise=True,
columnwise=torch.is_grad_enabled(),
)
quantizer.internal = False
with torch.no_grad():
weight = quantizer(weight)
......@@ -303,72 +310,52 @@ class BasicLinear(BasicOperation):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
super().pre_first_forward(recipe=recipe)
# Initialize weights if needed
weight = self.weight
if weight.device.type == "meta":
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
if self.weight.device.type == "meta":
self.reset_parameters()
weight = self.weight
# Configure quantizers
if recipe is not None:
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe)
# Specify required tensor formats
# Input/grad output quantizers use internal tensors
input_quantizer = self.get_quantizer("forward", 0)
grad_output_quantizer = self.get_quantizer("backward", 0)
if input_quantizer is not None:
input_quantizer.internal = True
weight_quantizer.internal = True
if grad_output_quantizer is not None:
grad_output_quantizer.internal = True
# Recipe-specific configuration
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
for q in (input_quantizer, weight_quantizer, grad_output_quantizer)
):
raise RuntimeError(
"FP8 current-scaling recipe is enabled, "
f"but input quantizer is {input_quantizer.__class__.__name__}, "
f"weight quantizer is {weight_quantizer.__class__.__name__}, "
f"grad output quantizer is {grad_output_quantizer.__class__.__name__}"
)
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if isinstance(
weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
) and isinstance(weight, Float8TensorBase):
weight._quantizer = weight_quantizer
# Handle weight quantizer
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
weight_quantizer = self.get_quantizer("forward", 1)
if weight_quantizer is None:
pass
elif is_quantized_tensor(getattr(self, "weight", None)):
# Make sure weight param has correct quantizer
weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
weight_quantizer.internal = False
self.weight.update_quantizer(weight_quantizer.copy())
else:
# Use internal tensors if quantized weights will not be
# exposed externally
weight_quantizer.internal = (
not FP8GlobalStateManager.with_fp8_parameters()
and not getattr(self, "_with_quantized_weight", False)
)
@staticmethod
def _functional_forward(
input: torch.Tensor, # pylint: disable=redefined-builtin
weight: torch.Tensor,
*,
alpha: float = 1.0,
bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None, # pylint: disable=unused-argument
dtype: Optional[torch.dtype] = None,
out: Optional[torch.Tensor] = None,
beta: Optional[float] = None,
accumulate_into_out: bool = False,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
......@@ -388,6 +375,8 @@ class BasicLinear(BasicOperation):
Input tensor
weight: torch.Tensor
Weight tensor
alpha: float, default = 1.0
Scaling factor applied to the result of the GEMM
bias: torch.Tensor, optional
Bias tensor
device: torch.device, default = default CUDA device
......@@ -396,6 +385,8 @@ class BasicLinear(BasicOperation):
Tensor datatype
out: torch.Tensor, optional
Output tensor
beta: float, optional
Scaling factor applied to original value of out when accumulating into it
accumulate_into_out: bool, default = `False`
Add result to output tensor instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
......@@ -441,7 +432,7 @@ class BasicLinear(BasicOperation):
if dtype is None:
if out is not None and isinstance(out, torch.Tensor):
dtype = out.dtype
elif weight is not None and isinstance(out, torch.Tensor):
elif weight is not None and isinstance(weight, torch.Tensor):
dtype = weight.dtype
else:
raise ValueError(
......@@ -516,18 +507,11 @@ class BasicLinear(BasicOperation):
raise ValueError("Output tensor is quantized, but quantizer was not provided")
else:
output_quantizer = None
if isinstance(output_quantizer, MXFP8Quantizer):
raise RuntimeError(
"Attempting to generate MXFP8 output tensor, "
"but GEMM with MXFP8 output is not supported"
)
if isinstance(output_quantizer, Float8BlockQuantizer):
raise RuntimeError(
"Attempting to generate Float8BlockQuantized output tensor, "
"but GEMM with Float8BlockQuantized output is not supported"
)
if output_quantizer is not None:
if not isinstance(output_quantizer, Float8Quantizer):
raise RuntimeError(
"Attempting to generate quantized output tensor with unsupported quantizer"
)
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Check if accumulating into output tensor
......@@ -552,6 +536,8 @@ class BasicLinear(BasicOperation):
get_workspace(),
out_dtype=dtype,
quantization_params=output_quantizer,
alpha=alpha,
beta=beta,
accumulate=accumulate_into_out,
out=y,
bias=bias,
......@@ -589,13 +575,17 @@ class BasicLinear(BasicOperation):
input: Optional[torch.Tensor], # pylint: disable=redefined-builtin
weight: Optional[torch.Tensor],
*,
grad_input_alpha: Optional[float] = None,
input_requires_grad: bool = True,
grad_weight_alpha: Optional[float] = None,
weight_requires_grad: bool = True,
device: Optional[torch.device] = None, # pylint: disable=unused-argument
dtype: Optional[torch.dtype] = None,
grad_weight: Optional[torch.Tensor] = None,
grad_weight_beta: Optional[float] = None,
accumulate_into_grad_weight: bool = False,
grad_input: Optional[torch.Tensor] = None,
grad_input_beta: Optional[float] = None,
accumulate_into_grad_input: bool = False,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
......@@ -618,8 +608,12 @@ class BasicLinear(BasicOperation):
weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t.
input.
grad_input_alpha: float, optional
Scaling factor applied to the result of the dgrad GEMM
input_requires_grad: bool
Whether to compute loss gradient w.r.t. input tensor
grad_weight_alpha: float, optional
Scaling factor applied to the result of the wgrad GEMM
weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor
device: torch.device, default = default CUDA device
......@@ -628,10 +622,14 @@ class BasicLinear(BasicOperation):
Tensor datatype
grad_weight: torch.Tensor, optional
Loss gradient w.r.t. weight tensor
grad_weight_beta: float, optional
Scaling factor applied to original value of grad_weight when accumulating into it
accumulate_into_grad_weight: bool, default = `False`
Add result to weight grad instead of overwriting
grad_input: torch.Tensor, optional
Loss gradient w.r.t. input tensor
grad_input_beta: float, optional
Scaling factor applied to original value of grad_input when accumulating into it
accumulate_into_grad_input: bool, default = `False`
Add result to input grad instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
......@@ -801,11 +799,12 @@ class BasicLinear(BasicOperation):
)
else:
grad_input_quantizer = None
if isinstance(grad_input_quantizer, MXFP8Quantizer):
raise RuntimeError(
"Attempting to generate MXFP8 grad input tensor, "
"but GEMM with MXFP8 output is not supported"
)
if grad_input_quantizer is not None:
if not isinstance(grad_input_quantizer, Float8Quantizer):
raise RuntimeError(
"Attempting to generate quantized grad input tensor "
"with unsupported quantizer"
)
# Check if accumulating into grad input tensor
if accumulate_into_grad_input:
......@@ -827,6 +826,8 @@ class BasicLinear(BasicOperation):
get_workspace(),
out_dtype=dtype,
quantization_params=grad_input_quantizer,
alpha=grad_input_alpha,
beta=grad_input_beta,
accumulate=accumulate_into_grad_input,
layout="NN",
out=dx,
......@@ -877,6 +878,8 @@ class BasicLinear(BasicOperation):
dy,
get_workspace(),
out_dtype=dw_dtype,
alpha=grad_weight_alpha,
beta=grad_weight_beta,
accumulate=accumulate_into_grad_weight,
layout="NT",
out=dw,
......@@ -894,7 +897,7 @@ class BasicLinear(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
......@@ -903,27 +906,34 @@ class BasicLinear(BasicOperation):
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
# FP8 metadata
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
output_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
# Get quantizers
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Get autocast dtype if needed
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
......@@ -947,15 +957,16 @@ class BasicLinear(BasicOperation):
)
# Save state for backward pass
ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer
ctx.weight_quantizer = weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.dtype = dtype
ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad
if ctx.requires_grad:
ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer
ctx.weight_quantizer = weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.dtype = dtype
ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad
return output
......
......@@ -10,15 +10,8 @@ from typing import Optional
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...utils import (
canonicalize_device,
canonicalize_dtype,
)
from ...fp8 import FP8GlobalStateManager
from ..op import BasicOperation, OperationContext
from ...utils import canonicalize_device, canonicalize_dtype
from ...tensor import Quantizer
......@@ -114,8 +107,8 @@ class Bias(BasicOperation):
bias = torch.nn.Parameter(bias)
self.bias = bias
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
if self.bias.device.type == "meta":
self.reset_parameters()
......@@ -123,24 +116,14 @@ class Bias(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
x = input_
b = self.bias.view([1] * (x.dim() - 1) + [self.local_size])
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if previous op quantizes its output's gradient
grad_input_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
grad_input_quantizer = prev_op_grad_input_quantizer
if requires_grad:
ctx.with_quantized_compute = with_quantized_compute
ctx.grad_input_quantizer = grad_input_quantizer
if ctx.requires_grad:
ctx.grad_input_quantizer = prev_op_grad_output_quantizer
return x + b
......@@ -152,10 +135,10 @@ class Bias(BasicOperation):
dy = grad_output
if dy.dim() > 1:
quantizer = ctx.grad_input_quantizer
if ctx.with_quantized_compute and quantizer is not None:
db, dy = tex.bgrad_quantize(dy, quantizer)
else:
if quantizer is None:
db = dy.sum(tuple(range(dy.dim() - 1)))
else:
db, dy = tex.bgrad_quantize(dy, quantizer)
else:
db = dy
return dy, (db,)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for constant scaling."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class ConstantScale(BasicOperation):
"""Multiply by a constant"""
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = scale
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
return input_ * self.scale
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return grad_output * self.scale, ()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for dropout."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class Dropout(BasicOperation):
"""Randomly zero out tensor entries during training
During training, tensor entries are randomly set to zero with
probability :math:`p` and remaining entries are scaled by
:math:`1/(1-p)`.
"""
def __init__(self, p: float) -> None:
super().__init__()
self.dropout_probability = p
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Compute dropout if training
out = input_
is_training = self.training
mask = None
if is_training:
keep_prob = 1 - self.dropout_probability
mask = torch.empty_like(input_)
mask.bernoulli_(keep_prob)
mask *= 1 / keep_prob
out = out * mask
# Save context for backward
if ctx.requires_grad:
ctx.save_for_backward(mask)
ctx.is_training = is_training
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
(mask,) = ctx.saved_tensors
grad_input = grad_output
if ctx.is_training:
grad_input = grad_input * mask
return grad_input, ()
......@@ -23,7 +23,7 @@ class Identity(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
return input_
......
......@@ -6,10 +6,12 @@
from __future__ import annotations
from typing import Optional
import os
import torch
from ...utils import clear_tensor_data
from ... import torch_version
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...jit import (
......@@ -60,7 +62,11 @@ class L2Normalization(BasicOperation):
# JIT warmup for L2Normalization fused operations
if seq_length and micro_batch_size:
if torch.cuda.is_available():
if (
torch.cuda.is_available()
and torch_version() >= (2, 0, 0)
and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1")))
):
set_jit_fusion_options()
# For L2Normalization, we don't know the hidden size until forward pass,
# but we can warm up with common sizes. For QK normalization, this will be
......@@ -74,7 +80,7 @@ class L2Normalization(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors
......@@ -86,7 +92,7 @@ class L2Normalization(BasicOperation):
# Compute L2 normalization using fused implementation
# L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps)
if requires_grad:
# Training: use version that returns both output and intermediate values
# Training: use version that returns output and intermediate values for backward pass
y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps)
else:
# Inference: use lightweight version that only returns output
......@@ -110,7 +116,7 @@ class L2Normalization(BasicOperation):
dy = maybe_dequantize(grad_output)
# Compute L2 norm backward pass using fused implementation
# Compute L2 norm backward pass using fused implementation - recalculates l2_norm_squared_eps
dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps)
# Clear saved tensors if possible
......
......@@ -13,7 +13,6 @@ from typing import Optional
import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...constants import TE_DType
from ...utils import (
canonicalize_device,
......@@ -168,8 +167,8 @@ class LayerNorm(BasicOperation):
self.weight = weight
self.bias = bias
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
if self.weight.device.type == "meta" or self.bias.device.type == "meta":
self.reset_parameters()
......@@ -177,7 +176,7 @@ class LayerNorm(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if is_in_onnx_export_mode():
......@@ -200,31 +199,22 @@ class LayerNorm(BasicOperation):
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
b = maybe_dequantize(self.bias, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if output is quantized
output_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute layer norm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"]
y, means, rstdevs = layernorm_fwd(
x,
w,
b,
self.eps,
None,
output_quantizer,
next_op_input_quantizer,
TE_DType[dtype],
sm_margin,
self.zero_centered_gamma,
)
# Save state for backward pass
if requires_grad:
if ctx.requires_grad:
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype
......
......@@ -22,14 +22,20 @@ class MakeExtraOutput(BasicOperation):
If this operation is included in the operation fuser, then the
operation fuser will return the intermediate tensor as an extra
tensor output. In the backward pass, the gradient is directly
accumulated into the gradient w.r.t. the extra output.
tensor output.
This operation is considered an advanced feature and most users
are discouraged from using it. In-place operations break some
autograd assumptions and they can result in subtle, esoteric bugs.
In the backward pass, the gradient may be directly
accumulated into the gradient w.r.t. the extra output. This is
controlled by the in_place kwarg. Currently, the BackwardLinearAdd
fusion is able to happen only with in_place=True.
Compare to `AddInPlace`, which does a similar operation in the
Using this operation with in_place=True is
considered an advanced feature. Most users are discouraged
from enabling it in-place gradient accumulation, as in-place
operations break some autograd assumptions and they can result
in subtle, esoteric bugs.
Compare to `AddExtraInput`, which does a similar operation in the
backward pass.
"""
......@@ -37,6 +43,10 @@ class MakeExtraOutput(BasicOperation):
# Operation expects buffer for output tensor
num_extra_outputs: int = 1
def __init__(self, *, in_place: bool = False):
super().__init__()
self._in_place: bool = in_place
def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
......@@ -59,7 +69,7 @@ class MakeExtraOutput(BasicOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -76,6 +86,10 @@ class MakeExtraOutput(BasicOperation):
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
grad_input = basic_op_grad_extra_outputs[0][0]
grad_input += grad_output
grad_extra_output = basic_op_grad_extra_outputs[0][0]
if self._in_place:
grad_extra_output += grad_output
grad_input = grad_extra_output
else:
grad_input = grad_extra_output + grad_output
return grad_input, [()], [()]
......@@ -50,7 +50,7 @@ class Quantize(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
......@@ -64,7 +64,8 @@ class Quantize(BasicOperation):
if quantize_forward and not is_quantized_tensor(out):
out = self.get_quantizer("forward", 0)(out)
ctx.quantize_backward = quantize_backward
if ctx.requires_grad:
ctx.quantize_backward = quantize_backward
return out
def op_backward(
......
......@@ -40,7 +40,7 @@ class ReduceScatter(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
......
......@@ -38,10 +38,11 @@ class Reshape(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
ctx.input_shape = input_.size()
if ctx.requires_grad:
ctx.input_shape = input_.size()
return input_.reshape(*self._shape)
def op_backward(
......
......@@ -13,7 +13,6 @@ from typing import Optional
import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...constants import TE_DType
from ...utils import (
canonicalize_device,
......@@ -151,8 +150,8 @@ class RMSNorm(BasicOperation):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
if self.weight.device.type == "meta":
self.reset_parameters()
......@@ -160,7 +159,7 @@ class RMSNorm(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if is_in_onnx_export_mode():
......@@ -182,30 +181,21 @@ class RMSNorm(BasicOperation):
x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim))
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if output is quantized
output_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute RMSNorm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"]
y, _, rstdevs = rmsnorm_fwd(
x,
w,
self.eps,
None,
output_quantizer,
next_op_input_quantizer,
TE_DType[dtype],
sm_margin,
self.zero_centered_gamma,
)
# Save state for backward pass
if requires_grad:
if ctx.requires_grad:
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype
......
......@@ -4,14 +4,18 @@
"""Compound tensor operation supported by the operation fuser."""
from .backward_bias_activation import (
BackwardBiasActivation,
fuse_backward_bias_activation,
from .backward_activation_bias import (
BackwardActivationBias,
fuse_backward_activation_bias,
)
from .backward_linear_add import (
BackwardLinearAdd,
fuse_backward_linear_add,
)
from .backward_linear_scale import (
BackwardLinearScale,
fuse_backward_linear_scale,
)
from .forward_linear_bias_activation import (
ForwardLinearBiasActivation,
fuse_forward_linear_bias_activation,
......@@ -20,6 +24,10 @@ from .forward_linear_bias_add import (
ForwardLinearBiasAdd,
fuse_forward_linear_bias_add,
)
from .forward_linear_scale_add import (
ForwardLinearScaleAdd,
fuse_forward_linear_scale_add,
)
from .userbuffers_backward_linear import (
UserbuffersBackwardLinear,
fuse_userbuffers_backward_linear,
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Fused backward dbias + dact + quantize."""
"""Fused backward dact + dbias + quantize."""
from __future__ import annotations
from typing import Optional
......@@ -29,8 +29,8 @@ _fused_activations = {GELU: tex.dbias_dgelu, ReLU: tex.dbias_drelu}
_fusible_activations = tuple(_fused_activations.keys())
class BackwardBiasActivation(FusedOperation):
"""Fused backward dbias + dact + quantize
class BackwardActivationBias(FusedOperation):
"""Fused backward dact + dbias + quantize
Uses the next operation's input quantizer.
......@@ -66,15 +66,10 @@ class BackwardBiasActivation(FusedOperation):
dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype)
# Get previous op quantizer
if not bias_op_ctx.with_quantized_compute:
raise RuntimeError(
"BackwardBiasActivation requires quantized compute, "
"but Bias context has it disabled"
)
quantizer = bias_op_ctx.grad_input_quantizer
if quantizer is None:
raise RuntimeError(
"BackwardBiasActivation requires previous op's grad output quantizer, "
"BackwardActivationBias requires previous op's grad output quantizer, "
"but Bias context has no quantizer"
)
......@@ -87,11 +82,11 @@ class BackwardBiasActivation(FusedOperation):
return dx, [(), (db,)], [(), ()]
def fuse_backward_bias_activation(
def fuse_backward_activation_bias(
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dbias + dact + quantize
"""Fused backward dact + dbias + quantize
Parameters
----------
......@@ -109,7 +104,7 @@ def fuse_backward_bias_activation(
"""
# Check if recipe supports bias activation fusion
if recipe is None or not (recipe.delayed() or recipe.mxfp8()):
if recipe is None:
return ops
# Scan through ops, fusing if possible
......@@ -138,7 +133,7 @@ def fuse_backward_bias_activation(
ops = ops[1:]
# Replace window with fused op
op = BackwardBiasActivation(
op = BackwardActivationBias(
activation=window[0][0],
bias=window[1][0],
)
......
......@@ -29,10 +29,10 @@ class BackwardLinearAdd(FusedOperation):
def __init__(
self,
*,
linear: BasicLinear,
backward_add: MakeExtraOutput,
linear: BasicLinear,
) -> None:
super().__init__((linear, backward_add))
super().__init__((backward_add, linear))
def fuser_backward(
self,
......@@ -47,7 +47,7 @@ class BackwardLinearAdd(FusedOperation):
]:
# Get basic operations
linear_op = self.basic_ops[0]
linear_op = self.basic_ops[1]
linear_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass
......@@ -139,6 +139,8 @@ def fuse_backward_linear_add(
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if not op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward dgrad GEMM + scale."""
from __future__ import annotations
from typing import Optional
import torch
from ..basic import BasicLinear, ConstantScale
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data
class BackwardLinearScale(FusedOperation):
"""Fused backward dgrad GEMM + scale
Column tensor parallelism is not supported since that requires
communication immediately after the dgrad GEMM.
"""
def __init__(
self,
*,
scale: ConstantScale,
linear: BasicLinear,
) -> None:
super().__init__((linear, scale))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[tuple[Optional[torch.Tensor], ...]],
list[tuple[()]],
]:
# Get basic operations
linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[1]
scale_op = self.basic_ops[1]
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
else:
accumulate_into_main_grad = False
# Linear backward pass
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=w,
input_requires_grad=linear_op_ctx.input_requires_grad,
grad_input_alpha=scale_op.scale,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
grad_weight_alpha=scale_op.scale,
dtype=linear_op_ctx.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
with_quantized_compute=linear_op_ctx.with_quantized_compute,
input_quantizer=linear_op_ctx.input_quantizer,
weight_quantizer=linear_op_ctx.weight_quantizer,
grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
)
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible
clear_tensor_data(x_local)
return grad_input, [(), (grad_weight,)], [(), ()]
def fuse_backward_linear_scale(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dgrad GEMM + constant scale
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
out.extend(window)
# Check if first op is constant scale
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, ConstantScale):
continue
# Check if second op is linear
op, _ = ops[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication after the dgrad GEMM
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardLinearScale(
scale=window[0][0],
linear=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
......@@ -59,7 +59,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -89,18 +89,12 @@ class ForwardLinearBiasActivation(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
output_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
if torch.is_autocast_enabled():
......@@ -126,18 +120,18 @@ class ForwardLinearBiasActivation(FusedOperation):
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
if linear_op_ctx.requires_grad:
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -11,7 +11,7 @@ from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias
from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
......@@ -33,7 +33,7 @@ class ForwardLinearBiasAdd(FusedOperation):
*,
linear: BasicLinear,
bias: Optional[Bias],
add: AddInPlace,
add: AddExtraInput,
) -> None:
# Basic operations that comprise this fused operation
......@@ -57,7 +57,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -83,17 +83,12 @@ class ForwardLinearBiasAdd(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
# Get autocast dtype if needed
if torch.is_autocast_enabled():
......@@ -122,18 +117,18 @@ class ForwardLinearBiasAdd(FusedOperation):
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
if linear_op_ctx.requires_grad:
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......@@ -184,8 +179,10 @@ def fuse_forward_linear_bias_add(
continue
op, _ = ops[0]
# Check if next op is add in-place
if not isinstance(op, AddInPlace):
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused operation for forward GEMM + scale + add."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
from ...fp8 import FP8GlobalStateManager
from ..basic import AddExtraInput, BasicLinear, ConstantScale
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...tensor import Quantizer
class ForwardLinearScaleAdd(FusedOperation):
"""Fused forward GEMM + scale + add
Row tensor parallelism is not supported since that requires
communication immediately after the GEMM.
"""
def __init__(
self,
*,
linear: BasicLinear,
scale: ConstantScale,
add: AddExtraInput,
) -> None:
super().__init__((linear, scale, add))
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations
linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[0]
scale_op = self.basic_ops[1]
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
# Get extra input tensor for add operation
extra_input = basic_op_extra_inputs[2][0]
# Get autocast dtype if needed
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Linear forward
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
alpha=scale_op.scale,
dtype=dtype,
out=extra_input,
accumulate_into_out=True,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
with_quantized_compute=with_quantized_compute,
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)
# Save state for backward pass
if linear_op_ctx.requires_grad:
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
return output, [() for _ in range(len(self.basic_ops))]
def fuse_forward_linear_scale_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse forward GEMM + scale + add
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 3:
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear = op
op, _ = ops[0]
# Check if next op is constant scale
if not isinstance(op, ConstantScale):
continue
scale = op
window.extend(ops[:1])
ops = ops[1:]
op, _ = ops[0]
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearScaleAdd(
linear=linear,
scale=scale,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment