Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
646
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
71 additions
and
70 deletions
+71
-70
transformer_engine/pytorch/ops/basic/reduce_scatter.py
transformer_engine/pytorch/ops/basic/reduce_scatter.py
+2
-2
transformer_engine/pytorch/ops/basic/reshape.py
transformer_engine/pytorch/ops/basic/reshape.py
+2
-2
transformer_engine/pytorch/ops/basic/rmsnorm.py
transformer_engine/pytorch/ops/basic/rmsnorm.py
+8
-6
transformer_engine/pytorch/ops/fused/__init__.py
transformer_engine/pytorch/ops/fused/__init__.py
+1
-1
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
...rmer_engine/pytorch/ops/fused/backward_activation_bias.py
+4
-4
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
+3
-3
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+3
-3
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
...sformer_engine/pytorch/ops/fused/backward_linear_scale.py
+3
-3
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+3
-3
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+3
-3
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
...rmer_engine/pytorch/ops/fused/forward_linear_scale_add.py
+3
-3
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+4
-6
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+3
-5
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+2
-2
transformer_engine/pytorch/ops/linear.py
transformer_engine/pytorch/ops/linear.py
+11
-11
transformer_engine/pytorch/ops/op.py
transformer_engine/pytorch/ops/op.py
+2
-5
transformer_engine/pytorch/ops/sequential.py
transformer_engine/pytorch/ops/sequential.py
+1
-1
transformer_engine/pytorch/optimizers/__init__.py
transformer_engine/pytorch/optimizers/__init__.py
+1
-1
transformer_engine/pytorch/optimizers/fused_adam.py
transformer_engine/pytorch/optimizers/fused_adam.py
+11
-5
transformer_engine/pytorch/optimizers/fused_sgd.py
transformer_engine/pytorch/optimizers/fused_sgd.py
+1
-1
No files found.
transformer_engine/pytorch/ops/basic/reduce_scatter.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -23,7 +23,7 @@ class ReduceScatter(BasicOperation):
...
@@ -23,7 +23,7 @@ class ReduceScatter(BasicOperation):
Parameters
Parameters
----------
----------
process_group: torch.distributed.ProcessGroup, default = world group
process_group
: torch.distributed.ProcessGroup, default = world group
Process group for communication
Process group for communication
"""
"""
...
...
transformer_engine/pytorch/ops/basic/reshape.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -24,7 +24,7 @@ class Reshape(BasicOperation):
...
@@ -24,7 +24,7 @@ class Reshape(BasicOperation):
Parameters
Parameters
----------
----------
shape: iterable of int
shape
: iterable of int
Output tensor dimensions. If one dimension is -1, it is
Output tensor dimensions. If one dimension is -1, it is
inferred based on input tensor dimensions.
inferred based on input tensor dimensions.
...
...
transformer_engine/pytorch/ops/basic/rmsnorm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -42,13 +42,13 @@ class RMSNorm(BasicOperation):
...
@@ -42,13 +42,13 @@ class RMSNorm(BasicOperation):
Parameters
Parameters
----------
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
Inner dimensions of input tensor
eps : float, default = 1e-5
eps : float, default = 1e-5
A value added to the denominator for numerical stability
A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
If `True`, the :math:`\gamma` parameter is initialized to zero
...
@@ -57,7 +57,7 @@ class RMSNorm(BasicOperation):
...
@@ -57,7 +57,7 @@ class RMSNorm(BasicOperation):
.. math::
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0
sm_margin
: int, default = 0
Number of SMs to exclude when launching CUDA kernels. This
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
For more fine-grained control, provide a dict with the SM
...
@@ -249,4 +249,6 @@ class RMSNorm(BasicOperation):
...
@@ -249,4 +249,6 @@ class RMSNorm(BasicOperation):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Every operand in this function has a defined ONNX translation."""
"""Every operand in this function has a defined ONNX translation."""
weight
=
self
.
weight
+
1
if
self
.
zero_centered_gamma
else
self
.
weight
weight
=
self
.
weight
+
1
if
self
.
zero_centered_gamma
else
self
.
weight
return
torch
.
nn
.
functional
.
rms_norm
(
input_
,
input_
.
shape
[
-
1
:],
weight
,
self
.
eps
)
variance
=
input_
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
normalized
=
input_
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
return
normalized
*
weight
transformer_engine/pytorch/ops/fused/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -90,15 +90,15 @@ def fuse_backward_activation_bias(
...
@@ -90,15 +90,15 @@ def fuse_backward_activation_bias(
Parameters
Parameters
----------
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
Backward pass operations and the indices of the corresponding
basic operations.
basic operations.
recipe: Recipe, optional
recipe
: Recipe, optional
Used quantization recipe
Used quantization recipe
Returns
Returns
-------
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
Updated backward pass operations
"""
"""
...
...
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -87,13 +87,13 @@ def fuse_backward_add_rmsnorm(
...
@@ -87,13 +87,13 @@ def fuse_backward_add_rmsnorm(
Parameters
Parameters
----------
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
Backward pass operations and the indices of the corresponding
basic operations.
basic operations.
Returns
Returns
-------
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
Updated backward pass operations
"""
"""
...
...
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -119,13 +119,13 @@ def fuse_backward_linear_add(
...
@@ -119,13 +119,13 @@ def fuse_backward_linear_add(
Parameters
Parameters
----------
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
Backward pass operations and the indices of the corresponding
basic operations.
basic operations.
Returns
Returns
-------
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
Updated backward pass operations
"""
"""
...
...
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -119,13 +119,13 @@ def fuse_backward_linear_scale(
...
@@ -119,13 +119,13 @@ def fuse_backward_linear_scale(
Parameters
Parameters
----------
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
Backward pass operations and the indices of the corresponding
basic operations.
basic operations.
Returns
Returns
-------
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
Updated backward pass operations
"""
"""
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -142,13 +142,13 @@ def fuse_forward_linear_bias_activation(
...
@@ -142,13 +142,13 @@ def fuse_forward_linear_bias_activation(
Parameters
Parameters
----------
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
Forward pass operations and the indices of the corresponding
basic operations.
basic operations.
Returns
Returns
-------
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
Updated forward pass operations
"""
"""
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -139,13 +139,13 @@ def fuse_forward_linear_bias_add(
...
@@ -139,13 +139,13 @@ def fuse_forward_linear_bias_add(
Parameters
Parameters
----------
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
Forward pass operations and the indices of the corresponding
basic operations.
basic operations.
Returns
Returns
-------
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
Updated forward pass operations
"""
"""
...
...
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -118,13 +118,13 @@ def fuse_forward_linear_scale_add(
...
@@ -118,13 +118,13 @@ def fuse_forward_linear_scale_add(
Parameters
Parameters
----------
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
Forward pass operations and the indices of the corresponding
basic operations.
basic operations.
Returns
Returns
-------
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
Updated forward pass operations
"""
"""
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -19,7 +19,6 @@ from ...module.base import (
...
@@ -19,7 +19,6 @@ from ...module.base import (
fill_userbuffers_buffer_for_all_gather
,
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_dummy_wgrad
,
get_ub
,
get_ub
,
get_workspace
,
)
)
from
...quantized_tensor
import
Quantizer
from
...quantized_tensor
import
Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
...
@@ -293,6 +292,7 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -293,6 +292,7 @@ class UserbuffersBackwardLinear(FusedOperation):
rowwise
=
True
,
rowwise
=
True
,
columnwise
=
with_columnwise
,
columnwise
=
with_columnwise
,
)
)
grad_output_quantizer
.
optimize_for_gemm
=
False
dy_local
=
grad_output_quantizer
(
dy_local
)
dy_local
=
grad_output_quantizer
(
dy_local
)
else
:
else
:
dy_local
=
maybe_dequantize
(
dy_local
,
dtype
)
dy_local
=
maybe_dequantize
(
dy_local
,
dtype
)
...
@@ -378,7 +378,6 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -378,7 +378,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dx
,
*
_
=
general_gemm
(
dx
,
*
_
=
general_gemm
(
w
,
w
,
dy
,
dy
,
get_workspace
(),
out_dtype
=
dtype
,
out_dtype
=
dtype
,
quantization_params
=
grad_input_quantizer
,
quantization_params
=
grad_input_quantizer
,
layout
=
"NN"
,
layout
=
"NN"
,
...
@@ -464,7 +463,6 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -464,7 +463,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dw
,
*
_
=
general_gemm
(
dw
,
*
_
=
general_gemm
(
x
,
x
,
dy
,
dy
,
get_workspace
(),
out_dtype
=
dw_dtype
,
out_dtype
=
dw_dtype
,
accumulate
=
accumulate_into_grad_weight
,
accumulate
=
accumulate_into_grad_weight
,
layout
=
"NT"
,
layout
=
"NT"
,
...
@@ -592,13 +590,13 @@ def fuse_userbuffers_backward_linear(
...
@@ -592,13 +590,13 @@ def fuse_userbuffers_backward_linear(
Parameters
Parameters
----------
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
Backward pass operations and the indices of the corresponding
basic operations.
basic operations.
Returns
Returns
-------
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
Updated backward pass operations
"""
"""
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -18,7 +18,6 @@ from ...quantization import FP8GlobalStateManager
...
@@ -18,7 +18,6 @@ from ...quantization import FP8GlobalStateManager
from
...module.base
import
(
from
...module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
fill_userbuffers_buffer_for_all_gather
,
get_ub
,
get_ub
,
get_workspace
,
_2X_ACC_FPROP
,
_2X_ACC_FPROP
,
)
)
from
...quantized_tensor
import
Quantizer
from
...quantized_tensor
import
Quantizer
...
@@ -243,7 +242,6 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -243,7 +242,6 @@ class UserbuffersForwardLinear(FusedOperation):
gemm_output
,
*
_
,
reduce_scatter_output
=
general_gemm
(
gemm_output
,
*
_
,
reduce_scatter_output
=
general_gemm
(
w
,
w
,
x
,
x
,
get_workspace
(),
out_dtype
=
dtype
,
out_dtype
=
dtype
,
quantization_params
=
output_quantizer
,
quantization_params
=
output_quantizer
,
bias
=
bias
,
bias
=
bias
,
...
@@ -379,13 +377,13 @@ def fuse_userbuffers_forward_linear(
...
@@ -379,13 +377,13 @@ def fuse_userbuffers_forward_linear(
Parameters
Parameters
----------
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
Forward pass operations and the indices of the corresponding
basic operations.
basic operations.
Returns
Returns
-------
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
Updated forward pass operations
"""
"""
...
...
transformer_engine/pytorch/ops/fuser.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -310,7 +310,7 @@ class OperationFuser:
...
@@ -310,7 +310,7 @@ class OperationFuser:
Parameters
Parameters
----------
----------
ops: list of FusibleOperation
ops
: list of FusibleOperation
Pipeline of operations
Pipeline of operations
"""
"""
...
...
transformer_engine/pytorch/ops/linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -27,29 +27,29 @@ class Linear(FusedOperation):
...
@@ -27,29 +27,29 @@ class Linear(FusedOperation):
Parameters
Parameters
----------
----------
in_features: int
in_features
: int
Inner dimension of input tensor
Inner dimension of input tensor
out_features: int
out_features
: int
Inner dimension of output tensor
Inner dimension of output tensor
bias: bool, default = `True`
bias
: bool, default = `True`
Apply additive bias
Apply additive bias
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
tensor_parallel_mode
: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
tensor_parallel_group
: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
sequence_parallel
: bool, default = `False`
Whether to apply sequence parallelism together with tensor
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
along inner dimension (embedding dim)
rng_state_tracker_function: callable
rng_state_tracker_function
: callable
Function that returns CudaRNGStatesTracker, which is used for
Function that returns CudaRNGStatesTracker, which is used for
model-parallel weight initialization
model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
accumulate_into_main_grad
: bool, default = `False`
Whether to directly accumulate weight gradients into the
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
autograd. The weight's `main_grad` must be set externally and
...
...
transformer_engine/pytorch/ops/op.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -188,9 +188,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
...
@@ -188,9 +188,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Objects for quantization
# Objects for quantization
self
.
_fp8_metas
:
Optional
[
dict
[
str
,
dict
[
str
,
Any
]]]
=
None
self
.
_fp8_metas
:
Optional
[
dict
[
str
,
dict
[
str
,
Any
]]]
=
None
self
.
_quantizers
:
Optional
[
dict
[
str
,
list
[
Quantizer
]]]
=
None
self
.
_quantizers
:
Optional
[
dict
[
str
,
list
[
Quantizer
]]]
=
None
with_fp8_parameters
=
FP8GlobalStateManager
.
with_fp8_parameters
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
with_fp8_parameters
else
None
self
.
reset_recipe_state
(
recipe
=
recipe
)
@
property
@
property
def
is_fused_op
(
self
)
->
bool
:
def
is_fused_op
(
self
)
->
bool
:
...
@@ -687,7 +684,7 @@ class FusedOperation(FusibleOperation):
...
@@ -687,7 +684,7 @@ class FusedOperation(FusibleOperation):
Parameters
Parameters
----------
----------
basic_ops: iterable of FusibleOperation
basic_ops
: iterable of FusibleOperation
Basic ops that are interchangeable with this op
Basic ops that are interchangeable with this op
"""
"""
...
...
transformer_engine/pytorch/ops/sequential.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/optimizers/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/optimizers/fused_adam.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -11,8 +11,10 @@ from typing import Optional
...
@@ -11,8 +11,10 @@ from typing import Optional
import
warnings
import
warnings
import
torch
import
torch
from
torch.distributed._tensor
import
DTensor
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8Quantizer
from
transformer_engine.pytorch.quantized_tensor
import
QuantizedTensor
from
.multi_tensor_apply
import
multi_tensor_applier
from
.multi_tensor_apply
import
multi_tensor_applier
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
@@ -371,10 +373,12 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -371,10 +373,12 @@ class FusedAdam(torch.optim.Optimizer):
store_param_remainders (bool): Store only trailing remainder bits.
store_param_remainders (bool): Store only trailing remainder bits.
"""
"""
dtype
=
self
.
name_to_dtype_map
[
state_name
]
dtype
=
self
.
name_to_dtype_map
[
state_name
]
# Handle QuantizedTensor by dequantizing first
param_for_empty
=
param
.
dequantize
()
if
isinstance
(
param
,
QuantizedTensor
)
else
param
if
store_param_remainders
:
if
store_param_remainders
:
data
=
torch
.
zeros
(
param
.
shape
,
dtype
=
torch
.
int16
,
device
=
param
.
device
)
data
=
torch
.
zeros
_like
(
param
_for_empty
,
dtype
=
torch
.
int16
)
else
:
else
:
data
=
torch
.
empty
(
param
.
shape
,
dtype
=
dtype
,
device
=
param
.
device
)
data
=
torch
.
empty
_like
(
param
_for_empty
,
dtype
=
dtype
)
if
zero_buffer
:
if
zero_buffer
:
data
.
zero_
()
data
.
zero_
()
...
@@ -567,8 +571,10 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -567,8 +571,10 @@ class FusedAdam(torch.optim.Optimizer):
unscaled_lists
[
name
].
append
(
unscaled
)
unscaled_lists
[
name
].
append
(
unscaled
)
scaled_lists
[
name
].
append
(
state
[
name
])
scaled_lists
[
name
].
append
(
state
[
name
])
state_scales
[
name
].
append
(
self
.
_scales
[
p
][
name
])
state_scales
[
name
].
append
(
self
.
_scales
[
p
][
name
])
if
isinstance
(
p
,
Float8Tensor
)
or
(
if
isinstance
(
p
,
Float8Tensor
):
isinstance
(
p
,
DTensor
)
and
isinstance
(
p
.
_local_tensor
,
Float8Tensor
)
):
p
=
p
.
_local_tensor
if
isinstance
(
p
,
DTensor
)
else
p
out_dtype
=
p
.
_fp8_dtype
out_dtype
=
p
.
_fp8_dtype
p_fp8_model
.
append
(
p
.
_data
.
data
)
p_fp8_model
.
append
(
p
.
_data
.
data
)
scale
,
amax
,
scale_inv
=
get_fp8_meta
(
p
)
scale
,
amax
,
scale_inv
=
get_fp8_meta
(
p
)
...
...
transformer_engine/pytorch/optimizers/fused_sgd.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
Prev
1
…
27
28
29
30
31
32
33
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment