Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
646
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
177 additions
and
231 deletions
+177
-231
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+104
-161
transformer_engine/pytorch/module/rmsnorm.py
transformer_engine/pytorch/module/rmsnorm.py
+11
-14
transformer_engine/pytorch/numerics_debug.py
transformer_engine/pytorch/numerics_debug.py
+1
-1
transformer_engine/pytorch/onnx_extensions.py
transformer_engine/pytorch/onnx_extensions.py
+4
-2
transformer_engine/pytorch/ops/__init__.py
transformer_engine/pytorch/ops/__init__.py
+1
-1
transformer_engine/pytorch/ops/_common.py
transformer_engine/pytorch/ops/_common.py
+2
-2
transformer_engine/pytorch/ops/basic/__init__.py
transformer_engine/pytorch/ops/basic/__init__.py
+1
-1
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+5
-5
transformer_engine/pytorch/ops/basic/add_extra_input.py
transformer_engine/pytorch/ops/basic/add_extra_input.py
+1
-1
transformer_engine/pytorch/ops/basic/all_gather.py
transformer_engine/pytorch/ops/basic/all_gather.py
+2
-2
transformer_engine/pytorch/ops/basic/all_reduce.py
transformer_engine/pytorch/ops/basic/all_reduce.py
+2
-2
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+21
-17
transformer_engine/pytorch/ops/basic/bias.py
transformer_engine/pytorch/ops/basic/bias.py
+6
-6
transformer_engine/pytorch/ops/basic/constant_scale.py
transformer_engine/pytorch/ops/basic/constant_scale.py
+1
-1
transformer_engine/pytorch/ops/basic/dropout.py
transformer_engine/pytorch/ops/basic/dropout.py
+1
-1
transformer_engine/pytorch/ops/basic/identity.py
transformer_engine/pytorch/ops/basic/identity.py
+1
-1
transformer_engine/pytorch/ops/basic/l2normalization.py
transformer_engine/pytorch/ops/basic/l2normalization.py
+4
-4
transformer_engine/pytorch/ops/basic/layer_norm.py
transformer_engine/pytorch/ops/basic/layer_norm.py
+5
-5
transformer_engine/pytorch/ops/basic/make_extra_output.py
transformer_engine/pytorch/ops/basic/make_extra_output.py
+1
-1
transformer_engine/pytorch/ops/basic/quantize.py
transformer_engine/pytorch/ops/basic/quantize.py
+3
-3
No files found.
transformer_engine/pytorch/module/linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -14,13 +14,12 @@ import torch
...
@@ -14,13 +14,12 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
.torch_version
import
torch_version
from
.base
import
(
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_dummy_wgrad
,
get_ub
,
get_ub
,
get_workspace
,
TransformerEngineBaseModule
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
...
@@ -39,6 +38,7 @@ from ..utils import (
...
@@ -39,6 +38,7 @@ from ..utils import (
assert_dim_for_all_gather
,
assert_dim_for_all_gather
,
nvtx_range_pop
,
nvtx_range_pop
,
nvtx_range_push
,
nvtx_range_push
,
get_nvtx_range_context
,
)
)
from
..distributed
import
(
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
set_tensor_model_parallel_attributes
,
...
@@ -91,42 +91,46 @@ class _Linear(torch.autograd.Function):
...
@@ -91,42 +91,46 @@ class _Linear(torch.autograd.Function):
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
],
is_first_microbatch
:
Union
[
bool
,
None
],
non_tensor_args
:
Tuple
,
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
input_quantizer
:
Optional
[
Quantizer
],
weight_quantizer
:
Optional
[
Quantizer
],
output_quantizer
:
Optional
[
Quantizer
],
grad_input_quantizer
:
Optional
[
Quantizer
],
grad_weight_quantizer
:
Optional
[
Quantizer
],
grad_output_quantizer
:
Optional
[
Quantizer
],
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
tp_size
:
int
,
sequence_parallel
:
bool
,
tensor_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
parallel_mode
:
Union
[
str
,
None
],
is_grad_enabled
:
bool
,
ub_overlap_rs_fprop
:
bool
,
ub_overlap_ag_dgrad
:
bool
,
ub_overlap_ag_fprop
:
bool
,
ub_overlap_rs_dgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_name
:
str
,
fp8_output
:
bool
,
# pylint: disable=unused-argument
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
save_original_input
:
bool
=
False
,
debug
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
(
is_first_microbatch
,
fp8
,
fp8_calibration
,
wgrad_store
,
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_input_quantizer
,
grad_weight_quantizer
,
grad_output_quantizer
,
fuse_wgrad_accumulation
,
cpu_offloading
,
tp_group
,
tp_size
,
sequence_parallel
,
tensor_parallel
,
activation_dtype
,
parallel_mode
,
is_grad_enabled
,
ub_overlap_rs_fprop
,
ub_overlap_ag_dgrad
,
ub_overlap_ag_fprop
,
ub_overlap_rs_dgrad
,
ub_bulk_dgrad
,
ub_bulk_wgrad
,
ub_name
,
fp8_output
,
# pylint: disable=unused-variable
fsdp_group
,
module
,
skip_fp8_weight_update
,
symmetric_ar_type
,
save_original_input
,
debug
,
)
=
non_tensor_args
# NVTX label for profiling
# NVTX label for profiling
nvtx_label
=
"transformer_engine._Linear.forward"
nvtx_label
=
"transformer_engine._Linear.forward"
if
ub_name
is
not
None
:
if
ub_name
is
not
None
:
...
@@ -321,7 +325,6 @@ class _Linear(torch.autograd.Function):
...
@@ -321,7 +325,6 @@ class _Linear(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weightmat
,
weightmat
,
inputmat_total
,
inputmat_total
,
get_workspace
(),
quantization_params
=
output_quantizer
,
quantization_params
=
output_quantizer
,
out_dtype
=
activation_dtype
,
out_dtype
=
activation_dtype
,
bias
=
bias
,
bias
=
bias
,
...
@@ -426,6 +429,7 @@ class _Linear(torch.autograd.Function):
...
@@ -426,6 +429,7 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
ctx
.
weight_object
=
weight
if
cpu_offloading
:
mark_not_offload
(
weight
,
weightmat
,
bias
)
mark_not_offload
(
weight
,
weightmat
,
bias
)
# TODO(ksivamani): Check memory usage
# TODO(ksivamani): Check memory usage
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
...
@@ -498,7 +502,7 @@ class _Linear(torch.autograd.Function):
...
@@ -498,7 +502,7 @@ class _Linear(torch.autograd.Function):
if
ctx
.
ub_name
is
not
None
:
if
ctx
.
ub_name
is
not
None
:
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
with
torch
.
cuda
.
nvtx
.
range
(
"_Linear_backward"
):
with
get_
nvtx
_
range
_context
(
"_Linear_backward"
):
saved_tensors
=
ctx
.
saved_tensors
saved_tensors
=
ctx
.
saved_tensors
inputmat
,
weight_fp8
,
weight
,
bias
=
(
# pylint: disable=unbalanced-tuple-unpacking
inputmat
,
weight_fp8
,
weight
,
bias
=
(
# pylint: disable=unbalanced-tuple-unpacking
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
...
@@ -721,7 +725,6 @@ class _Linear(torch.autograd.Function):
...
@@ -721,7 +725,6 @@ class _Linear(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight_fp8
,
weight_fp8
,
grad_output
,
grad_output
,
get_workspace
(),
layout
=
"NN"
,
layout
=
"NN"
,
grad
=
True
,
grad
=
True
,
quantization_params
=
ctx
.
grad_input_quantizer
,
quantization_params
=
ctx
.
grad_input_quantizer
,
...
@@ -847,7 +850,6 @@ class _Linear(torch.autograd.Function):
...
@@ -847,7 +850,6 @@ class _Linear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs
=
{
wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
"out_dtype"
:
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
),
...
@@ -982,46 +984,14 @@ class _Linear(torch.autograd.Function):
...
@@ -982,46 +984,14 @@ class _Linear(torch.autograd.Function):
wgrad
,
wgrad
,
dgrad
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
dgrad
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
grad_bias
,
grad_bias
,
None
,
# is_first_microbatch
None
,
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# input_quantizer
None
,
# weight_quantizer
None
,
# output_quantizer
None
,
# grad_input_quantizer
None
,
# grad_weight_quantizer
None
,
# grad_output_quantizer
None
,
# fuse_wgrad_accumulation
None
,
# cpu_offloading
None
,
# tp_group
None
,
# tp_size
None
,
# sequence_parallel
None
,
# tensor_parallel
None
,
# activation_dtype
None
,
# parallel_mode
None
,
# is_grad_enabled
None
,
# ub_overlap_rs_fprop
None
,
# ub_overlap_ag_dgrad
None
,
# ub_overlap_ag_fprop
None
,
# ub_overlap_rs_dgrad
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# fp8_output
None
,
# fsdp_group
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
None
,
# save_original_input
None
,
# debug
)
)
class
Linear
(
TransformerEngineBaseModule
):
class
Linear
(
TransformerEngineBaseModule
):
"""Applies a linear transformation to the incoming data :math:`y = xA^T + b`
"""Applies a linear transformation to the incoming data :math:`y = xA^T + b`
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
On NVIDIA GPUs it is a drop-in replacement for
`
`torch.nn.Linear`
`
.
Parameters
Parameters
----------
----------
...
@@ -1029,14 +999,14 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1029,14 +999,14 @@ class Linear(TransformerEngineBaseModule):
size of each input sample.
size of each input sample.
out_features : int
out_features : int
size of each output sample.
size of each output sample.
bias : bool, default =
`
True
`
bias : bool, default = True
if set to `False`, the layer will not learn an additive bias.
if set to
`
`False`
`
, the layer will not learn an additive bias.
init_method : Callable, default =
`
None
`
init_method : Callable, default = None
used for initializing weights in the following way: `init_method(weight)`.
used for initializing weights in the following way:
`
`init_method(weight)`
`
.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
When set to
`
`None`
`
, defaults to
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
get_rng_state_tracker : Callable, default =
`
None
`
get_rng_state_tracker : Callable, default = None
used to get the random number generator state tracker for initializing weights.
used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default =
`
None
`
rng_tracker_name : str, default = None
the param passed to get_rng_state_tracker to get the specific rng tracker.
the param passed to get_rng_state_tracker to get the specific rng tracker.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
Configuration for splitting the weight and bias tensors along dim 0 into
Configuration for splitting the weight and bias tensors along dim 0 into
...
@@ -1044,62 +1014,62 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1044,62 +1014,62 @@ class Linear(TransformerEngineBaseModule):
they are used to make the names of equally-sized parameters. If a dict
they are used to make the names of equally-sized parameters. If a dict
(preferably an OrderedDict) is provided, the keys are used as names and
(preferably an OrderedDict) is provided, the keys are used as names and
values as split sizes along dim 0. The resulting parameters will have
values as split sizes along dim 0. The resulting parameters will have
names that end in `_weight` or `_bias`, so trailing underscores are
names that end in
`
`_weight`
`
or
`
`_bias`
`
, so trailing underscores are
stripped from any provided names.
stripped from any provided names.
device : Union[torch.device, str], default = "cuda"
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
forward pass.
name: str, default =
`
None
`
name
: str, default = None
name of the module, currently used for debugging purposes.
name of the module, currently used for debugging purposes.
Parallelism parameters
Parallelism parameters
----------------------
----------------------
sequence_parallel : bool, default =
`
False
`
sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism.
if set to
`
`True`
`
, uses sequence parallelism.
tp_group : ProcessGroup, default =
`
None
`
tp_group : ProcessGroup, default = None
tensor parallel process group.
tensor parallel process group.
tp_size : int, default = 1
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
`
`set_tensor_parallel_group(tp_group)`
`
method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel collectives.
parallel_mode : {None, 'column', 'row'}, default =
`
None
`
parallel_mode : {None, 'column', 'row'}, default = None
used to decide whether this Linear layer is Column Parallel Linear or Row
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
When set to
`
`None`
`
, no communication is performed.
Optimization parameters
Optimization parameters
-----------------------
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
if set to
`
`True`
`
, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
have an additional
`
`main_grad`
`
attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
regular
`
`grad`
`
) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
will overwrite
`
`main_grad`
`
instead of accumulating.
return_bias : bool, default =
`
False
`
return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias itself, but
when set to
`
`True`
`
, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default =
`
torch.get_default_dtype()
`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
would not fit in GPU memory.
delay_wgrad_compute : bool, default =
`
False
`
delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to `True`,
Whether or not to delay weight gradient computation. If set to
`
`True`
`
,
it's the user's responsibility to call `module.backward_dw` to compute
it's the user's responsibility to call
`
`module.backward_dw`
`
to compute
weight gradients.
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
Requires PyTorch version 2.7.0 or higher. When set to
``
None
``
, standard all-reduce
is used.
is used.
save_original_input : bool, default =
`
False
`
save_original_input : bool, default = False
If set to `True`, always saves the original input tensor rather than the
If set to
`
`True`
`
, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
Cannot work with FP8 DelayedScaling recipe.
...
@@ -1348,15 +1318,12 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1348,15 +1318,12 @@ class Linear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd."""
"""Init scales and amaxes for fwd | bwd."""
super
().
set_meta_tensor
(
fwd
,
recipe
)
super
().
set_meta_tensor
(
fwd
,
recipe
)
#
customize quantizers based on each recipe & layer configs
#
Recipe-specific quantizer configuration
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
if
recipe
.
float8_current_scaling
():
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
def
reset_parameters
(
self
,
defer_init
=
False
):
def
reset_parameters
(
self
,
defer_init
=
False
):
super
().
reset_parameters
(
defer_init
=
defer_init
)
super
().
reset_parameters
(
defer_init
=
defer_init
)
...
@@ -1408,8 +1375,10 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1408,8 +1375,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
first microbatch (since it is the first gradient being
produced)
produced)
"""
"""
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_in_onnx_export_mode
():
if
is_in_onnx_export_mode
():
return
self
.
onnx_forward
(
inp
,
fp8_output
)
return
self
.
onnx_forward
(
inp
,
fp8_output
,
is_grad_enabled
)
debug
=
self
.
is_debug_iter
()
debug
=
self
.
is_debug_iter
()
...
@@ -1431,9 +1400,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1431,9 +1400,7 @@ class Linear(TransformerEngineBaseModule):
).
is_fp8_ubuf
():
).
is_fp8_ubuf
():
fp8_grad
=
True
fp8_grad
=
True
with
torch
.
cuda
.
device
(
with
self
.
prepare_forward
(
getattr
(
self
,
list
(
self
.
named_parameters
())[
0
][
0
]).
device
),
self
.
prepare_forward
(
inp
,
inp
,
allow_non_contiguous
=
isinstance
(
inp
,
QuantizedTensor
),
allow_non_contiguous
=
isinstance
(
inp
,
QuantizedTensor
),
)
as
inp
:
)
as
inp
:
...
@@ -1441,14 +1408,14 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1441,14 +1408,14 @@ class Linear(TransformerEngineBaseModule):
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
quantizers
=
(
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
if
not
debug
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
)
)
if
debug
:
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
debug
=
False
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
(
(
input_quantizer
,
input_quantizer
,
...
@@ -1459,16 +1426,14 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1459,16 +1426,14 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer
,
grad_output_quantizer
,
)
=
quantizers
)
=
quantizers
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
linear_fn
=
_Linear
.
apply
linear_fn
=
_Linear
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
else
:
linear_fn
=
_Linear
.
forward
linear_fn
=
_Linear
.
forward
args
=
[
None
]
autograd_ctx
=
[
None
]
args
+=
(
weight_tensor
,
non_tensor_args
=
(
inp
,
bias_tensor
if
(
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
)
else
None
,
is_first_microbatch
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8
,
self
.
fp8_calibration
,
self
.
fp8_calibration
,
...
@@ -1487,7 +1452,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1487,7 +1452,7 @@ class Linear(TransformerEngineBaseModule):
self
.
tp_size
>
1
,
self
.
tp_size
>
1
,
self
.
activation_dtype
,
self
.
activation_dtype
,
self
.
parallel_mode
,
self
.
parallel_mode
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
self
.
ub_overlap_rs_fprop
,
self
.
ub_overlap_rs_fprop
,
self
.
ub_overlap_ag_dgrad
,
self
.
ub_overlap_ag_dgrad
,
self
.
ub_overlap_ag_fprop
,
self
.
ub_overlap_ag_fprop
,
...
@@ -1503,7 +1468,13 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1503,7 +1468,13 @@ class Linear(TransformerEngineBaseModule):
self
.
save_original_input
,
self
.
save_original_input
,
debug
,
debug
,
)
)
out
=
linear_fn
(
*
args
)
out
=
linear_fn
(
*
autograd_ctx
,
weight_tensor
,
inp
,
bias_tensor
if
(
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
)
else
None
,
non_tensor_args
,
)
if
self
.
gemm_bias_unfused_add
:
if
self
.
gemm_bias_unfused_add
:
out
=
out
+
cast_if_needed
(
bias_tensor
,
self
.
activation_dtype
)
out
=
out
+
cast_if_needed
(
bias_tensor
,
self
.
activation_dtype
)
...
@@ -1511,7 +1482,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1511,7 +1482,7 @@ class Linear(TransformerEngineBaseModule):
return
out
,
cast_if_needed
(
bias_tensor
,
self
.
activation_dtype
)
return
out
,
cast_if_needed
(
bias_tensor
,
self
.
activation_dtype
)
return
out
return
out
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
):
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
,
is_grad_enabled
):
if
not
self
.
fp8
:
if
not
self
.
fp8
:
return
[
None
]
*
6
return
[
None
]
*
6
grad_input_quantizer
=
None
grad_input_quantizer
=
None
...
@@ -1520,12 +1491,16 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1520,12 +1491,16 @@ class Linear(TransformerEngineBaseModule):
output_quantizer
=
None
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
.
internal
=
True
input_quantizer
.
internal
=
True
if
not
(
self
.
parallel_mode
==
"column"
and
self
.
sequence_parallel
):
input_quantizer
.
optimize_for_gemm
=
True
(
weight_quantizer
,)
=
self
.
_get_weight_quantizers
()
(
weight_quantizer
,)
=
self
.
_get_weight_quantizers
()
if
fp8_output
:
if
fp8_output
:
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
grad_output_quantizer
.
internal
=
True
grad_output_quantizer
.
internal
=
True
if
not
(
self
.
parallel_mode
==
"row"
and
self
.
sequence_parallel
):
grad_output_quantizer
.
optimize_for_gemm
=
True
if
fp8_grad
:
if
fp8_grad
:
grad_input_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
grad_input_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
return
(
return
(
...
@@ -1537,8 +1512,8 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1537,8 +1512,8 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer
,
grad_output_quantizer
,
)
)
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
):
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
,
is_grad_enabled
):
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
assert
TEDebugState
.
debug_enabled
assert
TEDebugState
.
debug_enabled
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
...
@@ -1568,31 +1543,18 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1568,31 +1543,18 @@ class Linear(TransformerEngineBaseModule):
def
_get_weight_and_bias_tensors
(
self
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
def
_get_weight_and_bias_tensors
(
self
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
# Get concatenated weight and bias tensors
# Get concatenated weight and bias tensors
unfused_weights
=
self
.
_get_weight_tensors
()
unfused_weights
=
self
.
_get_weight_tensors
()
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
self
.
fp8
:
if
len
(
unfused_weights
)
!=
1
:
raise
RuntimeError
(
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
weight_tensor
=
noop_cat
(
unfused_weights
)
weight_tensor
=
noop_cat
(
unfused_weights
)
if
self
.
use_bias
:
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
else
:
else
:
bias_tensor
=
None
bias_tensor
=
None
return
weight_tensor
,
bias_tensor
return
weight_tensor
,
bias_tensor
def
onnx_forward
(
def
onnx_forward
(
self
,
self
,
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
fp8_output
:
bool
,
fp8_output
:
bool
,
is_grad_enabled
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
ONNX-compatible version of the forward function that provides numerical equivalence
ONNX-compatible version of the forward function that provides numerical equivalence
...
@@ -1609,7 +1571,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1609,7 +1571,7 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer
,
weight_quantizer
,
output_quantizer
,
output_quantizer
,
*
_
,
*
_
,
)
=
self
.
_get_quantizers
(
fp8_output
,
False
)
)
=
self
.
_get_quantizers
(
fp8_output
,
False
,
is_grad_enabled
)
inp_dtype
=
inp
.
dtype
inp_dtype
=
inp
.
dtype
if
input_quantizer
is
not
None
:
if
input_quantizer
is
not
None
:
...
@@ -1713,22 +1675,3 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1713,22 +1675,3 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
weight_quantizer
.
internal
=
True
return
[
weight_quantizer
]
return
[
weight_quantizer
]
def
_customize_quantizers_float8_blockwise_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on blockwise scaling recipe + linear."""
assert
(
recipe
.
float8_block_scaling
()
),
"blockwise scaling recipe quantizer customization here"
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
# set compact for inp tensor X
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
all_gather_usage
=
True
else
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# set compact for grad_output tensor dY
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
all_gather_usage
=
True
transformer_engine/pytorch/module/rmsnorm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -33,32 +33,29 @@ class RMSNorm(_RMSNormOp):
...
@@ -33,32 +33,29 @@ class RMSNorm(_RMSNormOp):
Parameters
Parameters
----------
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
Inner dimensions of input tensor
eps : float, default = 1e-5
eps : float, default = 1e-5
A value added to the denominator for numerical stability
A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
Tensor datatype
zero_centered_gamma : bool, default =
'
False
'
zero_centered_gamma : bool, default = False
If `True`, the :math:`\gamma` parameter is initialized to zero
If
`
`True`
`
, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
and the calculation changes to
.. math::
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0
sm_margin
: int, default = 0
Number of SMs to exclude when launching CUDA kernels. This
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
margin at each compute stage (``"forward"``, ``"backward"``,
"inference").
``"inference"``).
sequence_parallel : bool
Legacy
**Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters.
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
This is custom logic for Megatron-LM integration.
This is custom logic for Megatron-LM integration.
"""
"""
...
...
transformer_engine/pytorch/numerics_debug.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/onnx_extensions.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -356,7 +356,9 @@ def onnx_layernorm(
...
@@ -356,7 +356,9 @@ def onnx_layernorm(
)
)
if
normalization
==
"RMSNorm"
:
if
normalization
==
"RMSNorm"
:
ln_out
=
torch
.
nn
.
functional
.
rms_norm
(
inp
,
inp
.
shape
[
-
1
:],
ln_weight
,
eps
)
variance
=
inp
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
ln_out
=
inp
*
torch
.
rsqrt
(
variance
+
eps
)
ln_out
=
ln_out
*
ln_weight
else
:
else
:
ln_out
=
torch
.
nn
.
functional
.
layer_norm
(
ln_out
=
torch
.
nn
.
functional
.
layer_norm
(
inp
,
inp
.
shape
[
-
1
:],
ln_weight
,
layer_norm_bias
,
eps
inp
,
inp
.
shape
[
-
1
:],
ln_weight
,
layer_norm_bias
,
eps
...
...
transformer_engine/pytorch/ops/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/_common.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -10,7 +10,7 @@ from typing import Optional
...
@@ -10,7 +10,7 @@ from typing import Optional
import
torch
import
torch
from
transformer_engine_torch
import
FP8TensorMeta
from
transformer_engine_torch
import
FP8TensorMeta
from
..
import
torch_version
from
..
torch_version
import
torch_version
from
..quantization
import
FP8GlobalStateManager
from
..quantization
import
FP8GlobalStateManager
from
..tensor.float8_tensor
import
Float8Tensor
from
..tensor.float8_tensor
import
Float8Tensor
from
..quantized_tensor
import
QuantizedTensorStorage
from
..quantized_tensor
import
QuantizedTensorStorage
...
...
transformer_engine/pytorch/ops/basic/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/activation.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -53,7 +53,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
...
@@ -53,7 +53,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
Parameters
Parameters
----------
----------
cache_quantized_input: bool, default = False
cache_quantized_input
: bool, default = False
Quantize input tensor when caching for use in the backward
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
extra compute and increase numerical error. This feature is
...
@@ -408,11 +408,11 @@ class ClampedSwiGLU(_ActivationOperation):
...
@@ -408,11 +408,11 @@ class ClampedSwiGLU(_ActivationOperation):
Parameters
Parameters
----------
----------
limit: float
limit
: float
The clamp limit.
The clamp limit.
alpha: float
alpha
: float
The scaling factor for the sigmoid function used in the activation.
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input: bool, default = False
cache_quantized_input
: bool, default = False
Quantize input tensor when caching for use in the backward pass.
Quantize input tensor when caching for use in the backward pass.
"""
"""
...
...
transformer_engine/pytorch/ops/basic/add_extra_input.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/all_gather.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -23,7 +23,7 @@ class AllGather(BasicOperation):
...
@@ -23,7 +23,7 @@ class AllGather(BasicOperation):
Parameters
Parameters
----------
----------
process_group: torch.distributed.ProcessGroup, default = world group
process_group
: torch.distributed.ProcessGroup, default = world group
Process group for communication
Process group for communication
"""
"""
...
...
transformer_engine/pytorch/ops/basic/all_reduce.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -24,7 +24,7 @@ class AllReduce(BasicOperation):
...
@@ -24,7 +24,7 @@ class AllReduce(BasicOperation):
Parameters
Parameters
----------
----------
process_group: torch.distributed.ProcessGroup, default = world group
process_group
: torch.distributed.ProcessGroup, default = world group
Process group for communication
Process group for communication
"""
"""
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -25,7 +25,6 @@ from ...module.base import (
...
@@ -25,7 +25,6 @@ from ...module.base import (
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
_2X_ACC_WGRAD
,
get_dummy_wgrad
,
get_dummy_wgrad
,
get_workspace
,
)
)
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
...
@@ -54,27 +53,27 @@ class BasicLinear(BasicOperation):
...
@@ -54,27 +53,27 @@ class BasicLinear(BasicOperation):
Parameters
Parameters
----------
----------
in_features: int
in_features
: int
Inner dimension of input tensor
Inner dimension of input tensor
out_features: int
out_features
: int
Inner dimension of output tensor
Inner dimension of output tensor
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
tensor_parallel_mode
: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
tensor_parallel_group
: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
sequence_parallel
: bool, default = `False`
Whether to apply sequence parallelism together with tensor
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
along inner dimension (embedding dim)
rng_state_tracker_function: callable
rng_state_tracker_function
: callable
Function that returns `CudaRNGStatesTracker`, which is used
Function that returns `CudaRNGStatesTracker`, which is used
for model-parallel weight initialization
for model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
accumulate_into_main_grad
: bool, default = `False`
Whether to directly accumulate weight gradients into the
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
autograd. The weight's `main_grad` must be set externally and
...
@@ -138,8 +137,10 @@ class BasicLinear(BasicOperation):
...
@@ -138,8 +137,10 @@ class BasicLinear(BasicOperation):
out_features
=
out_features
,
out_features
=
out_features
,
)
)
#
Whether weight tensor is
natively quantized
#
Initialize recipe state if needed for
natively quantized
weight
self
.
_with_quantized_weight
:
bool
=
FP8GlobalStateManager
.
with_fp8_parameters
()
self
.
_with_quantized_weight
:
bool
=
FP8GlobalStateManager
.
with_fp8_parameters
()
if
self
.
_with_quantized_weight
:
self
.
reset_recipe_state
(
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
())
# Initialize parameters if needed
# Initialize parameters if needed
weight
=
torch
.
empty
(
weight
=
torch
.
empty
(
...
@@ -341,15 +342,21 @@ class BasicLinear(BasicOperation):
...
@@ -341,15 +342,21 @@ class BasicLinear(BasicOperation):
def
reset_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
])
->
None
:
def
reset_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
])
->
None
:
super
().
reset_recipe_state
(
recipe
=
recipe
)
super
().
reset_recipe_state
(
recipe
=
recipe
)
# Input/grad output quantizers use internal tensors
# Configure input/grad output tensor
# Note: These tensors are only used internally. If there is no
# tensor-parallel communication, they are only used for GEMM.
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
if
input_quantizer
is
not
None
:
if
input_quantizer
is
not
None
:
input_quantizer
.
internal
=
True
input_quantizer
.
internal
=
True
if
not
(
self
.
tensor_parallel_mode
==
"column"
and
self
.
sequence_parallel
):
input_quantizer
.
optimize_for_gemm
=
True
if
grad_output_quantizer
is
not
None
:
if
grad_output_quantizer
is
not
None
:
grad_output_quantizer
.
internal
=
True
grad_output_quantizer
.
internal
=
True
if
not
(
self
.
tensor_parallel_mode
==
"row"
and
self
.
sequence_parallel
):
grad_output_quantizer
.
optimize_for_gemm
=
True
#
Handl
e weight quantizer
#
Configur
e weight quantizer
# Note: This function may be called in base class constructor,
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
# before any basic linear attrs have been set.
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
...
@@ -585,7 +592,6 @@ class BasicLinear(BasicOperation):
...
@@ -585,7 +592,6 @@ class BasicLinear(BasicOperation):
y
,
*
_
=
general_gemm
(
y
,
*
_
=
general_gemm
(
w
,
w
,
x
,
x
,
get_workspace
(),
out_dtype
=
dtype
,
out_dtype
=
dtype
,
quantization_params
=
output_quantizer
,
quantization_params
=
output_quantizer
,
alpha
=
alpha
,
alpha
=
alpha
,
...
@@ -875,7 +881,6 @@ class BasicLinear(BasicOperation):
...
@@ -875,7 +881,6 @@ class BasicLinear(BasicOperation):
dx
,
*
_
=
general_gemm
(
dx
,
*
_
=
general_gemm
(
w
,
w
,
dy
,
dy
,
get_workspace
(),
out_dtype
=
dtype
,
out_dtype
=
dtype
,
quantization_params
=
grad_input_quantizer
,
quantization_params
=
grad_input_quantizer
,
alpha
=
grad_input_alpha
,
alpha
=
grad_input_alpha
,
...
@@ -928,7 +933,6 @@ class BasicLinear(BasicOperation):
...
@@ -928,7 +933,6 @@ class BasicLinear(BasicOperation):
dw
,
*
_
=
general_gemm
(
dw
,
*
_
=
general_gemm
(
x
,
x
,
dy
,
dy
,
get_workspace
(),
out_dtype
=
dw_dtype
,
out_dtype
=
dw_dtype
,
alpha
=
grad_weight_alpha
,
alpha
=
grad_weight_alpha
,
beta
=
grad_weight_beta
,
beta
=
grad_weight_beta
,
...
...
transformer_engine/pytorch/ops/basic/bias.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -22,16 +22,16 @@ class Bias(BasicOperation):
...
@@ -22,16 +22,16 @@ class Bias(BasicOperation):
Parameters
Parameters
----------
----------
size: int
size
: int
Inner dimension of input tensor
Inner dimension of input tensor
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
Tensor datatype
tensor_parallel: bool, default = `False`
tensor_parallel
: bool, default = `False`
Whether to distribute input tensor and bias tensors along
Whether to distribute input tensor and bias tensors along
inner dimension
inner dimension
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
tensor_parallel_group
: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
Process group for tensor parallelism
"""
"""
...
...
transformer_engine/pytorch/ops/basic/constant_scale.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/dropout.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/identity.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/l2normalization.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -10,7 +10,7 @@ import os
...
@@ -10,7 +10,7 @@ import os
import
torch
import
torch
from
...
import
torch_version
from
...
torch_version
import
torch_version
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...jit
import
(
from
...jit
import
(
l2normalization_fused
,
l2normalization_fused
,
...
@@ -40,11 +40,11 @@ class L2Normalization(BasicOperation):
...
@@ -40,11 +40,11 @@ class L2Normalization(BasicOperation):
----------
----------
eps : float, default = 1e-6
eps : float, default = 1e-6
A value added to the denominator for numerical stability
A value added to the denominator for numerical stability
seq_length: int, default = None
seq_length
: int, default = None
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward
functions are warmed up before training to ensure same kernels are used for forward
propagation and activation recompute phase.
propagation and activation recompute phase.
micro_batch_size: int, default = None
micro_batch_size
: int, default = None
batch size per training step. Needed for JIT Warmup, a technique where jit
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
used for forward propagation and activation recompute phase.
...
...
transformer_engine/pytorch/ops/basic/layer_norm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -42,14 +42,14 @@ class LayerNorm(BasicOperation):
...
@@ -42,14 +42,14 @@ class LayerNorm(BasicOperation):
Parameters
Parameters
----------
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
Inner dimensions of input tensor
eps : float, default = 1e-5
eps : float, default = 1e-5
A value added to the denominator of layer normalization for
A value added to the denominator of layer normalization for
numerical stability
numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
If `True`, the :math:`\gamma` parameter is initialized to zero
...
@@ -58,7 +58,7 @@ class LayerNorm(BasicOperation):
...
@@ -58,7 +58,7 @@ class LayerNorm(BasicOperation):
.. math::
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0
sm_margin
: int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
For more fine-grained control, provide a dict with the SM
...
...
transformer_engine/pytorch/ops/basic/make_extra_output.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/basic/quantize.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -23,9 +23,9 @@ class Quantize(BasicOperation):
...
@@ -23,9 +23,9 @@ class Quantize(BasicOperation):
Parameters
Parameters
----------
----------
forward: bool, default = `True`
forward
: bool, default = `True`
Perform quantization in forward pass
Perform quantization in forward pass
backward: bool, default = `False`
backward
: bool, default = `False`
Perform quantization in backward pass
Perform quantization in backward pass
"""
"""
...
...
Prev
1
…
26
27
28
29
30
31
32
33
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment