Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9df0c4a3
Commit
9df0c4a3
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main'
parents
0d874a4e
f122b07d
Changes
221
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
9 deletions
+59
-9
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+59
-9
No files found.
transformer_engine/pytorch/transformer.py
View file @
9df0c4a3
...
...
@@ -12,7 +12,6 @@ import torch
from
transformer_engine.pytorch.torch_version
import
torch_version
from
transformer_engine.pytorch.module
import
LayerNormMLP
,
LayerNorm
,
RMSNorm
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.jit
import
(
...
...
@@ -35,7 +34,7 @@ from transformer_engine.pytorch.constants import (
from
transformer_engine.pytorch.distributed
import
get_distributed_world_size
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
import
transformer_engine.pytorch.attention.dot_product_attention.utils
as
dpa_utils
warnings
.
filterwarnings
(
"module"
,
category
=
DeprecationWarning
,
module
=
"transformer"
)
...
...
@@ -149,11 +148,21 @@ class TransformerLayer(torch.nn.Module):
distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`.
Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by
:attr:`window_size` in :meth:`forward` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `self_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = "no_mask"
type of attention mask passed into softmax operation for decoder.
enc_dec_window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention in decoder.
enc_dec_bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the decoder.
If `None`, it will be set to `False` for `enc_dec_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
zero_centered_gamma : bool, default = False
if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
...
...
@@ -175,7 +184,7 @@ class TransformerLayer(torch.nn.Module):
if set to ``False``, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
Options are: ``'gelu'``, ``'geglu'``,
``'glu'``,
``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : Optional[dict], default = None
Additional parameters for the activation function.
...
...
@@ -302,7 +311,9 @@ class TransformerLayer(torch.nn.Module):
kv_channels
:
Optional
[
int
]
=
None
,
self_attn_mask_type
:
str
=
"causal"
,
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
bottom_right_diagonal
:
Optional
[
bool
]
=
None
,
enc_dec_attn_mask_type
:
str
=
"no_mask"
,
enc_dec_bottom_right_diagonal
:
Optional
[
bool
]
=
None
,
enc_dec_window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
tp_group
:
Optional
[
dist_group_type
]
=
None
,
tp_size
:
int
=
1
,
...
...
@@ -344,8 +355,10 @@ class TransformerLayer(torch.nn.Module):
self
.
self_attn_mask_type
=
self_attn_mask_type
self
.
window_size
=
window_size
self
.
bottom_right_diagonal
=
bottom_right_diagonal
self
.
enc_dec_attn_mask_type
=
enc_dec_attn_mask_type
self
.
enc_dec_window_size
=
enc_dec_window_size
self
.
enc_dec_bottom_right_diagonal
=
enc_dec_bottom_right_diagonal
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
ub_bulk_wgrad
=
ub_tp_comm_overlap
and
ub_bulk_wgrad
ub_bulk_dgrad
=
ub_tp_comm_overlap
and
ub_bulk_dgrad
...
...
@@ -398,6 +411,7 @@ class TransformerLayer(torch.nn.Module):
self
.
softmax_type
=
softmax_type
self
.
name
=
name
TransformerEngineBaseModule
.
_validate_name
(
self
)
attention_args
=
(
hidden_size
,
...
...
@@ -446,7 +460,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type
=
qk_norm_type
,
qk_norm_eps
=
qk_norm_eps
,
qk_norm_before_rope
=
qk_norm_before_rope
,
name
=
name
+
".self_attention"
if
name
is
not
None
else
None
,
name
=
self
.
name
+
".self_attention"
if
self
.
name
is
not
None
else
None
,
)
if
layer_type
==
"decoder"
:
...
...
@@ -463,7 +477,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type
=
qk_norm_type
,
qk_norm_eps
=
qk_norm_eps
,
qk_norm_before_rope
=
qk_norm_before_rope
,
name
=
name
+
".inter_attention"
if
name
is
not
None
else
None
,
name
=
self
.
name
+
".inter_attention"
if
self
.
name
is
not
None
else
None
,
)
# LayerNorm -> activation(Linear + Bias) -> Linear
...
...
@@ -499,7 +513,7 @@ class TransformerLayer(torch.nn.Module):
activation_params
=
activation_params
,
normalization
=
normalization
,
device
=
device
,
name
=
name
+
".layernorm_mlp"
if
name
is
not
None
else
None
,
name
=
self
.
name
+
".layernorm_mlp"
if
self
.
name
is
not
None
else
None
,
)
self
.
hidden_dropout
=
hidden_dropout
...
...
@@ -606,10 +620,12 @@ class TransformerLayer(torch.nn.Module):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
self_attn_mask_type
:
Optional
[
str
]
=
None
,
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
bottom_right_diagonal
:
Optional
[
bool
]
=
None
,
encoder_output
:
Optional
[
torch
.
Tensor
]
=
None
,
enc_dec_attn_mask
:
Optional
[
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]
=
None
,
enc_dec_attn_mask_type
:
Optional
[
str
]
=
None
,
enc_dec_window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
enc_dec_bottom_right_diagonal
:
Optional
[
bool
]
=
None
,
is_first_microbatch
:
Optional
[
bool
]
=
None
,
checkpoint_core_attention
:
bool
=
False
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
...
...
@@ -654,6 +670,11 @@ class TransformerLayer(torch.nn.Module):
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in encoder.
bottom_right_diagonal: Optional[bool] = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `self_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using
:attr:`layer_type` = ``"decoder"``.
...
...
@@ -670,6 +691,11 @@ class TransformerLayer(torch.nn.Module):
Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in decoder.
enc_dec_bottom_right_diagonal: Optional[bool] = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the decoder.
If `None`, it will be set to `False` for `enc_dec_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
...
...
@@ -736,10 +762,35 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type
=
self
.
self_attn_mask_type
if
window_size
is
None
:
window_size
=
self
.
window_size
window_size
=
dpa_utils
.
check_set_window_size
(
self_attn_mask_type
,
window_size
)
if
enc_dec_attn_mask_type
is
None
:
enc_dec_attn_mask_type
=
self
.
enc_dec_attn_mask_type
if
enc_dec_window_size
is
None
:
enc_dec_window_size
=
self
.
enc_dec_window_size
enc_dec_window_size
=
dpa_utils
.
check_set_window_size
(
enc_dec_attn_mask_type
,
enc_dec_window_size
)
if
bottom_right_diagonal
is
None
:
bottom_right_diagonal
=
self
.
bottom_right_diagonal
if
self_attn_mask_type
in
{
"causal"
,
"padding_causal"
}:
bottom_right_diagonal
=
False
if
bottom_right_diagonal
is
None
or
self_attn_mask_type
in
{
"causal_bottom_right"
,
"padding_causal_bottom_right"
,
}:
bottom_right_diagonal
=
True
if
enc_dec_bottom_right_diagonal
is
None
:
enc_dec_bottom_right_diagonal
=
self
.
enc_dec_bottom_right_diagonal
if
enc_dec_attn_mask_type
in
{
"causal"
,
"padding_causal"
}:
enc_dec_bottom_right_diagonal
=
False
if
enc_dec_bottom_right_diagonal
is
None
or
enc_dec_attn_mask_type
in
{
"causal_bottom_right"
,
"padding_causal_bottom_right"
,
}:
enc_dec_bottom_right_diagonal
=
True
assert
(
self_attn_mask_type
in
AttnMaskTypes
...
...
@@ -768,9 +819,6 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask
[
i
].
dtype
==
torch
.
bool
for
i
in
range
(
len
(
enc_dec_attn_mask
))
),
"Encoder-decoder attention mask must be boolean tensor(s)"
if
TEDebugState
.
debug_enabled
:
TransformerEngineBaseModule
.
_validate_name
(
self
)
# For AMP
if
torch
.
is_autocast_enabled
():
hidden_states
=
cast_if_needed
(
hidden_states
,
torch_get_autocast_gpu_dtype
())
...
...
@@ -781,6 +829,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask
=
attention_mask
,
attn_mask_type
=
self_attn_mask_type
,
window_size
=
window_size
,
bottom_right_diagonal
=
bottom_right_diagonal
,
inference_params
=
inference_params
,
is_first_microbatch
=
is_first_microbatch
,
checkpoint_core_attention
=
checkpoint_core_attention
,
...
...
@@ -816,6 +865,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask
=
enc_dec_attn_mask
,
attn_mask_type
=
enc_dec_attn_mask_type
,
window_size
=
enc_dec_window_size
,
bottom_right_diagonal
=
enc_dec_bottom_right_diagonal
,
encoder_output
=
encoder_output
,
inference_params
=
inference_params
,
is_first_microbatch
=
is_first_microbatch
,
...
...
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment