Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Current scaling recipe reference implementation."""
import dataclasses
import math
from typing import Optional, Tuple, Iterable
import torch
from transformer_engine.pytorch.custom_recipes import quantization
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer
def current_scaling_ref_quantizer_factory(role):
"""Factory function for current scaling reference quantizer.
Usage with CustomRecipe and autocast:
custom_recipe = recipe.CustomRecipe(qfactory=current_scaling_ref_quantizer_factory)
with autocast(recipe=custom_recipe):
output = model(input)
"""
if role in ("linear_input", "linear_weight"):
dtype = torch.float8_e4m3fn
elif role in ("linear_output", "linear_grad_output"):
dtype = torch.float8_e5m2
else:
return None
return CurrentScalingQuantizerRef(
dtype=dtype,
rowwise=True,
columnwise=True,
pow_2_scales=False,
eps=0.0,
)
@dataclasses.dataclass
class CurrentScalingTensorRef(QuantizedTensorStorage):
"""Reference implementation of current scaling quantized tensor"""
data: Optional[torch.Tensor] = None
scale: Optional[torch.Tensor] = None
data_t: Optional[torch.Tensor] = None
scale_t: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None
device: Optional[torch.device] = None
quant_dtype: Optional[torch.dtype] = None
original_shape: Optional[Tuple[int, ...]] = None
_quantizer: Optional[Quantizer] = None
@property
def custom(self) -> bool:
"""Flag to indicate this quantized tensor is custom."""
return True
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the quantization result for saving for backward"""
tensors = [self.data, self.data_t, self.scale, self.scale_t]
self.data = None
self.data_t = None
self.scale = None
self.scale_t = None
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the quantization result from the saved tensors"""
self.data = tensors[0]
self.data_t = tensors[1]
self.scale = tensors[2]
self.scale_t = tensors[3]
return tensors[4:]
# Compatibility
@property
def _data(self):
return self.data
@_data.setter
def _data(self, value):
self.data = value
@property
def _scale_inv(self):
return self.scale
@_scale_inv.setter
def _scale_inv(self, value):
self.scale = value
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"dtype={self.dtype}, "
f"device={self.device}, "
f"quant_dtype={self.quant_dtype}, "
f"original_shape={self.original_shape}"
")"
)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""Generate or remove quantized data based on provided usage."""
has_data = self.data is not None
has_data_transpose = self.data_t is not None
needs_data = has_data
needs_data_transpose = has_data_transpose
if rowwise_usage is not None:
needs_data = rowwise_usage
if columnwise_usage is not None:
needs_data_transpose = columnwise_usage
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP8 data is required to generate FP8 data transpose")
self._create_transpose()
# Delete data that is not required
if not needs_data:
self.data = None
if not needs_data_transpose:
self.data_t = None
def _create_transpose(self):
"""Create transposed quantized tensor"""
if not self.data.is_contiguous():
self.data = self.data.contiguous()
self.data_t = self.data.t().contiguous()
self.scale_t = self.scale
def size(self, *args, **kwargs):
"""Get the size of the quantized tensor"""
if self.data is not None:
return self.data.size(*args, **kwargs)
size = self.data_t.size(*args, **kwargs)
return torch.Size([size[-1], math.prod(size[:-1])])
def _scale_from_amax_tensor(
x_dtype: torch.dtype,
amax: torch.Tensor,
quant_dtype: torch.dtype,
*,
eps: float,
pow_2_scales: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Derives quantization and dequantization from amax and options.
Reference implementation for scale calculation.
Returns:
- scale: quantization scales
- scale_inv: dequantization scales
- amax: Amax tensor with updates made for extrema values.
"""
assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
# Compute scale factor
scale = torch.div(fp8_max, amax)
# Take care of inf before pow_2_scales
scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale)
if pow_2_scales:
_, exp = torch.frexp(scale)
exp = exp - 1
assert (exp > -127).all()
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
scale = torch.where(amax == float("inf"), 0.0, scale)
# Handle overflow cases for amax zero causing NaN
scale = torch.where(amax == 0, 1.0, scale)
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv, amax
class CurrentScalingQuantizerRef(Quantizer):
"""Reference implementation of current scaling quantizer"""
def __init__(
self,
dtype: torch.dtype,
rowwise: bool = True,
columnwise: bool = True,
pow_2_scales: bool = False,
eps: float = 0.0,
):
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.internal = True
self.dtype = dtype
self.pow_2_scales = pow_2_scales
self.eps = eps
self.with_amax_reduction = False
self.amax_reduction_group = None
@property
def custom(self) -> bool:
"""Flag to indicate this quantizer is custom."""
return True
@property
def supports_allgather_fp8(self) -> bool:
"""Flag to indicate this quantizer supports allgather fp8"""
return True
@classmethod
def compute_scale(
cls,
x: torch.Tensor,
quant_dtype: torch.dtype,
eps=0.0,
pow_2_scales: bool = False,
):
"""Compute the scale from the amax tensor"""
# Use float32 for computation
x_fp32 = x.to(torch.float32)
if x_fp32.numel() == 0:
amax = torch.empty(1, dtype=torch.float32, device=x.device)
else:
amax = torch.amax(torch.abs(x_fp32)).view(1)
return _scale_from_amax_tensor(
x.dtype,
amax=amax,
quant_dtype=quant_dtype,
eps=eps,
pow_2_scales=pow_2_scales,
)
def _quantize(self, tensor: torch.Tensor) -> Tuple[
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
"""
Python implementation of quantization (c++ kernel can be used as an option instead).
Parameters
----------
tensor : torch.Tensor
Input tensor to quantize (should be 2D)
Returns
-------
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]
(qx, sx, qx_t, sx_t) where:
- qx: quantized data in row-major order (if rowwise_usage), None otherwise
- sx: empty scale tensor for qx (if rowwise_usage), None otherwise
- qx_t: quantized data in column-major order (if columnwise_usage), None otherwise
- sx_t: empty scale tensor for qx_t (if columnwise_usage), None otherwise
"""
# Handle amax reduction if enabled
if self.with_amax_reduction:
assert (
self.amax_reduction_group is not None
), "amax_reduction_group must be set when with_amax_reduction is True"
# Compute local amax
if tensor.numel() == 0:
amax = torch.empty(1, dtype=torch.float32, device=tensor.device)
else:
amax = torch.amax(torch.abs(tensor)).view(1).to(torch.float32)
# Reduce amax across all ranks
torch.distributed.all_reduce(
amax, group=self.amax_reduction_group, op=torch.distributed.ReduceOp.MAX
)
# Compute scale using the global amax
scale, scale_inv, _ = _scale_from_amax_tensor(
tensor.dtype,
amax=amax,
quant_dtype=self.dtype,
eps=self.eps,
pow_2_scales=self.pow_2_scales,
)
else:
# compute scale factor using local amax
scale, scale_inv, _ = self.compute_scale(
tensor,
self.dtype,
eps=self.eps,
pow_2_scales=self.pow_2_scales,
)
qx: Optional[torch.Tensor] = (tensor.float() * scale).to(self.dtype)
sx: Optional[torch.Tensor] = scale_inv
# transpose if needed
if self.columnwise_usage:
assert qx is not None
qx_t = qx.t().contiguous()
sx_t = sx
else:
qx_t, sx_t = None, None
if not self.rowwise_usage:
qx = None
sx = None
return qx, sx, qx_t, sx_t
def quantize(
self,
tensor: torch.Tensor,
**kwargs, # pylint: disable=unused-argument
) -> CurrentScalingTensorRef:
# sanity checks
assert tensor.dtype in utils.HIGH_PRECISION_FLOAT_DTYPES, "Unsupported input dtype."
# Make it work with 3D tensors
original_shape = tensor.shape
if tensor.ndim > 2:
tensor = tensor.view(-1, tensor.shape[-1])
qx, sx, qx_t, sx_t = self._quantize(tensor)
return CurrentScalingTensorRef(
data=qx,
scale=sx,
data_t=qx_t,
scale_t=sx_t,
dtype=tensor.dtype,
device=tensor.device,
quant_dtype=self.dtype,
_quantizer=self,
original_shape=original_shape,
)
def dequantize(
self, tensor: torch.Tensor, scale: torch.Tensor, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
"""Dequantize the quantized tensor"""
tensor = tensor.to(torch.float32) * scale
if dtype is None:
return tensor
return tensor.to(dtype)
def qgemm(
self,
qx: torch.Tensor,
qw: torch.Tensor,
m_params: quantization.MMParams,
out_dtype: torch.dtype,
sx: torch.Tensor,
sw: torch.Tensor,
bias: torch.Tensor | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
gemm_type: quantization.GEMMType = quantization.GEMMType.FPROP, # pylint: disable=unused-argument
qresult_x: QuantizedTensorStorage | None = None, # pylint: disable=unused-argument
qresult_w: QuantizedTensorStorage | None = None, # pylint: disable=unused-argument
) -> torch.Tensor:
"""Python implementation of quantized gemm."""
M, K = qx.shape
N, _ = qw.shape
if M == 0 or K == 0 or N == 0:
if accumulate:
assert out is not None
y = out
else:
y = torch.zeros((M, N), dtype=out_dtype, device=qx.device)
if bias is not None:
y += bias
return y
# cublas fp8 gemm does not support fp32 bias
use_bias_in_gemm = (
bias is not None and out_dtype != torch.float32 and bias.dtype != torch.float32
)
# Run quantized gemm: y = qw * qx
scaled_mm_res = torch._scaled_mm(
qx,
qw.transpose(-1, -2),
scale_a=sx,
scale_b=sw,
out_dtype=out_dtype,
use_fast_accum=not m_params.use_split_accumulator,
bias=bias if use_bias_in_gemm else None,
)
y = scaled_mm_res[0] if isinstance(scaled_mm_res, tuple) else scaled_mm_res
if bias is not None and not use_bias_in_gemm:
# Check number of elements in bias tensor because it can be an empty tensor
if bias.numel():
y += bias
if accumulate:
assert out is not None, "Output tensor must be provided for accumulation."
out.add_(y)
y = out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y
def transpose_qresult(self, qresult: CurrentScalingTensorRef) -> CurrentScalingTensorRef:
"""Python implementation of transpose qresult."""
qx = qresult.data
scale = qresult.scale
assert qresult.data_t is None
assert qresult.scale_t is None
assert qx is not None
qx_t = qx.transpose(-2, -1).contiguous()
scale_t = scale
qresult.data_t = qx_t
qresult.scale_t = scale_t
return qresult
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensorStorage,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensorStorage:
"""Update the quantized tensor with the given tensor in-place
Parameters
----------
src: torch.Tensor
Source tensor to copy from
dst: ExperimentalQuantizedTensor
Destination ExperimentalQuantizedTensor to update
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
# Handle noop flag
if noop_flag is not None and noop_flag.item() != 0:
return dst
# Make sure input is in expected format
if not src.is_contiguous():
src = src.contiguous()
# Store the original shape and reshape for processing
original_shape = src.shape
if src.ndim > 2:
src = src.view(-1, src.shape[-1])
qx, sx, qx_t, sx_t = self._quantize(src)
# Update the destination with new data
dst.data = qx
dst.scale = sx
dst.data_t = qx_t
dst.scale_t = sx_t
dst.dtype = src.dtype
dst.quant_dtype = self.dtype
dst.original_shape = original_shape
return dst
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False, # pylint: disable=unused-argument
) -> CurrentScalingTensorRef:
assert len(shape) == 2, "shape is not 2d"
# Canonicalize tensor attributes
if device is None:
device = torch.device("cuda")
# Allocate quantized data
qx = torch.empty(shape, dtype=self.dtype, device=device)
sx = torch.empty(1, dtype=torch.float32, device=device)
# Allocate quantized data transpose if needed
qx_t = None
sx_t = None
if self.columnwise_usage:
inner_dim = qx.size(-1)
qx_t = torch.empty(
inner_dim,
qx.numel() // inner_dim,
dtype=self.dtype,
device=device,
)
sx_t = torch.empty(1, dtype=torch.float32, device=device)
# Construct quantized tensor
return CurrentScalingTensorRef(
data=qx,
scale=sx,
data_t=qx_t,
scale_t=sx_t,
dtype=dtype,
device=device,
quant_dtype=self.dtype,
_quantizer=self,
original_shape=shape,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -18,9 +18,9 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): ...@@ -18,9 +18,9 @@ def nvfp4_ref_rht_2d_quantizer_factory(role):
""" """
Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights). Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights).
Usage with CustomRecipe and fp8_autocast: Usage with CustomRecipe and autocast:
custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory) custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory)
with fp8_autocast(fp8_recipe=custom_recipe): with autocast(fp8_recipe=custom_recipe):
output = model(input) output = model(input)
""" """
if role == "linear_input": if role == "linear_input":
...@@ -338,7 +338,7 @@ def get_wgrad_sign_vector() -> torch.Tensor: ...@@ -338,7 +338,7 @@ def get_wgrad_sign_vector() -> torch.Tensor:
class NVFP4QuantizerRef(Quantizer): class NVFP4QuantizerRef(Quantizer):
"""NVFP4 quantizer for middleware between Transformer Engine and Kitchen""" """Reference implementation of NVFP4 quantizer"""
def __init__( def __init__(
self, self,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -30,7 +30,7 @@ except ImportError: ...@@ -30,7 +30,7 @@ except ImportError:
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv
from . import torch_version from .torch_version import torch_version
from .utils import ( from .utils import (
is_non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data, safely_set_viewless_tensor_data,
...@@ -48,7 +48,7 @@ from .tensor.storage.float8_tensor_storage import Float8TensorStorage ...@@ -48,7 +48,7 @@ from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
__all__ = ["checkpoint", "CudaRNGStatesTracker"] __all__ = ["checkpoint", "CudaRNGStatesTracker"]
...@@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool: ...@@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool:
) )
def is_graph_safe_rng_state(state: Union[torch.Tensor, torch.Generator]) -> bool:
"""Returns whether the rng state is a graph safe version."""
return graph_safe_rng_available() and isinstance(state, torch.Generator)
def _get_cuda_rng_state( def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda", device: Union[int, str, torch.device] = "cuda",
clone: bool = False, clone: bool = False,
...@@ -340,9 +345,16 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -340,9 +345,16 @@ class _CheckpointFunction(torch.autograd.Function):
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
ctx.graph_safe_rng_state = (
is_graph_safe_rng_state(next(iter(ctx.fwd_cuda_rng_state_tracker.values())))
if ctx.fwd_cuda_rng_state_tracker
else False
)
else:
ctx.graph_safe_rng_state = False
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
if context_fn is not None: if context_fn is not None:
forward_ctx, recompute_ctx = context_fn() forward_ctx, recompute_ctx = context_fn()
...@@ -406,13 +418,13 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -406,13 +418,13 @@ class _CheckpointFunction(torch.autograd.Function):
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
# Set the states to what it used to be before the forward pass. # Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state) torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False) _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
...@@ -427,7 +439,7 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -427,7 +439,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state) torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False) _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker) get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
...@@ -470,12 +482,21 @@ class _CheckpointFrame: ...@@ -470,12 +482,21 @@ class _CheckpointFrame:
def cache_rng_states(self, forward=True): def cache_rng_states(self, forward=True):
"""Cache fwd/bwd RNG states in the frame to restore later.""" """Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = ( rng_states = (torch.get_rng_state(),)
torch.get_rng_state(),
_get_cuda_rng_state(graph_safe=False),
)
if self.get_rng_state_tracker is not None: if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(),) tracker_states = self.get_rng_state_tracker().get_states()
self.graph_safe_rng_state = (
is_graph_safe_rng_state(next(iter(tracker_states.values())))
if tracker_states
else False
)
rng_states += (
_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),
tracker_states,
)
else:
self.graph_safe_rng_state = False
rng_states += (_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),)
if forward: if forward:
self.fwd_rng_states = rng_states self.fwd_rng_states = rng_states
...@@ -490,7 +511,7 @@ class _CheckpointFrame: ...@@ -490,7 +511,7 @@ class _CheckpointFrame:
rng_states = self.bwd_rng_states rng_states = self.bwd_rng_states
torch.set_rng_state(rng_states[0]) torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1], graph_safe=False) _set_cuda_rng_state(rng_states[1], graph_safe=self.graph_safe_rng_state)
if self.get_rng_state_tracker is not None: if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2]) self.get_rng_state_tracker().set_states(rng_states[2])
...@@ -642,18 +663,18 @@ def checkpoint( ...@@ -642,18 +663,18 @@ def checkpoint(
Parameters Parameters
---------- ----------
function: Callable function : Callable
pytorch module used to run the forward and backward passes using pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`. the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool, default = False distribute_saved_activations : bool, default = False
if set to `True` and `use_reentrant=True`, first tensor argument is distributed if set to ``True`` and ``use_reentrant=True``, first tensor argument is distributed
across the specified tensor parallel group (`tp_group`) before saving it for the across the specified tensor parallel group (``tp_group``) before saving it for the
backward pass. This has no effect when `use_reentrant=False`. backward pass. This has no effect when ``use_reentrant=False``.
get_rng_state_tracker: `Callable`, default = None get_rng_state_tracker : Callable, default = None
python callable which returns an instance of :func:`CudaRNGStatesTracker`. python callable which returns an instance of :class:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = None tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when `distribute_saved_activations=True` tensor parallel process group. Used only when ``distribute_saved_activations=True``
and `use_reentrant=True`. If `None`, it falls back to the default group. and ``use_reentrant=True``. If ``None``, it falls back to the default group.
use_reentrant : bool, default = True use_reentrant : bool, default = True
perform checkpointing in reentrant mode. perform checkpointing in reentrant mode.
args : tuple args : tuple
...@@ -778,8 +799,8 @@ class CudaRNGStatesTracker: ...@@ -778,8 +799,8 @@ class CudaRNGStatesTracker:
For model parallelism, multiple RNG states need to simultaneously exist in order For model parallelism, multiple RNG states need to simultaneously exist in order
to execute operations in or out of the model parallel region. This class keeps to execute operations in or out of the model parallel region. This class keeps
track of the various RNG states and provides utility methods to maintain them and track of the various RNG states and provides utility methods to maintain them and
execute parts of the model under a given RNG setting. Using the `add` method, a execute parts of the model under a given RNG setting. Using the :meth:`add` method, a
cuda rng state is initialized based on the input `seed` and is assigned to `name`. cuda rng state is initialized based on the input ``seed`` and is assigned to ``name``.
Later, by forking the rng state, we can perform operations and return to our starting Later, by forking the rng state, we can perform operations and return to our starting
cuda state. cuda state.
""" """
...@@ -812,18 +833,24 @@ class CudaRNGStatesTracker: ...@@ -812,18 +833,24 @@ class CudaRNGStatesTracker:
Set the rng states. For efficiency purposes, we do not Set the rng states. For efficiency purposes, we do not
check the size of seed for compatibility. check the size of seed for compatibility.
states: Dict[str, torch.Tensor] Parameters
----------
states : Dict[str, torch.Tensor]
A mapping from string names to RNG states. A mapping from string names to RNG states.
""" """
self.states_ = states self.states_ = states
# Update global states.
set_all_rng_states(self.states_)
def add(self, name: str, seed: int) -> None: def add(self, name: str, seed: int) -> None:
""" """
Adds a new RNG state. Adds a new RNG state.
name: str Parameters
----------
name : str
string identifier for the RNG state. string identifier for the RNG state.
seed: int seed : int
PyTorch seed for the RNG state. PyTorch seed for the RNG state.
""" """
# Check seed is not already used. # Check seed is not already used.
...@@ -857,7 +884,9 @@ class CudaRNGStatesTracker: ...@@ -857,7 +884,9 @@ class CudaRNGStatesTracker:
Fork the cuda rng state, perform operations, and exit with Fork the cuda rng state, perform operations, and exit with
the original state. the original state.
name: str Parameters
----------
name : str
string identifier for the RNG state. string identifier for the RNG state.
""" """
# Check if we have added the state # Check if we have added the state
...@@ -901,6 +930,34 @@ def reduce_scatter_along_first_dim( ...@@ -901,6 +930,34 @@ def reduce_scatter_along_first_dim(
return output, handle return output, handle
@dataclass
class _AsyncHandle:
"""Handle for asynchronous collectives."""
async_handle: torch.distributed.Work
post_process_function: Optional[Callable] = None
post_process_function_args: Optional[Tuple[Any, ...]] = None
post_process_function_kwargs: Optional[Dict[str, Any]] = None
_synchronized: bool = False
def wait(self) -> None:
"""Synchronize the asynchronous communicaton.
Perform post-processing if needed.
"""
if self._synchronized:
return
self.async_handle.wait()
if self.post_process_function is not None:
args = self.post_process_function_args
args = () if args is None else args
kwargs = self.post_process_function_kwargs
kwargs = {} if kwargs is None else kwargs
self.post_process_function(*args, **kwargs)
self._synchronized = True
def _all_gather_fp8( def _all_gather_fp8(
inp: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
...@@ -948,7 +1005,13 @@ def _all_gather_fp8( ...@@ -948,7 +1005,13 @@ def _all_gather_fp8(
if isinstance(inp, Float8Tensor): if isinstance(inp, Float8Tensor):
dtype = inp.dtype dtype = inp.dtype
device = inp.device device = inp.device
# Temporarily ensure rowwise usage for output tensor creation
# since we're gathering rowwise data, not the transpose
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=init_columnwise_usage)
out = quantizer.make_empty(out_shape, dtype=dtype, device=device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
quantizer.set_usage(rowwise=init_rowwise_usage, columnwise=init_columnwise_usage)
elif isinstance(inp, Float8Tensor): elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape) out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty( out._data = torch.empty(
...@@ -985,77 +1048,7 @@ def _all_gather_fp8( ...@@ -985,77 +1048,7 @@ def _all_gather_fp8(
return out, handle return out, handle
def _get_quantizer_format(quantizer: Quantizer) -> Optional[bool]: def _start_all_gather_fp8_blockwise(
"""Get quantizer format."""
if isinstance(quantizer, DebugQuantizer):
quantizer = quantizer.parent_quantizer
if isinstance(quantizer, Float8BlockQuantizer):
return quantizer.all_gather_usage
return None
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
"""Make quantizer compact"""
_quantizer = quantizer
if isinstance(quantizer, DebugQuantizer):
_quantizer = quantizer.parent_quantizer
if isinstance(_quantizer, Float8BlockQuantizer):
_quantizer.all_gather_usage = compact
def _post_process_fp8_blockwise_gather(
out: Float8BlockwiseQTensorStorage,
quantizer: Float8BlockQuantizer,
handle: Optional[torch.distributed.Work] = None,
) -> Float8BlockwiseQTensorStorage:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
handle = None
if out._is_gemm_ready_format():
return out
needs_columnwise_data_transpose = (
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported(is_blockwise=True)
)
need_rowwise_scale_transpose = (
quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported(is_blockwise=True)
)
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
# so quantized tensor is 256x1024, scale inv is 2x1024
# If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
# on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
# Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
if needs_columnwise_data_transpose:
out._transpose_columnwise_data()
if need_rowwise_scale_transpose:
out._rowwise_scale_inv = out._rowwise_scale_inv.transpose(-2, -1).contiguous()
out._data_format = tex.Float8BlockScaleTensorFormat.GEMM_READY
return out
@dataclass
class _FP8BlockwiseAllGatherAsyncHandle:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor: Float8BlockwiseQTensorStorage
quantizer: Float8BlockQuantizer
async_handle: torch.distributed.Work
_synchronized: bool = False
def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
self.async_handle.wait()
_post_process_fp8_blockwise_gather(self.tensor, self.quantizer)
self._synchronized = True
def _all_gather_fp8_blockwise(
inp: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
*, *,
...@@ -1094,44 +1087,25 @@ def _all_gather_fp8_blockwise( ...@@ -1094,44 +1087,25 @@ def _all_gather_fp8_blockwise(
) )
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
# Check that quantizer is valid
if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
if not (quantizer.block_scaling_dim == 1 and (quantizer.block_len == 128 or quantizer.block_len == 64)):
raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")
# Output tensor dims # Output tensor dims
if out_shape is None: if out_shape is None:
out_shape = list(inp.size()) out_shape = list(inp.size())
out_shape[0] *= world_size out_shape[0] *= world_size
# Doing BF16 gather for now as baseline because it's simpler # Check that quantizer is valid
if ( if quantizer is None:
not isinstance(inp, Float8BlockwiseQTensorStorage) raise ValueError("Quantizer is missing")
and quantizer is not None if not isinstance(quantizer, Float8BlockQuantizer):
and not quantizer.is_quantizable(inp) raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
):
out = torch.empty( # Fall back to high-precision all-gather if FP8 is not supported
out_shape, if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
dtype=dtype, out = torch.empty(out_shape, dtype=dtype, device=device)
device=device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = False
out = quantizer(out) out = quantizer(out)
quantizer.all_gather_usage = orig_all_gather_usage
return out, None return out, None
# Implementation of fp8 gather needs to account for: # Quantize input tensor if needed
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# Cast input tensor to Float8BlockwiseQTensor with required data
# Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = True
if not isinstance(inp, Float8BlockwiseQTensorStorage): if not isinstance(inp, Float8BlockwiseQTensorStorage):
inp = quantizer(inp) inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
...@@ -1146,14 +1120,9 @@ def _all_gather_fp8_blockwise( ...@@ -1146,14 +1120,9 @@ def _all_gather_fp8_blockwise(
# Construct Float8BlockwiseQTensor output tensor # Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
quantizer.all_gather_usage = orig_all_gather_usage # Temporary buffers for all-gathering transposed buffers
interleaved_rowwise_scale_inv = None
# Begin to do network communication, need to make sure compact format interleaved_columnwise_data = None
if inp._data_format != tex.Float8BlockScaleTensorFormat.COMPACT:
raise RuntimeError(
"All-gather with FP8 block-wise quantized tensor requires compact data format, "
f"but found data_format={inp._data_format}"
)
# Coalesce NCCL collectives # Coalesce NCCL collectives
with torch.distributed._coalescing_manager( with torch.distributed._coalescing_manager(
...@@ -1162,11 +1131,17 @@ def _all_gather_fp8_blockwise( ...@@ -1162,11 +1131,17 @@ def _all_gather_fp8_blockwise(
async_ops=async_op, async_ops=async_op,
) as coalescing_manager: ) as coalescing_manager:
# Gather Float8BlockwiseQTensor data for row-wise usage # Gather row-wise data
if quantizer.rowwise_usage: if quantizer.rowwise_usage:
# Launch all-gathers scale_inv_shape = list(inp._rowwise_scale_inv.size())
scale_inv_shape[0] *= world_size
interleaved_rowwise_scale_inv = torch.empty(
scale_inv_shape,
dtype=inp._rowwise_scale_inv.dtype,
device=device,
)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out._rowwise_scale_inv, interleaved_rowwise_scale_inv,
inp._rowwise_scale_inv, inp._rowwise_scale_inv,
group=process_group, group=process_group,
) )
...@@ -1176,36 +1151,73 @@ def _all_gather_fp8_blockwise( ...@@ -1176,36 +1151,73 @@ def _all_gather_fp8_blockwise(
group=process_group, group=process_group,
) )
# Gather Float8BlockwiseQTensor data for column-wise usage # Column-wise data
if quantizer.columnwise_usage: if quantizer.columnwise_usage:
# Launch all-gathers data_shape = list(inp._columnwise_data.size())
data_shape[0] *= world_size
interleaved_columnwise_data = torch.empty(
data_shape,
dtype=inp._columnwise_data.dtype,
device=device,
)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out._columnwise_scale_inv, out._columnwise_scale_inv,
inp._columnwise_scale_inv, inp._columnwise_scale_inv,
group=process_group, group=process_group,
) )
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out._columnwise_data, interleaved_columnwise_data,
inp._columnwise_data, inp._columnwise_data,
group=process_group, group=process_group,
) )
handle = coalescing_manager if async_op else None # Finalize communication if needed
async_handle = None
# Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
# This means that we need to transpose the gathered columnwise data
# Example usage is grad_output tensor, ie. dY in linear backward
# We want to gather two FP8 tensors (rowwise and columnwise) along dim0
# and then transpose the columnwise data to match the rowwise data
# Make sure FP8 transpose is populated if needed
if async_op: if async_op:
handle = _FP8BlockwiseAllGatherAsyncHandle(out, quantizer, handle) async_handle = _AsyncHandle(
coalescing_manager,
post_process_function=_finish_all_gather_fp8_blockwise,
post_process_function_args=(
out,
world_size,
interleaved_rowwise_scale_inv,
interleaved_columnwise_data,
),
)
else: else:
# if it's a sync op, we need to do the transpose here as post processing step _finish_all_gather_fp8_blockwise(
_post_process_fp8_blockwise_gather(out, quantizer, handle) out,
world_size,
interleaved_rowwise_scale_inv,
interleaved_columnwise_data,
)
return out, handle return out, async_handle
def _finish_all_gather_fp8_blockwise(
out: Float8BlockwiseQTensorStorage,
world_size: int,
interleaved_rowwise_scale_inv: Optional[torch.Tensor],
interleaved_columnwise_data: Optional[torch.Tensor],
) -> Float8BlockwiseQTensorStorage:
"""Post-process FP8 blockwise gather."""
# Fix interleaving in row-wise scales
if interleaved_rowwise_scale_inv is not None:
dim0 = out._rowwise_scale_inv.size(0)
view_in = interleaved_rowwise_scale_inv.view(world_size, dim0, -1)
view_out = out._rowwise_scale_inv.view(dim0, world_size, -1)
tex.swap_first_dims(view_in, out=view_out)
# Fix interleaving in column-wise data
if interleaved_columnwise_data is not None:
dim0 = out._columnwise_data.size(0)
view_in = interleaved_columnwise_data.view(world_size, dim0, -1)
view_out = out._columnwise_data.view(dim0, world_size, -1)
tex.swap_first_dims(view_in, out=view_out)
return out
def _swap_first_dims(tensor: torch.Tensor, world_size: int): def _swap_first_dims(tensor: torch.Tensor, world_size: int):
...@@ -1219,7 +1231,7 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int): ...@@ -1219,7 +1231,7 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int):
""" """
shape = tensor.shape shape = tensor.shape
assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave." assert len(shape) >= 2, "Wrong number of dimensions for fixing interleave."
first_dim = shape[0] first_dim = shape[0]
flattened_trailing = math.prod(shape[1:]) flattened_trailing = math.prod(shape[1:])
assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave." assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave."
...@@ -1650,7 +1662,7 @@ def gather_along_first_dim( ...@@ -1650,7 +1662,7 @@ def gather_along_first_dim(
if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance( if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance(
quantizer, Float8BlockQuantizer quantizer, Float8BlockQuantizer
): ):
return _all_gather_fp8_blockwise( return _start_all_gather_fp8_blockwise(
inp, inp,
process_group, process_group,
async_op=async_op, async_op=async_op,
...@@ -1688,10 +1700,6 @@ def gather_along_first_dim( ...@@ -1688,10 +1700,6 @@ def gather_along_first_dim(
) )
if isinstance(inp, QuantizedTensorStorage): if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
compact = _get_quantizer_format(quantizer)
_set_quantizer_format(quantizer, compact=False)
out = torch.empty( out = torch.empty(
out_shape, out_shape,
dtype=inp.dtype, dtype=inp.dtype,
...@@ -1700,7 +1708,6 @@ def gather_along_first_dim( ...@@ -1700,7 +1708,6 @@ def gather_along_first_dim(
) )
torch.distributed.all_gather_into_tensor(out, inp, group=process_group) torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out) out = quantizer(out)
_set_quantizer_format(quantizer, compact=compact)
return out, None return out, None
# Dequantize quantized tensor if not supported # Dequantize quantized tensor if not supported
...@@ -2001,7 +2008,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -2001,7 +2008,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
Parameters Parameters
---------- ----------
fsdp_root: torch.nn.Module fsdp_root : torch.nn.Module
FSDP-wrapped root module that may contain FSDP-wrapped TE modules. FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
""" """
assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped." assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]: ...@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled : bool, default = False
whether or not to enable export whether or not to enable export
""" """
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -7,6 +7,7 @@ from collections.abc import Iterable ...@@ -7,6 +7,7 @@ from collections.abc import Iterable
import contextlib import contextlib
import gc import gc
import warnings import warnings
from math import ceil
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch import torch
...@@ -61,6 +62,21 @@ def graph_pool_handle(): ...@@ -61,6 +62,21 @@ def graph_pool_handle():
return _graph_pool_handle() return _graph_pool_handle()
@contextlib.contextmanager
def _none_grad_context_wrapper(inputs):
"""
Wrapper to set the gradients of the inputs to None,
in case the backward pass makes grad accumulations.
"""
original_input_grads = []
for input_tensor in inputs:
original_input_grads.append(input_tensor.grad)
input_tensor.grad = None
yield
for input_tensor, original_grad in zip(inputs, original_input_grads):
input_tensor.grad = original_grad
@contextlib.contextmanager @contextlib.contextmanager
def _graph_context_wrapper(*args, **kwargs): def _graph_context_wrapper(*args, **kwargs):
"""Wrapper around `torch.cuda.graph`. """Wrapper around `torch.cuda.graph`.
...@@ -127,6 +143,8 @@ def _make_graphed_callables( ...@@ -127,6 +143,8 @@ def _make_graphed_callables(
) )
# Check sizes of args # Check sizes of args
_order_without_wgrad = None
delay_wgrad_compute = False
if _order is None: if _order is None:
assert len(sample_args) == len(callables) assert len(sample_args) == len(callables)
assert len(sample_kwargs) == len(callables) assert len(sample_kwargs) == len(callables)
...@@ -145,17 +163,34 @@ def _make_graphed_callables( ...@@ -145,17 +163,34 @@ def _make_graphed_callables(
# values indicate backward passes. Each # values indicate backward passes. Each
# entry in sample_args corresponds to one of the forward # entry in sample_args corresponds to one of the forward
# passes. # passes.
num_model_chunks = max(_order) _order_without_wgrad = []
num_microbatches = len(_order) // num_model_chunks // 2 for c_id in _order:
assert num_model_chunks * num_microbatches * 2 == len(_order) if ceil(c_id) != c_id:
delay_wgrad_compute = True
continue
_order_without_wgrad.append(c_id)
num_model_chunks = max(_order_without_wgrad)
num_microbatches = len(_order_without_wgrad) // num_model_chunks // 2
assert num_model_chunks * num_microbatches * 2 == len(_order_without_wgrad)
# When delay_wgrad_compute is enabled, each layer is treated as a model chunk, which
# allows for fine-grained graph capture order.
if delay_wgrad_compute:
assert (
_num_layers_per_chunk is not None
), "'_num_layers_per_chunk' must be provided when delay_wgrad_compute is True."
for num_layers in _num_layers_per_chunk:
assert (
num_layers == 1
), "Each model chunk must have only one layer when delay_wgrad_compute is True."
# Determine number of layers in each model chunk. # Determine number of layers in each model chunk.
if _num_layers_per_chunk is None: if _num_layers_per_chunk is None:
assert len(sample_args) * 2 >= len(_order) and ( assert len(sample_args) * 2 >= len(_order_without_wgrad) and (
len(sample_args) * 2 % len(_order) == 0 len(sample_args) * 2 % len(_order_without_wgrad) == 0
), ( ), (
f"{len(sample_args)} * 2 >= {len(_order)} and {len(sample_args)} * 2 %" f"{len(sample_args)} * 2 >= {len(_order_without_wgrad)} and {len(sample_args)} * 2"
f" {len(_order)} == 0" f" % {len(_order_without_wgrad)} == 0"
) )
num_layers = len(sample_args) // num_model_chunks // num_microbatches num_layers = len(sample_args) // num_model_chunks // num_microbatches
_num_layers_per_chunk = [num_layers] * num_model_chunks _num_layers_per_chunk = [num_layers] * num_model_chunks
...@@ -175,7 +210,7 @@ def _make_graphed_callables( ...@@ -175,7 +210,7 @@ def _make_graphed_callables(
+ f"entries when order input is provided but got {len(callables)}." + f"entries when order input is provided but got {len(callables)}."
) )
assert len(sample_args) == total_num_layers * num_microbatches, ( assert len(sample_args) == total_num_layers * num_microbatches, (
f"Expected {total_num_layers * num_microbatches}" f"Expected {total_num_layers * num_microbatches} "
+ f"args tuple, but got {len(sample_args)}." + f"args tuple, but got {len(sample_args)}."
) )
...@@ -198,9 +233,10 @@ def _make_graphed_callables( ...@@ -198,9 +233,10 @@ def _make_graphed_callables(
assert ( assert (
is_training is_training
), "`_reuse_graph_input_output_buffers` is only available in training mode." ), "`_reuse_graph_input_output_buffers` is only available in training mode."
assert isinstance( if isinstance(sample_args, tuple):
sample_args, list sample_args = list(sample_args)
), "sample_args must be a list for _reuse_graph_input_output_buffers." if isinstance(sample_kwargs, tuple):
sample_kwargs = list(sample_kwargs)
# Reorganize args and kwargs for input tensor reuse. # Reorganize args and kwargs for input tensor reuse.
# fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples. # fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples.
...@@ -214,7 +250,7 @@ def _make_graphed_callables( ...@@ -214,7 +250,7 @@ def _make_graphed_callables(
consumed_sample_q = {} consumed_sample_q = {}
fwd_idx = [0] * num_model_chunks fwd_idx = [0] * num_model_chunks
for c_id in _order: for c_id in _order:
m_chunk = abs(c_id) - 1 m_chunk = abs(ceil(c_id)) - 1
if c_id > 0: if c_id > 0:
sample_start_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( sample_start_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
...@@ -241,6 +277,8 @@ def _make_graphed_callables( ...@@ -241,6 +277,8 @@ def _make_graphed_callables(
sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx] sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx]
sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx] sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx]
fwd_idx[m_chunk] += 1 fwd_idx[m_chunk] += 1
elif ceil(c_id) != c_id:
continue
else: else:
num_consumed_samples = min( num_consumed_samples = min(
len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk] len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk]
...@@ -411,13 +449,15 @@ def _make_graphed_callables( ...@@ -411,13 +449,15 @@ def _make_graphed_callables(
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()
if is_training: if is_training:
grad_inputs = torch.autograd.grad( inputs = tuple(i for i in static_input_surface if i.requires_grad)
outputs=tuple(o for o in outputs if o.requires_grad), with _none_grad_context_wrapper(inputs):
inputs=tuple(i for i in static_input_surface if i.requires_grad), torch.autograd.backward(
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), tuple(o for o in outputs if o.requires_grad),
only_inputs=True, grad_tensors=tuple(
allow_unused=allow_unused_input, torch.empty_like(o) for o in outputs if o.requires_grad
),
) )
grad_inputs = tuple(input.grad for input in inputs)
# Filter module params that get None grad from grad_inputs and remove them # Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks # from static_input_surface. This is to ensure that the backward hooks
...@@ -432,6 +472,14 @@ def _make_graphed_callables( ...@@ -432,6 +472,14 @@ def _make_graphed_callables(
module_params_with_grad = [] module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
if ( if (
grad_inputs[grad_inputs_idx] is None
and grad_inputs_idx < num_required_grad_sample_args
):
assert allow_unused_input, (
"The input tensor requires grad, but the grad is None after"
" backward pass."
)
elif (
grad_inputs[grad_inputs_idx] is not None grad_inputs[grad_inputs_idx] is not None
and grad_inputs_idx >= num_required_grad_sample_args and grad_inputs_idx >= num_required_grad_sample_args
): ):
...@@ -477,9 +525,11 @@ def _make_graphed_callables( ...@@ -477,9 +525,11 @@ def _make_graphed_callables(
fwd_idx = [0] * num_model_chunks fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks
static_grad_outputs_dict = {} static_grad_outputs_dict = {}
wgrad_validation_list = [None] * len(_order)
previous_chunk_last_callable_bwd_idx = None previous_chunk_last_callable_bwd_idx = None
for c_id in _order: for i, c_id in enumerate(_order):
if c_id > 0: if c_id > 0:
assert isinstance(c_id, int), "Forward order value must be an integer."
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk = c_id - 1 m_chunk = c_id - 1
for l_no in range(_num_layers_per_chunk[m_chunk]): for l_no in range(_num_layers_per_chunk[m_chunk]):
...@@ -499,12 +549,65 @@ def _make_graphed_callables( ...@@ -499,12 +549,65 @@ def _make_graphed_callables(
fwd_idx[m_chunk] += 1 fwd_idx[m_chunk] += 1
else: else:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk = -c_id - 1 m_chunk = -ceil(c_id) - 1
previous_per_callable_bwd_idx = None previous_per_callable_bwd_idx = None
for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))): for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))):
per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
) )
if ceil(c_id) == c_id and need_bwd_dw_graph[per_callable_bwd_idx]:
# Check if bwd graph has corresponding wgrad graph:
# Number of dgrad backward graphs should be equal to number of
# wgrad backward graphs.
# Note: For MCore, the validation rule is more strict (the next backward
# of dgrad graph must be corresponding wgrad graph).
if wgrad_validation_list[i] is None:
same_bwd_c_id_list = [i]
num_wgrad_c_id = 0
for idx in range(i + 1, len(_order)):
if _order[idx] > 0:
continue
if _order[idx] == c_id:
same_bwd_c_id_list.append(idx)
if _order[idx] + 0.5 == c_id:
num_wgrad_c_id += 1
if len(same_bwd_c_id_list) == num_wgrad_c_id:
for same_c_id_idx in same_bwd_c_id_list:
wgrad_validation_list[same_c_id_idx] = True
break
if len(same_bwd_c_id_list) < num_wgrad_c_id:
# It's impossible to have more wgrad than dgrad.
wgrad_validation_list[i] = False
break
if wgrad_validation_list[i] is None:
wgrad_validation_list[i] = False
assert wgrad_validation_list[i], (
f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number "
f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)
elif ceil(c_id) != c_id:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
assert is_training, "Only training mode supports backward_dw."
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate
# the specific order of backward_dw.
assert ceil(c_id) - c_id == 0.5, (
"The order diff of wgrad and dgrad must be 0.5, "
f"get {ceil(c_id) - c_id}."
)
assert need_bwd_dw_graph[
per_callable_bwd_idx
], "No module needs wgrad computation but get float in order"
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
continue
static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx] static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx]
static_outputs = per_callable_static_outputs[per_callable_bwd_idx] static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx]
...@@ -528,26 +631,17 @@ def _make_graphed_callables( ...@@ -528,26 +631,17 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o.requires_grad else None for o in static_outputs
) )
if is_training: if is_training:
with _graph_context_wrapper(bwd_graph, pool=mempool): inputs = tuple(i for i in static_input_surface if i.requires_grad)
grad_inputs = torch.autograd.grad( with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
outputs=tuple(o for o in static_outputs if o.requires_grad), bwd_graph, pool=mempool
inputs=tuple(i for i in static_input_surface if i.requires_grad), ):
grad_outputs=tuple(o for o in static_grad_outputs if o is not None), torch.autograd.backward(
only_inputs=True, tuple(o for o in static_outputs if o.requires_grad),
allow_unused=allow_unused_input, grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
# If no one module needs the backward_dw, the bwd_dw_graph will be empty. grad_inputs = tuple(input.grad for input in inputs)
# So skip capturing it.
if need_bwd_dw_graph[per_callable_bwd_idx]:
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs # Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern. # that don't require grad. I couldn't think of a one-liner for this pattern.
...@@ -596,7 +690,7 @@ def _make_graphed_callables( ...@@ -596,7 +690,7 @@ def _make_graphed_callables(
per_callable_static_grad_inputs[idx] per_callable_static_grad_inputs[idx]
) )
previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx
if ceil(c_id) == c_id:
bwd_idx[m_chunk] += 1 bwd_idx[m_chunk] += 1
else: else:
# Capture forward graphs # Capture forward graphs
...@@ -628,15 +722,17 @@ def _make_graphed_callables( ...@@ -628,15 +722,17 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o.requires_grad else None for o in static_outputs
) )
if is_training: if is_training:
with _graph_context_wrapper(bwd_graph, pool=mempool): inputs = tuple(i for i in static_input_surface if i.requires_grad)
grad_inputs = torch.autograd.grad( with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
outputs=tuple(o for o in static_outputs if o.requires_grad), bwd_graph, pool=mempool
inputs=tuple(i for i in static_input_surface if i.requires_grad), ):
grad_outputs=tuple(o for o in static_grad_outputs if o is not None), torch.autograd.backward(
only_inputs=True, tuple(o for o in static_outputs if o.requires_grad),
allow_unused=allow_unused_input, grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
grad_inputs = tuple(input.grad for input in inputs)
if need_bwd_dw_graph[bwd_idx]: if need_bwd_dw_graph[bwd_idx]:
with _graph_context_wrapper(bwd_dw_graph, pool=mempool): with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[bwd_idx]: for module in visited_te_modules[bwd_idx]:
...@@ -950,38 +1046,38 @@ def make_graphed_callables( ...@@ -950,38 +1046,38 @@ def make_graphed_callables(
Positional arguments to callable(s). Positional arguments to callable(s).
num_warmup_iters: int, default = 3 num_warmup_iters: int, default = 3
Number of warmup iterations. Number of warmup iterations.
allow_unused_input: bool, default = `False` allow_unused_input: bool, default = False
Whether to handle case where callable inputs Whether to handle case where callable inputs
and outputs are disconnected in compute graph. and outputs are disconnected in compute graph.
sample_kwargs: (tuple of) dict, optional sample_kwargs: (tuple of) dict, optional
Keyword arguments to callable(s) Keyword arguments to callable(s)
pool: (tuple of) int, default = `None`, optional pool: (tuple of) int, default = None, optional
An instance returned from function `torch.cuda.graph_pool_handle` that hints An instance returned from function `torch.cuda.graph_pool_handle` that hints
this graph may share memory with the indicated pool. this graph may share memory with the indicated pool.
retain_graph_in_backward: bool, default = `False` retain_graph_in_backward: bool, default = False
Whether to set retain_graph=True in backward graph capture. Whether to set retain_graph=True in backward graph capture.
_reuse_graph_input_output_buffers: bool, default = `False` _reuse_graph_input_output_buffers: bool, default = False
Reduce memory usage by reusing input/output data buffers between Reduce memory usage by reusing input/output data buffers between
graphs. Only supported with Mcore interleaved pipeline parallelism, i.e. graphs. Only supported with Mcore interleaved pipeline parallelism, i.e.
when `_order` is provided. All callables in `modules` are assumed to have when `_order` is provided. All callables in `modules` are assumed to have
inputs and outputs with the same dtype and shape. inputs and outputs with the same dtype and shape.
Quantization related parameters Quantization parameters
---------------------- -----------------------
enabled: (tuple of) bool, default = `False` enabled: (tuple of) bool, default = False
whether or not to enable low precision quantization (FP8/FP4). whether or not to enable low precision quantization (FP8/FP4).
If tuple, the length must match the number of modules. If tuple, the length must match the number of modules.
calibrating: bool, default = `False` calibrating: bool, default = False
calibration mode allows collecting statistics such as amax and scale calibration mode allows collecting statistics such as amax and scale
data of quantized tensors even when executing without quantization enabled. data of quantized tensors even when executing without quantization enabled.
This is useful for saving an inference ready checkpoint while training This is useful for saving an inference ready checkpoint while training
using a higher precision. using a higher precision.
recipe: recipe.Recipe, default = `None` recipe: recipe.Recipe, default = None
recipe used for low precision quantization. recipe used for low precision quantization.
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None` amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = None
distributed group over which amaxes for the quantized tensors distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step. are reduced at the end of each training step.
cache_quantized_params: bool, default = `False` cache_quantized_params: bool, default = False
Whether or not to cache quantized weights across microbatches. if set to `True`, Whether or not to cache quantized weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward the `is_first_microbatch` boolean argument must be passed into the forward
method for TransformerEngine modules. When storing primary weights in low precision method for TransformerEngine modules. When storing primary weights in low precision
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -8,7 +8,7 @@ from functools import wraps ...@@ -8,7 +8,7 @@ from functools import wraps
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import torch import torch
from . import torch_version from .torch_version import torch_version
from .export import is_in_onnx_export_mode from .export import is_in_onnx_export_mode
from .utils import gpu_autocast_ctx from .utils import gpu_autocast_ctx
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -20,7 +20,6 @@ import torch.nn.functional as F ...@@ -20,7 +20,6 @@ import torch.nn.functional as F
from torch.distributed.tensor import DTensor from torch.distributed.tensor import DTensor
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from ._common import _ParameterInitMeta, noop_cat from ._common import _ParameterInitMeta, noop_cat
from ..quantization import ( from ..quantization import (
...@@ -39,13 +38,19 @@ from ..distributed import ( ...@@ -39,13 +38,19 @@ from ..distributed import (
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..cpp_extensions.gemm import _NUM_MAX_UB_STREAMS
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from ..utils import (
is_non_tn_fp8_gemm_supported,
torch_get_autocast_gpu_dtype,
get_nvtx_range_context,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ...common.recipe import DelayedScaling, Recipe from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
...@@ -58,13 +63,9 @@ __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] ...@@ -58,13 +63,9 @@ __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
_2X_ACC_FPROP = False _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True _2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_dummy_wgrads = {} _dummy_wgrads = {}
_multi_stream_cublas_batchgemm_workspace = [] _multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None
_ub_communicators = None _ub_communicators = None
ub_stream_nums = int(os.getenv("NVTE_UB_STREAM_NUMS", "2"))
_NUM_MAX_UB_STREAMS = ub_stream_nums if IS_HIP_EXTENSION else 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -77,39 +78,6 @@ class UserBufferQuantizationMode(Enum): ...@@ -77,39 +78,6 @@ class UserBufferQuantizationMode(Enum):
NONE = "none" NONE = "none"
FP8 = "fp8" FP8 = "fp8"
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
# Add env for control the padding for blaslt
if IS_HIP_EXTENSION:
return 134_217_728
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return 32 * 1024 * 1024 + 1024
return 4_194_304
def get_workspace() -> torch.Tensor:
"""Returns workspace for cublas."""
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
)
return _cublas_workspace
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace:
for _ in range(tex.get_num_cublas_streams()):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
)
return _multi_stream_cublas_workspace
def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]: def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas.""" """Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_batchgemm_workspace global _multi_stream_cublas_batchgemm_workspace
...@@ -126,7 +94,6 @@ if bool(int(os.getenv("NVTE_DISABLE_FC2_DGRAD_OVERLAP", "0"))): ...@@ -126,7 +94,6 @@ if bool(int(os.getenv("NVTE_DISABLE_FC2_DGRAD_OVERLAP", "0"))):
else: else:
remove_ag_gemm_dgrad = [] remove_ag_gemm_dgrad = []
def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
"""Returns a dummy tensor of given shape.""" """Returns a dummy tensor of given shape."""
assert len(shape) == 2 assert len(shape) == 2
...@@ -154,27 +121,27 @@ def initialize_ub( ...@@ -154,27 +121,27 @@ def initialize_ub(
) -> None: ) -> None:
r""" r"""
Initialize the Userbuffers communicator for overlapping tensor-parallel communications with Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules. GEMM compute in ``te.Linear``, ``te.LayerNormLinear`` and ``te.LayerNormMLP`` modules.
Parameters Parameters
---------- ----------
shape : list shape : list
shape of the communication buffer, typically set to be the same as the global shape of shape of the communication buffer, typically set to be the same as the global shape of
the input tensor to a te.TransformerLayer forward pass, with the sequence and batch the input tensor to a ``te.TransformerLayer`` forward pass, with the sequence and batch
dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)` dimensions collapsed together -- i.e.: ``(sequence_length * batch_size, hidden_size)``
tp_size : int tp_size : int
number of GPUs in the tensor-parallel process group number of GPUs in the tensor-parallel process group
use_fp8 : bool = False use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs. allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use `quantization_modes` instead. DEPRECATED: Please use ``quantization_modes`` instead.
quantization_modes : List[UserBufferQuantizationMode] = None quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list. if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy `use_fp8` parameter if `None` is provided. falls back to the legacy ``use_fp8`` parameter if ``None`` is provided.
dtype : torch.dtype = torch.bfloat16 dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False` non-FP8 data type of the communication buffer when ``use_fp8 = False``
ub_cfgs: dict = None ub_cfgs : dict = None
Configuration dictionary with the structure Configuration dictionary with the structure::
```
{ {
<gemm_name> : { <gemm_name> : {
"method": <"ring_exchange" or "pipeline">, "method": <"ring_exchange" or "pipeline">,
...@@ -189,20 +156,20 @@ def initialize_ub( ...@@ -189,20 +156,20 @@ def initialize_ub(
"fp8_buf": bool, "fp8_buf": bool,
} }
} }
```
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", for ``te.TransformerLayer`` GEMM layers in ``["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]`. "fc2_fprop", "fc2_wgrad"]``.
a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes` a list may be provided to specify different overlap configurations for different the quantization settings in ``quantization_modes``
bootstrap_backend : str = None bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and ``torch.distributed`` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are barrier collectives during Userbuffers initialization. Not all backends are
valid for every cluster configuration and distributed launch method even if valid for every cluster configuration and distributed launch method even if
they are available in PyTorch. When left unset, the initialization prefers they are available in PyTorch. When left unset, the initialization prefers
to use the MPI backend, falling back first on Gloo and then NCCL if MPI is to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this not available. Setting ``NVTE_UB_WITH_MPI=1`` when building TE overrides this
option and always initializes Userbuffers with direct MPI calls in C++, option and always initializes Userbuffers with direct MPI calls in C++,
which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time. which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time.
""" """
if not tex.device_supports_multicast(): if not tex.device_supports_multicast():
assert bool(int(os.getenv("UB_SKIPMC", "1"))), ( assert bool(int(os.getenv("UB_SKIPMC", "1"))), (
...@@ -299,16 +266,6 @@ def initialize_ub( ...@@ -299,16 +266,6 @@ def initialize_ub(
flush=True, flush=True,
) )
# Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS:
# This ensures we don't do `.repeat()` on an already expanded workspace
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
).repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
layers_all_gather_overlap = [ layers_all_gather_overlap = [
"qkv_fprop", "qkv_fprop",
...@@ -642,6 +599,8 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -642,6 +599,8 @@ def fill_userbuffers_buffer_for_all_gather(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, " "Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f"but got MXFP8 tensor with shape={tuple(local_shape)}" f"but got MXFP8 tensor with shape={tuple(local_shape)}"
) )
if local_tensor._with_gemm_swizzled_scales:
raise ValueError("Userbuffers assumes MXFP8 tensors have unswizzled scales")
local_scale_inv = ( local_scale_inv = (
local_tensor._rowwise_scale_inv local_tensor._rowwise_scale_inv
if with_rowwise_data if with_rowwise_data
...@@ -674,6 +633,7 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -674,6 +633,7 @@ def fill_userbuffers_buffer_for_all_gather(
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=local_tensor._fp8_dtype, fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer, quantizer=quantizer,
with_gemm_swizzled_scales=False,
) )
return global_tensor, local_tensor return global_tensor, local_tensor
...@@ -1033,7 +993,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1033,7 +993,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
Parameters Parameters
---------- ----------
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
""" """
self.tp_group = tp_group self.tp_group = tp_group
...@@ -1123,8 +1083,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1123,8 +1083,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
self.allow_different_data_and_param_types = allow_different_data_and_param_types self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True self.forwarded_at_least_once = True
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
delayed_scaling_recipe = self.fp8_meta["recipe"].delayed()
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else: else:
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
...@@ -1136,25 +1098,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1136,25 +1098,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.init_fp8_metadata(num_gemms=num_gemms) self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence() self._check_weight_tensor_recipe_correspondence()
if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe:
if self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, ( assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is " "Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8." "necessary when using sequence parallelism with FP8."
) )
if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): if not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
# Activation recomputation is used and this is the first forward phase. # Activation recomputation is used and this is the first forward phase.
if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): if self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): with get_nvtx_range_context(self.__class__.__name__ + " forward"):
if not allow_non_contiguous and not inp.is_contiguous(): if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous() inp = inp.contiguous()
yield inp yield inp
if self.fp8 and in_fp8_activation_recompute_phase(): if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
...@@ -1243,18 +1207,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1243,18 +1207,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
if ctx.debug: if ctx.debug:
grad_output_ = quantizer(grad_output) grad_output_ = quantizer(grad_output)
if ( if ctx.use_bias:
isinstance(
grad_output_.get_tensor(True),
(
QuantizedTensor,
Float8TensorStorage,
MXFP8TensorStorage,
Float8BlockwiseQTensorStorage,
),
)
and ctx.use_bias
):
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
grad_bias = None grad_bias = None
...@@ -1434,7 +1387,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1434,7 +1387,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
workspace is being constructed or updated. workspace is being constructed or updated.
cache_name: str, optional cache_name: str, optional
Key for caching. Key for caching.
update_workspace: bool, default = `True` update_workspace: bool, default = True
Update workspace with values from `tensor`. Update workspace with values from `tensor`.
skip_update_flag: torch.Tensor, optional skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence GPU flag to skip updating the workspace. Take precedence
...@@ -1478,6 +1431,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1478,6 +1431,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset_cache = True reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None: elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True reset_cache = True
elif isinstance(out, NVFP4TensorStorage):
if quantizer.rowwise_usage and out._rowwise_data is None:
reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True
if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
reset_cache = True reset_cache = True
if reset_cache: if reset_cache:
...@@ -1576,7 +1534,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1576,7 +1534,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
if not self.need_backward_dw(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): with get_nvtx_range_context(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop() (wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
weight_tensor = noop_cat(self._get_weight_tensors()) weight_tensor = noop_cat(self._get_weight_tensors())
...@@ -1618,6 +1576,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1618,6 +1576,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# we use the debug value from the first invocation in the iteration. # we use the debug value from the first invocation in the iteration.
debug = self.debug_enabled_in_this_iteration debug = self.debug_enabled_in_this_iteration
self.debug_last_iteration = TEDebugState.get_iteration()
if self.wgrad_store is not None:
if debug and self.wgrad_store.delay_wgrad_compute():
raise RuntimeError("Delayed wgrad compute is not supported in debug mode.")
return debug return debug
def no_debug_features_active(self, quantizers): def no_debug_features_active(self, quantizers):
...@@ -1673,6 +1637,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1673,6 +1637,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
if not self.fp8 and not self.fp8_calibration: if not self.fp8 and not self.fp8_calibration:
return return
if not self.primary_weights_in_fp8:
return
if not hasattr(self, "weight_names") or not self.weight_names: if not hasattr(self, "weight_names") or not self.weight_names:
return return
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..quantization import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager, get_align_size_for_quantization
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
...@@ -24,11 +24,14 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -24,11 +24,14 @@ class _Fp8Padding(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
m_splits: List[int], non_tensor_args: Tuple,
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(m_splits, padded_m_splits, is_grad_enabled) = non_tensor_args
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = inp.shape[-1] in_features = inp.shape[-1]
...@@ -65,7 +68,7 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -65,7 +68,7 @@ class _Fp8Padding(torch.autograd.Function):
grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits
) )
return (grad_input, None, None, None) return grad_input, None
class Fp8Padding(torch.nn.Module): class Fp8Padding(torch.nn.Module):
...@@ -111,14 +114,8 @@ class Fp8Padding(torch.nn.Module): ...@@ -111,14 +114,8 @@ class Fp8Padding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None: if self.align_size is None:
self.align_size = ( recipe = FP8GlobalStateManager.get_fp8_recipe()
32 self.align_size = get_align_size_for_quantization(recipe)
if (
FP8GlobalStateManager.get_fp8_recipe().mxfp8()
or FP8GlobalStateManager.get_fp8_recipe().nvfp4()
)
else 16
)
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
...@@ -128,19 +125,20 @@ class Fp8Padding(torch.nn.Module): ...@@ -128,19 +125,20 @@ class Fp8Padding(torch.nn.Module):
if m_splits == padded_m_splits: if m_splits == padded_m_splits:
return inp, m_splits return inp, m_splits
if torch.is_grad_enabled(): is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
fn = _Fp8Padding.apply fn = _Fp8Padding.apply
args = [] autograd_ctx = []
else: else:
fn = _Fp8Padding.forward fn = _Fp8Padding.forward
args = [None] autograd_ctx = [None]
args += ( non_tensor_args = (
inp,
m_splits, m_splits,
padded_m_splits, padded_m_splits,
torch.is_grad_enabled(), is_grad_enabled,
) )
out = fn(*args) out = fn(*autograd_ctx, inp, non_tensor_args)
return out, padded_m_splits return out, padded_m_splits
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""FP8 Padding API""" """FP8 Padding API"""
from typing import List, Optional from typing import List, Optional, Tuple
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..quantization import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager, get_align_size_for_quantization
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
...@@ -24,11 +24,14 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -24,11 +24,14 @@ class _Fp8Unpadding(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
m_splits: List[int], non_tensor_args: Tuple,
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(m_splits, padded_m_splits, is_grad_enabled) = non_tensor_args
in_features = inp.shape[-1] in_features = inp.shape[-1]
# Allocate cast and transpose output tensor # Allocate cast and transpose output tensor
...@@ -63,7 +66,7 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -63,7 +66,7 @@ class _Fp8Unpadding(torch.autograd.Function):
grad_output.view(-1, in_features), grad_input, ctx.m_splits, ctx.padded_m_splits grad_output.view(-1, in_features), grad_input, ctx.m_splits, ctx.padded_m_splits
) )
return (grad_input, None, None, None) return grad_input, None
class Fp8Unpadding(torch.nn.Module): class Fp8Unpadding(torch.nn.Module):
...@@ -109,14 +112,8 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -109,14 +112,8 @@ class Fp8Unpadding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None: if self.align_size is None:
self.align_size = ( recipe = FP8GlobalStateManager.get_fp8_recipe()
32 self.align_size = get_align_size_for_quantization(recipe)
if (
FP8GlobalStateManager.get_fp8_recipe().mxfp8()
or FP8GlobalStateManager.get_fp8_recipe().nvfp4()
)
else 16
)
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
...@@ -126,19 +123,20 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -126,19 +123,20 @@ class Fp8Unpadding(torch.nn.Module):
if m_splits == padded_m_splits: if m_splits == padded_m_splits:
return inp return inp
if torch.is_grad_enabled(): is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
fn = _Fp8Unpadding.apply fn = _Fp8Unpadding.apply
args = [] autograd_ctx = []
else: else:
fn = _Fp8Unpadding.forward fn = _Fp8Unpadding.forward
args = [None] autograd_ctx = [None]
args += ( non_tensor_args = (
inp,
m_splits, m_splits,
padded_m_splits, padded_m_splits,
torch.is_grad_enabled(), is_grad_enabled,
) )
out = fn(*args) out = fn(*autograd_ctx, inp, non_tensor_args)
return out return out
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""GroupedLinear API""" """GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List from typing import Union, Optional, Callable, Tuple, List
from itertools import chain
import warnings import warnings
import os import os
import functools import functools
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from .base import ( from .base import (
get_dummy_wgrad,
get_multi_stream_cublas_workspace,
get_dummy_wgrad, get_dummy_wgrad,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
...@@ -30,6 +30,7 @@ from ..utils import ( ...@@ -30,6 +30,7 @@ from ..utils import (
clear_tensor_data, clear_tensor_data,
init_method_constant, init_method_constant,
requires_grad, requires_grad,
get_nvtx_range_context,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -42,7 +43,6 @@ from ..cpp_extensions import ( ...@@ -42,7 +43,6 @@ from ..cpp_extensions import (
) )
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
...@@ -52,7 +52,8 @@ from ..quantized_tensor import ( ...@@ -52,7 +52,8 @@ from ..quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from torch.utils.cpp_extension import IS_HIP_EXTENSION from ...debug.pytorch.debug_quantization import DebugQuantizer
from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["GroupedLinear"] __all__ = ["GroupedLinear"]
...@@ -62,32 +63,42 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -62,32 +63,42 @@ class _GroupedLinear(torch.autograd.Function):
Calls custom cuda extensions. Calls custom cuda extensions.
""" """
# pylint: disable=keyword-arg-before-vararg
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
m_splits: List[int], non_tensor_args: Tuple,
use_bias: bool,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizers: List[Quantizer],
weight_quantizers: List[Quantizer],
output_quantizers: List[Quantizer],
grad_output_quantizers: List[Quantizer],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
sequence_parallel: bool,
activation_dtype: torch.dtype,
is_grad_enabled: bool,
module,
skip_fp8_weight_update,
save_original_input,
*weights_and_biases, *weights_and_biases,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
m_splits,
use_bias,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
fuse_wgrad_accumulation,
cpu_offloading,
sequence_parallel,
activation_dtype,
is_grad_enabled,
module,
skip_fp8_weight_update,
save_original_input,
debug,
) = non_tensor_args
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:] biases = weights_and_biases[num_gemms:]
...@@ -133,8 +144,17 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -133,8 +144,17 @@ class _GroupedLinear(torch.autograd.Function):
) )
inp_view = inp.reshape(-1, in_features) inp_view = inp.reshape(-1, in_features)
inputmats: list inputmats: list
if fp8: if fp8 and not debug:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) # Disable bulk allocation when CPU offloading is active: offloading skips small
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
inputmats = tex.split_quantize(
inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading
)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype
)
else: else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
...@@ -143,7 +163,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -143,7 +163,7 @@ class _GroupedLinear(torch.autograd.Function):
# Initialize weights # Initialize weights
weights_fp8: list weights_fp8: list
if fp8: if fp8 or debug:
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
weights_fp8 = [] weights_fp8 = []
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
...@@ -154,6 +174,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -154,6 +174,7 @@ class _GroupedLinear(torch.autograd.Function):
cache_name=(None if is_first_microbatch is None else f"weight{i}"), cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
workspace_dtype=activation_dtype,
) )
weights_fp8.append(weight_fp8) weights_fp8.append(weight_fp8)
...@@ -165,7 +186,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -165,7 +186,6 @@ class _GroupedLinear(torch.autograd.Function):
if fp8 and activation_dtype == torch.float32: if fp8 and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 # FP8 GEMM only supports BF16/FP16 bias bias_dtype = torch.bfloat16 # FP8 GEMM only supports BF16/FP16 bias
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
# Initialize output tensor # Initialize output tensor
out = torch.empty( out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0)], [sum(m_splits), weights_fp8[0].size(0)],
...@@ -181,12 +201,12 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -181,12 +201,12 @@ class _GroupedLinear(torch.autograd.Function):
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Perform GEMM # Perform GEMM
_ = general_grouped_gemm( general_grouped_gemm(
weights_fp8, weights_fp8,
inputmats, inputmats,
[out], [out],
output_quantizers,
activation_dtype, activation_dtype,
get_multi_stream_cublas_workspace(),
single_output=True, single_output=True,
m_splits=m_splits, m_splits=m_splits,
bias=biases, bias=biases,
...@@ -243,6 +263,10 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -243,6 +263,10 @@ class _GroupedLinear(torch.autograd.Function):
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.grad_input_quantizers = grad_input_quantizers
ctx.grad_output_quantizers = grad_output_quantizers
ctx.grad_weight_quantizers = grad_weight_quantizers
ctx.weights_requires_grad = weights[0].requires_grad ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad: if fuse_wgrad_accumulation and ctx.weights_requires_grad:
# This check is needed to ensure that main_grad is not created # This check is needed to ensure that main_grad is not created
...@@ -258,7 +282,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -258,7 +282,7 @@ class _GroupedLinear(torch.autograd.Function):
else: else:
ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)] ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)]
ctx.device = device ctx.device = device
ctx.grad_output_quantizers = grad_output_quantizers ctx.output_quantizers = output_quantizers
ctx.m_splits = m_splits ctx.m_splits = m_splits
ctx.num_gemms = num_gemms ctx.num_gemms = num_gemms
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
...@@ -278,6 +302,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -278,6 +302,7 @@ class _GroupedLinear(torch.autograd.Function):
or FP8GlobalStateManager.is_first_fp8_module() or FP8GlobalStateManager.is_first_fp8_module()
) )
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
ctx.debug = debug
ctx.save_original_input = save_original_input ctx.save_original_input = save_original_input
ctx.input_quantizers = input_quantizers ctx.input_quantizers = input_quantizers
...@@ -287,7 +312,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -287,7 +312,7 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_GroupedLinear_backward"): with get_nvtx_range_context("_GroupedLinear_backward"):
saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
N = ctx.num_gemms N = ctx.num_gemms
inputmats = saved_tensors[:N] inputmats = saved_tensors[:N]
...@@ -310,7 +335,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -310,7 +335,7 @@ class _GroupedLinear(torch.autograd.Function):
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
grad_output = [None] * ctx.num_gemms grad_output = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms
if ctx.fp8: if ctx.fp8 and not ctx.debug:
if ctx.use_bias: if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits) grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
recipe = ctx.fp8_recipe recipe = ctx.fp8_recipe
...@@ -337,6 +362,13 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -337,6 +362,13 @@ class _GroupedLinear(torch.autograd.Function):
ctx.m_splits, ctx.m_splits,
ctx.grad_output_quantizers, ctx.grad_output_quantizers,
) )
elif ctx.debug:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = DebugQuantizer.multi_tensor_quantize(
grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype
)
else: else:
# Only split grad output. Grad bias is fused with # Only split grad output. Grad bias is fused with
# wgrad GEMM. # wgrad GEMM.
...@@ -354,7 +386,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -354,7 +386,7 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8: if ctx.fp8 or ctx.debug:
recipe = ctx.fp8_recipe recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"): if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = ( dgrad_gemm_use_split_accumulator = (
...@@ -374,8 +406,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -374,8 +406,8 @@ class _GroupedLinear(torch.autograd.Function):
weights, weights,
grad_output, grad_output,
[dgrad], [dgrad],
ctx.grad_input_quantizers,
ctx.activation_dtype, ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
single_output=True, single_output=True,
layout="NN", layout="NN",
m_splits=ctx.m_splits, m_splits=ctx.m_splits,
...@@ -412,17 +444,20 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -412,17 +444,20 @@ class _GroupedLinear(torch.autograd.Function):
else: else:
input_quantizer.set_usage(rowwise=False, columnwise=True) input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmats: list inputmats: list
if ctx.fp8: if ctx.fp8 and not ctx.debug:
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype
)
else: else:
inputmats = torch.split( inputmats = torch.split(
cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits
) )
grouped_gemm_wgrad = functools.partial( grouped_gemm_wgrad = functools.partial(
general_grouped_gemm, general_grouped_gemm,
quantization_params=ctx.grad_weight_quantizers,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
workspaces=get_multi_stream_cublas_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
m_splits=ctx.m_splits, m_splits=ctx.m_splits,
...@@ -494,28 +529,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -494,28 +529,11 @@ class _GroupedLinear(torch.autograd.Function):
): ):
grad_biases = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
None, None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
*wgrad_list, *wgrad_list,
*grad_biases, *grad_biases,
) )
...@@ -533,14 +551,14 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -533,14 +551,14 @@ class GroupedLinear(TransformerEngineBaseModule):
size of each input sample. size of each input sample.
out_features : int out_features : int
size of each output sample. size of each output sample.
bias : bool, default = `True` bias : bool, default = True
if set to `False`, the layer will not learn an additive bias. if set to ``False``, the layer will not learn an additive bias.
init_method : Callable, default = `None` init_method : Callable, default = None
used for initializing weights in the following way: `init_method(weight)`. used for initializing weights in the following way: ``init_method(weight)``.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
get_rng_state_tracker : Callable, default = `None` get_rng_state_tracker : Callable, default = None
used to get the random number generator state tracker for initializing weights. used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default = `None` rng_tracker_name : str, default = None
the param passed to get_rng_state_tracker to get the specific rng tracker. the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
...@@ -549,33 +567,35 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -549,33 +567,35 @@ class GroupedLinear(TransformerEngineBaseModule):
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = False
if set to `True`, enables fusing of creation and accumulation of if set to ``True``, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional ``main_grad`` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular ``grad``) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating. will overwrite ``main_grad`` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias itself, but when set to ``True``, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.get_default_dtype()` params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False` delay_wgrad_compute : bool, default = False
Whether to delay weight gradient computation Whether to delay weight gradient computation
save_original_input : bool, default = `False` save_original_input : bool, default = False
If set to `True`, always saves the original input tensor rather than the If set to ``True``, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules, cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage. and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe. Cannot work with FP8 DelayedScaling recipe.
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and Notes
`parallel_mode` are used to determine the shapes of weights and biases. -----
GroupedLinear doesn't really handle the TP communications inside. The ``tp_size`` and
``parallel_mode`` are used to determine the shapes of weights and biases.
The TP communication should be handled in the dispatch and combine stages of MoE models. The TP communication should be handled in the dispatch and combine stages of MoE models.
""" """
...@@ -601,6 +621,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -601,6 +621,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
save_original_input: bool = False, save_original_input: bool = False,
name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -621,6 +642,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -621,6 +642,7 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support Userbuffer overlap." ), "GroupedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute) self.wgrad_store = WeightGradStore(delay_wgrad_compute)
...@@ -694,7 +716,8 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -694,7 +716,8 @@ class GroupedLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms) self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=device == "meta") is_meta = torch.device(device).type == "meta"
self.reset_parameters(defer_init=is_meta)
if self.wgrad_store.delay_wgrad_compute(): if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters(): for name, param in self.named_parameters():
...@@ -706,13 +729,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -706,13 +729,9 @@ class GroupedLinear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs # Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
assert not self.tp_size > 1, (
"GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
...@@ -770,58 +789,46 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -770,58 +789,46 @@ class GroupedLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
debug = self.is_debug_iter()
assert not isinstance( assert not isinstance(
inp, QuantizedTensorStorage inp, QuantizedTensorStorage
), "GroupedLinear doesn't support input tensor in FP8." ), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if FP8GlobalStateManager.fp8_graph_capturing(): is_grad_enabled = torch.is_grad_enabled()
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with torch.cuda.device( with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors() weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
weight_quantizers = self._get_weight_quantizers() quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
input_quantizers, output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
)
grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms):
input_quantizers[i].internal = False
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
if torch.is_grad_enabled(): if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()
if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")
(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers
if is_grad_enabled:
linear_fn = _GroupedLinear.apply linear_fn = _GroupedLinear.apply
args = [] autograd_ctx = []
else: else:
linear_fn = _GroupedLinear.forward linear_fn = _GroupedLinear.forward
args = [None] autograd_ctx = [None]
args += (
inp, non_tensor_args = (
m_splits, m_splits,
self.apply_bias, self.apply_bias,
is_first_microbatch, is_first_microbatch,
...@@ -831,19 +838,20 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -831,19 +838,20 @@ class GroupedLinear(TransformerEngineBaseModule):
input_quantizers, input_quantizers,
weight_quantizers, weight_quantizers,
output_quantizers, output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers, grad_output_quantizers,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(), is_cpu_offload_enabled(),
self.sequence_parallel, self.sequence_parallel,
self.activation_dtype, self.activation_dtype,
torch.is_grad_enabled(), is_grad_enabled,
self, self,
skip_fp8_weight_update, None, # skip_fp8_weight_update
self.save_original_input, self.save_original_input,
*weight_tensors, debug,
*bias_tensors,
) )
out = linear_fn(*args) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if self.return_bias: if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
...@@ -856,7 +864,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -856,7 +864,7 @@ class GroupedLinear(TransformerEngineBaseModule):
""" """
if not self.need_backward_dw(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): with get_nvtx_range_context("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop() (_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2] wgrad_list = tensor_list[2]
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
...@@ -876,9 +884,12 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -876,9 +884,12 @@ class GroupedLinear(TransformerEngineBaseModule):
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear.""" """Customize quantizers based on current scaling recipe + linear."""
assert (
recipe.float8_current_scaling() assert not self.tp_size > 1, (
), "current scaling recipe quantizer customization here" "GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
if fwd: if fwd:
for i in range(self.num_gemms): for i in range(self.num_gemms):
# set configs about amax epsilon and power_2_scale # set configs about amax epsilon and power_2_scale
...@@ -932,3 +943,56 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -932,3 +943,56 @@ class GroupedLinear(TransformerEngineBaseModule):
for i in range(self.num_gemms): for i in range(self.num_gemms):
weight_quantizers[i].internal = True weight_quantizers[i].internal = True
return weight_quantizers return weight_quantizers
def _get_quantizers(self):
weight_quantizers = self._get_weight_quantizers()
input_quantizers, output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
)
grad_input_quantizers, grad_weight_quantizers, grad_output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
[None] * self.num_gemms,
)
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
input_quantizers[i].internal = True
input_quantizers[i].optimize_for_gemm = True
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
grad_output_quantizers[i].optimize_for_gemm = True
return (
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
)
def _get_debug_quantizers(self):
original_quantizers = self._get_quantizers()
assert TEDebugState.debug_enabled
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
[
DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group)
for q_id, q in enumerate(qs)
]
for name, qs in zip(names, original_quantizers)
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -28,33 +28,30 @@ class LayerNorm(_LayerNormOp): ...@@ -28,33 +28,30 @@ class LayerNorm(_LayerNormOp):
Parameters Parameters
---------- ----------
normalized_shape: int or iterable of int normalized_shape : int or iterable of int
Inner dimensions of input tensor Inner dimensions of input tensor
eps : float, default = 1e-5 eps : float, default = 1e-5
A value added to the denominator of layer normalization for A value added to the denominator of layer normalization for
numerical stability numerical stability
device: torch.device, default = default CUDA device device : torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype : torch.dtype, default = default dtype
Tensor datatype Tensor datatype
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero If ``True``, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to and the calculation changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0 sm_margin : int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels. helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward", margin at each compute stage (``"forward"``, ``"backward"``,
"inference"). ``"inference"``).
sequence_parallel : bool
Legacy **Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters.
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
This is custom logic for Megatron-LM integration. This is custom logic for Megatron-LM integration.
""" """
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -15,11 +15,10 @@ from torch.nn import init ...@@ -15,11 +15,10 @@ from torch.nn import init
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.torch_version import torch_version
from transformer_engine.pytorch.tensor.utils import is_custom from transformer_engine.pytorch.tensor.utils import is_custom
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_workspace,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
get_dummy_wgrad, get_dummy_wgrad,
...@@ -40,6 +39,7 @@ from ..utils import ( ...@@ -40,6 +39,7 @@ from ..utils import (
nvtx_range_push, nvtx_range_push,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
get_nvtx_range_context,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -64,7 +64,6 @@ from ..quantized_tensor import ( ...@@ -64,7 +64,6 @@ from ..quantized_tensor import (
restore_from_saved, restore_from_saved,
) )
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..cpu_offload import ( from ..cpu_offload import (
is_cpu_offload_enabled, is_cpu_offload_enabled,
...@@ -107,47 +106,53 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -107,47 +106,53 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias: Union[torch.Tensor, None], ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
eps: float, non_tensor_args: Tuple,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
normalization: str,
ub_overlap_ag_fprop: bool,
ub_overlap_rs_fprop: bool,
ub_overlap_ag_dgrad: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_name: str,
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
parallel_mode,
return_layernorm_output,
return_layernorm_output_gathered,
is_grad_enabled,
fwd_ln_sm_margin,
bwd_ln_sm_margin,
zero_centered_gamma,
normalization,
ub_overlap_ag_fprop,
ub_overlap_rs_fprop,
ub_overlap_ag_dgrad,
ub_overlap_rs_dgrad,
ub_bulk_wgrad,
ub_bulk_dgrad,
ub_name,
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
debug,
) = non_tensor_args
# NVTX label for profiling # NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.forward" nvtx_label = "transformer_engine._LayerNormLinear.forward"
if ub_name is not None: if ub_name is not None:
...@@ -258,8 +263,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -258,8 +263,6 @@ class _LayerNormLinear(torch.autograd.Function):
if fp8 or debug: if fp8 or debug:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(input_quantizer, Float8BlockQuantizer):
input_quantizer.all_gather_usage = False
ln_out_total = input_quantizer(ln_out_total) ln_out_total = input_quantizer(ln_out_total)
else: else:
quantizer = None quantizer = None
...@@ -366,7 +369,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -366,7 +369,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat, weightmat,
ln_out_total, ln_out_total,
get_workspace(),
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=bias, bias=bias,
...@@ -555,7 +557,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -555,7 +557,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_name is not None: if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}" nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_LayerNormLinear_backward"): with get_nvtx_range_context("_LayerNormLinear_backward"):
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
...@@ -743,7 +745,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -743,7 +745,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weight, weight,
grad_output, grad_output,
get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
...@@ -870,7 +871,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -870,7 +871,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = { wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
...@@ -1045,44 +1045,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1045,44 +1045,7 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta, dbeta,
wgrad, wgrad,
grad_bias, grad_bias,
None, # eps None,
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # return_layernorm_output
None, # return_layernorm_output_gathered
None, # is_grad_enabled
None, # fwd_ln_sm_margin
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # normalization
None, # ub_overlap_ag_fprop
None, # ub_overlap_rs_fprop
None, # ub_overlap_ag_dgrad
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fsdp_group
None, # debug
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
) )
...@@ -1098,20 +1061,20 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1098,20 +1061,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
size of each output sample. size of each output sample.
eps : float, default = 1e-5 eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability. a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True` bias : bool, default = True
if set to `False`, the layer will not learn an additive bias. if set to ``False``, the layer will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied. type of normalization applied.
init_method : Callable, default = `None` init_method : Callable, default = None
used for initializing weights in the following way: `init_method(weight)`. used for initializing weights in the following way: ``init_method(weight)``.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
return_layernorm_output : bool, default = `False` return_layernorm_output : bool, default = False
if set to `True`, output of layernorm is returned from the forward if set to ``True``, output of layernorm is returned from the forward
together with the output of the linear transformation. together with the output of the linear transformation.
Example use case: residual connection for transformer module is Example use case: residual connection for transformer module is
taken post layernorm. taken post layernorm.
return_layernorm_output_gathered : bool, default = `False` return_layernorm_output_gathered : bool, default = False
if set to `True`, output of layernorm is returned after the all if set to ``True``, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False. gather operation. Ignored if return_layernorm_output is False.
Example use case: with sequence parallel, input to residual connection Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered. for transformer module (e.g. LoRA) will need to be gathered.
...@@ -1122,10 +1085,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1122,10 +1085,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
they are used to make the names of equally-sized parameters. If a dict they are used to make the names of equally-sized parameters. If a dict
(preferably an OrderedDict) is provided, the keys are used as names and (preferably an OrderedDict) is provided, the keys are used as names and
values as split sizes along dim 0. The resulting parameters will have values as split sizes along dim 0. The resulting parameters will have
names that end in `_weight` or `_bias`, so trailing underscores are names that end in ``_weight`` or ``_bias``, so trailing underscores are
stripped from any provided names. stripped from any provided names.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to ``'True'``, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
.. math:: .. math::
...@@ -1135,53 +1098,53 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1135,53 +1098,53 @@ class LayerNormLinear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
name: str, default = `None` name : str, default = None
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism. if set to ``True``, uses sequence parallelism.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
tp_size : int, default = 1 tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives. parallel collectives.
parallel_mode : {None, 'column', 'row'}, default = `None` parallel_mode : {None, 'column', 'row'}, default = None
used to decide whether this Linear layer is Column Parallel Linear or Row used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_. Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed. When set to ``None``, no communication is performed.
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of if set to ``True``, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional ``main_grad`` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular ``grad``) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating. will overwrite ``main_grad`` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias itself, but when set to ``True``, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.get_default_dtype()` params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False` delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to `True`, Whether or not to delay weight gradient computation. If set to ``True``,
it's the user's responsibility to call `module.backward_dw` to compute it's the user's responsibility to call ``module.backward_dw`` to compute
weight gradients. weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass. Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations. This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce Requires PyTorch version 2.7.0 or higher. When set to ``None``, standard all-reduce
is used. is used.
""" """
...@@ -1462,15 +1425,12 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1462,15 +1425,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs # Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4(): elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe) self._customize_quantizers_nvfp4(fwd, recipe)
# elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
...@@ -1542,8 +1502,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1542,8 +1502,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output) return self.onnx_forward(inp, fp8_output, is_grad_enabled)
debug = self.is_debug_iter() debug = self.is_debug_iter()
...@@ -1565,9 +1527,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1565,9 +1527,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
).is_fp8_ubuf(): ).is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with torch.cuda.device( with self.prepare_forward(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp: ) as inp:
...@@ -1575,14 +1535,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1575,14 +1535,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = ( quantizers = (
self._get_quantizers(fp8_output, fp8_grad) self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad) else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
) )
if debug: if debug:
if self.no_debug_features_active(quantizers): if self.no_debug_features_active(quantizers):
debug = False debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad) quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
( (
input_quantizer, input_quantizer,
...@@ -1593,18 +1553,13 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1593,18 +1553,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) = quantizers ) = quantizers
if torch.is_grad_enabled(): if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply fwd_fn = _LayerNormLinear.apply
args = [] autograd_ctx = []
else: else:
fwd_fn = _LayerNormLinear.forward fwd_fn = _LayerNormLinear.forward
args = [None] autograd_ctx = [None]
args += ( non_tensor_args = (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
...@@ -1626,8 +1581,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1626,8 +1581,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.parallel_mode, self.parallel_mode,
self.return_layernorm_output, self.return_layernorm_output,
self.return_layernorm_output_gathered, self.return_layernorm_output_gathered,
torch.is_grad_enabled(), is_grad_enabled,
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.normalization, self.normalization,
...@@ -1644,7 +1599,15 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1644,7 +1599,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.symmetric_ar_type, self.symmetric_ar_type,
debug, debug,
) )
out = fwd_fn(*args) out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
...@@ -1660,7 +1623,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1660,7 +1623,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return out, ln_out return out, ln_out
return out return out
def _get_quantizers(self, fp8_output, fp8_grad): def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
if not self.fp8: if not self.fp8:
return [None] * 6 return [None] * 6
grad_input_quantizer = None grad_input_quantizer = None
...@@ -1669,12 +1632,16 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1669,12 +1632,16 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True input_quantizer.internal = True
if not (self.parallel_mode == "column" and self.sequence_parallel):
input_quantizer.optimize_for_gemm = True
(weight_quantizer,) = self._get_weight_quantizers() (weight_quantizer,) = self._get_weight_quantizers()
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled(): if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True grad_output_quantizer.internal = True
if not (self.parallel_mode == "row" and self.sequence_parallel):
grad_output_quantizer.optimize_for_gemm = True
if fp8_grad: if fp8_grad:
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
...@@ -1687,8 +1654,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1687,8 +1654,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) )
def _get_debug_quantizers(self, fp8_output, fp8_grad): def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad) original_quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
assert TEDebugState.debug_enabled assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
...@@ -1713,6 +1680,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1713,6 +1680,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self, self,
inp: torch.Tensor, inp: torch.Tensor,
fp8_output: bool, fp8_output: bool,
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
ONNX-compatible version of the forward function that provides numerical equivalence ONNX-compatible version of the forward function that provides numerical equivalence
...@@ -1728,7 +1696,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1728,7 +1696,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
*_, *_,
) = self._get_quantizers(fp8_output, fp8_grad=False) ) = self._get_quantizers(fp8_output, False, is_grad_enabled)
inp_dtype = inp.dtype inp_dtype = inp.dtype
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
...@@ -1857,14 +1825,3 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1857,14 +1825,3 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True weight_quantizer.internal = True
return [weight_quantizer] return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -17,11 +17,10 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION ...@@ -17,11 +17,10 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.torch_version import torch_version
from transformer_engine.pytorch.tensor.utils import is_custom from transformer_engine.pytorch.tensor.utils import is_custom
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_workspace,
_ub_communicators, _ub_communicators,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
...@@ -46,6 +45,7 @@ from ..utils import ( ...@@ -46,6 +45,7 @@ from ..utils import (
clear_tensor_data, clear_tensor_data,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
get_nvtx_range_context,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -57,6 +57,8 @@ from ..distributed import ( ...@@ -57,6 +57,8 @@ from ..distributed import (
use_reentrant_activation_recompute, use_reentrant_activation_recompute,
in_fp8_activation_recompute_phase, in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors, _fsdp_scatter_tensors,
_get_cuda_rng_state,
_set_cuda_rng_state,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
...@@ -174,7 +176,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -174,7 +176,7 @@ class _LayerNormMLP(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward( def _forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
ln_weight: torch.Tensor, ln_weight: torch.Tensor,
...@@ -183,55 +185,155 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -183,55 +185,155 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias: torch.Tensor, fc1_bias: torch.Tensor,
fc2_weight: torch.Tensor, fc2_weight: torch.Tensor,
fc2_bias: torch.Tensor, fc2_bias: torch.Tensor,
eps: float, non_tensor_args: Tuple,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
fc1_input_quantizer: Optional[Quantizer],
fc1_weight_quantizer: Optional[Quantizer],
fc1_output_quantizer: Optional[Quantizer],
fc1_grad_input_quantizer: Optional[Quantizer],
fc1_grad_weight_quantizer: Optional[Quantizer],
fc1_grad_output_quantizer: Optional[Quantizer],
fc2_input_quantizer: Optional[Quantizer],
fc2_weight_quantizer: Optional[Quantizer],
fc2_output_quantizer: Optional[Quantizer],
fc2_grad_input_quantizer: Optional[Quantizer],
fc2_grad_weight_quantizer: Optional[Quantizer],
fc2_grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
bias_gelu_fusion: bool,
set_parallel_mode: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
activation: str,
activation_params: Optional[dict],
normalization: str,
ub_overlap_ag: bool,
ub_overlap_rs: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
gemm_gelu_fusion: bool,
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
return_layernorm_output,
return_layernorm_output_gathered,
bias_gelu_fusion,
set_parallel_mode,
is_grad_enabled,
fwd_ln_sm_margin,
bwd_ln_sm_margin,
zero_centered_gamma,
activation,
activation_params,
normalization,
ub_overlap_ag,
ub_overlap_rs,
ub_overlap_rs_dgrad,
ub_bulk_wgrad,
ub_bulk_dgrad,
gemm_gelu_fusion,
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
checkpoint,
debug,
recompute_for_bwd,
) = non_tensor_args
# if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take
if is_grad_enabled and not recompute_for_bwd:
ctx.checkpoint = checkpoint
if checkpoint:
# save the state of autocast and quantizers for recomputation
ctx.autocast_state = (
FP8GlobalStateManager.get_autocast_state()
) # to restore autocast state during recomputation
if (
fp8
and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__
== "DelayedScaling"
): # only applicable for delayed scaling
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(
module.fp8_meta
) # to restore quantizers during recomputation
# save the rng states
ctx.cpu_rng_state = torch.get_rng_state()
ctx.cuda_rng_state = _get_cuda_rng_state()
# whether to save activations regularly, or save inputs for recomputation in bwd
save_for_checkpoint = checkpoint and is_grad_enabled and not recompute_for_bwd
# whether we are in the forward stage, or recomputing in the bwd stage (false if not checkpointing)
is_recomputation = checkpoint and is_grad_enabled and recompute_for_bwd
# save the initial state for recomputation by bwd
if save_for_checkpoint:
# save tensors
tensors_to_save, tensor_objects = prepare_for_saving(
inp,
ln_weight,
ln_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias,
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
ctx.other_args = {
"eps": eps,
"is_first_microbatch": is_first_microbatch,
"fp8": fp8,
"fp8_calibration": fp8_calibration,
"wgrad_store": wgrad_store,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"fc1_input_quantizer": fc1_input_quantizer,
"fc1_weight_quantizer": fc1_weight_quantizer,
"fc1_output_quantizer": fc1_output_quantizer,
"fc1_grad_input_quantizer": fc1_grad_input_quantizer,
"fc1_grad_weight_quantizer": fc1_grad_weight_quantizer,
"fc1_grad_output_quantizer": fc1_grad_output_quantizer,
"fc2_input_quantizer": fc2_input_quantizer,
"fc2_weight_quantizer": fc2_weight_quantizer,
"fc2_output_quantizer": fc2_output_quantizer,
"fc2_grad_input_quantizer": fc2_grad_input_quantizer,
"fc2_grad_weight_quantizer": fc2_grad_weight_quantizer,
"fc2_grad_output_quantizer": fc2_grad_output_quantizer,
"cpu_offloading": cpu_offloading,
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": sequence_parallel,
"tensor_parallel": tensor_parallel,
"activation_dtype": activation_dtype,
"return_layernorm_output": return_layernorm_output,
"return_layernorm_output_gathered": return_layernorm_output_gathered,
"bias_gelu_fusion": bias_gelu_fusion,
"set_parallel_mode": set_parallel_mode,
"is_grad_enabled": is_grad_enabled,
"fwd_ln_sm_margin": fwd_ln_sm_margin,
"bwd_ln_sm_margin": bwd_ln_sm_margin,
"zero_centered_gamma": zero_centered_gamma,
"activation": activation,
"activation_params": activation_params,
"normalization": normalization,
"ub_overlap_ag": ub_overlap_ag,
"ub_overlap_rs": ub_overlap_rs,
"ub_overlap_rs_dgrad": ub_overlap_rs_dgrad,
"ub_bulk_wgrad": ub_bulk_wgrad,
"ub_bulk_dgrad": ub_bulk_dgrad,
"gemm_gelu_fusion": gemm_gelu_fusion,
"fsdp_group": fsdp_group,
"module": module,
"skip_fp8_weight_update": skip_fp8_weight_update,
"symmetric_ar_type": symmetric_ar_type,
"checkpoint": checkpoint,
"debug": debug,
"recompute_for_bwd": True, # set this to true for recomputation phase
}
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features, inp_shape = ln_weight.numel(), inp.shape in_features, inp_shape = ln_weight.numel(), inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible" assert inp_shape[-1] == in_features, "GEMM not possible"
...@@ -253,7 +355,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -253,7 +355,14 @@ class _LayerNormMLP(torch.autograd.Function):
start_offload(inputmat) start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# bwd needs fc1 input when grad is enabled, fc1 needs grad, and either
# 1) no checkpointing
# or 2) doing the recomputation with checkpointing
backwards_needs_fc1_input = fc1_weight.requires_grad and (
(is_grad_enabled and not checkpoint) or is_recomputation
)
device = inp.device device = inp.device
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
...@@ -311,7 +420,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -311,7 +420,9 @@ class _LayerNormMLP(torch.autograd.Function):
zero_centered_gamma, zero_centered_gamma,
) )
ln_out_return = None ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
# do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing
if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation:
ln_out_return = ln_out ln_out_return = ln_out
# Prepare GEMM input # Prepare GEMM input
...@@ -319,7 +430,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -319,7 +430,9 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total = None ln_out_total = None
ub_obj_lnout = None ub_obj_lnout = None
if sequence_parallel: if sequence_parallel:
if return_layernorm_output_gathered:
# do not return ln output if checkpointing and in recomputation, not necessary
if return_layernorm_output_gathered and not is_recomputation:
# Perform all-gather in high precision if gathered # Perform all-gather in high precision if gathered
# norm output will be returned # norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
...@@ -327,8 +440,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -327,8 +440,6 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8 or debug: if fp8 or debug:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
fc1_input_quantizer.all_gather_usage = False
ln_out_total = fc1_input_quantizer(ln_out_total) ln_out_total = fc1_input_quantizer(ln_out_total)
else: else:
quantizer = None quantizer = None
...@@ -442,7 +553,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -442,7 +553,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_outputs = general_gemm( fc1_outputs = general_gemm(
fc1_weight_final, fc1_weight_final,
ln_out_total, ln_out_total,
get_workspace(),
quantization_params=( quantization_params=(
fc2_input_quantizer fc2_input_quantizer
if gemm_gelu_fusion if gemm_gelu_fusion
...@@ -463,7 +573,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -463,7 +573,12 @@ class _LayerNormMLP(torch.autograd.Function):
# ------------------------------------------------------ # ------------------------------------------------------
# Deallocate FC1 GEMM input tensor if no longer needed # Deallocate FC1 GEMM input tensor if no longer needed
if not is_grad_enabled and (ln_out_total is not ln_out_return): # first part of if statement means that we only clear ln_out_total if
# 1) checkpointing and not recomputing (in the forward stage, not bwd recompute stage)
# 2) not checkpointing and grad disabled
if ((checkpoint and not is_recomputation) or not is_grad_enabled) and (
ln_out_total is not ln_out_return
):
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
# ACTIVATION - sometimes activation is fused with the GEMM above. # ACTIVATION - sometimes activation is fused with the GEMM above.
...@@ -501,12 +616,27 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -501,12 +616,27 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params) act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params)
if not is_grad_enabled:
clear_tensor_data(fc1_out)
if not fp8 and fp8_calibration: if not fp8 and fp8_calibration:
if fc2_input_quantizer is not None: if fc2_input_quantizer is not None:
fc2_input_quantizer.calibrate(act_out) fc2_input_quantizer.calibrate(act_out)
# we want to skip fc2 computation if we are checkpointing and recomputing,
# otherwise we compute fc2
if not (is_recomputation and checkpoint):
# if we get to this point, we know this is not bwd recomputation
# so we must be in the fwd
# now is_grad_enabled can be true or false
# if false, can safely delete
# if true, we can only delete if checkpoint is true, since we will recompute anyways,
# otherwise, checkpoint is false, so cant delete
if (
checkpoint or not is_grad_enabled
): # we can safely get rid of these if this is the case
clear_tensor_data(fc1_out)
if not fp8 and fp8_calibration:
if fc2_weight_quantizer is not None: if fc2_weight_quantizer is not None:
fc2_weight_quantizer.calibrate(fc2_weight) fc2_weight_quantizer.calibrate(fc2_weight)
...@@ -526,7 +656,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -526,7 +656,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
fc2_weight_final, fc2_weight_final,
act_out, act_out,
get_workspace(),
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=fc2_bias, bias=fc2_bias,
quantization_params=fc2_output_quantizer, quantization_params=fc2_output_quantizer,
...@@ -539,8 +668,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -539,8 +668,8 @@ class _LayerNormMLP(torch.autograd.Function):
# Finished FC2 GEMM... # Finished FC2 GEMM...
# ------------------------------------------------------ # ------------------------------------------------------
# Deallocate tensors if no longer needed # Deallocate tensors if no longer needed, again, can safely deallocate
if not is_grad_enabled: if checkpoint or not is_grad_enabled: # same logic as last clear_tensor_data block
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
# Prepare output tensor # Prepare output tensor
...@@ -561,8 +690,24 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -561,8 +690,24 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_out = gemm_out fc2_out = gemm_out
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1]) fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
# Cache state for backward pass # now saving stuff for bwd:
if is_grad_enabled: # if we are using checkpointing, this information will be saved in the bwd recomputation stage, so can skip it in fwd
# if we are not checkpointing, then we must save this if grad is enabled
if is_grad_enabled and not save_for_checkpoint:
ctx.fc1_weight_quantizer = fc1_weight_quantizer
ctx.fc2_weight_quantizer = fc2_weight_quantizer
if not fc1_weight.requires_grad:
if not return_layernorm_output:
clear_tensor_data(ln_out)
ln_out = None
if not fc2_weight.requires_grad:
clear_tensor_data(act_out)
act_out = None
if not checkpoint: # regular path, no selective activation checkpointing
if cpu_offloading: if cpu_offloading:
mark_activation_offload( mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
...@@ -572,26 +717,27 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -572,26 +717,27 @@ class _LayerNormMLP(torch.autograd.Function):
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves # shards/unshards the base weights so we don't do it ourselves
ctx.fsdp_group = fsdp_group ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
ctx.fsdp_shapes = (
_fsdp_scatter_tensors( # again, ony relevant if we have activations to save
fsdp_group, fsdp_group,
mu, mu,
rsigma, rsigma,
ln_out, ln_out,
fc1_out_without_bias if bias_gelu_fusion else fc1_out, fc1_out_without_bias if bias_gelu_fusion else fc1_out,
act_out, act_out,
fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) else None, (
fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, fc1_weight_final
if fp8 and not isinstance(fc1_weight, Float8Tensor)
else None
),
(
fc2_weight_final
if fp8 and not isinstance(fc2_weight, Float8Tensor)
else None
),
)
) )
ctx.fc1_weight_quantizer = fc1_weight_quantizer
ctx.fc2_weight_quantizer = fc2_weight_quantizer
if not fc1_weight.requires_grad:
if not return_layernorm_output:
clear_tensor_data(ln_out)
ln_out = None
if not fc2_weight.requires_grad:
clear_tensor_data(act_out)
act_out = None
if cpu_offloading: if cpu_offloading:
mark_not_offload( mark_not_offload(
...@@ -604,7 +750,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -604,7 +750,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight, fc2_weight,
fc2_bias, fc2_bias,
) )
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -622,6 +767,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -622,6 +767,9 @@ class _LayerNormMLP(torch.autograd.Function):
rsigma, rsigma,
) )
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
# This check is needed to ensure that main_grad is not created # This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates # during the forward pass when using MCore FSDP as it creates
...@@ -638,9 +786,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -638,9 +786,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fc1_main_grad_func = lambda: fc1_weight.main_grad ctx.fc1_main_grad_func = lambda: fc1_weight.main_grad
ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer
ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer
...@@ -695,11 +840,30 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -695,11 +840,30 @@ class _LayerNormMLP(torch.autograd.Function):
): ):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase(): if in_fp8_activation_recompute_phase() or is_recomputation:
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
if is_recomputation: # return the recomputed tensors
return (
ctx,
inputmat,
ln_weight,
ln_out,
fc1_weight_final,
fc1_weight,
fc1_bias,
fc1_out,
fc1_out_without_bias,
act_out,
fc2_weight_final,
fc2_weight,
fc2_bias,
mu,
rsigma,
)
# we only get to this point if we are not recomputing for bwd, since that would have returned in the block above
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp_shape) shape = list(inp_shape)
...@@ -708,14 +872,101 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -708,14 +872,101 @@ class _LayerNormMLP(torch.autograd.Function):
return fc2_out, ln_out_return.view(inp_shape) return fc2_out, ln_out_return.view(inp_shape)
return fc2_out return fc2_out
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
fc1_weight: torch.Tensor,
fc1_bias: torch.Tensor,
fc2_weight: torch.Tensor,
fc2_bias: torch.Tensor,
non_tensor_args: Tuple,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
# add recompute_for_bwd
non_tensor_args += (False,)
return _LayerNormMLP._forward(
ctx,
inp,
ln_weight,
ln_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias,
non_tensor_args,
)
@staticmethod
def _recompute(ctx):
# pylint: disable=missing-function-docstring
saved_tensors = ctx.saved_tensors
tensors = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
if ctx.checkpoint: # do recomputation from the original args
# backward is not in autocast context, so we set the state here
# we also have to set the quantizer states to what they were before the forward pass (only relevant for DelayedScaling recipe)
final_autocast_state = (
FP8GlobalStateManager.get_autocast_state()
) # get current autocast state
FP8GlobalStateManager.set_autocast_state(ctx.autocast_state) # set old autocast state
if (
ctx.other_args["fp8"]
and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling"
): # only applicable for delayed scaling
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(
ctx.other_args["module"].fp8_meta
) # set old quantizer state
# get current rng state
final_cpu_rng_state = torch.get_rng_state()
final_cuda_rng_state = _get_cuda_rng_state()
# set rng state for fwd
torch.set_rng_state(ctx.cpu_rng_state)
_set_cuda_rng_state(ctx.cuda_rng_state)
out = _LayerNormMLP._forward( # recompute
ctx,
*tensors,
tuple(ctx.other_args.values()),
)
FP8GlobalStateManager.set_autocast_state(final_autocast_state) # restore autocast state
if (
ctx.other_args["fp8"]
and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling"
):
FP8GlobalStateManager.restore_fp8_meta_tensors(
ctx.other_args["module"].fp8_meta
) # restore quantizers
# set rng state for fwd
torch.set_rng_state(final_cpu_rng_state)
_set_cuda_rng_state(final_cuda_rng_state)
return out
# load from saved (return ctx is just because the other branch does too)
return tuple([ctx] + tensors)
@staticmethod @staticmethod
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_LayerNormMLP_backward"): with get_nvtx_range_context("_LayerNormMLP_backward"):
saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
ctx,
inputmat, inputmat,
ln_weight, ln_weight,
ln_out, ln_out,
...@@ -730,11 +981,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -730,11 +981,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_bias, fc2_bias,
mu, mu,
rsigma, rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors) ) = _LayerNormMLP._recompute(ctx)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
# Since main_grad can be modified inplace, it should not be a part of saved_tensors # Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad = ( fc1_weight_main_grad = (
...@@ -883,7 +1130,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -883,7 +1130,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_output, *_ = general_gemm( gemm_output, *_ = general_gemm(
fc2_weight, fc2_weight,
grad_output, grad_output,
get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=( quantization_params=(
...@@ -977,7 +1223,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -977,7 +1223,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
fc2_wgrad_gemm_kwargs = { fc2_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
origin_fc2_weight.main_grad.dtype origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
...@@ -1155,7 +1400,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1155,7 +1400,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
fc1_weight, fc1_weight,
dact, dact,
get_workspace(),
out=gemm_out, out=gemm_out,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer, quantization_params=ctx.fc1_grad_input_quantizer,
...@@ -1234,7 +1478,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1234,7 +1478,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
fc1_wgrad_gemm_kwargs = { fc1_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
origin_fc1_weight.main_grad.dtype origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
...@@ -1429,52 +1672,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1429,52 +1672,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias_grad if fc1_bias is not None else None, fc1_bias_grad if fc1_bias is not None else None,
fc2_wgrad, # pylint: disable=possibly-used-before-assignment fc2_wgrad, # pylint: disable=possibly-used-before-assignment
fc2_bias_grad, fc2_bias_grad,
None, # eps None,
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # fc1_input_quantizer,
None, # fc1_weight_quantizer,
None, # fc1_output_quantizer,
None, # fc1_grad_input_quantizer,
None, # fc1_grad_weight_quantizer,
None, # fc1_grad_output_quantizer,
None, # fc2_input_quantizer,
None, # fc2_weight_quantizer,
None, # fc2_output_quantizer,
None, # fc2_grad_input_quantizer,
None, # fc2_grad_weight_quantizer,
None, # fc2_grad_output_quantizer,
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # return_layernorm_output
None, # return_layernorm_output_gathered
None, # bias_gelu_fusion
None, # set_parallel_mode
None, # is_grad_enabled
None, # fwd_ln_sm_margin
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # activation
None, # activation_params
None, # normalization
None, # ub_overlap_ag
None, # ub_overlap_rs
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # gemm_gelu_fusion
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # debug
) )
...@@ -1491,38 +1689,38 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1491,38 +1689,38 @@ class LayerNormMLP(TransformerEngineBaseModule):
intermediate size to which input samples are projected. intermediate size to which input samples are projected.
eps : float, default = 1e-5 eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability. a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True` bias : bool, default = True
if set to `False`, the FC1 and FC2 layers will not learn an additive bias. if set to ``False``, the FC1 and FC2 layers will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied. type of normalization applied.
activation : str, default = 'gelu' activation : str, default = 'gelu'
activation function used. activation function used.
Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
'silu', 'swiglu', and 'clamped_swiglu'. ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : dict, default = `None` activation_params : dict, default = None
Additional parameters for the activation function. Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which At the moment, only used for ``'clamped_swiglu'`` activation which
supports 'limit' and 'alpha' parameters. supports ``'limit'`` and ``'alpha'`` parameters.
init_method : Callable, default = `None` init_method : Callable, default = None
used for initializing FC1 weights in the following way: `init_method(weight)`. used for initializing FC1 weights in the following way: ``init_method(weight)``.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
output_layer_init_method : Callable, default = `None` output_layer_init_method : Callable, default = None
used for initializing FC2 weights in the following way: used for initializing FC2 weights in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to ``output_layer_init_method(weight)``. When set to ``None``, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`. ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
return_layernorm_output : bool, default = `False` return_layernorm_output : bool, default = False
if set to `True`, output of layernorm is returned from the forward if set to ``True``, output of layernorm is returned from the :meth:`forward` method
together with the output of the linear transformation. together with the output of the linear transformation.
Example use case: residual connection for transformer module Example use case: residual connection for transformer module
is taken post layernorm. is taken post layernorm.
return_layernorm_output_gathered : bool, default = `False` return_layernorm_output_gathered : bool, default = False
if set to `True`, output of layernorm is returned after the all if set to ``True``, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False. gather operation. Ignored if ``return_layernorm_output`` is False.
Example use case: with sequence parallel, input to residual connection Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered. for transformer module (e.g. LoRA) will need to be gathered.
Returning layernorm output gathered will prevent a redundant gather. Returning layernorm output gathered will prevent a redundant gather.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = False
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
.. math:: .. math::
...@@ -1532,61 +1730,65 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1532,61 +1730,65 @@ class LayerNormMLP(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
name: str, default = `None` name : str, default = None
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
set_parallel_mode : bool, default = `False` set_parallel_mode : bool, default = False
if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row if set to ``True``, FC1 is used as Column Parallel and FC2 is used as Row
Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_. Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism. if set to ``True``, uses sequence parallelism.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
tp_size : int, default = 1 tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives. parallel collectives.
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = False
if set to `True`, enables fusing of creation and accumulation of if set to ``True``, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional ``main_grad`` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular ``grad``) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True weight tensor having attribute ``'overwrite_main_grad'`` set to True
will overwrite `main_grad` instead of accumulating. will overwrite ``main_grad`` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias for FC2, but when set to ``True``, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.get_default_dtype()` params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
seq_length: int seq_length : int
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward functions are warmed up before training to ensure same kernels are used for forward
propogation and activation recompute phase. propogation and activation recompute phase.
micro_batch_size: int micro_batch_size : int
batch size per training step. Needed for JIT Warmup, a technique where jit batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase. used for forward propogation and activation recompute phase.
delay_wgrad_compute : bool, default = `False` delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to `True`, Whether or not to delay weight gradient computation. If set to ``True``,
it's the user's responsibility to call `module.backward_dw` to compute it's the user's responsibility to call :meth:`backward_dw` to compute
weight gradients. weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass. Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations. This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce Requires PyTorch version 2.7.0 or higher. When set to ``None``, standard all-reduce
is used. is used.
checkpoint : bool, default = False
whether to use selective activation checkpointing, where activations are not saved for bwd,
and instead are recomputed (skipping fc2, as it is not needed for backward). Trades compute
for memory. default is false, in which activations are saved in fwd. not supported for onnx forward
""" """
def __init__( def __init__(
...@@ -1622,6 +1824,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1622,6 +1824,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
checkpoint: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1642,6 +1845,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1642,6 +1845,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type self.symmetric_ar_type = symmetric_ar_type
self.checkpoint = checkpoint
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
self.gemm_gelu_fusion = ( self.gemm_gelu_fusion = (
...@@ -1788,15 +1992,12 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1788,15 +1992,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs # Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4(): elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe) self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
...@@ -1857,8 +2058,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1857,8 +2058,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp) return self.onnx_forward(inp, is_grad_enabled)
debug = self.is_debug_iter() debug = self.is_debug_iter()
...@@ -1874,19 +2077,17 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1874,19 +2077,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True fp8_output = True
with torch.cuda.device( with self.prepare_forward(inp, num_gemms=2) as inp:
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=2) as inp:
quantizers = ( quantizers = (
self._get_quantizers(fp8_output) self._get_quantizers(fp8_output, is_grad_enabled)
if not debug if not debug
else self._get_debug_quantizers(fp8_output) else self._get_debug_quantizers(fp8_output, is_grad_enabled)
) )
if debug: if debug:
if self.no_debug_features_active(quantizers): if self.no_debug_features_active(quantizers):
debug = False debug = False
quantizers = self._get_quantizers(fp8_output) quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
# Get quantizers # Get quantizers
( (
...@@ -1919,20 +2120,14 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1919,20 +2120,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
and self.bias_gelu_nvfusion and not use_reentrant_activation_recompute() ): and self.bias_gelu_nvfusion and not use_reentrant_activation_recompute() ):
self.bias_gelu_nvfusion = False self.bias_gelu_nvfusion = False
if torch.is_grad_enabled(): if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply fwd_fn = _LayerNormMLP.apply
args = [] autograd_ctx = []
else: else:
fwd_fn = _LayerNormMLP.forward fwd_fn = _LayerNormMLP.forward
args = [None] autograd_ctx = [None]
args += (
inp, non_tensor_args = (
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
...@@ -1961,8 +2156,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1961,8 +2156,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_layernorm_output_gathered, self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug, self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode, self.set_parallel_mode,
torch.is_grad_enabled(), is_grad_enabled,
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.activation, self.activation,
...@@ -1978,9 +2173,20 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1978,9 +2173,20 @@ class LayerNormMLP(TransformerEngineBaseModule):
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
self.symmetric_ar_type, self.symmetric_ar_type,
self.checkpoint,
debug, debug,
) )
out = fwd_fn(*args) out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
...@@ -1996,7 +2202,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1996,7 +2202,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return out, ln_out return out, ln_out
return out return out
def _get_quantizers(self, fp8_output): def _get_quantizers(self, fp8_output, is_grad_enabled):
( (
fc1_input_quantizer, fc1_input_quantizer,
fc1_output_quantizer, fc1_output_quantizer,
...@@ -2013,6 +2219,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2013,6 +2219,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.fp8 or self.fp8_calibration: if self.fp8 or self.fp8_calibration:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = True fc1_input_quantizer.internal = True
if not self.sequence_parallel:
fc1_input_quantizer.optimize_for_gemm = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage( fc2_input_quantizer.set_usage(
rowwise=True, rowwise=True,
...@@ -2021,20 +2229,24 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2021,20 +2229,24 @@ class LayerNormMLP(TransformerEngineBaseModule):
(MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer), (MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer),
), ),
) )
fc1_input_quantizer.internal = True fc2_input_quantizer.internal = True
fc2_input_quantizer.optimize_for_gemm = True
if fp8_output: if fp8_output:
fc2_output_quantizer = self.quantizers["scaling_fwd"][ fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT tex.FP8FwdTensors.GEMM2_OUTPUT
] ]
if torch.is_grad_enabled(): if is_grad_enabled:
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2 tex.FP8BwdTensors.GRAD_OUTPUT2
] ]
fc2_grad_output_quantizer.internal = True fc2_grad_output_quantizer.internal = True
if not self.sequence_parallel:
fc2_grad_output_quantizer.optimize_for_gemm = True
fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
] ]
fc1_grad_output_quantizer.internal = True fc1_grad_output_quantizer.internal = True
fc1_grad_output_quantizer.optimize_for_gemm = True
return ( return (
fc1_input_quantizer, fc1_input_quantizer,
...@@ -2051,9 +2263,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2051,9 +2263,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer, fc2_grad_output_quantizer,
) )
def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: def onnx_forward(
self, inp: torch.Tensor, is_grad_enabled: bool
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
ONNX-compatible version of the forward function that provides numerical equivalence ONNX-compatible version of the :meth:`forward` method that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations. while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios. This simplified implementation is designed specifically for inference scenarios.
""" """
...@@ -2061,14 +2275,23 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2061,14 +2275,23 @@ class LayerNormMLP(TransformerEngineBaseModule):
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export" assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export"
assert_warmed_up(self) assert_warmed_up(self)
# Get quantizers
( (
fc1_input_quantizer, fc1_input_quantizer,
fc1_weight_quantizer, fc1_weight_quantizer,
_,
_,
_,
_,
fc2_input_quantizer, fc2_input_quantizer,
fc2_weight_quantizer, fc2_weight_quantizer,
output_quantizer, fc2_output_quantizer,
*_, _,
) = self._get_quantizers(False) _,
_,
) = self._get_quantizers(False, is_grad_enabled)
inp_dtype = inp.dtype inp_dtype = inp.dtype
fc1_weight, fc2_weight = self._get_weight_tensors() fc1_weight, fc2_weight = self._get_weight_tensors()
...@@ -2142,7 +2365,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2142,7 +2365,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_out = onnx_gemm(fc2_weight, act_out, fc2_bias) fc2_out = onnx_gemm(fc2_weight, act_out, fc2_bias)
if output_quantizer is not None: if fc2_output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported") raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_layernorm_output: if self.return_layernorm_output:
...@@ -2153,10 +2376,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2153,10 +2376,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
return fc2_out, fc2_bias.to(inp_dtype) return fc2_out, fc2_bias.to(inp_dtype)
return fc2_out return fc2_out
def _get_debug_quantizers(self, fp8_output): def _get_debug_quantizers(self, fp8_output, is_grad_enabled):
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
base_quantizers = list(self._get_quantizers(fp8_output)) base_quantizers = list(self._get_quantizers(fp8_output, is_grad_enabled))
assert TEDebugState.debug_enabled assert TEDebugState.debug_enabled
def make_debug(prefix, offset): def make_debug(prefix, offset):
...@@ -2276,22 +2499,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2276,22 +2499,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer.internal = True fc2_weight_quantizer.internal = True
return [fc1_weight_quantizer, fc2_weight_quantizer] return [fc1_weight_quantizer, fc2_weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_mlp."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].all_gather_usage = True
def backward_dw(self): def backward_dw(self):
""" """
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
...@@ -2299,7 +2506,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2299,7 +2506,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
""" """
if not self.need_backward_dw(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"): with get_nvtx_range_context("_LayerNormMLP_wgrad"):
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop() (fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()
if self.use_bias and self.fc1_bias.grad is None: if self.use_bias and self.fc1_bias.grad is None:
(fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop() (fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment