Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
970620a5
Commit
970620a5
authored
Dec 27, 2025
by
wenjh
Browse files
merge nv_release_v2.10 to release_v2.10
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
c1a1c04e
769ed778
Changes
135
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
66 additions
and
71 deletions
+66
-71
transformer_engine/pytorch/ops/basic/all_reduce.py
transformer_engine/pytorch/ops/basic/all_reduce.py
+1
-1
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+12
-14
transformer_engine/pytorch/ops/basic/bias.py
transformer_engine/pytorch/ops/basic/bias.py
+5
-5
transformer_engine/pytorch/ops/basic/l2normalization.py
transformer_engine/pytorch/ops/basic/l2normalization.py
+3
-3
transformer_engine/pytorch/ops/basic/layer_norm.py
transformer_engine/pytorch/ops/basic/layer_norm.py
+4
-4
transformer_engine/pytorch/ops/basic/quantize.py
transformer_engine/pytorch/ops/basic/quantize.py
+2
-2
transformer_engine/pytorch/ops/basic/reduce_scatter.py
transformer_engine/pytorch/ops/basic/reduce_scatter.py
+1
-1
transformer_engine/pytorch/ops/basic/reshape.py
transformer_engine/pytorch/ops/basic/reshape.py
+1
-1
transformer_engine/pytorch/ops/basic/rmsnorm.py
transformer_engine/pytorch/ops/basic/rmsnorm.py
+7
-5
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
...rmer_engine/pytorch/ops/fused/backward_activation_bias.py
+3
-3
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
+2
-2
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+2
-2
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
...sformer_engine/pytorch/ops/fused/backward_linear_scale.py
+2
-2
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+2
-2
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+2
-2
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
...rmer_engine/pytorch/ops/fused/forward_linear_scale_add.py
+2
-2
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+2
-5
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+2
-4
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+1
-1
transformer_engine/pytorch/ops/linear.py
transformer_engine/pytorch/ops/linear.py
+10
-10
No files found.
transformer_engine/pytorch/ops/basic/all_reduce.py
View file @
970620a5
...
...
@@ -24,7 +24,7 @@ class AllReduce(BasicOperation):
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
process_group
: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
970620a5
...
...
@@ -25,7 +25,6 @@ from ...module.base import (
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
get_dummy_wgrad
,
get_workspace
,
)
from
...tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
...
...
@@ -54,27 +53,27 @@ class BasicLinear(BasicOperation):
Parameters
----------
in_features: int
in_features
: int
Inner dimension of input tensor
out_features: int
out_features
: int
Inner dimension of output tensor
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
tensor_parallel_mode
: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
tensor_parallel_group
: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
sequence_parallel
: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
rng_state_tracker_function: callable
rng_state_tracker_function
: callable
Function that returns `CudaRNGStatesTracker`, which is used
for model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
accumulate_into_main_grad
: bool, default = `False`
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
...
...
@@ -138,8 +137,10 @@ class BasicLinear(BasicOperation):
out_features
=
out_features
,
)
#
Whether weight tensor is
natively quantized
#
Initialize recipe state if needed for
natively quantized
weight
self
.
_with_quantized_weight
:
bool
=
FP8GlobalStateManager
.
with_fp8_parameters
()
if
self
.
_with_quantized_weight
:
self
.
reset_recipe_state
(
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
())
# Initialize parameters if needed
weight
=
torch
.
empty
(
...
...
@@ -585,7 +586,6 @@ class BasicLinear(BasicOperation):
y
,
*
_
=
general_gemm
(
w
,
x
,
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
output_quantizer
,
alpha
=
alpha
,
...
...
@@ -875,7 +875,6 @@ class BasicLinear(BasicOperation):
dx
,
*
_
=
general_gemm
(
w
,
dy
,
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
grad_input_quantizer
,
alpha
=
grad_input_alpha
,
...
...
@@ -928,7 +927,6 @@ class BasicLinear(BasicOperation):
dw
,
*
_
=
general_gemm
(
x
,
dy
,
get_workspace
(),
out_dtype
=
dw_dtype
,
alpha
=
grad_weight_alpha
,
beta
=
grad_weight_beta
,
...
...
transformer_engine/pytorch/ops/basic/bias.py
View file @
970620a5
...
...
@@ -22,16 +22,16 @@ class Bias(BasicOperation):
Parameters
----------
size: int
size
: int
Inner dimension of input tensor
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel: bool, default = `False`
tensor_parallel
: bool, default = `False`
Whether to distribute input tensor and bias tensors along
inner dimension
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
tensor_parallel_group
: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
"""
...
...
transformer_engine/pytorch/ops/basic/l2normalization.py
View file @
970620a5
...
...
@@ -10,7 +10,7 @@ import os
import
torch
from
...
import
torch_version
from
...
torch_version
import
torch_version
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...jit
import
(
l2normalization_fused
,
...
...
@@ -40,11 +40,11 @@ class L2Normalization(BasicOperation):
----------
eps : float, default = 1e-6
A value added to the denominator for numerical stability
seq_length: int, default = None
seq_length
: int, default = None
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward
propagation and activation recompute phase.
micro_batch_size: int, default = None
micro_batch_size
: int, default = None
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
...
...
transformer_engine/pytorch/ops/basic/layer_norm.py
View file @
970620a5
...
...
@@ -42,14 +42,14 @@ class LayerNorm(BasicOperation):
Parameters
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator of layer normalization for
numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
...
...
@@ -58,7 +58,7 @@ class LayerNorm(BasicOperation):
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0
sm_margin
: int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
...
...
transformer_engine/pytorch/ops/basic/quantize.py
View file @
970620a5
...
...
@@ -23,9 +23,9 @@ class Quantize(BasicOperation):
Parameters
----------
forward: bool, default = `True`
forward
: bool, default = `True`
Perform quantization in forward pass
backward: bool, default = `False`
backward
: bool, default = `False`
Perform quantization in backward pass
"""
...
...
transformer_engine/pytorch/ops/basic/reduce_scatter.py
View file @
970620a5
...
...
@@ -23,7 +23,7 @@ class ReduceScatter(BasicOperation):
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
process_group
: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
...
...
transformer_engine/pytorch/ops/basic/reshape.py
View file @
970620a5
...
...
@@ -24,7 +24,7 @@ class Reshape(BasicOperation):
Parameters
----------
shape: iterable of int
shape
: iterable of int
Output tensor dimensions. If one dimension is -1, it is
inferred based on input tensor dimensions.
...
...
transformer_engine/pytorch/ops/basic/rmsnorm.py
View file @
970620a5
...
...
@@ -42,13 +42,13 @@ class RMSNorm(BasicOperation):
Parameters
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
...
...
@@ -57,7 +57,7 @@ class RMSNorm(BasicOperation):
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0
sm_margin
: int, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
...
...
@@ -249,4 +249,6 @@ class RMSNorm(BasicOperation):
)
->
torch
.
Tensor
:
"""Every operand in this function has a defined ONNX translation."""
weight
=
self
.
weight
+
1
if
self
.
zero_centered_gamma
else
self
.
weight
return
torch
.
nn
.
functional
.
rms_norm
(
input_
,
input_
.
shape
[
-
1
:],
weight
,
self
.
eps
)
variance
=
input_
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
normalized
=
input_
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
return
normalized
*
weight
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
View file @
970620a5
...
...
@@ -90,15 +90,15 @@ def fuse_backward_activation_bias(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
recipe: Recipe, optional
recipe
: Recipe, optional
Used quantization recipe
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
View file @
970620a5
...
...
@@ -87,13 +87,13 @@ def fuse_backward_add_rmsnorm(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
970620a5
...
...
@@ -119,13 +119,13 @@ def fuse_backward_linear_add(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
View file @
970620a5
...
...
@@ -119,13 +119,13 @@ def fuse_backward_linear_scale(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
970620a5
...
...
@@ -142,13 +142,13 @@ def fuse_forward_linear_bias_activation(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
970620a5
...
...
@@ -139,13 +139,13 @@ def fuse_forward_linear_bias_add(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
View file @
970620a5
...
...
@@ -118,13 +118,13 @@ def fuse_forward_linear_scale_add(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
970620a5
...
...
@@ -19,7 +19,6 @@ from ...module.base import (
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_ub
,
get_workspace
,
)
from
...quantized_tensor
import
Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
...
...
@@ -378,7 +377,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dx
,
*
_
=
general_gemm
(
w
,
dy
,
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
grad_input_quantizer
,
layout
=
"NN"
,
...
...
@@ -464,7 +462,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dw
,
*
_
=
general_gemm
(
x
,
dy
,
get_workspace
(),
out_dtype
=
dw_dtype
,
accumulate
=
accumulate_into_grad_weight
,
layout
=
"NT"
,
...
...
@@ -592,13 +589,13 @@ def fuse_userbuffers_backward_linear(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated backward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
970620a5
...
...
@@ -18,7 +18,6 @@ from ...quantization import FP8GlobalStateManager
from
...module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_ub
,
get_workspace
,
_2X_ACC_FPROP
,
)
from
...quantized_tensor
import
Quantizer
...
...
@@ -243,7 +242,6 @@ class UserbuffersForwardLinear(FusedOperation):
gemm_output
,
*
_
,
reduce_scatter_output
=
general_gemm
(
w
,
x
,
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
output_quantizer
,
bias
=
bias
,
...
...
@@ -379,13 +377,13 @@ def fuse_userbuffers_forward_linear(
Parameters
----------
ops: list of tuples
ops
: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
ops
: list of tuples
Updated forward pass operations
"""
...
...
transformer_engine/pytorch/ops/fuser.py
View file @
970620a5
...
...
@@ -310,7 +310,7 @@ class OperationFuser:
Parameters
----------
ops: list of FusibleOperation
ops
: list of FusibleOperation
Pipeline of operations
"""
...
...
transformer_engine/pytorch/ops/linear.py
View file @
970620a5
...
...
@@ -27,29 +27,29 @@ class Linear(FusedOperation):
Parameters
----------
in_features: int
in_features
: int
Inner dimension of input tensor
out_features: int
out_features
: int
Inner dimension of output tensor
bias: bool, default = `True`
bias
: bool, default = `True`
Apply additive bias
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
tensor_parallel_mode
: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
tensor_parallel_group
: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
sequence_parallel
: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
rng_state_tracker_function: callable
rng_state_tracker_function
: callable
Function that returns CudaRNGStatesTracker, which is used for
model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
accumulate_into_main_grad
: bool, default = `False`
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment