Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
152 additions
and
73 deletions
+152
-73
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+9
-2
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+7
-2
transformer_engine/pytorch/tensor/quantized_tensor.py
transformer_engine/pytorch/tensor/quantized_tensor.py
+16
-0
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+111
-69
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+9
-0
No files found.
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
2b05e121
...
...
@@ -4,13 +4,14 @@
"""Tensor class with FP8 data"""
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
from
typing
import
Optional
,
Tuple
,
Iterable
,
Union
import
warnings
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine.common.recipe
import
DelayedScaling
,
Float8CurrentScaling
,
Recipe
from
..utils
import
canonicalize_process_group
,
devices_match
from
._internal.float8_tensor_base
import
Float8TensorBase
,
_FromFloat8Func
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
...
...
@@ -166,6 +167,9 @@ class Float8Quantizer(Quantizer):
quantizer
=
self
,
)
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
DelayedScaling
class
Float8CurrentScalingQuantizer
(
Quantizer
):
"""Builder class for FP8 tensors with per-tensor current scaling
...
...
@@ -328,6 +332,9 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""Get process group for amax reduction"""
return
canonicalize_process_group
(
self
.
amax_reduction_group
)
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
Float8CurrentScaling
class
Float8Tensor
(
Float8TensorBase
,
QuantizedTensor
):
"""Experimental tensor class with FP8 data
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
2b05e121
...
...
@@ -6,12 +6,13 @@
from
__future__
import
annotations
from
collections.abc
import
Iterable
import
math
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine.common.recipe
import
MXFP8BlockScaling
,
Recipe
from
..constants
import
MXFP8_BLOCK_SCALING_SIZE
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
...
...
@@ -135,6 +136,9 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8?
pass
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
MXFP8BlockScaling
class
MXFP8Tensor
(
MXFP8TensorBase
,
QuantizedTensor
):
"""Experimental tensor class with FP8 data
...
...
@@ -380,6 +384,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# Quantize to FP8
assert
self
.
_quantizer
is
not
None
,
"Can't quantize without a quantizer"
self
.
_quantizer
.
internal
=
False
self
.
data
=
self
.
_quantizer
.
quantize
(
tensor
)
if
self
.
requires_grad
!=
tensor
.
requires_grad
:
self
.
requires_grad_
(
requires_grad
=
tensor
.
requires_grad
)
...
...
transformer_engine/pytorch/tensor/quantized_tensor.py
View file @
2b05e121
...
...
@@ -8,11 +8,13 @@ from __future__ import annotations
from
typing
import
Optional
,
Tuple
,
Iterable
,
Any
,
Dict
,
Union
import
abc
import
copy
import
warnings
import
torch
from
torch.utils._pytree
import
tree_map
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
class
QuantizedTensorBase
:
...
...
@@ -31,6 +33,8 @@ class QuantizedTensorBase:
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
_quantizer
:
Optional
[
Quantizer
]
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
...
...
@@ -69,6 +73,14 @@ class QuantizedTensorBase:
f
"
{
self
.
__class__
.
__name__
}
class does not implement restore_from_saved function"
)
def
update_quantizer
(
self
,
quantizer
:
Quantizer
):
"""Update quantizer for the tensor"""
if
self
.
_quantizer
is
None
:
raise
RuntimeError
(
"To be updated, quantizer must be set"
)
if
self
.
_quantizer
is
not
quantizer
:
warnings
.
warn
(
"Quantizer is being updated, this may affect model behavior"
)
self
.
_quantizer
=
quantizer
def
prepare_for_saving
(
*
tensors
:
Union
[
torch
.
Tensor
,
QuantizedTensorBase
],
...
...
@@ -238,6 +250,10 @@ class Quantizer(abc.ABC):
"""Create shallow copy"""
return
copy
.
copy
(
self
)
@
abc
.
abstractmethod
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
"""Returns recipe class that is compatible with this quantizer"""
class
_QuantizeFunc
(
torch
.
autograd
.
Function
):
"""Cast to FP8 from other dtype"""
...
...
transformer_engine/pytorch/transformer.py
View file @
2b05e121
...
...
@@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd'
attn_input_format: {'sbhd', 'bshd'
, 'thd'
}, default = 'sbhd'
This controls whether the dimensions of the
intermediate hidden states is 'batch first' ('bshd')
or
'sequence
first' ('
sb
hd'). `s` stands for the sequence
length, `b` batch size
, `h` the number of heads, `d`
head size.
Note that these formats are very closely
intermediate hidden states is
'sequence first' ('sbhd'),
'batch first' ('bshd')
,
or 'token
first' ('
t
hd'). `s` stands for the sequence
length, `b` batch size,
`t` the total number of tokens
, `h` the number of heads, `d`
head size.
Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules.
name: str, default = `None`
...
...
@@ -235,6 +235,14 @@ class TransformerLayer(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
"""
def
__init__
(
...
...
@@ -284,6 +292,8 @@ class TransformerLayer(torch.nn.Module):
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
attn_input_format
:
str
=
"sbhd"
,
name
:
str
=
None
,
use_qk_norm
:
bool
=
False
,
qk_norm_eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
...
...
@@ -373,6 +383,8 @@ class TransformerLayer(torch.nn.Module):
"ub_overlap_rs"
:
ub_overlap_rs
,
"ub_overlap_rs_dgrad"
:
ub_overlap_rs_dgrad
,
"qkv_format"
:
self
.
attn_input_format
,
"seq_length"
:
seq_length
,
"micro_batch_size"
:
micro_batch_size
,
}
self
.
self_attention
=
MultiheadAttention
(
...
...
@@ -384,6 +396,8 @@ class TransformerLayer(torch.nn.Module):
return_bias
=
not
self
.
parallel_attention_mlp
,
normalization
=
normalization
,
device
=
device
,
use_qk_norm
=
use_qk_norm
,
qk_norm_eps
=
qk_norm_eps
,
name
=
name
+
".self_attention"
if
name
is
not
None
else
None
,
)
...
...
@@ -398,6 +412,8 @@ class TransformerLayer(torch.nn.Module):
return_bias
=
True
,
normalization
=
normalization
,
device
=
device
,
use_qk_norm
=
use_qk_norm
,
qk_norm_eps
=
qk_norm_eps
,
name
=
name
+
".inter_attention"
if
name
is
not
None
else
None
,
)
...
...
@@ -552,6 +568,8 @@ class TransformerLayer(torch.nn.Module):
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_kv
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_kv_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
max_seqlen_kv
:
Optional
[
int
]
=
None
,
fast_zero_fill
:
bool
=
True
,
...
...
@@ -568,88 +586,99 @@ class TransformerLayer(torch.nn.Module):
Parameters
----------
hidden_states : torch.Tensor
Input tensor.
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input. It should be
in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
mask. It should be `None` for causal masks and "`no_mask`" type.
A `True` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention.
Boolean tensor used to mask out self-attention softmax input. It should be
in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
mask. It should be `None` for causal masks and "`no_mask`" type.
A `True` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default = `causal`
Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of
the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner.
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default = `causal`
Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of
the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention in encoder.
Sliding window size for local attention in encoder.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
A `True` value means the corresponding position is masked out and a `False`
means that position is allowed to participate in attention.
default = `None`. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
A `True` value means the corresponding position is masked out and a `False`
means that position is allowed to participate in attention.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `None`
Type of attention mask passed into softmax operation for decoder.
default = `None`
Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention in decoder.
Sliding window size for local attention in decoder.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
checkpoint_core_attention: bool, default = `False`
If true, forward activations for core attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
If true, forward activations for core attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
Bias tensor for Q * K.T
alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Used by decoders' cross-attention.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q
_padded
` if not provided.
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv
_padded
` if not provided.
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
If true, there are padding tokens between individual sequences in a packed batch,
i.e. qkv_format = 'thd'.
"""
if
self_attn_mask_type
is
None
:
...
...
@@ -678,7 +707,9 @@ class TransformerLayer(torch.nn.Module):
if
(
"padding"
in
self_attn_mask_type
or
self_attn_mask_type
==
"arbitrary"
)
and
attention_mask
is
not
None
:
assert
attention_mask
.
dtype
==
torch
.
bool
,
"Attention mask must be a boolean tensor"
assert
all
(
attention_mask
[
i
].
dtype
==
torch
.
bool
for
i
in
range
(
len
(
attention_mask
))
),
"Attention mask must be a boolean tensor or a list/tuple of two boolean tensors"
if
(
"padding"
in
enc_dec_attn_mask_type
or
enc_dec_attn_mask_type
==
"arbitrary"
)
and
enc_dec_attn_mask
is
not
None
:
...
...
@@ -707,9 +738,11 @@ class TransformerLayer(torch.nn.Module):
core_attention_bias
=
core_attention_bias
,
alibi_slopes
=
alibi_slopes
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
cu_seqlens_kv
=
cu_seqlens_q
,
cu_seqlens_q_padded
=
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
=
cu_seqlens_q_padded
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_
kv
,
max_seqlen_kv
=
max_seqlen_
q
,
fast_zero_fill
=
fast_zero_fill
,
pad_between_seqs
=
pad_between_seqs
,
)
...
...
@@ -733,12 +766,21 @@ class TransformerLayer(torch.nn.Module):
attn_mask_type
=
enc_dec_attn_mask_type
,
window_size
=
enc_dec_window_size
,
encoder_output
=
encoder_output
,
inference_params
=
inference_params
,
is_first_microbatch
=
is_first_microbatch
,
checkpoint_core_attention
=
checkpoint_core_attention
,
rotary_pos_emb
=
rotary_pos_emb
,
core_attention_bias_type
=
core_attention_bias_type
,
core_attention_bias
=
core_attention_bias
,
alibi_slopes
=
alibi_slopes
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
cu_seqlens_q_padded
=
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_kv
,
fast_zero_fill
=
fast_zero_fill
,
pad_between_seqs
=
pad_between_seqs
,
)
if
self
.
apply_residual_connection_post_layernorm
:
attention_output
,
attention_bias
,
residual
=
inter_attention_outputs
...
...
transformer_engine/pytorch/utils.py
View file @
2b05e121
...
...
@@ -37,8 +37,16 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
Must be used carefully.
"""
for
t
in
tensors
:
if
t
is
not
None
:
# Workaround for double buffering in cpu offload
if
hasattr
(
t
,
"do_not_clear"
):
continue
if
hasattr
(
t
,
"get_data_tensors"
):
if
any
(
hasattr
(
tensor
,
"do_not_clear"
)
for
tensor
in
t
.
get_data_tensors
()):
continue
if
hasattr
(
t
,
"clear"
):
t
.
clear
()
else
:
...
...
@@ -462,6 +470,7 @@ def is_bf16_compatible() -> None:
return
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
@
functools
.
lru_cache
(
maxsize
=
None
)
def
is_non_tn_fp8_gemm_supported
()
->
bool
:
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
...
...
Prev
1
…
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment