Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
970620a5
Commit
970620a5
authored
Dec 27, 2025
by
wenjh
Browse files
merge nv_release_v2.10 to release_v2.10
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
c1a1c04e
769ed778
Changes
135
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
591 additions
and
698 deletions
+591
-698
transformer_engine/pytorch/csrc/extensions/attention.cpp
transformer_engine/pytorch/csrc/extensions/attention.cpp
+5
-0
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+15
-0
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+10
-0
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+29
-17
transformer_engine/pytorch/export.py
transformer_engine/pytorch/export.py
+1
-1
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+11
-11
transformer_engine/pytorch/jit.py
transformer_engine/pytorch/jit.py
+1
-1
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+57
-93
transformer_engine/pytorch/module/fp8_padding.py
transformer_engine/pytorch/module/fp8_padding.py
+15
-11
transformer_engine/pytorch/module/fp8_unpadding.py
transformer_engine/pytorch/module/fp8_unpadding.py
+16
-12
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+63
-87
transformer_engine/pytorch/module/layernorm.py
transformer_engine/pytorch/module/layernorm.py
+9
-12
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+106
-137
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+137
-172
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+97
-124
transformer_engine/pytorch/module/rmsnorm.py
transformer_engine/pytorch/module/rmsnorm.py
+10
-13
transformer_engine/pytorch/onnx_extensions.py
transformer_engine/pytorch/onnx_extensions.py
+3
-1
transformer_engine/pytorch/ops/_common.py
transformer_engine/pytorch/ops/_common.py
+1
-1
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+4
-4
transformer_engine/pytorch/ops/basic/all_gather.py
transformer_engine/pytorch/ops/basic/all_gather.py
+1
-1
No files found.
transformer_engine/pytorch/csrc/extensions/attention.cpp
View file @
970620a5
...
...
@@ -115,6 +115,11 @@ std::vector<py::object> fused_attn_fwd(
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
#else
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at
::
cuda
::
CUDAGuard
device_guard
(
cu_seqlens_q
.
device
());
auto
none
=
py
::
none
();
// create QKV tensor wrappers
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
970620a5
...
...
@@ -97,6 +97,11 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
bool
bulk_overlap
,
float
alpha
,
std
::
optional
<
float
>
beta
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at
::
cuda
::
CUDAGuard
device_guard
(
workspace
.
device
());
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
...
...
@@ -353,6 +358,11 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
at
::
Tensor
counter
)
{
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at
::
cuda
::
CUDAGuard
device_guard
(
workspace
.
device
());
// TODO: Handle scaling modes
NVTEScalingMode
nvte_scaling_modeA
=
NVTE_DELAYED_TENSOR_SCALING
;
NVTEScalingMode
nvte_scaling_modeB
=
NVTE_DELAYED_TENSOR_SCALING
;
...
...
@@ -402,6 +412,11 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
NVTE_ERROR
(
"not implemented, D should be allocated for single output case."
);
}
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at
::
cuda
::
CUDAGuard
device_guard
(
workspace
[
0
].
device
());
void
*
output_data_ptr
=
nullptr
;
if
(
single_output
)
{
output_data_ptr
=
(
*
D
)[
0
].
data_ptr
();
...
...
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
970620a5
...
...
@@ -64,6 +64,11 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at
::
cuda
::
CUDAGuard
device_guard
(
input
.
cast
<
at
::
Tensor
>
().
device
());
// Input and param tensors
auto
none
=
py
::
none
();
const
TensorWrapper
&
input_nvte
=
makeTransformerEngineTensor
(
input
,
none
);
...
...
@@ -294,6 +299,11 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at
::
cuda
::
CUDAGuard
device_guard
(
input
.
cast
<
at
::
Tensor
>
().
device
());
// Input and param tensors
auto
none
=
py
::
none
();
const
TensorWrapper
&
input_nvte
=
makeTransformerEngineTensor
(
input
,
none
);
...
...
transformer_engine/pytorch/distributed.py
View file @
970620a5
...
...
@@ -30,7 +30,7 @@ except ImportError:
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.triton.pad
import
pad_columnwise_scale_inv
from
.
import
torch_version
from
.
torch_version
import
torch_version
from
.utils
import
(
is_non_tn_fp8_gemm_supported
,
safely_set_viewless_tensor_data
,
...
...
@@ -642,18 +642,18 @@ def checkpoint(
Parameters
----------
function: Callable
function
: Callable
pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool, default = False
if set to `True` and `use_reentrant=True`, first tensor argument is distributed
across the specified tensor parallel group (`tp_group`) before saving it for the
backward pass. This has no effect when `use_reentrant=False`.
get_rng_state_tracker:
`
Callable
`
, default = None
python callable which returns an instance of :
func
:`CudaRNGStatesTracker`.
distribute_saved_activations
: bool, default = False
if set to
`
`True`
`
and
`
`use_reentrant=True`
`
, first tensor argument is distributed
across the specified tensor parallel group (`
`
tp_group`
`
) before saving it for the
backward pass. This has no effect when
`
`use_reentrant=False`
`
.
get_rng_state_tracker
: Callable, default = None
python callable which returns an instance of :
class
:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when `distribute_saved_activations=True`
and `use_reentrant=True`. If `None`, it falls back to the default group.
tensor parallel process group. Used only when
`
`distribute_saved_activations=True`
`
and
`
`use_reentrant=True`
`
. If
`
`None`
`
, it falls back to the default group.
use_reentrant : bool, default = True
perform checkpointing in reentrant mode.
args : tuple
...
...
@@ -778,8 +778,8 @@ class CudaRNGStatesTracker:
For model parallelism, multiple RNG states need to simultaneously exist in order
to execute operations in or out of the model parallel region. This class keeps
track of the various RNG states and provides utility methods to maintain them and
execute parts of the model under a given RNG setting. Using the `add` method, a
cuda rng state is initialized based on the input `seed` and is assigned to `name`.
execute parts of the model under a given RNG setting. Using the
:meth:
`add` method, a
cuda rng state is initialized based on the input
`
`seed`
`
and is assigned to
`
`name`
`
.
Later, by forking the rng state, we can perform operations and return to our starting
cuda state.
"""
...
...
@@ -812,7 +812,9 @@ class CudaRNGStatesTracker:
Set the rng states. For efficiency purposes, we do not
check the size of seed for compatibility.
states: Dict[str, torch.Tensor]
Parameters
----------
states : Dict[str, torch.Tensor]
A mapping from string names to RNG states.
"""
self
.
states_
=
states
...
...
@@ -821,9 +823,11 @@ class CudaRNGStatesTracker:
"""
Adds a new RNG state.
name: str
Parameters
----------
name : str
string identifier for the RNG state.
seed: int
seed
: int
PyTorch seed for the RNG state.
"""
# Check seed is not already used.
...
...
@@ -857,7 +861,9 @@ class CudaRNGStatesTracker:
Fork the cuda rng state, perform operations, and exit with
the original state.
name: str
Parameters
----------
name : str
string identifier for the RNG state.
"""
# Check if we have added the state
...
...
@@ -948,7 +954,13 @@ def _all_gather_fp8(
if
isinstance
(
inp
,
Float8Tensor
):
dtype
=
inp
.
dtype
device
=
inp
.
device
# Temporarily ensure rowwise usage for output tensor creation
# since we're gathering rowwise data, not the transpose
init_rowwise_usage
=
quantizer
.
rowwise_usage
init_columnwise_usage
=
quantizer
.
columnwise_usage
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
init_columnwise_usage
)
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
quantizer
.
set_usage
(
rowwise
=
init_rowwise_usage
,
columnwise
=
init_columnwise_usage
)
elif
isinstance
(
inp
,
Float8Tensor
):
out
=
inp
.
make_like
(
inp
,
shape
=
out_shape
)
out
.
_data
=
torch
.
empty
(
...
...
@@ -2001,7 +2013,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
Parameters
----------
fsdp_root: torch.nn.Module
fsdp_root
: torch.nn.Module
FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
"""
assert
isinstance
(
fsdp_root
,
FSDP
),
"Root module must be FSDP-wrapped."
...
...
transformer_engine/pytorch/export.py
View file @
970620a5
...
...
@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
Parameters
----------
enabled: bool, default =
`
False
`
enabled
: bool, default = False
whether or not to enable export
"""
...
...
transformer_engine/pytorch/graph.py
View file @
970620a5
...
...
@@ -950,38 +950,38 @@ def make_graphed_callables(
Positional arguments to callable(s).
num_warmup_iters: int, default = 3
Number of warmup iterations.
allow_unused_input: bool, default =
`
False
`
allow_unused_input: bool, default = False
Whether to handle case where callable inputs
and outputs are disconnected in compute graph.
sample_kwargs: (tuple of) dict, optional
Keyword arguments to callable(s)
pool: (tuple of) int, default =
`
None
`
, optional
pool: (tuple of) int, default = None, optional
An instance returned from function `torch.cuda.graph_pool_handle` that hints
this graph may share memory with the indicated pool.
retain_graph_in_backward: bool, default =
`
False
`
retain_graph_in_backward: bool, default = False
Whether to set retain_graph=True in backward graph capture.
_reuse_graph_input_output_buffers: bool, default =
`
False
`
_reuse_graph_input_output_buffers: bool, default = False
Reduce memory usage by reusing input/output data buffers between
graphs. Only supported with Mcore interleaved pipeline parallelism, i.e.
when `_order` is provided. All callables in `modules` are assumed to have
inputs and outputs with the same dtype and shape.
Quantization
related
parameters
----------------------
enabled: (tuple of) bool, default =
`
False
`
Quantization parameters
----------------------
-
enabled: (tuple of) bool, default = False
whether or not to enable low precision quantization (FP8/FP4).
If tuple, the length must match the number of modules.
calibrating: bool, default =
`
False
`
calibrating: bool, default = False
calibration mode allows collecting statistics such as amax and scale
data of quantized tensors even when executing without quantization enabled.
This is useful for saving an inference ready checkpoint while training
using a higher precision.
recipe: recipe.Recipe, default =
`
None
`
recipe: recipe.Recipe, default = None
recipe used for low precision quantization.
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default =
`
None
`
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = None
distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step.
cache_quantized_params: bool, default =
`
False
`
cache_quantized_params: bool, default = False
Whether or not to cache quantized weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward
method for TransformerEngine modules. When storing primary weights in low precision
...
...
transformer_engine/pytorch/jit.py
View file @
970620a5
...
...
@@ -8,7 +8,7 @@ from functools import wraps
from
typing
import
Callable
,
Optional
,
Tuple
import
torch
from
.
import
torch_version
from
.
torch_version
import
torch_version
from
.export
import
is_in_onnx_export_mode
from
.utils
import
gpu_autocast_ctx
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
...
transformer_engine/pytorch/module/base.py
View file @
970620a5
...
...
@@ -20,7 +20,6 @@ import torch.nn.functional as F
from
torch.distributed.tensor
import
DTensor
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
._common
import
_ParameterInitMeta
,
noop_cat
from
..quantization
import
(
...
...
@@ -39,13 +38,18 @@ from ..distributed import (
_fsdp_gather_tensors
,
)
from
..constants
import
dist_group_type
from
..cpp_extensions.gemm
import
_NUM_MAX_UB_STREAMS
from
..quantized_tensor
import
QuantizedTensor
,
QuantizedTensorStorage
,
Quantizer
from
..tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.storage.float8_tensor_storage
import
Float8TensorStorage
from
..tensor.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
from
..utils
import
is_non_tn_fp8_gemm_supported
,
torch_get_autocast_gpu_dtype
from
..utils
import
(
is_non_tn_fp8_gemm_supported
,
torch_get_autocast_gpu_dtype
,
get_nvtx_range_context
,
)
from
..tensor.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
...common.recipe
import
DelayedScaling
,
Recipe
from
...debug.pytorch.debug_state
import
TEDebugState
...
...
@@ -58,13 +62,9 @@ __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
_2X_ACC_FPROP
=
False
_2X_ACC_DGRAD
=
True
_2X_ACC_WGRAD
=
True
_multi_stream_cublas_workspace
=
[]
_dummy_wgrads
=
{}
_multi_stream_cublas_batchgemm_workspace
=
[]
_cublas_workspace
=
None
_ub_communicators
=
None
ub_stream_nums
=
int
(
os
.
getenv
(
"NVTE_UB_STREAM_NUMS"
,
"2"
))
_NUM_MAX_UB_STREAMS
=
ub_stream_nums
if
IS_HIP_EXTENSION
else
3
_MIN_STREAM_PRIORITY
,
_MAX_STREAM_PRIORITY
=
None
,
None
layers_atomic_ring_exchange
=
[]
...
...
@@ -78,38 +78,6 @@ class UserBufferQuantizationMode(Enum):
FP8
=
"fp8"
def
get_cublas_workspace_size_bytes
()
->
None
:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
# Add env for control the padding for blaslt
if
IS_HIP_EXTENSION
:
return
134_217_728
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return
32
*
1024
*
1024
+
1024
return
4_194_304
def
get_workspace
()
->
torch
.
Tensor
:
"""Returns workspace for cublas."""
global
_cublas_workspace
if
_cublas_workspace
is
None
:
_cublas_workspace
=
torch
.
empty
(
get_cublas_workspace_size_bytes
(),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
return
_cublas_workspace
def
get_multi_stream_cublas_workspace
()
->
List
[
torch
.
Tensor
]:
"""Returns workspace for multi-stream cublas."""
global
_multi_stream_cublas_workspace
if
not
_multi_stream_cublas_workspace
:
for
_
in
range
(
tex
.
get_num_cublas_streams
()):
_multi_stream_cublas_workspace
.
append
(
torch
.
empty
(
get_cublas_workspace_size_bytes
(),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
)
return
_multi_stream_cublas_workspace
def
get_multi_stream_cublas_batchgemm_workspace
()
->
List
[
torch
.
Tensor
]:
"""Returns workspace for multi-stream cublas."""
global
_multi_stream_cublas_batchgemm_workspace
...
...
@@ -154,27 +122,27 @@ def initialize_ub(
)
->
None
:
r
"""
Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
GEMM compute in te.Linear
,
te.LayerNormLinear and te.LayerNormMLP modules.
GEMM compute in
``
te.Linear
``, ``
te.LayerNormLinear
``
and
``
te.LayerNormMLP
``
modules.
Parameters
----------
shape : list
shape of the communication buffer, typically set to be the same as the global shape of
the input tensor to a te.TransformerLayer forward pass, with the sequence and batch
dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)`
the input tensor to a
``
te.TransformerLayer
``
forward pass, with the sequence and batch
dimensions collapsed together -- i.e.:
`
`(sequence_length * batch_size, hidden_size)`
`
tp_size : int
number of GPUs in the tensor-parallel process group
use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use `quantization_modes` instead.
DEPRECATED: Please use
`
`quantization_modes`
`
instead.
quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy `use_fp8` parameter if `None` is provided.
falls back to the legacy
`
`use_fp8`
`
parameter if
`
`None`
`
is provided.
dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False`
ub_cfgs: dict = None
Configuration dictionary with the structure
```
non-FP8 data type of the communication buffer when
`
`use_fp8 = False`
`
ub_cfgs
: dict = None
Configuration dictionary with the structure
::
{
<gemm_name> : {
"method": <"ring_exchange" or "pipeline">,
...
...
@@ -189,20 +157,20 @@ def initialize_ub(
"fp8_buf": bool,
}
}
```
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
for
`
`te.TransformerLayer`
`
GEMM layers in
`
`["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]`.
a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes`
"fc2_fprop", "fc2_wgrad"]`
`
.
a list may be provided to specify different overlap configurations for different the quantization settings in
`
`quantization_modes`
`
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
`
`torch.distributed`
`
communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
valid for every cluster configuration and distributed launch method even if
they are available in PyTorch. When left unset, the initialization prefers
to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this
not available. Setting
`
`NVTE_UB_WITH_MPI=1`
`
when building TE overrides this
option and always initializes Userbuffers with direct MPI calls in C++,
which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
which also requires
`
`MPI_HOME=/path/to/mpi/root`
`
to be set at compile time.
"""
if
not
tex
.
device_supports_multicast
():
assert
bool
(
int
(
os
.
getenv
(
"UB_SKIPMC"
,
"1"
))),
(
...
...
@@ -299,16 +267,6 @@ def initialize_ub(
flush
=
True
,
)
# Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls
global
_cublas_workspace
if
_cublas_workspace
is
None
:
_cublas_workspace
=
get_workspace
().
repeat
(
_NUM_MAX_UB_STREAMS
)
elif
_cublas_workspace
.
numel
()
!=
get_cublas_workspace_size_bytes
()
*
_NUM_MAX_UB_STREAMS
:
# This ensures we don't do `.repeat()` on an already expanded workspace
_cublas_workspace
=
torch
.
empty
(
get_cublas_workspace_size_bytes
(),
dtype
=
torch
.
uint8
,
device
=
"cuda"
).
repeat
(
_NUM_MAX_UB_STREAMS
)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
layers_all_gather_overlap
=
[
"qkv_fprop"
,
...
...
@@ -1033,7 +991,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
Parameters
----------
tp_group : ProcessGroup, default =
`
None
`
tp_group : ProcessGroup, default = None
tensor parallel process group.
"""
self
.
tp_group
=
tp_group
...
...
@@ -1123,8 +1081,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
self
.
allow_different_data_and_param_types
=
allow_different_data_and_param_types
self
.
forwarded_at_least_once
=
True
# Activation recomputation is used and this is the second forward phase.
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
delayed_scaling_recipe
=
self
.
fp8_meta
[
"recipe"
].
delayed
()
FP8GlobalStateManager
.
get_old_fp8_meta_tensors_for_recompute
(
self
.
fp8_meta
)
else
:
assert
inp
.
is_cuda
,
"TransformerEngine needs CUDA."
...
...
@@ -1136,25 +1096,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
init_fp8_metadata
(
num_gemms
=
num_gemms
)
self
.
_check_weight_tensor_recipe_correspondence
()
if
self
.
fp8
and
self
.
sequence_parallel
and
self
.
fp8_meta
[
"recipe"
].
delayed
():
delayed_scaling_recipe
=
self
.
fp8
and
self
.
fp8_meta
[
"recipe"
].
delayed
()
if
delayed_scaling_recipe
:
if
self
.
sequence_parallel
:
assert
self
.
fp8_meta
[
"recipe"
].
reduce_amax
,
(
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
)
if
self
.
fp8
and
not
FP8GlobalStateManager
.
fp8_graph_capturing
():
if
not
FP8GlobalStateManager
.
fp8_graph_capturing
():
FP8GlobalStateManager
.
add_fp8_tensors_to_global_buffer
(
self
.
fp8_meta
)
# Activation recomputation is used and this is the first forward phase.
if
self
.
fp8
and
self
.
training
and
is_fp8_activation_recompute_enabled
():
if
self
.
training
and
is_fp8_activation_recompute_enabled
():
FP8GlobalStateManager
.
copy_forward_fp8_meta_tensors_for_recompute
(
self
.
fp8_meta
)
with
torch
.
cuda
.
nvtx
.
range
(
self
.
__class__
.
__name__
+
" forward"
):
with
get_
nvtx
_
range
_context
(
self
.
__class__
.
__name__
+
" forward"
):
if
not
allow_non_contiguous
and
not
inp
.
is_contiguous
():
inp
=
inp
.
contiguous
()
yield
inp
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
if
delayed_scaling_recipe
and
self
.
fp8
and
in_fp8_activation_recompute_phase
():
FP8GlobalStateManager
.
restore_fp8_meta_tensors
(
self
.
fp8_meta
)
def
set_nccl_overlap_warning_if_tp
(
self
)
->
None
:
...
...
@@ -1434,7 +1396,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
workspace is being constructed or updated.
cache_name: str, optional
Key for caching.
update_workspace: bool, default =
`
True
`
update_workspace: bool, default = True
Update workspace with values from `tensor`.
skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence
...
...
@@ -1576,7 +1538,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
if
not
self
.
need_backward_dw
():
return
with
torch
.
cuda
.
nvtx
.
range
(
f
"_
{
self
.
__class__
.
__name__
}
_wgrad"
):
with
get_
nvtx
_
range
_context
(
f
"_
{
self
.
__class__
.
__name__
}
_wgrad"
):
(
wgrad
,
bgrad
),
_
=
self
.
wgrad_store
.
pop
()
if
not
self
.
fuse_wgrad_accumulation
:
weight_tensor
=
noop_cat
(
self
.
_get_weight_tensors
())
...
...
@@ -1673,6 +1635,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
return
if
not
self
.
primary_weights_in_fp8
:
return
if
not
hasattr
(
self
,
"weight_names"
)
or
not
self
.
weight_names
:
return
...
...
transformer_engine/pytorch/module/fp8_padding.py
View file @
970620a5
...
...
@@ -24,11 +24,14 @@ class _Fp8Padding(torch.autograd.Function):
def
forward
(
ctx
,
inp
:
torch
.
Tensor
,
m_splits
:
List
[
int
],
padded_m_splits
:
List
[
int
],
is_grad_enabled
:
bool
,
non_tensor_args
:
Tuple
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
m_splits
,
padded_m_splits
,
is_grad_enabled
)
=
non_tensor_args
# Make sure input dimensions are compatible
in_features
=
inp
.
shape
[
-
1
]
...
...
@@ -65,7 +68,7 @@ class _Fp8Padding(torch.autograd.Function):
grad_output
.
view
(
-
1
,
in_features
),
grad_input
,
ctx
.
padded_m_splits
,
ctx
.
m_splits
)
return
(
grad_input
,
None
,
None
,
None
)
return
grad_input
,
None
class
Fp8Padding
(
torch
.
nn
.
Module
):
...
...
@@ -128,19 +131,20 @@ class Fp8Padding(torch.nn.Module):
if
m_splits
==
padded_m_splits
:
return
inp
,
m_splits
if
torch
.
is_grad_enabled
():
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_grad_enabled
:
fn
=
_Fp8Padding
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
fn
=
_Fp8Padding
.
forward
a
rgs
=
[
None
]
a
utograd_ctx
=
[
None
]
args
+=
(
inp
,
non_tensor_args
=
(
m_splits
,
padded_m_splits
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
)
out
=
fn
(
*
args
)
out
=
fn
(
*
autograd_ctx
,
inp
,
non_tensor_
args
)
return
out
,
padded_m_splits
transformer_engine/pytorch/module/fp8_unpadding.py
View file @
970620a5
...
...
@@ -4,7 +4,7 @@
"""FP8 Padding API"""
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Tuple
import
torch
...
...
@@ -24,11 +24,14 @@ class _Fp8Unpadding(torch.autograd.Function):
def
forward
(
ctx
,
inp
:
torch
.
Tensor
,
m_splits
:
List
[
int
],
padded_m_splits
:
List
[
int
],
is_grad_enabled
:
bool
,
non_tensor_args
:
Tuple
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
m_splits
,
padded_m_splits
,
is_grad_enabled
)
=
non_tensor_args
in_features
=
inp
.
shape
[
-
1
]
# Allocate cast and transpose output tensor
...
...
@@ -63,7 +66,7 @@ class _Fp8Unpadding(torch.autograd.Function):
grad_output
.
view
(
-
1
,
in_features
),
grad_input
,
ctx
.
m_splits
,
ctx
.
padded_m_splits
)
return
(
grad_input
,
None
,
None
,
None
)
return
grad_input
,
None
class
Fp8Unpadding
(
torch
.
nn
.
Module
):
...
...
@@ -126,19 +129,20 @@ class Fp8Unpadding(torch.nn.Module):
if
m_splits
==
padded_m_splits
:
return
inp
if
torch
.
is_grad_enabled
():
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_grad_enabled
:
fn
=
_Fp8Unpadding
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
fn
=
_Fp8Unpadding
.
forward
a
rgs
=
[
None
]
a
utograd_ctx
=
[
None
]
args
+=
(
inp
,
non_tensor_args
=
(
m_splits
,
padded_m_splits
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
)
out
=
fn
(
*
args
)
out
=
fn
(
*
autograd_ctx
,
inp
,
non_tensor_
args
)
return
out
transformer_engine/pytorch/module/grouped_linear.py
View file @
970620a5
...
...
@@ -14,8 +14,6 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
.base
import
(
get_dummy_wgrad
,
get_multi_stream_cublas_workspace
,
get_dummy_wgrad
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
...
...
@@ -30,6 +28,7 @@ from ..utils import (
clear_tensor_data
,
init_method_constant
,
requires_grad
,
get_nvtx_range_context
,
)
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
...
...
@@ -42,7 +41,6 @@ from ..cpp_extensions import (
)
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_not_offload
,
start_offload
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
...
...
@@ -66,28 +64,34 @@ class _GroupedLinear(torch.autograd.Function):
def
forward
(
ctx
,
inp
:
torch
.
Tensor
,
m_splits
:
List
[
int
],
use_bias
:
bool
,
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
input_quantizers
:
List
[
Quantizer
],
weight_quantizers
:
List
[
Quantizer
],
output_quantizers
:
List
[
Quantizer
],
grad_output_quantizers
:
List
[
Quantizer
],
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
sequence_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
is_grad_enabled
:
bool
,
non_tensor_args
:
Tuple
,
*
weights_and_biases
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
m_splits
,
use_bias
,
is_first_microbatch
,
fp8
,
fp8_calibration
,
wgrad_store
,
input_quantizers
,
weight_quantizers
,
output_quantizers
,
grad_output_quantizers
,
fuse_wgrad_accumulation
,
cpu_offloading
,
sequence_parallel
,
activation_dtype
,
is_grad_enabled
,
module
,
skip_fp8_weight_update
,
save_original_input
,
fine_grained_activation_offloading
,
*
weights_and_biases
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
)
=
non_tensor_args
num_gemms
=
len
(
m_splits
)
weights
=
weights_and_biases
[:
num_gemms
]
...
...
@@ -187,7 +191,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmats
,
[
out
],
activation_dtype
,
get_multi_stream_cublas_workspace
(),
single_output
=
True
,
m_splits
=
m_splits
,
bias
=
biases
,
...
...
@@ -313,7 +316,7 @@ class _GroupedLinear(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
# pylint: disable=missing-function-docstring
with
torch
.
cuda
.
nvtx
.
range
(
"_GroupedLinear_backward"
):
with
get_
nvtx
_
range
_context
(
"_GroupedLinear_backward"
):
saved_tensors
=
restore_from_saved
(
ctx
.
tensor_objects
,
ctx
.
saved_tensors
)
N
=
ctx
.
num_gemms
inputmats
=
saved_tensors
[:
N
]
...
...
@@ -404,7 +407,6 @@ class _GroupedLinear(torch.autograd.Function):
grad_output
,
[
dgrad
],
ctx
.
activation_dtype
,
get_multi_stream_cublas_workspace
(),
single_output
=
True
,
layout
=
"NN"
,
m_splits
=
ctx
.
m_splits
,
...
...
@@ -451,7 +453,6 @@ class _GroupedLinear(torch.autograd.Function):
grouped_gemm_wgrad
=
functools
.
partial
(
general_grouped_gemm
,
out_dtype
=
ctx
.
activation_dtype
,
workspaces
=
get_multi_stream_cublas_workspace
(),
layout
=
"NT"
,
grad
=
True
,
m_splits
=
ctx
.
m_splits
,
...
...
@@ -523,29 +524,11 @@ class _GroupedLinear(torch.autograd.Function):
):
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
()
:
if
ctx
.
reduce_and_update_bwd_fp8_tensors
:
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
return
(
dgrad
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
wgrad_list
,
*
grad_biases
,
)
...
...
@@ -563,14 +546,14 @@ class GroupedLinear(TransformerEngineBaseModule):
size of each input sample.
out_features : int
size of each output sample.
bias : bool, default =
`
True
`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default =
`
None
`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default =
`
None
`
bias : bool, default = True
if set to
`
`False`
`
, the layer will not learn an additive bias.
init_method : Callable, default = None
used for initializing weights in the following way:
`
`init_method(weight)`
`
.
When set to
`
`None`
`
, defaults to
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
get_rng_state_tracker : Callable, default = None
used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default =
`
None
`
rng_tracker_name : str, default = None
the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
...
...
@@ -579,33 +562,35 @@ class GroupedLinear(TransformerEngineBaseModule):
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default =
'
False
'
if set to `True`, enables fusing of creation and accumulation of
fuse_wgrad_accumulation : bool, default = False
if set to
`
`True`
`
, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
have an additional
`
`main_grad`
`
attribute (used instead of the
regular
`
`grad`
`
) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default =
`
False
`
when set to `True`, this module will not apply the additive bias itself, but
will overwrite
`
`main_grad`
`
instead of accumulating.
return_bias : bool, default = False
when set to
`
`True`
`
, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default =
`
torch.get_default_dtype()
`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default =
`
False
`
delay_wgrad_compute : bool, default = False
Whether to delay weight gradient computation
save_original_input : bool, default =
`
False
`
If set to `True`, always saves the original input tensor rather than the
save_original_input : bool, default = False
If set to
`
`True`
`
, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
`parallel_mode` are used to determine the shapes of weights and biases.
Notes
-----
GroupedLinear doesn't really handle the TP communications inside. The ``tp_size`` and
``parallel_mode`` are used to determine the shapes of weights and biases.
The TP communication should be handled in the dispatch and combine stages of MoE models.
"""
...
...
@@ -807,16 +792,9 @@ class GroupedLinear(TransformerEngineBaseModule):
),
"GroupedLinear doesn't support input tensor in FP8."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
if
FP8GlobalStateManager
.
fp8_graph_capturing
():
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
else
:
skip_fp8_weight_update
=
None
if
skip_fp8_weight_update
is
not
None
:
is_first_microbatch
=
False
is_grad_enabled
=
torch
.
is_grad_enabled
()
with
torch
.
cuda
.
device
(
getattr
(
self
,
list
(
self
.
named_parameters
())[
0
][
0
]).
device
),
self
.
prepare_forward
(
inp
,
num_gemms
=
self
.
num_gemms
)
as
inp
:
with
self
.
prepare_forward
(
inp
,
num_gemms
=
self
.
num_gemms
)
as
inp
:
weight_tensors
=
self
.
_get_weight_tensors
()
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
...
...
@@ -836,7 +814,7 @@ class GroupedLinear(TransformerEngineBaseModule):
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for
i
in
range
(
self
.
num_gemms
):
input_quantizers
[
i
].
internal
=
False
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
grad_output_quantizers
=
[
self
.
quantizers
[
"scaling_bwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
...
...
@@ -846,14 +824,14 @@ class GroupedLinear(TransformerEngineBaseModule):
for
i
in
range
(
self
.
num_gemms
):
grad_output_quantizers
[
i
].
internal
=
True
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
linear_fn
=
_GroupedLinear
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
linear_fn
=
_GroupedLinear
.
forward
a
rgs
=
[
None
]
args
+=
(
inp
,
a
utograd_ctx
=
[
None
]
non_tensor_args
=
(
m_splits
,
self
.
apply_bias
,
is_first_microbatch
,
...
...
@@ -868,15 +846,13 @@ class GroupedLinear(TransformerEngineBaseModule):
is_cpu_offload_enabled
(),
self
.
sequence_parallel
,
self
.
activation_dtype
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
self
,
skip_fp8_weight_update
,
None
,
#
skip_fp8_weight_update
self
.
save_original_input
,
self
.
fine_grained_activation_offloading
,
*
weight_tensors
,
*
bias_tensors
,
)
out
=
linear_fn
(
*
ar
g
s
)
out
=
linear_fn
(
*
a
utograd_ctx
,
inp
,
non_tensor_args
,
*
weight_tensors
,
*
bias_tenso
rs
)
if
self
.
return_bias
:
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
...
...
@@ -889,7 +865,7 @@ class GroupedLinear(TransformerEngineBaseModule):
"""
if
not
self
.
need_backward_dw
():
return
with
torch
.
cuda
.
nvtx
.
range
(
"_GroupedLinear_wgrad"
):
with
get_
nvtx
_
range
_context
(
"_GroupedLinear_wgrad"
):
(
_
,
grad_biases_
,
_
),
tensor_list
=
self
.
wgrad_store
.
pop
()
wgrad_list
=
tensor_list
[
2
]
weight_params
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
...
...
transformer_engine/pytorch/module/layernorm.py
View file @
970620a5
...
...
@@ -28,33 +28,30 @@ class LayerNorm(_LayerNormOp):
Parameters
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator of layer normalization for
numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
If
`
`True`
`
, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0
sm_margin
: int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
Legacy
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
margin at each compute stage (``"forward"``, ``"backward"``,
``"inference"``).
sequence_parallel : bool
**Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters.
This is custom logic for Megatron-LM integration.
"""
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
970620a5
...
...
@@ -15,11 +15,10 @@ from torch.nn import init
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
.torch_version
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_custom
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
get_ub
,
TransformerEngineBaseModule
,
get_dummy_wgrad
,
...
...
@@ -40,6 +39,7 @@ from ..utils import (
nvtx_range_push
,
requires_grad
,
needs_quantized_gemm
,
get_nvtx_range_context
,
get_activation_offloading
,
)
from
..distributed
import
(
...
...
@@ -104,48 +104,54 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias
:
Union
[
torch
.
Tensor
,
None
],
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
,
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
fuse_wgrad_accumulation
:
bool
,
input_quantizer
:
Optional
[
Quantizer
],
weight_quantizer
:
Optional
[
Quantizer
],
output_quantizer
:
Optional
[
Quantizer
],
grad_input_quantizer
:
Optional
[
Quantizer
],
grad_weight_quantizer
:
Optional
[
Quantizer
],
grad_output_quantizer
:
Optional
[
Quantizer
],
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
tp_size
:
int
,
sequence_parallel
:
bool
,
tensor_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
parallel_mode
:
Union
[
str
,
None
],
return_layernorm_output
:
bool
,
return_layernorm_output_gathered
:
bool
,
is_grad_enabled
:
bool
,
fwd_ln_sm_margin
:
int
,
bwd_ln_sm_margin
:
int
,
zero_centered_gamma
:
bool
,
normalization
:
str
,
ub_overlap_ag_fprop
:
bool
,
ub_overlap_rs_fprop
:
bool
,
ub_overlap_ag_dgrad
:
bool
,
ub_overlap_rs_dgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
ub_name
:
str
,
fine_grained_activation_offloading
:
bool
,
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
debug
:
Optional
[
bool
]
=
False
,
non_tensor_args
:
Tuple
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps
,
is_first_microbatch
,
fp8
,
fp8_calibration
,
wgrad_store
,
fuse_wgrad_accumulation
,
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_input_quantizer
,
grad_weight_quantizer
,
grad_output_quantizer
,
cpu_offloading
,
tp_group
,
tp_size
,
sequence_parallel
,
tensor_parallel
,
activation_dtype
,
parallel_mode
,
return_layernorm_output
,
return_layernorm_output_gathered
,
is_grad_enabled
,
fwd_ln_sm_margin
,
bwd_ln_sm_margin
,
zero_centered_gamma
,
normalization
,
ub_overlap_ag_fprop
,
ub_overlap_rs_fprop
,
ub_overlap_ag_dgrad
,
ub_overlap_rs_dgrad
,
ub_bulk_wgrad
,
ub_bulk_dgrad
,
ub_name
,
fine_grained_activation_offloading
,
fsdp_group
,
module
,
skip_fp8_weight_update
,
symmetric_ar_type
,
debug
,
)
=
non_tensor_args
# NVTX label for profiling
nvtx_label
=
"transformer_engine._LayerNormLinear.forward"
if
ub_name
is
not
None
:
...
...
@@ -364,7 +370,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weightmat
,
ln_out_total
,
get_workspace
(),
quantization_params
=
output_quantizer
,
out_dtype
=
activation_dtype
,
bias
=
bias
,
...
...
@@ -553,7 +558,7 @@ class _LayerNormLinear(torch.autograd.Function):
if
ctx
.
ub_name
is
not
None
:
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormLinear_backward"
):
with
get_
nvtx
_
range
_context
(
"_LayerNormLinear_backward"
):
saved_tensors
=
ctx
.
saved_tensors
(
# pylint: disable=unbalanced-tuple-unpacking
inputmat
,
...
...
@@ -742,7 +747,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight
,
grad_output
,
get_workspace
(),
layout
=
"NN"
,
grad
=
True
,
quantization_params
=
ctx
.
grad_input_quantizer
,
...
...
@@ -869,7 +873,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
...
...
@@ -1044,45 +1047,7 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta
,
wgrad
,
grad_bias
,
None
,
# eps
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# fuse_wgrad_accumulation
None
,
# input_quantizer
None
,
# weight_quantizer
None
,
# output_quantizer
None
,
# grad_input_quantizer
None
,
# grad_weight_quantizer
None
,
# grad_output_quantizer
None
,
# cpu_offloading
None
,
# tp_group
None
,
# tp_size
None
,
# sequence_parallel
None
,
# tensor_parallel
None
,
# activation_dtype
None
,
# parallel_mode
None
,
# return_layernorm_output
None
,
# return_layernorm_output_gathered
None
,
# is_grad_enabled
None
,
# fwd_ln_sm_margin
None
,
# bwd_ln_sm_margin
None
,
# zero_centered_gamma
None
,
# normalization
None
,
# ub_overlap_ag_fprop
None
,
# ub_overlap_rs_fprop
None
,
# ub_overlap_ag_dgrad
None
,
# ub_overlap_rs_dgrad
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# fine_grained_activation_offloading
None
,
# fsdp_group
None
,
# debug
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
None
,
)
...
...
@@ -1098,20 +1063,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
size of each output sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default =
`
True
`
if set to `False`, the layer will not learn an additive bias.
bias : bool, default = True
if set to
`
`False`
`
, the layer will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
init_method : Callable, default =
`
None
`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
return_layernorm_output : bool, default =
`
False
`
if set to `True`, output of layernorm is returned from the forward
init_method : Callable, default = None
used for initializing weights in the following way:
`
`init_method(weight)`
`
.
When set to
`
`None`
`
, defaults to
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
return_layernorm_output : bool, default = False
if set to
`
`True`
`
, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
return_layernorm_output_gathered : bool, default =
`
False
`
if set to `True`, output of layernorm is returned after the all
return_layernorm_output_gathered : bool, default = False
if set to
`
`True`
`
, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False.
Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered.
...
...
@@ -1122,10 +1087,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
they are used to make the names of equally-sized parameters. If a dict
(preferably an OrderedDict) is provided, the keys are used as names and
values as split sizes along dim 0. The resulting parameters will have
names that end in `_weight` or `_bias`, so trailing underscores are
names that end in
`
`_weight`
`
or
`
`_bias`
`
, so trailing underscores are
stripped from any provided names.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
if set to
``
'True'
``
, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
...
...
@@ -1135,53 +1100,53 @@ class LayerNormLinear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default =
`
None
`
name
: str, default = None
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
sequence_parallel : bool, default =
`
False
`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default =
`
None
`
sequence_parallel : bool, default = False
if set to
`
`True`
`
, uses sequence parallelism.
tp_group : ProcessGroup, default = None
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
`
`set_tensor_parallel_group(tp_group)`
`
method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'column', 'row'}, default =
`
None
`
parallel_mode : {None, 'column', 'row'}, default = None
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
When set to
`
`None`
`
, no communication is performed.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
if set to
`
`True`
`
, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
have an additional
`
`main_grad`
`
attribute (used instead of the
regular
`
`grad`
`
) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default =
`
False
`
when set to `True`, this module will not apply the additive bias itself, but
will overwrite
`
`main_grad`
`
instead of accumulating.
return_bias : bool, default = False
when set to
`
`True`
`
, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default =
`
torch.get_default_dtype()
`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default =
`
False
`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to
`
`True`
`
,
it's the user's responsibility to call
`
`module.backward_dw`
`
to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
Requires PyTorch version 2.7.0 or higher. When set to
``
None
``
, standard all-reduce
is used.
"""
...
...
@@ -1544,8 +1509,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_in_onnx_export_mode
():
return
self
.
onnx_forward
(
inp
,
fp8_output
)
return
self
.
onnx_forward
(
inp
,
fp8_output
,
is_grad_enabled
)
debug
=
self
.
is_debug_iter
()
...
...
@@ -1567,9 +1534,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
).
is_fp8_ubuf
():
fp8_grad
=
True
with
torch
.
cuda
.
device
(
getattr
(
self
,
list
(
self
.
named_parameters
())[
0
][
0
]).
device
),
self
.
prepare_forward
(
with
self
.
prepare_forward
(
inp
,
allow_non_contiguous
=
False
# removed .contiguous from inside the layer
)
as
inp
:
...
...
@@ -1577,14 +1542,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
)
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
(
input_quantizer
,
...
...
@@ -1595,18 +1560,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer
,
)
=
quantizers
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
fwd_fn
=
_LayerNormLinear
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
fwd_fn
=
_LayerNormLinear
.
forward
args
=
[
None
]
args
+=
(
inp
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
weight_tensor
,
bias_tensor
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
autograd_ctx
=
[
None
]
non_tensor_args
=
(
self
.
eps
,
is_first_microbatch
,
self
.
fp8
,
...
...
@@ -1628,8 +1588,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
parallel_mode
,
self
.
return_layernorm_output
,
self
.
return_layernorm_output_gathered
,
torch
.
is_grad_enabled
()
,
self
.
fwd_ln_sm_margin
if
torch
.
is_grad_enabled
()
else
self
.
inf_ln_sm_margin
,
is_grad_enabled
,
self
.
fwd_ln_sm_margin
if
is_grad_enabled
else
self
.
inf_ln_sm_margin
,
self
.
bwd_ln_sm_margin
,
self
.
zero_centered_gamma
,
self
.
normalization
,
...
...
@@ -1647,7 +1607,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
symmetric_ar_type
,
debug
,
)
out
=
fwd_fn
(
*
args
)
out
=
fwd_fn
(
*
autograd_ctx
,
inp
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
weight_tensor
,
bias_tensor
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
non_tensor_args
,
)
if
self
.
return_layernorm_output
:
out
,
ln_out
=
out
...
...
@@ -1663,7 +1631,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return
out
,
ln_out
return
out
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
):
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
,
is_grad_enabled
):
if
not
self
.
fp8
:
return
[
None
]
*
6
grad_input_quantizer
=
None
...
...
@@ -1675,7 +1643,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
(
weight_quantizer
,)
=
self
.
_get_weight_quantizers
()
if
fp8_output
:
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
grad_output_quantizer
.
internal
=
True
if
fp8_grad
:
...
...
@@ -1690,8 +1658,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer
,
)
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
):
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
,
is_grad_enabled
):
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
assert
TEDebugState
.
debug_enabled
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
...
...
@@ -1716,6 +1684,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
,
inp
:
torch
.
Tensor
,
fp8_output
:
bool
,
is_grad_enabled
:
bool
,
)
->
torch
.
Tensor
:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
...
...
@@ -1731,7 +1700,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer
,
output_quantizer
,
*
_
,
)
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
=
False
)
)
=
self
.
_get_quantizers
(
fp8_output
,
False
,
is_grad_enabled
)
inp_dtype
=
inp
.
dtype
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
970620a5
...
...
@@ -17,11 +17,10 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
.torch_version
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_custom
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
_ub_communicators
,
get_ub
,
TransformerEngineBaseModule
,
...
...
@@ -46,6 +45,7 @@ from ..utils import (
clear_tensor_data
,
requires_grad
,
needs_quantized_gemm
,
get_nvtx_range_context
,
)
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
...
...
@@ -181,55 +181,61 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias
:
torch
.
Tensor
,
fc2_weight
:
torch
.
Tensor
,
fc2_bias
:
torch
.
Tensor
,
eps
:
float
,
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
fuse_wgrad_accumulation
:
bool
,
fc1_input_quantizer
:
Optional
[
Quantizer
],
fc1_weight_quantizer
:
Optional
[
Quantizer
],
fc1_output_quantizer
:
Optional
[
Quantizer
],
fc1_grad_input_quantizer
:
Optional
[
Quantizer
],
fc1_grad_weight_quantizer
:
Optional
[
Quantizer
],
fc1_grad_output_quantizer
:
Optional
[
Quantizer
],
fc2_input_quantizer
:
Optional
[
Quantizer
],
fc2_weight_quantizer
:
Optional
[
Quantizer
],
fc2_output_quantizer
:
Optional
[
Quantizer
],
fc2_grad_input_quantizer
:
Optional
[
Quantizer
],
fc2_grad_weight_quantizer
:
Optional
[
Quantizer
],
fc2_grad_output_quantizer
:
Optional
[
Quantizer
],
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
tp_size
:
int
,
sequence_parallel
:
bool
,
tensor_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
return_layernorm_output
:
bool
,
return_layernorm_output_gathered
:
bool
,
bias_gelu_fusion
:
bool
,
set_parallel_mode
:
bool
,
is_grad_enabled
:
bool
,
fwd_ln_sm_margin
:
int
,
bwd_ln_sm_margin
:
int
,
zero_centered_gamma
:
bool
,
activation
:
str
,
activation_params
:
Optional
[
dict
],
normalization
:
str
,
ub_overlap_ag
:
bool
,
ub_overlap_rs
:
bool
,
ub_overlap_rs_dgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
gemm_gelu_fusion
:
bool
,
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
debug
:
Optional
[
bool
]
=
False
,
non_tensor_args
:
Tuple
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps
,
is_first_microbatch
,
fp8
,
fp8_calibration
,
wgrad_store
,
fuse_wgrad_accumulation
,
fc1_input_quantizer
,
fc1_weight_quantizer
,
fc1_output_quantizer
,
fc1_grad_input_quantizer
,
fc1_grad_weight_quantizer
,
fc1_grad_output_quantizer
,
fc2_input_quantizer
,
fc2_weight_quantizer
,
fc2_output_quantizer
,
fc2_grad_input_quantizer
,
fc2_grad_weight_quantizer
,
fc2_grad_output_quantizer
,
cpu_offloading
,
tp_group
,
tp_size
,
sequence_parallel
,
tensor_parallel
,
activation_dtype
,
return_layernorm_output
,
return_layernorm_output_gathered
,
bias_gelu_fusion
,
set_parallel_mode
,
is_grad_enabled
,
fwd_ln_sm_margin
,
bwd_ln_sm_margin
,
zero_centered_gamma
,
activation
,
activation_params
,
normalization
,
ub_overlap_ag
,
ub_overlap_rs
,
ub_overlap_rs_dgrad
,
ub_bulk_wgrad
,
ub_bulk_dgrad
,
gemm_gelu_fusion
,
fsdp_group
,
module
,
skip_fp8_weight_update
,
symmetric_ar_type
,
debug
,
)
=
non_tensor_args
# Make sure input dimensions are compatible
in_features
,
inp_shape
=
ln_weight
.
numel
(),
inp
.
shape
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
...
...
@@ -440,7 +446,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_outputs
=
general_gemm
(
fc1_weight_final
,
ln_out_total
,
get_workspace
(),
quantization_params
=
(
fc2_input_quantizer
if
gemm_gelu_fusion
...
...
@@ -524,7 +529,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
fc2_weight_final
,
act_out
,
get_workspace
(),
out_dtype
=
activation_dtype
,
bias
=
fc2_bias
,
quantization_params
=
fc2_output_quantizer
,
...
...
@@ -711,7 +715,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
,
*
grad_outputs
:
Tuple
[
torch
.
Tensor
,
...]
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
# pylint: disable=missing-function-docstring
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormMLP_backward"
):
with
get_
nvtx
_
range
_context
(
"_LayerNormMLP_backward"
):
saved_tensors
=
ctx
.
saved_tensors
(
# pylint: disable=unbalanced-tuple-unpacking
inputmat
,
...
...
@@ -881,7 +885,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_output
,
*
_
=
general_gemm
(
fc2_weight
,
grad_output
,
get_workspace
(),
layout
=
"NN"
,
grad
=
True
,
quantization_params
=
(
...
...
@@ -975,7 +978,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
fc2_wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
origin_fc2_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
...
...
@@ -1153,7 +1155,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
fc1_weight
,
dact
,
get_workspace
(),
out
=
gemm_out
,
out_dtype
=
ctx
.
activation_dtype
,
quantization_params
=
ctx
.
fc1_grad_input_quantizer
,
...
...
@@ -1232,7 +1233,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
fc1_wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
origin_fc1_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
...
...
@@ -1427,52 +1427,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias_grad
if
fc1_bias
is
not
None
else
None
,
fc2_wgrad
,
# pylint: disable=possibly-used-before-assignment
fc2_bias_grad
,
None
,
# eps
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# fuse_wgrad_accumulation
None
,
# fc1_input_quantizer,
None
,
# fc1_weight_quantizer,
None
,
# fc1_output_quantizer,
None
,
# fc1_grad_input_quantizer,
None
,
# fc1_grad_weight_quantizer,
None
,
# fc1_grad_output_quantizer,
None
,
# fc2_input_quantizer,
None
,
# fc2_weight_quantizer,
None
,
# fc2_output_quantizer,
None
,
# fc2_grad_input_quantizer,
None
,
# fc2_grad_weight_quantizer,
None
,
# fc2_grad_output_quantizer,
None
,
# cpu_offloading
None
,
# tp_group
None
,
# tp_size
None
,
# sequence_parallel
None
,
# tensor_parallel
None
,
# activation_dtype
None
,
# return_layernorm_output
None
,
# return_layernorm_output_gathered
None
,
# bias_gelu_fusion
None
,
# set_parallel_mode
None
,
# is_grad_enabled
None
,
# fwd_ln_sm_margin
None
,
# bwd_ln_sm_margin
None
,
# zero_centered_gamma
None
,
# activation
None
,
# activation_params
None
,
# normalization
None
,
# ub_overlap_ag
None
,
# ub_overlap_rs
None
,
# ub_overlap_rs_dgrad
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# gemm_gelu_fusion
None
,
# fsdp_group
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
None
,
# debug
None
,
)
...
...
@@ -1489,38 +1444,38 @@ class LayerNormMLP(TransformerEngineBaseModule):
intermediate size to which input samples are projected.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default =
`
True
`
if set to `False`, the FC1 and FC2 layers will not learn an additive bias.
bias : bool, default = True
if set to
`
`False`
`
, the FC1 and FC2 layers will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
activation : str, default = 'gelu'
activation function used.
Options: 'gelu'
,
'geglu'
,
'qgelu'
,
'qgeglu'
, 'relu',
'reglu'
,
'srelu'
,
'sreglu',
'silu',
'swiglu', and 'clamped_swiglu'.
activation_params : dict, default =
`
None
`
Options:
``
'gelu'
``, ``
'geglu'
``, ``
'qgelu'
``, ``
'qgeglu'
``, ``'relu'``, ``
'reglu'
``, ``
'srelu'
``, ``
'sreglu'
``
,
``'silu'``, ``
'swiglu'
``
, and
``
'clamped_swiglu'
``
.
activation_params : dict, default = None
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
supports 'limit' and 'alpha' parameters.
init_method : Callable, default =
`
None
`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
output_layer_init_method : Callable, default =
`
None
`
At the moment, only used for
``
'clamped_swiglu'
``
activation which
supports
``
'limit'
``
and
``
'alpha'
``
parameters.
init_method : Callable, default = None
used for initializing FC1 weights in the following way:
`
`init_method(weight)`
`
.
When set to
`
`None`
`
, defaults to
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
output_layer_init_method : Callable, default = None
used for initializing FC2 weights in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
return_layernorm_output : bool, default =
`
False
`
if set to `True`, output of layernorm is returned from the
forwar
d
`
`output_layer_init_method(weight)`
`
. When set to
`
`None`
`
, defaults to
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
return_layernorm_output : bool, default = False
if set to
`
`True`
`
, output of layernorm is returned from the
:meth:`forward` metho
d
together with the output of the linear transformation.
Example use case: residual connection for transformer module
is taken post layernorm.
return_layernorm_output_gathered : bool, default =
`
False
`
if set to `True`, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False.
return_layernorm_output_gathered : bool, default = False
if set to
`
`True`
`
, output of layernorm is returned after the all
gather operation. Ignored if
``
return_layernorm_output
``
is False.
Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered.
Returning layernorm output gathered will prevent a redundant gather.
zero_centered_gamma : bool, default =
'
False
'
if set to
'
True
'
, gamma parameter in LayerNorm is initialized to 0 and
zero_centered_gamma : bool, default = False
if set to
``
True
``
, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
...
...
@@ -1530,61 +1485,65 @@ class LayerNormMLP(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default =
`
None
`
name
: str, default = None
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
set_parallel_mode : bool, default =
`
False
`
if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row
set_parallel_mode : bool, default = False
if set to
`
`True`
`
, FC1 is used as Column Parallel and FC2 is used as Row
Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default =
`
False
`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default =
`
None
`
sequence_parallel : bool, default = False
if set to
`
`True`
`
, uses sequence parallelism.
tp_group : ProcessGroup, default = None
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
`
`set_tensor_parallel_group(tp_group)`
`
method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default =
'
False
'
if set to `True`, enables fusing of creation and accumulation of
fuse_wgrad_accumulation : bool, default = False
if set to
`
`True`
`
, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
have an additional
`
`main_grad`
`
attribute (used instead of the
regular
`
`grad`
`
) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default =
`
False
`
when set to `True`, this module will not apply the additive bias for FC2, but
weight tensor having attribute
``
'overwrite_main_grad'
``
set to True
will overwrite
`
`main_grad`
`
instead of accumulating.
return_bias : bool, default = False
when set to
`
`True`
`
, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default =
`
torch.get_default_dtype()
`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
seq_length: int
seq_length
: int
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward
propogation and activation recompute phase.
micro_batch_size: int
micro_batch_size
: int
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase.
delay_wgrad_compute : bool, default =
`
False
`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call
`module.
backward_dw` to compute
delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to
`
`True`
`
,
it's the user's responsibility to call
:meth:`
backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
Requires PyTorch version 2.7.0 or higher. When set to
``
None
``
, standard all-reduce
is used.
checkpoint : bool, default = False
whether to use selective activation checkpointing, where activations are not saved for bwd,
and instead are recomputed (skipping fc2, as it is not needed for backward). Trades compute
for memory. default is false, in which activations are saved in fwd. not supported for onnx forward
"""
def
__init__
(
...
...
@@ -1855,8 +1814,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_in_onnx_export_mode
():
return
self
.
onnx_forward
(
inp
)
return
self
.
onnx_forward
(
inp
,
is_grad_enabled
)
debug
=
self
.
is_debug_iter
()
...
...
@@ -1872,19 +1833,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
if
get_ub
(
"fc2_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()).
is_fp8_ubuf
():
fp8_output
=
True
with
torch
.
cuda
.
device
(
getattr
(
self
,
list
(
self
.
named_parameters
())[
0
][
0
]).
device
),
self
.
prepare_forward
(
inp
,
num_gemms
=
2
)
as
inp
:
with
self
.
prepare_forward
(
inp
,
num_gemms
=
2
)
as
inp
:
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
)
self
.
_get_quantizers
(
fp8_output
,
is_grad_enabled
)
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
)
else
self
.
_get_debug_quantizers
(
fp8_output
,
is_grad_enabled
)
)
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
quantizers
=
self
.
_get_quantizers
(
fp8_output
)
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
is_grad_enabled
)
# Get quantizers
(
...
...
@@ -1917,20 +1876,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
and
self
.
bias_gelu_nvfusion
and
not
use_reentrant_activation_recompute
()
):
self
.
bias_gelu_nvfusion
=
False
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
fwd_fn
=
_LayerNormMLP
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
fwd_fn
=
_LayerNormMLP
.
forward
args
=
[
None
]
args
+=
(
inp
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
fc1_weight
,
fc1_bias
,
fc2_weight
,
fc2_bias
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
autograd_ctx
=
[
None
]
non_tensor_args
=
(
self
.
eps
,
is_first_microbatch
,
self
.
fp8
,
...
...
@@ -1959,8 +1912,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
return_layernorm_output_gathered
,
self
.
bias_gelu_nvfusion
and
not
self
.
fp8
and
not
debug
,
self
.
set_parallel_mode
,
torch
.
is_grad_enabled
()
,
self
.
fwd_ln_sm_margin
if
torch
.
is_grad_enabled
()
else
self
.
inf_ln_sm_margin
,
is_grad_enabled
,
self
.
fwd_ln_sm_margin
if
is_grad_enabled
else
self
.
inf_ln_sm_margin
,
self
.
bwd_ln_sm_margin
,
self
.
zero_centered_gamma
,
self
.
activation
,
...
...
@@ -1978,7 +1931,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
symmetric_ar_type
,
debug
,
)
out
=
fwd_fn
(
*
args
)
out
=
fwd_fn
(
*
autograd_ctx
,
inp
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
fc1_weight
,
fc1_bias
,
fc2_weight
,
fc2_bias
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
non_tensor_args
,
)
if
self
.
return_layernorm_output
:
out
,
ln_out
=
out
...
...
@@ -1994,7 +1957,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return
out
,
ln_out
return
out
def
_get_quantizers
(
self
,
fp8_output
):
def
_get_quantizers
(
self
,
fp8_output
,
is_grad_enabled
):
(
fc1_input_quantizer
,
fc1_output_quantizer
,
...
...
@@ -2024,7 +1987,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_OUTPUT
]
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
fc2_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
]
...
...
@@ -2049,9 +2012,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer
,
)
def
onnx_forward
(
self
,
inp
:
torch
.
Tensor
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
def
onnx_forward
(
self
,
inp
:
torch
.
Tensor
,
is_grad_enabled
:
bool
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
"""
ONNX-compatible version of the
forward function
that provides numerical equivalence
ONNX-compatible version of the
:meth:`forward` method
that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
...
...
@@ -2066,7 +2031,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer
,
output_quantizer
,
*
_
,
)
=
self
.
_get_quantizers
(
False
)
)
=
self
.
_get_quantizers
(
False
,
is_grad_enabled
)
inp_dtype
=
inp
.
dtype
fc1_weight
,
fc2_weight
=
self
.
_get_weight_tensors
()
...
...
@@ -2151,10 +2116,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
return
fc2_out
,
fc2_bias
.
to
(
inp_dtype
)
return
fc2_out
def
_get_debug_quantizers
(
self
,
fp8_output
):
def
_get_debug_quantizers
(
self
,
fp8_output
,
is_grad_enabled
):
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
base_quantizers
=
list
(
self
.
_get_quantizers
(
fp8_output
))
base_quantizers
=
list
(
self
.
_get_quantizers
(
fp8_output
,
is_grad_enabled
))
assert
TEDebugState
.
debug_enabled
def
make_debug
(
prefix
,
offset
):
...
...
@@ -2297,7 +2262,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""
if
not
self
.
need_backward_dw
():
return
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormMLP_wgrad"
):
with
get_
nvtx
_
range
_context
(
"_LayerNormMLP_wgrad"
):
(
fc2_wgrad
,
fc2_bias_grad_
,
*
_
),
tensor_list_fc2
=
self
.
wgrad_store
.
pop
()
if
self
.
use_bias
and
self
.
fc1_bias
.
grad
is
None
:
(
fc1_wgrad
,
fc1_bias_grad
,
*
_
),
_
=
self
.
wgrad_store
.
pop
()
...
...
transformer_engine/pytorch/module/linear.py
View file @
970620a5
...
...
@@ -14,13 +14,12 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
.torch_version
import
torch_version
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_ub
,
get_workspace
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
...
...
@@ -39,6 +38,7 @@ from ..utils import (
assert_dim_for_all_gather
,
nvtx_range_pop
,
nvtx_range_push
,
get_nvtx_range_context
,
get_activation_offloading
,
)
from
..distributed
import
(
...
...
@@ -92,43 +92,47 @@ class _Linear(torch.autograd.Function):
weight
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
input_quantizer
:
Optional
[
Quantizer
],
weight_quantizer
:
Optional
[
Quantizer
],
output_quantizer
:
Optional
[
Quantizer
],
grad_input_quantizer
:
Optional
[
Quantizer
],
grad_weight_quantizer
:
Optional
[
Quantizer
],
grad_output_quantizer
:
Optional
[
Quantizer
],
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
tp_size
:
int
,
sequence_parallel
:
bool
,
tensor_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
parallel_mode
:
Union
[
str
,
None
],
is_grad_enabled
:
bool
,
ub_overlap_rs_fprop
:
bool
,
ub_overlap_ag_dgrad
:
bool
,
ub_overlap_ag_fprop
:
bool
,
ub_overlap_rs_dgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_name
:
str
,
fine_grained_activation_offloading
:
bool
,
fp8_output
:
bool
,
# pylint: disable=unused-argument
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
save_original_input
:
bool
=
False
,
debug
:
Optional
[
bool
]
=
False
,
non_tensor_args
:
Tuple
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
(
is_first_microbatch
,
fp8
,
fp8_calibration
,
wgrad_store
,
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_input_quantizer
,
grad_weight_quantizer
,
grad_output_quantizer
,
fuse_wgrad_accumulation
,
cpu_offloading
,
tp_group
,
tp_size
,
sequence_parallel
,
tensor_parallel
,
activation_dtype
,
parallel_mode
,
is_grad_enabled
,
ub_overlap_rs_fprop
,
ub_overlap_ag_dgrad
,
ub_overlap_ag_fprop
,
ub_overlap_rs_dgrad
,
ub_bulk_dgrad
,
ub_bulk_wgrad
,
ub_name
,
fine_grained_activation_offloading
,
fp8_output
,
# pylint: disable=unused-variable
fsdp_group
,
module
,
skip_fp8_weight_update
,
symmetric_ar_type
,
save_original_input
,
debug
,
)
=
non_tensor_args
# NVTX label for profiling
nvtx_label
=
"transformer_engine._Linear.forward"
if
ub_name
is
not
None
:
...
...
@@ -323,7 +327,6 @@ class _Linear(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weightmat
,
inputmat_total
,
get_workspace
(),
quantization_params
=
output_quantizer
,
out_dtype
=
activation_dtype
,
bias
=
bias
,
...
...
@@ -520,7 +523,7 @@ class _Linear(torch.autograd.Function):
if
ctx
.
ub_name
is
not
None
:
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
with
torch
.
cuda
.
nvtx
.
range
(
"_Linear_backward"
):
with
get_
nvtx
_
range
_context
(
"_Linear_backward"
):
saved_tensors
=
ctx
.
saved_tensors
inputmat
,
weight_fp8
,
weight
,
bias
=
(
# pylint: disable=unbalanced-tuple-unpacking
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
...
...
@@ -744,7 +747,6 @@ class _Linear(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight_fp8
,
grad_output
,
get_workspace
(),
layout
=
"NN"
,
grad
=
True
,
quantization_params
=
ctx
.
grad_input_quantizer
,
...
...
@@ -870,7 +872,6 @@ class _Linear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
...
...
@@ -1005,47 +1006,14 @@ class _Linear(torch.autograd.Function):
wgrad
,
dgrad
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
grad_bias
,
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# input_quantizer
None
,
# weight_quantizer
None
,
# output_quantizer
None
,
# grad_input_quantizer
None
,
# grad_weight_quantizer
None
,
# grad_output_quantizer
None
,
# fuse_wgrad_accumulation
None
,
# cpu_offloading
None
,
# tp_group
None
,
# tp_size
None
,
# sequence_parallel
None
,
# tensor_parallel
None
,
# activation_dtype
None
,
# parallel_mode
None
,
# is_grad_enabled
None
,
# ub_overlap_rs_fprop
None
,
# ub_overlap_ag_dgrad
None
,
# ub_overlap_ag_fprop
None
,
# ub_overlap_rs_dgrad
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# fine_grained_activation_offloading
None
,
# fp8_output
None
,
# fsdp_group
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
None
,
# save_original_input
None
,
# debug
None
,
)
class
Linear
(
TransformerEngineBaseModule
):
"""Applies a linear transformation to the incoming data :math:`y = xA^T + b`
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
On NVIDIA GPUs it is a drop-in replacement for
`
`torch.nn.Linear`
`
.
Parameters
----------
...
...
@@ -1053,14 +1021,14 @@ class Linear(TransformerEngineBaseModule):
size of each input sample.
out_features : int
size of each output sample.
bias : bool, default =
`
True
`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default =
`
None
`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default =
`
None
`
bias : bool, default = True
if set to
`
`False`
`
, the layer will not learn an additive bias.
init_method : Callable, default = None
used for initializing weights in the following way:
`
`init_method(weight)`
`
.
When set to
`
`None`
`
, defaults to
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
get_rng_state_tracker : Callable, default = None
used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default =
`
None
`
rng_tracker_name : str, default = None
the param passed to get_rng_state_tracker to get the specific rng tracker.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
Configuration for splitting the weight and bias tensors along dim 0 into
...
...
@@ -1068,62 +1036,62 @@ class Linear(TransformerEngineBaseModule):
they are used to make the names of equally-sized parameters. If a dict
(preferably an OrderedDict) is provided, the keys are used as names and
values as split sizes along dim 0. The resulting parameters will have
names that end in `_weight` or `_bias`, so trailing underscores are
names that end in
`
`_weight`
`
or
`
`_bias`
`
, so trailing underscores are
stripped from any provided names.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default =
`
None
`
name
: str, default = None
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
sequence_parallel : bool, default =
`
False
`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default =
`
None
`
sequence_parallel : bool, default = False
if set to
`
`True`
`
, uses sequence parallelism.
tp_group : ProcessGroup, default = None
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
`
`set_tensor_parallel_group(tp_group)`
`
method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'column', 'row'}, default =
`
None
`
parallel_mode : {None, 'column', 'row'}, default = None
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
When set to
`
`None`
`
, no communication is performed.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
if set to
`
`True`
`
, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
have an additional
`
`main_grad`
`
attribute (used instead of the
regular
`
`grad`
`
) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default =
`
False
`
when set to `True`, this module will not apply the additive bias itself, but
will overwrite
`
`main_grad`
`
instead of accumulating.
return_bias : bool, default = False
when set to
`
`True`
`
, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default =
`
torch.get_default_dtype()
`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default =
`
False
`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to
`
`True`
`
,
it's the user's responsibility to call
`
`module.backward_dw`
`
to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
Requires PyTorch version 2.7.0 or higher. When set to
``
None
``
, standard all-reduce
is used.
save_original_input : bool, default =
`
False
`
If set to `True`, always saves the original input tensor rather than the
save_original_input : bool, default = False
If set to
`
`True`
`
, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
...
...
@@ -1434,8 +1402,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_in_onnx_export_mode
():
return
self
.
onnx_forward
(
inp
,
fp8_output
)
return
self
.
onnx_forward
(
inp
,
fp8_output
,
is_grad_enabled
)
debug
=
self
.
is_debug_iter
()
...
...
@@ -1457,9 +1427,7 @@ class Linear(TransformerEngineBaseModule):
).
is_fp8_ubuf
():
fp8_grad
=
True
with
torch
.
cuda
.
device
(
getattr
(
self
,
list
(
self
.
named_parameters
())[
0
][
0
]).
device
),
self
.
prepare_forward
(
with
self
.
prepare_forward
(
inp
,
allow_non_contiguous
=
isinstance
(
inp
,
QuantizedTensor
),
)
as
inp
:
...
...
@@ -1467,14 +1435,14 @@ class Linear(TransformerEngineBaseModule):
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
)
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
(
input_quantizer
,
...
...
@@ -1485,16 +1453,14 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer
,
)
=
quantizers
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
linear_fn
=
_Linear
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
linear_fn
=
_Linear
.
forward
args
=
[
None
]
args
+=
(
weight_tensor
,
inp
,
bias_tensor
if
(
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
)
else
None
,
autograd_ctx
=
[
None
]
non_tensor_args
=
(
is_first_microbatch
,
self
.
fp8
,
self
.
fp8_calibration
,
...
...
@@ -1513,7 +1479,7 @@ class Linear(TransformerEngineBaseModule):
self
.
tp_size
>
1
,
self
.
activation_dtype
,
self
.
parallel_mode
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
self
.
ub_overlap_rs_fprop
,
self
.
ub_overlap_ag_dgrad
,
self
.
ub_overlap_ag_fprop
,
...
...
@@ -1530,7 +1496,13 @@ class Linear(TransformerEngineBaseModule):
self
.
save_original_input
,
debug
,
)
out
=
linear_fn
(
*
args
)
out
=
linear_fn
(
*
autograd_ctx
,
weight_tensor
,
inp
,
bias_tensor
if
(
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
)
else
None
,
non_tensor_args
,
)
if
self
.
gemm_bias_unfused_add
:
out
=
out
+
cast_if_needed
(
bias_tensor
,
self
.
activation_dtype
)
...
...
@@ -1538,7 +1510,7 @@ class Linear(TransformerEngineBaseModule):
return
out
,
cast_if_needed
(
bias_tensor
,
self
.
activation_dtype
)
return
out
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
):
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
,
is_grad_enabled
):
if
not
self
.
fp8
:
return
[
None
]
*
6
grad_input_quantizer
=
None
...
...
@@ -1550,7 +1522,7 @@ class Linear(TransformerEngineBaseModule):
(
weight_quantizer
,)
=
self
.
_get_weight_quantizers
()
if
fp8_output
:
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
grad_output_quantizer
.
internal
=
True
if
fp8_grad
:
...
...
@@ -1564,8 +1536,8 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer
,
)
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
):
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
,
is_grad_enabled
):
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
assert
TEDebugState
.
debug_enabled
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
...
...
@@ -1620,6 +1592,7 @@ class Linear(TransformerEngineBaseModule):
self
,
inp
:
torch
.
Tensor
,
fp8_output
:
bool
,
is_grad_enabled
:
bool
,
)
->
torch
.
Tensor
:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
...
...
@@ -1636,7 +1609,7 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer
,
output_quantizer
,
*
_
,
)
=
self
.
_get_quantizers
(
fp8_output
,
False
)
)
=
self
.
_get_quantizers
(
fp8_output
,
False
,
is_grad_enabled
)
inp_dtype
=
inp
.
dtype
if
input_quantizer
is
not
None
:
...
...
transformer_engine/pytorch/module/rmsnorm.py
View file @
970620a5
...
...
@@ -33,32 +33,29 @@ class RMSNorm(_RMSNormOp):
Parameters
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default =
'
False
'
If `True`, the :math:`\gamma` parameter is initialized to zero
zero_centered_gamma : bool, default = False
If
`
`True`
`
, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0
sm_margin
: int, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
Legacy
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
margin at each compute stage (``"forward"``, ``"backward"``,
``"inference"``).
sequence_parallel : bool
**Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters.
This is custom logic for Megatron-LM integration.
"""
...
...
transformer_engine/pytorch/onnx_extensions.py
View file @
970620a5
...
...
@@ -356,7 +356,9 @@ def onnx_layernorm(
)
if
normalization
==
"RMSNorm"
:
ln_out
=
torch
.
nn
.
functional
.
rms_norm
(
inp
,
inp
.
shape
[
-
1
:],
ln_weight
,
eps
)
variance
=
inp
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
ln_out
=
inp
*
torch
.
rsqrt
(
variance
+
eps
)
ln_out
=
ln_out
*
ln_weight
else
:
ln_out
=
torch
.
nn
.
functional
.
layer_norm
(
inp
,
inp
.
shape
[
-
1
:],
ln_weight
,
layer_norm_bias
,
eps
...
...
transformer_engine/pytorch/ops/_common.py
View file @
970620a5
...
...
@@ -10,7 +10,7 @@ from typing import Optional
import
torch
from
transformer_engine_torch
import
FP8TensorMeta
from
..
import
torch_version
from
..
torch_version
import
torch_version
from
..quantization
import
FP8GlobalStateManager
from
..tensor.float8_tensor
import
Float8Tensor
from
..quantized_tensor
import
QuantizedTensorStorage
...
...
transformer_engine/pytorch/ops/basic/activation.py
View file @
970620a5
...
...
@@ -53,7 +53,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
Parameters
----------
cache_quantized_input: bool, default = False
cache_quantized_input
: bool, default = False
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
...
...
@@ -408,11 +408,11 @@ class ClampedSwiGLU(_ActivationOperation):
Parameters
----------
limit: float
limit
: float
The clamp limit.
alpha: float
alpha
: float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input: bool, default = False
cache_quantized_input
: bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
...
...
transformer_engine/pytorch/ops/basic/all_gather.py
View file @
970620a5
...
...
@@ -23,7 +23,7 @@ class AllGather(BasicOperation):
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
process_group
: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment