Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
71 additions
and
70 deletions
+71
-70
transformer_engine/pytorch/ops/basic/reduce_scatter.py
transformer_engine/pytorch/ops/basic/reduce_scatter.py
+2
-2
transformer_engine/pytorch/ops/basic/reshape.py
transformer_engine/pytorch/ops/basic/reshape.py
+2
-2
transformer_engine/pytorch/ops/basic/rmsnorm.py
transformer_engine/pytorch/ops/basic/rmsnorm.py
+8
-6
transformer_engine/pytorch/ops/fused/__init__.py
transformer_engine/pytorch/ops/fused/__init__.py
+1
-1
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
...rmer_engine/pytorch/ops/fused/backward_activation_bias.py
+4
-4
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
+3
-3
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+3
-3
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
...sformer_engine/pytorch/ops/fused/backward_linear_scale.py
+3
-3
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+3
-3
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+3
-3
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
...rmer_engine/pytorch/ops/fused/forward_linear_scale_add.py
+3
-3
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+4
-6
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+3
-5
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+2
-2
transformer_engine/pytorch/ops/linear.py
transformer_engine/pytorch/ops/linear.py
+11
-11
transformer_engine/pytorch/ops/op.py
transformer_engine/pytorch/ops/op.py
+2
-5
transformer_engine/pytorch/ops/sequential.py
transformer_engine/pytorch/ops/sequential.py
+1
-1
transformer_engine/pytorch/optimizers/__init__.py
transformer_engine/pytorch/optimizers/__init__.py
+1
-1
transformer_engine/pytorch/optimizers/fused_adam.py
transformer_engine/pytorch/optimizers/fused_adam.py
+11
-5
transformer_engine/pytorch/optimizers/fused_sgd.py
transformer_engine/pytorch/optimizers/fused_sgd.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
transformer_engine/pytorch/ops/basic/reduce_scatter.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -23,7 +23,7 @@ class ReduceScatter(BasicOperation):
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
process_group
: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
...
...
transformer_engine/pytorch/ops/basic/reshape.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -24,7 +24,7 @@ class Reshape(BasicOperation):
Parameters
----------
shape: iterable of int
shape
: iterable of int
Output tensor dimensions. If one dimension is -1, it is
inferred based on input tensor dimensions.
...
...
transformer_engine/pytorch/ops/basic/rmsnorm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -42,13 +42,13 @@ class RMSNorm(BasicOperation):
Parameters
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
...
...
@@ -57,7 +57,7 @@ class RMSNorm(BasicOperation):
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0
sm_margin
: int, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
...
...
@@ -249,4 +249,6 @@ class RMSNorm(BasicOperation):
)
->
torch
.
Tensor
:
"""Every operand in this function has a defined ONNX translation."""
weight
=
self
.
weight
+
1
if
self
.
zero_centered_gamma
else
self
.
weight
return
torch
.
nn
.
functional
.
rms_norm
(
input_
,
input_
.
shape
[
-
1
:],
weight
,
self
.
eps
)
variance
=
input_
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
normalized
=
input_
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
return
normalized
*
weight
transformer_engine/pytorch/ops/fused/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -90,15 +90,15 @@ def fuse_backward_activation_bias(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
recipe: Recipe, optional
recipe
: Recipe, optional
Used quantization recipe
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -87,13 +87,13 @@ def fuse_backward_add_rmsnorm(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -119,13 +119,13 @@ def fuse_backward_linear_add(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -119,13 +119,13 @@ def fuse_backward_linear_scale(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -142,13 +142,13 @@ def fuse_forward_linear_bias_activation(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -139,13 +139,13 @@ def fuse_forward_linear_bias_add(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -118,13 +118,13 @@ def fuse_forward_linear_scale_add(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -19,7 +19,6 @@ from ...module.base import (
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_ub
,
get_workspace
,
)
from
...quantized_tensor
import
Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
...
...
@@ -293,6 +292,7 @@ class UserbuffersBackwardLinear(FusedOperation):
rowwise
=
True
,
columnwise
=
with_columnwise
,
)
grad_output_quantizer
.
optimize_for_gemm
=
False
dy_local
=
grad_output_quantizer
(
dy_local
)
else
:
dy_local
=
maybe_dequantize
(
dy_local
,
dtype
)
...
...
@@ -378,7 +378,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dx
,
*
_
=
general_gemm
(
w
,
dy
,
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
grad_input_quantizer
,
layout
=
"NN"
,
...
...
@@ -464,7 +463,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dw
,
*
_
=
general_gemm
(
x
,
dy
,
get_workspace
(),
out_dtype
=
dw_dtype
,
accumulate
=
accumulate_into_grad_weight
,
layout
=
"NT"
,
...
...
@@ -592,13 +590,13 @@ def fuse_userbuffers_backward_linear(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -18,7 +18,6 @@ from ...quantization import FP8GlobalStateManager
from
...module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_ub
,
get_workspace
,
_2X_ACC_FPROP
,
)
from
...quantized_tensor
import
Quantizer
...
...
@@ -243,7 +242,6 @@ class UserbuffersForwardLinear(FusedOperation):
gemm_output
,
*
_
,
reduce_scatter_output
=
general_gemm
(
w
,
x
,
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
output_quantizer
,
bias
=
bias
,
...
...
@@ -379,13 +377,13 @@ def fuse_userbuffers_forward_linear(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
"""
...
...
transformer_engine/pytorch/ops/fuser.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -310,7 +310,7 @@ class OperationFuser:
Parameters
----------
ops: list of FusibleOperation
ops
: list of FusibleOperation
Pipeline of operations
"""
...
...
transformer_engine/pytorch/ops/linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -27,29 +27,29 @@ class Linear(FusedOperation):
Parameters
----------
in_features: int
in_features
: int
Inner dimension of input tensor
out_features: int
out_features
: int
Inner dimension of output tensor
bias: bool, default = `True`
bias
: bool, default = `True`
Apply additive bias
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
tensor_parallel_mode
: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
tensor_parallel_group
: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
sequence_parallel
: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
rng_state_tracker_function: callable
rng_state_tracker_function
: callable
Function that returns CudaRNGStatesTracker, which is used for
model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
accumulate_into_main_grad
: bool, default = `False`
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
...
...
transformer_engine/pytorch/ops/op.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -188,9 +188,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Objects for quantization
self
.
_fp8_metas
:
Optional
[
dict
[
str
,
dict
[
str
,
Any
]]]
=
None
self
.
_quantizers
:
Optional
[
dict
[
str
,
list
[
Quantizer
]]]
=
None
with_fp8_parameters
=
FP8GlobalStateManager
.
with_fp8_parameters
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
with_fp8_parameters
else
None
self
.
reset_recipe_state
(
recipe
=
recipe
)
@
property
def
is_fused_op
(
self
)
->
bool
:
...
...
@@ -687,7 +684,7 @@ class FusedOperation(FusibleOperation):
Parameters
----------
basic_ops: iterable of FusibleOperation
basic_ops
: iterable of FusibleOperation
Basic ops that are interchangeable with this op
"""
...
...
transformer_engine/pytorch/ops/sequential.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/optimizers/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/optimizers/fused_adam.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -11,8 +11,10 @@ from typing import Optional
import
warnings
import
torch
from
torch.distributed._tensor
import
DTensor
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8Quantizer
from
transformer_engine.pytorch.quantized_tensor
import
QuantizedTensor
from
.multi_tensor_apply
import
multi_tensor_applier
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
...
@@ -371,10 +373,12 @@ class FusedAdam(torch.optim.Optimizer):
store_param_remainders (bool): Store only trailing remainder bits.
"""
dtype
=
self
.
name_to_dtype_map
[
state_name
]
# Handle QuantizedTensor by dequantizing first
param_for_empty
=
param
.
dequantize
()
if
isinstance
(
param
,
QuantizedTensor
)
else
param
if
store_param_remainders
:
data
=
torch
.
zeros
(
param
.
shape
,
dtype
=
torch
.
int16
,
device
=
param
.
device
)
data
=
torch
.
zeros
_like
(
param
_for_empty
,
dtype
=
torch
.
int16
)
else
:
data
=
torch
.
empty
(
param
.
shape
,
dtype
=
dtype
,
device
=
param
.
device
)
data
=
torch
.
empty
_like
(
param
_for_empty
,
dtype
=
dtype
)
if
zero_buffer
:
data
.
zero_
()
...
...
@@ -567,8 +571,10 @@ class FusedAdam(torch.optim.Optimizer):
unscaled_lists
[
name
].
append
(
unscaled
)
scaled_lists
[
name
].
append
(
state
[
name
])
state_scales
[
name
].
append
(
self
.
_scales
[
p
][
name
])
if
isinstance
(
p
,
Float8Tensor
):
if
isinstance
(
p
,
Float8Tensor
)
or
(
isinstance
(
p
,
DTensor
)
and
isinstance
(
p
.
_local_tensor
,
Float8Tensor
)
):
p
=
p
.
_local_tensor
if
isinstance
(
p
,
DTensor
)
else
p
out_dtype
=
p
.
_fp8_dtype
p_fp8_model
.
append
(
p
.
_data
.
data
)
scale
,
amax
,
scale_inv
=
get_fp8_meta
(
p
)
...
...
transformer_engine/pytorch/optimizers/fused_sgd.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
Prev
1
…
27
28
29
30
31
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment