Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
fd60eedd
Commit
fd60eedd
authored
Mar 09, 2026
by
wenjh
Browse files
Support GLM params
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
99a8a0c5
Pipeline
#3435
failed with stages
in 0 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
30 deletions
+74
-30
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+6
-0
transformer_engine/pytorch/attention/dot_product_attention/backends.py
...ngine/pytorch/attention/dot_product_attention/backends.py
+52
-19
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
.../attention/dot_product_attention/dot_product_attention.py
+4
-0
transformer_engine/pytorch/attention/dot_product_attention/utils.py
...r_engine/pytorch/attention/dot_product_attention/utils.py
+12
-11
No files found.
tests/pytorch/attention/test_attention.py
View file @
fd60eedd
...
...
@@ -83,6 +83,8 @@ model_configs_base = {
"base_5_1"
:
ModelConfig
(
8
,
128
,
16
,
512
,
max_seqlen_kv
=
2048
),
"base_6_0"
:
ModelConfig
(
8
,
1
,
16
,
1024
,
max_seqlen_kv
=
2048
),
"base_6_1"
:
ModelConfig
(
8
,
128
,
16
,
1024
,
max_seqlen_kv
=
2048
),
"base_7_0"
:
ModelConfig
(
4
,
1226
,
32
,
256
),
}
...
...
@@ -277,6 +279,8 @@ model_configs_mla = {
# "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
# "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# inference
"mla_4_0"
:
ModelConfig
(
4
,
1226
,
32
,
256
),
}
...
...
@@ -332,6 +336,8 @@ model_configs_mask = {
"mask_10_1"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"mask_11_0"
:
ModelConfig
(
4
,
1226
,
32
,
256
,
attn_mask_type
=
"padding_causal"
),
}
...
...
transformer_engine/pytorch/attention/dot_product_attention/backends.py
View file @
fd60eedd
...
...
@@ -446,6 +446,7 @@ class FlashAttention(torch.nn.Module):
attention_type
:
str
=
"self"
,
layer_number
:
Optional
[
int
]
=
None
,
deterministic
:
bool
=
False
,
return_qk_max
:
Optional
[
bool
]
=
False
,
)
->
None
:
super
().
__init__
()
...
...
@@ -470,6 +471,8 @@ class FlashAttention(torch.nn.Module):
if
not
self
.
logger
.
hasHandlers
():
self
.
logger
.
addHandler
(
attn_log
.
_stream_handler
)
self
.
return_qk_max
=
return_qk_max
@
classmethod
def
_get_cached_page_offsets
(
cls
,
split_factor
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
...
...
@@ -724,6 +727,7 @@ class FlashAttention(torch.nn.Module):
alibi_slopes
is
None
),
"Alibi slope bias addition is not supported with context parallelism."
with
self
.
attention_dropout_ctx
():
assert
(
not
self
.
return_qk_max
),
"attn_forward_func_with_cp does not support returning qk_max yet."
output
=
attn_forward_func_with_cp
(
self
.
training
,
query_layer
,
...
...
@@ -821,16 +825,29 @@ class FlashAttention(torch.nn.Module):
allow_negative_entries
=
False
,
)
fa_optional_forward_kwargs
[
"block_table"
]
=
remapped_block_table
output
=
func
(
query_layer
,
key_layer
,
value_layer
,
*
fa_optional_forward_args_thd
,
self
.
attention_dropout
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
"causal"
in
attn_mask_type
,
**
fa_optional_forward_kwargs
,
)
if
not
self
.
return_qk_max
:
output
=
func
(
query_layer
,
key_layer
,
value_layer
,
*
fa_optional_forward_args_thd
,
self
.
attention_dropout
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
"causal"
in
attn_mask_type
,
**
fa_optional_forward_kwargs
,
)
else
:
output
,
qk_max
=
func
(
query_layer
,
key_layer
,
value_layer
,
*
fa_optional_forward_args_thd
,
self
.
attention_dropout
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
"causal"
in
attn_mask_type
,
return_qkmax
=
True
,
**
fa_optional_forward_kwargs
,
)
else
:
fa_3_optional_forward_kwargs
=
{}
fa_3_optional_forward_kwargs
[
"window_size"
]
=
window_size
...
...
@@ -886,15 +903,27 @@ class FlashAttention(torch.nn.Module):
for
x
in
[
query_layer
,
key_layer
,
value_layer
]
)
try
:
output
=
func
(
query_layer
,
key_layer
,
value_layer
,
*
fa_optional_forward_args_thd
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
"causal"
in
attn_mask_type
,
**
fa_3_optional_forward_kwargs
,
)
if
not
self
.
return_qk_max
:
output
=
func
(
query_layer
,
key_layer
,
value_layer
,
*
fa_optional_forward_args_thd
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
"causal"
in
attn_mask_type
,
**
fa_3_optional_forward_kwargs
,
)
else
:
output
,
qk_max
=
func
(
query_layer
,
key_layer
,
value_layer
,
*
fa_optional_forward_args_thd
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
"causal"
in
attn_mask_type
,
return_qkmax
=
True
,
**
fa_3_optional_forward_kwargs
,
)
if
isinstance
(
output
,
(
List
,
Tuple
)):
output
=
output
[
0
]
except
TypeError
as
e
:
...
...
@@ -956,6 +985,10 @@ class FlashAttention(torch.nn.Module):
elif
q_format
==
"thd"
:
# thd -> t(hd)
output
=
output
.
reshape
(
output
.
shape
[
0
],
-
1
)
if
self
.
return_qk_max
:
return
output
.
contiguous
(),
qk_max
return
output
.
contiguous
()
...
...
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
View file @
fd60eedd
...
...
@@ -223,6 +223,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream
:
torch
.
cuda
.
Stream
=
None
,
cp_comm_type
:
str
=
"p2p"
,
softmax_scale
:
Optional
[
float
]
=
None
,
return_qk_max
:
Optional
[
bool
]
=
False
,
)
->
None
:
super
().
__init__
()
...
...
@@ -251,6 +252,8 @@ class DotProductAttention(TransformerEngineBaseModule):
self
.
cp_stream
=
cp_stream
self
.
cp_comm_type
=
cp_comm_type
self
.
return_qk_max
=
return_qk_max
self
.
hidden_size_per_attention_head_k
=
(
kv_channels
if
isinstance
(
kv_channels
,
int
)
else
kv_channels
[
0
]
)
...
...
@@ -317,6 +320,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type
=
attention_type
,
layer_number
=
layer_number
,
deterministic
=
self
.
deterministic
,
return_qk_max
=
self
.
return_qk_max
,
**
attn_kwargs
,
)
...
...
transformer_engine/pytorch/attention/dot_product_attention/utils.py
View file @
fd60eedd
...
...
@@ -507,17 +507,18 @@ def get_attention_backend(
and
device_compute_capability
not
in
((
8
,
0
),
(
9
,
0
),
(
10
,
0
),
(
12
,
0
))
)
):
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s."
,
head_dim_qk
,
head_dim_v
,
"."
.
join
([
str
(
i
)
for
i
in
device_compute_capability
]),
)
use_flash_attention_2
=
False
if
not
(
IS_HIP_EXTENSION
and
head_dim_qk
==
256
and
head_dim_v
==
256
):
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s."
,
head_dim_qk
,
head_dim_v
,
"."
.
join
([
str
(
i
)
for
i
in
device_compute_capability
]),
)
use_flash_attention_2
=
False
if
use_flash_attention_3
and
(
head_dim_qk
>
128
or
head_dim_v
>
128
):
if
FlashAttentionUtils
.
v3_is_installed
:
logger
.
debug
(
"Disabling FlashAttention 3 for head_dim > 128"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment