Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
87e3e56e
Commit
87e3e56e
authored
Aug 27, 2025
by
yuguo
Browse files
Merge commit '
734bcedd
' of...
Merge commit '
734bcedd
' of
https://github.com/NVIDIA/TransformerEngine
parents
2f11bd2e
734bcedd
Changes
217
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1690 additions
and
748 deletions
+1690
-748
tests/jax/utils.py
tests/jax/utils.py
+6
-4
tests/pytorch/attention/run_attention_with_cp.py
tests/pytorch/attention/run_attention_with_cp.py
+1
-1
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+538
-448
tests/pytorch/attention/test_attention_with_cp.py
tests/pytorch/attention/test_attention_with_cp.py
+57
-38
tests/pytorch/attention/test_kv_cache.py
tests/pytorch/attention/test_kv_cache.py
+15
-20
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+34
-0
tests/pytorch/debug/test_api_features.py
tests/pytorch/debug/test_api_features.py
+81
-79
tests/pytorch/debug/test_configs/log_config.yaml
tests/pytorch/debug/test_configs/log_config.yaml
+19
-0
tests/pytorch/debug/test_configs/perf_config.yaml
tests/pytorch/debug/test_configs/perf_config.yaml
+13
-0
tests/pytorch/debug/test_log.py
tests/pytorch/debug/test_log.py
+250
-0
tests/pytorch/debug/test_perf.py
tests/pytorch/debug/test_perf.py
+76
-0
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+1
-0
tests/pytorch/distributed/test_sanity.py
tests/pytorch/distributed/test_sanity.py
+121
-0
tests/pytorch/test_cpu_offloading.py
tests/pytorch/test_cpu_offloading.py
+25
-17
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+37
-65
tests/pytorch/test_float8blockwisetensor.py
tests/pytorch/test_float8blockwisetensor.py
+1
-1
tests/pytorch/test_fused_optimizer.py
tests/pytorch/test_fused_optimizer.py
+0
-15
tests/pytorch/test_fused_router.py
tests/pytorch/test_fused_router.py
+37
-21
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+378
-38
tests/pytorch/test_hf_integration.py
tests/pytorch/test_hf_integration.py
+0
-1
No files found.
tests/jax/utils.py
View file @
87e3e56e
...
...
@@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
@
contextmanager
def
use_jax_gemm
(
enabled
=
False
):
orig_custom_calls_filter
=
os
.
environ
.
get
(
"NVTE_JAX_CUSTOM_CALLS
_RE
"
,
None
)
orig_custom_calls_filter
=
os
.
environ
.
get
(
"NVTE_JAX_CUSTOM_CALLS"
,
None
)
try
:
if
enabled
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS_RE"
]
=
"^(?!GemmPrimitive$).+$"
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS"
]
=
"GemmPrimitive=false"
else
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS"
]
=
"GemmPrimitive=true"
yield
finally
:
if
enabled
:
if
orig_custom_calls_filter
is
None
:
os
.
environ
.
pop
(
"NVTE_JAX_CUSTOM_CALLS
_RE
"
)
os
.
environ
.
pop
(
"NVTE_JAX_CUSTOM_CALLS"
)
else
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS
_RE
"
]
=
orig_custom_calls_filter
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS"
]
=
orig_custom_calls_filter
tests/pytorch/
fused_attn/run_fused_att
n_with_cp.py
→
tests/pytorch/
attention/run_attentio
n_with_cp.py
View file @
87e3e56e
...
...
@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
get_cu_seqlens_on_cp_rank
,
)
import
transformer_engine_torch
as
tex
from
test_
fused_att
n_with_cp
import
model_configs_flash_attn
,
model_configs_fused_attn
from
test_
attentio
n_with_cp
import
model_configs_flash_attn
,
model_configs_fused_attn
from
transformer_engine.pytorch.fp8
import
fp8_autocast
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8Quantizer
from
transformer_engine.common.recipe
import
DelayedScaling
...
...
tests/pytorch/
fused_attn/test_fused_att
n.py
→
tests/pytorch/
attention/test_attentio
n.py
View file @
87e3e56e
...
...
@@ -4,12 +4,12 @@
import
logging
import
math
import
os
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
,
Optional
from
contextlib
import
contextmanager
import
sys
import
pathlib
from
typing
import
Any
,
Dict
,
Tuple
,
Union
import
pytest
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch
import
TransformerLayer
,
fp8_autocast
,
fp8_model_init
...
...
@@ -20,11 +20,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import (
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
FlashAttentionUtils
,
get_attention_backend
,
check_set_window_size
,
AttentionParams
,
)
from
transformer_engine.pytorch.attention
import
InferenceParams
from
transformer_engine.pytorch.attention
import
RotaryPositionEmbedding
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
(
...
...
@@ -49,21 +46,21 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
restore_from_saved
,
)
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
from
utils
import
(
reset_rng_states
,
ModelConfig
,
dtype_tols
,
get_available_attention_backends
,
)
# Only run FP8 tests on H100
fp8_available
,
reason_for_no_fp8
=
fp8
.
FP8GlobalStateManager
.
is_fp8_available
()
# Initialize RNG state
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
def
reset_rng_states
()
->
None
:
"""Revert back to initial RNG state"""
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
# Reset RNG states
reset_rng_states
()
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -72,170 +69,20 @@ def reset_global_fp8_state():
fp8
.
FP8GlobalStateManager
.
reset
()
class
ModelConfig
:
def
__init__
(
self
,
batch_size
:
int
,
num_heads
:
int
,
num_gqa_groups
:
int
,
head_dim_qk
:
int
,
max_seqlen_q
:
int
,
max_seqlen_kv
:
int
,
dropout_p
:
float
,
attn_mask_type
:
str
,
attn_bias_type
:
str
,
head_dim_v
:
int
=
None
,
alibi_type
:
str
=
"none"
,
num_layers
:
int
=
1
,
bias_shape
:
str
=
"1hss"
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
total_requests
:
int
=
None
,
max_ctx_len
:
int
=
None
,
):
self
.
batch_size
=
batch_size
self
.
num_heads
=
num_heads
self
.
num_gqa_groups
=
num_gqa_groups
self
.
head_dim_qk
=
head_dim_qk
self
.
head_dim_v
=
head_dim_qk
if
head_dim_v
is
None
else
head_dim_v
self
.
hidden_size
=
num_heads
*
head_dim_qk
self
.
hidden_size_kv
=
num_gqa_groups
*
self
.
head_dim_v
self
.
max_seqlen_q
=
max_seqlen_q
self
.
max_seqlen_kv
=
max_seqlen_kv
self
.
dropout_p
=
dropout_p
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_bias_type
=
attn_bias_type
self
.
alibi_type
=
alibi_type
self
.
attn_type
=
"self"
if
(
max_seqlen_q
==
max_seqlen_kv
)
else
"cross"
self
.
num_layers
=
num_layers
self
.
bias_shape
=
bias_shape
self
.
window_size
=
window_size
self
.
total_requests
=
total_requests
self
.
max_ctx_len
=
max_ctx_len
@
contextmanager
def
logging_context
(
highest_level
=
logging
.
WARNING
):
previous_level
=
logging
.
root
.
manager
.
disable
logging
.
disable
(
highest_level
)
try
:
yield
finally
:
logging
.
disable
(
previous_level
)
def
_get_attention_backends
(
config
:
ModelConfig
,
qkv_dtype
:
torch
.
dtype
,
qkv_layout
:
str
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
pad_between_seqs
:
bool
=
False
,
context_parallel
:
bool
=
False
,
deterministic
:
bool
=
False
,
fp8
:
bool
=
False
,
fp8_meta
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_training
:
bool
=
True
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
)
->
Tuple
[
List
,
List
]:
"""Check if what attention backends support a model configuration"""
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
alibi_slopes_shape
=
None
if
config
.
attn_bias_type
==
"alibi"
and
config
.
alibi_type
==
"custom"
:
if
config
.
bias_shape
==
"1hss"
:
alibi_slopes_shape
=
[
config
.
num_heads
]
if
config
.
bias_shape
==
"bhss"
:
alibi_slopes_shape
=
[
config
.
batch_size
,
config
.
num_heads
]
core_attention_bias_shape
=
(
config
.
bias_shape
if
config
.
attn_bias_type
==
"post_scale_bias"
else
None
)
core_attention_bias_requires_grad
=
False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if
(
config
.
attn_bias_type
==
"post_scale_bias"
and
config
.
head_dim_qk
<=
128
and
config
.
head_dim_v
<=
128
):
core_attention_bias_requires_grad
=
True
fused_attn_backends
=
[]
available_backends
=
None
flash_attention_backend
=
None
fused_attention_backend
=
None
def
test
():
attention_params
=
AttentionParams
(
qkv_dtype
=
qkv_dtype
,
qkv_layout
=
qkv_layout
,
batch_size
=
config
.
batch_size
,
num_heads
=
config
.
num_heads
,
num_gqa_groups
=
config
.
num_gqa_groups
,
max_seqlen_q
=
config
.
max_seqlen_q
,
max_seqlen_kv
=
config
.
max_seqlen_kv
,
head_dim_qk
=
config
.
head_dim_qk
,
head_dim_v
=
config
.
head_dim_v
,
attn_mask_type
=
config
.
attn_mask_type
,
window_size
=
window_size
,
alibi_slopes_shape
=
alibi_slopes_shape
,
core_attention_bias_type
=
config
.
attn_bias_type
,
core_attention_bias_shape
=
core_attention_bias_shape
,
core_attention_bias_requires_grad
=
core_attention_bias_requires_grad
,
pad_between_seqs
=
pad_between_seqs
,
attention_dropout
=
config
.
dropout_p
,
context_parallel
=
context_parallel
,
deterministic
=
deterministic
,
fp8
=
fp8
,
fp8_meta
=
fp8_meta
,
is_training
=
is_training
,
inference_params
=
inference_params
,
)
(
use_flash_attention
,
use_fused_attention
,
flash_attention_backend
,
fused_attention_backend
,
use_unfused_attention
,
available_backends
,
)
=
get_attention_backend
(
attention_params
)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
_attention_backends
[
"use_fused_attention"
]
=
use_fused_attention
_attention_backends
[
"flash_attention_backend"
]
=
flash_attention_backend
_attention_backends
[
"fused_attention_backend"
]
=
fused_attention_backend
_attention_backends
[
"use_unfused_attention"
]
=
use_unfused_attention
_attention_backends
[
"backend_selection_requires_update"
]
=
False
return
available_backends
,
flash_attention_backend
,
fused_attention_backend
backends
=
{
0
:
"F16_max512_seqlen"
,
1
:
"F16_arbitrary_seqlen"
,
2
:
"FP8"
}
with
logging_context
():
for
i
in
range
(
3
):
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
str
(
i
)
_attention_backends
[
"backend_selection_requires_update"
]
=
True
available_backends
,
flash_attention_backend
,
fused_attention_backend
=
test
()
if
fused_attention_backend
==
FusedAttnBackend
[
backends
[
i
]]:
fused_attn_backends
.
append
(
fused_attention_backend
)
return
available_backends
,
flash_attention_backend
,
fused_attn_backends
model_configs_base
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"base_1_0"
:
ModelConfig
(
8
,
1
6
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_1_1"
:
ModelConfig
(
4
,
1
6
,
16
,
64
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_2_0"
:
ModelConfig
(
2
,
2
4
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_2_1"
:
ModelConfig
(
1
,
2
4
,
24
,
128
,
2048
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_3_0"
:
ModelConfig
(
8
,
1
6
,
16
,
128
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_3_1"
:
ModelConfig
(
8
,
1
6
,
16
,
256
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_4_0"
:
ModelConfig
(
8
,
1
6
,
16
,
192
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_4_1"
:
ModelConfig
(
8
,
1
6
,
16
,
192
,
128
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_5_0"
:
ModelConfig
(
8
,
1
6
,
16
,
512
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_5_1"
:
ModelConfig
(
8
,
1
6
,
16
,
512
,
128
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_6_0"
:
ModelConfig
(
8
,
1
6
,
16
,
1024
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_6_1"
:
ModelConfig
(
8
,
1
6
,
16
,
1024
,
128
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_1_0"
:
ModelConfig
(
8
,
1
28
,
16
,
64
),
"base_1_1"
:
ModelConfig
(
4
,
1
28
,
16
,
64
,
max_seqlen_kv
=
256
),
"base_2_0"
:
ModelConfig
(
2
,
2
048
,
24
,
128
),
"base_2_1"
:
ModelConfig
(
1
,
2
048
,
24
,
128
,
max_seqlen_kv
=
4096
),
"base_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
),
"base_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
),
"base_4_0"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
),
"base_4_1"
:
ModelConfig
(
8
,
1
28
,
16
,
192
,
max_seqlen_kv
=
2048
),
"base_5_0"
:
ModelConfig
(
8
,
1
,
16
,
512
,
max_seqlen_kv
=
2048
),
"base_5_1"
:
ModelConfig
(
8
,
1
28
,
16
,
512
,
max_seqlen_kv
=
2048
),
"base_6_0"
:
ModelConfig
(
8
,
1
,
16
,
1024
,
max_seqlen_kv
=
2048
),
"base_6_1"
:
ModelConfig
(
8
,
1
28
,
16
,
1024
,
max_seqlen_kv
=
2048
),
}
...
...
@@ -279,7 +126,7 @@ def test_dot_product_attention(
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
is_training
=
True
available_backends
,
_
,
fused_attn_backends
=
_
get_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_
available_
attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
...
...
@@ -290,7 +137,7 @@ def test_dot_product_attention(
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
:
is_training
=
False
available_backends
,
_
,
fused_attn_backends
=
_
get_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_
available_
attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
...
...
@@ -411,62 +258,26 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
if
IS_HIP_EXTENSION
:
model_configs_mla
=
{
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0"
:
ModelConfig
(
8
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# self , 0
"mla_1_1"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 0
"mla_1_2"
:
ModelConfig
(
4
,
16
,
16
,
192
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 0
"mla_2_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
),
# self , 1
"mla_2_1"
:
ModelConfig
(
1
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
),
# cross, 1
"mla_2_2"
:
ModelConfig
(
1
,
24
,
24
,
192
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 1
}
else
:
model_configs_mla
=
{
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0"
:
ModelConfig
(
8
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# self , 0
"mla_1_1"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 0
"mla_1_2"
:
ModelConfig
(
4
,
16
,
16
,
192
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 0
"mla_2_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
),
# self , 1
"mla_2_1"
:
ModelConfig
(
1
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
),
# cross, 1
"mla_2_2"
:
ModelConfig
(
1
,
24
,
24
,
192
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 1
"mla_3_0"
:
ModelConfig
(
8
,
16
,
16
,
128
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
64
),
# inference
"mla_3_1"
:
ModelConfig
(
8
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# inference
"mla_3_2"
:
ModelConfig
(
8
,
16
,
16
,
192
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# inference
}
model_configs_mla
=
{
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0"
:
ModelConfig
(
8
,
128
,
16
,
64
,
head_dim_v
=
128
),
# self , 0
"mla_1_1"
:
ModelConfig
(
4
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# cross, 0
"mla_1_2"
:
ModelConfig
(
4
,
128
,
16
,
192
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# cross, 0
"mla_2_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# self , 1
"mla_2_1"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# cross, 1
"mla_2_2"
:
ModelConfig
(
1
,
2048
,
24
,
192
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
128
),
# cross, 1
"mla_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
head_dim_v
=
64
),
# inference
"mla_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
"mla_3_2"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
}
@
pytest
.
mark
.
skipif
(
not
IS_HIP_EXTENSION
and
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_mla
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_mla
.
keys
())
...
...
@@ -477,40 +288,46 @@ def test_dpa_mla(dtype, model_configs, model):
model_configs_mask
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"mask_1_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"mask_1_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
"mask_2_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"mask_2_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"mask_2_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"mask_3_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"mask_3_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"mask_3_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"no_bias"
),
"mask_4_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"mask_4_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"mask_4_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"no_bias"
),
"mask_5_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
"mask_1_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"causal"
),
"mask_1_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"causal"
),
"mask_1_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
),
"mask_2_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"causal_bottom_right"
),
"mask_2_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"causal_bottom_right"
),
"mask_2_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal_bottom_right"
),
"mask_3_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding"
),
"mask_3_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding"
),
"mask_3_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
),
"mask_4_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal"
),
"mask_4_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal"
),
"mask_4_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
),
"mask_5_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"mask_5_1"
:
ModelConfig
(
2
,
2
4
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"mask_5_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"mask_6_0"
:
ModelConfig
(
2
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal"
),
"mask_6_1"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal"
),
"mask_7_0"
:
ModelConfig
(
2
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal_bottom_right"
),
"mask_7_1"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal_bottom_right"
),
"mask_6_0"
:
ModelConfig
(
2
,
16
,
16
,
128
,
1
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"mask_6_1"
:
ModelConfig
(
2
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"mask_7_0"
:
ModelConfig
(
2
,
16
,
16
,
128
,
1
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"mask_7_1"
:
ModelConfig
(
2
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"mask_8_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
1
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"mask_8_1"
:
ModelConfig
(
2
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"mask_9_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
1
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"mask_9_1"
:
ModelConfig
(
2
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"mask_8_0"
:
ModelConfig
(
2
,
1
,
24
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding"
),
"mask_8_1"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding"
),
"mask_9_0"
:
ModelConfig
(
2
,
1
,
24
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal"
),
"mask_9_1"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal"
),
"mask_10_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
1
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
1
,
24
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"mask_10_1"
:
ModelConfig
(
2
,
1
6
,
16
,
256
,
1
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal_bottom_right"
),
}
...
...
@@ -526,44 +343,102 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
"bias_1_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
"bias_1_2"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
"bias_1_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
"bias_1_4"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"alibi"
),
# skipped
"bias_1_5"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"no_mask"
,
"alibi"
),
# skipped
"bias_2_0"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"padding"
,
"post_scale_bias"
),
# skipped
"bias_2_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"padding"
,
"post_scale_bias"
),
# skipped
"bias_1_0"
:
ModelConfig
(
4
,
128
,
16
,
64
,
attn_bias_type
=
"post_scale_bias"
),
"bias_1_1"
:
ModelConfig
(
2
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_bias_type
=
"post_scale_bias"
),
"bias_1_2"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_bias_type
=
"post_scale_bias"
),
"bias_1_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_bias_type
=
"post_scale_bias"
),
"bias_1_4"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_bias_type
=
"alibi"
),
# skipped
"bias_1_5"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_bias_type
=
"alibi"
),
# skipped
"bias_2_0"
:
ModelConfig
(
4
,
128
,
16
,
64
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"post_scale_bias"
),
# skipped
"bias_2_1"
:
ModelConfig
(
2
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"post_scale_bias"
,
),
# skipped
"bias_2_2"
:
ModelConfig
(
4
,
2
4
,
24
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"post_scale_bias"
4
,
2
048
,
24
,
128
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"post_scale_bias"
),
# skipped
"bias_2_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"post_scale_bias"
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"post_scale_bias"
,
),
# skipped
"bias_2_4"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"alibi"
),
# skipped
"bias_2_5"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"alibi"
),
# skipped
"bias_2_4"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"alibi"
),
# skipped
"bias_2_5"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"alibi"
),
# skipped
"bias_3_0"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"causal"
,
"post_scale_bias"
),
"bias_3_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"causal"
,
"post_scale_bias"
),
"bias_3_2"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"post_scale_bias"
),
"bias_3_0"
:
ModelConfig
(
4
,
128
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
"bias_3_1"
:
ModelConfig
(
2
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
"bias_3_2"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
"bias_3_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"post_scale_bias"
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
# skipped
"bias_3_4"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
),
"bias_3_5"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
),
# skipped
"bias_3_4"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"alibi"
),
"bias_3_5"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"alibi"
),
# skipped
"bias_4_0"
:
ModelConfig
(
4
,
1
6
,
16
,
64
,
128
,
128
,
0.0
,
"padding_causal"
,
"post_scale_bias"
4
,
1
28
,
16
,
64
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
),
# skipped
"bias_4_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"padding_causal"
,
"post_scale_bias"
2
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
# skipped
"bias_4_2"
:
ModelConfig
(
4
,
2
4
,
24
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"post_scale_bias"
4
,
2
048
,
24
,
128
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
),
# skipped
"bias_4_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"post_scale_bias"
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
# skipped
"bias_4_4"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"alibi"
),
# skipped
"bias_4_5"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"alibi"
,
),
# skipped
"bias_4_4"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"alibi"
),
# skipped
"bias_4_5"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"alibi"
),
# skipped
}
...
...
@@ -578,33 +453,29 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes
=
{
# test: b, h, hg, d, sq, skv, p,
"bias_1_0"
:
ModelConfig
(
"bias_1_0"
:
ModelConfig
(
4
,
128
,
16
,
64
,
attn_bias_type
=
"post_scale_bias"
,
bias_shape
=
"11ss"
),
"bias_1_1"
:
ModelConfig
(
2
,
128
,
16
,
64
,
attn_bias_type
=
"post_scale_bias"
,
bias_shape
=
"1hss"
),
"bias_1_2"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_bias_type
=
"post_scale_bias"
,
bias_shape
=
"b1ss"
),
"bias_1_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_bias_type
=
"post_scale_bias"
,
bias_shape
=
"bhss"
),
"bias_1_4"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
2048
,
24
,
128
,
0.0
,
# mask, bias, bias_shape,
"no_mask"
,
"post_scale_bias"
,
bias_shape
=
"11ss"
,
),
"bias_1_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"post_scale_bias"
,
bias_shape
=
"1hss"
),
"bias_1_2"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"post_scale_bias"
,
bias_shape
=
"b1ss"
),
"bias_1_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"post_scale_bias"
,
bias_shape
=
"bhss"
),
"bias_1_4"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"alibi"
,
bias_shape
=
"1hss"
,
alibi_type
=
"custom"
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
bias_shape
=
"1hss"
,
alibi_type
=
"custom"
,
),
"bias_1_5"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"alibi"
,
bias_shape
=
"bhss"
,
alibi_type
=
"custom"
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
bias_shape
=
"bhss"
,
alibi_type
=
"custom"
,
),
}
...
...
@@ -620,34 +491,36 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"swa_1_2"
:
ModelConfig
(
2
,
24
,
4
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"swa_1_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
"swa_2_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"swa_2_2"
:
ModelConfig
(
2
,
24
,
4
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"swa_2_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
"swa_3_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"swa_3_2"
:
ModelConfig
(
2
,
24
,
4
,
128
,
2048
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"swa_3_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"swa_4_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"swa_4_2"
:
ModelConfig
(
2
,
24
,
4
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"swa_4_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"no_bias"
),
"swa_5_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"swa_5_2"
:
ModelConfig
(
2
,
24
,
4
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"swa_5_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"no_bias"
),
"swa_6_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
"swa_1_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
),
"swa_1_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
4
),
"swa_1_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
),
"swa_2_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"causal"
),
"swa_2_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"causal"
),
"swa_2_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
),
"swa_3_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"causal_bottom_right"
),
"swa_3_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"causal_bottom_right"
),
"swa_3_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal_bottom_right"
),
"swa_4_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding"
),
"swa_4_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"padding"
),
"swa_4_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
),
"swa_5_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal"
),
"swa_5_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"padding_causal"
),
"swa_5_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
),
"swa_6_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"swa_6_2"
:
ModelConfig
(
2
,
2
4
,
4
,
128
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"swa_6_3"
:
ModelConfig
(
2
,
2
4
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2
048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal_bottom_right"
),
}
@
pytest
.
mark
.
skipif
((
not
IS_HIP_EXTENSION
)
and
(
not
FlashAttentionUtils
.
v2_3_plus
)
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
skipif
(
not
FlashAttentionUtils
.
v2_3_plus
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_swa
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_swa
.
keys
())
...
...
@@ -658,18 +531,36 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"causal"
,
"alibi"
,
alibi_type
=
"vanilla"
),
"alibi_1_1"
:
ModelConfig
(
1
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"causal"
,
"alibi"
,
alibi_type
=
"vanilla"
),
"alibi_1_0"
:
ModelConfig
(
2
,
128
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
alibi_type
=
"vanilla"
),
"alibi_1_1"
:
ModelConfig
(
1
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
alibi_type
=
"vanilla"
,
),
"alibi_2_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
1024
,
1024
,
0.0
,
"causal"
,
"alibi"
,
alibi_type
=
"custom"
2
,
10
24
,
24
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
alibi_type
=
"custom"
),
"alibi_2_1"
:
ModelConfig
(
1
,
24
,
24
,
128
,
1024
,
2048
,
0.0
,
"causal"
,
"alibi"
,
alibi_type
=
"custom"
1
,
1024
,
24
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
alibi_type
=
"custom"
,
),
}
@
pytest
.
mark
.
skipif
((
not
IS_HIP_EXTENSION
)
and
(
not
FlashAttentionUtils
.
v2_3_plus
)
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
skipif
(
not
FlashAttentionUtils
.
v2_3_plus
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_alibi_slopes
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_alibi_slopes
.
keys
())
...
...
@@ -694,16 +585,38 @@ qkv_layouts = [
model_configs_layout
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
),
"layout_0_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"causal"
,
"post_scale_bias"
),
"layout_0_2"
:
ModelConfig
(
1
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"padding"
,
"no_bias"
),
"layout_0_3"
:
ModelConfig
(
1
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"padding_causal"
,
"post_scale_bias"
),
"layout_1_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"layout_1_1"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"post_scale_bias"
),
"layout_1_2"
:
ModelConfig
(
1
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"no_bias"
),
"layout_1_3"
:
ModelConfig
(
1
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"post_scale_bias"
),
"layout_2_0"
:
ModelConfig
(
2
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"layout_2_1"
:
ModelConfig
(
2
,
24
,
24
,
256
,
2048
,
2048
,
0.0
,
"causal"
,
"post_scale_bias"
),
"layout_0_0"
:
ModelConfig
(
2
,
128
,
16
,
64
),
"layout_0_1"
:
ModelConfig
(
2
,
128
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
"layout_0_2"
:
ModelConfig
(
1
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"padding"
),
"layout_0_3"
:
ModelConfig
(
1
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
"layout_1_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
),
"layout_1_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
"layout_1_2"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
),
"layout_1_3"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
"layout_2_0"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
),
"layout_2_1"
:
ModelConfig
(
2
,
2048
,
24
,
256
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
}
...
...
@@ -720,55 +633,54 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd
=
[
"t3hd"
,
"th3d"
,
"thd_t2hd"
,
"thd_th2d"
,
"thd_thd_thd"
]
model_configs_layout_thd
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"layout_0_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"layout_0_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"no_bias"
),
"layout_1_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"layout_1_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"layout_1_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"no_bias"
),
"layout_2_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
"layout_0_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding"
),
"layout_0_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding"
),
"layout_0_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
),
"layout_1_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal"
),
"layout_1_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal"
),
"layout_1_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
),
"layout_2_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"layout_2_1"
:
ModelConfig
(
2
,
2
4
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"layout_2_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
),
"layout_3_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
,
window_size
=
(
4
,
4
)
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"layout_3_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding"
,
window_size
=
(
4
,
4
)),
"layout_3_1"
:
ModelConfig
(
2
,
2
4
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias
"
,
window_size
=
(
4
,
4
)
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding
"
,
window_size
=
(
4
,
4
)
),
"layout_3_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"no_bias"
,
window_size
=
(
4
,
4
)
),
"layout_4_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
,
window_size
=
(
4
,
0
)
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
,
window_size
=
(
4
,
4
)
),
"layout_4_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal"
,
window_size
=
(
4
,
0
)),
"layout_4_1"
:
ModelConfig
(
2
,
2
4
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
,
window_size
=
(
4
,
0
)
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal"
,
window_size
=
(
4
,
0
)
),
"layout_4_2"
:
ModelConfig
(
2
,
2
4
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"no_bias"
,
window_size
=
(
4
,
0
)
2
,
2
048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
,
window_size
=
(
4
,
0
)
),
"layout_5_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
,
window_size
=
(
4
,
0
)
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal_bottom_right"
,
window_size
=
(
4
,
0
)
),
"layout_5_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
,
window_size
=
(
4
,
0
)
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal_bottom_right"
,
window_size
=
(
4
,
0
),
),
"layout_5_2"
:
ModelConfig
(
2
,
2
4
,
2
048
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal_bottom_right"
,
window_size
=
(
4
,
0
),
),
}
...
...
@@ -1158,16 +1070,22 @@ def _run_dot_product_attention(
model_configs_te_layer
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
"te_1_1"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"causal"
,
"post_scale_bias"
),
"te_1_2"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"padding"
,
"post_scale_bias"
),
"te_1_3"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"padding"
,
"no_bias"
),
"te_2_0"
:
ModelConfig
(
1
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"te_2_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"te_2_2"
:
ModelConfig
(
1
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"te_2_3"
:
ModelConfig
(
1
,
16
,
16
,
64
,
2048
,
4096
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
),
"te_3_0"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"causal"
,
"alibi"
),
"te_3_1"
:
ModelConfig
(
4
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal"
,
"alibi"
),
"te_1_0"
:
ModelConfig
(
2
,
128
,
16
,
64
,
attn_bias_type
=
"post_scale_bias"
),
"te_1_1"
:
ModelConfig
(
4
,
128
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
"te_1_2"
:
ModelConfig
(
2
,
128
,
16
,
64
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"post_scale_bias"
),
"te_1_3"
:
ModelConfig
(
2
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"padding"
),
"te_2_0"
:
ModelConfig
(
1
,
2048
,
16
,
64
,
attn_mask_type
=
"causal"
),
"te_2_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
),
"te_2_2"
:
ModelConfig
(
1
,
2048
,
16
,
64
,
attn_mask_type
=
"padding"
),
"te_2_3"
:
ModelConfig
(
1
,
2048
,
16
,
64
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"te_3_0"
:
ModelConfig
(
4
,
128
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
),
"te_3_1"
:
ModelConfig
(
4
,
2048
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
),
}
...
...
@@ -1189,26 +1107,27 @@ def test_transformer_layer(
tols
=
dict
(
atol
=
5e-2
,
rtol
=
5e-2
)
workspace_opt
=
True
qkv_layout
=
"sbh3d"
if
fused_qkv_params
else
"sb3hd"
# override the qkv_layout in mqa gqa mode in ROCm TE
if
IS_HIP_EXTENSION
and
model_configs
[
model
].
num_gqa_groups
!=
model_configs
[
model
].
num_heads
:
qkv_layout
=
"sbhd_sbhd_sbhd"
# Test backend availability
is_training
=
True
available_backends
,
_
,
fused_attn_backends
=
_
get_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_
available_
attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
(
qkv_format
.
replace
(
"hd"
,
"h3d"
)
if
fused_qkv_params
else
qkv_format
.
replace
(
"hd"
,
"3hd"
)
),
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
:
is_training
=
False
available_backends
,
_
,
fused_attn_backends
=
_
get_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_
available_
attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
(
qkv_format
.
replace
(
"hd"
,
"h3d"
)
if
fused_qkv_params
else
qkv_format
.
replace
(
"hd"
,
"3hd"
)
),
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
...
...
@@ -1514,20 +1433,164 @@ def _run_transformer_layer(
return
out
,
inp
.
grad
model_configs_fp8_extra_state
=
{
"large"
:
ModelConfig
(
2
,
128
,
4
,
128
,
num_layers
=
1
),
}
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_sanity_attention_extra_state
(
model
,
dtype
):
config
=
model_configs_fp8_extra_state
[
model
]
# Test backend availability
is_training
=
True
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_layout
=
"sb3hd"
,
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
and
not
flash_attn_supported
:
pytest
.
skip
(
"No attention backend available."
)
outputs
=
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
True
)
outputs_checkpoint_v1_6
=
_run_attention_extra_state
(
dtype
,
config
,
mimic_v1_6
=
True
,
checkpoint
=
True
)
# Check that results match
tols
=
dtype_tols
(
dtype
)
if
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
tols
.
update
(
dict
(
rtol
=
2e-2
,
atol
=
2e-3
))
for
i
,
(
ref
,
test
)
in
enumerate
(
zip
(
outputs
,
outputs_checkpoint
)):
torch
.
testing
.
assert_close
(
test
,
ref
,
**
tols
,
)
for
i
,
(
ref
,
test
)
in
enumerate
(
zip
(
outputs
,
outputs_checkpoint_v1_6
)):
torch
.
testing
.
assert_close
(
test
,
ref
,
**
tols
,
)
def
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
False
,
mimic_v1_6
=
False
):
steps
=
10
path
=
"checkpoint.pt"
fp8_enabled
=
True
fp8_recipe
=
recipe
.
DelayedScaling
(
margin
=
0
,
fp8_format
=
recipe
.
Format
.
HYBRID
,
amax_history_len
=
1
,
amax_compute_algo
=
"most_recent"
,
fp8_dpa
=
fp8_enabled
,
fp8_mha
=
False
,
)
reset_rng_states
()
hidden_states
=
torch
.
randn
(
(
config
.
max_seqlen_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
def
get_model
(
dtype
,
config
):
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
with
fp8_model_init
(
enabled
=
fp8_enabled
,
recipe
=
fp8_recipe
):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.0
,
attention_dropout
=
0.0
,
fuse_qkv_params
=
True
,
params_dtype
=
dtype
,
device
=
"cuda"
,
)
return
block
block
=
get_model
(
dtype
,
config
)
for
i
in
range
(
steps
//
2
):
with
fp8_autocast
(
enabled
=
fp8_enabled
,
fp8_recipe
=
fp8_recipe
):
output
=
block
(
hidden_states
,
None
)
loss
=
output
.
sum
()
loss
.
backward
()
if
checkpoint
:
sd
=
block
.
state_dict
()
if
mimic_v1_6
:
sd
[
"self_attention.core_attention.fused_attention._extra_state"
]
=
sd
[
"self_attention.core_attention._extra_state"
]
del
sd
[
"self_attention.core_attention._extra_state"
]
torch
.
save
(
sd
,
path
)
param_grads
=
[]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
param_grads
.
append
(
p
.
grad
.
clone
())
_cpu_rng_state_new
=
torch
.
get_rng_state
()
_cuda_rng_state_new
=
torch
.
cuda
.
get_rng_state
()
del
block
block
=
get_model
(
dtype
,
config
)
block
.
load_state_dict
(
torch
.
load
(
path
,
weights_only
=
False
))
torch
.
set_rng_state
(
_cpu_rng_state_new
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state_new
)
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
p
.
grad
=
param_grads
.
pop
(
0
)
assert
not
param_grads
,
"Oops!"
for
i
in
range
((
steps
+
1
)
//
2
):
with
fp8_autocast
(
enabled
=
fp8_enabled
,
fp8_recipe
=
fp8_recipe
):
output
=
block
(
hidden_states
,
None
)
loss
=
output
.
sum
()
loss
.
backward
()
torch
.
cuda
.
synchronize
()
if
os
.
path
.
exists
(
path
):
os
.
remove
(
path
)
outputs
=
[
output
,
hidden_states
.
grad
]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
outputs
.
append
(
p
.
grad
)
return
outputs
model_configs_fp8_vs_f16
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_9"
:
ModelConfig
(
2
,
16
,
16
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_10"
:
ModelConfig
(
2
,
2
4
,
12
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_11"
:
ModelConfig
(
1
,
32
,
4
,
128
,
8192
,
8192
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_12"
:
ModelConfig
(
2
,
16
,
16
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_13"
:
ModelConfig
(
2
,
2
4
,
12
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_14"
:
ModelConfig
(
1
,
32
,
4
,
128
,
8192
,
8192
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_15"
:
ModelConfig
(
2
,
16
,
16
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias
"
),
"fp8_16"
:
ModelConfig
(
2
,
2
4
,
12
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias
"
),
"fp8_17"
:
ModelConfig
(
1
,
32
,
4
,
128
,
8192
,
8192
,
0.0
,
"padding"
,
"no_bias
"
),
"fp8_18"
:
ModelConfig
(
2
,
16
,
16
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"fp8_19"
:
ModelConfig
(
2
,
2
4
,
12
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"fp8_20"
:
ModelConfig
(
1
,
32
,
4
,
128
,
8192
,
8192
,
0.0
,
"padding_causal"
,
"no_bias"
),
"fp8_9"
:
ModelConfig
(
2
,
2048
,
16
,
128
),
"fp8_10"
:
ModelConfig
(
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
12
),
"fp8_11"
:
ModelConfig
(
1
,
8192
,
32
,
128
,
num_gqa_groups
=
4
),
"fp8_12"
:
ModelConfig
(
2
,
2048
,
16
,
128
,
attn_mask_type
=
"causal
"
),
"fp8_13"
:
ModelConfig
(
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
12
,
attn_mask_type
=
"causal
"
),
"fp8_14"
:
ModelConfig
(
1
,
8192
,
32
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"causal
"
),
"fp8_15"
:
ModelConfig
(
2
,
2048
,
16
,
128
,
attn_mask_type
=
"padding
"
),
"fp8_16"
:
ModelConfig
(
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
12
,
attn_mask_type
=
"padding
"
),
"fp8_17"
:
ModelConfig
(
1
,
8192
,
32
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"padding
"
),
"fp8_18"
:
ModelConfig
(
2
,
2048
,
16
,
128
,
attn_mask_type
=
"padding_causal"
),
"fp8_19"
:
ModelConfig
(
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
12
,
attn_mask_type
=
"padding_causal"
),
"fp8_20"
:
ModelConfig
(
1
,
8192
,
32
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"padding_causal"
),
}
param_types_fp8_vs_f16
=
[
torch
.
float16
,
torch
.
bfloat16
]
...
...
@@ -1561,7 +1624,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
)
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
,
reason
=
"FP8 Fused attention is not supported on ROCm"
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
...
...
@@ -1576,18 +1639,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"1"
os
.
environ
[
"NVTE_FP8_DPA_BWD"
]
=
"1"
if
fp8_dpa_bwd
else
"0"
config
=
model_configs_fp8_vs_f16
[
model
]
if
(
"padding"
in
config
.
attn_mask_type
or
config
.
head_dim_qk
!=
128
)
and
get_cudnn_version
()
<
(
9
,
7
,
0
,
):
pytest
.
skip
(
"FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7"
)
if
(
FlashAttentionUtils
.
v3_is_installed
and
not
is_training
and
"padding"
not
in
config
.
attn_mask_type
):
# Test backend availability
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_layout
=
qkv_format
.
replace
(
"hd"
,
"h3d"
),
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
# Skip if only unfused backend is supported
if
(
len
(
fused_attn_backends
)
+
flash_attn_supported
+
unfused_attn_supported
)
<
2
:
pytest
.
skip
(
"Less than two backends to compare."
)
if
not
fp8_dpa_bwd
:
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_format
.
replace
(
"hd"
,
"h3d"
),
is_training
=
is_training
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
pytest
.
skip
(
"No attention backend available."
)
if
flash_attn_supported
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
...
...
@@ -1613,11 +1688,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol
=
5e-1
rmse_tol
=
0.15
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"forward output"
))
if
(
FlashAttentionUtils
.
v3_is_installed
and
not
is_training
and
"padding"
not
in
config
.
attn_mask_type
):
if
flash_attn_supported
:
_error
(
flash_attn_fwd_fp8
,
fused_attn_fwd_f16
,
...
...
@@ -1768,7 +1839,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
return
out
,
param_names
,
tuple
(
x
.
grad
for
x
in
params
)
return
out
,
param_names
,
tuple
(
None
for
x
in
params
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
,
reason
=
"FP8 Fused attention is not supported on ROCm"
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
...
...
@@ -1790,23 +1861,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
# if get_device_compute_capability() >= (10, 0):
# config.dropout_p = 0.1
if
(
"padding"
in
config
.
attn_mask_type
or
config
.
head_dim_qk
!=
128
)
and
get_cudnn_version
()
<
(
9
,
7
,
0
,
):
pytest
.
skip
(
"FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7"
)
if
config
.
num_heads
!=
config
.
num_gqa_groups
and
"3"
in
qkv_layout
:
pytest
.
skip
(
"qkv_layout not applicable for MQA/GQA"
)
os
.
environ
[
"NVTE_FP8_DPA_BWD"
]
=
"1"
if
fp8_dpa_bwd
else
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"1"
if
(
FlashAttentionUtils
.
v3_is_installed
and
not
is_training
and
"padding"
not
in
config
.
attn_mask_type
):
# Test backend availability
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
# Skip if only unfused backend is supported
if
flash_attn_supported
+
fused_attn_supported
<
1
:
pytest
.
skip
(
"No FP8 attention backend available."
)
if
not
fp8_dpa_bwd
:
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
pytest
.
skip
(
"No attention backend available."
)
if
config
.
num_heads
!=
config
.
num_gqa_groups
and
"3"
in
qkv_layout
:
pytest
.
skip
(
"qkv_layout not applicable for MQA/GQA"
)
if
flash_attn_supported
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
...
...
@@ -1835,11 +1917,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol
=
0.11
bwd_names
=
[
"dq"
,
"dk"
,
"dv"
]
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"forward output"
))
if
(
FlashAttentionUtils
.
v3_is_installed
and
not
is_training
and
"padding"
not
in
config
.
attn_mask_type
):
if
flash_attn_supported
:
_error
(
flash_attn_fwd_fp8
,
fused_attn_fwd_f16
,
...
...
@@ -2013,21 +2091,21 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
model_configs_fp8
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1"
:
ModelConfig
(
1
,
1
,
1
,
64
,
512
,
512
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_2"
:
ModelConfig
(
4
,
16
,
16
,
64
,
512
,
512
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_3"
:
ModelConfig
(
1
,
1
,
1
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_4"
:
ModelConfig
(
2
,
2
4
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_5"
:
ModelConfig
(
1
,
1
,
1
,
64
,
512
,
512
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_6"
:
ModelConfig
(
4
,
16
,
16
,
64
,
512
,
512
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_7"
:
ModelConfig
(
1
,
1
,
1
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_8"
:
ModelConfig
(
2
,
2
4
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_1"
:
ModelConfig
(
1
,
512
,
1
,
64
),
"fp8_2"
:
ModelConfig
(
4
,
512
,
16
,
64
),
"fp8_3"
:
ModelConfig
(
1
,
2048
,
1
,
128
),
"fp8_4"
:
ModelConfig
(
2
,
2
048
,
24
,
128
),
"fp8_5"
:
ModelConfig
(
1
,
512
,
1
,
64
,
attn_mask_type
=
"causal
"
),
"fp8_6"
:
ModelConfig
(
4
,
512
,
16
,
64
,
attn_mask_type
=
"causal
"
),
"fp8_7"
:
ModelConfig
(
1
,
2048
,
1
,
128
,
attn_mask_type
=
"causal
"
),
"fp8_8"
:
ModelConfig
(
2
,
2
048
,
24
,
128
,
attn_mask_type
=
"causal
"
),
}
param_types_fp8
=
[
torch
.
float16
,
torch
.
bfloat16
]
cudnn_frontend_version
=
int
(
os
.
getenv
(
"NVTE_FUSED_ATTN_FE_VER"
,
"1"
))
models_v0
=
[
"fp8_1"
,
"fp8_2"
,
"fp8_5"
,
"fp8_6"
]
models_v1
=
[
"fp8_3"
,
"fp8_4"
,
"fp8_7"
,
"fp8_8"
]
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
,
reason
=
"FP8 Fused attention is not supported on ROCm"
)
@
pytest
.
mark
.
skipif
(
(
get_cudnn_version
()
<
(
8
,
9
,
3
)
...
...
@@ -2049,6 +2127,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config
=
model_configs_fp8
[
model
]
# Test backend availability
is_training
=
True
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_layout
=
"t3hd"
if
cudnn_frontend_version
==
0
else
"bs3hd"
,
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
(
fused_attn_backends
and
unfused_attn_supported
):
pytest
.
skip
(
"Not enough backends to run this test with."
)
fused_attn_fwd_fp8
,
fused_attn_bwd_fp8
=
_run_custom_mha_fp8
(
dtype
,
config
,
"FusedAttention"
)
unfused_attn_fwd_f16
,
unfused_attn_bwd_f16
=
_run_ref_mha_f16
(
dtype
,
config
,
"UnfusedAttention"
)
...
...
tests/pytorch/
fused_attn/test_fused_att
n_with_cp.py
→
tests/pytorch/
attention/test_attentio
n_with_cp.py
View file @
87e3e56e
...
...
@@ -4,6 +4,8 @@
import
os
import
subprocess
import
sys
import
pathlib
import
pytest
import
torch
...
...
@@ -12,27 +14,29 @@ from transformer_engine.pytorch.utils import (
get_cudnn_version
,
)
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
FlashAttentionUtils
from
test_fused_attn
import
ModelConfig
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
from
utils
import
ModelConfig
,
get_available_attention_backends
# Initialize RNG state
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
model_configs_flash_attn
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
# MHA
"cp_1_2"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
window_size
=
(
512
,
0
)
),
# MHA
"cp_1_3"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
,
window_size
=
(
512
,
512
)
),
# MHA
"cp_2_0"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
# GQA
"cp_2_1"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
# GQA
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
),
# MHA
"cp_1_2"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)),
# MHA
"cp_1_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
window_size
=
(
512
,
512
)),
# MHA
"cp_2_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
),
# GQA
"cp_2_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
),
# GQA
"cp_2_2"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
window_size
=
(
512
,
0
)
),
# GQA
"cp_2_3"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
,
window_size
=
(
512
,
512
)
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)
),
# GQA
"cp_2_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
window_size
=
(
512
,
512
)),
# GQA
}
...
...
@@ -44,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
"--nproc-per-node="
+
str
(
num_gpus_per_node
),
]
te_path
=
os
.
getenv
(
"TE_PATH"
,
"/opt/transformerengine"
)
script_path
=
os
.
path
.
join
(
te_path
,
"tests/pytorch/
fused_attn/run_fused_att
n_with_cp.py"
)
script_path
=
os
.
path
.
join
(
te_path
,
"tests/pytorch/
attention/run_attentio
n_with_cp.py"
)
args
.
append
(
script_path
)
for
k
,
v
in
kwargs
.
items
():
args
.
append
(
f
"
{
k
}
=
{
v
}
"
)
...
...
@@ -94,37 +98,41 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
# MHA
"cp_1_2"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"post_scale_bias"
),
# MHA
"cp_1_3"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
# MHA
"cp_1_4"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
window_size
=
(
512
,
0
)
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
),
# MHA
"cp_1_2"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
"cp_2_0"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
# GQA
"cp_2_1"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
# GQA
"cp_2_2"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"post_scale_bias"
),
# GQA
"cp_2_3"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
# GQA
"cp_1_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
"cp_1_4"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)),
# MHA
"cp_2_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
),
# GQA
"cp_2_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
),
# GQA
"cp_2_2"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
# GQA
"cp_2_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_bias_type
=
"post_scale_bias"
),
# GQA
"cp_2_4"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias
"
,
window_size
=
(
512
,
0
)
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal
"
,
window_size
=
(
512
,
0
)
),
# GQA
"cp_3_0"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
),
# MLA
"cp_3_1"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
64
),
# MLA
"cp_3_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# MLA
"cp_3_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
head_dim_v
=
64
),
# MLA
"cp_3_2"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"post_scale_bias"
,
head_dim_v
=
64
),
# MLA
"cp_3_3"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"post_scale_bias"
,
head_dim_v
=
64
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
,
head_dim_v
=
64
),
# MLA
"cp_3_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_bias_type
=
"post_scale_bias"
,
head_dim_v
=
64
),
# MLA
}
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
7
),
reason
=
"cuDNN 8.9.7+ is required."
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
or
get_device_compute_capability
()
<
(
8
,
0
),
reason
=
"
DTK not surpport fused attn for now,
CP tests require sm80+."
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
or
get_device_compute_capability
()
<
(
8
,
0
),
reason
=
"CP tests require sm80+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bf16"
,
"fp16"
,
"fp8"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fused_attn
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
[
"bshd"
,
"sbhd"
,
"thd"
])
...
...
@@ -176,6 +184,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest
.
skip
(
"MLA CP currently only support KV P2P!"
)
if
dtype
==
"fp8"
and
config
.
head_dim_qk
!=
config
.
head_dim_v
:
pytest
.
skip
(
"MLA CP currently does not support FP8 attention!"
)
dtypes
=
{
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
,
"fp8"
:
torch
.
bfloat16
}
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtypes
[
dtype
],
qkv_layout
=
"_"
.
join
([
qkv_format
]
*
3
),
window_size
=
config
.
window_size
,
context_parallel
=
True
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
pytest
.
skip
(
"No attention backend available."
)
subprocess
.
run
(
get_bash_arguments
(
...
...
tests/pytorch/
fused_att
n/test_kv_cache.py
→
tests/pytorch/
attentio
n/test_kv_cache.py
View file @
87e3e56e
...
...
@@ -5,18 +5,14 @@
from
collections
import
OrderedDict
from
typing
import
List
import
os
import
sys
import
pathlib
import
logging
import
math
import
pytest
import
torch
from
test_fused_attn
import
(
ModelConfig
,
reset_rng_states
,
_get_attention_backends
,
)
from
torch.distributions
import
Exponential
from
transformer_engine.pytorch
import
make_graphed_callables
from
transformer_engine.common
import
recipe
...
...
@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible
,
)
# Initialize RNG state
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
from
utils
import
(
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
,
)
# Reset RNG states
reset_rng_states
()
param_types
=
[
torch
.
float16
]
if
is_bf16_compatible
():
param_types
.
append
(
torch
.
bfloat16
)
model_configs_infer
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
"infer_0"
:
ModelConfig
(
4
,
16
,
16
,
128
,
64
,
64
,
0.0
,
"no_mask"
,
"no_bias"
,
total_requests
=
8
,
max_ctx_len
=
16
),
"infer_1"
:
ModelConfig
(
2
,
16
,
4
,
256
,
66
,
66
,
0.0
,
"no_mask"
,
"no_bias"
,
total_requests
=
6
,
max_ctx_len
=
16
),
# test: b, sq, hq, dqk,
"infer_0"
:
ModelConfig
(
4
,
64
,
16
,
128
,
total_requests
=
8
,
max_ctx_len
=
16
),
"infer_1"
:
ModelConfig
(
2
,
66
,
16
,
256
,
num_gqa_groups
=
4
,
total_requests
=
6
,
max_ctx_len
=
16
),
}
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
...
...
@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
qkv_layout
=
qkv_format
+
"_"
+
"_"
.
join
([
inference_params_qkv_format
]
*
2
)
if
is_paged
:
qkv_layout
=
"paged_kv_"
+
qkv_layout
available_backends
,
_
,
fused_attn_backends
=
_
get_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_
available_
attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
...
...
tests/pytorch/debug/run_distributed.py
View file @
87e3e56e
...
...
@@ -364,6 +364,40 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
set_weight_tensor_tp_group_reduce
(
True
)
# reset
@
run_debug_test
def
sanity_test_log_quantized_stats
(
parallel_mode
,
gather_weight
,
**
kwargs
):
from
test_log
import
LOG_QUANTIZED_CONFIG
kwargs
[
"config_file"
].
write
(
LOG_QUANTIZED_CONFIG
)
kwargs
[
"config_file"
].
flush
()
_init_debug
(
kwargs
[
"config_file"
].
name
,
kwargs
[
"log_dir"
],
FEATURE_DIRS
)
set_weight_tensor_tp_group_reduce
(
gather_weight
)
if
WORLD_SIZE
%
2
!=
0
:
return
# skip
TP_SIZE
=
WORLD_SIZE
//
2
DP_SIZE
=
2
TP_RANK
=
WORLD_RANK
%
TP_SIZE
DP_RANK
=
(
WORLD_RANK
-
TP_RANK
)
//
TP_SIZE
debug_api
.
set_tensor_reduction_group
(
NCCL_WORLD
)
x
,
weight
=
_get_tensors
(
parallel_mode
,
weight_seed
=
TP_RANK
*
1234
,
data_seed
=
DP_RANK
*
1234
,
tp_size
=
TP_SIZE
,
tp_rank
=
TP_RANK
,
)
tp_group_ranks
=
[
i
for
i
in
range
(
DP_RANK
*
TP_SIZE
,
(
DP_RANK
+
1
)
*
TP_SIZE
)]
tp_group
=
dist
.
new_group
(
ranks
=
tp_group_ranks
)
model
=
_init_model
(
weight
,
parallel_mode
=
parallel_mode
,
tp_group
=
tp_group
)
_run_forward_backward
(
x
,
model
,
parallel_mode
=
parallel_mode
,
group
=
tp_group
)
set_weight_tensor_tp_group_reduce
(
True
)
# reset
@
run_debug_test
def
test_log_expert_parallel
(
**
kwargs
):
"""
...
...
tests/pytorch/debug/test_api_features.py
View file @
87e3e56e
...
...
@@ -24,22 +24,17 @@ def test_transformer_engine_no_config(feature_dirs):
# FP8 enabled - true by the default
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
)
[
0
]
# modify_tensor_enabled - False by default
# modify_tensor_enabled -
(
False
, None)
by default
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
)
[
0
]
# inspect_tensor_enabled - False by default
# inspect_tensor_enabled -
(
False
, None)
by default
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.attn.qkv"
,
tensor_name
=
"activation"
,
iteration
=
0
)
# inspect_tensor_postquantize - False by default
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_postquantize_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
)[
0
]
finally
:
debug_api
.
end_debug
()
...
...
@@ -51,24 +46,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs):
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
# caching
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
finally
:
debug_api
.
end_debug
()
...
...
@@ -80,22 +75,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs):
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
iteration
=
0
)
)
[
0
]
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
)
)
[
0
]
finally
:
debug_api
.
end_debug
()
...
...
@@ -111,22 +106,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
# check modify_tensor_enabled
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
)
[
0
]
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"weight"
,
iteration
=
0
)
)
[
0
]
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"weight"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
tensor_name
=
"activation"
,
iteration
=
0
)
)
[
0
]
# check modify_tensor
...
...
@@ -168,14 +163,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
gemm
=
"wgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
,
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc4"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
,
)
)
[
0
]
finally
:
debug_api
.
end_debug
()
...
...
@@ -191,11 +186,11 @@ def test_fake_quant(configs_dir, feature_dirs):
# modify_tensor_enabled
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
)
[
0
]
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
)
)
[
0
]
# modify_tensor
debug_api
.
transformer_engine
.
modify_tensor
(
...
...
@@ -218,11 +213,11 @@ def test_fake_quant(configs_dir, feature_dirs):
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.fc2"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
# caching
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.fc2"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
finally
:
debug_api
.
end_debug
()
...
...
@@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs):
)
tensor
=
torch
.
randn
((
100
,
100
,
5
)).
cuda
()
tensor_fp8
=
Float8
Tenso
r
(
data
=
tensor
.
to
(
torch
.
uint8
).
cuda
(),
fp8_scale_inv
=
torch
.
full
([
1
],
1.0
).
cuda
(),
quantizer
=
Float8
Quantize
r
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
shape
=
tensor
.
shape
,
dtype
=
torch
.
float32
,
)
tensor_fp8
=
quantizer
(
tensor
)
def
log
():
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
...
...
@@ -260,54 +254,64 @@ def test_statistics_collection(configs_dir, feature_dirs):
tensor_name
=
"activation"
,
iteration
=
200
,
tp_group
=
None
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
tensor_fp8
,
columnwise_quantized_tensor
=
tensor_fp8
,
)
stats
=
log
()
assert
stats
[(
"decoder.1.mlp.fc1"
,
"activation"
,
"cur_amax"
,
200
)]
==
tensor
.
abs
().
max
()
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"activation"
,
iteration
=
201
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.2.mlp.fc1"
,
tensor_name
=
"activation"
,
iteration
=
200
)
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"gradient"
,
iteration
=
200
)[
0
]
expected_underflows
=
(
((
tensor_fp8
.
_data
==
0
).
sum
()
-
(
tensor
==
0
).
sum
())
*
100
/
(
100
*
100
*
5
)
)
expected_underflows
=
(
tensor_fp8
.
_data
==
0
).
sum
()
*
100
/
(
100
*
100
*
5
)
expected_overflows
=
(
tensor_fp8
.
_data
==
126
).
sum
()
*
100
/
(
100
*
100
*
5
)
assert
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"gradient"
,
iteration
=
200
)[
0
]
# TE FP8 tensor stats --
assert
debug_api
.
transformer_engine
.
inspect_tensor_
postquantize_
enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"gradient"
,
gemm
=
"wgrad"
,
iteration
=
200
)
debug_api
.
transformer_engine
.
inspect_tensor
_postquantize
(
assert
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"gradient"
,
iteration
=
200
)
[
0
]
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.1.mlp.fc1"
,
tensor
=
tensor_fp8
,
tensor_name
=
"gradient"
,
iteration
=
200
,
rowwise
=
True
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
tensor_fp8
,
columnwise_quantized_tensor
=
tensor_fp8
,
)
stats
=
log
()
torch
.
testing
.
assert_close
(
stats
[(
"decoder.1.mlp.fc1"
,
"gradient"
,
"underflows%"
,
200
)],
expected_underflows
)
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_
postquantize_
enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"activation"
,
gemm
=
"fprop"
,
iteration
=
201
)
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_
postquantize_
enabled
(
"decoder.2.mlp.fc1"
,
tensor_name
=
"gradient"
,
gemm
=
"wgrad"
,
iteration
=
200
)
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"activation"
,
iteration
=
201
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.2.mlp.fc1"
,
tensor_name
=
"gradient"
,
iteration
=
200
)
[
0
]
# Second config in same yaml
tensor
=
torch
.
rand
((
100
,
100
,
5
))
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.6.mlp.fc1"
,
tensor
=
tensor
,
tensor_name
=
"activation"
,
iteration
=
200
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
tensor_fp8
,
columnwise_quantized_tensor
=
tensor_fp8
,
)
stats
=
log
()
stats_names
=
[
x
[
3
]
for
x
in
stats
.
keys
()]
...
...
@@ -316,10 +320,13 @@ def test_statistics_collection(configs_dir, feature_dirs):
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.7.mlp.fc1"
,
tensor
=
tensor
,
tensor_name
=
"weight"
,
iteration
=
200
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
tensor_fp8
,
columnwise_quantized_tensor
=
tensor_fp8
,
)
stats
=
log
()
stats_names
=
[
x
[
3
]
for
x
in
stats
.
keys
()]
...
...
@@ -328,7 +335,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.7.mlp.fc1"
,
tensor_name
=
"weight"
,
iteration
=
201
)
)
[
0
]
assert_empty
()
finally
:
...
...
@@ -343,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
default_logging_enabled
=
False
,
)
def
feed
(
tensor
,
tensor_fp8
):
def
feed
(
tensor
,
tensor_fp8
,
quantizer
):
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.5.mlp.fc1"
,
tensor
=
tensor
,
tensor_name
=
"activation"
,
iteration
=
1
,
tp_group
=
None
,
)
debug_api
.
transformer_engine
.
inspect_tensor_postquantize
(
"decoder.5.mlp.fc1"
,
tensor
=
tensor_fp8
,
tensor_name
=
"activation"
,
iteration
=
1
,
rowwise
=
True
,
tp_group
=
None
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
tensor_fp8
,
columnwise_quantized_tensor
=
tensor_fp8
,
)
def
log_stats
():
...
...
@@ -365,26 +367,26 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
return
STATS_BUFFERS
.
log_stats
()
quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
def
fp8_tensor
(
t
):
return
Float8Tensor
(
data
=
t
.
to
(
torch
.
uint8
).
cuda
(),
fp8_scale_inv
=
torch
.
ones
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
shape
=
t
.
shape
,
dtype
=
torch
.
float32
,
)
return
quantizer
(
t
.
cuda
())
shape
=
[
1024
,
1024
]
tensors
=
[
torch
.
randn
(
shape
)
for
_
in
range
(
2
)]
tensors_fp8
=
[
fp8_tensor
(
tensors
[
i
])
for
i
in
range
(
2
)]
feed
(
tensors
[
0
],
tensors_fp8
[
0
])
feed
(
tensors
[
1
],
tensors_fp8
[
1
])
feed
(
tensors
[
0
],
tensors_fp8
[
0
]
,
quantizer
)
feed
(
tensors
[
1
],
tensors_fp8
[
1
]
,
quantizer
)
stats1
=
log_stats
()
tensor2
=
torch
.
cat
((
tensors
[
0
],
tensors
[
1
])).
cuda
()
fp8tensor2
=
fp8_tensor
(
tensor2
)
feed
(
tensor2
,
fp8tensor2
)
feed
(
tensor2
,
fp8tensor2
,
quantizer
)
stats2
=
log_stats
()
assert
len
(
stats1
.
keys
())
>
0
...
...
tests/pytorch/debug/test_configs/log_config.yaml
0 → 100644
View file @
87e3e56e
test
:
enabled
:
True
layers
:
layer_name_regex_pattern
:
.*
transformer_engine
:
LogTensorStats
:
enabled
:
True
tensors_struct
:
-
tensor
:
activation
stats
:
[
cur_amax
,
dynamic_range
,
mean
,
std
,
l1_norm
]
start_step
:
1
freq
:
3
LogFp8TensorStats
:
enabled
:
True
tensors
:
activation
stats
:
[
underflows%
]
start_step
:
1
freq
:
5
\ No newline at end of file
tests/pytorch/debug/test_configs/perf_config.yaml
0 → 100644
View file @
87e3e56e
test
:
enabled
:
True
layers
:
layer_name_regex_pattern
:
.*1
transformer_engine
:
LogTensorStats
:
enabled
:
True
tensors_struct
:
-
tensor
:
activation
stats
:
[
cur_amax
,
dynamic_range
,
mean
,
std
,
l1_norm
]
start_step
:
0
freq
:
100000
\ No newline at end of file
tests/pytorch/debug/test_log.py
0 → 100644
View file @
87e3e56e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
nvdlfw_inspect.api
as
debug_api
import
transformer_engine.debug
import
transformer_engine.pytorch
as
te
import
torch
import
tempfile
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.fp8
import
RecipeState
import
pytest
import
contextlib
import
os
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
LOG_QUANTIZED_CONFIG_BASE
=
"""
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
recipes
=
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"fp8_block_scaling"
,
"mxfp8"
,
]
bare_stats
=
[
"underflows%"
,
"scale_inv_min"
,
"scale_inv_max"
,
"mse"
,
]
all_stats
=
[]
for
r
in
recipes
:
for
stat
in
bare_stats
:
for
columnwise_postfix
in
[
""
,
"_columnwise"
]:
if
(
r
in
[
"fp8_current_scaling"
,
"fp8_block_scaling"
]
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
9
):
# hopper is needed for current-scaling, block-scaling
continue
if
r
==
"mxfp8"
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
10
:
# blackwell is needed for mxfp8
continue
if
(
r
in
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
]
and
columnwise_postfix
==
"_columnwise"
):
# columnwise stats are not supported for fp8_delayed_scaling and fp8_current_scaling
continue
all_stats
.
append
(
f
"
{
r
}
_
{
stat
}{
columnwise_postfix
}
"
)
all_stats
.
append
(
"fp8_delayed_scaling_overflows%"
)
# only delayed-scaling supports overflows%
@
contextlib
.
contextmanager
def
debug_session
(
config_str
:
str
,
feature_dirs
):
"""
Helper context manager that
1. writes the YAML `config_str` to a temporary file,
2. starts a debug session, and
3. yields the directory that contains the statistics log.
The session is closed automatically – even on exceptions – so every test
stays concise and leak-free.
"""
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
delete
=
False
)
as
cfg_file
,
tempfile
.
TemporaryDirectory
()
as
log_dir
:
cfg_file
.
write
(
config_str
)
cfg_file
.
flush
()
debug_api
.
initialize
(
config_file
=
cfg_file
.
name
,
feature_dirs
=
feature_dirs
,
log_dir
=
log_dir
,
)
try
:
yield
log_dir
finally
:
debug_api
.
end_debug
()
def
read_log
(
log_dir
:
str
)
->
str
:
"""Return the content of the statistics log produced by `debug_session`."""
stat_path
=
os
.
path
.
join
(
log_dir
,
"nvdlfw_inspect_statistics_logs"
,
"nvdlfw_inspect_globalrank-0.log"
,
)
with
open
(
stat_path
,
"r"
)
as
f
:
return
f
.
read
()
def
test_sanity
(
feature_dirs
):
log_all_stats_config
=
LOG_QUANTIZED_CONFIG_BASE
.
format
(
stats
=
", "
.
join
(
all_stats
))
with
debug_session
(
log_all_stats_config
,
feature_dirs
)
as
log_dir
:
model
=
te
.
Linear
(
128
,
128
,
params_dtype
=
torch
.
bfloat16
)
inp
=
torch
.
zeros
(
128
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
for
_
in
range
(
10
):
with
te
.
fp8_autocast
(
fp8_recipe
=
recipe
.
DelayedScaling
()):
output
=
model
(
inp
)
loss
=
output
.
sum
()
loss
.
backward
()
debug_api
.
step
()
output
=
read_log
(
log_dir
)
assert
output
,
"Output is empty"
for
stat
in
all_stats
:
assert
stat
in
output
,
f
"Stat
{
stat
}
not found in output"
fp8_recipes
=
[
recipe
.
MXFP8BlockScaling
(),
recipe
.
DelayedScaling
(),
recipe
.
Float8CurrentScaling
(),
recipe
.
Float8BlockScaling
(),
]
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
test_numerics
(
fp8_recipe
,
feature_dirs
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
not
mxfp8_available
and
fp8_recipe
==
recipe
.
MXFP8BlockScaling
():
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
fp8_block_scaling_available
and
fp8_recipe
==
recipe
.
Float8BlockScaling
():
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
log_only_bare_stats_config
=
LOG_QUANTIZED_CONFIG_BASE
.
format
(
stats
=
", "
.
join
(
bare_stats
))
with
debug_session
(
log_only_bare_stats_config
,
feature_dirs
)
as
log_dir
:
recipe_state
=
RecipeState
.
create
(
fp8_recipe
,
mode
=
"forward"
,
num_quantizers
=
3
,
)
tensor
=
torch
.
zeros
(
1024
,
1024
).
cuda
()
tensor
[
0
,
:]
=
1000
quantizer
=
recipe_state
.
make_quantizers
()[
0
]
quantized_tensor
=
quantizer
(
tensor
)
debug_api
.
transformer_engine
.
inspect_tensor
(
layer_name
=
"layer_name"
,
tensor_name
=
"activation"
,
iteration
=
0
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
quantized_tensor
,
columnwise_quantized_tensor
=
quantized_tensor
,
)
debug_api
.
step
()
dequantized_tensor
=
quantized_tensor
.
dequantize
()
output
=
read_log
(
log_dir
)
for
line
in
output
.
splitlines
():
if
"underflows%"
in
line
:
underflows
=
float
(
line
.
split
(
"value="
)[
1
])
expected
=
(
((
dequantized_tensor
==
0
).
sum
()
-
(
tensor
==
0
).
sum
())
/
dequantized_tensor
.
numel
()
*
100
)
assert
underflows
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-4
)
if
"mse"
in
line
:
mse
=
float
(
line
.
split
(
"value="
)[
1
])
expected
=
torch
.
nn
.
functional
.
mse_loss
(
dequantized_tensor
,
tensor
,
reduction
=
"mean"
)
assert
mse
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-6
)
if
"overflows%"
in
line
:
overflows
=
float
(
line
.
split
(
"value="
)[
1
])
expected
=
(
(
abs
(
dequantized_tensor
)
>
abs
(
tensor
)).
sum
()
/
dequantized_tensor
.
numel
()
*
100
)
assert
overflows
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-4
)
@
pytest
.
mark
.
parametrize
(
"layer"
,
[
"linear"
,
"transformer"
])
def
test_log_every_3_or_5_layers
(
layer
,
configs_dir
,
feature_dirs
):
# If layer does not invoke any feature in current iteration,
# then it changed into non-debug mode.
# This test checks whether this works correctly -
# non-quantized statistics should be logged every 3 iterations,
# and quantized statistics should be logged every 5 iterations.
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
debug_api
.
initialize
(
config_file
=
configs_dir
+
"/log_config.yaml"
,
feature_dirs
=
feature_dirs
,
log_dir
=
temp_dir
,
)
if
layer
==
"linear"
:
model
=
te
.
Linear
(
128
,
128
,
name
=
"linear1"
)
elif
layer
==
"transformer"
:
model
=
te
.
TransformerLayer
(
128
,
128
,
4
,
name
=
"transformer1"
)
else
:
raise
ValueError
(
f
"Invalid layer:
{
layer
}
"
)
for
i
in
range
(
20
):
x
=
torch
.
randn
(
4
,
128
,
128
).
cuda
()
with
te
.
fp8_autocast
(
enabled
=
True
):
y
=
model
(
x
)
y
.
sum
().
backward
()
debug_api
.
step
()
with
open
(
os
.
path
.
join
(
temp_dir
,
"nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
),
"r"
,
)
as
f
:
file_content
=
f
.
read
()
for
i
in
range
(
1
,
20
):
if
i
%
3
==
0
or
i
%
5
==
0
:
assert
f
"iteration=
{
i
:
06
d
}
"
in
file_content
else
:
assert
f
"iteration=
{
i
:
06
d
}
"
not
in
file_content
debug_api
.
end_debug
()
TEDebugState
.
_reset
()
tests/pytorch/debug/test_perf.py
0 → 100644
View file @
87e3e56e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
import
time
import
nvdlfw_inspect.api
as
debug_api
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
def
_run_cpu_overhead
(
debug_tools_initialized
,
layer
,
configs_dir
,
feature_dirs
):
debug_api
.
end_debug
()
TEDebugState
.
_reset
()
if
debug_tools_initialized
:
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api
.
initialize
(
config_file
=
configs_dir
+
"/perf_config.yaml"
,
feature_dirs
=
feature_dirs
)
try
:
if
layer
==
"linear"
:
model
=
torch
.
nn
.
Sequential
(
te
.
Linear
(
1
,
1
,
name
=
"linear1"
),
te
.
Linear
(
1
,
1
,
name
=
"linear2"
)
).
cuda
()
NUM_ITERS
=
18000
elif
layer
==
"transformer"
:
model
=
torch
.
nn
.
Sequential
(
te
.
TransformerLayer
(
1
,
1
,
1
,
name
=
"transformer1"
),
te
.
TransformerLayer
(
1
,
1
,
1
,
name
=
"transformer2"
),
).
cuda
()
NUM_ITERS
=
2000
x
=
torch
.
randn
(
1
,
1
,
1
).
cuda
()
y
=
model
(
x
)
y
.
sum
().
backward
()
debug_api
.
step
()
torch
.
cuda
.
synchronize
()
time_start
=
time
.
time
()
for
i
in
range
(
NUM_ITERS
):
y
=
model
(
x
)
y
.
sum
().
backward
()
if
debug_tools_initialized
:
debug_api
.
step
()
torch
.
cuda
.
synchronize
()
time_end
=
time
.
time
()
finally
:
if
debug_tools_initialized
:
debug_api
.
end_debug
()
return
time_end
-
time_start
@
pytest
.
mark
.
parametrize
(
"layer"
,
[
"linear"
,
"transformer"
])
def
test_cpu_overhead
(
layer
,
configs_dir
,
feature_dirs
):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.
with_debug_tools
=
_run_cpu_overhead
(
True
,
layer
,
configs_dir
,
feature_dirs
)
without_debug_tools
=
_run_cpu_overhead
(
False
,
layer
,
configs_dir
,
feature_dirs
)
print
(
f
"with_debug_tools:
{
with_debug_tools
}
s"
)
print
(
f
"without_debug_tools:
{
without_debug_tools
}
s"
)
assert
with_debug_tools
<
without_debug_tools
*
1.25
# 25% overhead margin
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
87e3e56e
...
...
@@ -519,6 +519,7 @@ def _train(opts):
if
opts
.
use_cuda_graphs
:
del
test_graph
torch
.
cuda
.
synchronize
()
te
.
module
.
base
.
destroy_ub
()
dist_print
(
"Destroying Userbuffers objects..."
,
debug
=
True
)
...
...
tests/pytorch/distributed/test_sanity.py
0 → 100644
View file @
87e3e56e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pathlib
import
sys
import
pytest
import
torch
import
transformer_engine
from
transformer_engine.pytorch.attention.dot_product_attention
import
DotProductAttention
from
transformer_engine.pytorch
import
TransformerLayer
,
Linear
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
from
utils
import
ModelConfig
model_configs
=
{
"small"
:
ModelConfig
(
2
,
10
,
2
,
16
),
}
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
])
def
test_current_device
(
model
,
module
):
"""Test cases where current device is different from tensor device"""
num_devices
=
torch
.
cuda
.
device_count
()
assert
num_devices
>
1
,
"This test requires more than one GPU!"
tensor_device
=
num_devices
-
1
dtype
=
torch
.
bfloat16
config
=
model_configs
[
model
]
args
=
[]
kwargs
=
{}
bwd_args
=
[]
if
module
==
"TransformerLayer"
:
model
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_heads
,
params_dtype
=
dtype
,
attn_input_format
=
"thd"
,
self_attn_mask_type
=
"padding"
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
args
=
[
torch
.
randn
(
(
num_tokens
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
)
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
if
module
==
"DotProductAttention"
:
model
=
DotProductAttention
(
config
.
num_heads
,
config
.
head_dim_qk
,
qkv_format
=
"thd"
,
attn_mask_type
=
"padding"
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
args
=
[
torch
.
randn
(
num_tokens
,
config
.
num_heads
,
config
.
head_dim_qk
,
dtype
=
dtype
,
device
=
tensor_device
,
requires_grad
=
True
,
)
for
_
in
range
(
3
)
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
tensor_device
)]
elif
module
==
"Linear"
:
model
=
Linear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
params_dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
args
=
[
torch
.
randn
(
(
config
.
max_seqlen_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
)
]
current_device_before
=
torch
.
cuda
.
current_device
()
out
=
model
(
*
args
,
**
kwargs
)
if
module
==
"DotProductAttention"
:
out
.
backward
(
*
bwd_args
)
else
:
loss
=
out
.
sum
()
loss
.
backward
()
current_device_after
=
torch
.
cuda
.
current_device
()
tensor_device_out
=
out
.
get_device
()
tensor_device_grad
=
args
[
0
].
grad
.
get_device
()
assert
(
current_device_after
==
current_device_before
),
"The current device should not have changed!"
assert
(
tensor_device_out
==
tensor_device
),
"The output tensor should be the same as the input tensors!"
assert
(
tensor_device_grad
==
tensor_device
),
"The gradient tensor should be the same as the input tensors!"
tests/pytorch/test_cpu_offloading.py
View file @
87e3e56e
...
...
@@ -10,22 +10,24 @@ import torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
utils
import
ModelConfig
,
get_available_attention_backends
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_recipes
=
[
None
,
# non-fp8
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
recipe
.
Float8CurrentScaling
(),
recipe
.
DelayedScaling
(),
]
fp8_recipes
=
[
None
]
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
SIZE
=
512
NUM_HEADS
=
8
NUM_LAYERS
=
5
EPSILON
=
0.1
model_config
=
{
"small"
:
ModelConfig
(
8
,
512
,
8
,
64
,
num_layers
=
5
,
eps
=
0.1
),
}
SIZE
=
model_config
[
"small"
].
hidden_size
NUM_HEADS
=
model_config
[
"small"
].
num_heads
NUM_LAYERS
=
model_config
[
"small"
].
num_layers
EPSILON
=
model_config
[
"small"
].
eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
...
...
@@ -124,11 +126,17 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
model_cls
=
model_types
[
model_key
]
models_list
=
[
model_cls
()
for
_
in
range
(
NUM_LAYERS
)]
if
fp8_recipe
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
model_key
in
[
"multihead_attention"
,
"transformer_layer"
]:
available_backends
,
*
_
=
get_available_attention_backends
(
model_config
[
"small"
],
qkv_dtype
=
torch
.
bfloat16
,
qkv_layout
=
"sbhd_sbhd_sbhd"
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
pytest
.
skip
(
"Fused attention backend not available."
)
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
without_offloading
=
_measure_memory_between_forward_and_backward
(
models_list
,
fp8_recipe
,
False
...
...
tests/pytorch/test_cuda_graphs.py
View file @
87e3e56e
...
...
@@ -2,9 +2,7 @@
#
# See LICENSE for license information.
from
dataclasses
import
dataclass
import
itertools
from
typing
import
Iterable
,
List
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Union
import
pytest
import
torch
...
...
@@ -23,46 +21,32 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.common
import
recipe
from
utils
import
ModelConfig
,
reset_rng_states
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
if
IS_HIP_EXTENSION
:
import
os
from
functools
import
cache
# Check if FP8 is supported.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Record initial RNG state.
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
@
dataclass
class
ModelConfig
:
"""Data tensor dimensions within Transformer model"""
sequence_length
:
int
batch_size
:
int
hidden_size
:
int
num_heads
:
int
kv_channels
:
int
model_configs
=
{
"small"
:
ModelConfig
(
2
,
32
,
64
,
2
,
32
)}
fp8_recipes
=
[
recipe
.
DelayedScaling
(),
recipe
.
MXFP8BlockScaling
(),
recipe
.
Float8CurrentScaling
(),
recipe
.
Float8BlockScaling
(),
]
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Reset RNG states.
reset_rng_states
()
model_configs
=
{
"small"
:
ModelConfig
(
32
,
2
,
2
,
32
),
}
fp8_recipes
=
[]
if
mxfp8_available
:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
fp8_block_scaling_available
:
fp8_recipes
.
append
(
recipe
.
Float8BlockScaling
())
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
# Supported data types
dtypes
:
List
[
torch
.
dtype
]
=
[
torch
.
float32
,
torch
.
float16
]
...
...
@@ -70,12 +54,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes
.
append
(
torch
.
bfloat16
)
def
reset_rng_states
()
->
None
:
"""Revert to initial RNG state."""
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_global_fp8_state
():
yield
...
...
@@ -119,7 +97,7 @@ def generate_data(
"""Generate synthetic data."""
gen_func
=
torch
.
ones
if
warmup
else
torch
.
randn
return
gen_func
(
model_config
.
seq
u
en
ce_length
,
model_config
.
max_
seq
l
en
_q
,
model_config
.
batch_size
,
model_config
.
hidden_size
,
device
=
"cuda"
,
...
...
@@ -157,10 +135,12 @@ class _Sequential(torch.nn.Sequential):
# Supported modules
_test_cuda_graphs_modules
:
List
[
str
]
=
[
# Put linear first to test the case where the cuda context might not be set in
# creating TMA descriptor for MXFP8 quantization.
"linear"
,
"transformer"
,
"layernorm_mlp"
,
"layernorm_linear"
,
"linear"
,
"mha"
,
"linear_op"
,
]
...
...
@@ -310,35 +290,27 @@ def _test_cuda_graphs(
@
pytest
.
mark
.
parametrize
(
"module"
,
_test_cuda_graphs_modules
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
@
pytest
.
mark
.
parametrize
(
"fp8"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_params"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
+
[
None
]
)
def
test_make_graphed_callables
(
*
,
module
:
str
,
model_config
:
str
=
"small"
,
num_layers
:
int
=
3
,
dtype
:
torch
.
dtype
,
fp8
:
bool
,
fp8_params
:
bool
,
fp8_recipe
:
recipe
.
Recipe
,
fp8_weight_caching
:
bool
=
False
,
)
->
None
:
# Skip invalid configurations.
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
fp8
=
fp8_recipe
is
not
None
if
fp8_params
and
not
fp8
:
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
if
fp8_weight_caching
and
not
fp8
:
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
module
==
"linear_op"
:
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
module
==
"linear_op"
:
pytest
.
skip
(
"Module not yet supported for float8_block_scaling with CUDA graphs"
)
# Run model with different CUDA graph settings.
model_config
=
model_configs
[
model_config
]
kwargs
=
dict
(
...
...
@@ -351,9 +323,11 @@ def test_make_graphed_callables(
fp8_weight_caching
=
fp8_weight_caching
,
fp8_recipe
=
fp8_recipe
,
)
outputs
=
_test_cuda_graphs
(
graph_mode
=
"none"
,
**
kwargs
)
# Put graphed callables first to test the case where the cuda context might not be set in
# creating TMA descriptor for MXFP8 quantization.
graph_outputs_mode1
=
_test_cuda_graphs
(
graph_mode
=
"full"
,
**
kwargs
)
graph_outputs_mode2
=
_test_cuda_graphs
(
graph_mode
=
"individual"
,
**
kwargs
)
outputs
=
_test_cuda_graphs
(
graph_mode
=
"none"
,
**
kwargs
)
# Check that results match.
assert_all_equal
(
outputs
,
graph_outputs_mode1
)
...
...
@@ -369,7 +343,6 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
]
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
parametrize
(
"module"
,
_test_make_graphed_callables_with_fp8_weight_caching_modules
,
...
...
@@ -385,7 +358,6 @@ def test_make_graphed_callables_with_fp8_weight_caching(
test_make_graphed_callables
(
module
=
module
,
dtype
=
torch
.
float32
,
fp8
=
True
,
fp8_params
=
fp8_params
,
fp8_recipe
=
fp8_recipe
,
fp8_weight_caching
=
True
,
...
...
@@ -401,7 +373,7 @@ def generate_data_for_dot_product_attention(
gen_func
=
torch
.
ones
if
warmup
else
torch
.
randn
return
[
gen_func
(
model_config
.
seq
u
en
ce_length
,
model_config
.
max_
seq
l
en
_q
,
model_config
.
batch_size
,
model_config
.
num_heads
,
model_config
.
kv_channels
,
...
...
@@ -495,8 +467,8 @@ def _test_cuda_graphs_with_kwargs(
(
model_config
.
batch_size
,
1
,
model_config
.
seq
u
en
ce_length
,
model_config
.
sequence_length
,
model_config
.
max_
seq
l
en
_q
,
model_config
.
max_seqlen_kv
,
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
...
...
@@ -522,8 +494,8 @@ def _test_cuda_graphs_with_kwargs(
(
model_config
.
batch_size
,
1
,
model_config
.
seq
u
en
ce_length
,
model_config
.
sequence_length
,
model_config
.
max_
seq
l
en
_q
,
model_config
.
max_seqlen_kv
,
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
...
...
tests/pytorch/test_float8blockwisetensor.py
View file @
87e3e56e
...
...
@@ -223,7 +223,7 @@ class TestFloat8BlockwiseTensor:
rowwise
=
True
,
columnwise
=
dq_columnwise
,
block_scaling_dim
=
block_scaling_dim
,
all_gather_usage
=
True
,
all_gather_usage
=
(
block_scaling_dim
==
1
)
,
)
self
.
_test_quantize_dequantize
(
quantizer
=
quantizer
,
...
...
tests/pytorch/test_fused_optimizer.py
View file @
87e3e56e
...
...
@@ -2,7 +2,6 @@
#
# See LICENSE for license information.
from
itertools
import
product
import
copy
from
contextlib
import
nullcontext
...
...
@@ -112,13 +111,6 @@ class TestFusedAdam(TestFusedOptimizer):
def
test_bfloat16
(
self
):
self
.
gen_single_type_test
(
param_type
=
torch
.
bfloat16
,
skip_assert
=
True
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"more than 1 GPU required"
)
def
test_multi_device
(
self
):
devices
=
(
"cuda:0"
,
"cuda:1"
)
for
current_dev
,
tensor_dev
in
product
(
devices
,
devices
):
with
torch
.
cuda
.
device
(
current_dev
):
self
.
gen_single_type_test
(
param_type
=
torch
.
float
,
device
=
tensor_dev
)
def
test_multi_params
(
self
):
sizes
=
[[
4096
,
1024
],
[
4096
],
[
4096
,
2048
],
[
32320
,
1024
],
[
1
]]
...
...
@@ -530,13 +522,6 @@ class TestFusedSGD(TestFusedOptimizer):
def
test_half
(
self
):
self
.
gen_single_type_test
(
param_type
=
torch
.
float16
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"more than 1 GPU required"
)
def
test_multi_device
(
self
):
devices
=
(
"cuda:0"
,
"cuda:1"
)
for
current_dev
,
tensor_dev
in
product
(
devices
,
devices
):
with
torch
.
cuda
.
device
(
current_dev
):
self
.
gen_single_type_test
(
param_type
=
torch
.
float
,
device
=
tensor_dev
)
class
Model
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
...
...
tests/pytorch/test_fused_router.py
View file @
87e3e56e
...
...
@@ -2,8 +2,7 @@
#
# See LICENSE for license information.
import
torch
import
math
from
typing
import
Optional
,
Dict
from
typing
import
Optional
from
transformer_engine.pytorch.router
import
(
fused_topk_with_score_function
,
fused_compute_score_for_moe_aux_loss
,
...
...
@@ -149,11 +148,21 @@ def run_comparison(
# Set some parameters
if
score_function
==
"sigmoid"
:
# Construct the special logits to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
logits
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
offset
=
torch
.
arange
(
-
num_tokens
//
2
,
num_tokens
//
2
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
logits
=
(
torch
.
arange
(
-
num_experts
//
2
,
num_experts
//
2
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
)
logits
=
logits
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
else
:
logits
=
torch
.
arange
(
num_tokens
*
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-4
logits
=
(
torch
.
arange
(
-
num_tokens
*
num_experts
//
2
,
num_tokens
*
num_experts
//
2
,
device
=
"cuda"
,
dtype
=
dtype
,
)
*
1e-4
)
logits
=
logits
.
view
(
num_tokens
,
num_experts
)
logits
.
requires_grad
=
True
if
enable_bias
and
score_function
==
"sigmoid"
:
...
...
@@ -282,11 +291,21 @@ def test_topk_softmax(
def
test_fused_scores_for_aux_loss
(
dtype
,
num_tokens
,
num_experts
,
topk
,
score_function
):
if
score_function
==
"sigmoid"
:
# Construct the special logits to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
logits
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
offset
=
torch
.
arange
(
-
num_tokens
//
2
,
num_tokens
//
2
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
logits
=
(
torch
.
arange
(
-
num_experts
//
2
,
num_experts
//
2
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
)
logits
=
logits
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
else
:
logits
=
torch
.
arange
(
num_tokens
*
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-4
logits
=
(
torch
.
arange
(
-
num_tokens
*
num_experts
//
2
,
num_tokens
*
num_experts
//
2
,
device
=
"cuda"
,
dtype
=
dtype
,
)
*
1e-4
)
logits
=
logits
.
view
(
num_tokens
,
num_experts
)
logits
.
requires_grad
=
True
...
...
@@ -322,8 +341,8 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
])
def
test_fused_moe_aux_loss
(
dtype
,
num_tokens
,
num_experts
,
topk
):
# Construct the special probs to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
probs
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
offset
=
torch
.
arange
(
-
num_tokens
//
2
,
num_tokens
//
2
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
probs
=
torch
.
arange
(
-
num_experts
//
2
,
num_experts
//
2
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
probs
=
probs
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
probs
=
probs
.
view
(
num_tokens
,
num_experts
)
probs
.
requires_grad
=
True
...
...
@@ -380,15 +399,12 @@ def profile_topk_softmax(
if
__name__
==
"__main__"
:
test_fused_scores_for_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2
,
num_experts
=
32
,
topk
=
8
,
score_function
=
"softmax"
test_topk_softmax
(
dtype
=
torch
.
float32
,
num_tokens
=
1024
,
num_experts
=
128
,
topk
=
4
,
use_pre_softmax
=
False
,
group_topk
=
None
,
scaling_factor
=
None
,
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
256
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
256
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
256
,
topk
=
4
)
tests/pytorch/test_fusible_ops.py
View file @
87e3e56e
...
...
@@ -21,10 +21,12 @@ import transformer_engine.pytorch as te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops.fused
import
(
Backward
Bias
Activation
,
BackwardActivation
Bias
,
BackwardLinearAdd
,
BackwardLinearScale
,
ForwardLinearBiasActivation
,
ForwardLinearBiasAdd
,
ForwardLinearScaleAdd
,
)
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
...
...
@@ -39,7 +41,7 @@ import transformer_engine_torch as tex
# Import utility functions
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
))
from
utils
import
dtype_tols
,
make_recipe
from
utils
import
dtype_tols
,
make_recipe
,
reset_rng_states
if
IS_HIP_EXTENSION
:
import
os
...
...
@@ -271,16 +273,72 @@ class TestSequentialContainer:
model
(
torch
.
zeros
(
1
))
assert
len
(
model
.
_module_groups
)
==
6
def
test_extra_tensors
(
self
,
size
:
int
=
16
)
->
None
:
"""Check that extra inputs are distributed properly between module groups
and that extra outputs are properly collected"""
# Construct sequential container
bias
=
te_ops
.
Bias
(
size
=
size
,
device
=
"cpu"
)
with
torch
.
no_grad
():
bias
.
bias
.
copy_
(
torch
.
rand
((
size
,)))
model
=
te_ops
.
Sequential
(
# | Inputs | Outputs
torch
.
nn
.
Identity
(),
# | x1 | x1
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
# | x1 | x1 [x1]
bias
,
# | x1 | h1 (= x1 + b)
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
# | h1 | h1 [h1]
te_ops
.
AddExtraInput
(
in_place
=
True
),
# | h1 [x2] | x2 (= x2 + h1)
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
# | x2 | x2 [x2]
torch
.
nn
.
Identity
(),
# | x2 | x2
bias
,
# | x2 | h2 (= x2 + b)
te_ops
.
AddExtraInput
(
in_place
=
True
),
# | h2 [x3] | x3 (= x3 + h2)
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
# | x3 | x3 [x3]
te_ops
.
AddExtraInput
(
in_place
=
True
),
# | x3 [x4] | x4 (= x4 + x3)
torch
.
nn
.
Identity
(),
# | x4 | x4
te_ops
.
Identity
(),
# | x4 | x4
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
# | x4 | x4 [x4]
te_ops
.
Identity
(),
# | x4 | x4
)
# Create input tensors
x1
=
torch
.
rand
((
size
,))
x2
=
torch
.
rand
((
size
,))
x3
=
torch
.
rand
((
size
,))
x4
=
torch
.
rand
((
size
,))
# Save original input tensor values
x1_orig
=
x1
.
clone
()
x2_orig
=
x2
.
clone
()
x3_orig
=
x3
.
clone
()
x4_orig
=
x4
.
clone
()
# Run forward
ys
=
model
(
x1
,
x2
,
x3
,
x4
)
# Check whether outputs match (x4, x1, h1, x2, x3, x4)
assert
len
(
ys
)
==
6
assert
ys
[
0
].
data_ptr
()
==
x4
.
data_ptr
()
assert
ys
[
1
].
data_ptr
()
==
x1
.
data_ptr
()
assert
ys
[
2
].
data_ptr
()
not
in
[
x
.
data_ptr
()
for
x
in
(
x1
,
x2
,
x3
,
x4
)]
assert
ys
[
3
].
data_ptr
()
==
x2
.
data_ptr
()
assert
ys
[
4
].
data_ptr
()
==
x3
.
data_ptr
()
assert
ys
[
5
].
data_ptr
()
==
x4
.
data_ptr
()
# Check whether tensors have correct values
b
=
bias
.
bias
h1
=
ys
[
2
]
torch
.
testing
.
assert_close
(
x1
,
x1_orig
)
torch
.
testing
.
assert_close
(
h1
,
x1_orig
+
b
)
torch
.
testing
.
assert_close
(
x2
,
x2_orig
+
h1
)
torch
.
testing
.
assert_close
(
x3
,
x3_orig
+
x2
+
b
)
torch
.
testing
.
assert_close
(
x4
,
x4_orig
+
x3
)
class
TestFuser
:
"""Tests for operation fusion infrastructure"""
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
reset_rng_states
()
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_scale_update
(
...
...
@@ -494,10 +552,7 @@ class TestBasicOps:
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
reset_rng_states
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
...
...
@@ -795,10 +850,9 @@ class TestBasicOps:
pytest
.
skip
(
"FP8 output is only supported with FP8 GEMMs"
)
if
quantized_grad_input
and
not
quantized_compute
:
pytest
.
skip
(
"FP8 grad input is only supported with FP8 GEMMs"
)
if
quantization
==
"mxfp8"
and
quantized_output
:
pytest
.
skip
(
"MXFP8 output is not supported with MXFP8 GEMMs"
)
if
quantization
==
"mxfp8"
and
quantized_grad_input
:
pytest
.
skip
(
"MXFP8 grad input is not supported with MXFP8 GEMMs"
)
if
quantization
not
in
(
None
,
"fp8"
):
if
quantized_output
or
quantized_grad_input
:
pytest
.
skip
(
"Recipe does not support quantized GEMM output"
)
if
(
IS_HIP_EXTENSION
and
not
use_hipblaslt
()
and
accumulate_into_main_grad
and
dtype
!=
torch
.
float32
and
not
quantized_compute
):
pytest
.
skip
(
"Parameters combination is not supported by ROCBLAS"
)
...
...
@@ -1353,18 +1407,17 @@ class TestBasicOps:
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
# L2Norm backward pass requires slightly looser atol for bfloat16
if
dtype
==
torch
.
bfloat16
:
tols
[
"atol"
]
=
2e-3
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"in_place"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_add_
in_place
(
def
test_add_
extra_input
(
self
,
*
,
in_shape
:
Iterable
[
int
]
=
(
32
,
32
),
in_place
:
bool
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
quantization
:
Optional
[
str
],
...
...
@@ -1410,7 +1463,7 @@ class TestBasicOps:
dx2_ref
=
dy_ref
# Implementation with fusible operation
op
=
te_ops
.
Add
InPlace
(
)
op
=
te_ops
.
Add
ExtraInput
(
in_place
=
in_place
)
y_test
=
op
(
x1_test
,
x2_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -1425,6 +1478,7 @@ class TestBasicOps:
torch
.
testing
.
assert_close
(
dx1_test
,
dx1_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
dx2_test
,
dx2_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
parametrize
(
"in_place"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
...
...
@@ -1432,6 +1486,7 @@ class TestBasicOps:
self
,
*
,
in_shape
:
Iterable
[
int
]
=
(
32
,
32
),
in_place
:
bool
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
quantization
:
Optional
[
str
],
...
...
@@ -1477,7 +1532,7 @@ class TestBasicOps:
(
y1_ref
*
dy1_ref
+
y2_ref
*
dy2_ref
).
sum
().
backward
()
# Implementation with fusible operation
op
=
te_ops
.
MakeExtraOutput
()
op
=
te_ops
.
MakeExtraOutput
(
in_place
=
in_place
)
y1_test
,
y2_test
=
op
(
x_test
)
(
y1_test
*
dy1_test
+
y2_test
*
dy2_test
).
sum
().
backward
()
...
...
@@ -1645,16 +1700,107 @@ class TestBasicOps:
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
1
,
0
,
-
2.5
,
3.5
))
@
pytest
.
mark
.
parametrize
(
"shape"
,
((),
(
1
,
13
),
(
4
,
4
,
2
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
_devices
)
def
test_constant_scale
(
self
,
*
,
scale
:
float
,
shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
scale
*
x_ref
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
op
=
te_ops
.
ConstantScale
(
scale
)
y_test
=
op
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check results
tols
=
dtype_tols
(
dtype
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"prob"
,
(
0.1
,
0.5
,
0.75
))
@
pytest
.
mark
.
parametrize
(
"is_training"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"shape"
,
((
101
,),
(
2
,
4
,
16
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
def
test_dropout
(
self
,
*
,
prob
:
float
,
is_training
:
bool
,
shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
):
# Random data
x_ref
=
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
device
)
+
0.5
x_test
=
x_ref
.
clone
().
requires_grad_
()
dy_ref
=
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
device
)
+
0.5
dy_test
=
dy_ref
.
clone
()
# Apply dropout
op
=
te_ops
.
Dropout
(
prob
)
if
is_training
:
op
.
train
()
else
:
op
.
eval
()
y
=
op
(
x_test
)
y
.
backward
(
dy_test
)
# Check values
if
is_training
:
mask
=
((
y
!=
0
)
/
(
1
-
prob
)).
to
(
dtype
=
dtype
)
torch
.
testing
.
assert_close
(
y
,
x_ref
*
mask
)
torch
.
testing
.
assert_close
(
x_test
.
grad
,
dy_ref
*
mask
)
else
:
torch
.
testing
.
assert_close
(
y
,
x_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
x_test
.
grad
,
dy_ref
,
rtol
=
0
,
atol
=
0
)
# Hypothesis testing for number of zeros
# Note: A Bernoulli random variable with probability p has
# mean p and standard deviation sqrt(p*(1-p)). By the central
# limit theorem, the mean of n iid Bernoulli variables
# converges to a normal random variable with mean p and
# standard deviation sqrt(p*(1-p)/n). If the observed mean is
# below the 0.5th or above the 99.5th percentiles, then the
# p-value is less than 1% and we assume that the dropout
# distribution is incorrect.
if
is_training
:
prob_observed
=
1
-
torch
.
count_nonzero
(
y
).
item
()
/
y
.
numel
()
z_score
=
(
prob_observed
-
prob
)
/
math
.
sqrt
(
prob
*
(
1
-
prob
)
/
y
.
numel
())
assert
abs
(
z_score
)
<
2.5758
,
"Number of zeros is outside 99% confidence interval"
class
TestFusedOps
:
"""Tests for fused operations"""
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
reset_rng_states
()
@
pytest
.
mark
.
parametrize
(
"weight_shape"
,
((
32
,
64
),
(
3
,
5
)))
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
-
1
,),
(
1
,
7
,
-
1
),
(
8
,
2
,
10
,
-
1
)))
...
...
@@ -1841,7 +1987,7 @@ class TestFusedOps:
device
=
device
,
dtype
=
dtype
,
),
te_ops
.
Add
InPlace
(
),
te_ops
.
Add
ExtraInput
(
in_place
=
True
),
)
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
...
...
@@ -1878,11 +2024,114 @@ class TestFusedOps:
db_test
=
model
[
0
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
1
,
0
,
-
2.5
,
3.5
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_forward_linear_scale_add
(
self
,
*
,
scale
:
float
,
weight_shape
:
tuple
[
int
,
int
]
=
(
32
,
32
),
in_shape
:
Iterable
[
int
]
=
(
32
,
-
1
),
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantized_weight
:
bool
=
False
,
)
->
None
:
"""Forward GEMM + scale + add"""
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
# Random data
x1_ref
,
x1_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
x2_ref
,
x2_test
=
make_reference_and_test_tensors
(
out_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
x1_ref
,
w_ref
)
*
scale
+
x2_ref
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
out_features
,
bias
=
False
,
device
=
device
,
dtype
=
dtype
,
),
te_ops
.
ConstantScale
(
scale
),
te_ops
.
AddExtraInput
(
in_place
=
True
),
te_ops
.
Quantize
(),
)
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
del
w_test
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
y_test
=
model
(
x1_test
,
x2_test
)
y_test
.
backward
(
dy_test
)
# Check that forward operations have been fused
forward_ops
=
model
.
_module_groups
[
0
].
_forward_ops
assert
len
(
forward_ops
)
==
2
assert
isinstance
(
forward_ops
[
0
][
0
],
ForwardLinearScaleAdd
)
assert
isinstance
(
forward_ops
[
1
][
0
],
te_ops
.
Quantize
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx1_test
=
x1_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx2_test
=
x2_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
model
[
0
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx1_test
,
x1_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dx2_test
,
x2_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
(
"relu"
,
"gelu"
))
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
32
,
32
),
(
32
,
1
,
32
),
(
8
,
2
,
2
,
32
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_backward_
bias_
activation
(
def
test_backward_activation
_bias
(
self
,
*
,
activation
:
str
,
...
...
@@ -1891,7 +2140,7 @@ class TestFusedOps:
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
)
->
None
:
"""Backward d
bias + dact
+ quantize"""
"""Backward d
act + dbias
+ quantize"""
# Tensor dimensions
in_shape
=
list
(
out_shape
)
...
...
@@ -1948,9 +2197,9 @@ class TestFusedOps:
# Check that backward operations have been fused
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
if
with_quantization
and
quantization
in
[
"fp8_delayed_scaling"
,
"mxfp8"
]
:
if
with_quantization
:
assert
len
(
backward_ops
)
==
2
assert
isinstance
(
backward_ops
[
0
][
0
],
Backward
Bias
Activation
)
assert
isinstance
(
backward_ops
[
0
][
0
],
BackwardActivation
Bias
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Quantize
)
else
:
assert
len
(
backward_ops
)
==
3
...
...
@@ -1963,6 +2212,7 @@ class TestFusedOps:
if
with_quantization
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
db_test
=
model
[
1
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -2033,7 +2283,7 @@ class TestFusedOps:
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
):
model
=
te_ops
.
Sequential
(
te_ops
.
MakeExtraOutput
(),
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
te_ops
.
Linear
(
in_features
,
out_features
,
...
...
@@ -2071,16 +2321,106 @@ class TestFusedOps:
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
1
,
0
,
-
2.5
,
3.5
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_backward_linear_scale
(
self
,
*
,
scale
:
float
,
weight_shape
:
tuple
[
int
,
int
]
=
(
32
,
32
),
in_shape
:
Iterable
[
int
]
=
(
32
,
-
1
),
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantized_weight
:
bool
=
False
,
)
->
None
:
"""Backward dgrad GEMM + scale"""
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
x_ref
,
w_ref
)
*
scale
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
):
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
out_features
,
bias
=
False
,
device
=
device
,
dtype
=
dtype
,
),
te_ops
.
ConstantScale
(
scale
),
)
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
del
w_test
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
y_test
=
model
(
x_test
)
(
y_test
*
dy_test
).
sum
().
backward
()
# Check that backward operations have been fused
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
assert
len
(
backward_ops
)
==
1
assert
isinstance
(
backward_ops
[
0
][
0
],
BackwardLinearScale
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
model
[
0
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
class
TestCheckpointing
:
"""Tests for checkpointing"""
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
reset_rng_states
()
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
...
...
@@ -2192,11 +2532,9 @@ class TestSequentialModules:
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
reset_rng_states
()
@
pytest
.
mark
.
parametrize
(
"requires_grad"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"normalization"
,
(
"LayerNorm"
,
"RMSNorm"
))
@
pytest
.
mark
.
parametrize
(
"quantized_compute"
,
(
False
,
True
))
...
...
@@ -2206,6 +2544,7 @@ class TestSequentialModules:
def
test_layernorm_mlp
(
self
,
*
,
requires_grad
:
bool
,
bias
:
bool
,
normalization
:
str
,
quantized_compute
:
bool
,
...
...
@@ -2246,6 +2585,7 @@ class TestSequentialModules:
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
requires_grad
,
)
_
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
...
...
tests/pytorch/test_hf_integration.py
View file @
87e3e56e
...
...
@@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig
from
transformers.modeling_utils
import
PreTrainedModel
from
transformer_engine.pytorch.transformer
import
TransformerLayer
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
class
SimpleTEModel
(
PreTrainedModel
):
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment