Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -42,7 +42,7 @@ class AllReduce(BasicOperation): ...@@ -42,7 +42,7 @@ class AllReduce(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -22,9 +22,7 @@ from ...distributed import ( ...@@ -22,9 +22,7 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager, Recipe from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor from .._common import maybe_dequantize, is_quantized_tensor
...@@ -291,10 +289,19 @@ class BasicLinear(BasicOperation): ...@@ -291,10 +289,19 @@ class BasicLinear(BasicOperation):
# Quantize if needed # Quantize if needed
if self._with_quantized_weight: if self._with_quantized_weight:
quantizer = self.get_quantizer("forward", 1) quantizer = self.get_quantizer("forward", 1)
if quantizer is None:
raise RuntimeError(
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized "
"within fp8_model_init, but the forward pass was not "
"performed within fp8_autocast."
)
quantizer.set_usage( quantizer.set_usage(
rowwise=True, rowwise=True,
columnwise=torch.is_grad_enabled(), columnwise=torch.is_grad_enabled(),
) )
quantizer.internal = False
with torch.no_grad(): with torch.no_grad():
weight = quantizer(weight) weight = quantizer(weight)
...@@ -303,72 +310,52 @@ class BasicLinear(BasicOperation): ...@@ -303,72 +310,52 @@ class BasicLinear(BasicOperation):
weight = torch.nn.Parameter(weight) weight = torch.nn.Parameter(weight)
self.weight = weight self.weight = weight
def pre_first_forward( def pre_first_fuser_forward(self) -> None:
self, super().pre_first_fuser_forward()
*, if self.weight.device.type == "meta":
recipe: Optional[Recipe],
) -> None:
super().pre_first_forward(recipe=recipe)
# Initialize weights if needed
weight = self.weight
if weight.device.type == "meta":
self.reset_parameters() self.reset_parameters()
weight = self.weight
# Configure quantizers def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
if recipe is not None: super().reset_recipe_state(recipe=recipe)
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
# Specify required tensor formats # Input/grad output quantizers use internal tensors
input_quantizer = self.get_quantizer("forward", 0)
grad_output_quantizer = self.get_quantizer("backward", 0)
if input_quantizer is not None:
input_quantizer.internal = True input_quantizer.internal = True
weight_quantizer.internal = True if grad_output_quantizer is not None:
grad_output_quantizer.internal = True grad_output_quantizer.internal = True
# Recipe-specific configuration # Handle weight quantizer
if recipe.float8_current_scaling(): # Note: This function may be called in base class constructor,
if any( # before any basic linear attrs have been set.
not isinstance(q, Float8CurrentScalingQuantizer) weight_quantizer = self.get_quantizer("forward", 1)
for q in (input_quantizer, weight_quantizer, grad_output_quantizer) if weight_quantizer is None:
): pass
raise RuntimeError( elif is_quantized_tensor(getattr(self, "weight", None)):
"FP8 current-scaling recipe is enabled, " # Make sure weight param has correct quantizer
f"but input quantizer is {input_quantizer.__class__.__name__}, " weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
f"weight quantizer is {weight_quantizer.__class__.__name__}, " weight_quantizer.internal = False
f"grad output quantizer is {grad_output_quantizer.__class__.__name__}" self.weight.update_quantizer(weight_quantizer.copy())
) else:
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale # Use internal tensors if quantized weights will not be
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon # exposed externally
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale weight_quantizer.internal = (
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon not FP8GlobalStateManager.with_fp8_parameters()
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale and not getattr(self, "_with_quantized_weight", False)
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon )
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if isinstance(
weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
) and isinstance(weight, Float8TensorBase):
weight._quantizer = weight_quantizer
@staticmethod @staticmethod
def _functional_forward( def _functional_forward(
input: torch.Tensor, # pylint: disable=redefined-builtin input: torch.Tensor, # pylint: disable=redefined-builtin
weight: torch.Tensor, weight: torch.Tensor,
*, *,
alpha: float = 1.0,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None, # pylint: disable=unused-argument device: Optional[torch.device] = None, # pylint: disable=unused-argument
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
beta: Optional[float] = None,
accumulate_into_out: bool = False, accumulate_into_out: bool = False,
tensor_parallel_mode: Optional[str] = None, tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
...@@ -388,6 +375,8 @@ class BasicLinear(BasicOperation): ...@@ -388,6 +375,8 @@ class BasicLinear(BasicOperation):
Input tensor Input tensor
weight: torch.Tensor weight: torch.Tensor
Weight tensor Weight tensor
alpha: float, default = 1.0
Scaling factor applied to the result of the GEMM
bias: torch.Tensor, optional bias: torch.Tensor, optional
Bias tensor Bias tensor
device: torch.device, default = default CUDA device device: torch.device, default = default CUDA device
...@@ -396,6 +385,8 @@ class BasicLinear(BasicOperation): ...@@ -396,6 +385,8 @@ class BasicLinear(BasicOperation):
Tensor datatype Tensor datatype
out: torch.Tensor, optional out: torch.Tensor, optional
Output tensor Output tensor
beta: float, optional
Scaling factor applied to original value of out when accumulating into it
accumulate_into_out: bool, default = `False` accumulate_into_out: bool, default = `False`
Add result to output tensor instead of overwriting Add result to output tensor instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None` tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
...@@ -441,7 +432,7 @@ class BasicLinear(BasicOperation): ...@@ -441,7 +432,7 @@ class BasicLinear(BasicOperation):
if dtype is None: if dtype is None:
if out is not None and isinstance(out, torch.Tensor): if out is not None and isinstance(out, torch.Tensor):
dtype = out.dtype dtype = out.dtype
elif weight is not None and isinstance(out, torch.Tensor): elif weight is not None and isinstance(weight, torch.Tensor):
dtype = weight.dtype dtype = weight.dtype
else: else:
raise ValueError( raise ValueError(
...@@ -516,18 +507,11 @@ class BasicLinear(BasicOperation): ...@@ -516,18 +507,11 @@ class BasicLinear(BasicOperation):
raise ValueError("Output tensor is quantized, but quantizer was not provided") raise ValueError("Output tensor is quantized, but quantizer was not provided")
else: else:
output_quantizer = None output_quantizer = None
if isinstance(output_quantizer, MXFP8Quantizer):
raise RuntimeError(
"Attempting to generate MXFP8 output tensor, "
"but GEMM with MXFP8 output is not supported"
)
if isinstance(output_quantizer, Float8BlockQuantizer):
raise RuntimeError(
"Attempting to generate Float8BlockQuantized output tensor, "
"but GEMM with Float8BlockQuantized output is not supported"
)
if output_quantizer is not None: if output_quantizer is not None:
if not isinstance(output_quantizer, Float8Quantizer):
raise RuntimeError(
"Attempting to generate quantized output tensor with unsupported quantizer"
)
output_quantizer.set_usage(rowwise=True, columnwise=False) output_quantizer.set_usage(rowwise=True, columnwise=False)
# Check if accumulating into output tensor # Check if accumulating into output tensor
...@@ -552,6 +536,8 @@ class BasicLinear(BasicOperation): ...@@ -552,6 +536,8 @@ class BasicLinear(BasicOperation):
get_workspace(), get_workspace(),
out_dtype=dtype, out_dtype=dtype,
quantization_params=output_quantizer, quantization_params=output_quantizer,
alpha=alpha,
beta=beta,
accumulate=accumulate_into_out, accumulate=accumulate_into_out,
out=y, out=y,
bias=bias, bias=bias,
...@@ -589,13 +575,17 @@ class BasicLinear(BasicOperation): ...@@ -589,13 +575,17 @@ class BasicLinear(BasicOperation):
input: Optional[torch.Tensor], # pylint: disable=redefined-builtin input: Optional[torch.Tensor], # pylint: disable=redefined-builtin
weight: Optional[torch.Tensor], weight: Optional[torch.Tensor],
*, *,
grad_input_alpha: Optional[float] = None,
input_requires_grad: bool = True, input_requires_grad: bool = True,
grad_weight_alpha: Optional[float] = None,
weight_requires_grad: bool = True, weight_requires_grad: bool = True,
device: Optional[torch.device] = None, # pylint: disable=unused-argument device: Optional[torch.device] = None, # pylint: disable=unused-argument
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
grad_weight: Optional[torch.Tensor] = None, grad_weight: Optional[torch.Tensor] = None,
grad_weight_beta: Optional[float] = None,
accumulate_into_grad_weight: bool = False, accumulate_into_grad_weight: bool = False,
grad_input: Optional[torch.Tensor] = None, grad_input: Optional[torch.Tensor] = None,
grad_input_beta: Optional[float] = None,
accumulate_into_grad_input: bool = False, accumulate_into_grad_input: bool = False,
tensor_parallel_mode: Optional[str] = None, tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
...@@ -618,8 +608,12 @@ class BasicLinear(BasicOperation): ...@@ -618,8 +608,12 @@ class BasicLinear(BasicOperation):
weight: torch.Tensor, optional weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t. Weight tensor. Required to compute loss gradient w.r.t.
input. input.
grad_input_alpha: float, optional
Scaling factor applied to the result of the dgrad GEMM
input_requires_grad: bool input_requires_grad: bool
Whether to compute loss gradient w.r.t. input tensor Whether to compute loss gradient w.r.t. input tensor
grad_weight_alpha: float, optional
Scaling factor applied to the result of the wgrad GEMM
weight_requires_grad: bool weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor Whether to compute loss gradient w.r.t. weight tensor
device: torch.device, default = default CUDA device device: torch.device, default = default CUDA device
...@@ -628,10 +622,14 @@ class BasicLinear(BasicOperation): ...@@ -628,10 +622,14 @@ class BasicLinear(BasicOperation):
Tensor datatype Tensor datatype
grad_weight: torch.Tensor, optional grad_weight: torch.Tensor, optional
Loss gradient w.r.t. weight tensor Loss gradient w.r.t. weight tensor
grad_weight_beta: float, optional
Scaling factor applied to original value of grad_weight when accumulating into it
accumulate_into_grad_weight: bool, default = `False` accumulate_into_grad_weight: bool, default = `False`
Add result to weight grad instead of overwriting Add result to weight grad instead of overwriting
grad_input: torch.Tensor, optional grad_input: torch.Tensor, optional
Loss gradient w.r.t. input tensor Loss gradient w.r.t. input tensor
grad_input_beta: float, optional
Scaling factor applied to original value of grad_input when accumulating into it
accumulate_into_grad_input: bool, default = `False` accumulate_into_grad_input: bool, default = `False`
Add result to input grad instead of overwriting Add result to input grad instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None` tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
...@@ -801,11 +799,12 @@ class BasicLinear(BasicOperation): ...@@ -801,11 +799,12 @@ class BasicLinear(BasicOperation):
) )
else: else:
grad_input_quantizer = None grad_input_quantizer = None
if isinstance(grad_input_quantizer, MXFP8Quantizer): if grad_input_quantizer is not None:
raise RuntimeError( if not isinstance(grad_input_quantizer, Float8Quantizer):
"Attempting to generate MXFP8 grad input tensor, " raise RuntimeError(
"but GEMM with MXFP8 output is not supported" "Attempting to generate quantized grad input tensor "
) "with unsupported quantizer"
)
# Check if accumulating into grad input tensor # Check if accumulating into grad input tensor
if accumulate_into_grad_input: if accumulate_into_grad_input:
...@@ -827,6 +826,8 @@ class BasicLinear(BasicOperation): ...@@ -827,6 +826,8 @@ class BasicLinear(BasicOperation):
get_workspace(), get_workspace(),
out_dtype=dtype, out_dtype=dtype,
quantization_params=grad_input_quantizer, quantization_params=grad_input_quantizer,
alpha=grad_input_alpha,
beta=grad_input_beta,
accumulate=accumulate_into_grad_input, accumulate=accumulate_into_grad_input,
layout="NN", layout="NN",
out=dx, out=dx,
...@@ -877,6 +878,8 @@ class BasicLinear(BasicOperation): ...@@ -877,6 +878,8 @@ class BasicLinear(BasicOperation):
dy, dy,
get_workspace(), get_workspace(),
out_dtype=dw_dtype, out_dtype=dw_dtype,
alpha=grad_weight_alpha,
beta=grad_weight_beta,
accumulate=accumulate_into_grad_weight, accumulate=accumulate_into_grad_weight,
layout="NT", layout="NT",
out=dw, out=dw,
...@@ -894,7 +897,7 @@ class BasicLinear(BasicOperation): ...@@ -894,7 +897,7 @@ class BasicLinear(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -903,27 +906,34 @@ class BasicLinear(BasicOperation): ...@@ -903,27 +906,34 @@ class BasicLinear(BasicOperation):
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
# FP8 metadata # FP8 metadata
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
output_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute: if with_quantized_compute:
# Get quantizers
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Configure quantizers # Configure quantizers
# Note: We cache the quantized input for backward pass, # Note: We cache the quantized input for backward pass,
# but discard the quantized weights. # but discard the quantized weights.
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False) weight_quantizer.set_usage(rowwise=True, columnwise=False)
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Get autocast dtype if needed # Get autocast dtype if needed
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda") dtype = torch.get_autocast_dtype("cuda")
...@@ -947,15 +957,16 @@ class BasicLinear(BasicOperation): ...@@ -947,15 +957,16 @@ class BasicLinear(BasicOperation):
) )
# Save state for backward pass # Save state for backward pass
ctx.save_for_backward(x_local, w) if ctx.requires_grad:
ctx.with_quantized_compute = with_quantized_compute ctx.save_for_backward(x_local, w)
ctx.input_quantizer = input_quantizer ctx.with_quantized_compute = with_quantized_compute
ctx.weight_quantizer = weight_quantizer ctx.input_quantizer = input_quantizer
ctx.grad_output_quantizer = grad_output_quantizer ctx.weight_quantizer = weight_quantizer
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_output_quantizer = grad_output_quantizer
ctx.dtype = dtype ctx.grad_input_quantizer = grad_input_quantizer
ctx.input_requires_grad = input_requires_grad ctx.dtype = dtype
ctx.weight_requires_grad = weight_requires_grad ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad
return output return output
......
...@@ -10,15 +10,8 @@ from typing import Optional ...@@ -10,15 +10,8 @@ from typing import Optional
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.ops.op import ( from ..op import BasicOperation, OperationContext
BasicOperation, from ...utils import canonicalize_device, canonicalize_dtype
OperationContext,
)
from ...utils import (
canonicalize_device,
canonicalize_dtype,
)
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer from ...tensor import Quantizer
...@@ -114,8 +107,8 @@ class Bias(BasicOperation): ...@@ -114,8 +107,8 @@ class Bias(BasicOperation):
bias = torch.nn.Parameter(bias) bias = torch.nn.Parameter(bias)
self.bias = bias self.bias = bias
def pre_first_forward(self, *args, **kwargs) -> None: def pre_first_fuser_forward(self) -> None:
super().pre_first_forward(*args, **kwargs) super().pre_first_fuser_forward()
if self.bias.device.type == "meta": if self.bias.device.type == "meta":
self.reset_parameters() self.reset_parameters()
...@@ -123,24 +116,14 @@ class Bias(BasicOperation): ...@@ -123,24 +116,14 @@ class Bias(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
x = input_ x = input_
b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) b = self.bias.view([1] * (x.dim() - 1) + [self.local_size])
# Check if backward pass is needed if ctx.requires_grad:
requires_grad = ctx.requires_grad ctx.grad_input_quantizer = prev_op_grad_output_quantizer
# Check if previous op quantizes its output's gradient
grad_input_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
grad_input_quantizer = prev_op_grad_input_quantizer
if requires_grad:
ctx.with_quantized_compute = with_quantized_compute
ctx.grad_input_quantizer = grad_input_quantizer
return x + b return x + b
...@@ -152,10 +135,10 @@ class Bias(BasicOperation): ...@@ -152,10 +135,10 @@ class Bias(BasicOperation):
dy = grad_output dy = grad_output
if dy.dim() > 1: if dy.dim() > 1:
quantizer = ctx.grad_input_quantizer quantizer = ctx.grad_input_quantizer
if ctx.with_quantized_compute and quantizer is not None: if quantizer is None:
db, dy = tex.bgrad_quantize(dy, quantizer)
else:
db = dy.sum(tuple(range(dy.dim() - 1))) db = dy.sum(tuple(range(dy.dim() - 1)))
else:
db, dy = tex.bgrad_quantize(dy, quantizer)
else: else:
db = dy db = dy
return dy, (db,) return dy, (db,)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for constant scaling."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class ConstantScale(BasicOperation):
"""Multiply by a constant"""
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = scale
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
return input_ * self.scale
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return grad_output * self.scale, ()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for dropout."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class Dropout(BasicOperation):
"""Randomly zero out tensor entries during training
During training, tensor entries are randomly set to zero with
probability :math:`p` and remaining entries are scaled by
:math:`1/(1-p)`.
"""
def __init__(self, p: float) -> None:
super().__init__()
self.dropout_probability = p
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Compute dropout if training
out = input_
is_training = self.training
mask = None
if is_training:
keep_prob = 1 - self.dropout_probability
mask = torch.empty_like(input_)
mask.bernoulli_(keep_prob)
mask *= 1 / keep_prob
out = out * mask
# Save context for backward
if ctx.requires_grad:
ctx.save_for_backward(mask)
ctx.is_training = is_training
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
(mask,) = ctx.saved_tensors
grad_input = grad_output
if ctx.is_training:
grad_input = grad_input * mask
return grad_input, ()
...@@ -23,7 +23,7 @@ class Identity(BasicOperation): ...@@ -23,7 +23,7 @@ class Identity(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
return input_ return input_
......
...@@ -6,10 +6,12 @@ ...@@ -6,10 +6,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional from typing import Optional
import os
import torch import torch
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ... import torch_version
from .._common import maybe_dequantize from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from ...jit import ( from ...jit import (
...@@ -60,7 +62,11 @@ class L2Normalization(BasicOperation): ...@@ -60,7 +62,11 @@ class L2Normalization(BasicOperation):
# JIT warmup for L2Normalization fused operations # JIT warmup for L2Normalization fused operations
if seq_length and micro_batch_size: if seq_length and micro_batch_size:
if torch.cuda.is_available(): if (
torch.cuda.is_available()
and torch_version() >= (2, 0, 0)
and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1")))
):
set_jit_fusion_options() set_jit_fusion_options()
# For L2Normalization, we don't know the hidden size until forward pass, # For L2Normalization, we don't know the hidden size until forward pass,
# but we can warm up with common sizes. For QK normalization, this will be # but we can warm up with common sizes. For QK normalization, this will be
...@@ -74,7 +80,7 @@ class L2Normalization(BasicOperation): ...@@ -74,7 +80,7 @@ class L2Normalization(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors # Use input directly - torch.compile can handle multi-dimensional tensors
...@@ -86,7 +92,7 @@ class L2Normalization(BasicOperation): ...@@ -86,7 +92,7 @@ class L2Normalization(BasicOperation):
# Compute L2 normalization using fused implementation # Compute L2 normalization using fused implementation
# L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps) # L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps)
if requires_grad: if requires_grad:
# Training: use version that returns both output and intermediate values # Training: use version that returns output and intermediate values for backward pass
y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps) y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps)
else: else:
# Inference: use lightweight version that only returns output # Inference: use lightweight version that only returns output
...@@ -110,7 +116,7 @@ class L2Normalization(BasicOperation): ...@@ -110,7 +116,7 @@ class L2Normalization(BasicOperation):
dy = maybe_dequantize(grad_output) dy = maybe_dequantize(grad_output)
# Compute L2 norm backward pass using fused implementation # Compute L2 norm backward pass using fused implementation - recalculates l2_norm_squared_eps
dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps) dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps)
# Clear saved tensors if possible # Clear saved tensors if possible
......
...@@ -13,7 +13,6 @@ from typing import Optional ...@@ -13,7 +13,6 @@ from typing import Optional
import torch import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...constants import TE_DType from ...constants import TE_DType
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
...@@ -168,8 +167,8 @@ class LayerNorm(BasicOperation): ...@@ -168,8 +167,8 @@ class LayerNorm(BasicOperation):
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
def pre_first_forward(self, *args, **kwargs) -> None: def pre_first_fuser_forward(self) -> None:
super().pre_first_forward(*args, **kwargs) super().pre_first_fuser_forward()
if self.weight.device.type == "meta" or self.bias.device.type == "meta": if self.weight.device.type == "meta" or self.bias.device.type == "meta":
self.reset_parameters() self.reset_parameters()
...@@ -177,7 +176,7 @@ class LayerNorm(BasicOperation): ...@@ -177,7 +176,7 @@ class LayerNorm(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -200,31 +199,22 @@ class LayerNorm(BasicOperation): ...@@ -200,31 +199,22 @@ class LayerNorm(BasicOperation):
w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
b = maybe_dequantize(self.bias, dtype).view((inner_dim,)) b = maybe_dequantize(self.bias, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if output is quantized
output_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute layer norm # Compute layer norm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"] sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"]
y, means, rstdevs = layernorm_fwd( y, means, rstdevs = layernorm_fwd(
x, x,
w, w,
b, b,
self.eps, self.eps,
None, None,
output_quantizer, next_op_input_quantizer,
TE_DType[dtype], TE_DType[dtype],
sm_margin, sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
) )
# Save state for backward pass # Save state for backward pass
if requires_grad: if ctx.requires_grad:
ctx.save_for_backward(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
......
...@@ -22,14 +22,20 @@ class MakeExtraOutput(BasicOperation): ...@@ -22,14 +22,20 @@ class MakeExtraOutput(BasicOperation):
If this operation is included in the operation fuser, then the If this operation is included in the operation fuser, then the
operation fuser will return the intermediate tensor as an extra operation fuser will return the intermediate tensor as an extra
tensor output. In the backward pass, the gradient is directly tensor output.
accumulated into the gradient w.r.t. the extra output.
This operation is considered an advanced feature and most users In the backward pass, the gradient may be directly
are discouraged from using it. In-place operations break some accumulated into the gradient w.r.t. the extra output. This is
autograd assumptions and they can result in subtle, esoteric bugs. controlled by the in_place kwarg. Currently, the BackwardLinearAdd
fusion is able to happen only with in_place=True.
Compare to `AddInPlace`, which does a similar operation in the Using this operation with in_place=True is
considered an advanced feature. Most users are discouraged
from enabling it in-place gradient accumulation, as in-place
operations break some autograd assumptions and they can result
in subtle, esoteric bugs.
Compare to `AddExtraInput`, which does a similar operation in the
backward pass. backward pass.
""" """
...@@ -37,6 +43,10 @@ class MakeExtraOutput(BasicOperation): ...@@ -37,6 +43,10 @@ class MakeExtraOutput(BasicOperation):
# Operation expects buffer for output tensor # Operation expects buffer for output tensor
num_extra_outputs: int = 1 num_extra_outputs: int = 1
def __init__(self, *, in_place: bool = False):
super().__init__()
self._in_place: bool = in_place
def op_forward(self, *args, **kwargs) -> None: def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError( raise RuntimeError(
"{self.__class__.__name__} operation has " "{self.__class__.__name__} operation has "
...@@ -59,7 +69,7 @@ class MakeExtraOutput(BasicOperation): ...@@ -59,7 +69,7 @@ class MakeExtraOutput(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -76,6 +86,10 @@ class MakeExtraOutput(BasicOperation): ...@@ -76,6 +86,10 @@ class MakeExtraOutput(BasicOperation):
Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]],
]: ]:
grad_input = basic_op_grad_extra_outputs[0][0] grad_extra_output = basic_op_grad_extra_outputs[0][0]
grad_input += grad_output if self._in_place:
grad_extra_output += grad_output
grad_input = grad_extra_output
else:
grad_input = grad_extra_output + grad_output
return grad_input, [()], [()] return grad_input, [()], [()]
...@@ -50,7 +50,7 @@ class Quantize(BasicOperation): ...@@ -50,7 +50,7 @@ class Quantize(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -64,7 +64,8 @@ class Quantize(BasicOperation): ...@@ -64,7 +64,8 @@ class Quantize(BasicOperation):
if quantize_forward and not is_quantized_tensor(out): if quantize_forward and not is_quantized_tensor(out):
out = self.get_quantizer("forward", 0)(out) out = self.get_quantizer("forward", 0)(out)
ctx.quantize_backward = quantize_backward if ctx.requires_grad:
ctx.quantize_backward = quantize_backward
return out return out
def op_backward( def op_backward(
......
...@@ -40,7 +40,7 @@ class ReduceScatter(BasicOperation): ...@@ -40,7 +40,7 @@ class ReduceScatter(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -38,10 +38,11 @@ class Reshape(BasicOperation): ...@@ -38,10 +38,11 @@ class Reshape(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
ctx.input_shape = input_.size() if ctx.requires_grad:
ctx.input_shape = input_.size()
return input_.reshape(*self._shape) return input_.reshape(*self._shape)
def op_backward( def op_backward(
......
...@@ -13,7 +13,6 @@ from typing import Optional ...@@ -13,7 +13,6 @@ from typing import Optional
import torch import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...constants import TE_DType from ...constants import TE_DType
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
...@@ -151,8 +150,8 @@ class RMSNorm(BasicOperation): ...@@ -151,8 +150,8 @@ class RMSNorm(BasicOperation):
weight = torch.nn.Parameter(weight) weight = torch.nn.Parameter(weight)
self.weight = weight self.weight = weight
def pre_first_forward(self, *args, **kwargs) -> None: def pre_first_fuser_forward(self) -> None:
super().pre_first_forward(*args, **kwargs) super().pre_first_fuser_forward()
if self.weight.device.type == "meta": if self.weight.device.type == "meta":
self.reset_parameters() self.reset_parameters()
...@@ -160,7 +159,7 @@ class RMSNorm(BasicOperation): ...@@ -160,7 +159,7 @@ class RMSNorm(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -182,30 +181,21 @@ class RMSNorm(BasicOperation): ...@@ -182,30 +181,21 @@ class RMSNorm(BasicOperation):
x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim)) x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim))
w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if output is quantized
output_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute RMSNorm # Compute RMSNorm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"] sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"]
y, _, rstdevs = rmsnorm_fwd( y, _, rstdevs = rmsnorm_fwd(
x, x,
w, w,
self.eps, self.eps,
None, None,
output_quantizer, next_op_input_quantizer,
TE_DType[dtype], TE_DType[dtype],
sm_margin, sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
) )
# Save state for backward pass # Save state for backward pass
if requires_grad: if ctx.requires_grad:
ctx.save_for_backward(x, rstdevs) ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
......
...@@ -4,14 +4,18 @@ ...@@ -4,14 +4,18 @@
"""Compound tensor operation supported by the operation fuser.""" """Compound tensor operation supported by the operation fuser."""
from .backward_bias_activation import ( from .backward_activation_bias import (
BackwardBiasActivation, BackwardActivationBias,
fuse_backward_bias_activation, fuse_backward_activation_bias,
) )
from .backward_linear_add import ( from .backward_linear_add import (
BackwardLinearAdd, BackwardLinearAdd,
fuse_backward_linear_add, fuse_backward_linear_add,
) )
from .backward_linear_scale import (
BackwardLinearScale,
fuse_backward_linear_scale,
)
from .forward_linear_bias_activation import ( from .forward_linear_bias_activation import (
ForwardLinearBiasActivation, ForwardLinearBiasActivation,
fuse_forward_linear_bias_activation, fuse_forward_linear_bias_activation,
...@@ -20,6 +24,10 @@ from .forward_linear_bias_add import ( ...@@ -20,6 +24,10 @@ from .forward_linear_bias_add import (
ForwardLinearBiasAdd, ForwardLinearBiasAdd,
fuse_forward_linear_bias_add, fuse_forward_linear_bias_add,
) )
from .forward_linear_scale_add import (
ForwardLinearScaleAdd,
fuse_forward_linear_scale_add,
)
from .userbuffers_backward_linear import ( from .userbuffers_backward_linear import (
UserbuffersBackwardLinear, UserbuffersBackwardLinear,
fuse_userbuffers_backward_linear, fuse_userbuffers_backward_linear,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Fused backward dbias + dact + quantize.""" """Fused backward dact + dbias + quantize."""
from __future__ import annotations from __future__ import annotations
from typing import Optional from typing import Optional
...@@ -29,8 +29,8 @@ _fused_activations = {GELU: tex.dbias_dgelu, ReLU: tex.dbias_drelu} ...@@ -29,8 +29,8 @@ _fused_activations = {GELU: tex.dbias_dgelu, ReLU: tex.dbias_drelu}
_fusible_activations = tuple(_fused_activations.keys()) _fusible_activations = tuple(_fused_activations.keys())
class BackwardBiasActivation(FusedOperation): class BackwardActivationBias(FusedOperation):
"""Fused backward dbias + dact + quantize """Fused backward dact + dbias + quantize
Uses the next operation's input quantizer. Uses the next operation's input quantizer.
...@@ -66,15 +66,10 @@ class BackwardBiasActivation(FusedOperation): ...@@ -66,15 +66,10 @@ class BackwardBiasActivation(FusedOperation):
dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype) dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype)
# Get previous op quantizer # Get previous op quantizer
if not bias_op_ctx.with_quantized_compute:
raise RuntimeError(
"BackwardBiasActivation requires quantized compute, "
"but Bias context has it disabled"
)
quantizer = bias_op_ctx.grad_input_quantizer quantizer = bias_op_ctx.grad_input_quantizer
if quantizer is None: if quantizer is None:
raise RuntimeError( raise RuntimeError(
"BackwardBiasActivation requires previous op's grad output quantizer, " "BackwardActivationBias requires previous op's grad output quantizer, "
"but Bias context has no quantizer" "but Bias context has no quantizer"
) )
...@@ -87,11 +82,11 @@ class BackwardBiasActivation(FusedOperation): ...@@ -87,11 +82,11 @@ class BackwardBiasActivation(FusedOperation):
return dx, [(), (db,)], [(), ()] return dx, [(), (db,)], [(), ()]
def fuse_backward_bias_activation( def fuse_backward_activation_bias(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe], recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]: ) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dbias + dact + quantize """Fused backward dact + dbias + quantize
Parameters Parameters
---------- ----------
...@@ -109,7 +104,7 @@ def fuse_backward_bias_activation( ...@@ -109,7 +104,7 @@ def fuse_backward_bias_activation(
""" """
# Check if recipe supports bias activation fusion # Check if recipe supports bias activation fusion
if recipe is None or not (recipe.delayed() or recipe.mxfp8()): if recipe is None:
return ops return ops
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
...@@ -138,7 +133,7 @@ def fuse_backward_bias_activation( ...@@ -138,7 +133,7 @@ def fuse_backward_bias_activation(
ops = ops[1:] ops = ops[1:]
# Replace window with fused op # Replace window with fused op
op = BackwardBiasActivation( op = BackwardActivationBias(
activation=window[0][0], activation=window[0][0],
bias=window[1][0], bias=window[1][0],
) )
......
...@@ -29,10 +29,10 @@ class BackwardLinearAdd(FusedOperation): ...@@ -29,10 +29,10 @@ class BackwardLinearAdd(FusedOperation):
def __init__( def __init__(
self, self,
*, *,
linear: BasicLinear,
backward_add: MakeExtraOutput, backward_add: MakeExtraOutput,
linear: BasicLinear,
) -> None: ) -> None:
super().__init__((linear, backward_add)) super().__init__((backward_add, linear))
def fuser_backward( def fuser_backward(
self, self,
...@@ -47,7 +47,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -47,7 +47,7 @@ class BackwardLinearAdd(FusedOperation):
]: ]:
# Get basic operations # Get basic operations
linear_op = self.basic_ops[0] linear_op = self.basic_ops[1]
linear_op_ctx = basic_op_ctxs[0] linear_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass # Saved tensors from forward pass
...@@ -139,6 +139,8 @@ def fuse_backward_linear_add( ...@@ -139,6 +139,8 @@ def fuse_backward_linear_add(
op, _ = ops[0] op, _ = ops[0]
if not isinstance(op, MakeExtraOutput): if not isinstance(op, MakeExtraOutput):
continue continue
if not op._in_place:
continue
window.extend(ops[:1]) window.extend(ops[:1])
ops = ops[1:] ops = ops[1:]
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward dgrad GEMM + scale."""
from __future__ import annotations
from typing import Optional
import torch
from ..basic import BasicLinear, ConstantScale
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data
class BackwardLinearScale(FusedOperation):
"""Fused backward dgrad GEMM + scale
Column tensor parallelism is not supported since that requires
communication immediately after the dgrad GEMM.
"""
def __init__(
self,
*,
scale: ConstantScale,
linear: BasicLinear,
) -> None:
super().__init__((linear, scale))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[tuple[Optional[torch.Tensor], ...]],
list[tuple[()]],
]:
# Get basic operations
linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[1]
scale_op = self.basic_ops[1]
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
else:
accumulate_into_main_grad = False
# Linear backward pass
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=w,
input_requires_grad=linear_op_ctx.input_requires_grad,
grad_input_alpha=scale_op.scale,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
grad_weight_alpha=scale_op.scale,
dtype=linear_op_ctx.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
with_quantized_compute=linear_op_ctx.with_quantized_compute,
input_quantizer=linear_op_ctx.input_quantizer,
weight_quantizer=linear_op_ctx.weight_quantizer,
grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
)
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible
clear_tensor_data(x_local)
return grad_input, [(), (grad_weight,)], [(), ()]
def fuse_backward_linear_scale(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dgrad GEMM + constant scale
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
out.extend(window)
# Check if first op is constant scale
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, ConstantScale):
continue
# Check if second op is linear
op, _ = ops[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication after the dgrad GEMM
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardLinearScale(
scale=window[0][0],
linear=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -59,7 +59,7 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -59,7 +59,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -89,18 +89,12 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -89,18 +89,12 @@ class ForwardLinearBiasActivation(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # FP8 metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
output_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed # Get autocast dtype if needed
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
...@@ -126,18 +120,18 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -126,18 +120,18 @@ class ForwardLinearBiasActivation(FusedOperation):
) )
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w) if linear_op_ctx.requires_grad:
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.dtype = dtype
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
if bias_op is not None: linear_op_ctx.weight_requires_grad = weight_requires_grad
bias_op_ctx.with_quantized_compute = with_quantized_compute if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -11,7 +11,7 @@ from typing import Any, Optional ...@@ -11,7 +11,7 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
FusedOperation, FusedOperation,
FusibleOperation, FusibleOperation,
...@@ -33,7 +33,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -33,7 +33,7 @@ class ForwardLinearBiasAdd(FusedOperation):
*, *,
linear: BasicLinear, linear: BasicLinear,
bias: Optional[Bias], bias: Optional[Bias],
add: AddInPlace, add: AddExtraInput,
) -> None: ) -> None:
# Basic operations that comprise this fused operation # Basic operations that comprise this fused operation
...@@ -57,7 +57,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -57,7 +57,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -83,17 +83,12 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -83,17 +83,12 @@ class ForwardLinearBiasAdd(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() input_quantizer = linear_op.get_quantizer("forward", 0)
input_quantizer = None weight_quantizer = linear_op.get_quantizer("forward", 1)
weight_quantizer = None
output_quantizer = None output_quantizer = None
grad_output_quantizer = None grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = None grad_input_quantizer = prev_op_grad_output_quantizer
if with_quantized_compute: with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed # Get autocast dtype if needed
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
...@@ -122,18 +117,18 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -122,18 +117,18 @@ class ForwardLinearBiasAdd(FusedOperation):
) )
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w) if linear_op_ctx.requires_grad:
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.dtype = dtype
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
if bias_op is not None: linear_op_ctx.weight_requires_grad = weight_requires_grad
bias_op_ctx.with_quantized_compute = with_quantized_compute if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
...@@ -184,8 +179,10 @@ def fuse_forward_linear_bias_add( ...@@ -184,8 +179,10 @@ def fuse_forward_linear_bias_add(
continue continue
op, _ = ops[0] op, _ = ops[0]
# Check if next op is add in-place # Check if next op is in-place add extra input
if not isinstance(op, AddInPlace): if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue continue
add = op add = op
window.extend(ops[:1]) window.extend(ops[:1])
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused operation for forward GEMM + scale + add."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
from ...fp8 import FP8GlobalStateManager
from ..basic import AddExtraInput, BasicLinear, ConstantScale
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...tensor import Quantizer
class ForwardLinearScaleAdd(FusedOperation):
"""Fused forward GEMM + scale + add
Row tensor parallelism is not supported since that requires
communication immediately after the GEMM.
"""
def __init__(
self,
*,
linear: BasicLinear,
scale: ConstantScale,
add: AddExtraInput,
) -> None:
super().__init__((linear, scale, add))
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations
linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[0]
scale_op = self.basic_ops[1]
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
# Get extra input tensor for add operation
extra_input = basic_op_extra_inputs[2][0]
# Get autocast dtype if needed
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Linear forward
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
alpha=scale_op.scale,
dtype=dtype,
out=extra_input,
accumulate_into_out=True,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
with_quantized_compute=with_quantized_compute,
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)
# Save state for backward pass
if linear_op_ctx.requires_grad:
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
return output, [() for _ in range(len(self.basic_ops))]
def fuse_forward_linear_scale_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse forward GEMM + scale + add
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 3:
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear = op
op, _ = ops[0]
# Check if next op is constant scale
if not isinstance(op, ConstantScale):
continue
scale = op
window.extend(ops[:1])
ops = ops[1:]
op, _ = ops[0]
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearScaleAdd(
linear=linear,
scale=scale,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment