Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -23,7 +23,7 @@ class ReduceScatter(BasicOperation):
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
process_group : torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -24,7 +24,7 @@ class Reshape(BasicOperation):
Parameters
----------
shape: iterable of int
shape : iterable of int
Output tensor dimensions. If one dimension is -1, it is
inferred based on input tensor dimensions.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -42,13 +42,13 @@ class RMSNorm(BasicOperation):
Parameters
----------
normalized_shape: int or iterable of int
normalized_shape : int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device
device : torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype : torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
......@@ -57,7 +57,7 @@ class RMSNorm(BasicOperation):
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0
sm_margin : int, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
......@@ -249,4 +249,6 @@ class RMSNorm(BasicOperation):
) -> torch.Tensor:
"""Every operand in this function has a defined ONNX translation."""
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.rms_norm(input_, input_.shape[-1:], weight, self.eps)
variance = input_.pow(2).mean(-1, keepdim=True)
normalized = input_ * torch.rsqrt(variance + self.eps)
return normalized * weight
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -90,15 +90,15 @@ def fuse_backward_activation_bias(
Parameters
----------
ops: list of tuples
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
recipe: Recipe, optional
recipe : Recipe, optional
Used quantization recipe
Returns
-------
ops: list of tuples
ops : list of tuples
Updated backward pass operations
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -87,13 +87,13 @@ def fuse_backward_add_rmsnorm(
Parameters
----------
ops: list of tuples
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops : list of tuples
Updated backward pass operations
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -119,13 +119,13 @@ def fuse_backward_linear_add(
Parameters
----------
ops: list of tuples
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops : list of tuples
Updated backward pass operations
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -119,13 +119,13 @@ def fuse_backward_linear_scale(
Parameters
----------
ops: list of tuples
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops : list of tuples
Updated backward pass operations
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -142,13 +142,13 @@ def fuse_forward_linear_bias_activation(
Parameters
----------
ops: list of tuples
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops : list of tuples
Updated forward pass operations
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -139,13 +139,13 @@ def fuse_forward_linear_bias_add(
Parameters
----------
ops: list of tuples
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops : list of tuples
Updated forward pass operations
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -118,13 +118,13 @@ def fuse_forward_linear_scale_add(
Parameters
----------
ops: list of tuples
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops : list of tuples
Updated forward pass operations
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -19,7 +19,6 @@ from ...module.base import (
fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub,
get_workspace,
)
from ...quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
......@@ -293,6 +292,7 @@ class UserbuffersBackwardLinear(FusedOperation):
rowwise=True,
columnwise=with_columnwise,
)
grad_output_quantizer.optimize_for_gemm = False
dy_local = grad_output_quantizer(dy_local)
else:
dy_local = maybe_dequantize(dy_local, dtype)
......@@ -378,7 +378,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dx, *_ = general_gemm(
w,
dy,
get_workspace(),
out_dtype=dtype,
quantization_params=grad_input_quantizer,
layout="NN",
......@@ -464,7 +463,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dw, *_ = general_gemm(
x,
dy,
get_workspace(),
out_dtype=dw_dtype,
accumulate=accumulate_into_grad_weight,
layout="NT",
......@@ -592,13 +590,13 @@ def fuse_userbuffers_backward_linear(
Parameters
----------
ops: list of tuples
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops : list of tuples
Updated backward pass operations
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -18,7 +18,6 @@ from ...quantization import FP8GlobalStateManager
from ...module.base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
get_workspace,
_2X_ACC_FPROP,
)
from ...quantized_tensor import Quantizer
......@@ -243,7 +242,6 @@ class UserbuffersForwardLinear(FusedOperation):
gemm_output, *_, reduce_scatter_output = general_gemm(
w,
x,
get_workspace(),
out_dtype=dtype,
quantization_params=output_quantizer,
bias=bias,
......@@ -379,13 +377,13 @@ def fuse_userbuffers_forward_linear(
Parameters
----------
ops: list of tuples
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops : list of tuples
Updated forward pass operations
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -310,7 +310,7 @@ class OperationFuser:
Parameters
----------
ops: list of FusibleOperation
ops : list of FusibleOperation
Pipeline of operations
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -27,29 +27,29 @@ class Linear(FusedOperation):
Parameters
----------
in_features: int
in_features : int
Inner dimension of input tensor
out_features: int
out_features : int
Inner dimension of output tensor
bias: bool, default = `True`
bias : bool, default = `True`
Apply additive bias
device: torch.device, default = default CUDA device
device : torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype : torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
tensor_parallel_mode : {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
tensor_parallel_group : torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
sequence_parallel : bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
rng_state_tracker_function: callable
rng_state_tracker_function : callable
Function that returns CudaRNGStatesTracker, which is used for
model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
accumulate_into_main_grad : bool, default = `False`
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -188,9 +188,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Objects for quantization
self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None
self.reset_recipe_state(recipe=recipe)
@property
def is_fused_op(self) -> bool:
......@@ -687,7 +684,7 @@ class FusedOperation(FusibleOperation):
Parameters
----------
basic_ops: iterable of FusibleOperation
basic_ops : iterable of FusibleOperation
Basic ops that are interchangeable with this op
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -11,8 +11,10 @@ from typing import Optional
import warnings
import torch
from torch.distributed._tensor import DTensor
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from .multi_tensor_apply import multi_tensor_applier
from torch.utils.cpp_extension import IS_HIP_EXTENSION
......@@ -371,10 +373,12 @@ class FusedAdam(torch.optim.Optimizer):
store_param_remainders (bool): Store only trailing remainder bits.
"""
dtype = self.name_to_dtype_map[state_name]
# Handle QuantizedTensor by dequantizing first
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else:
data = torch.empty(param.shape, dtype=dtype, device=param.device)
data = torch.empty_like(param_for_empty, dtype=dtype)
if zero_buffer:
data.zero_()
......@@ -567,8 +571,10 @@ class FusedAdam(torch.optim.Optimizer):
unscaled_lists[name].append(unscaled)
scaled_lists[name].append(state[name])
state_scales[name].append(self._scales[p][name])
if isinstance(p, Float8Tensor):
if isinstance(p, Float8Tensor) or (
isinstance(p, DTensor) and isinstance(p._local_tensor, Float8Tensor)
):
p = p._local_tensor if isinstance(p, DTensor) else p
out_dtype = p._fp8_dtype
p_fp8_model.append(p._data.data)
scale, amax, scale_inv = get_fp8_meta(p)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment