Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
177 additions
and
231 deletions
+177
-231
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+104
-161
transformer_engine/pytorch/module/rmsnorm.py
transformer_engine/pytorch/module/rmsnorm.py
+11
-14
transformer_engine/pytorch/numerics_debug.py
transformer_engine/pytorch/numerics_debug.py
+1
-1
transformer_engine/pytorch/onnx_extensions.py
transformer_engine/pytorch/onnx_extensions.py
+4
-2
transformer_engine/pytorch/ops/__init__.py
transformer_engine/pytorch/ops/__init__.py
+1
-1
transformer_engine/pytorch/ops/_common.py
transformer_engine/pytorch/ops/_common.py
+2
-2
transformer_engine/pytorch/ops/basic/__init__.py
transformer_engine/pytorch/ops/basic/__init__.py
+1
-1
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+5
-5
transformer_engine/pytorch/ops/basic/add_extra_input.py
transformer_engine/pytorch/ops/basic/add_extra_input.py
+1
-1
transformer_engine/pytorch/ops/basic/all_gather.py
transformer_engine/pytorch/ops/basic/all_gather.py
+2
-2
transformer_engine/pytorch/ops/basic/all_reduce.py
transformer_engine/pytorch/ops/basic/all_reduce.py
+2
-2
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+21
-17
transformer_engine/pytorch/ops/basic/bias.py
transformer_engine/pytorch/ops/basic/bias.py
+6
-6
transformer_engine/pytorch/ops/basic/constant_scale.py
transformer_engine/pytorch/ops/basic/constant_scale.py
+1
-1
transformer_engine/pytorch/ops/basic/dropout.py
transformer_engine/pytorch/ops/basic/dropout.py
+1
-1
transformer_engine/pytorch/ops/basic/identity.py
transformer_engine/pytorch/ops/basic/identity.py
+1
-1
transformer_engine/pytorch/ops/basic/l2normalization.py
transformer_engine/pytorch/ops/basic/l2normalization.py
+4
-4
transformer_engine/pytorch/ops/basic/layer_norm.py
transformer_engine/pytorch/ops/basic/layer_norm.py
+5
-5
transformer_engine/pytorch/ops/basic/make_extra_output.py
transformer_engine/pytorch/ops/basic/make_extra_output.py
+1
-1
transformer_engine/pytorch/ops/basic/quantize.py
transformer_engine/pytorch/ops/basic/quantize.py
+3
-3
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
transformer_engine/pytorch/module/linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -14,13 +14,12 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
.torch_version
import
torch_version
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_ub
,
get_workspace
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
...
...
@@ -39,6 +38,7 @@ from ..utils import (
assert_dim_for_all_gather
,
nvtx_range_pop
,
nvtx_range_push
,
get_nvtx_range_context
,
)
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
...
...
@@ -91,42 +91,46 @@ class _Linear(torch.autograd.Function):
weight
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
input_quantizer
:
Optional
[
Quantizer
],
weight_quantizer
:
Optional
[
Quantizer
],
output_quantizer
:
Optional
[
Quantizer
],
grad_input_quantizer
:
Optional
[
Quantizer
],
grad_weight_quantizer
:
Optional
[
Quantizer
],
grad_output_quantizer
:
Optional
[
Quantizer
],
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
tp_size
:
int
,
sequence_parallel
:
bool
,
tensor_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
parallel_mode
:
Union
[
str
,
None
],
is_grad_enabled
:
bool
,
ub_overlap_rs_fprop
:
bool
,
ub_overlap_ag_dgrad
:
bool
,
ub_overlap_ag_fprop
:
bool
,
ub_overlap_rs_dgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_name
:
str
,
fp8_output
:
bool
,
# pylint: disable=unused-argument
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
save_original_input
:
bool
=
False
,
debug
:
Optional
[
bool
]
=
False
,
non_tensor_args
:
Tuple
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
(
is_first_microbatch
,
fp8
,
fp8_calibration
,
wgrad_store
,
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_input_quantizer
,
grad_weight_quantizer
,
grad_output_quantizer
,
fuse_wgrad_accumulation
,
cpu_offloading
,
tp_group
,
tp_size
,
sequence_parallel
,
tensor_parallel
,
activation_dtype
,
parallel_mode
,
is_grad_enabled
,
ub_overlap_rs_fprop
,
ub_overlap_ag_dgrad
,
ub_overlap_ag_fprop
,
ub_overlap_rs_dgrad
,
ub_bulk_dgrad
,
ub_bulk_wgrad
,
ub_name
,
fp8_output
,
# pylint: disable=unused-variable
fsdp_group
,
module
,
skip_fp8_weight_update
,
symmetric_ar_type
,
save_original_input
,
debug
,
)
=
non_tensor_args
# NVTX label for profiling
nvtx_label
=
"transformer_engine._Linear.forward"
if
ub_name
is
not
None
:
...
...
@@ -321,7 +325,6 @@ class _Linear(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weightmat
,
inputmat_total
,
get_workspace
(),
quantization_params
=
output_quantizer
,
out_dtype
=
activation_dtype
,
bias
=
bias
,
...
...
@@ -426,7 +429,8 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
mark_not_offload
(
weight
,
weightmat
,
bias
)
if
cpu_offloading
:
mark_not_offload
(
weight
,
weightmat
,
bias
)
# TODO(ksivamani): Check memory usage
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
saved_inputmat
,
...
...
@@ -498,7 +502,7 @@ class _Linear(torch.autograd.Function):
if
ctx
.
ub_name
is
not
None
:
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
with
torch
.
cuda
.
nvtx
.
range
(
"_Linear_backward"
):
with
get_
nvtx
_
range
_context
(
"_Linear_backward"
):
saved_tensors
=
ctx
.
saved_tensors
inputmat
,
weight_fp8
,
weight
,
bias
=
(
# pylint: disable=unbalanced-tuple-unpacking
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
...
...
@@ -721,7 +725,6 @@ class _Linear(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight_fp8
,
grad_output
,
get_workspace
(),
layout
=
"NN"
,
grad
=
True
,
quantization_params
=
ctx
.
grad_input_quantizer
,
...
...
@@ -847,7 +850,6 @@ class _Linear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
...
...
@@ -982,46 +984,14 @@ class _Linear(torch.autograd.Function):
wgrad
,
dgrad
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
grad_bias
,
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# input_quantizer
None
,
# weight_quantizer
None
,
# output_quantizer
None
,
# grad_input_quantizer
None
,
# grad_weight_quantizer
None
,
# grad_output_quantizer
None
,
# fuse_wgrad_accumulation
None
,
# cpu_offloading
None
,
# tp_group
None
,
# tp_size
None
,
# sequence_parallel
None
,
# tensor_parallel
None
,
# activation_dtype
None
,
# parallel_mode
None
,
# is_grad_enabled
None
,
# ub_overlap_rs_fprop
None
,
# ub_overlap_ag_dgrad
None
,
# ub_overlap_ag_fprop
None
,
# ub_overlap_rs_dgrad
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# fp8_output
None
,
# fsdp_group
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
None
,
# save_original_input
None
,
# debug
None
,
)
class
Linear
(
TransformerEngineBaseModule
):
"""Applies a linear transformation to the incoming data :math:`y = xA^T + b`
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
On NVIDIA GPUs it is a drop-in replacement for
`
`torch.nn.Linear`
`
.
Parameters
----------
...
...
@@ -1029,14 +999,14 @@ class Linear(TransformerEngineBaseModule):
size of each input sample.
out_features : int
size of each output sample.
bias : bool, default =
`
True
`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default =
`
None
`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default =
`
None
`
bias : bool, default = True
if set to
`
`False`
`
, the layer will not learn an additive bias.
init_method : Callable, default = None
used for initializing weights in the following way:
`
`init_method(weight)`
`
.
When set to
`
`None`
`
, defaults to
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
get_rng_state_tracker : Callable, default = None
used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default =
`
None
`
rng_tracker_name : str, default = None
the param passed to get_rng_state_tracker to get the specific rng tracker.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
Configuration for splitting the weight and bias tensors along dim 0 into
...
...
@@ -1044,62 +1014,62 @@ class Linear(TransformerEngineBaseModule):
they are used to make the names of equally-sized parameters. If a dict
(preferably an OrderedDict) is provided, the keys are used as names and
values as split sizes along dim 0. The resulting parameters will have
names that end in `_weight` or `_bias`, so trailing underscores are
names that end in
`
`_weight`
`
or
`
`_bias`
`
, so trailing underscores are
stripped from any provided names.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default =
`
None
`
name
: str, default = None
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
sequence_parallel : bool, default =
`
False
`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default =
`
None
`
sequence_parallel : bool, default = False
if set to
`
`True`
`
, uses sequence parallelism.
tp_group : ProcessGroup, default = None
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
`
`set_tensor_parallel_group(tp_group)`
`
method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'column', 'row'}, default =
`
None
`
parallel_mode : {None, 'column', 'row'}, default = None
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
When set to
`
`None`
`
, no communication is performed.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
if set to
`
`True`
`
, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
have an additional
`
`main_grad`
`
attribute (used instead of the
regular
`
`grad`
`
) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default =
`
False
`
when set to `True`, this module will not apply the additive bias itself, but
will overwrite
`
`main_grad`
`
instead of accumulating.
return_bias : bool, default = False
when set to
`
`True`
`
, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default =
`
torch.get_default_dtype()
`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default =
`
False
`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to
`
`True`
`
,
it's the user's responsibility to call
`
`module.backward_dw`
`
to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
Requires PyTorch version 2.7.0 or higher. When set to
``
None
``
, standard all-reduce
is used.
save_original_input : bool, default =
`
False
`
If set to `True`, always saves the original input tensor rather than the
save_original_input : bool, default = False
If set to
`
`True`
`
, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
...
...
@@ -1348,15 +1318,12 @@ class Linear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd."""
super
().
set_meta_tensor
(
fwd
,
recipe
)
#
customize quantizers based on each recipe & layer configs
#
Recipe-specific quantizer configuration
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
def
reset_parameters
(
self
,
defer_init
=
False
):
super
().
reset_parameters
(
defer_init
=
defer_init
)
...
...
@@ -1408,8 +1375,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_in_onnx_export_mode
():
return
self
.
onnx_forward
(
inp
,
fp8_output
)
return
self
.
onnx_forward
(
inp
,
fp8_output
,
is_grad_enabled
)
debug
=
self
.
is_debug_iter
()
...
...
@@ -1431,9 +1400,7 @@ class Linear(TransformerEngineBaseModule):
).
is_fp8_ubuf
():
fp8_grad
=
True
with
torch
.
cuda
.
device
(
getattr
(
self
,
list
(
self
.
named_parameters
())[
0
][
0
]).
device
),
self
.
prepare_forward
(
with
self
.
prepare_forward
(
inp
,
allow_non_contiguous
=
isinstance
(
inp
,
QuantizedTensor
),
)
as
inp
:
...
...
@@ -1441,14 +1408,14 @@ class Linear(TransformerEngineBaseModule):
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
)
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
(
input_quantizer
,
...
...
@@ -1459,16 +1426,14 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer
,
)
=
quantizers
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
linear_fn
=
_Linear
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
linear_fn
=
_Linear
.
forward
args
=
[
None
]
args
+=
(
weight_tensor
,
inp
,
bias_tensor
if
(
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
)
else
None
,
autograd_ctx
=
[
None
]
non_tensor_args
=
(
is_first_microbatch
,
self
.
fp8
,
self
.
fp8_calibration
,
...
...
@@ -1487,7 +1452,7 @@ class Linear(TransformerEngineBaseModule):
self
.
tp_size
>
1
,
self
.
activation_dtype
,
self
.
parallel_mode
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
self
.
ub_overlap_rs_fprop
,
self
.
ub_overlap_ag_dgrad
,
self
.
ub_overlap_ag_fprop
,
...
...
@@ -1503,7 +1468,13 @@ class Linear(TransformerEngineBaseModule):
self
.
save_original_input
,
debug
,
)
out
=
linear_fn
(
*
args
)
out
=
linear_fn
(
*
autograd_ctx
,
weight_tensor
,
inp
,
bias_tensor
if
(
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
)
else
None
,
non_tensor_args
,
)
if
self
.
gemm_bias_unfused_add
:
out
=
out
+
cast_if_needed
(
bias_tensor
,
self
.
activation_dtype
)
...
...
@@ -1511,7 +1482,7 @@ class Linear(TransformerEngineBaseModule):
return
out
,
cast_if_needed
(
bias_tensor
,
self
.
activation_dtype
)
return
out
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
):
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
,
is_grad_enabled
):
if
not
self
.
fp8
:
return
[
None
]
*
6
grad_input_quantizer
=
None
...
...
@@ -1520,12 +1491,16 @@ class Linear(TransformerEngineBaseModule):
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
.
internal
=
True
if
not
(
self
.
parallel_mode
==
"column"
and
self
.
sequence_parallel
):
input_quantizer
.
optimize_for_gemm
=
True
(
weight_quantizer
,)
=
self
.
_get_weight_quantizers
()
if
fp8_output
:
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
grad_output_quantizer
.
internal
=
True
if
not
(
self
.
parallel_mode
==
"row"
and
self
.
sequence_parallel
):
grad_output_quantizer
.
optimize_for_gemm
=
True
if
fp8_grad
:
grad_input_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
return
(
...
...
@@ -1537,8 +1512,8 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer
,
)
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
):
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
,
is_grad_enabled
):
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
assert
TEDebugState
.
debug_enabled
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
...
...
@@ -1568,31 +1543,18 @@ class Linear(TransformerEngineBaseModule):
def
_get_weight_and_bias_tensors
(
self
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
# Get concatenated weight and bias tensors
unfused_weights
=
self
.
_get_weight_tensors
()
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
self
.
fp8
:
if
len
(
unfused_weights
)
!=
1
:
raise
RuntimeError
(
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
weight_tensor
=
noop_cat
(
unfused_weights
)
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
else
:
bias_tensor
=
None
return
weight_tensor
,
bias_tensor
def
onnx_forward
(
self
,
inp
:
torch
.
Tensor
,
fp8_output
:
bool
,
is_grad_enabled
:
bool
,
)
->
torch
.
Tensor
:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
...
...
@@ -1609,7 +1571,7 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer
,
output_quantizer
,
*
_
,
)
=
self
.
_get_quantizers
(
fp8_output
,
False
)
)
=
self
.
_get_quantizers
(
fp8_output
,
False
,
is_grad_enabled
)
inp_dtype
=
inp
.
dtype
if
input_quantizer
is
not
None
:
...
...
@@ -1713,22 +1675,3 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
return
[
weight_quantizer
]
def
_customize_quantizers_float8_blockwise_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on blockwise scaling recipe + linear."""
assert
(
recipe
.
float8_block_scaling
()
),
"blockwise scaling recipe quantizer customization here"
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
# set compact for inp tensor X
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
all_gather_usage
=
True
else
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# set compact for grad_output tensor dY
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
all_gather_usage
=
True
transformer_engine/pytorch/module/rmsnorm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -33,32 +33,29 @@ class RMSNorm(_RMSNormOp):
Parameters
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default =
'
False
'
If `True`, the :math:`\gamma` parameter is initialized to zero
zero_centered_gamma : bool, default = False
If
`
`True`
`
, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0
sm_margin
: int, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
Legacy
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
margin at each compute stage (``"forward"``, ``"backward"``,
``"inference"``).
sequence_parallel : bool
**Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters.
This is custom logic for Megatron-LM integration.
"""
...
...
transformer_engine/pytorch/numerics_debug.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/onnx_extensions.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -356,7 +356,9 @@ def onnx_layernorm(
)
if
normalization
==
"RMSNorm"
:
ln_out
=
torch
.
nn
.
functional
.
rms_norm
(
inp
,
inp
.
shape
[
-
1
:],
ln_weight
,
eps
)
variance
=
inp
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
ln_out
=
inp
*
torch
.
rsqrt
(
variance
+
eps
)
ln_out
=
ln_out
*
ln_weight
else
:
ln_out
=
torch
.
nn
.
functional
.
layer_norm
(
inp
,
inp
.
shape
[
-
1
:],
ln_weight
,
layer_norm_bias
,
eps
...
...
transformer_engine/pytorch/ops/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/_common.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -10,7 +10,7 @@ from typing import Optional
import
torch
from
transformer_engine_torch
import
FP8TensorMeta
from
..
import
torch_version
from
..
torch_version
import
torch_version
from
..quantization
import
FP8GlobalStateManager
from
..tensor.float8_tensor
import
Float8Tensor
from
..quantized_tensor
import
QuantizedTensorStorage
...
...
transformer_engine/pytorch/ops/basic/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/activation.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -53,7 +53,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
Parameters
----------
cache_quantized_input: bool, default = False
cache_quantized_input
: bool, default = False
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
...
...
@@ -408,11 +408,11 @@ class ClampedSwiGLU(_ActivationOperation):
Parameters
----------
limit: float
limit
: float
The clamp limit.
alpha: float
alpha
: float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input: bool, default = False
cache_quantized_input
: bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
...
...
transformer_engine/pytorch/ops/basic/add_extra_input.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/all_gather.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -23,7 +23,7 @@ class AllGather(BasicOperation):
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
process_group
: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
...
...
transformer_engine/pytorch/ops/basic/all_reduce.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -24,7 +24,7 @@ class AllReduce(BasicOperation):
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
process_group
: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -25,7 +25,6 @@ from ...module.base import (
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
get_dummy_wgrad
,
get_workspace
,
)
from
...tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
...
...
@@ -54,27 +53,27 @@ class BasicLinear(BasicOperation):
Parameters
----------
in_features: int
in_features
: int
Inner dimension of input tensor
out_features: int
out_features
: int
Inner dimension of output tensor
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
tensor_parallel_mode
: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
tensor_parallel_group
: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
sequence_parallel
: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
rng_state_tracker_function: callable
rng_state_tracker_function
: callable
Function that returns `CudaRNGStatesTracker`, which is used
for model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
accumulate_into_main_grad
: bool, default = `False`
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
...
...
@@ -138,8 +137,10 @@ class BasicLinear(BasicOperation):
out_features
=
out_features
,
)
#
Whether weight tensor is
natively quantized
#
Initialize recipe state if needed for
natively quantized
weight
self
.
_with_quantized_weight
:
bool
=
FP8GlobalStateManager
.
with_fp8_parameters
()
if
self
.
_with_quantized_weight
:
self
.
reset_recipe_state
(
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
())
# Initialize parameters if needed
weight
=
torch
.
empty
(
...
...
@@ -341,15 +342,21 @@ class BasicLinear(BasicOperation):
def
reset_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
])
->
None
:
super
().
reset_recipe_state
(
recipe
=
recipe
)
# Input/grad output quantizers use internal tensors
# Configure input/grad output tensor
# Note: These tensors are only used internally. If there is no
# tensor-parallel communication, they are only used for GEMM.
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
if
input_quantizer
is
not
None
:
input_quantizer
.
internal
=
True
if
not
(
self
.
tensor_parallel_mode
==
"column"
and
self
.
sequence_parallel
):
input_quantizer
.
optimize_for_gemm
=
True
if
grad_output_quantizer
is
not
None
:
grad_output_quantizer
.
internal
=
True
if
not
(
self
.
tensor_parallel_mode
==
"row"
and
self
.
sequence_parallel
):
grad_output_quantizer
.
optimize_for_gemm
=
True
#
Handl
e weight quantizer
#
Configur
e weight quantizer
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
...
...
@@ -585,7 +592,6 @@ class BasicLinear(BasicOperation):
y
,
*
_
=
general_gemm
(
w
,
x
,
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
output_quantizer
,
alpha
=
alpha
,
...
...
@@ -875,7 +881,6 @@ class BasicLinear(BasicOperation):
dx
,
*
_
=
general_gemm
(
w
,
dy
,
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
grad_input_quantizer
,
alpha
=
grad_input_alpha
,
...
...
@@ -928,7 +933,6 @@ class BasicLinear(BasicOperation):
dw
,
*
_
=
general_gemm
(
x
,
dy
,
get_workspace
(),
out_dtype
=
dw_dtype
,
alpha
=
grad_weight_alpha
,
beta
=
grad_weight_beta
,
...
...
transformer_engine/pytorch/ops/basic/bias.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -22,16 +22,16 @@ class Bias(BasicOperation):
Parameters
----------
size: int
size
: int
Inner dimension of input tensor
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel: bool, default = `False`
tensor_parallel
: bool, default = `False`
Whether to distribute input tensor and bias tensors along
inner dimension
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
tensor_parallel_group
: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
"""
...
...
transformer_engine/pytorch/ops/basic/constant_scale.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/dropout.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/identity.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/l2normalization.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -10,7 +10,7 @@ import os
import
torch
from
...
import
torch_version
from
...
torch_version
import
torch_version
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...jit
import
(
l2normalization_fused
,
...
...
@@ -40,11 +40,11 @@ class L2Normalization(BasicOperation):
----------
eps : float, default = 1e-6
A value added to the denominator for numerical stability
seq_length: int, default = None
seq_length
: int, default = None
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward
propagation and activation recompute phase.
micro_batch_size: int, default = None
micro_batch_size
: int, default = None
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
...
...
transformer_engine/pytorch/ops/basic/layer_norm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -42,14 +42,14 @@ class LayerNorm(BasicOperation):
Parameters
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator of layer normalization for
numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
...
...
@@ -58,7 +58,7 @@ class LayerNorm(BasicOperation):
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0
sm_margin
: int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
...
...
transformer_engine/pytorch/ops/basic/make_extra_output.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/quantize.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -23,9 +23,9 @@ class Quantize(BasicOperation):
Parameters
----------
forward: bool, default = `True`
forward
: bool, default = `True`
Perform quantization in forward pass
backward: bool, default = `False`
backward
: bool, default = `False`
Perform quantization in backward pass
"""
...
...
Prev
1
…
26
27
28
29
30
31
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment