Commit 970620a5 authored by wenjh's avatar wenjh
Browse files

merge nv_release_v2.10 to release_v2.10


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents c1a1c04e 769ed778
...@@ -115,6 +115,11 @@ std::vector<py::object> fused_attn_fwd( ...@@ -115,6 +115,11 @@ std::vector<py::object> fused_attn_fwd(
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
#else #else
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(cu_seqlens_q.device());
auto none = py::none(); auto none = py::none();
// create QKV tensor wrappers // create QKV tensor wrappers
......
...@@ -97,6 +97,11 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -97,6 +97,11 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
bool bulk_overlap, float alpha, std::optional<float> beta) { bool bulk_overlap, float alpha, std::optional<float> beta) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace.device());
// Input tensors // Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
...@@ -353,6 +358,11 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, ...@@ -353,6 +358,11 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
at::Tensor workspace, size_t workspaceSize, bool accumulate, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter) { bool gemm_producer, at::Tensor counter) {
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace.device());
// TODO: Handle scaling modes // TODO: Handle scaling modes
NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING;
NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING;
...@@ -402,6 +412,11 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -402,6 +412,11 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
NVTE_ERROR("not implemented, D should be allocated for single output case."); NVTE_ERROR("not implemented, D should be allocated for single output case.");
} }
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace[0].device());
void* output_data_ptr = nullptr; void* output_data_ptr = nullptr;
if (single_output) { if (single_output) {
output_data_ptr = (*D)[0].data_ptr(); output_data_ptr = (*D)[0].data_ptr();
......
...@@ -64,6 +64,11 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -64,6 +64,11 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(input.cast<at::Tensor>().device());
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
...@@ -294,6 +299,11 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -294,6 +299,11 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
const int sm_margin, const bool zero_centered_gamma) { const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(input.cast<at::Tensor>().device());
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
......
...@@ -30,7 +30,7 @@ except ImportError: ...@@ -30,7 +30,7 @@ except ImportError:
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv
from . import torch_version from .torch_version import torch_version
from .utils import ( from .utils import (
is_non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data, safely_set_viewless_tensor_data,
...@@ -642,18 +642,18 @@ def checkpoint( ...@@ -642,18 +642,18 @@ def checkpoint(
Parameters Parameters
---------- ----------
function: Callable function : Callable
pytorch module used to run the forward and backward passes using pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`. the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool, default = False distribute_saved_activations : bool, default = False
if set to `True` and `use_reentrant=True`, first tensor argument is distributed if set to ``True`` and ``use_reentrant=True``, first tensor argument is distributed
across the specified tensor parallel group (`tp_group`) before saving it for the across the specified tensor parallel group (``tp_group``) before saving it for the
backward pass. This has no effect when `use_reentrant=False`. backward pass. This has no effect when ``use_reentrant=False``.
get_rng_state_tracker: `Callable`, default = None get_rng_state_tracker : Callable, default = None
python callable which returns an instance of :func:`CudaRNGStatesTracker`. python callable which returns an instance of :class:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = None tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when `distribute_saved_activations=True` tensor parallel process group. Used only when ``distribute_saved_activations=True``
and `use_reentrant=True`. If `None`, it falls back to the default group. and ``use_reentrant=True``. If ``None``, it falls back to the default group.
use_reentrant : bool, default = True use_reentrant : bool, default = True
perform checkpointing in reentrant mode. perform checkpointing in reentrant mode.
args : tuple args : tuple
...@@ -778,8 +778,8 @@ class CudaRNGStatesTracker: ...@@ -778,8 +778,8 @@ class CudaRNGStatesTracker:
For model parallelism, multiple RNG states need to simultaneously exist in order For model parallelism, multiple RNG states need to simultaneously exist in order
to execute operations in or out of the model parallel region. This class keeps to execute operations in or out of the model parallel region. This class keeps
track of the various RNG states and provides utility methods to maintain them and track of the various RNG states and provides utility methods to maintain them and
execute parts of the model under a given RNG setting. Using the `add` method, a execute parts of the model under a given RNG setting. Using the :meth:`add` method, a
cuda rng state is initialized based on the input `seed` and is assigned to `name`. cuda rng state is initialized based on the input ``seed`` and is assigned to ``name``.
Later, by forking the rng state, we can perform operations and return to our starting Later, by forking the rng state, we can perform operations and return to our starting
cuda state. cuda state.
""" """
...@@ -812,7 +812,9 @@ class CudaRNGStatesTracker: ...@@ -812,7 +812,9 @@ class CudaRNGStatesTracker:
Set the rng states. For efficiency purposes, we do not Set the rng states. For efficiency purposes, we do not
check the size of seed for compatibility. check the size of seed for compatibility.
states: Dict[str, torch.Tensor] Parameters
----------
states : Dict[str, torch.Tensor]
A mapping from string names to RNG states. A mapping from string names to RNG states.
""" """
self.states_ = states self.states_ = states
...@@ -821,9 +823,11 @@ class CudaRNGStatesTracker: ...@@ -821,9 +823,11 @@ class CudaRNGStatesTracker:
""" """
Adds a new RNG state. Adds a new RNG state.
name: str Parameters
----------
name : str
string identifier for the RNG state. string identifier for the RNG state.
seed: int seed : int
PyTorch seed for the RNG state. PyTorch seed for the RNG state.
""" """
# Check seed is not already used. # Check seed is not already used.
...@@ -857,7 +861,9 @@ class CudaRNGStatesTracker: ...@@ -857,7 +861,9 @@ class CudaRNGStatesTracker:
Fork the cuda rng state, perform operations, and exit with Fork the cuda rng state, perform operations, and exit with
the original state. the original state.
name: str Parameters
----------
name : str
string identifier for the RNG state. string identifier for the RNG state.
""" """
# Check if we have added the state # Check if we have added the state
...@@ -948,7 +954,13 @@ def _all_gather_fp8( ...@@ -948,7 +954,13 @@ def _all_gather_fp8(
if isinstance(inp, Float8Tensor): if isinstance(inp, Float8Tensor):
dtype = inp.dtype dtype = inp.dtype
device = inp.device device = inp.device
# Temporarily ensure rowwise usage for output tensor creation
# since we're gathering rowwise data, not the transpose
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=init_columnwise_usage)
out = quantizer.make_empty(out_shape, dtype=dtype, device=device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
quantizer.set_usage(rowwise=init_rowwise_usage, columnwise=init_columnwise_usage)
elif isinstance(inp, Float8Tensor): elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape) out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty( out._data = torch.empty(
...@@ -2001,7 +2013,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -2001,7 +2013,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
Parameters Parameters
---------- ----------
fsdp_root: torch.nn.Module fsdp_root : torch.nn.Module
FSDP-wrapped root module that may contain FSDP-wrapped TE modules. FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
""" """
assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped." assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."
......
...@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]: ...@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled : bool, default = False
whether or not to enable export whether or not to enable export
""" """
......
...@@ -950,38 +950,38 @@ def make_graphed_callables( ...@@ -950,38 +950,38 @@ def make_graphed_callables(
Positional arguments to callable(s). Positional arguments to callable(s).
num_warmup_iters: int, default = 3 num_warmup_iters: int, default = 3
Number of warmup iterations. Number of warmup iterations.
allow_unused_input: bool, default = `False` allow_unused_input: bool, default = False
Whether to handle case where callable inputs Whether to handle case where callable inputs
and outputs are disconnected in compute graph. and outputs are disconnected in compute graph.
sample_kwargs: (tuple of) dict, optional sample_kwargs: (tuple of) dict, optional
Keyword arguments to callable(s) Keyword arguments to callable(s)
pool: (tuple of) int, default = `None`, optional pool: (tuple of) int, default = None, optional
An instance returned from function `torch.cuda.graph_pool_handle` that hints An instance returned from function `torch.cuda.graph_pool_handle` that hints
this graph may share memory with the indicated pool. this graph may share memory with the indicated pool.
retain_graph_in_backward: bool, default = `False` retain_graph_in_backward: bool, default = False
Whether to set retain_graph=True in backward graph capture. Whether to set retain_graph=True in backward graph capture.
_reuse_graph_input_output_buffers: bool, default = `False` _reuse_graph_input_output_buffers: bool, default = False
Reduce memory usage by reusing input/output data buffers between Reduce memory usage by reusing input/output data buffers between
graphs. Only supported with Mcore interleaved pipeline parallelism, i.e. graphs. Only supported with Mcore interleaved pipeline parallelism, i.e.
when `_order` is provided. All callables in `modules` are assumed to have when `_order` is provided. All callables in `modules` are assumed to have
inputs and outputs with the same dtype and shape. inputs and outputs with the same dtype and shape.
Quantization related parameters Quantization parameters
---------------------- -----------------------
enabled: (tuple of) bool, default = `False` enabled: (tuple of) bool, default = False
whether or not to enable low precision quantization (FP8/FP4). whether or not to enable low precision quantization (FP8/FP4).
If tuple, the length must match the number of modules. If tuple, the length must match the number of modules.
calibrating: bool, default = `False` calibrating: bool, default = False
calibration mode allows collecting statistics such as amax and scale calibration mode allows collecting statistics such as amax and scale
data of quantized tensors even when executing without quantization enabled. data of quantized tensors even when executing without quantization enabled.
This is useful for saving an inference ready checkpoint while training This is useful for saving an inference ready checkpoint while training
using a higher precision. using a higher precision.
recipe: recipe.Recipe, default = `None` recipe: recipe.Recipe, default = None
recipe used for low precision quantization. recipe used for low precision quantization.
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None` amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = None
distributed group over which amaxes for the quantized tensors distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step. are reduced at the end of each training step.
cache_quantized_params: bool, default = `False` cache_quantized_params: bool, default = False
Whether or not to cache quantized weights across microbatches. if set to `True`, Whether or not to cache quantized weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward the `is_first_microbatch` boolean argument must be passed into the forward
method for TransformerEngine modules. When storing primary weights in low precision method for TransformerEngine modules. When storing primary weights in low precision
......
...@@ -8,7 +8,7 @@ from functools import wraps ...@@ -8,7 +8,7 @@ from functools import wraps
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import torch import torch
from . import torch_version from .torch_version import torch_version
from .export import is_in_onnx_export_mode from .export import is_in_onnx_export_mode
from .utils import gpu_autocast_ctx from .utils import gpu_autocast_ctx
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
......
...@@ -20,7 +20,6 @@ import torch.nn.functional as F ...@@ -20,7 +20,6 @@ import torch.nn.functional as F
from torch.distributed.tensor import DTensor from torch.distributed.tensor import DTensor
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from ._common import _ParameterInitMeta, noop_cat from ._common import _ParameterInitMeta, noop_cat
from ..quantization import ( from ..quantization import (
...@@ -39,13 +38,18 @@ from ..distributed import ( ...@@ -39,13 +38,18 @@ from ..distributed import (
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..cpp_extensions.gemm import _NUM_MAX_UB_STREAMS
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..utils import (
is_non_tn_fp8_gemm_supported,
torch_get_autocast_gpu_dtype,
get_nvtx_range_context,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ...common.recipe import DelayedScaling, Recipe from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
...@@ -58,13 +62,9 @@ __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] ...@@ -58,13 +62,9 @@ __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
_2X_ACC_FPROP = False _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True _2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_dummy_wgrads = {} _dummy_wgrads = {}
_multi_stream_cublas_batchgemm_workspace = [] _multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None
_ub_communicators = None _ub_communicators = None
ub_stream_nums = int(os.getenv("NVTE_UB_STREAM_NUMS", "2"))
_NUM_MAX_UB_STREAMS = ub_stream_nums if IS_HIP_EXTENSION else 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -78,38 +78,6 @@ class UserBufferQuantizationMode(Enum): ...@@ -78,38 +78,6 @@ class UserBufferQuantizationMode(Enum):
FP8 = "fp8" FP8 = "fp8"
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
# Add env for control the padding for blaslt
if IS_HIP_EXTENSION:
return 134_217_728
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return 32 * 1024 * 1024 + 1024
return 4_194_304
def get_workspace() -> torch.Tensor:
"""Returns workspace for cublas."""
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
)
return _cublas_workspace
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace:
for _ in range(tex.get_num_cublas_streams()):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
)
return _multi_stream_cublas_workspace
def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]: def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas.""" """Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_batchgemm_workspace global _multi_stream_cublas_batchgemm_workspace
...@@ -154,55 +122,55 @@ def initialize_ub( ...@@ -154,55 +122,55 @@ def initialize_ub(
) -> None: ) -> None:
r""" r"""
Initialize the Userbuffers communicator for overlapping tensor-parallel communications with Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules. GEMM compute in ``te.Linear``, ``te.LayerNormLinear`` and ``te.LayerNormMLP`` modules.
Parameters Parameters
---------- ----------
shape : list shape : list
shape of the communication buffer, typically set to be the same as the global shape of shape of the communication buffer, typically set to be the same as the global shape of
the input tensor to a te.TransformerLayer forward pass, with the sequence and batch the input tensor to a ``te.TransformerLayer`` forward pass, with the sequence and batch
dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)` dimensions collapsed together -- i.e.: ``(sequence_length * batch_size, hidden_size)``
tp_size : int tp_size : int
number of GPUs in the tensor-parallel process group number of GPUs in the tensor-parallel process group
use_fp8 : bool = False use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs. allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use `quantization_modes` instead. DEPRECATED: Please use ``quantization_modes`` instead.
quantization_modes : List[UserBufferQuantizationMode] = None quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list. if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy `use_fp8` parameter if `None` is provided. falls back to the legacy ``use_fp8`` parameter if ``None`` is provided.
dtype : torch.dtype = torch.bfloat16 dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False` non-FP8 data type of the communication buffer when ``use_fp8 = False``
ub_cfgs: dict = None ub_cfgs : dict = None
Configuration dictionary with the structure Configuration dictionary with the structure::
```
{ {
<gemm_name> : { <gemm_name> : {
"method": <"ring_exchange" or "pipeline">, "method": <"ring_exchange" or "pipeline">,
"is_reduce_scatter": bool, "is_reduce_scatter": bool,
"num_sm": int, "num_sm": int,
"cga_size": int, "cga_size": int,
"set_sm_margin": bool, "set_sm_margin": bool,
"num_splits": int, "num_splits": int,
"aggregate": bool, "aggregate": bool,
"atomic_gemm": bool, "atomic_gemm": bool,
"use_ce": bool, "use_ce": bool,
"fp8_buf": bool, "fp8_buf": bool,
} }
} }
```
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", for ``te.TransformerLayer`` GEMM layers in ``["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]`. "fc2_fprop", "fc2_wgrad"]``.
a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes` a list may be provided to specify different overlap configurations for different the quantization settings in ``quantization_modes``
bootstrap_backend : str = None bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and ``torch.distributed`` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are barrier collectives during Userbuffers initialization. Not all backends are
valid for every cluster configuration and distributed launch method even if valid for every cluster configuration and distributed launch method even if
they are available in PyTorch. When left unset, the initialization prefers they are available in PyTorch. When left unset, the initialization prefers
to use the MPI backend, falling back first on Gloo and then NCCL if MPI is to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this not available. Setting ``NVTE_UB_WITH_MPI=1`` when building TE overrides this
option and always initializes Userbuffers with direct MPI calls in C++, option and always initializes Userbuffers with direct MPI calls in C++,
which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time. which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time.
""" """
if not tex.device_supports_multicast(): if not tex.device_supports_multicast():
assert bool(int(os.getenv("UB_SKIPMC", "1"))), ( assert bool(int(os.getenv("UB_SKIPMC", "1"))), (
...@@ -299,16 +267,6 @@ def initialize_ub( ...@@ -299,16 +267,6 @@ def initialize_ub(
flush=True, flush=True,
) )
# Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS:
# This ensures we don't do `.repeat()` on an already expanded workspace
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
).repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
layers_all_gather_overlap = [ layers_all_gather_overlap = [
"qkv_fprop", "qkv_fprop",
...@@ -1033,7 +991,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1033,7 +991,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
Parameters Parameters
---------- ----------
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
""" """
self.tp_group = tp_group self.tp_group = tp_group
...@@ -1123,8 +1081,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1123,8 +1081,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
self.allow_different_data_and_param_types = allow_different_data_and_param_types self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True self.forwarded_at_least_once = True
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
delayed_scaling_recipe = self.fp8_meta["recipe"].delayed()
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else: else:
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
...@@ -1136,25 +1096,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1136,25 +1096,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.init_fp8_metadata(num_gemms=num_gemms) self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence() self._check_weight_tensor_recipe_correspondence()
if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
assert self.fp8_meta["recipe"].reduce_amax, ( if delayed_scaling_recipe:
"Amax reduction across tensor parallel group is " if self.sequence_parallel:
"necessary when using sequence parallelism with FP8." assert self.fp8_meta["recipe"].reduce_amax, (
) "Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
)
if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): if not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
# Activation recomputation is used and this is the first forward phase. # Activation recomputation is used and this is the first forward phase.
if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): if self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): with get_nvtx_range_context(self.__class__.__name__ + " forward"):
if not allow_non_contiguous and not inp.is_contiguous(): if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous() inp = inp.contiguous()
yield inp yield inp
if self.fp8 and in_fp8_activation_recompute_phase(): if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
...@@ -1434,7 +1396,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1434,7 +1396,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
workspace is being constructed or updated. workspace is being constructed or updated.
cache_name: str, optional cache_name: str, optional
Key for caching. Key for caching.
update_workspace: bool, default = `True` update_workspace: bool, default = True
Update workspace with values from `tensor`. Update workspace with values from `tensor`.
skip_update_flag: torch.Tensor, optional skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence GPU flag to skip updating the workspace. Take precedence
...@@ -1576,7 +1538,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1576,7 +1538,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
if not self.need_backward_dw(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): with get_nvtx_range_context(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop() (wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
weight_tensor = noop_cat(self._get_weight_tensors()) weight_tensor = noop_cat(self._get_weight_tensors())
...@@ -1673,6 +1635,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1673,6 +1635,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
if not self.fp8 and not self.fp8_calibration: if not self.fp8 and not self.fp8_calibration:
return return
if not self.primary_weights_in_fp8:
return
if not hasattr(self, "weight_names") or not self.weight_names: if not hasattr(self, "weight_names") or not self.weight_names:
return return
......
...@@ -24,11 +24,14 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -24,11 +24,14 @@ class _Fp8Padding(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
m_splits: List[int], non_tensor_args: Tuple,
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(m_splits, padded_m_splits, is_grad_enabled) = non_tensor_args
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = inp.shape[-1] in_features = inp.shape[-1]
...@@ -65,7 +68,7 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -65,7 +68,7 @@ class _Fp8Padding(torch.autograd.Function):
grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits
) )
return (grad_input, None, None, None) return grad_input, None
class Fp8Padding(torch.nn.Module): class Fp8Padding(torch.nn.Module):
...@@ -128,19 +131,20 @@ class Fp8Padding(torch.nn.Module): ...@@ -128,19 +131,20 @@ class Fp8Padding(torch.nn.Module):
if m_splits == padded_m_splits: if m_splits == padded_m_splits:
return inp, m_splits return inp, m_splits
if torch.is_grad_enabled(): is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
fn = _Fp8Padding.apply fn = _Fp8Padding.apply
args = [] autograd_ctx = []
else: else:
fn = _Fp8Padding.forward fn = _Fp8Padding.forward
args = [None] autograd_ctx = [None]
args += ( non_tensor_args = (
inp,
m_splits, m_splits,
padded_m_splits, padded_m_splits,
torch.is_grad_enabled(), is_grad_enabled,
) )
out = fn(*args) out = fn(*autograd_ctx, inp, non_tensor_args)
return out, padded_m_splits return out, padded_m_splits
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""FP8 Padding API""" """FP8 Padding API"""
from typing import List, Optional from typing import List, Optional, Tuple
import torch import torch
...@@ -24,11 +24,14 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -24,11 +24,14 @@ class _Fp8Unpadding(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
m_splits: List[int], non_tensor_args: Tuple,
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(m_splits, padded_m_splits, is_grad_enabled) = non_tensor_args
in_features = inp.shape[-1] in_features = inp.shape[-1]
# Allocate cast and transpose output tensor # Allocate cast and transpose output tensor
...@@ -63,7 +66,7 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -63,7 +66,7 @@ class _Fp8Unpadding(torch.autograd.Function):
grad_output.view(-1, in_features), grad_input, ctx.m_splits, ctx.padded_m_splits grad_output.view(-1, in_features), grad_input, ctx.m_splits, ctx.padded_m_splits
) )
return (grad_input, None, None, None) return grad_input, None
class Fp8Unpadding(torch.nn.Module): class Fp8Unpadding(torch.nn.Module):
...@@ -126,19 +129,20 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -126,19 +129,20 @@ class Fp8Unpadding(torch.nn.Module):
if m_splits == padded_m_splits: if m_splits == padded_m_splits:
return inp return inp
if torch.is_grad_enabled(): is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
fn = _Fp8Unpadding.apply fn = _Fp8Unpadding.apply
args = [] autograd_ctx = []
else: else:
fn = _Fp8Unpadding.forward fn = _Fp8Unpadding.forward
args = [None] autograd_ctx = [None]
args += ( non_tensor_args = (
inp,
m_splits, m_splits,
padded_m_splits, padded_m_splits,
torch.is_grad_enabled(), is_grad_enabled,
) )
out = fn(*args) out = fn(*autograd_ctx, inp, non_tensor_args)
return out return out
...@@ -14,8 +14,6 @@ import transformer_engine_torch as tex ...@@ -14,8 +14,6 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from .base import ( from .base import (
get_dummy_wgrad,
get_multi_stream_cublas_workspace,
get_dummy_wgrad, get_dummy_wgrad,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
...@@ -30,6 +28,7 @@ from ..utils import ( ...@@ -30,6 +28,7 @@ from ..utils import (
clear_tensor_data, clear_tensor_data,
init_method_constant, init_method_constant,
requires_grad, requires_grad,
get_nvtx_range_context,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -42,7 +41,6 @@ from ..cpp_extensions import ( ...@@ -42,7 +41,6 @@ from ..cpp_extensions import (
) )
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
...@@ -66,29 +64,35 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -66,29 +64,35 @@ class _GroupedLinear(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
m_splits: List[int], non_tensor_args: Tuple,
use_bias: bool,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizers: List[Quantizer],
weight_quantizers: List[Quantizer],
output_quantizers: List[Quantizer],
grad_output_quantizers: List[Quantizer],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
sequence_parallel: bool,
activation_dtype: torch.dtype,
is_grad_enabled: bool,
module,
skip_fp8_weight_update,
save_original_input,
fine_grained_activation_offloading,
*weights_and_biases, *weights_and_biases,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
m_splits,
use_bias,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_output_quantizers,
fuse_wgrad_accumulation,
cpu_offloading,
sequence_parallel,
activation_dtype,
is_grad_enabled,
module,
skip_fp8_weight_update,
save_original_input,
fine_grained_activation_offloading,
) = non_tensor_args
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:] biases = weights_and_biases[num_gemms:]
...@@ -187,7 +191,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -187,7 +191,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmats, inputmats,
[out], [out],
activation_dtype, activation_dtype,
get_multi_stream_cublas_workspace(),
single_output=True, single_output=True,
m_splits=m_splits, m_splits=m_splits,
bias=biases, bias=biases,
...@@ -313,7 +316,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -313,7 +316,7 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_GroupedLinear_backward"): with get_nvtx_range_context("_GroupedLinear_backward"):
saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
N = ctx.num_gemms N = ctx.num_gemms
inputmats = saved_tensors[:N] inputmats = saved_tensors[:N]
...@@ -404,7 +407,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -404,7 +407,6 @@ class _GroupedLinear(torch.autograd.Function):
grad_output, grad_output,
[dgrad], [dgrad],
ctx.activation_dtype, ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
single_output=True, single_output=True,
layout="NN", layout="NN",
m_splits=ctx.m_splits, m_splits=ctx.m_splits,
...@@ -451,7 +453,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -451,7 +453,6 @@ class _GroupedLinear(torch.autograd.Function):
grouped_gemm_wgrad = functools.partial( grouped_gemm_wgrad = functools.partial(
general_grouped_gemm, general_grouped_gemm,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
workspaces=get_multi_stream_cublas_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
m_splits=ctx.m_splits, m_splits=ctx.m_splits,
...@@ -523,29 +524,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -523,29 +524,11 @@ class _GroupedLinear(torch.autograd.Function):
): ):
grad_biases = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
None, None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
*wgrad_list, *wgrad_list,
*grad_biases, *grad_biases,
) )
...@@ -563,14 +546,14 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -563,14 +546,14 @@ class GroupedLinear(TransformerEngineBaseModule):
size of each input sample. size of each input sample.
out_features : int out_features : int
size of each output sample. size of each output sample.
bias : bool, default = `True` bias : bool, default = True
if set to `False`, the layer will not learn an additive bias. if set to ``False``, the layer will not learn an additive bias.
init_method : Callable, default = `None` init_method : Callable, default = None
used for initializing weights in the following way: `init_method(weight)`. used for initializing weights in the following way: ``init_method(weight)``.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
get_rng_state_tracker : Callable, default = `None` get_rng_state_tracker : Callable, default = None
used to get the random number generator state tracker for initializing weights. used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default = `None` rng_tracker_name : str, default = None
the param passed to get_rng_state_tracker to get the specific rng tracker. the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
...@@ -579,34 +562,36 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -579,34 +562,36 @@ class GroupedLinear(TransformerEngineBaseModule):
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = False
if set to `True`, enables fusing of creation and accumulation of if set to ``True``, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional ``main_grad`` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular ``grad``) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating. will overwrite ``main_grad`` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias itself, but when set to ``True``, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.get_default_dtype()` params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False` delay_wgrad_compute : bool, default = False
Whether to delay weight gradient computation Whether to delay weight gradient computation
save_original_input : bool, default = `False` save_original_input : bool, default = False
If set to `True`, always saves the original input tensor rather than the If set to ``True``, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules, cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage. and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe. Cannot work with FP8 DelayedScaling recipe.
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and Notes
`parallel_mode` are used to determine the shapes of weights and biases. -----
The TP communication should be handled in the dispatch and combine stages of MoE models. GroupedLinear doesn't really handle the TP communications inside. The ``tp_size`` and
``parallel_mode`` are used to determine the shapes of weights and biases.
The TP communication should be handled in the dispatch and combine stages of MoE models.
""" """
def __init__( def __init__(
...@@ -807,16 +792,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -807,16 +792,9 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support input tensor in FP8." ), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if FP8GlobalStateManager.fp8_graph_capturing(): is_grad_enabled = torch.is_grad_enabled()
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with torch.cuda.device( with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors() weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
...@@ -836,7 +814,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -836,7 +814,7 @@ class GroupedLinear(TransformerEngineBaseModule):
# TODO: use internal after #1638 is merged. # pylint: disable=fixme # TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms): for i in range(self.num_gemms):
input_quantizers[i].internal = False input_quantizers[i].internal = False
if torch.is_grad_enabled(): if is_grad_enabled:
grad_output_quantizers = [ grad_output_quantizers = [
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
...@@ -846,14 +824,14 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -846,14 +824,14 @@ class GroupedLinear(TransformerEngineBaseModule):
for i in range(self.num_gemms): for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True grad_output_quantizers[i].internal = True
if torch.is_grad_enabled(): if is_grad_enabled:
linear_fn = _GroupedLinear.apply linear_fn = _GroupedLinear.apply
args = [] autograd_ctx = []
else: else:
linear_fn = _GroupedLinear.forward linear_fn = _GroupedLinear.forward
args = [None] autograd_ctx = [None]
args += (
inp, non_tensor_args = (
m_splits, m_splits,
self.apply_bias, self.apply_bias,
is_first_microbatch, is_first_microbatch,
...@@ -868,15 +846,13 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -868,15 +846,13 @@ class GroupedLinear(TransformerEngineBaseModule):
is_cpu_offload_enabled(), is_cpu_offload_enabled(),
self.sequence_parallel, self.sequence_parallel,
self.activation_dtype, self.activation_dtype,
torch.is_grad_enabled(), is_grad_enabled,
self, self,
skip_fp8_weight_update, None, # skip_fp8_weight_update
self.save_original_input, self.save_original_input,
self.fine_grained_activation_offloading, self.fine_grained_activation_offloading,
*weight_tensors,
*bias_tensors,
) )
out = linear_fn(*args) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if self.return_bias: if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
...@@ -889,7 +865,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -889,7 +865,7 @@ class GroupedLinear(TransformerEngineBaseModule):
""" """
if not self.need_backward_dw(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): with get_nvtx_range_context("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop() (_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2] wgrad_list = tensor_list[2]
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
......
...@@ -28,33 +28,30 @@ class LayerNorm(_LayerNormOp): ...@@ -28,33 +28,30 @@ class LayerNorm(_LayerNormOp):
Parameters Parameters
---------- ----------
normalized_shape: int or iterable of int normalized_shape : int or iterable of int
Inner dimensions of input tensor Inner dimensions of input tensor
eps : float, default = 1e-5 eps : float, default = 1e-5
A value added to the denominator of layer normalization for A value added to the denominator of layer normalization for
numerical stability numerical stability
device: torch.device, default = default CUDA device device : torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype : torch.dtype, default = default dtype
Tensor datatype Tensor datatype
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero If ``True``, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to and the calculation changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0 sm_margin : int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels. helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward", margin at each compute stage (``"forward"``, ``"backward"``,
"inference"). ``"inference"``).
sequence_parallel : bool
Legacy **Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters.
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
This is custom logic for Megatron-LM integration. This is custom logic for Megatron-LM integration.
""" """
......
...@@ -15,11 +15,10 @@ from torch.nn import init ...@@ -15,11 +15,10 @@ from torch.nn import init
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.torch_version import torch_version
from transformer_engine.pytorch.tensor.utils import is_custom from transformer_engine.pytorch.tensor.utils import is_custom
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_workspace,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
get_dummy_wgrad, get_dummy_wgrad,
...@@ -40,6 +39,7 @@ from ..utils import ( ...@@ -40,6 +39,7 @@ from ..utils import (
nvtx_range_push, nvtx_range_push,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
get_nvtx_range_context,
get_activation_offloading, get_activation_offloading,
) )
from ..distributed import ( from ..distributed import (
...@@ -104,48 +104,54 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -104,48 +104,54 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias: Union[torch.Tensor, None], ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
eps: float, non_tensor_args: Tuple,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
normalization: str,
ub_overlap_ag_fprop: bool,
ub_overlap_rs_fprop: bool,
ub_overlap_ag_dgrad: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_name: str,
fine_grained_activation_offloading: bool,
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
parallel_mode,
return_layernorm_output,
return_layernorm_output_gathered,
is_grad_enabled,
fwd_ln_sm_margin,
bwd_ln_sm_margin,
zero_centered_gamma,
normalization,
ub_overlap_ag_fprop,
ub_overlap_rs_fprop,
ub_overlap_ag_dgrad,
ub_overlap_rs_dgrad,
ub_bulk_wgrad,
ub_bulk_dgrad,
ub_name,
fine_grained_activation_offloading,
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
debug,
) = non_tensor_args
# NVTX label for profiling # NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.forward" nvtx_label = "transformer_engine._LayerNormLinear.forward"
if ub_name is not None: if ub_name is not None:
...@@ -364,7 +370,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -364,7 +370,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat, weightmat,
ln_out_total, ln_out_total,
get_workspace(),
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=bias, bias=bias,
...@@ -553,7 +558,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -553,7 +558,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_name is not None: if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}" nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_LayerNormLinear_backward"): with get_nvtx_range_context("_LayerNormLinear_backward"):
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
...@@ -742,7 +747,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -742,7 +747,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weight, weight,
grad_output, grad_output,
get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
...@@ -869,7 +873,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -869,7 +873,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = { wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
...@@ -1044,45 +1047,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1044,45 +1047,7 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta, dbeta,
wgrad, wgrad,
grad_bias, grad_bias,
None, # eps None,
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # return_layernorm_output
None, # return_layernorm_output_gathered
None, # is_grad_enabled
None, # fwd_ln_sm_margin
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # normalization
None, # ub_overlap_ag_fprop
None, # ub_overlap_rs_fprop
None, # ub_overlap_ag_dgrad
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fine_grained_activation_offloading
None, # fsdp_group
None, # debug
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
) )
...@@ -1098,20 +1063,20 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1098,20 +1063,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
size of each output sample. size of each output sample.
eps : float, default = 1e-5 eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability. a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True` bias : bool, default = True
if set to `False`, the layer will not learn an additive bias. if set to ``False``, the layer will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied. type of normalization applied.
init_method : Callable, default = `None` init_method : Callable, default = None
used for initializing weights in the following way: `init_method(weight)`. used for initializing weights in the following way: ``init_method(weight)``.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
return_layernorm_output : bool, default = `False` return_layernorm_output : bool, default = False
if set to `True`, output of layernorm is returned from the forward if set to ``True``, output of layernorm is returned from the forward
together with the output of the linear transformation. together with the output of the linear transformation.
Example use case: residual connection for transformer module is Example use case: residual connection for transformer module is
taken post layernorm. taken post layernorm.
return_layernorm_output_gathered : bool, default = `False` return_layernorm_output_gathered : bool, default = False
if set to `True`, output of layernorm is returned after the all if set to ``True``, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False. gather operation. Ignored if return_layernorm_output is False.
Example use case: with sequence parallel, input to residual connection Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered. for transformer module (e.g. LoRA) will need to be gathered.
...@@ -1122,10 +1087,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1122,10 +1087,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
they are used to make the names of equally-sized parameters. If a dict they are used to make the names of equally-sized parameters. If a dict
(preferably an OrderedDict) is provided, the keys are used as names and (preferably an OrderedDict) is provided, the keys are used as names and
values as split sizes along dim 0. The resulting parameters will have values as split sizes along dim 0. The resulting parameters will have
names that end in `_weight` or `_bias`, so trailing underscores are names that end in ``_weight`` or ``_bias``, so trailing underscores are
stripped from any provided names. stripped from any provided names.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to ``'True'``, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
.. math:: .. math::
...@@ -1135,53 +1100,53 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1135,53 +1100,53 @@ class LayerNormLinear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
name: str, default = `None` name : str, default = None
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism. if set to ``True``, uses sequence parallelism.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
tp_size : int, default = 1 tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives. parallel collectives.
parallel_mode : {None, 'column', 'row'}, default = `None` parallel_mode : {None, 'column', 'row'}, default = None
used to decide whether this Linear layer is Column Parallel Linear or Row used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_. Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed. When set to ``None``, no communication is performed.
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of if set to ``True``, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional ``main_grad`` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular ``grad``) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating. will overwrite ``main_grad`` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias itself, but when set to ``True``, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.get_default_dtype()` params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False` delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to `True`, Whether or not to delay weight gradient computation. If set to ``True``,
it's the user's responsibility to call `module.backward_dw` to compute it's the user's responsibility to call ``module.backward_dw`` to compute
weight gradients. weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass. Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations. This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce Requires PyTorch version 2.7.0 or higher. When set to ``None``, standard all-reduce
is used. is used.
""" """
...@@ -1544,8 +1509,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1544,8 +1509,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output) return self.onnx_forward(inp, fp8_output, is_grad_enabled)
debug = self.is_debug_iter() debug = self.is_debug_iter()
...@@ -1567,9 +1534,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1567,9 +1534,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
).is_fp8_ubuf(): ).is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with torch.cuda.device( with self.prepare_forward(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp: ) as inp:
...@@ -1577,14 +1542,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1577,14 +1542,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = ( quantizers = (
self._get_quantizers(fp8_output, fp8_grad) self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad) else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
) )
if debug: if debug:
if self.no_debug_features_active(quantizers): if self.no_debug_features_active(quantizers):
debug = False debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad) quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
( (
input_quantizer, input_quantizer,
...@@ -1595,18 +1560,13 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1595,18 +1560,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) = quantizers ) = quantizers
if torch.is_grad_enabled(): if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply fwd_fn = _LayerNormLinear.apply
args = [] autograd_ctx = []
else: else:
fwd_fn = _LayerNormLinear.forward fwd_fn = _LayerNormLinear.forward
args = [None] autograd_ctx = [None]
args += ( non_tensor_args = (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
...@@ -1628,8 +1588,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1628,8 +1588,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.parallel_mode, self.parallel_mode,
self.return_layernorm_output, self.return_layernorm_output,
self.return_layernorm_output_gathered, self.return_layernorm_output_gathered,
torch.is_grad_enabled(), is_grad_enabled,
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.normalization, self.normalization,
...@@ -1647,7 +1607,15 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1647,7 +1607,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.symmetric_ar_type, self.symmetric_ar_type,
debug, debug,
) )
out = fwd_fn(*args) out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
...@@ -1663,7 +1631,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1663,7 +1631,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return out, ln_out return out, ln_out
return out return out
def _get_quantizers(self, fp8_output, fp8_grad): def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
if not self.fp8: if not self.fp8:
return [None] * 6 return [None] * 6
grad_input_quantizer = None grad_input_quantizer = None
...@@ -1675,7 +1643,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1675,7 +1643,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
(weight_quantizer,) = self._get_weight_quantizers() (weight_quantizer,) = self._get_weight_quantizers()
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled(): if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True grad_output_quantizer.internal = True
if fp8_grad: if fp8_grad:
...@@ -1690,8 +1658,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1690,8 +1658,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) )
def _get_debug_quantizers(self, fp8_output, fp8_grad): def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad) original_quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
assert TEDebugState.debug_enabled assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
...@@ -1716,6 +1684,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1716,6 +1684,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self, self,
inp: torch.Tensor, inp: torch.Tensor,
fp8_output: bool, fp8_output: bool,
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
ONNX-compatible version of the forward function that provides numerical equivalence ONNX-compatible version of the forward function that provides numerical equivalence
...@@ -1731,7 +1700,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1731,7 +1700,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
*_, *_,
) = self._get_quantizers(fp8_output, fp8_grad=False) ) = self._get_quantizers(fp8_output, False, is_grad_enabled)
inp_dtype = inp.dtype inp_dtype = inp.dtype
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
......
...@@ -17,11 +17,10 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION ...@@ -17,11 +17,10 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.torch_version import torch_version
from transformer_engine.pytorch.tensor.utils import is_custom from transformer_engine.pytorch.tensor.utils import is_custom
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_workspace,
_ub_communicators, _ub_communicators,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
...@@ -46,6 +45,7 @@ from ..utils import ( ...@@ -46,6 +45,7 @@ from ..utils import (
clear_tensor_data, clear_tensor_data,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
get_nvtx_range_context,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -181,55 +181,61 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -181,55 +181,61 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias: torch.Tensor, fc1_bias: torch.Tensor,
fc2_weight: torch.Tensor, fc2_weight: torch.Tensor,
fc2_bias: torch.Tensor, fc2_bias: torch.Tensor,
eps: float, non_tensor_args: Tuple,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
fc1_input_quantizer: Optional[Quantizer],
fc1_weight_quantizer: Optional[Quantizer],
fc1_output_quantizer: Optional[Quantizer],
fc1_grad_input_quantizer: Optional[Quantizer],
fc1_grad_weight_quantizer: Optional[Quantizer],
fc1_grad_output_quantizer: Optional[Quantizer],
fc2_input_quantizer: Optional[Quantizer],
fc2_weight_quantizer: Optional[Quantizer],
fc2_output_quantizer: Optional[Quantizer],
fc2_grad_input_quantizer: Optional[Quantizer],
fc2_grad_weight_quantizer: Optional[Quantizer],
fc2_grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
bias_gelu_fusion: bool,
set_parallel_mode: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
activation: str,
activation_params: Optional[dict],
normalization: str,
ub_overlap_ag: bool,
ub_overlap_rs: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
gemm_gelu_fusion: bool,
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
return_layernorm_output,
return_layernorm_output_gathered,
bias_gelu_fusion,
set_parallel_mode,
is_grad_enabled,
fwd_ln_sm_margin,
bwd_ln_sm_margin,
zero_centered_gamma,
activation,
activation_params,
normalization,
ub_overlap_ag,
ub_overlap_rs,
ub_overlap_rs_dgrad,
ub_bulk_wgrad,
ub_bulk_dgrad,
gemm_gelu_fusion,
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
debug,
) = non_tensor_args
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features, inp_shape = ln_weight.numel(), inp.shape in_features, inp_shape = ln_weight.numel(), inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible" assert inp_shape[-1] == in_features, "GEMM not possible"
...@@ -440,7 +446,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -440,7 +446,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_outputs = general_gemm( fc1_outputs = general_gemm(
fc1_weight_final, fc1_weight_final,
ln_out_total, ln_out_total,
get_workspace(),
quantization_params=( quantization_params=(
fc2_input_quantizer fc2_input_quantizer
if gemm_gelu_fusion if gemm_gelu_fusion
...@@ -524,7 +529,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -524,7 +529,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
fc2_weight_final, fc2_weight_final,
act_out, act_out,
get_workspace(),
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=fc2_bias, bias=fc2_bias,
quantization_params=fc2_output_quantizer, quantization_params=fc2_output_quantizer,
...@@ -711,7 +715,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -711,7 +715,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_LayerNormMLP_backward"): with get_nvtx_range_context("_LayerNormMLP_backward"):
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
...@@ -881,7 +885,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -881,7 +885,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_output, *_ = general_gemm( gemm_output, *_ = general_gemm(
fc2_weight, fc2_weight,
grad_output, grad_output,
get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=( quantization_params=(
...@@ -975,7 +978,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -975,7 +978,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
fc2_wgrad_gemm_kwargs = { fc2_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
origin_fc2_weight.main_grad.dtype origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
...@@ -1153,7 +1155,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1153,7 +1155,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
fc1_weight, fc1_weight,
dact, dact,
get_workspace(),
out=gemm_out, out=gemm_out,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer, quantization_params=ctx.fc1_grad_input_quantizer,
...@@ -1232,7 +1233,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1232,7 +1233,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
fc1_wgrad_gemm_kwargs = { fc1_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
origin_fc1_weight.main_grad.dtype origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
...@@ -1427,52 +1427,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1427,52 +1427,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias_grad if fc1_bias is not None else None, fc1_bias_grad if fc1_bias is not None else None,
fc2_wgrad, # pylint: disable=possibly-used-before-assignment fc2_wgrad, # pylint: disable=possibly-used-before-assignment
fc2_bias_grad, fc2_bias_grad,
None, # eps None,
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # fc1_input_quantizer,
None, # fc1_weight_quantizer,
None, # fc1_output_quantizer,
None, # fc1_grad_input_quantizer,
None, # fc1_grad_weight_quantizer,
None, # fc1_grad_output_quantizer,
None, # fc2_input_quantizer,
None, # fc2_weight_quantizer,
None, # fc2_output_quantizer,
None, # fc2_grad_input_quantizer,
None, # fc2_grad_weight_quantizer,
None, # fc2_grad_output_quantizer,
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # return_layernorm_output
None, # return_layernorm_output_gathered
None, # bias_gelu_fusion
None, # set_parallel_mode
None, # is_grad_enabled
None, # fwd_ln_sm_margin
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # activation
None, # activation_params
None, # normalization
None, # ub_overlap_ag
None, # ub_overlap_rs
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # gemm_gelu_fusion
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # debug
) )
...@@ -1489,38 +1444,38 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1489,38 +1444,38 @@ class LayerNormMLP(TransformerEngineBaseModule):
intermediate size to which input samples are projected. intermediate size to which input samples are projected.
eps : float, default = 1e-5 eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability. a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True` bias : bool, default = True
if set to `False`, the FC1 and FC2 layers will not learn an additive bias. if set to ``False``, the FC1 and FC2 layers will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied. type of normalization applied.
activation : str, default = 'gelu' activation : str, default = 'gelu'
activation function used. activation function used.
Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
'silu', 'swiglu', and 'clamped_swiglu'. ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : dict, default = `None` activation_params : dict, default = None
Additional parameters for the activation function. Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which At the moment, only used for ``'clamped_swiglu'`` activation which
supports 'limit' and 'alpha' parameters. supports ``'limit'`` and ``'alpha'`` parameters.
init_method : Callable, default = `None` init_method : Callable, default = None
used for initializing FC1 weights in the following way: `init_method(weight)`. used for initializing FC1 weights in the following way: ``init_method(weight)``.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
output_layer_init_method : Callable, default = `None` output_layer_init_method : Callable, default = None
used for initializing FC2 weights in the following way: used for initializing FC2 weights in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to ``output_layer_init_method(weight)``. When set to ``None``, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`. ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
return_layernorm_output : bool, default = `False` return_layernorm_output : bool, default = False
if set to `True`, output of layernorm is returned from the forward if set to ``True``, output of layernorm is returned from the :meth:`forward` method
together with the output of the linear transformation. together with the output of the linear transformation.
Example use case: residual connection for transformer module Example use case: residual connection for transformer module
is taken post layernorm. is taken post layernorm.
return_layernorm_output_gathered : bool, default = `False` return_layernorm_output_gathered : bool, default = False
if set to `True`, output of layernorm is returned after the all if set to ``True``, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False. gather operation. Ignored if ``return_layernorm_output`` is False.
Example use case: with sequence parallel, input to residual connection Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered. for transformer module (e.g. LoRA) will need to be gathered.
Returning layernorm output gathered will prevent a redundant gather. Returning layernorm output gathered will prevent a redundant gather.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = False
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
.. math:: .. math::
...@@ -1530,61 +1485,65 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1530,61 +1485,65 @@ class LayerNormMLP(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
name: str, default = `None` name : str, default = None
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
set_parallel_mode : bool, default = `False` set_parallel_mode : bool, default = False
if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row if set to ``True``, FC1 is used as Column Parallel and FC2 is used as Row
Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_. Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism. if set to ``True``, uses sequence parallelism.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
tp_size : int, default = 1 tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives. parallel collectives.
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = False
if set to `True`, enables fusing of creation and accumulation of if set to ``True``, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional ``main_grad`` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular ``grad``) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True weight tensor having attribute ``'overwrite_main_grad'`` set to True
will overwrite `main_grad` instead of accumulating. will overwrite ``main_grad`` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias for FC2, but when set to ``True``, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.get_default_dtype()` params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
seq_length: int seq_length : int
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward functions are warmed up before training to ensure same kernels are used for forward
propogation and activation recompute phase. propogation and activation recompute phase.
micro_batch_size: int micro_batch_size : int
batch size per training step. Needed for JIT Warmup, a technique where jit batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase. used for forward propogation and activation recompute phase.
delay_wgrad_compute : bool, default = `False` delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to `True`, Whether or not to delay weight gradient computation. If set to ``True``,
it's the user's responsibility to call `module.backward_dw` to compute it's the user's responsibility to call :meth:`backward_dw` to compute
weight gradients. weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass. Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations. This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce Requires PyTorch version 2.7.0 or higher. When set to ``None``, standard all-reduce
is used. is used.
checkpoint : bool, default = False
whether to use selective activation checkpointing, where activations are not saved for bwd,
and instead are recomputed (skipping fc2, as it is not needed for backward). Trades compute
for memory. default is false, in which activations are saved in fwd. not supported for onnx forward
""" """
def __init__( def __init__(
...@@ -1855,8 +1814,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1855,8 +1814,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp) return self.onnx_forward(inp, is_grad_enabled)
debug = self.is_debug_iter() debug = self.is_debug_iter()
...@@ -1872,19 +1833,17 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1872,19 +1833,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True fp8_output = True
with torch.cuda.device( with self.prepare_forward(inp, num_gemms=2) as inp:
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=2) as inp:
quantizers = ( quantizers = (
self._get_quantizers(fp8_output) self._get_quantizers(fp8_output, is_grad_enabled)
if not debug if not debug
else self._get_debug_quantizers(fp8_output) else self._get_debug_quantizers(fp8_output, is_grad_enabled)
) )
if debug: if debug:
if self.no_debug_features_active(quantizers): if self.no_debug_features_active(quantizers):
debug = False debug = False
quantizers = self._get_quantizers(fp8_output) quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
# Get quantizers # Get quantizers
( (
...@@ -1917,20 +1876,14 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1917,20 +1876,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
and self.bias_gelu_nvfusion and not use_reentrant_activation_recompute() ): and self.bias_gelu_nvfusion and not use_reentrant_activation_recompute() ):
self.bias_gelu_nvfusion = False self.bias_gelu_nvfusion = False
if torch.is_grad_enabled(): if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply fwd_fn = _LayerNormMLP.apply
args = [] autograd_ctx = []
else: else:
fwd_fn = _LayerNormMLP.forward fwd_fn = _LayerNormMLP.forward
args = [None] autograd_ctx = [None]
args += (
inp, non_tensor_args = (
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
...@@ -1959,8 +1912,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1959,8 +1912,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_layernorm_output_gathered, self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug, self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode, self.set_parallel_mode,
torch.is_grad_enabled(), is_grad_enabled,
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.activation, self.activation,
...@@ -1978,7 +1931,17 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1978,7 +1931,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.symmetric_ar_type, self.symmetric_ar_type,
debug, debug,
) )
out = fwd_fn(*args) out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
...@@ -1994,7 +1957,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1994,7 +1957,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return out, ln_out return out, ln_out
return out return out
def _get_quantizers(self, fp8_output): def _get_quantizers(self, fp8_output, is_grad_enabled):
( (
fc1_input_quantizer, fc1_input_quantizer,
fc1_output_quantizer, fc1_output_quantizer,
...@@ -2024,7 +1987,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2024,7 +1987,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_output_quantizer = self.quantizers["scaling_fwd"][ fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT tex.FP8FwdTensors.GEMM2_OUTPUT
] ]
if torch.is_grad_enabled(): if is_grad_enabled:
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2 tex.FP8BwdTensors.GRAD_OUTPUT2
] ]
...@@ -2049,9 +2012,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2049,9 +2012,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer, fc2_grad_output_quantizer,
) )
def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: def onnx_forward(
self, inp: torch.Tensor, is_grad_enabled: bool
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
ONNX-compatible version of the forward function that provides numerical equivalence ONNX-compatible version of the :meth:`forward` method that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations. while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios. This simplified implementation is designed specifically for inference scenarios.
""" """
...@@ -2066,7 +2031,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2066,7 +2031,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer, fc2_weight_quantizer,
output_quantizer, output_quantizer,
*_, *_,
) = self._get_quantizers(False) ) = self._get_quantizers(False, is_grad_enabled)
inp_dtype = inp.dtype inp_dtype = inp.dtype
fc1_weight, fc2_weight = self._get_weight_tensors() fc1_weight, fc2_weight = self._get_weight_tensors()
...@@ -2151,10 +2116,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2151,10 +2116,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
return fc2_out, fc2_bias.to(inp_dtype) return fc2_out, fc2_bias.to(inp_dtype)
return fc2_out return fc2_out
def _get_debug_quantizers(self, fp8_output): def _get_debug_quantizers(self, fp8_output, is_grad_enabled):
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
base_quantizers = list(self._get_quantizers(fp8_output)) base_quantizers = list(self._get_quantizers(fp8_output, is_grad_enabled))
assert TEDebugState.debug_enabled assert TEDebugState.debug_enabled
def make_debug(prefix, offset): def make_debug(prefix, offset):
...@@ -2297,7 +2262,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2297,7 +2262,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
""" """
if not self.need_backward_dw(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"): with get_nvtx_range_context("_LayerNormMLP_wgrad"):
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop() (fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()
if self.use_bias and self.fc1_bias.grad is None: if self.use_bias and self.fc1_bias.grad is None:
(fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop() (fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop()
......
...@@ -14,13 +14,12 @@ import torch ...@@ -14,13 +14,12 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.torch_version import torch_version
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad, get_dummy_wgrad,
get_ub, get_ub,
get_workspace,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
...@@ -39,6 +38,7 @@ from ..utils import ( ...@@ -39,6 +38,7 @@ from ..utils import (
assert_dim_for_all_gather, assert_dim_for_all_gather,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
get_nvtx_range_context,
get_activation_offloading, get_activation_offloading,
) )
from ..distributed import ( from ..distributed import (
...@@ -92,43 +92,47 @@ class _Linear(torch.autograd.Function): ...@@ -92,43 +92,47 @@ class _Linear(torch.autograd.Function):
weight: torch.Tensor, weight: torch.Tensor,
inp: torch.Tensor, inp: torch.Tensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
is_first_microbatch: Union[bool, None], non_tensor_args: Tuple,
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
ub_overlap_rs_fprop: bool,
ub_overlap_ag_dgrad: bool,
ub_overlap_ag_fprop: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_dgrad: bool,
ub_bulk_wgrad: bool,
ub_name: str,
fine_grained_activation_offloading: bool,
fp8_output: bool, # pylint: disable=unused-argument
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
save_original_input: bool = False,
debug: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
(
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
fuse_wgrad_accumulation,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
parallel_mode,
is_grad_enabled,
ub_overlap_rs_fprop,
ub_overlap_ag_dgrad,
ub_overlap_ag_fprop,
ub_overlap_rs_dgrad,
ub_bulk_dgrad,
ub_bulk_wgrad,
ub_name,
fine_grained_activation_offloading,
fp8_output, # pylint: disable=unused-variable
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
save_original_input,
debug,
) = non_tensor_args
# NVTX label for profiling # NVTX label for profiling
nvtx_label = "transformer_engine._Linear.forward" nvtx_label = "transformer_engine._Linear.forward"
if ub_name is not None: if ub_name is not None:
...@@ -323,7 +327,6 @@ class _Linear(torch.autograd.Function): ...@@ -323,7 +327,6 @@ class _Linear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat, weightmat,
inputmat_total, inputmat_total,
get_workspace(),
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=bias, bias=bias,
...@@ -520,7 +523,7 @@ class _Linear(torch.autograd.Function): ...@@ -520,7 +523,7 @@ class _Linear(torch.autograd.Function):
if ctx.ub_name is not None: if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}" nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_Linear_backward"): with get_nvtx_range_context("_Linear_backward"):
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors) restore_from_saved(ctx.tensor_objects, saved_tensors)
...@@ -744,7 +747,6 @@ class _Linear(torch.autograd.Function): ...@@ -744,7 +747,6 @@ class _Linear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weight_fp8, weight_fp8,
grad_output, grad_output,
get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
...@@ -870,7 +872,6 @@ class _Linear(torch.autograd.Function): ...@@ -870,7 +872,6 @@ class _Linear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = { wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
...@@ -1005,47 +1006,14 @@ class _Linear(torch.autograd.Function): ...@@ -1005,47 +1006,14 @@ class _Linear(torch.autograd.Function):
wgrad, wgrad,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias, grad_bias,
None, # is_first_microbatch None,
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # fuse_wgrad_accumulation
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # is_grad_enabled
None, # ub_overlap_rs_fprop
None, # ub_overlap_ag_dgrad
None, # ub_overlap_ag_fprop
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fine_grained_activation_offloading
None, # fp8_output
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # save_original_input
None, # debug
) )
class Linear(TransformerEngineBaseModule): class Linear(TransformerEngineBaseModule):
"""Applies a linear transformation to the incoming data :math:`y = xA^T + b` """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`. On NVIDIA GPUs it is a drop-in replacement for ``torch.nn.Linear``.
Parameters Parameters
---------- ----------
...@@ -1053,14 +1021,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -1053,14 +1021,14 @@ class Linear(TransformerEngineBaseModule):
size of each input sample. size of each input sample.
out_features : int out_features : int
size of each output sample. size of each output sample.
bias : bool, default = `True` bias : bool, default = True
if set to `False`, the layer will not learn an additive bias. if set to ``False``, the layer will not learn an additive bias.
init_method : Callable, default = `None` init_method : Callable, default = None
used for initializing weights in the following way: `init_method(weight)`. used for initializing weights in the following way: ``init_method(weight)``.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
get_rng_state_tracker : Callable, default = `None` get_rng_state_tracker : Callable, default = None
used to get the random number generator state tracker for initializing weights. used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default = `None` rng_tracker_name : str, default = None
the param passed to get_rng_state_tracker to get the specific rng tracker. the param passed to get_rng_state_tracker to get the specific rng tracker.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
Configuration for splitting the weight and bias tensors along dim 0 into Configuration for splitting the weight and bias tensors along dim 0 into
...@@ -1068,62 +1036,62 @@ class Linear(TransformerEngineBaseModule): ...@@ -1068,62 +1036,62 @@ class Linear(TransformerEngineBaseModule):
they are used to make the names of equally-sized parameters. If a dict they are used to make the names of equally-sized parameters. If a dict
(preferably an OrderedDict) is provided, the keys are used as names and (preferably an OrderedDict) is provided, the keys are used as names and
values as split sizes along dim 0. The resulting parameters will have values as split sizes along dim 0. The resulting parameters will have
names that end in `_weight` or `_bias`, so trailing underscores are names that end in ``_weight`` or ``_bias``, so trailing underscores are
stripped from any provided names. stripped from any provided names.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
name: str, default = `None` name : str, default = None
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism. if set to ``True``, uses sequence parallelism.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
tp_size : int, default = 1 tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives. parallel collectives.
parallel_mode : {None, 'column', 'row'}, default = `None` parallel_mode : {None, 'column', 'row'}, default = None
used to decide whether this Linear layer is Column Parallel Linear or Row used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_. Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed. When set to ``None``, no communication is performed.
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of if set to ``True``, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional ``main_grad`` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular ``grad``) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating. will overwrite ``main_grad`` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias itself, but when set to ``True``, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.get_default_dtype()` params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False` delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to `True`, Whether or not to delay weight gradient computation. If set to ``True``,
it's the user's responsibility to call `module.backward_dw` to compute it's the user's responsibility to call ``module.backward_dw`` to compute
weight gradients. weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass. Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations. This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce Requires PyTorch version 2.7.0 or higher. When set to ``None``, standard all-reduce
is used. is used.
save_original_input : bool, default = `False` save_original_input : bool, default = False
If set to `True`, always saves the original input tensor rather than the If set to ``True``, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules, cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage. and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe. Cannot work with FP8 DelayedScaling recipe.
...@@ -1434,8 +1402,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -1434,8 +1402,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output) return self.onnx_forward(inp, fp8_output, is_grad_enabled)
debug = self.is_debug_iter() debug = self.is_debug_iter()
...@@ -1457,9 +1427,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1457,9 +1427,7 @@ class Linear(TransformerEngineBaseModule):
).is_fp8_ubuf(): ).is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with torch.cuda.device( with self.prepare_forward(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor), allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp: ) as inp:
...@@ -1467,14 +1435,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -1467,14 +1435,14 @@ class Linear(TransformerEngineBaseModule):
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = ( quantizers = (
self._get_quantizers(fp8_output, fp8_grad) self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad) else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
) )
if debug: if debug:
if self.no_debug_features_active(quantizers): if self.no_debug_features_active(quantizers):
debug = False debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad) quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
( (
input_quantizer, input_quantizer,
...@@ -1485,16 +1453,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -1485,16 +1453,14 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) = quantizers ) = quantizers
if torch.is_grad_enabled(): if is_grad_enabled:
linear_fn = _Linear.apply linear_fn = _Linear.apply
args = [] autograd_ctx = []
else: else:
linear_fn = _Linear.forward linear_fn = _Linear.forward
args = [None] autograd_ctx = [None]
args += (
weight_tensor, non_tensor_args = (
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
...@@ -1513,7 +1479,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1513,7 +1479,7 @@ class Linear(TransformerEngineBaseModule):
self.tp_size > 1, self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
torch.is_grad_enabled(), is_grad_enabled,
self.ub_overlap_rs_fprop, self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad, self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop, self.ub_overlap_ag_fprop,
...@@ -1530,7 +1496,13 @@ class Linear(TransformerEngineBaseModule): ...@@ -1530,7 +1496,13 @@ class Linear(TransformerEngineBaseModule):
self.save_original_input, self.save_original_input,
debug, debug,
) )
out = linear_fn(*args) out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype) out = out + cast_if_needed(bias_tensor, self.activation_dtype)
...@@ -1538,7 +1510,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1538,7 +1510,7 @@ class Linear(TransformerEngineBaseModule):
return out, cast_if_needed(bias_tensor, self.activation_dtype) return out, cast_if_needed(bias_tensor, self.activation_dtype)
return out return out
def _get_quantizers(self, fp8_output, fp8_grad): def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
if not self.fp8: if not self.fp8:
return [None] * 6 return [None] * 6
grad_input_quantizer = None grad_input_quantizer = None
...@@ -1550,7 +1522,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1550,7 +1522,7 @@ class Linear(TransformerEngineBaseModule):
(weight_quantizer,) = self._get_weight_quantizers() (weight_quantizer,) = self._get_weight_quantizers()
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled(): if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True grad_output_quantizer.internal = True
if fp8_grad: if fp8_grad:
...@@ -1564,8 +1536,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -1564,8 +1536,8 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) )
def _get_debug_quantizers(self, fp8_output, fp8_grad): def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad) original_quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
assert TEDebugState.debug_enabled assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
...@@ -1620,6 +1592,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1620,6 +1592,7 @@ class Linear(TransformerEngineBaseModule):
self, self,
inp: torch.Tensor, inp: torch.Tensor,
fp8_output: bool, fp8_output: bool,
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
ONNX-compatible version of the forward function that provides numerical equivalence ONNX-compatible version of the forward function that provides numerical equivalence
...@@ -1636,7 +1609,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1636,7 +1609,7 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
*_, *_,
) = self._get_quantizers(fp8_output, False) ) = self._get_quantizers(fp8_output, False, is_grad_enabled)
inp_dtype = inp.dtype inp_dtype = inp.dtype
if input_quantizer is not None: if input_quantizer is not None:
......
...@@ -33,32 +33,29 @@ class RMSNorm(_RMSNormOp): ...@@ -33,32 +33,29 @@ class RMSNorm(_RMSNormOp):
Parameters Parameters
---------- ----------
normalized_shape: int or iterable of int normalized_shape : int or iterable of int
Inner dimensions of input tensor Inner dimensions of input tensor
eps : float, default = 1e-5 eps : float, default = 1e-5
A value added to the denominator for numerical stability A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device device : torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype : torch.dtype, default = default dtype
Tensor datatype Tensor datatype
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = False
If `True`, the :math:`\gamma` parameter is initialized to zero If ``True``, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to and the calculation changes to
.. math:: .. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0 sm_margin : int, default = 0
Number of SMs to exclude when launching CUDA kernels. This Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels. helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward", margin at each compute stage (``"forward"``, ``"backward"``,
"inference"). ``"inference"``).
sequence_parallel : bool
Legacy **Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters.
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
This is custom logic for Megatron-LM integration. This is custom logic for Megatron-LM integration.
""" """
......
...@@ -356,7 +356,9 @@ def onnx_layernorm( ...@@ -356,7 +356,9 @@ def onnx_layernorm(
) )
if normalization == "RMSNorm": if normalization == "RMSNorm":
ln_out = torch.nn.functional.rms_norm(inp, inp.shape[-1:], ln_weight, eps) variance = inp.pow(2).mean(-1, keepdim=True)
ln_out = inp * torch.rsqrt(variance + eps)
ln_out = ln_out * ln_weight
else: else:
ln_out = torch.nn.functional.layer_norm( ln_out = torch.nn.functional.layer_norm(
inp, inp.shape[-1:], ln_weight, layer_norm_bias, eps inp, inp.shape[-1:], ln_weight, layer_norm_bias, eps
......
...@@ -10,7 +10,7 @@ from typing import Optional ...@@ -10,7 +10,7 @@ from typing import Optional
import torch import torch
from transformer_engine_torch import FP8TensorMeta from transformer_engine_torch import FP8TensorMeta
from .. import torch_version from ..torch_version import torch_version
from ..quantization import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor from ..tensor.float8_tensor import Float8Tensor
from ..quantized_tensor import QuantizedTensorStorage from ..quantized_tensor import QuantizedTensorStorage
......
...@@ -53,7 +53,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -53,7 +53,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
Parameters Parameters
---------- ----------
cache_quantized_input: bool, default = False cache_quantized_input : bool, default = False
Quantize input tensor when caching for use in the backward Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is extra compute and increase numerical error. This feature is
...@@ -408,11 +408,11 @@ class ClampedSwiGLU(_ActivationOperation): ...@@ -408,11 +408,11 @@ class ClampedSwiGLU(_ActivationOperation):
Parameters Parameters
---------- ----------
limit: float limit : float
The clamp limit. The clamp limit.
alpha: float alpha : float
The scaling factor for the sigmoid function used in the activation. The scaling factor for the sigmoid function used in the activation.
cache_quantized_input: bool, default = False cache_quantized_input : bool, default = False
Quantize input tensor when caching for use in the backward pass. Quantize input tensor when caching for use in the backward pass.
""" """
......
...@@ -23,7 +23,7 @@ class AllGather(BasicOperation): ...@@ -23,7 +23,7 @@ class AllGather(BasicOperation):
Parameters Parameters
---------- ----------
process_group: torch.distributed.ProcessGroup, default = world group process_group : torch.distributed.ProcessGroup, default = world group
Process group for communication Process group for communication
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment