Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
970620a5
Commit
970620a5
authored
Dec 27, 2025
by
wenjh
Browse files
merge nv_release_v2.10 to release_v2.10
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
c1a1c04e
769ed778
Changes
135
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
465 additions
and
320 deletions
+465
-320
transformer_engine/pytorch/ops/op.py
transformer_engine/pytorch/ops/op.py
+1
-4
transformer_engine/pytorch/permutation.py
transformer_engine/pytorch/permutation.py
+22
-22
transformer_engine/pytorch/quantization.py
transformer_engine/pytorch/quantization.py
+11
-11
transformer_engine/pytorch/quantized_tensor.py
transformer_engine/pytorch/quantized_tensor.py
+4
-9
transformer_engine/pytorch/router.py
transformer_engine/pytorch/router.py
+22
-22
transformer_engine/pytorch/setup.py
transformer_engine/pytorch/setup.py
+17
-5
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+22
-6
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+50
-10
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+19
-5
transformer_engine/pytorch/tensor/nvfp4_tensor.py
transformer_engine/pytorch/tensor/nvfp4_tensor.py
+29
-9
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+1
-1
transformer_engine/pytorch/torch_version.py
transformer_engine/pytorch/torch_version.py
+15
-0
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+191
-175
transformer_engine/pytorch/triton/permutation.py
transformer_engine/pytorch/triton/permutation.py
+38
-38
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+23
-3
No files found.
transformer_engine/pytorch/ops/op.py
View file @
970620a5
...
@@ -188,9 +188,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
...
@@ -188,9 +188,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Objects for quantization
# Objects for quantization
self
.
_fp8_metas
:
Optional
[
dict
[
str
,
dict
[
str
,
Any
]]]
=
None
self
.
_fp8_metas
:
Optional
[
dict
[
str
,
dict
[
str
,
Any
]]]
=
None
self
.
_quantizers
:
Optional
[
dict
[
str
,
list
[
Quantizer
]]]
=
None
self
.
_quantizers
:
Optional
[
dict
[
str
,
list
[
Quantizer
]]]
=
None
with_fp8_parameters
=
FP8GlobalStateManager
.
with_fp8_parameters
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
with_fp8_parameters
else
None
self
.
reset_recipe_state
(
recipe
=
recipe
)
@
property
@
property
def
is_fused_op
(
self
)
->
bool
:
def
is_fused_op
(
self
)
->
bool
:
...
@@ -687,7 +684,7 @@ class FusedOperation(FusibleOperation):
...
@@ -687,7 +684,7 @@ class FusedOperation(FusibleOperation):
Parameters
Parameters
----------
----------
basic_ops: iterable of FusibleOperation
basic_ops
: iterable of FusibleOperation
Basic ops that are interchangeable with this op
Basic ops that are interchangeable with this op
"""
"""
...
...
transformer_engine/pytorch/permutation.py
View file @
970620a5
...
@@ -514,22 +514,22 @@ def moe_permute(
...
@@ -514,22 +514,22 @@ def moe_permute(
Parameters
Parameters
----------
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
routing_map: torch.Tensor
routing_map
: torch.Tensor
The token to expert mapping tensor.
The token to expert mapping tensor.
If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
The values in it: 1 means the token is routed to this expert and 0 means not.
If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
The values in it are the routed expert indices.
The values in it are the routed expert indices.
num_out_tokens: int, default = -1
num_out_tokens
: int, default = -1
The effective output token count, representing the number of tokens not dropped.
The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped.
By default, set to '-1', meaning no tokens are dropped.
max_token_num: int, default = -1
max_token_num
: int, default = -1
The maximum number of tokens, used for workspace allocation.
The maximum number of tokens, used for workspace allocation.
By default, set to '-1', meaning the calculation of the size of workspace is
By default, set to '-1', meaning the calculation of the size of workspace is
automatically taken over by the operator.
automatically taken over by the operator.
map_type: str, default = 'mask'
map_type
: str, default = 'mask'
Type of the routing map tensor.
Type of the routing map tensor.
Options are: 'mask', 'index'.
Options are: 'mask', 'index'.
Refer to `routing_map` for more details.
Refer to `routing_map` for more details.
...
@@ -556,16 +556,16 @@ def moe_permute_with_probs(
...
@@ -556,16 +556,16 @@ def moe_permute_with_probs(
Parameters
Parameters
----------
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
probs
: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens, num_experts]. It will be permuted with the tokens
of shape [num_tokens, num_experts]. It will be permuted with the tokens
according to the routing_map.
according to the routing_map.
routing_map: torch.Tensor
routing_map
: torch.Tensor
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
The values in it: 1 means the token is routed to this expert and 0 means not.
num_out_tokens: int, default = -1
num_out_tokens
: int, default = -1
The effective output token count, representing the number of tokens not dropped.
The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped.
By default, set to '-1', meaning no tokens are dropped.
"""
"""
...
@@ -589,21 +589,21 @@ def moe_unpermute(
...
@@ -589,21 +589,21 @@ def moe_unpermute(
Parameters
Parameters
----------
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
row_id_map: torch.Tensor
row_id_map
: torch.Tensor
The tensor of a mapping table for sorted indices used to unpermute the tokens,
The tensor of a mapping table for sorted indices used to unpermute the tokens,
which is the second output tensor of `Permute`.
which is the second output tensor of `Permute`.
merging_probs: torch.Tensor, default = None
merging_probs
: torch.Tensor, default = None
The tensor of probabilities corresponding to the permuted tokens. If provided,
The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities.
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.Size, default = None
restore_shape
: torch.Size, default = None
The output shape after the unpermute operation.
The output shape after the unpermute operation.
map_type: str, default = 'mask'
map_type
: str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
Options are: 'mask', 'index'.
Options are: 'mask', 'index'.
probs: torch.Tensor, default = None
probs
: torch.Tensor, default = None
Renamed to merging_probs. Keep for backward compatibility.
Renamed to merging_probs. Keep for backward compatibility.
"""
"""
if
probs
is
not
None
:
if
probs
is
not
None
:
...
@@ -733,11 +733,11 @@ def moe_sort_chunks_by_index(
...
@@ -733,11 +733,11 @@ def moe_sort_chunks_by_index(
Parameters
Parameters
----------
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
split_sizes: torch.Tensor
split_sizes
: torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension.
Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor
sorted_indices
: torch.Tensor
Chunk indices used to permute the chunks.
Chunk indices used to permute the chunks.
"""
"""
output
,
_
=
_moe_chunk_sort
.
apply
(
inp
,
split_sizes
,
sorted_index
,
None
)
output
,
_
=
_moe_chunk_sort
.
apply
(
inp
,
split_sizes
,
sorted_index
,
None
)
...
@@ -757,15 +757,15 @@ def moe_sort_chunks_by_index_with_probs(
...
@@ -757,15 +757,15 @@ def moe_sort_chunks_by_index_with_probs(
Parameters
Parameters
----------
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
probs
: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens]. It will be permuted with the tokens according to
of shape [num_tokens]. It will be permuted with the tokens according to
the split_sizes and sorted_indices.
the split_sizes and sorted_indices.
split_sizes: torch.Tensor
split_sizes
: torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension.
Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor
sorted_indices
: torch.Tensor
Chunk indices used to permute the chunks.
Chunk indices used to permute the chunks.
"""
"""
output
,
permuted_probs
=
_moe_chunk_sort
.
apply
(
inp
,
split_sizes
,
sorted_index
,
probs
)
output
,
permuted_probs
=
_moe_chunk_sort
.
apply
(
inp
,
split_sizes
,
sorted_index
,
probs
)
...
...
transformer_engine/pytorch/quantization.py
View file @
970620a5
...
@@ -26,8 +26,8 @@ from transformer_engine.common.recipe import (
...
@@ -26,8 +26,8 @@ from transformer_engine.common.recipe import (
NVFP4BlockScaling
,
NVFP4BlockScaling
,
CustomRecipe
,
CustomRecipe
,
)
)
from
.constants
import
dist_group_type
from
.constants
import
dist_group_type
from
.utils
import
get_device_compute_capability
from
.utils
import
get_device_compute_capability
from
.jit
import
jit_fuser
from
.jit
import
jit_fuser
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
@@ -687,7 +687,7 @@ def fp8_model_init(
...
@@ -687,7 +687,7 @@ def fp8_model_init(
.. warning::
.. warning::
fp8_model_init is deprecated and will be removed in a future release. Use
fp8_model_init is deprecated and will be removed in a future release. Use
quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...) instead.
``
quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...)
``
instead.
"""
"""
...
@@ -732,7 +732,7 @@ def quantized_model_init(
...
@@ -732,7 +732,7 @@ def quantized_model_init(
Parameters
Parameters
----------
----------
enabled: bool, default =
`
True
`
enabled
: bool, default = True
when enabled, Transformer Engine modules created inside this `quantized_model_init`
when enabled, Transformer Engine modules created inside this `quantized_model_init`
region will hold only quantized copies of its parameters, as opposed to the default
region will hold only quantized copies of its parameters, as opposed to the default
behavior where both higher precision and quantized copies are present. Setting this
behavior where both higher precision and quantized copies are present. Setting this
...
@@ -743,9 +743,9 @@ def quantized_model_init(
...
@@ -743,9 +743,9 @@ def quantized_model_init(
precision copies of weights are already present in the optimizer.
precision copies of weights are already present in the optimizer.
* inference, where only the quantized copies of the parameters are used.
* inference, where only the quantized copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default =
`
None
`
recipe
: transformer_engine.common.recipe.Recipe, default = None
Recipe used to create the parameters. If left to None, it uses the default recipe.
Recipe used to create the parameters. If left to None, it uses the default recipe.
preserve_high_precision_init_val: bool, default =
`
False
`
preserve_high_precision_init_val
: bool, default = False
when enabled, store the high precision tensor used to initialize quantized parameters
when enabled, store the high precision tensor used to initialize quantized parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to quantized parameters to get/clear this high
and `clear_high_precision_init_val()` to quantized parameters to get/clear this high
...
@@ -782,8 +782,8 @@ def fp8_autocast(
...
@@ -782,8 +782,8 @@ def fp8_autocast(
"""
"""
.. warning::
.. warning::
fp8_autocast is deprecated and will be removed in a future release.
``
fp8_autocast
``
is deprecated and will be removed in a future release.
Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead.
Use
``
autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...)
``
instead.
"""
"""
...
@@ -837,16 +837,16 @@ def autocast(
...
@@ -837,16 +837,16 @@ def autocast(
Parameters
Parameters
----------
----------
enabled: bool, default =
`
True
`
enabled
: bool, default = True
whether or not to enable low precision quantization (FP8/FP4).
whether or not to enable low precision quantization (FP8/FP4).
calibrating: bool, default =
`
False
`
calibrating
: bool, default = False
calibration mode allows collecting statistics such as amax and scale
calibration mode allows collecting statistics such as amax and scale
data of quantized tensors even when executing without quantization enabled.
data of quantized tensors even when executing without quantization enabled.
This is useful for saving an inference ready checkpoint while training
This is useful for saving an inference ready checkpoint while training
using a higher precision.
using a higher precision.
recipe: recipe.Recipe, default =
`
None
`
recipe
: recipe.Recipe, default = None
recipe used for low precision quantization.
recipe used for low precision quantization.
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default =
`
None
`
amax_reduction_group
: torch._C._distributed_c10d.ProcessGroup, default = None
distributed group over which amaxes for the quantized tensors
distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step.
are reduced at the end of each training step.
"""
"""
...
...
transformer_engine/pytorch/quantized_tensor.py
View file @
970620a5
...
@@ -7,7 +7,6 @@
...
@@ -7,7 +7,6 @@
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
,
Any
,
Dict
,
Union
from
typing
import
Optional
,
Tuple
,
Iterable
,
Any
,
Dict
,
Union
import
abc
import
abc
import
copy
import
warnings
import
warnings
import
math
import
math
...
@@ -28,7 +27,7 @@ _quantized_tensor_cpu_supported_ops = (
...
@@ -28,7 +27,7 @@ _quantized_tensor_cpu_supported_ops = (
class
QuantizedTensorStorage
:
class
QuantizedTensorStorage
:
r
"""Base class for all
*
TensorStorage classes.
r
"""Base class for all TensorStorage classes.
This class (and its subclasses) are optimization for when
This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully
the full QuantizedTensor is not needed (when it is fully
...
@@ -55,11 +54,11 @@ class QuantizedTensorStorage:
...
@@ -55,11 +54,11 @@ class QuantizedTensorStorage:
Parameters
Parameters
----------
----------
rowwise_usage : Optional[bool[, default =
`
None
`
rowwise_usage : Optional[bool[, default = None
Whether to create or keep the data needed for using the tensor
Whether to create or keep the data needed for using the tensor
in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None`
in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None`
preserves the original value in the tensor.
preserves the original value in the tensor.
columnwise_usage : Optional[bool], default =
`
None
`
columnwise_usage : Optional[bool], default = None
Whether to create or keep the data needed for using the tensor
Whether to create or keep the data needed for using the tensor
in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as
in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as
`None` preserves the original value in the tensor.
`None` preserves the original value in the tensor.
...
@@ -129,7 +128,7 @@ def prepare_for_saving(
...
@@ -129,7 +128,7 @@ def prepare_for_saving(
]:
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal
*
TensorStorage types too."""
the internal TensorStorage types too."""
tensor_list
,
tensor_objects_list
=
[],
[]
tensor_list
,
tensor_objects_list
=
[],
[]
for
tensor
in
tensors
:
for
tensor
in
tensors
:
...
@@ -297,10 +296,6 @@ class Quantizer(abc.ABC):
...
@@ -297,10 +296,6 @@ class Quantizer(abc.ABC):
if
columnwise
is
not
None
:
if
columnwise
is
not
None
:
self
.
columnwise_usage
=
columnwise
self
.
columnwise_usage
=
columnwise
def
copy
(
self
)
->
Quantizer
:
"""Create shallow copy"""
return
copy
.
copy
(
self
)
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Symbolic function for ONNX export"""
"""Symbolic function for ONNX export"""
raise
NotImplementedError
(
raise
NotImplementedError
(
...
...
transformer_engine/pytorch/router.py
View file @
970620a5
...
@@ -92,24 +92,24 @@ def fused_topk_with_score_function(
...
@@ -92,24 +92,24 @@ def fused_topk_with_score_function(
Fused topk with score function router.
Fused topk with score function router.
Parameters
Parameters
----------
----------
logits: torch.Tensor
logits
: torch.Tensor
topk: int
topk
: int
use_pre_softmax: bool
use_pre_softmax
: bool
if enabled, the computation order: softmax -> topk
if enabled, the computation order: softmax -> topk
num_groups: int
num_groups
: int
used in the group topk
used in the group topk
group_topk: int
group_topk
: int
used in the group topk
used in the group topk
scaling_factor: float
scaling_factor
: float
score_function: str
score_function
: str
currently only support softmax and sigmoid
currently only support softmax and sigmoid
expert_bias: torch.Tensor
expert_bias
: torch.Tensor
could be used in the sigmoid
could be used in the sigmoid
Returns
Returns
-------
-------
probs: torch.Tensor
probs
: torch.Tensor
routing_map: torch.Tensor
routing_map
: torch.Tensor
"""
"""
if
logits
.
dtype
==
torch
.
float64
:
if
logits
.
dtype
==
torch
.
float64
:
raise
ValueError
(
"Current TE does not support float64 router type"
)
raise
ValueError
(
"Current TE does not support float64 router type"
)
...
@@ -186,15 +186,15 @@ def fused_compute_score_for_moe_aux_loss(
...
@@ -186,15 +186,15 @@ def fused_compute_score_for_moe_aux_loss(
Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function.
Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function.
Parameters
Parameters
----------
----------
logits: torch.Tensor
logits
: torch.Tensor
topk: int
topk
: int
score_function: str
score_function
: str
currently only support softmax and sigmoid
currently only support softmax and sigmoid
Returns
Returns
-------
-------
routing_map: torch.Tensor
routing_map
: torch.Tensor
scores: torch.Tensor
scores
: torch.Tensor
"""
"""
return
FusedComputeScoresForMoEAuxLoss
.
apply
(
logits
,
topk
,
score_function
)
return
FusedComputeScoresForMoEAuxLoss
.
apply
(
logits
,
topk
,
score_function
)
...
@@ -258,18 +258,18 @@ def fused_moe_aux_loss(
...
@@ -258,18 +258,18 @@ def fused_moe_aux_loss(
Fused MoE aux loss.
Fused MoE aux loss.
Parameters
Parameters
----------
----------
probs: torch.Tensor
probs
: torch.Tensor
tokens_per_expert: torch.Tensor
tokens_per_expert
: torch.Tensor
the number of tokens per expert
the number of tokens per expert
total_num_tokens: int
total_num_tokens
: int
the total number of tokens, involved in the aux loss calculation
the total number of tokens, involved in the aux loss calculation
num_experts: int
num_experts
: int
topk: int
topk
: int
coeff: float
coeff
: float
the coefficient of the aux loss
the coefficient of the aux loss
Returns
Returns
-------
-------
aux_loss: torch.scalar
aux_loss
: torch.scalar
"""
"""
return
FusedAuxLoss
.
apply
(
probs
,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
topk
,
coeff
)
return
FusedAuxLoss
.
apply
(
probs
,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
topk
,
coeff
)
transformer_engine/pytorch/setup.py
View file @
970620a5
...
@@ -75,21 +75,29 @@ def get_platform():
...
@@ -75,21 +75,29 @@ def get_platform():
def
get_wheel_url
():
def
get_wheel_url
():
"""Construct the wheel URL for the current platform."""
"""Construct the wheel URL for the current platform."""
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
platform_name
=
get_platform
()
nvte_version
=
te_version
()
nvte_version
=
te_version
()
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
# Determine the version numbers that will be used to determine the correct wheel
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
# For CUDA
11, we only compile for CUDA 11.8, and for CUDA
12 we only compile for CUDA 12.3
# For CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
# to save CI time. Minor versions should be compatible.
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.3"
)
if
torch_cuda_version
.
major
==
12
:
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
torch_cuda_version
=
parse
(
"12.3"
)
elif
torch_cuda_version
.
major
==
13
:
torch_cuda_version
=
parse
(
"13.0"
)
else
:
raise
ValueError
(
f
"CUDA version
{
torch_cuda_version
}
not supported"
)
if
os
.
environ
.
get
(
"NVIDIA_PRODUCT_NAME"
,
""
)
==
"PyTorch"
:
torch_version
=
str
(
os
.
environ
.
get
(
"NVIDIA_PYTORCH_VERSION"
))
else
:
torch_version
=
f
"
{
torch
.
__version__
}
"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}
"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}
"
# Determine wheel URL based on CUDA version, torch version, python version and OS
# Determine wheel URL based on CUDA version, torch version, python version and OS
...
@@ -109,8 +117,10 @@ class CachedWheelsCommand(_bdist_wheel):
...
@@ -109,8 +117,10 @@ class CachedWheelsCommand(_bdist_wheel):
"""
"""
def
run
(
self
):
def
run
(
self
):
"""Acts a proxy before _bdist_wheel.run() and downloads a prebuilt wheel if available."""
if
FORCE_BUILD
:
if
FORCE_BUILD
:
super
().
run
()
super
().
run
()
return
wheel_url
,
wheel_filename
=
get_wheel_url
()
wheel_url
,
wheel_filename
=
get_wheel_url
()
print
(
"Guessing wheel URL: "
,
wheel_url
)
print
(
"Guessing wheel URL: "
,
wheel_url
)
...
@@ -129,10 +139,12 @@ class CachedWheelsCommand(_bdist_wheel):
...
@@ -129,10 +139,12 @@ class CachedWheelsCommand(_bdist_wheel):
wheel_path
=
os
.
path
.
join
(
self
.
dist_dir
,
archive_basename
+
".whl"
)
wheel_path
=
os
.
path
.
join
(
self
.
dist_dir
,
archive_basename
+
".whl"
)
print
(
"Raw wheel path"
,
wheel_path
)
print
(
"Raw wheel path"
,
wheel_path
)
os
.
rename
(
wheel_filename
,
wheel_path
)
os
.
rename
(
wheel_filename
,
wheel_path
)
return
except
(
urllib
.
error
.
HTTPError
,
urllib
.
error
.
URLError
):
except
(
urllib
.
error
.
HTTPError
,
urllib
.
error
.
URLError
):
print
(
"Precompiled wheel not found. Building from source..."
)
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
# If the wheel could not be downloaded, build from source
super
().
run
()
super
().
run
()
return
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
970620a5
...
@@ -60,6 +60,22 @@ class Float8BlockQuantizer(Quantizer):
...
@@ -60,6 +60,22 @@ class Float8BlockQuantizer(Quantizer):
self
.
block_scaling_dim
=
block_scaling_dim
self
.
block_scaling_dim
=
block_scaling_dim
self
.
all_gather_usage
=
all_gather_usage
self
.
all_gather_usage
=
all_gather_usage
def
copy
(
self
)
->
Float8BlockQuantizer
:
"""Create shallow copy"""
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
self
.
dtype
,
rowwise
=
self
.
rowwise_usage
,
columnwise
=
self
.
columnwise_usage
,
block_scaling_dim
=
self
.
block_scaling_dim
,
all_gather_usage
=
self
.
all_gather_usage
,
amax_epsilon
=
self
.
amax_epsilon
,
force_pow_2_scales
=
self
.
force_pow_2_scales
,
)
quantizer
.
internal
=
self
.
internal
return
quantizer
def
update_quantized
(
def
update_quantized
(
self
,
self
,
src
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
...
@@ -294,18 +310,18 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
...
@@ -294,18 +310,18 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
Parameters
Parameters
----------
----------
rowwise_data: torch.Tensor
rowwise_data
: torch.Tensor
FP8 data in a uint8 tensor matching shape of dequantized tensor.
FP8 data in a uint8 tensor matching shape of dequantized tensor.
rowwise_scale_inv: torch.Tensor
rowwise_scale_inv
: torch.Tensor
FP32 dequantization scales in GEMM format for dequantizing rowwise_data.
FP32 dequantization scales in GEMM format for dequantizing rowwise_data.
columnwise_data: Optional[torch.Tensor]
columnwise_data
: Optional[torch.Tensor]
FP8 data in a uint8 tensor matching shape of dequantized tensor transpose.
FP8 data in a uint8 tensor matching shape of dequantized tensor transpose.
columnwise_scale_inv: Optional[torch.Tensor]
columnwise_scale_inv
: Optional[torch.Tensor]
FP32 dequantization scales in GEMM format for dequantizing columnwise_data.
FP32 dequantization scales in GEMM format for dequantizing columnwise_data.
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3
fp8_dtype
: transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format.
FP8 format.
quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and
quantizer
: Quantizer - the Float8BlockQuantizer that quantized this tensor and
holds configuration about quantization and dequantization modes.
holds configuration about quantization and dequantization modes.
"""
"""
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
970620a5
...
@@ -67,6 +67,20 @@ class Float8Quantizer(Quantizer):
...
@@ -67,6 +67,20 @@ class Float8Quantizer(Quantizer):
self
.
amax
=
amax
self
.
amax
=
amax
self
.
dtype
=
fp8_dtype
self
.
dtype
=
fp8_dtype
def
copy
(
self
)
->
Float8Quantizer
:
"""Create shallow copy"""
quantizer
=
Float8Quantizer
(
scale
=
self
.
scale
,
amax
=
self
.
amax
,
fp8_dtype
=
self
.
dtype
,
rowwise
=
self
.
rowwise_usage
,
columnwise
=
self
.
columnwise_usage
,
)
quantizer
.
internal
=
self
.
internal
return
quantizer
def
update_quantized
(
def
update_quantized
(
self
,
self
,
src
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
...
@@ -246,10 +260,16 @@ class Float8CurrentScalingQuantizer(Quantizer):
...
@@ -246,10 +260,16 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_reduction_group
:
Optional
[
dist_group_type
]
=
None
,
amax_reduction_group
:
Optional
[
dist_group_type
]
=
None
,
force_pow_2_scales
:
bool
=
False
,
force_pow_2_scales
:
bool
=
False
,
amax_epsilon
:
float
=
0.0
,
amax_epsilon
:
float
=
0.0
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
amax
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
scale
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
if
scale
is
None
:
self
.
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
scale
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
if
amax
is
None
:
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
scale
=
scale
self
.
amax
=
amax
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8_tensorwise
else
fp8_dtype
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8_tensorwise
else
fp8_dtype
self
.
use_existing_amax
=
use_existing_amax
self
.
use_existing_amax
=
use_existing_amax
self
.
with_amax_reduction
=
with_amax_reduction
self
.
with_amax_reduction
=
with_amax_reduction
...
@@ -257,6 +277,26 @@ class Float8CurrentScalingQuantizer(Quantizer):
...
@@ -257,6 +277,26 @@ class Float8CurrentScalingQuantizer(Quantizer):
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
amax_epsilon
=
amax_epsilon
self
.
amax_epsilon
=
amax_epsilon
def
copy
(
self
)
->
Float8CurrentScalingQuantizer
:
"""Create shallow copy"""
quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
self
.
dtype
,
device
=
0
,
rowwise
=
self
.
rowwise_usage
,
columnwise
=
self
.
columnwise_usage
,
with_amax_reduction
=
self
.
with_amax_reduction
,
amax_reduction_group
=
self
.
amax_reduction_group
,
use_existing_amax
=
self
.
use_existing_amax
,
force_pow_2_scales
=
self
.
force_pow_2_scales
,
amax_epsilon
=
self
.
amax_epsilon
,
scale
=
self
.
scale
,
amax
=
self
.
amax
,
)
quantizer
.
internal
=
self
.
internal
return
quantizer
def
update_quantized
(
def
update_quantized
(
self
,
self
,
src
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
...
@@ -414,23 +454,23 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
...
@@ -414,23 +454,23 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
Parameters
Parameters
----------
----------
shape: int or iterable of int
shape
: int or iterable of int
Tensor dimensions.
Tensor dimensions.
dtype: torch.dtype
dtype
: torch.dtype
Nominal tensor datatype.
Nominal tensor datatype.
requires_grad: bool, optional = False
requires_grad
: bool, optional = False
Whether to compute gradients for this tensor.
Whether to compute gradients for this tensor.
data: torch.Tensor
data
: torch.Tensor
Raw FP8 data in a uint8 tensor
Raw FP8 data in a uint8 tensor
fp8_scale_inv: torch.Tensor
fp8_scale_inv
: torch.Tensor
Reciprocal of the scaling factor applied when casting to FP8,
Reciprocal of the scaling factor applied when casting to FP8,
i.e. the scaling factor that must be applied when casting from
i.e. the scaling factor that must be applied when casting from
FP8 to higher precision.
FP8 to higher precision.
fp8_dtype: transformer_engine_torch.DType
fp8_dtype
: transformer_engine_torch.DType
FP8 format.
FP8 format.
data_transpose: torch.Tensor, optional
data_transpose
: torch.Tensor, optional
FP8 transpose data in a uint8 tensor
FP8 transpose data in a uint8 tensor
quantizer: Float8Quantizer, Float8CurrentScalingQuantizer, optional
quantizer
: Float8Quantizer, Float8CurrentScalingQuantizer, optional
Builder class for FP8 tensors
Builder class for FP8 tensors
"""
"""
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
970620a5
...
@@ -45,6 +45,18 @@ class MXFP8Quantizer(Quantizer):
...
@@ -45,6 +45,18 @@ class MXFP8Quantizer(Quantizer):
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
fp8_dtype
self
.
dtype
=
fp8_dtype
def
copy
(
self
)
->
MXFP8Quantizer
:
"""Create shallow copy"""
quantizer
=
MXFP8Quantizer
(
fp8_dtype
=
self
.
dtype
,
rowwise
=
self
.
rowwise_usage
,
columnwise
=
self
.
columnwise_usage
,
)
quantizer
.
internal
=
self
.
internal
return
quantizer
def
update_quantized
(
def
update_quantized
(
self
,
self
,
src
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
...
@@ -122,7 +134,9 @@ class MXFP8Quantizer(Quantizer):
...
@@ -122,7 +134,9 @@ class MXFP8Quantizer(Quantizer):
columnwise_data
=
None
columnwise_data
=
None
columnwise_scale_inv
=
None
columnwise_scale_inv
=
None
if
self
.
columnwise_usage
:
if
self
.
columnwise_usage
:
columnwise_data
=
torch
.
empty_like
(
data
,
pin_memory
=
pin_memory
)
columnwise_data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
)
columnwise_scale_inv
=
torch
.
empty
(
columnwise_scale_inv
=
torch
.
empty
(
round_up_to_nearest_multiple
(
math
.
prod
(
shape
[:
-
1
])
//
MXFP8_BLOCK_SCALING_SIZE
,
4
),
round_up_to_nearest_multiple
(
math
.
prod
(
shape
[:
-
1
])
//
MXFP8_BLOCK_SCALING_SIZE
,
4
),
round_up_to_nearest_multiple
(
shape
[
-
1
],
128
),
round_up_to_nearest_multiple
(
shape
[
-
1
],
128
),
...
@@ -190,16 +204,16 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
...
@@ -190,16 +204,16 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
Parameters
Parameters
----------
----------
data: torch.Tensor
data
: torch.Tensor
Raw FP8 data in a uint8 tensor
Raw FP8 data in a uint8 tensor
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3
fp8_dtype
: transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format.
FP8 format.
fp8_scale_inv: torch.Tensor
fp8_scale_inv
: torch.Tensor
Reciprocal of the scaling factor applied when
Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must
casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher
be applied when casting from FP8 to higher
precision.
precision.
dtype: torch.dtype, default = torch.float32
dtype
: torch.dtype, default = torch.float32
Nominal tensor datatype.
Nominal tensor datatype.
"""
"""
...
...
transformer_engine/pytorch/tensor/nvfp4_tensor.py
View file @
970620a5
...
@@ -176,6 +176,26 @@ class NVFP4Quantizer(Quantizer):
...
@@ -176,6 +176,26 @@ class NVFP4Quantizer(Quantizer):
return
dst
return
dst
def
copy
(
self
)
->
NVFP4Quantizer
:
"""Create shallow copy"""
quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
self
.
dtype
,
rowwise
=
self
.
rowwise_usage
,
columnwise
=
self
.
columnwise_usage
,
with_amax_reduction
=
self
.
with_amax_reduction
,
amax_reduction_group
=
self
.
amax_reduction_group
,
with_rht
=
self
.
with_rht
,
with_post_rht_amax
=
self
.
with_post_rht_amax
,
with_2d_quantization
=
self
.
with_2d_quantization
,
stochastic_rounding
=
self
.
stochastic_rounding
,
)
quantizer
.
internal
=
self
.
internal
quantizer
.
rht_matrix
=
self
.
rht_matrix
quantizer
.
rht_matrix_random_sign_mask_t
=
self
.
rht_matrix_random_sign_mask_t
return
quantizer
def
quantize_impl
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
def
quantize_impl
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Quantize tensor implementation"""
"""Quantize tensor implementation"""
return
tex
.
quantize
(
tensor
,
self
)
return
tex
.
quantize
(
tensor
,
self
)
...
@@ -360,26 +380,26 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
...
@@ -360,26 +380,26 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
Parameters
Parameters
----------
----------
rowwise_data: torch.Tensor
rowwise_data
: torch.Tensor
Raw FP4 data in a uint8 tensor (rowwise layout).
Raw FP4 data in a uint8 tensor (rowwise layout).
rowwise_scale_inv: torch.Tensor
rowwise_scale_inv
: torch.Tensor
Reciprocal of the scaling factor applied when
Reciprocal of the scaling factor applied when
casting to FP4, i.e. the scaling factor that must
casting to FP4, i.e. the scaling factor that must
be applied when casting from FP4 to higher
be applied when casting from FP4 to higher
precision (rowwise).
precision (rowwise).
columnwise_data: torch.Tensor, optional
columnwise_data
: torch.Tensor, optional
Raw FP4 data in a uint8 tensor (columnwise layout).
Raw FP4 data in a uint8 tensor (columnwise layout).
columnwise_scale_inv: torch.Tensor, optional
columnwise_scale_inv
: torch.Tensor, optional
Reciprocal of the scaling factor for columnwise FP4 data.
Reciprocal of the scaling factor for columnwise FP4 data.
amax_rowwise: torch.Tensor, optional
amax_rowwise
: torch.Tensor, optional
Rowwise amax tracking tensor.
Rowwise amax tracking tensor.
amax_columnwise: torch.Tensor, optional
amax_columnwise
: torch.Tensor, optional
Columnwise amax tracking tensor.
Columnwise amax tracking tensor.
fp4_dtype: TE_DType
fp4_dtype
: TE_DType
The FP4 data type used for quantization.
The FP4 data type used for quantization.
quantizer: Quantizer
quantizer
: Quantizer
The quantizer instance used for this tensor.
The quantizer instance used for this tensor.
dtype: torch.dtype, default = torch.float32
dtype
: torch.dtype, default = torch.float32
Nominal tensor datatype, used in dequantize.
Nominal tensor datatype, used in dequantize.
"""
"""
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
970620a5
...
@@ -74,7 +74,7 @@ def cast_master_weights_to_fp8(
...
@@ -74,7 +74,7 @@ def cast_master_weights_to_fp8(
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
not sharded. Otherwise, it means that the model weights are sharded and we get
not sharded. Otherwise, it means that the model weights are sharded and we get
target model weights data storage using the FSDP shard model weights.
target model weights data storage using the FSDP shard model weights.
manual_post_all_gather_processing: bool, default = `False`.
manual_post_all_gather_processing
: bool, default = `False`.
If False, post processing will be automatically triggered during next forward.
If False, post processing will be automatically triggered during next forward.
If True, the timing of calling post_all_gather_processing is left to the user.
If True, the timing of calling post_all_gather_processing is left to the user.
Note that users must call `post_all_gather_processing` if it's set to True,
Note that users must call `post_all_gather_processing` if it's set to True,
...
...
transformer_engine/pytorch/torch_version.py
0 → 100644
View file @
970620a5
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""PyTorch version utilities"""
from
__future__
import
annotations
import
functools
import
torch
from
packaging.version
import
Version
as
PkgVersion
@
functools
.
lru_cache
(
maxsize
=
None
)
def
torch_version
()
->
tuple
[
int
,
...]:
"""Get PyTorch version"""
return
PkgVersion
(
str
(
torch
.
__version__
)).
release
transformer_engine/pytorch/transformer.py
View file @
970620a5
...
@@ -10,7 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union
...
@@ -10,7 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union
import
torch
import
torch
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
.torch_version
import
torch_version
from
transformer_engine.pytorch.module
import
LayerNormMLP
,
LayerNorm
,
RMSNorm
from
transformer_engine.pytorch.module
import
LayerNormMLP
,
LayerNorm
,
RMSNorm
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
...
@@ -75,8 +75,8 @@ class TransformerLayer(torch.nn.Module):
...
@@ -75,8 +75,8 @@ class TransformerLayer(torch.nn.Module):
.. note::
.. note::
Argument :attr:`attention_mask` in the `forward` call is only used when
Argument :attr:`attention_mask` in the
:meth:
`forward` call is only used when
:attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
:attr:`self_attn_mask_type` includes
`
`"padding"`
`
or
`
`"arbitrary"`
`
.
Parameters
Parameters
----------
----------
...
@@ -86,76 +86,76 @@ class TransformerLayer(torch.nn.Module):
...
@@ -86,76 +86,76 @@ class TransformerLayer(torch.nn.Module):
intermediate size to which input samples are projected.
intermediate size to which input samples are projected.
num_attention_heads : int
num_attention_heads : int
number of attention heads in the transformer layer.
number of attention heads in the transformer layer.
num_gqa_groups : int, default =
`
None
`
num_gqa_groups : int, default = None
number of GQA groups in the transformer layer.
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
is equivalent to MHA, i.e.
`
`num_gqa_groups = num_attention_heads`
`
.
layernorm_epsilon : float, default = 1e-5
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization
a value added to the denominator of layer normalization
for numerical stability.
for numerical stability.
hidden_dropout: float, default = 0.1
hidden_dropout
: float, default = 0.1
dropout probability for the dropout op after FC2 layer.
dropout probability for the dropout op after FC2 layer.
attention_dropout: float, default = 0.1
attention_dropout
: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
dropout probability for the dropout op during multi-head attention.
init_method : Callable, default =
`
None
`
init_method : Callable, default = None
used for initializing weights of QKV and FC1 weights in the following way:
used for initializing weights of QKV and FC1 weights in the following way:
`init_method(weight)`. When set to `None`, defaults to
`
`init_method(weight)`
`
. When set to
`
`None`
`
, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
output_layer_init_method : Callable, default =
`
None
`
output_layer_init_method : Callable, default = None
used for initializing weights of PROJ and FC2 in the following way:
used for initializing weights of PROJ and FC2 in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to
`
`output_layer_init_method(weight)`
`
. When set to
`
`None`
`
, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
apply_residual_connection_post_layernorm : bool, default =
`
False
`
apply_residual_connection_post_layernorm : bool, default = False
if set to `True`, residual connections are taken
if set to
`
`True`
`
, residual connections are taken
from the output of layer norm (default is taken
from the output of layer norm (default is taken
from input of layer norm)
from input of layer norm)
layer_number: int, default =
`
None
`
layer_number
: int, default = None
layer number of the current `TransformerLayer` when multiple such modules are
layer number of the current
:class:
`TransformerLayer` when multiple such modules are
concatenated to form a transformer block.
concatenated to form a transformer block.
output_layernorm: bool, default =
`
False
`
output_layernorm
: bool, default = False
if set to `True`, layer normalization is applied on the output side,
if set to
`
`True`
`
, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
normalization on the input side, before the QKV transformation.
parallel_attention_mlp: bool, default =
`
False
`
parallel_attention_mlp
: bool, default = False
if set to `True`, self-attention and feedforward network are computed
if set to
`
`True`
`
, self-attention and feedforward network are computed
based on the same input (in parallel) instead of sequentially.
based on the same input (in parallel) instead of sequentially.
Both blocks have an independent normalization.
Both blocks have an independent normalization.
This architecture is used in `Falcon` models.
This architecture is used in `Falcon` models.
layer_type: {'encoder', 'decoder'}, default =
`
encoder
`
layer_type
: {'encoder', 'decoder'}, default =
"
encoder
"
if set to `decoder`, an additional cross-attn block is added after self-attn.
if set to `
`"
decoder
"`
`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the
This can be used for structures like `T5` Transformer in conjunction with the
`encoder` option.
`
`"
encoder
"`
` option.
kv_channels: int, default =
`
None
`
kv_channels
: int, default = None
number of query-key-value channels per attention head. defaults to
number of query-key-value channels per attention head. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
:attr:`hidden_size` / :attr:`num_attention_heads` if
`
`None`
`
.
self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
self_attn_mask_type
: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right', 'arbitrary'},
'padding_causal_bottom_right', 'arbitrary'},
default =
`
causal
`
default =
"
causal
"
type of attention mask passed into softmax operation for encoder.
type of attention mask passed into softmax operation for encoder.
Overridden by :attr:`self_attn_mask_type` in the `forward` method.
Overridden by :attr:`self_attn_mask_type` in the
:meth:
`forward` method.
The forward arg is useful for dynamically changing mask types, e.g.
The
:meth:`
forward
`
arg is useful for dynamically changing mask types, e.g.
a different mask for training and inference. The init arg is useful
a different mask for training and inference. The
:meth:`__
init
__`
arg is useful
for cases involving compilation/tracing, e.g. ONNX export.
for cases involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default =
`
None
`
window_size
: Optional[Tuple[int, int]], default = None
sliding window size for local attention in encoder, where query at position i
sliding window size for local attention in encoder, where query at position i
attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
attends to keys in
``
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
- seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
- seqlen_q + window_size[1]]
``
inclusive. Special cases
``
(-1, -1)
``
and
``
(-1, 0)
``
mean
no sliding window and causal mask specifically. Both `causal` and
no sliding window and causal mask specifically. Both `
`"
causal
"`
` and
`causal_bottom_right` masks map to `window_size = (-1, 0)` and Transformer Engine
`
`"
causal_bottom_right
"`
` masks map to
:attr:
`window_size
`
=
``
(-1, 0)`
`
and Transformer Engine
distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`.
distinguishes them based on
:attr:
`self_attn_mask_type` or
:attr:
`enc_dec_attn_mask_type`.
Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by
Similar to :attr:`self_attn_mask_type`,
:attr:
`window_size` can be overridden by
:attr:`window_size` in `forward` as well.
:attr:`window_size` in
:meth:
`forward` as well.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
enc_dec_attn_mask_type
: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default =
`
no_mask
`
default =
"
no_mask
"
type of attention mask passed into softmax operation for decoder.
type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default =
`
None
`
enc_dec_window_size
: Optional[Tuple[int, int]], default = None
sliding window size for local attention in decoder.
sliding window size for local attention in decoder.
zero_centered_gamma : bool, default =
'
False
'
zero_centered_gamma : bool, default = False
if set to
'
True
'
, gamma parameter in LayerNorm is initialized to 0 and
if set to
``
True
``
, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
the LayerNorm formula changes to
.. math::
.. math::
...
@@ -163,111 +163,126 @@ class TransformerLayer(torch.nn.Module):
...
@@ -163,111 +163,126 @@ class TransformerLayer(torch.nn.Module):
(1 + \gamma) + \beta
(1 + \gamma) + \beta
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
type of normalization applied.
qkv_weight_interleaved : bool, default =
`
True
`
qkv_weight_interleaved : bool, default = True
if set to `False`, the QKV weight is interpreted as a concatenation of
if set to
`
`False`
`
, the QKV weight is interpreted as a concatenation of
query, key, and value weights along the `0th` dimension. The default
query, key, and value weights along the
`
`0th`
`
dimension. The default
interpretation is that the individual `
q
`, `k`, and `
v
` weights for each
interpretation is that the individual `
`q`
`,
`
`k`
`
, and `
`v`
` weights for each
attention head are interleaved. This parameter is set to `False` when
attention head are interleaved. This parameter is set to
`
`False`
`
when
using :attr:`fuse_qkv_params=False`.
using :attr:`fuse_qkv_params=False`.
rotary_pos_interleaved : bool, default =
`
False
`
rotary_pos_interleaved : bool, default = False
whether to use interleaved rotary position embeddings.
whether to use interleaved rotary position embeddings.
bias : bool, default =
`
True
`
bias : bool, default = True
if set to `False`, the transformer layer will not learn any additive biases.
if set to
`
`False`
`
, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
activation : str, default = 'gelu'
Type of activation used in MLP block.
Type of activation used in MLP block.
Options are: 'gelu'
,
'geglu'
,
'qgelu'
,
'qgeglu'
, 'relu',
'reglu'
,
'srelu'
,
'sreglu',
Options are:
``
'gelu'
``, ``
'geglu'
``, ``
'qgelu'
``, ``
'qgeglu'
``, ``'relu'``, ``
'reglu'
``, ``
'srelu'
``, ``
'sreglu'
``
,
'silu',
'swiglu', and 'clamped_swiglu'.
``'silu'``, ``
'swiglu'
``
, and
``
'clamped_swiglu'
``
.
activation_params : Optional[dict], default =
`
None
`
activation_params : Optional[dict], default = None
Additional parameters for the activation function.
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
At the moment, only used for
``
'clamped_swiglu'
``
activation which
supports 'limit' and 'alpha' parameters. You can set these as
supports
``
'limit'
``
and
``
'alpha'
``
parameters. You can set these as
`activation_params={'limit': 7.0, 'alpha': 1.702}`.
`
`activation_params={'limit': 7.0, 'alpha': 1.702}`
`
.
device : Union[torch.device, str], default = "cuda"
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
forward pass.
attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
attn_input_format
: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
This controls whether the dimensions of the
This controls whether the dimensions of the
intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'),
intermediate hidden states is 'sequence first' (
``
'sbhd'
``
), 'batch first' (
``
'bshd'
``
),
or 'token first' ('thd'). `
s
` stands for the sequence length, `
b
` batch size,
or 'token first' (
``
'thd'
``
). `
`s`
` stands for the sequence length, `
`b`
` batch size,
`t` the total number of tokens, `
h
` the number of heads, `
d
` head size.
`
`t`
`
the total number of tokens, `
`h`
` the number of heads, `
`d`
` head size.
Note that these formats are very closely
Note that these formats are very closely
related to the `qkv_format`
in the
`MultiHeadAttention`
related to the
:attr:
`qkv_format`
parameter in the :class:
`MultiHeadAttention`
and `DotProductAttention` modules.
and
:class:
`DotProductAttention` modules.
name: str, default =
`
None
`
name
: str, default = None
name of the module, currently used for debugging purposes.
name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax_type
: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
s
oftmax type as described in th
is
paper
:
S
oftmax type as described in th
e
paper
`Efficient Streaming Language Models with Attention Sinks
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* ``'vanilla'``:
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
.. math::
('zero sink' and 'learnable sink').
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Parallelism parameters
Parallelism parameters
----------------------
----------------------
set_parallel_mode : bool, default =
`
False
`
set_parallel_mode : bool, default = False
if set to `True`, QKV and FC1 layers are used as Column Parallel
if set to
`
`True`
`
, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described
whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_.
`here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default =
`
False
`
sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism.
if set to
`
`True`
`
, uses sequence parallelism.
tp_group : ProcessGroup, default =
`
None
`
tp_group : ProcessGroup, default = None
tensor parallel process group.
tensor parallel process group.
tp_size : int, default = 1
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
initialization. In this case, users must call the
`set_tensor_parallel_group
(tp_group)
` method on the initialized module before the
:meth:
`set_tensor_parallel_group` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel collectives.
Optimization parameters
Optimization parameters
-----------------------
-----------------------
fuse_wgrad_accumulation : bool, default =
'
False
'
fuse_wgrad_accumulation : bool, default = False
if set to `True`, enables fusing of creation and accumulation of
if set to
`
`True`
`
, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
have an additional
:attr:
`main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
regular
:attr:
`grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in.
params_dtype : torch.dtype, default =
`
torch.get_default_dtype()
`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
would not fit in GPU memory.
seq_length: int
seq_length
: int
sequence length of input samples. Needed for JIT Warmup, a technique where jit
sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for
fused functions are warmed up before training to ensure same kernels are used for
forward propogation and activation recompute phase.
forward propogation and activation recompute phase.
micro_batch_size: int
micro_batch_size
: int
batch size per training step. Needed for JIT Warmup, a technique where jit
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase.
used for forward propogation and activation recompute phase.
drop_path_rate: float, default = 0.0
drop_path_rate
: float, default = 0.0
when > 0.0, applies stochastic depth per sample in
when > 0.0, applies stochastic depth per sample in
the main path of the residual block.
the main path of the residual block.
fuse_qkv_params: bool, default =
'
False
'
fuse_qkv_params
: bool, default = False
if set to `True`, `TransformerLayer` module exposes a single fused
if set to
`
`True`
`
,
:class:
`TransformerLayer` module exposes a single fused
parameter for query-key-value. This enables optimizations such as QKV
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
:attr:
`fuse_wgrad_accumulation`.
qk_norm_type: Optional[str], default = None
qk_norm_type
: Optional[str], default = None
type of normalization to apply to query and key tensors.
type of normalization to apply to query and key tensors.
Options: None
,
'L2Normalization'
,
'RMSNorm'
,
'LayerNorm'. When None, no normalization is applied.
Options:
``
None
``, ``
'L2Normalization'
``, ``
'RMSNorm'
``, ``
'LayerNorm'
``
. When
``
None
``
, no normalization is applied.
When 'L2Normalization', L2 normalization is applied to query and key tensors.
When
``
'L2Normalization'
``
, L2 normalization is applied to query and key tensors.
When 'RMSNorm', RMS normalization is applied to query and key tensors.
When
``
'RMSNorm'
``
, RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors.
When
``
'LayerNorm'
``
, layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation
Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach for
when
`
`qk_norm_before_rope`
`
is
``
False
``
. This follows the e.g. Llama4 approach for
QK normalization to improve training stability and model performance.
QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6
qk_norm_eps
: float, default = 1e-6
epsilon value for normalization of query and key tensors.
epsilon value for normalization of query and key tensors.
Only used when `qk_norm_type` is not None.
Only used when
`
`qk_norm_type`
`
is not
``
None
``
.
qk_norm_before_rope: bool, default =
`
False
`
qk_norm_before_rope
: bool, default = False
if set to `True`, query and key normalization is applied before rotary position
if set to
`
`True`
`
, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE.
embedding. When
`
`False`
`
(default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply
This parameter allows supporting different architectural variants that apply
QK normalization at different points.
QK normalization at different points.
"""
"""
...
@@ -523,7 +538,7 @@ class TransformerLayer(torch.nn.Module):
...
@@ -523,7 +538,7 @@ class TransformerLayer(torch.nn.Module):
Parameters
Parameters
----------
----------
tp_group : ProcessGroup, default =
`
None
`
tp_group : ProcessGroup, default = None
tensor parallel process group.
tensor parallel process group.
"""
"""
# Deep iterate but skip self to avoid infinite recursion.
# Deep iterate but skip self to avoid infinite recursion.
...
@@ -549,7 +564,7 @@ class TransformerLayer(torch.nn.Module):
...
@@ -549,7 +564,7 @@ class TransformerLayer(torch.nn.Module):
cp_stream
:
torch
.
cuda
.
Stream
,
cp_stream
:
torch
.
cuda
.
Stream
,
cp_comm_type
:
str
=
"p2p"
,
cp_comm_type
:
str
=
"p2p"
,
)
->
None
:
)
->
None
:
"""
r
"""
Set the context parallel attributes for the given
Set the context parallel attributes for the given
module before executing the forward pass.
module before executing the forward pass.
...
@@ -557,25 +572,26 @@ class TransformerLayer(torch.nn.Module):
...
@@ -557,25 +572,26 @@ class TransformerLayer(torch.nn.Module):
----------
----------
cp_group : Union[ProcessGroup, List[ProcessGroup]]
cp_group : Union[ProcessGroup, List[ProcessGroup]]
context parallel process group.
context parallel process group.
ProcessGroup is for cp_comm_type of "p2p"
,
"all_gather", and "a2a".
ProcessGroup is for cp_comm_type of
``
"p2p"
``, ``
"all_gather"
``
, and
``
"a2a"
``
.
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
List[ProcessGroup] is for cp_comm_type of
``
"a2a+p2p"
``
, where
``
cp_group[0]
``
and cp_group[1] are for a2a and p2p communications respectively.
and
``
cp_group[1]
``
are for a2a and p2p communications respectively.
cp_global_ranks : List[int]
cp_global_ranks : List[int]
list of global ranks in the context group.
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cuda stream for context parallel execution.
cp_comm_type : str, default =
`
p2p
`
cp_comm_type : str, default =
"
p2p
"
inter-gpu communication type for context parallelism.
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a", or "a2a+p2p".
Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``.
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
- ``"p2p"``: Exchange KV chunks with P2P communications in ring topology.
"all_gather": All-gather to get full sequence of KV before attention.
P2P is async and can be overlapped with attention compute.
The all-gather is not async, and cannot be overlapped.
- ``"all_gather"``: All-gather to get full sequence of KV before attention.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
The all-gather is not async, and cannot be overlapped.
group, and gather to get full sequence of QKV.
- ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
group, and gather to get full sequence of QKV.
across each CP sub-group (e.g., via NVLink), then exchanging KV with
- ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV
p2p between sub-groups (e.g., via IBLink).
across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
"""
"""
# Deep iterate but skip self to avoid infinite recursion.
# Deep iterate but skip self to avoid infinite recursion.
for
index
,
child
in
enumerate
(
self
.
modules
()):
for
index
,
child
in
enumerate
(
self
.
modules
()):
...
@@ -610,49 +626,49 @@ class TransformerLayer(torch.nn.Module):
...
@@ -610,49 +626,49 @@ class TransformerLayer(torch.nn.Module):
fast_zero_fill
:
bool
=
True
,
fast_zero_fill
:
bool
=
True
,
pad_between_seqs
:
Optional
[
bool
]
=
None
,
pad_between_seqs
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
r
"""
Transformer Layer: attention block and a feedforward network (MLP)
Transformer Layer: attention block and a feedforward network (MLP)
.. note::
.. note::
Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
includes `"padding"` or `"arbitrary"`.
includes
`
`"padding"`
`
or
`
`"arbitrary"`
`
.
Parameters
Parameters
----------
----------
hidden_states : torch.Tensor
hidden_states : torch.Tensor
Input tensor.
Input tensor.
attention_mask : Optional[torch.Tensor], default =
`
None
`
attention_mask : Optional[torch.Tensor], default = None
Boolean tensor used to mask out self-attention softmax input. It should be
Boolean tensor used to mask out self-attention softmax input. It should be
in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
in
``
[batch_size, 1, 1, seqlen_q]
``
for padding masks, and broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for
"
`arbitrary
`
"
to
``
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
``
for `
`"
arbitrary"
``
mask. It should be `None` for causal masks and
"
`no_mask
`
" type.
mask. It should be
`
`None`
`
for causal masks and `
`"
no_mask"
``
type.
A `True` value means the corresponding position is masked out and
A
`
`True`
`
value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention.
a
`
`False`
`
means that position is allowed to participate in attention.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default =
`
causal
`
default =
"
causal
"
Type of attention mask passed into softmax operation for encoder.
Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of
By default, causal masks are aligned to the top left corner of
the softmax matrix. When
"
`bottom_right
`
" is specified in the mask type,
the softmax matrix. When `
`"
bottom_right"
``
is specified in the mask type,
causal masks are aligned to the bottom right corner.
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default =
`
None
`
window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in encoder.
Sliding window size for local attention in encoder.
encoder_output : Optional[torch.Tensor], default =
`
None
`
encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using
Output of the encoder block to be fed into the decoder block if using
`layer_type
=
"decoder"`.
:attr:
`layer_type
` = ``
"decoder"`
`
.
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default =
`
None
`
. Boolean tensors used to mask out inter-attention softmax input if
default = None. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type
=
"decoder"`. It should be a tuple of two masks in
using
:attr:
`layer_type
` = ``
"decoder"`
`
. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
``
[batch_size, 1, 1, seqlen_q]
``
and
``
[batch_size, 1, 1, seqlen_kv]
``
for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
It should be broadcastable to
``
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
``
for
"
`arbitrary
`
" mask. It should be `None` for causal masks and
"
`no_mask
`
".
for `
`"
arbitrary"
``
mask. It should be
`
`None`
`
for causal masks and `
`"
no_mask"
``
.
A `True` value means the corresponding position is masked out and a `False`
A
`
`True`
`
value means the corresponding position is masked out and a
`
`False`
`
means that position is allowed to participate in attention.
means that position is allowed to participate in attention.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default =
`
None
`
default = None
Type of attention mask passed into softmax operation for decoder.
Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default =
`
None
`
enc_dec_window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in decoder.
Sliding window size for local attention in decoder.
is_first_microbatch : {True, False, None}, default = None
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
During training using either gradient accumulation or
...
@@ -667,53 +683,53 @@ class TransformerLayer(torch.nn.Module):
...
@@ -667,53 +683,53 @@ class TransformerLayer(torch.nn.Module):
* it also allows skipping gradient accumulation during the
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
first microbatch (since it is the first gradient being
produced)
produced)
checkpoint_core_attention: bool, default =
`
False
`
checkpoint_core_attention: bool, default = False
If
t
rue, forward activations for core attention are recomputed
If
``T
rue
``
, forward activations for core attention are recomputed
during the backward pass in order to save memory that would
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
otherwise be occupied to store the forward activations until
backprop.
backprop.
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default =
`
None
`
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None
Embeddings for query and key tensors for applying rotary position
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default =
`
no_bias
`
core_attention_bias_type: str, default =
"
no_bias
"
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
Bias type, {`
`"
no_bias
"`
`, `
`"
pre_scale_bias
"`
`, `
`"
post_scale_bias
"`
`, `
`"
alibi
"`
`}
core_attention_bias: Optional[torch.Tensor], default =
`
None
`
core_attention_bias: Optional[torch.Tensor], default = None
Bias tensor for
Q * K.T
Bias tensor for
:math:`Q \cdot K^T`
alibi_slopes: Optional[torch.Tensor], default =
`
None
`
alibi_slopes: Optional[torch.Tensor], default = None
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
ALiBi slopes in FP32 and shape
``
[nheads]
``
or
``
[batch_size, nheads]
``
.
It adds a bias of
(-alibi_slope * (i +
seqlen_k - seqlen_q - j))
It adds a bias of
:math:`(-\text{alibi_slope} \cdot (i + \text{
seqlen_k
}
-
\text{
seqlen_q
}
- j))
`
to the attention score of query i and key j.
to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default =
`
None
`
cu_seqlens_q: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for
`
query
_
layer
`
,
Cumulative sum of sequence lengths (without offset) in a batch for query
layer,
with shape [batch_size + 1] and dtype torch.int32.
with shape
``
[batch_size + 1]
``
and dtype torch.int32.
Used by encoders, or decoders' self-attention.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv: Optional[torch.Tensor], default =
`
None
`
cu_seqlens_kv: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for
`
key
_
layer
`
Cumulative sum of sequence lengths (without offset) in a batch for key
layer
and
`
value
_
layer
`
, with shape [batch_size + 1] and dtype torch.int32.
and value
layer, with shape
``
[batch_size + 1]
``
and dtype torch.int32.
Used by decoders' cross-attention.
Used by decoders' cross-attention.
cu_seqlens_q_padded: Optional[torch.Tensor], default =
`
None
`
cu_seqlens_q_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for
`
query
_
layer
`
,
Cumulative sum of sequence lengths (with offset) in a batch for query
layer,
with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None.
with shape
``
[batch_size + 1]
``
and dtype torch.int32. Set to
:attr:
`cu_seqlens_q` if
``
None
``
.
Used by encoders, or decoders' self-attention.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv_padded: Optional[torch.Tensor], default =
`
None
`
cu_seqlens_kv_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for
`
key
_
layer
`
Cumulative sum of sequence lengths (with offset) in a batch for key
layer
and
`
value
_
layer
`
, with shape [batch_size + 1] and dtype torch.int32.
and value
layer, with shape
``
[batch_size + 1]
``
and dtype torch.int32.
Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention.
Set to
:attr:
`cu_seqlens_kv` if
``
None
``
. Used by decoders' cross-attention.
max_seqlen_q: Optional[int], default =
`
None
`
max_seqlen_q: Optional[int], default = None
Maximum sequence length in
`
query
_
layer
`
.
Maximum sequence length in query
layer.
Calculated from `cu_seqlens_q_padded` if not provided.
Calculated from
:attr:
`cu_seqlens_q_padded` if not provided.
max_seqlen_kv: Optional[int], default =
`
None
`
max_seqlen_kv: Optional[int], default = None
Maximum sequence length in
`
key
_
layer
`
and
`
value
_
layer
`
.
Maximum sequence length in key
layer and value
layer.
Calculated from `cu_seqlens_kv_padded` if not provided.
Calculated from
:attr:
`cu_seqlens_kv_padded` if not provided.
fast_zero_fill: bool, default =
`
True
`
fast_zero_fill: bool, default = True
Whether to set output tensors to 0 or not before use.
Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None
inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
to efficiently calculate and store the context during inference.
pad_between_seqs: Optional[bool], default =
`
None
`
pad_between_seqs: Optional[bool], default = None
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If
``
None
``
, inferred from
:attr:`
qkv_format
`
, cu_seqlens and cu_seqlens_padded.
If
t
rue, there are padding tokens between individual sequences in a packed batch,
If
``T
rue
``
, there are padding tokens between individual sequences in a packed batch,
i.e. qkv_format = 'thd'.
i.e.
:attr:`
qkv_format
`
=
``
'thd'
``
.
"""
"""
if
self_attn_mask_type
is
None
:
if
self_attn_mask_type
is
None
:
...
...
transformer_engine/pytorch/triton/permutation.py
View file @
970620a5
...
@@ -31,18 +31,18 @@ def make_row_id_map(
...
@@ -31,18 +31,18 @@ def make_row_id_map(
Parameters
Parameters
----------
----------
routing_map: torch.Tensor
routing_map
: torch.Tensor
Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
which experts are routed to which tokens. The values in it: 1 means the token is routed to
which experts are routed to which tokens. The values in it: 1 means the token is routed to
this expert and 0 means not.
this expert and 0 means not.
num_tokens: int
num_tokens
: int
Number of tokens in the input tensor.
Number of tokens in the input tensor.
num_experts: int
num_experts
: int
Number of experts in the input tensor.
Number of experts in the input tensor.
Returns
Returns
-------
-------
row_id_map: torch.Tensor
row_id_map
: torch.Tensor
The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
For each token, the last item is the number of experts that are routed (n_routed).
For each token, the last item is the number of experts that are routed (n_routed).
The first n_routed items are the destination row indices in the permuted tokens.
The first n_routed items are the destination row indices in the permuted tokens.
...
@@ -134,23 +134,23 @@ def permute_with_mask_map(
...
@@ -134,23 +134,23 @@ def permute_with_mask_map(
Parameters
Parameters
----------
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
row_id_map: torch.Tensor
row_id_map
: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
probs: torch.Tensor
probs
: torch.Tensor
The probabilities of the input tensor. If it is not None, it will be permuted.
The probabilities of the input tensor. If it is not None, it will be permuted.
scale: torch.Tensor
scale
: torch.Tensor
The scale of the input tensor. If it is not None, it will be permuted.
The scale of the input tensor. If it is not None, it will be permuted.
num_tokens: int
num_tokens
: int
Number of tokens in the input tensor.
Number of tokens in the input tensor.
num_experts: int
num_experts
: int
Number of experts in the input tensor.
Number of experts in the input tensor.
num_out_tokens: int
num_out_tokens
: int
Number of tokens in the permuted tensor.
Number of tokens in the permuted tensor.
hidden_size: int
hidden_size
: int
Hidden size of the input tensor.
Hidden size of the input tensor.
scale_hidden_dim: int
scale_hidden_dim
: int
Hidden size of the scale tensor.
Hidden size of the scale tensor.
"""
"""
output
=
torch
.
empty
((
num_out_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
output
=
torch
.
empty
((
num_out_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
...
@@ -211,20 +211,20 @@ def unpermute_with_mask_map(
...
@@ -211,20 +211,20 @@ def unpermute_with_mask_map(
Parameters
Parameters
----------
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_out_tokens, hidden_size]`.
Input tensor of shape `[num_out_tokens, hidden_size]`.
row_id_map: torch.Tensor
row_id_map
: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
merging_probs: torch.Tensor
merging_probs
: torch.Tensor
The merging probabilities of the input tensor. If it is not None, it will be used as weights
The merging probabilities of the input tensor. If it is not None, it will be used as weights
to reduce the unpermuted tokens.
to reduce the unpermuted tokens.
permuted_probs: torch.Tensor
permuted_probs
: torch.Tensor
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
num_tokens: int
num_tokens
: int
Number of tokens in the permuted tensor.
Number of tokens in the permuted tensor.
num_experts: int
num_experts
: int
Number of experts in the permuted tensor.
Number of experts in the permuted tensor.
hidden_size: int
hidden_size
: int
Hidden size of the permuted tensor.
Hidden size of the permuted tensor.
"""
"""
output
=
torch
.
empty
((
num_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
output
=
torch
.
empty
((
num_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
...
@@ -278,21 +278,21 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
...
@@ -278,21 +278,21 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
Parameters
Parameters
----------
----------
fwd_output_grad: torch.Tensor
fwd_output_grad
: torch.Tensor
The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
row_id_map: torch.Tensor
row_id_map
: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
fwd_input: torch.Tensor
fwd_input
: torch.Tensor
The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
merging_probs: torch.Tensor
merging_probs
: torch.Tensor
The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
num_tokens: int
num_tokens
: int
Number of tokens in the permuted tensor.
Number of tokens in the permuted tensor.
num_experts: int
num_experts
: int
Number of experts in the permuted tensor.
Number of experts in the permuted tensor.
num_out_tokens: int
num_out_tokens
: int
Number of tokens in the output tensor.
Number of tokens in the output tensor.
hidden_size: int
hidden_size
: int
Hidden size of the output tensor.
Hidden size of the output tensor.
"""
"""
act_grad
=
torch
.
empty
(
act_grad
=
torch
.
empty
(
...
@@ -339,13 +339,13 @@ def make_chunk_sort_map(
...
@@ -339,13 +339,13 @@ def make_chunk_sort_map(
Parameters
Parameters
----------
----------
split_sizes: torch.Tensor
split_sizes
: torch.Tensor
The sizes of the chunks of shape `[num_splits,]`.
The sizes of the chunks of shape `[num_splits,]`.
sorted_indices: torch.Tensor
sorted_indices
: torch.Tensor
The indices of the sorted chunks of shape `[num_splits,]`.
The indices of the sorted chunks of shape `[num_splits,]`.
num_tokens: int
num_tokens
: int
Number of tokens in the input tensor.
Number of tokens in the input tensor.
num_splits: int
num_splits
: int
Number of splits of split_sizes and sorted_indices.
Number of splits of split_sizes and sorted_indices.
"""
"""
row_id_map
=
torch
.
empty
((
num_tokens
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
row_id_map
=
torch
.
empty
((
num_tokens
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
@@ -373,17 +373,17 @@ def sort_chunks_by_map(
...
@@ -373,17 +373,17 @@ def sort_chunks_by_map(
Parameters
Parameters
----------
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`.
Input tensor of shape `[num_tokens, hidden_size]`.
row_id_map: torch.Tensor
row_id_map
: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens,]`.
The token to expert mapping tensor of shape `[num_tokens,]`.
probs: torch.Tensor
probs
: torch.Tensor
The probabilities of the input tensor. If it is not None, it will be permuted.
The probabilities of the input tensor. If it is not None, it will be permuted.
num_tokens: int
num_tokens
: int
Number of tokens in the input tensor.
Number of tokens in the input tensor.
hidden_size: int
hidden_size
: int
Hidden size of the input tensor.
Hidden size of the input tensor.
is_forward: bool
is_forward
: bool
Whether the sort is for forward or backward.
Whether the sort is for forward or backward.
"""
"""
output
=
torch
.
empty
((
num_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
output
=
torch
.
empty
((
num_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
...
...
transformer_engine/pytorch/utils.py
View file @
970620a5
...
@@ -8,11 +8,13 @@ import functools
...
@@ -8,11 +8,13 @@ import functools
import
math
import
math
import
os
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
contextlib
import
nullcontext
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
.
import
torch_version
from
.quantized_tensor
import
Quantizer
from
.quantized_tensor
import
Quantizer
from
.torch_version
import
torch_version
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
__all__
=
[
"get_device_compute_capability"
,
"get_cudnn_version"
,
"is_bf16_available"
]
__all__
=
[
"get_device_compute_capability"
,
"get_cudnn_version"
,
"is_bf16_available"
]
...
@@ -622,6 +624,24 @@ def _nvtx_enabled() -> bool:
...
@@ -622,6 +624,24 @@ def _nvtx_enabled() -> bool:
_nvtx_range_messages
:
list
[
str
]
=
[]
_nvtx_range_messages
:
list
[
str
]
=
[]
def
get_nvtx_range_context
(
msg
:
str
):
"""Get NVTX context manager to tag module forward and backward passes.
Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX
context manager for module level profiling tags.
Parameters
----------
msg : str
Message to associate with profiling context.
"""
if
_nvtx_enabled
():
return
torch
.
cuda
.
nvtx
.
range
(
msg
)
return
nullcontext
()
def
nvtx_range_push
(
msg
:
str
)
->
None
:
def
nvtx_range_push
(
msg
:
str
)
->
None
:
"""Push NVTX range onto stack, if NVTX range profiling is enabled
"""Push NVTX range onto stack, if NVTX range profiling is enabled
...
@@ -630,7 +650,7 @@ def nvtx_range_push(msg: str) -> None:
...
@@ -630,7 +650,7 @@ def nvtx_range_push(msg: str) -> None:
Parameters
Parameters
----------
----------
msg: str
msg
: str
Message to associate with range
Message to associate with range
"""
"""
...
@@ -648,7 +668,7 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:
...
@@ -648,7 +668,7 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:
Parameters
Parameters
----------
----------
msg: str, optional
msg
: str, optional
Message associated with range
Message associated with range
"""
"""
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment