Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
87e3e56e
Commit
87e3e56e
authored
Aug 27, 2025
by
yuguo
Browse files
Merge commit '
734bcedd
' of...
Merge commit '
734bcedd
' of
https://github.com/NVIDIA/TransformerEngine
parents
2f11bd2e
734bcedd
Changes
217
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1690 additions
and
748 deletions
+1690
-748
tests/jax/utils.py
tests/jax/utils.py
+6
-4
tests/pytorch/attention/run_attention_with_cp.py
tests/pytorch/attention/run_attention_with_cp.py
+1
-1
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+538
-448
tests/pytorch/attention/test_attention_with_cp.py
tests/pytorch/attention/test_attention_with_cp.py
+57
-38
tests/pytorch/attention/test_kv_cache.py
tests/pytorch/attention/test_kv_cache.py
+15
-20
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+34
-0
tests/pytorch/debug/test_api_features.py
tests/pytorch/debug/test_api_features.py
+81
-79
tests/pytorch/debug/test_configs/log_config.yaml
tests/pytorch/debug/test_configs/log_config.yaml
+19
-0
tests/pytorch/debug/test_configs/perf_config.yaml
tests/pytorch/debug/test_configs/perf_config.yaml
+13
-0
tests/pytorch/debug/test_log.py
tests/pytorch/debug/test_log.py
+250
-0
tests/pytorch/debug/test_perf.py
tests/pytorch/debug/test_perf.py
+76
-0
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+1
-0
tests/pytorch/distributed/test_sanity.py
tests/pytorch/distributed/test_sanity.py
+121
-0
tests/pytorch/test_cpu_offloading.py
tests/pytorch/test_cpu_offloading.py
+25
-17
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+37
-65
tests/pytorch/test_float8blockwisetensor.py
tests/pytorch/test_float8blockwisetensor.py
+1
-1
tests/pytorch/test_fused_optimizer.py
tests/pytorch/test_fused_optimizer.py
+0
-15
tests/pytorch/test_fused_router.py
tests/pytorch/test_fused_router.py
+37
-21
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+378
-38
tests/pytorch/test_hf_integration.py
tests/pytorch/test_hf_integration.py
+0
-1
No files found.
tests/jax/utils.py
View file @
87e3e56e
...
@@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
...
@@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
@
contextmanager
@
contextmanager
def
use_jax_gemm
(
enabled
=
False
):
def
use_jax_gemm
(
enabled
=
False
):
orig_custom_calls_filter
=
os
.
environ
.
get
(
"NVTE_JAX_CUSTOM_CALLS
_RE
"
,
None
)
orig_custom_calls_filter
=
os
.
environ
.
get
(
"NVTE_JAX_CUSTOM_CALLS"
,
None
)
try
:
try
:
if
enabled
:
if
enabled
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS_RE"
]
=
"^(?!GemmPrimitive$).+$"
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS"
]
=
"GemmPrimitive=false"
else
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS"
]
=
"GemmPrimitive=true"
yield
yield
finally
:
finally
:
if
enabled
:
if
enabled
:
if
orig_custom_calls_filter
is
None
:
if
orig_custom_calls_filter
is
None
:
os
.
environ
.
pop
(
"NVTE_JAX_CUSTOM_CALLS
_RE
"
)
os
.
environ
.
pop
(
"NVTE_JAX_CUSTOM_CALLS"
)
else
:
else
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS
_RE
"
]
=
orig_custom_calls_filter
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS"
]
=
orig_custom_calls_filter
tests/pytorch/
fused_attn/run_fused_att
n_with_cp.py
→
tests/pytorch/
attention/run_attentio
n_with_cp.py
View file @
87e3e56e
...
@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
...
@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
get_cu_seqlens_on_cp_rank
,
get_cu_seqlens_on_cp_rank
,
)
)
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
test_
fused_att
n_with_cp
import
model_configs_flash_attn
,
model_configs_fused_attn
from
test_
attentio
n_with_cp
import
model_configs_flash_attn
,
model_configs_fused_attn
from
transformer_engine.pytorch.fp8
import
fp8_autocast
from
transformer_engine.pytorch.fp8
import
fp8_autocast
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8Quantizer
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.common.recipe
import
DelayedScaling
...
...
tests/pytorch/
fused_attn/test_fused_att
n.py
→
tests/pytorch/
attention/test_attentio
n.py
View file @
87e3e56e
...
@@ -4,12 +4,12 @@
...
@@ -4,12 +4,12 @@
import
logging
import
logging
import
math
import
math
import
os
import
os
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
,
Optional
import
sys
from
contextlib
import
contextmanager
import
pathlib
from
typing
import
Any
,
Dict
,
Tuple
,
Union
import
pytest
import
pytest
import
torch
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch
import
TransformerLayer
,
fp8_autocast
,
fp8_model_init
from
transformer_engine.pytorch
import
TransformerLayer
,
fp8_autocast
,
fp8_model_init
...
@@ -20,11 +20,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import (
...
@@ -20,11 +20,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import (
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
FlashAttentionUtils
,
FlashAttentionUtils
,
get_attention_backend
,
check_set_window_size
,
check_set_window_size
,
AttentionParams
,
)
)
from
transformer_engine.pytorch.attention
import
InferenceParams
from
transformer_engine.pytorch.attention
import
RotaryPositionEmbedding
from
transformer_engine.pytorch.attention
import
RotaryPositionEmbedding
import
transformer_engine.pytorch.cpp_extensions
as
ext
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
(
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
(
...
@@ -49,21 +46,21 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
...
@@ -49,21 +46,21 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
restore_from_saved
,
restore_from_saved
,
)
)
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
from
utils
import
(
reset_rng_states
,
ModelConfig
,
dtype_tols
,
get_available_attention_backends
,
)
# Only run FP8 tests on H100
# Only run FP8 tests on H100
fp8_available
,
reason_for_no_fp8
=
fp8
.
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
fp8
.
FP8GlobalStateManager
.
is_fp8_available
()
# Initialize RNG state
seed
=
1234
seed
=
1234
torch
.
manual_seed
(
seed
)
# Reset RNG states
torch
.
cuda
.
manual_seed
(
seed
)
reset_rng_states
()
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
def
reset_rng_states
()
->
None
:
"""Revert back to initial RNG state"""
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
@@ -72,170 +69,20 @@ def reset_global_fp8_state():
...
@@ -72,170 +69,20 @@ def reset_global_fp8_state():
fp8
.
FP8GlobalStateManager
.
reset
()
fp8
.
FP8GlobalStateManager
.
reset
()
class
ModelConfig
:
def
__init__
(
self
,
batch_size
:
int
,
num_heads
:
int
,
num_gqa_groups
:
int
,
head_dim_qk
:
int
,
max_seqlen_q
:
int
,
max_seqlen_kv
:
int
,
dropout_p
:
float
,
attn_mask_type
:
str
,
attn_bias_type
:
str
,
head_dim_v
:
int
=
None
,
alibi_type
:
str
=
"none"
,
num_layers
:
int
=
1
,
bias_shape
:
str
=
"1hss"
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
total_requests
:
int
=
None
,
max_ctx_len
:
int
=
None
,
):
self
.
batch_size
=
batch_size
self
.
num_heads
=
num_heads
self
.
num_gqa_groups
=
num_gqa_groups
self
.
head_dim_qk
=
head_dim_qk
self
.
head_dim_v
=
head_dim_qk
if
head_dim_v
is
None
else
head_dim_v
self
.
hidden_size
=
num_heads
*
head_dim_qk
self
.
hidden_size_kv
=
num_gqa_groups
*
self
.
head_dim_v
self
.
max_seqlen_q
=
max_seqlen_q
self
.
max_seqlen_kv
=
max_seqlen_kv
self
.
dropout_p
=
dropout_p
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_bias_type
=
attn_bias_type
self
.
alibi_type
=
alibi_type
self
.
attn_type
=
"self"
if
(
max_seqlen_q
==
max_seqlen_kv
)
else
"cross"
self
.
num_layers
=
num_layers
self
.
bias_shape
=
bias_shape
self
.
window_size
=
window_size
self
.
total_requests
=
total_requests
self
.
max_ctx_len
=
max_ctx_len
@
contextmanager
def
logging_context
(
highest_level
=
logging
.
WARNING
):
previous_level
=
logging
.
root
.
manager
.
disable
logging
.
disable
(
highest_level
)
try
:
yield
finally
:
logging
.
disable
(
previous_level
)
def
_get_attention_backends
(
config
:
ModelConfig
,
qkv_dtype
:
torch
.
dtype
,
qkv_layout
:
str
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
pad_between_seqs
:
bool
=
False
,
context_parallel
:
bool
=
False
,
deterministic
:
bool
=
False
,
fp8
:
bool
=
False
,
fp8_meta
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_training
:
bool
=
True
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
)
->
Tuple
[
List
,
List
]:
"""Check if what attention backends support a model configuration"""
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
alibi_slopes_shape
=
None
if
config
.
attn_bias_type
==
"alibi"
and
config
.
alibi_type
==
"custom"
:
if
config
.
bias_shape
==
"1hss"
:
alibi_slopes_shape
=
[
config
.
num_heads
]
if
config
.
bias_shape
==
"bhss"
:
alibi_slopes_shape
=
[
config
.
batch_size
,
config
.
num_heads
]
core_attention_bias_shape
=
(
config
.
bias_shape
if
config
.
attn_bias_type
==
"post_scale_bias"
else
None
)
core_attention_bias_requires_grad
=
False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if
(
config
.
attn_bias_type
==
"post_scale_bias"
and
config
.
head_dim_qk
<=
128
and
config
.
head_dim_v
<=
128
):
core_attention_bias_requires_grad
=
True
fused_attn_backends
=
[]
available_backends
=
None
flash_attention_backend
=
None
fused_attention_backend
=
None
def
test
():
attention_params
=
AttentionParams
(
qkv_dtype
=
qkv_dtype
,
qkv_layout
=
qkv_layout
,
batch_size
=
config
.
batch_size
,
num_heads
=
config
.
num_heads
,
num_gqa_groups
=
config
.
num_gqa_groups
,
max_seqlen_q
=
config
.
max_seqlen_q
,
max_seqlen_kv
=
config
.
max_seqlen_kv
,
head_dim_qk
=
config
.
head_dim_qk
,
head_dim_v
=
config
.
head_dim_v
,
attn_mask_type
=
config
.
attn_mask_type
,
window_size
=
window_size
,
alibi_slopes_shape
=
alibi_slopes_shape
,
core_attention_bias_type
=
config
.
attn_bias_type
,
core_attention_bias_shape
=
core_attention_bias_shape
,
core_attention_bias_requires_grad
=
core_attention_bias_requires_grad
,
pad_between_seqs
=
pad_between_seqs
,
attention_dropout
=
config
.
dropout_p
,
context_parallel
=
context_parallel
,
deterministic
=
deterministic
,
fp8
=
fp8
,
fp8_meta
=
fp8_meta
,
is_training
=
is_training
,
inference_params
=
inference_params
,
)
(
use_flash_attention
,
use_fused_attention
,
flash_attention_backend
,
fused_attention_backend
,
use_unfused_attention
,
available_backends
,
)
=
get_attention_backend
(
attention_params
)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
_attention_backends
[
"use_fused_attention"
]
=
use_fused_attention
_attention_backends
[
"flash_attention_backend"
]
=
flash_attention_backend
_attention_backends
[
"fused_attention_backend"
]
=
fused_attention_backend
_attention_backends
[
"use_unfused_attention"
]
=
use_unfused_attention
_attention_backends
[
"backend_selection_requires_update"
]
=
False
return
available_backends
,
flash_attention_backend
,
fused_attention_backend
backends
=
{
0
:
"F16_max512_seqlen"
,
1
:
"F16_arbitrary_seqlen"
,
2
:
"FP8"
}
with
logging_context
():
for
i
in
range
(
3
):
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
str
(
i
)
_attention_backends
[
"backend_selection_requires_update"
]
=
True
available_backends
,
flash_attention_backend
,
fused_attention_backend
=
test
()
if
fused_attention_backend
==
FusedAttnBackend
[
backends
[
i
]]:
fused_attn_backends
.
append
(
fused_attention_backend
)
return
available_backends
,
flash_attention_backend
,
fused_attn_backends
model_configs_base
=
{
model_configs_base
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"base_1_0"
:
ModelConfig
(
8
,
1
6
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_1_0"
:
ModelConfig
(
8
,
1
28
,
16
,
64
),
"base_1_1"
:
ModelConfig
(
4
,
1
6
,
16
,
64
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_1_1"
:
ModelConfig
(
4
,
1
28
,
16
,
64
,
max_seqlen_kv
=
256
),
"base_2_0"
:
ModelConfig
(
2
,
2
4
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_2_0"
:
ModelConfig
(
2
,
2
048
,
24
,
128
),
"base_2_1"
:
ModelConfig
(
1
,
2
4
,
24
,
128
,
2048
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_2_1"
:
ModelConfig
(
1
,
2
048
,
24
,
128
,
max_seqlen_kv
=
4096
),
"base_3_0"
:
ModelConfig
(
8
,
1
6
,
16
,
128
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
),
"base_3_1"
:
ModelConfig
(
8
,
1
6
,
16
,
256
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
),
"base_4_0"
:
ModelConfig
(
8
,
1
6
,
16
,
192
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_4_0"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
),
"base_4_1"
:
ModelConfig
(
8
,
1
6
,
16
,
192
,
128
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_4_1"
:
ModelConfig
(
8
,
1
28
,
16
,
192
,
max_seqlen_kv
=
2048
),
"base_5_0"
:
ModelConfig
(
8
,
1
6
,
16
,
512
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_5_0"
:
ModelConfig
(
8
,
1
,
16
,
512
,
max_seqlen_kv
=
2048
),
"base_5_1"
:
ModelConfig
(
8
,
1
6
,
16
,
512
,
128
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_5_1"
:
ModelConfig
(
8
,
1
28
,
16
,
512
,
max_seqlen_kv
=
2048
),
"base_6_0"
:
ModelConfig
(
8
,
1
6
,
16
,
1024
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_6_0"
:
ModelConfig
(
8
,
1
,
16
,
1024
,
max_seqlen_kv
=
2048
),
"base_6_1"
:
ModelConfig
(
8
,
1
6
,
16
,
1024
,
128
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"base_6_1"
:
ModelConfig
(
8
,
1
28
,
16
,
1024
,
max_seqlen_kv
=
2048
),
}
}
...
@@ -279,7 +126,7 @@ def test_dot_product_attention(
...
@@ -279,7 +126,7 @@ def test_dot_product_attention(
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
is_training
=
True
is_training
=
True
available_backends
,
_
,
fused_attn_backends
=
_
get_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_
available_
attention_backends
(
config
,
config
,
qkv_dtype
=
dtype
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
...
@@ -290,7 +137,7 @@ def test_dot_product_attention(
...
@@ -290,7 +137,7 @@ def test_dot_product_attention(
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
:
if
not
fused_attn_supported
:
is_training
=
False
is_training
=
False
available_backends
,
_
,
fused_attn_backends
=
_
get_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_
available_
attention_backends
(
config
,
config
,
qkv_dtype
=
dtype
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
...
@@ -411,62 +258,26 @@ def test_dpa_checkpoint(dtype, model_configs, model):
...
@@ -411,62 +258,26 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
if
IS_HIP_EXTENSION
:
model_configs_mla
=
{
model_configs_mla
=
{
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0"
:
ModelConfig
(
"mla_1_0"
:
ModelConfig
(
8
,
128
,
16
,
64
,
head_dim_v
=
128
),
# self , 0
8
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
"mla_1_1"
:
ModelConfig
(
4
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# cross, 0
),
# self , 0
"mla_1_2"
:
ModelConfig
(
4
,
128
,
16
,
192
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# cross, 0
"mla_1_1"
:
ModelConfig
(
"mla_2_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# self , 1
4
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
"mla_2_1"
:
ModelConfig
(
),
# cross, 0
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
"mla_1_2"
:
ModelConfig
(
),
# cross, 1
4
,
16
,
16
,
192
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
"mla_2_2"
:
ModelConfig
(
),
# cross, 0
1
,
2048
,
24
,
192
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
128
"mla_2_0"
:
ModelConfig
(
),
# cross, 1
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
"mla_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
head_dim_v
=
64
),
# inference
),
# self , 1
"mla_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
"mla_2_1"
:
ModelConfig
(
"mla_3_2"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
1
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
}
),
# cross, 1
"mla_2_2"
:
ModelConfig
(
1
,
24
,
24
,
192
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 1
}
else
:
model_configs_mla
=
{
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0"
:
ModelConfig
(
8
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# self , 0
"mla_1_1"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 0
"mla_1_2"
:
ModelConfig
(
4
,
16
,
16
,
192
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 0
"mla_2_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
),
# self , 1
"mla_2_1"
:
ModelConfig
(
1
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
),
# cross, 1
"mla_2_2"
:
ModelConfig
(
1
,
24
,
24
,
192
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 1
"mla_3_0"
:
ModelConfig
(
8
,
16
,
16
,
128
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
64
),
# inference
"mla_3_1"
:
ModelConfig
(
8
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# inference
"mla_3_2"
:
ModelConfig
(
8
,
16
,
16
,
192
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# inference
}
@
pytest
.
mark
.
skipif
(
not
IS_HIP_EXTENSION
and
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_mla
])
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_mla
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_mla
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_mla
.
keys
())
...
@@ -477,40 +288,46 @@ def test_dpa_mla(dtype, model_configs, model):
...
@@ -477,40 +288,46 @@ def test_dpa_mla(dtype, model_configs, model):
model_configs_mask
=
{
model_configs_mask
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"mask_1_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"causal"
),
"mask_1_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"mask_1_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"causal"
),
"mask_1_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
"mask_1_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
),
"mask_2_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"mask_2_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"causal_bottom_right"
),
"mask_2_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"mask_2_1"
:
ModelConfig
(
"mask_2_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"causal_bottom_right"
"mask_3_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"mask_3_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"mask_3_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"no_bias"
),
"mask_4_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"mask_4_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"mask_4_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"no_bias"
),
"mask_5_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
),
),
"mask_2_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal_bottom_right"
),
"mask_3_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding"
),
"mask_3_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding"
),
"mask_3_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
),
"mask_4_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal"
),
"mask_4_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal"
),
"mask_4_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
),
"mask_5_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"mask_5_1"
:
ModelConfig
(
"mask_5_1"
:
ModelConfig
(
2
,
2
4
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal_bottom_right"
),
),
"mask_5_2"
:
ModelConfig
(
"mask_5_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"mask_6_0"
:
ModelConfig
(
2
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal"
),
"mask_6_1"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal"
),
"mask_7_0"
:
ModelConfig
(
2
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal_bottom_right"
),
"mask_7_1"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal_bottom_right"
),
),
"mask_6_0"
:
ModelConfig
(
2
,
16
,
16
,
128
,
1
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"mask_8_0"
:
ModelConfig
(
2
,
1
,
24
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding"
),
"mask_6_1"
:
ModelConfig
(
2
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"mask_8_1"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding"
),
"mask_7_0"
:
ModelConfig
(
2
,
16
,
16
,
128
,
1
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"mask_9_0"
:
ModelConfig
(
2
,
1
,
24
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal"
),
"mask_7_1"
:
ModelConfig
(
2
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"mask_9_1"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal"
),
"mask_8_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
1
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"mask_8_1"
:
ModelConfig
(
2
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"mask_9_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
1
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"mask_9_1"
:
ModelConfig
(
2
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"mask_10_0"
:
ModelConfig
(
"mask_10_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
1
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
1
,
24
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal_bottom_right"
),
),
"mask_10_1"
:
ModelConfig
(
"mask_10_1"
:
ModelConfig
(
2
,
1
6
,
16
,
256
,
1
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal_bottom_right"
),
),
}
}
...
@@ -526,44 +343,102 @@ def test_dpa_mask(dtype, model_configs, model):
...
@@ -526,44 +343,102 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias
=
{
model_configs_bias
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
"bias_1_0"
:
ModelConfig
(
4
,
128
,
16
,
64
,
attn_bias_type
=
"post_scale_bias"
),
"bias_1_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
"bias_1_1"
:
ModelConfig
(
2
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_bias_type
=
"post_scale_bias"
),
"bias_1_2"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
"bias_1_2"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_bias_type
=
"post_scale_bias"
),
"bias_1_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
"bias_1_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_bias_type
=
"post_scale_bias"
),
"bias_1_4"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"alibi"
),
# skipped
"bias_1_4"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_bias_type
=
"alibi"
),
# skipped
"bias_1_5"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"no_mask"
,
"alibi"
),
# skipped
"bias_1_5"
:
ModelConfig
(
"bias_2_0"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"padding"
,
"post_scale_bias"
),
# skipped
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_bias_type
=
"alibi"
"bias_2_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"padding"
,
"post_scale_bias"
),
# skipped
),
# skipped
"bias_2_0"
:
ModelConfig
(
4
,
128
,
16
,
64
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"post_scale_bias"
),
# skipped
"bias_2_1"
:
ModelConfig
(
2
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"post_scale_bias"
,
),
# skipped
"bias_2_2"
:
ModelConfig
(
"bias_2_2"
:
ModelConfig
(
4
,
2
4
,
24
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"post_scale_bias"
4
,
2
048
,
24
,
128
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"post_scale_bias"
),
# skipped
),
# skipped
"bias_2_3"
:
ModelConfig
(
"bias_2_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"post_scale_bias"
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"post_scale_bias"
,
),
# skipped
"bias_2_4"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"alibi"
),
# skipped
"bias_2_5"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"alibi"
),
# skipped
),
# skipped
"bias_2_4"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"alibi"
),
# skipped
"bias_3_0"
:
ModelConfig
(
"bias_2_5"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"alibi"
),
# skipped
4
,
128
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
"bias_3_0"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"causal"
,
"post_scale_bias"
),
),
"bias_3_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"causal"
,
"post_scale_bias"
),
"bias_3_1"
:
ModelConfig
(
"bias_3_2"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"post_scale_bias"
),
2
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
"bias_3_2"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
"bias_3_3"
:
ModelConfig
(
"bias_3_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"post_scale_bias"
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
# skipped
"bias_3_4"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
),
"bias_3_5"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
),
# skipped
),
# skipped
"bias_3_4"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"alibi"
),
"bias_3_5"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"alibi"
),
# skipped
"bias_4_0"
:
ModelConfig
(
"bias_4_0"
:
ModelConfig
(
4
,
1
6
,
16
,
64
,
128
,
128
,
0.0
,
"padding_causal"
,
"post_scale_bias"
4
,
1
28
,
16
,
64
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
),
# skipped
),
# skipped
"bias_4_1"
:
ModelConfig
(
"bias_4_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"padding_causal"
,
"post_scale_bias"
2
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
# skipped
),
# skipped
"bias_4_2"
:
ModelConfig
(
"bias_4_2"
:
ModelConfig
(
4
,
2
4
,
24
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"post_scale_bias"
4
,
2
048
,
24
,
128
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
),
# skipped
),
# skipped
"bias_4_3"
:
ModelConfig
(
"bias_4_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"post_scale_bias"
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
# skipped
"bias_4_4"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"alibi"
),
# skipped
"bias_4_5"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"alibi"
,
),
# skipped
),
# skipped
"bias_4_4"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"alibi"
),
# skipped
"bias_4_5"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"alibi"
),
# skipped
}
}
...
@@ -578,33 +453,29 @@ def test_dpa_bias(dtype, model_configs, model):
...
@@ -578,33 +453,29 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes
=
{
model_configs_bias_shapes
=
{
# test: b, h, hg, d, sq, skv, p,
# test: b, h, hg, d, sq, skv, p,
"bias_1_0"
:
ModelConfig
(
"bias_1_0"
:
ModelConfig
(
4
,
128
,
16
,
64
,
attn_bias_type
=
"post_scale_bias"
,
bias_shape
=
"11ss"
),
"bias_1_1"
:
ModelConfig
(
2
,
128
,
16
,
64
,
attn_bias_type
=
"post_scale_bias"
,
bias_shape
=
"1hss"
),
"bias_1_2"
:
ModelConfig
(
4
,
2048
,
24
,
128
,
attn_bias_type
=
"post_scale_bias"
,
bias_shape
=
"b1ss"
),
"bias_1_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_bias_type
=
"post_scale_bias"
,
bias_shape
=
"bhss"
),
"bias_1_4"
:
ModelConfig
(
4
,
4
,
16
,
2048
,
16
,
24
,
64
,
128
,
128
,
128
,
0.0
,
attn_mask_type
=
"causal"
,
# mask, bias, bias_shape,
attn_bias_type
=
"alibi"
,
"no_mask"
,
bias_shape
=
"1hss"
,
"post_scale_bias"
,
alibi_type
=
"custom"
,
bias_shape
=
"11ss"
,
),
"bias_1_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"post_scale_bias"
,
bias_shape
=
"1hss"
),
"bias_1_2"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"post_scale_bias"
,
bias_shape
=
"b1ss"
),
"bias_1_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"post_scale_bias"
,
bias_shape
=
"bhss"
),
"bias_1_4"
:
ModelConfig
(
4
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"alibi"
,
bias_shape
=
"1hss"
,
alibi_type
=
"custom"
),
),
"bias_1_5"
:
ModelConfig
(
"bias_1_5"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"alibi"
,
bias_shape
=
"bhss"
,
alibi_type
=
"custom"
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
bias_shape
=
"bhss"
,
alibi_type
=
"custom"
,
),
),
}
}
...
@@ -620,34 +491,36 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
...
@@ -620,34 +491,36 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa
=
{
model_configs_swa
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"swa_1_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
),
"swa_1_2"
:
ModelConfig
(
2
,
24
,
4
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"swa_1_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
4
),
"swa_1_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
"swa_1_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
),
"swa_2_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"swa_2_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"causal"
),
"swa_2_2"
:
ModelConfig
(
2
,
24
,
4
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"swa_2_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"causal"
),
"swa_2_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
"swa_2_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
),
"swa_3_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"swa_3_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"causal_bottom_right"
),
"swa_3_2"
:
ModelConfig
(
2
,
24
,
4
,
128
,
2048
,
2048
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
"swa_3_2"
:
ModelConfig
(
"swa_3_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal_bottom_right"
,
"no_bias"
),
2
,
2048
,
24
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"causal_bottom_right"
"swa_4_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
),
"swa_4_2"
:
ModelConfig
(
2
,
24
,
4
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"swa_3_3"
:
ModelConfig
(
"swa_4_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"no_bias"
),
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal_bottom_right"
"swa_5_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"swa_5_2"
:
ModelConfig
(
2
,
24
,
4
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"swa_5_3"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"no_bias"
),
"swa_6_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
),
),
"swa_4_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding"
),
"swa_4_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"padding"
),
"swa_4_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
),
"swa_5_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal"
),
"swa_5_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"padding_causal"
),
"swa_5_3"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
),
"swa_6_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"swa_6_2"
:
ModelConfig
(
"swa_6_2"
:
ModelConfig
(
2
,
2
4
,
4
,
128
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"padding_causal_bottom_right"
),
),
"swa_6_3"
:
ModelConfig
(
"swa_6_3"
:
ModelConfig
(
2
,
2
4
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2
048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal_bottom_right"
),
),
}
}
@
pytest
.
mark
.
skipif
((
not
IS_HIP_EXTENSION
)
and
(
not
FlashAttentionUtils
.
v2_3_plus
)
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
skipif
(
not
FlashAttentionUtils
.
v2_3_plus
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_swa
])
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_swa
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_swa
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_swa
.
keys
())
...
@@ -658,18 +531,36 @@ def test_dpa_sliding_window(dtype, model_configs, model):
...
@@ -658,18 +531,36 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes
=
{
model_configs_alibi_slopes
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"causal"
,
"alibi"
,
alibi_type
=
"vanilla"
),
"alibi_1_0"
:
ModelConfig
(
"alibi_1_1"
:
ModelConfig
(
1
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"causal"
,
"alibi"
,
alibi_type
=
"vanilla"
),
2
,
128
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
alibi_type
=
"vanilla"
),
"alibi_1_1"
:
ModelConfig
(
1
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
alibi_type
=
"vanilla"
,
),
"alibi_2_0"
:
ModelConfig
(
"alibi_2_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
1024
,
1024
,
0.0
,
"causal"
,
"alibi"
,
alibi_type
=
"custom"
2
,
10
24
,
24
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
alibi_type
=
"custom"
),
),
"alibi_2_1"
:
ModelConfig
(
"alibi_2_1"
:
ModelConfig
(
1
,
24
,
24
,
128
,
1024
,
2048
,
0.0
,
"causal"
,
"alibi"
,
alibi_type
=
"custom"
1
,
1024
,
24
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
,
alibi_type
=
"custom"
,
),
),
}
}
@
pytest
.
mark
.
skipif
((
not
IS_HIP_EXTENSION
)
and
(
not
FlashAttentionUtils
.
v2_3_plus
)
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
skipif
(
not
FlashAttentionUtils
.
v2_3_plus
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_alibi_slopes
])
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_alibi_slopes
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_alibi_slopes
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_alibi_slopes
.
keys
())
...
@@ -694,16 +585,38 @@ qkv_layouts = [
...
@@ -694,16 +585,38 @@ qkv_layouts = [
model_configs_layout
=
{
model_configs_layout
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
),
"layout_0_0"
:
ModelConfig
(
2
,
128
,
16
,
64
),
"layout_0_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"causal"
,
"post_scale_bias"
),
"layout_0_1"
:
ModelConfig
(
"layout_0_2"
:
ModelConfig
(
1
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"padding"
,
"no_bias"
),
2
,
128
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
"layout_0_3"
:
ModelConfig
(
1
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"padding_causal"
,
"post_scale_bias"
),
),
"layout_1_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"layout_0_2"
:
ModelConfig
(
1
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"padding"
),
"layout_1_1"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"post_scale_bias"
),
"layout_0_3"
:
ModelConfig
(
"layout_1_2"
:
ModelConfig
(
1
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"no_bias"
),
1
,
"layout_1_3"
:
ModelConfig
(
1
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"post_scale_bias"
),
128
,
"layout_2_0"
:
ModelConfig
(
2
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
16
,
"layout_2_1"
:
ModelConfig
(
2
,
24
,
24
,
256
,
2048
,
2048
,
0.0
,
"causal"
,
"post_scale_bias"
),
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
"layout_1_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
),
"layout_1_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
"layout_1_2"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
),
"layout_1_3"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
"layout_2_0"
:
ModelConfig
(
2
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
),
"layout_2_1"
:
ModelConfig
(
2
,
2048
,
24
,
256
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
}
}
...
@@ -720,55 +633,54 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
...
@@ -720,55 +633,54 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd
=
[
"t3hd"
,
"th3d"
,
"thd_t2hd"
,
"thd_th2d"
,
"thd_thd_thd"
]
qkv_layouts_thd
=
[
"t3hd"
,
"th3d"
,
"thd_t2hd"
,
"thd_th2d"
,
"thd_thd_thd"
]
model_configs_layout_thd
=
{
model_configs_layout_thd
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"layout_0_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding"
),
"layout_0_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
"layout_0_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding"
),
"layout_0_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"no_bias"
),
"layout_0_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
),
"layout_1_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"layout_1_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal"
),
"layout_1_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"layout_1_1"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal"
),
"layout_1_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"no_bias"
),
"layout_1_2"
:
ModelConfig
(
"layout_2_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
),
),
"layout_2_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"layout_2_1"
:
ModelConfig
(
"layout_2_1"
:
ModelConfig
(
2
,
2
4
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal_bottom_right"
),
),
"layout_2_2"
:
ModelConfig
(
"layout_2_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"layout_3_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
,
window_size
=
(
4
,
4
)
),
),
"layout_3_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding"
,
window_size
=
(
4
,
4
)),
"layout_3_1"
:
ModelConfig
(
"layout_3_1"
:
ModelConfig
(
2
,
2
4
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias
"
,
window_size
=
(
4
,
4
)
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding
"
,
window_size
=
(
4
,
4
)
),
),
"layout_3_2"
:
ModelConfig
(
"layout_3_2"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding"
,
"no_bias"
,
window_size
=
(
4
,
4
)
2
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding"
,
window_size
=
(
4
,
4
)
),
"layout_4_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
,
window_size
=
(
4
,
0
)
),
),
"layout_4_0"
:
ModelConfig
(
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal"
,
window_size
=
(
4
,
0
)),
"layout_4_1"
:
ModelConfig
(
"layout_4_1"
:
ModelConfig
(
2
,
2
4
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
,
window_size
=
(
4
,
0
)
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal"
,
window_size
=
(
4
,
0
)
),
),
"layout_4_2"
:
ModelConfig
(
"layout_4_2"
:
ModelConfig
(
2
,
2
4
,
24
,
128
,
2048
,
4096
,
0.0
,
"padding_causal"
,
"no_bias"
,
window_size
=
(
4
,
0
)
2
,
2
048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal"
,
window_size
=
(
4
,
0
)
),
),
"layout_5_0"
:
ModelConfig
(
"layout_5_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
,
window_size
=
(
4
,
0
)
2
,
2048
,
16
,
64
,
attn_mask_type
=
"padding_causal_bottom_right"
,
window_size
=
(
4
,
0
)
),
),
"layout_5_1"
:
ModelConfig
(
"layout_5_1"
:
ModelConfig
(
2
,
24
,
1
,
128
,
2048
,
2048
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
,
window_size
=
(
4
,
0
)
2
,
2048
,
24
,
128
,
num_gqa_groups
=
1
,
attn_mask_type
=
"padding_causal_bottom_right"
,
window_size
=
(
4
,
0
),
),
),
"layout_5_2"
:
ModelConfig
(
"layout_5_2"
:
ModelConfig
(
2
,
2
,
2
4
,
2
048
,
24
,
24
,
128
,
128
,
2048
,
max_seqlen_kv
=
4096
,
4096
,
attn_mask_type
=
"padding_causal_bottom_right"
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
,
window_size
=
(
4
,
0
),
window_size
=
(
4
,
0
),
),
),
}
}
...
@@ -1158,16 +1070,22 @@ def _run_dot_product_attention(
...
@@ -1158,16 +1070,22 @@ def _run_dot_product_attention(
model_configs_te_layer
=
{
model_configs_te_layer
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
"te_1_0"
:
ModelConfig
(
2
,
128
,
16
,
64
,
attn_bias_type
=
"post_scale_bias"
),
"te_1_1"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"causal"
,
"post_scale_bias"
),
"te_1_1"
:
ModelConfig
(
"te_1_2"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"padding"
,
"post_scale_bias"
),
4
,
128
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
"te_1_3"
:
ModelConfig
(
2
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"padding"
,
"no_bias"
),
),
"te_2_0"
:
ModelConfig
(
1
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
),
"te_1_2"
:
ModelConfig
(
"te_2_1"
:
ModelConfig
(
2
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
2
,
128
,
16
,
64
,
attn_mask_type
=
"padding"
,
attn_bias_type
=
"post_scale_bias"
"te_2_2"
:
ModelConfig
(
1
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias"
),
),
"te_2_3"
:
ModelConfig
(
1
,
16
,
16
,
64
,
2048
,
4096
,
0.0
,
"padding_causal_bottom_right"
,
"no_bias"
),
"te_1_3"
:
ModelConfig
(
2
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
attn_mask_type
=
"padding"
),
"te_3_0"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"causal"
,
"alibi"
),
"te_2_0"
:
ModelConfig
(
1
,
2048
,
16
,
64
,
attn_mask_type
=
"causal"
),
"te_3_1"
:
ModelConfig
(
4
,
16
,
16
,
64
,
2048
,
2048
,
0.0
,
"causal"
,
"alibi"
),
"te_2_1"
:
ModelConfig
(
2
,
2048
,
16
,
64
),
"te_2_2"
:
ModelConfig
(
1
,
2048
,
16
,
64
,
attn_mask_type
=
"padding"
),
"te_2_3"
:
ModelConfig
(
1
,
2048
,
16
,
64
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"padding_causal_bottom_right"
),
"te_3_0"
:
ModelConfig
(
4
,
128
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
),
"te_3_1"
:
ModelConfig
(
4
,
2048
,
16
,
64
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"alibi"
),
}
}
...
@@ -1189,26 +1107,27 @@ def test_transformer_layer(
...
@@ -1189,26 +1107,27 @@ def test_transformer_layer(
tols
=
dict
(
atol
=
5e-2
,
rtol
=
5e-2
)
tols
=
dict
(
atol
=
5e-2
,
rtol
=
5e-2
)
workspace_opt
=
True
workspace_opt
=
True
qkv_layout
=
"sbh3d"
if
fused_qkv_params
else
"sb3hd"
# override the qkv_layout in mqa gqa mode in ROCm TE
if
IS_HIP_EXTENSION
and
model_configs
[
model
].
num_gqa_groups
!=
model_configs
[
model
].
num_heads
:
qkv_layout
=
"sbhd_sbhd_sbhd"
# Test backend availability
# Test backend availability
is_training
=
True
is_training
=
True
available_backends
,
_
,
fused_attn_backends
=
_
get_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_
available_
attention_backends
(
config
,
config
,
qkv_dtype
=
dtype
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
(
qkv_format
.
replace
(
"hd"
,
"h3d"
)
if
fused_qkv_params
else
qkv_format
.
replace
(
"hd"
,
"3hd"
)
),
is_training
=
is_training
,
is_training
=
is_training
,
)
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
:
if
not
fused_attn_supported
:
is_training
=
False
is_training
=
False
available_backends
,
_
,
fused_attn_backends
=
_
get_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_
available_
attention_backends
(
config
,
config
,
qkv_dtype
=
dtype
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
(
qkv_format
.
replace
(
"hd"
,
"h3d"
)
if
fused_qkv_params
else
qkv_format
.
replace
(
"hd"
,
"3hd"
)
),
is_training
=
is_training
,
is_training
=
is_training
,
)
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
...
@@ -1514,20 +1433,164 @@ def _run_transformer_layer(
...
@@ -1514,20 +1433,164 @@ def _run_transformer_layer(
return
out
,
inp
.
grad
return
out
,
inp
.
grad
model_configs_fp8_extra_state
=
{
"large"
:
ModelConfig
(
2
,
128
,
4
,
128
,
num_layers
=
1
),
}
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_sanity_attention_extra_state
(
model
,
dtype
):
config
=
model_configs_fp8_extra_state
[
model
]
# Test backend availability
is_training
=
True
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_layout
=
"sb3hd"
,
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
and
not
flash_attn_supported
:
pytest
.
skip
(
"No attention backend available."
)
outputs
=
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
True
)
outputs_checkpoint_v1_6
=
_run_attention_extra_state
(
dtype
,
config
,
mimic_v1_6
=
True
,
checkpoint
=
True
)
# Check that results match
tols
=
dtype_tols
(
dtype
)
if
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
tols
.
update
(
dict
(
rtol
=
2e-2
,
atol
=
2e-3
))
for
i
,
(
ref
,
test
)
in
enumerate
(
zip
(
outputs
,
outputs_checkpoint
)):
torch
.
testing
.
assert_close
(
test
,
ref
,
**
tols
,
)
for
i
,
(
ref
,
test
)
in
enumerate
(
zip
(
outputs
,
outputs_checkpoint_v1_6
)):
torch
.
testing
.
assert_close
(
test
,
ref
,
**
tols
,
)
def
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
False
,
mimic_v1_6
=
False
):
steps
=
10
path
=
"checkpoint.pt"
fp8_enabled
=
True
fp8_recipe
=
recipe
.
DelayedScaling
(
margin
=
0
,
fp8_format
=
recipe
.
Format
.
HYBRID
,
amax_history_len
=
1
,
amax_compute_algo
=
"most_recent"
,
fp8_dpa
=
fp8_enabled
,
fp8_mha
=
False
,
)
reset_rng_states
()
hidden_states
=
torch
.
randn
(
(
config
.
max_seqlen_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
def
get_model
(
dtype
,
config
):
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
with
fp8_model_init
(
enabled
=
fp8_enabled
,
recipe
=
fp8_recipe
):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.0
,
attention_dropout
=
0.0
,
fuse_qkv_params
=
True
,
params_dtype
=
dtype
,
device
=
"cuda"
,
)
return
block
block
=
get_model
(
dtype
,
config
)
for
i
in
range
(
steps
//
2
):
with
fp8_autocast
(
enabled
=
fp8_enabled
,
fp8_recipe
=
fp8_recipe
):
output
=
block
(
hidden_states
,
None
)
loss
=
output
.
sum
()
loss
.
backward
()
if
checkpoint
:
sd
=
block
.
state_dict
()
if
mimic_v1_6
:
sd
[
"self_attention.core_attention.fused_attention._extra_state"
]
=
sd
[
"self_attention.core_attention._extra_state"
]
del
sd
[
"self_attention.core_attention._extra_state"
]
torch
.
save
(
sd
,
path
)
param_grads
=
[]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
param_grads
.
append
(
p
.
grad
.
clone
())
_cpu_rng_state_new
=
torch
.
get_rng_state
()
_cuda_rng_state_new
=
torch
.
cuda
.
get_rng_state
()
del
block
block
=
get_model
(
dtype
,
config
)
block
.
load_state_dict
(
torch
.
load
(
path
,
weights_only
=
False
))
torch
.
set_rng_state
(
_cpu_rng_state_new
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state_new
)
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
p
.
grad
=
param_grads
.
pop
(
0
)
assert
not
param_grads
,
"Oops!"
for
i
in
range
((
steps
+
1
)
//
2
):
with
fp8_autocast
(
enabled
=
fp8_enabled
,
fp8_recipe
=
fp8_recipe
):
output
=
block
(
hidden_states
,
None
)
loss
=
output
.
sum
()
loss
.
backward
()
torch
.
cuda
.
synchronize
()
if
os
.
path
.
exists
(
path
):
os
.
remove
(
path
)
outputs
=
[
output
,
hidden_states
.
grad
]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
outputs
.
append
(
p
.
grad
)
return
outputs
model_configs_fp8_vs_f16
=
{
model_configs_fp8_vs_f16
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_9"
:
ModelConfig
(
2
,
16
,
16
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_9"
:
ModelConfig
(
2
,
2048
,
16
,
128
),
"fp8_10"
:
ModelConfig
(
2
,
2
4
,
12
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_10"
:
ModelConfig
(
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
12
),
"fp8_11"
:
ModelConfig
(
1
,
32
,
4
,
128
,
8192
,
8192
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_11"
:
ModelConfig
(
1
,
8192
,
32
,
128
,
num_gqa_groups
=
4
),
"fp8_12"
:
ModelConfig
(
2
,
16
,
16
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_12"
:
ModelConfig
(
2
,
2048
,
16
,
128
,
attn_mask_type
=
"causal
"
),
"fp8_13"
:
ModelConfig
(
2
,
2
4
,
12
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_13"
:
ModelConfig
(
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
12
,
attn_mask_type
=
"causal
"
),
"fp8_14"
:
ModelConfig
(
1
,
32
,
4
,
128
,
8192
,
8192
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_14"
:
ModelConfig
(
1
,
8192
,
32
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"causal
"
),
"fp8_15"
:
ModelConfig
(
2
,
16
,
16
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias
"
),
"fp8_15"
:
ModelConfig
(
2
,
2048
,
16
,
128
,
attn_mask_type
=
"padding
"
),
"fp8_16"
:
ModelConfig
(
2
,
2
4
,
12
,
128
,
2048
,
2048
,
0.0
,
"padding"
,
"no_bias
"
),
"fp8_16"
:
ModelConfig
(
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
12
,
attn_mask_type
=
"padding
"
),
"fp8_17"
:
ModelConfig
(
1
,
32
,
4
,
128
,
8192
,
8192
,
0.0
,
"padding"
,
"no_bias
"
),
"fp8_17"
:
ModelConfig
(
1
,
8192
,
32
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"padding
"
),
"fp8_18"
:
ModelConfig
(
2
,
16
,
16
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"fp8_18"
:
ModelConfig
(
2
,
2048
,
16
,
128
,
attn_mask_type
=
"padding_causal"
),
"fp8_19"
:
ModelConfig
(
2
,
2
4
,
12
,
128
,
2048
,
2048
,
0.0
,
"padding_causal"
,
"no_bias"
),
"fp8_19"
:
ModelConfig
(
2
,
2
048
,
24
,
128
,
num_gqa_groups
=
12
,
attn_mask_type
=
"padding_causal"
),
"fp8_20"
:
ModelConfig
(
1
,
32
,
4
,
128
,
8192
,
8192
,
0.0
,
"padding_causal"
,
"no_bias"
),
"fp8_20"
:
ModelConfig
(
1
,
8192
,
32
,
128
,
num_gqa_groups
=
4
,
attn_mask_type
=
"padding_causal"
),
}
}
param_types_fp8_vs_f16
=
[
torch
.
float16
,
torch
.
bfloat16
]
param_types_fp8_vs_f16
=
[
torch
.
float16
,
torch
.
bfloat16
]
...
@@ -1561,7 +1624,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
...
@@ -1561,7 +1624,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
)
)
)
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
,
reason
=
"FP8 Fused attention is not supported on ROCm"
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
...
@@ -1576,18 +1639,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
...
@@ -1576,18 +1639,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"1"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"1"
os
.
environ
[
"NVTE_FP8_DPA_BWD"
]
=
"1"
if
fp8_dpa_bwd
else
"0"
os
.
environ
[
"NVTE_FP8_DPA_BWD"
]
=
"1"
if
fp8_dpa_bwd
else
"0"
config
=
model_configs_fp8_vs_f16
[
model
]
config
=
model_configs_fp8_vs_f16
[
model
]
if
(
"padding"
in
config
.
attn_mask_type
or
config
.
head_dim_qk
!=
128
)
and
get_cudnn_version
()
<
(
9
,
7
,
0
,
):
pytest
.
skip
(
"FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7"
)
if
(
# Test backend availability
FlashAttentionUtils
.
v3_is_installed
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
and
not
is_training
config
,
and
"padding"
not
in
config
.
attn_mask_type
qkv_dtype
=
torch
.
float8_e4m3fn
,
):
qkv_layout
=
qkv_format
.
replace
(
"hd"
,
"h3d"
),
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
# Skip if only unfused backend is supported
if
(
len
(
fused_attn_backends
)
+
flash_attn_supported
+
unfused_attn_supported
)
<
2
:
pytest
.
skip
(
"Less than two backends to compare."
)
if
not
fp8_dpa_bwd
:
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_format
.
replace
(
"hd"
,
"h3d"
),
is_training
=
is_training
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
pytest
.
skip
(
"No attention backend available."
)
if
flash_attn_supported
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
...
@@ -1613,11 +1688,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
...
@@ -1613,11 +1688,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol
=
5e-1
rtol
=
5e-1
rmse_tol
=
0.15
rmse_tol
=
0.15
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"forward output"
))
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"forward output"
))
if
(
if
flash_attn_supported
:
FlashAttentionUtils
.
v3_is_installed
and
not
is_training
and
"padding"
not
in
config
.
attn_mask_type
):
_error
(
_error
(
flash_attn_fwd_fp8
,
flash_attn_fwd_fp8
,
fused_attn_fwd_f16
,
fused_attn_fwd_f16
,
...
@@ -1768,7 +1839,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
...
@@ -1768,7 +1839,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
return
out
,
param_names
,
tuple
(
x
.
grad
for
x
in
params
)
return
out
,
param_names
,
tuple
(
x
.
grad
for
x
in
params
)
return
out
,
param_names
,
tuple
(
None
for
x
in
params
)
return
out
,
param_names
,
tuple
(
None
for
x
in
params
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
,
reason
=
"FP8 Fused attention is not supported on ROCm"
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
...
@@ -1790,23 +1861,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
...
@@ -1790,23 +1861,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
# if get_device_compute_capability() >= (10, 0):
# if get_device_compute_capability() >= (10, 0):
# config.dropout_p = 0.1
# config.dropout_p = 0.1
if
(
"padding"
in
config
.
attn_mask_type
or
config
.
head_dim_qk
!=
128
)
and
get_cudnn_version
()
<
(
9
,
7
,
0
,
):
pytest
.
skip
(
"FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7"
)
if
config
.
num_heads
!=
config
.
num_gqa_groups
and
"3"
in
qkv_layout
:
pytest
.
skip
(
"qkv_layout not applicable for MQA/GQA"
)
os
.
environ
[
"NVTE_FP8_DPA_BWD"
]
=
"1"
if
fp8_dpa_bwd
else
"0"
os
.
environ
[
"NVTE_FP8_DPA_BWD"
]
=
"1"
if
fp8_dpa_bwd
else
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"1"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"1"
if
(
# Test backend availability
FlashAttentionUtils
.
v3_is_installed
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
and
not
is_training
config
,
and
"padding"
not
in
config
.
attn_mask_type
qkv_dtype
=
torch
.
float8_e4m3fn
,
):
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
# Skip if only unfused backend is supported
if
flash_attn_supported
+
fused_attn_supported
<
1
:
pytest
.
skip
(
"No FP8 attention backend available."
)
if
not
fp8_dpa_bwd
:
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
pytest
.
skip
(
"No attention backend available."
)
if
config
.
num_heads
!=
config
.
num_gqa_groups
and
"3"
in
qkv_layout
:
pytest
.
skip
(
"qkv_layout not applicable for MQA/GQA"
)
if
flash_attn_supported
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
...
@@ -1835,11 +1917,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
...
@@ -1835,11 +1917,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol
=
0.11
rmse_tol
=
0.11
bwd_names
=
[
"dq"
,
"dk"
,
"dv"
]
bwd_names
=
[
"dq"
,
"dk"
,
"dv"
]
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"forward output"
))
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"forward output"
))
if
(
if
flash_attn_supported
:
FlashAttentionUtils
.
v3_is_installed
and
not
is_training
and
"padding"
not
in
config
.
attn_mask_type
):
_error
(
_error
(
flash_attn_fwd_fp8
,
flash_attn_fwd_fp8
,
fused_attn_fwd_f16
,
fused_attn_fwd_f16
,
...
@@ -2013,21 +2091,21 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
...
@@ -2013,21 +2091,21 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
model_configs_fp8
=
{
model_configs_fp8
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1"
:
ModelConfig
(
1
,
1
,
1
,
64
,
512
,
512
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_1"
:
ModelConfig
(
1
,
512
,
1
,
64
),
"fp8_2"
:
ModelConfig
(
4
,
16
,
16
,
64
,
512
,
512
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_2"
:
ModelConfig
(
4
,
512
,
16
,
64
),
"fp8_3"
:
ModelConfig
(
1
,
1
,
1
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_3"
:
ModelConfig
(
1
,
2048
,
1
,
128
),
"fp8_4"
:
ModelConfig
(
2
,
2
4
,
24
,
128
,
2048
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
),
"fp8_4"
:
ModelConfig
(
2
,
2
048
,
24
,
128
),
"fp8_5"
:
ModelConfig
(
1
,
1
,
1
,
64
,
512
,
512
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_5"
:
ModelConfig
(
1
,
512
,
1
,
64
,
attn_mask_type
=
"causal
"
),
"fp8_6"
:
ModelConfig
(
4
,
16
,
16
,
64
,
512
,
512
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_6"
:
ModelConfig
(
4
,
512
,
16
,
64
,
attn_mask_type
=
"causal
"
),
"fp8_7"
:
ModelConfig
(
1
,
1
,
1
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_7"
:
ModelConfig
(
1
,
2048
,
1
,
128
,
attn_mask_type
=
"causal
"
),
"fp8_8"
:
ModelConfig
(
2
,
2
4
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias
"
),
"fp8_8"
:
ModelConfig
(
2
,
2
048
,
24
,
128
,
attn_mask_type
=
"causal
"
),
}
}
param_types_fp8
=
[
torch
.
float16
,
torch
.
bfloat16
]
param_types_fp8
=
[
torch
.
float16
,
torch
.
bfloat16
]
cudnn_frontend_version
=
int
(
os
.
getenv
(
"NVTE_FUSED_ATTN_FE_VER"
,
"1"
))
cudnn_frontend_version
=
int
(
os
.
getenv
(
"NVTE_FUSED_ATTN_FE_VER"
,
"1"
))
models_v0
=
[
"fp8_1"
,
"fp8_2"
,
"fp8_5"
,
"fp8_6"
]
models_v0
=
[
"fp8_1"
,
"fp8_2"
,
"fp8_5"
,
"fp8_6"
]
models_v1
=
[
"fp8_3"
,
"fp8_4"
,
"fp8_7"
,
"fp8_8"
]
models_v1
=
[
"fp8_3"
,
"fp8_4"
,
"fp8_7"
,
"fp8_8"
]
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
,
reason
=
"FP8 Fused attention is not supported on ROCm"
)
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
(
(
get_cudnn_version
()
<
(
8
,
9
,
3
)
get_cudnn_version
()
<
(
8
,
9
,
3
)
...
@@ -2049,6 +2127,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
...
@@ -2049,6 +2127,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config
=
model_configs_fp8
[
model
]
config
=
model_configs_fp8
[
model
]
# Test backend availability
is_training
=
True
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_layout
=
"t3hd"
if
cudnn_frontend_version
==
0
else
"bs3hd"
,
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
(
fused_attn_backends
and
unfused_attn_supported
):
pytest
.
skip
(
"Not enough backends to run this test with."
)
fused_attn_fwd_fp8
,
fused_attn_bwd_fp8
=
_run_custom_mha_fp8
(
dtype
,
config
,
"FusedAttention"
)
fused_attn_fwd_fp8
,
fused_attn_bwd_fp8
=
_run_custom_mha_fp8
(
dtype
,
config
,
"FusedAttention"
)
unfused_attn_fwd_f16
,
unfused_attn_bwd_f16
=
_run_ref_mha_f16
(
dtype
,
config
,
"UnfusedAttention"
)
unfused_attn_fwd_f16
,
unfused_attn_bwd_f16
=
_run_ref_mha_f16
(
dtype
,
config
,
"UnfusedAttention"
)
...
...
tests/pytorch/
fused_attn/test_fused_att
n_with_cp.py
→
tests/pytorch/
attention/test_attentio
n_with_cp.py
View file @
87e3e56e
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
import
os
import
os
import
subprocess
import
subprocess
import
sys
import
pathlib
import
pytest
import
pytest
import
torch
import
torch
...
@@ -12,27 +14,29 @@ from transformer_engine.pytorch.utils import (
...
@@ -12,27 +14,29 @@ from transformer_engine.pytorch.utils import (
get_cudnn_version
,
get_cudnn_version
,
)
)
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
FlashAttentionUtils
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
FlashAttentionUtils
from
test_fused_attn
import
ModelConfig
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
from
utils
import
ModelConfig
,
get_available_attention_backends
# Initialize RNG state
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
model_configs_flash_attn
=
{
model_configs_flash_attn
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
# MHA
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
),
# MHA
"cp_1_2"
:
ModelConfig
(
"cp_1_2"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)),
# MHA
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
window_size
=
(
512
,
0
)
"cp_1_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
window_size
=
(
512
,
512
)),
# MHA
),
# MHA
"cp_2_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
),
# GQA
"cp_1_3"
:
ModelConfig
(
"cp_2_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
),
# GQA
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
,
window_size
=
(
512
,
512
)
),
# MHA
"cp_2_0"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
# GQA
"cp_2_1"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
# GQA
"cp_2_2"
:
ModelConfig
(
"cp_2_2"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
window_size
=
(
512
,
0
)
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)
),
# GQA
"cp_2_3"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
,
window_size
=
(
512
,
512
)
),
# GQA
),
# GQA
"cp_2_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
window_size
=
(
512
,
512
)),
# GQA
}
}
...
@@ -44,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
...
@@ -44,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
"--nproc-per-node="
+
str
(
num_gpus_per_node
),
"--nproc-per-node="
+
str
(
num_gpus_per_node
),
]
]
te_path
=
os
.
getenv
(
"TE_PATH"
,
"/opt/transformerengine"
)
te_path
=
os
.
getenv
(
"TE_PATH"
,
"/opt/transformerengine"
)
script_path
=
os
.
path
.
join
(
te_path
,
"tests/pytorch/
fused_attn/run_fused_att
n_with_cp.py"
)
script_path
=
os
.
path
.
join
(
te_path
,
"tests/pytorch/
attention/run_attentio
n_with_cp.py"
)
args
.
append
(
script_path
)
args
.
append
(
script_path
)
for
k
,
v
in
kwargs
.
items
():
for
k
,
v
in
kwargs
.
items
():
args
.
append
(
f
"
{
k
}
=
{
v
}
"
)
args
.
append
(
f
"
{
k
}
=
{
v
}
"
)
...
@@ -94,37 +98,41 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
...
@@ -94,37 +98,41 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn
=
{
model_configs_fused_attn
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
# MHA
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
),
# MHA
"cp_1_2"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"post_scale_bias"
),
# MHA
"cp_1_2"
:
ModelConfig
(
"cp_1_3"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
# MHA
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
"cp_1_4"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
window_size
=
(
512
,
0
)
),
# MHA
),
# MHA
"cp_2_0"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
),
# GQA
"cp_1_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
"cp_2_1"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
),
# GQA
"cp_1_4"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)),
# MHA
"cp_2_2"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"post_scale_bias"
),
# GQA
"cp_2_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
),
# GQA
"cp_2_3"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"post_scale_bias"
),
# GQA
"cp_2_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
),
# GQA
"cp_2_2"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
,
),
# GQA
"cp_2_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_bias_type
=
"post_scale_bias"
),
# GQA
"cp_2_4"
:
ModelConfig
(
"cp_2_4"
:
ModelConfig
(
2
,
12
,
2
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias
"
,
window_size
=
(
512
,
0
)
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal
"
,
window_size
=
(
512
,
0
)
),
# GQA
),
# GQA
"cp_3_0"
:
ModelConfig
(
"cp_3_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# MLA
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
"cp_3_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
head_dim_v
=
64
),
# MLA
),
# MLA
"cp_3_1"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
64
),
# MLA
"cp_3_2"
:
ModelConfig
(
"cp_3_2"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"causal"
,
"post_scale_bias"
,
head_dim_v
=
64
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
,
head_dim_v
=
64
),
# MLA
"cp_3_3"
:
ModelConfig
(
2
,
12
,
12
,
128
,
4096
,
4096
,
0.0
,
"no_mask"
,
"post_scale_bias"
,
head_dim_v
=
64
),
# MLA
),
# MLA
"cp_3_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_bias_type
=
"post_scale_bias"
,
head_dim_v
=
64
),
# MLA
}
}
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
7
),
reason
=
"cuDNN 8.9.7+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
7
),
reason
=
"cuDNN 8.9.7+ is required."
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
or
get_device_compute_capability
()
<
(
8
,
0
),
reason
=
"
DTK not surpport fused attn for now,
CP tests require sm80+."
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
or
get_device_compute_capability
()
<
(
8
,
0
),
reason
=
"CP tests require sm80+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bf16"
,
"fp16"
,
"fp8"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bf16"
,
"fp16"
,
"fp8"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fused_attn
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fused_attn
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
[
"bshd"
,
"sbhd"
,
"thd"
])
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
[
"bshd"
,
"sbhd"
,
"thd"
])
...
@@ -176,6 +184,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
...
@@ -176,6 +184,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest
.
skip
(
"MLA CP currently only support KV P2P!"
)
pytest
.
skip
(
"MLA CP currently only support KV P2P!"
)
if
dtype
==
"fp8"
and
config
.
head_dim_qk
!=
config
.
head_dim_v
:
if
dtype
==
"fp8"
and
config
.
head_dim_qk
!=
config
.
head_dim_v
:
pytest
.
skip
(
"MLA CP currently does not support FP8 attention!"
)
pytest
.
skip
(
"MLA CP currently does not support FP8 attention!"
)
dtypes
=
{
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
,
"fp8"
:
torch
.
bfloat16
}
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtypes
[
dtype
],
qkv_layout
=
"_"
.
join
([
qkv_format
]
*
3
),
window_size
=
config
.
window_size
,
context_parallel
=
True
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
pytest
.
skip
(
"No attention backend available."
)
subprocess
.
run
(
subprocess
.
run
(
get_bash_arguments
(
get_bash_arguments
(
...
...
tests/pytorch/
fused_att
n/test_kv_cache.py
→
tests/pytorch/
attentio
n/test_kv_cache.py
View file @
87e3e56e
...
@@ -5,18 +5,14 @@
...
@@ -5,18 +5,14 @@
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
List
from
typing
import
List
import
os
import
os
import
sys
import
pathlib
import
logging
import
logging
import
math
import
math
import
pytest
import
pytest
import
torch
import
torch
from
test_fused_attn
import
(
ModelConfig
,
reset_rng_states
,
_get_attention_backends
,
)
from
torch.distributions
import
Exponential
from
torch.distributions
import
Exponential
from
transformer_engine.pytorch
import
make_graphed_callables
from
transformer_engine.pytorch
import
make_graphed_callables
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
...
@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import (
...
@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible
,
is_bf16_compatible
,
)
)
# Initialize RNG state
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
seed
=
1234
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
torch
.
manual_seed
(
seed
)
from
utils
import
(
torch
.
cuda
.
manual_seed
(
seed
)
ModelConfig
,
_cpu_rng_state
=
torch
.
get_rng_state
()
reset_rng_states
,
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
get_available_attention_backends
,
)
# Reset RNG states
reset_rng_states
()
param_types
=
[
torch
.
float16
]
param_types
=
[
torch
.
float16
]
if
is_bf16_compatible
():
if
is_bf16_compatible
():
param_types
.
append
(
torch
.
bfloat16
)
param_types
.
append
(
torch
.
bfloat16
)
model_configs_infer
=
{
model_configs_infer
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, sq, hq, dqk,
"infer_0"
:
ModelConfig
(
"infer_0"
:
ModelConfig
(
4
,
64
,
16
,
128
,
total_requests
=
8
,
max_ctx_len
=
16
),
4
,
16
,
16
,
128
,
64
,
64
,
0.0
,
"no_mask"
,
"no_bias"
,
total_requests
=
8
,
max_ctx_len
=
16
"infer_1"
:
ModelConfig
(
2
,
66
,
16
,
256
,
num_gqa_groups
=
4
,
total_requests
=
6
,
max_ctx_len
=
16
),
),
"infer_1"
:
ModelConfig
(
2
,
16
,
4
,
256
,
66
,
66
,
0.0
,
"no_mask"
,
"no_bias"
,
total_requests
=
6
,
max_ctx_len
=
16
),
}
}
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
...
@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
...
@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
qkv_layout
=
qkv_format
+
"_"
+
"_"
.
join
([
inference_params_qkv_format
]
*
2
)
qkv_layout
=
qkv_format
+
"_"
+
"_"
.
join
([
inference_params_qkv_format
]
*
2
)
if
is_paged
:
if
is_paged
:
qkv_layout
=
"paged_kv_"
+
qkv_layout
qkv_layout
=
"paged_kv_"
+
qkv_layout
available_backends
,
_
,
fused_attn_backends
=
_
get_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_
available_
attention_backends
(
config
,
config
,
qkv_dtype
=
dtype
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
...
...
tests/pytorch/debug/run_distributed.py
View file @
87e3e56e
...
@@ -364,6 +364,40 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
...
@@ -364,6 +364,40 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
set_weight_tensor_tp_group_reduce
(
True
)
# reset
set_weight_tensor_tp_group_reduce
(
True
)
# reset
@
run_debug_test
def
sanity_test_log_quantized_stats
(
parallel_mode
,
gather_weight
,
**
kwargs
):
from
test_log
import
LOG_QUANTIZED_CONFIG
kwargs
[
"config_file"
].
write
(
LOG_QUANTIZED_CONFIG
)
kwargs
[
"config_file"
].
flush
()
_init_debug
(
kwargs
[
"config_file"
].
name
,
kwargs
[
"log_dir"
],
FEATURE_DIRS
)
set_weight_tensor_tp_group_reduce
(
gather_weight
)
if
WORLD_SIZE
%
2
!=
0
:
return
# skip
TP_SIZE
=
WORLD_SIZE
//
2
DP_SIZE
=
2
TP_RANK
=
WORLD_RANK
%
TP_SIZE
DP_RANK
=
(
WORLD_RANK
-
TP_RANK
)
//
TP_SIZE
debug_api
.
set_tensor_reduction_group
(
NCCL_WORLD
)
x
,
weight
=
_get_tensors
(
parallel_mode
,
weight_seed
=
TP_RANK
*
1234
,
data_seed
=
DP_RANK
*
1234
,
tp_size
=
TP_SIZE
,
tp_rank
=
TP_RANK
,
)
tp_group_ranks
=
[
i
for
i
in
range
(
DP_RANK
*
TP_SIZE
,
(
DP_RANK
+
1
)
*
TP_SIZE
)]
tp_group
=
dist
.
new_group
(
ranks
=
tp_group_ranks
)
model
=
_init_model
(
weight
,
parallel_mode
=
parallel_mode
,
tp_group
=
tp_group
)
_run_forward_backward
(
x
,
model
,
parallel_mode
=
parallel_mode
,
group
=
tp_group
)
set_weight_tensor_tp_group_reduce
(
True
)
# reset
@
run_debug_test
@
run_debug_test
def
test_log_expert_parallel
(
**
kwargs
):
def
test_log_expert_parallel
(
**
kwargs
):
"""
"""
...
...
tests/pytorch/debug/test_api_features.py
View file @
87e3e56e
...
@@ -24,22 +24,17 @@ def test_transformer_engine_no_config(feature_dirs):
...
@@ -24,22 +24,17 @@ def test_transformer_engine_no_config(feature_dirs):
# FP8 enabled - true by the default
# FP8 enabled - true by the default
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
)
[
0
]
# modify_tensor_enabled - False by default
# modify_tensor_enabled -
(
False
, None)
by default
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
)
[
0
]
# inspect_tensor_enabled - False by default
# inspect_tensor_enabled -
(
False
, None)
by default
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.attn.qkv"
,
tensor_name
=
"activation"
,
iteration
=
0
"decoder.1.attn.qkv"
,
tensor_name
=
"activation"
,
iteration
=
0
)
)[
0
]
# inspect_tensor_postquantize - False by default
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_postquantize_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
finally
:
finally
:
debug_api
.
end_debug
()
debug_api
.
end_debug
()
...
@@ -51,24 +46,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs):
...
@@ -51,24 +46,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs):
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
# caching
# caching
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
finally
:
finally
:
debug_api
.
end_debug
()
debug_api
.
end_debug
()
...
@@ -80,22 +75,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs):
...
@@ -80,22 +75,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs):
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
iteration
=
0
)
)
[
0
]
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
)
)
[
0
]
finally
:
finally
:
debug_api
.
end_debug
()
debug_api
.
end_debug
()
...
@@ -111,22 +106,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
...
@@ -111,22 +106,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
# check modify_tensor_enabled
# check modify_tensor_enabled
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
)
[
0
]
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"weight"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"weight"
,
iteration
=
0
)
)
[
0
]
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"weight"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"weight"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
tensor_name
=
"activation"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
tensor_name
=
"activation"
,
iteration
=
0
)
)
[
0
]
# check modify_tensor
# check modify_tensor
...
@@ -168,14 +163,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
...
@@ -168,14 +163,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
gemm
=
"wgrad"
,
gemm
=
"wgrad"
,
tensor_name
=
"gradient"
,
tensor_name
=
"gradient"
,
iteration
=
0
,
iteration
=
0
,
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc4"
,
"decoder.1.mlp.fc4"
,
gemm
=
"fprop"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
tensor_name
=
"activation"
,
iteration
=
0
,
iteration
=
0
,
)
)
[
0
]
finally
:
finally
:
debug_api
.
end_debug
()
debug_api
.
end_debug
()
...
@@ -191,11 +186,11 @@ def test_fake_quant(configs_dir, feature_dirs):
...
@@ -191,11 +186,11 @@ def test_fake_quant(configs_dir, feature_dirs):
# modify_tensor_enabled
# modify_tensor_enabled
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
)
[
0
]
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
)
)
[
0
]
# modify_tensor
# modify_tensor
debug_api
.
transformer_engine
.
modify_tensor
(
debug_api
.
transformer_engine
.
modify_tensor
(
...
@@ -218,11 +213,11 @@ def test_fake_quant(configs_dir, feature_dirs):
...
@@ -218,11 +213,11 @@ def test_fake_quant(configs_dir, feature_dirs):
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.fc2"
,
gemm
=
"wgrad"
,
iteration
=
0
"decoder.1.fc2"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
# caching
# caching
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.fc2"
,
gemm
=
"wgrad"
,
iteration
=
0
"decoder.1.fc2"
,
gemm
=
"wgrad"
,
iteration
=
0
)
)
[
0
]
finally
:
finally
:
debug_api
.
end_debug
()
debug_api
.
end_debug
()
...
@@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs):
...
@@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs):
)
)
tensor
=
torch
.
randn
((
100
,
100
,
5
)).
cuda
()
tensor
=
torch
.
randn
((
100
,
100
,
5
)).
cuda
()
tensor_fp8
=
Float8
Tenso
r
(
quantizer
=
Float8
Quantize
r
(
data
=
tensor
.
to
(
torch
.
uint8
).
cuda
(),
scale
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_scale_inv
=
torch
.
full
([
1
],
1.0
).
cuda
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
shape
=
tensor
.
shape
,
dtype
=
torch
.
float32
,
)
)
tensor_fp8
=
quantizer
(
tensor
)
def
log
():
def
log
():
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
...
@@ -260,54 +254,64 @@ def test_statistics_collection(configs_dir, feature_dirs):
...
@@ -260,54 +254,64 @@ def test_statistics_collection(configs_dir, feature_dirs):
tensor_name
=
"activation"
,
tensor_name
=
"activation"
,
iteration
=
200
,
iteration
=
200
,
tp_group
=
None
,
tp_group
=
None
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
tensor_fp8
,
columnwise_quantized_tensor
=
tensor_fp8
,
)
)
stats
=
log
()
stats
=
log
()
assert
stats
[(
"decoder.1.mlp.fc1"
,
"activation"
,
"cur_amax"
,
200
)]
==
tensor
.
abs
().
max
()
assert
stats
[(
"decoder.1.mlp.fc1"
,
"activation"
,
"cur_amax"
,
200
)]
==
tensor
.
abs
().
max
()
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"activation"
,
iteration
=
201
"decoder.1.mlp.fc1"
,
tensor_name
=
"activation"
,
iteration
=
201
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.2.mlp.fc1"
,
tensor_name
=
"activation"
,
iteration
=
200
"decoder.2.mlp.fc1"
,
tensor_name
=
"activation"
,
iteration
=
200
)
)[
0
]
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"gradient"
,
iteration
=
200
expected_underflows
=
(
((
tensor_fp8
.
_data
==
0
).
sum
()
-
(
tensor
==
0
).
sum
())
*
100
/
(
100
*
100
*
5
)
)
)
expected_underflows
=
(
tensor_fp8
.
_data
==
0
).
sum
()
*
100
/
(
100
*
100
*
5
)
assert
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
expected_overflows
=
(
tensor_fp8
.
_data
==
126
).
sum
()
*
100
/
(
100
*
100
*
5
)
"decoder.1.mlp.fc1"
,
tensor_name
=
"gradient"
,
iteration
=
200
)[
0
]
# TE FP8 tensor stats --
# TE FP8 tensor stats --
assert
debug_api
.
transformer_engine
.
inspect_tensor_
postquantize_
enabled
(
assert
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"gradient"
,
gemm
=
"wgrad"
,
iteration
=
200
"decoder.1.mlp.fc1"
,
tensor_name
=
"gradient"
,
iteration
=
200
)
)
[
0
]
debug_api
.
transformer_engine
.
inspect_tensor
_postquantize
(
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.1.mlp.fc1"
,
"decoder.1.mlp.fc1"
,
tensor
=
tensor_fp8
,
tensor_name
=
"gradient"
,
tensor_name
=
"gradient"
,
iteration
=
200
,
iteration
=
200
,
rowwise
=
True
,
tp_group
=
None
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
tensor_fp8
,
columnwise_quantized_tensor
=
tensor_fp8
,
)
)
stats
=
log
()
stats
=
log
()
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
stats
[(
"decoder.1.mlp.fc1"
,
"gradient"
,
"underflows%"
,
200
)],
expected_underflows
stats
[(
"decoder.1.mlp.fc1"
,
"gradient"
,
"underflows%"
,
200
)],
expected_underflows
)
)
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_
postquantize_
enabled
(
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"activation"
,
gemm
=
"fprop"
,
iteration
=
201
"decoder.1.mlp.fc1"
,
tensor_name
=
"activation"
,
iteration
=
201
)
)
[
0
]
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_
postquantize_
enabled
(
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.2.mlp.fc1"
,
tensor_name
=
"gradient"
,
gemm
=
"wgrad"
,
iteration
=
200
"decoder.2.mlp.fc1"
,
tensor_name
=
"gradient"
,
iteration
=
200
)
)
[
0
]
# Second config in same yaml
# Second config in same yaml
tensor
=
torch
.
rand
((
100
,
100
,
5
))
tensor
=
torch
.
rand
((
100
,
100
,
5
))
debug_api
.
transformer_engine
.
inspect_tensor
(
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.6.mlp.fc1"
,
"decoder.6.mlp.fc1"
,
tensor
=
tensor
,
tensor_name
=
"activation"
,
tensor_name
=
"activation"
,
iteration
=
200
,
iteration
=
200
,
tp_group
=
None
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
tensor_fp8
,
columnwise_quantized_tensor
=
tensor_fp8
,
)
)
stats
=
log
()
stats
=
log
()
stats_names
=
[
x
[
3
]
for
x
in
stats
.
keys
()]
stats_names
=
[
x
[
3
]
for
x
in
stats
.
keys
()]
...
@@ -316,10 +320,13 @@ def test_statistics_collection(configs_dir, feature_dirs):
...
@@ -316,10 +320,13 @@ def test_statistics_collection(configs_dir, feature_dirs):
debug_api
.
transformer_engine
.
inspect_tensor
(
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.7.mlp.fc1"
,
"decoder.7.mlp.fc1"
,
tensor
=
tensor
,
tensor_name
=
"weight"
,
tensor_name
=
"weight"
,
iteration
=
200
,
iteration
=
200
,
tp_group
=
None
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
tensor_fp8
,
columnwise_quantized_tensor
=
tensor_fp8
,
)
)
stats
=
log
()
stats
=
log
()
stats_names
=
[
x
[
3
]
for
x
in
stats
.
keys
()]
stats_names
=
[
x
[
3
]
for
x
in
stats
.
keys
()]
...
@@ -328,7 +335,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
...
@@ -328,7 +335,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.7.mlp.fc1"
,
tensor_name
=
"weight"
,
iteration
=
201
"decoder.7.mlp.fc1"
,
tensor_name
=
"weight"
,
iteration
=
201
)
)
[
0
]
assert_empty
()
assert_empty
()
finally
:
finally
:
...
@@ -343,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
...
@@ -343,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
default_logging_enabled
=
False
,
default_logging_enabled
=
False
,
)
)
def
feed
(
tensor
,
tensor_fp8
):
def
feed
(
tensor
,
tensor_fp8
,
quantizer
):
debug_api
.
transformer_engine
.
inspect_tensor
(
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.5.mlp.fc1"
,
"decoder.5.mlp.fc1"
,
tensor
=
tensor
,
tensor
=
tensor
,
tensor_name
=
"activation"
,
tensor_name
=
"activation"
,
iteration
=
1
,
iteration
=
1
,
tp_group
=
None
,
tp_group
=
None
,
)
quantizer
=
quantizer
,
debug_api
.
transformer_engine
.
inspect_tensor_postquantize
(
rowwise_quantized_tensor
=
tensor_fp8
,
"decoder.5.mlp.fc1"
,
columnwise_quantized_tensor
=
tensor_fp8
,
tensor
=
tensor_fp8
,
tensor_name
=
"activation"
,
iteration
=
1
,
rowwise
=
True
,
tp_group
=
None
,
)
)
def
log_stats
():
def
log_stats
():
...
@@ -365,26 +367,26 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
...
@@ -365,26 +367,26 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
return
STATS_BUFFERS
.
log_stats
()
return
STATS_BUFFERS
.
log_stats
()
quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
def
fp8_tensor
(
t
):
def
fp8_tensor
(
t
):
return
Float8Tensor
(
return
quantizer
(
t
.
cuda
())
data
=
t
.
to
(
torch
.
uint8
).
cuda
(),
fp8_scale_inv
=
torch
.
ones
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
shape
=
t
.
shape
,
dtype
=
torch
.
float32
,
)
shape
=
[
1024
,
1024
]
shape
=
[
1024
,
1024
]
tensors
=
[
torch
.
randn
(
shape
)
for
_
in
range
(
2
)]
tensors
=
[
torch
.
randn
(
shape
)
for
_
in
range
(
2
)]
tensors_fp8
=
[
fp8_tensor
(
tensors
[
i
])
for
i
in
range
(
2
)]
tensors_fp8
=
[
fp8_tensor
(
tensors
[
i
])
for
i
in
range
(
2
)]
feed
(
tensors
[
0
],
tensors_fp8
[
0
])
feed
(
tensors
[
0
],
tensors_fp8
[
0
]
,
quantizer
)
feed
(
tensors
[
1
],
tensors_fp8
[
1
])
feed
(
tensors
[
1
],
tensors_fp8
[
1
]
,
quantizer
)
stats1
=
log_stats
()
stats1
=
log_stats
()
tensor2
=
torch
.
cat
((
tensors
[
0
],
tensors
[
1
])).
cuda
()
tensor2
=
torch
.
cat
((
tensors
[
0
],
tensors
[
1
])).
cuda
()
fp8tensor2
=
fp8_tensor
(
tensor2
)
fp8tensor2
=
fp8_tensor
(
tensor2
)
feed
(
tensor2
,
fp8tensor2
)
feed
(
tensor2
,
fp8tensor2
,
quantizer
)
stats2
=
log_stats
()
stats2
=
log_stats
()
assert
len
(
stats1
.
keys
())
>
0
assert
len
(
stats1
.
keys
())
>
0
...
...
tests/pytorch/debug/test_configs/log_config.yaml
0 → 100644
View file @
87e3e56e
test
:
enabled
:
True
layers
:
layer_name_regex_pattern
:
.*
transformer_engine
:
LogTensorStats
:
enabled
:
True
tensors_struct
:
-
tensor
:
activation
stats
:
[
cur_amax
,
dynamic_range
,
mean
,
std
,
l1_norm
]
start_step
:
1
freq
:
3
LogFp8TensorStats
:
enabled
:
True
tensors
:
activation
stats
:
[
underflows%
]
start_step
:
1
freq
:
5
\ No newline at end of file
tests/pytorch/debug/test_configs/perf_config.yaml
0 → 100644
View file @
87e3e56e
test
:
enabled
:
True
layers
:
layer_name_regex_pattern
:
.*1
transformer_engine
:
LogTensorStats
:
enabled
:
True
tensors_struct
:
-
tensor
:
activation
stats
:
[
cur_amax
,
dynamic_range
,
mean
,
std
,
l1_norm
]
start_step
:
0
freq
:
100000
\ No newline at end of file
tests/pytorch/debug/test_log.py
0 → 100644
View file @
87e3e56e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
nvdlfw_inspect.api
as
debug_api
import
transformer_engine.debug
import
transformer_engine.pytorch
as
te
import
torch
import
tempfile
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.fp8
import
RecipeState
import
pytest
import
contextlib
import
os
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
LOG_QUANTIZED_CONFIG_BASE
=
"""
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
recipes
=
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"fp8_block_scaling"
,
"mxfp8"
,
]
bare_stats
=
[
"underflows%"
,
"scale_inv_min"
,
"scale_inv_max"
,
"mse"
,
]
all_stats
=
[]
for
r
in
recipes
:
for
stat
in
bare_stats
:
for
columnwise_postfix
in
[
""
,
"_columnwise"
]:
if
(
r
in
[
"fp8_current_scaling"
,
"fp8_block_scaling"
]
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
9
):
# hopper is needed for current-scaling, block-scaling
continue
if
r
==
"mxfp8"
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
10
:
# blackwell is needed for mxfp8
continue
if
(
r
in
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
]
and
columnwise_postfix
==
"_columnwise"
):
# columnwise stats are not supported for fp8_delayed_scaling and fp8_current_scaling
continue
all_stats
.
append
(
f
"
{
r
}
_
{
stat
}{
columnwise_postfix
}
"
)
all_stats
.
append
(
"fp8_delayed_scaling_overflows%"
)
# only delayed-scaling supports overflows%
@
contextlib
.
contextmanager
def
debug_session
(
config_str
:
str
,
feature_dirs
):
"""
Helper context manager that
1. writes the YAML `config_str` to a temporary file,
2. starts a debug session, and
3. yields the directory that contains the statistics log.
The session is closed automatically – even on exceptions – so every test
stays concise and leak-free.
"""
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
delete
=
False
)
as
cfg_file
,
tempfile
.
TemporaryDirectory
()
as
log_dir
:
cfg_file
.
write
(
config_str
)
cfg_file
.
flush
()
debug_api
.
initialize
(
config_file
=
cfg_file
.
name
,
feature_dirs
=
feature_dirs
,
log_dir
=
log_dir
,
)
try
:
yield
log_dir
finally
:
debug_api
.
end_debug
()
def
read_log
(
log_dir
:
str
)
->
str
:
"""Return the content of the statistics log produced by `debug_session`."""
stat_path
=
os
.
path
.
join
(
log_dir
,
"nvdlfw_inspect_statistics_logs"
,
"nvdlfw_inspect_globalrank-0.log"
,
)
with
open
(
stat_path
,
"r"
)
as
f
:
return
f
.
read
()
def
test_sanity
(
feature_dirs
):
log_all_stats_config
=
LOG_QUANTIZED_CONFIG_BASE
.
format
(
stats
=
", "
.
join
(
all_stats
))
with
debug_session
(
log_all_stats_config
,
feature_dirs
)
as
log_dir
:
model
=
te
.
Linear
(
128
,
128
,
params_dtype
=
torch
.
bfloat16
)
inp
=
torch
.
zeros
(
128
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
for
_
in
range
(
10
):
with
te
.
fp8_autocast
(
fp8_recipe
=
recipe
.
DelayedScaling
()):
output
=
model
(
inp
)
loss
=
output
.
sum
()
loss
.
backward
()
debug_api
.
step
()
output
=
read_log
(
log_dir
)
assert
output
,
"Output is empty"
for
stat
in
all_stats
:
assert
stat
in
output
,
f
"Stat
{
stat
}
not found in output"
fp8_recipes
=
[
recipe
.
MXFP8BlockScaling
(),
recipe
.
DelayedScaling
(),
recipe
.
Float8CurrentScaling
(),
recipe
.
Float8BlockScaling
(),
]
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
test_numerics
(
fp8_recipe
,
feature_dirs
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
not
mxfp8_available
and
fp8_recipe
==
recipe
.
MXFP8BlockScaling
():
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
fp8_block_scaling_available
and
fp8_recipe
==
recipe
.
Float8BlockScaling
():
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
log_only_bare_stats_config
=
LOG_QUANTIZED_CONFIG_BASE
.
format
(
stats
=
", "
.
join
(
bare_stats
))
with
debug_session
(
log_only_bare_stats_config
,
feature_dirs
)
as
log_dir
:
recipe_state
=
RecipeState
.
create
(
fp8_recipe
,
mode
=
"forward"
,
num_quantizers
=
3
,
)
tensor
=
torch
.
zeros
(
1024
,
1024
).
cuda
()
tensor
[
0
,
:]
=
1000
quantizer
=
recipe_state
.
make_quantizers
()[
0
]
quantized_tensor
=
quantizer
(
tensor
)
debug_api
.
transformer_engine
.
inspect_tensor
(
layer_name
=
"layer_name"
,
tensor_name
=
"activation"
,
iteration
=
0
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
quantized_tensor
,
columnwise_quantized_tensor
=
quantized_tensor
,
)
debug_api
.
step
()
dequantized_tensor
=
quantized_tensor
.
dequantize
()
output
=
read_log
(
log_dir
)
for
line
in
output
.
splitlines
():
if
"underflows%"
in
line
:
underflows
=
float
(
line
.
split
(
"value="
)[
1
])
expected
=
(
((
dequantized_tensor
==
0
).
sum
()
-
(
tensor
==
0
).
sum
())
/
dequantized_tensor
.
numel
()
*
100
)
assert
underflows
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-4
)
if
"mse"
in
line
:
mse
=
float
(
line
.
split
(
"value="
)[
1
])
expected
=
torch
.
nn
.
functional
.
mse_loss
(
dequantized_tensor
,
tensor
,
reduction
=
"mean"
)
assert
mse
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-6
)
if
"overflows%"
in
line
:
overflows
=
float
(
line
.
split
(
"value="
)[
1
])
expected
=
(
(
abs
(
dequantized_tensor
)
>
abs
(
tensor
)).
sum
()
/
dequantized_tensor
.
numel
()
*
100
)
assert
overflows
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-4
)
@
pytest
.
mark
.
parametrize
(
"layer"
,
[
"linear"
,
"transformer"
])
def
test_log_every_3_or_5_layers
(
layer
,
configs_dir
,
feature_dirs
):
# If layer does not invoke any feature in current iteration,
# then it changed into non-debug mode.
# This test checks whether this works correctly -
# non-quantized statistics should be logged every 3 iterations,
# and quantized statistics should be logged every 5 iterations.
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
debug_api
.
initialize
(
config_file
=
configs_dir
+
"/log_config.yaml"
,
feature_dirs
=
feature_dirs
,
log_dir
=
temp_dir
,
)
if
layer
==
"linear"
:
model
=
te
.
Linear
(
128
,
128
,
name
=
"linear1"
)
elif
layer
==
"transformer"
:
model
=
te
.
TransformerLayer
(
128
,
128
,
4
,
name
=
"transformer1"
)
else
:
raise
ValueError
(
f
"Invalid layer:
{
layer
}
"
)
for
i
in
range
(
20
):
x
=
torch
.
randn
(
4
,
128
,
128
).
cuda
()
with
te
.
fp8_autocast
(
enabled
=
True
):
y
=
model
(
x
)
y
.
sum
().
backward
()
debug_api
.
step
()
with
open
(
os
.
path
.
join
(
temp_dir
,
"nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
),
"r"
,
)
as
f
:
file_content
=
f
.
read
()
for
i
in
range
(
1
,
20
):
if
i
%
3
==
0
or
i
%
5
==
0
:
assert
f
"iteration=
{
i
:
06
d
}
"
in
file_content
else
:
assert
f
"iteration=
{
i
:
06
d
}
"
not
in
file_content
debug_api
.
end_debug
()
TEDebugState
.
_reset
()
tests/pytorch/debug/test_perf.py
0 → 100644
View file @
87e3e56e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
import
time
import
nvdlfw_inspect.api
as
debug_api
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
def
_run_cpu_overhead
(
debug_tools_initialized
,
layer
,
configs_dir
,
feature_dirs
):
debug_api
.
end_debug
()
TEDebugState
.
_reset
()
if
debug_tools_initialized
:
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api
.
initialize
(
config_file
=
configs_dir
+
"/perf_config.yaml"
,
feature_dirs
=
feature_dirs
)
try
:
if
layer
==
"linear"
:
model
=
torch
.
nn
.
Sequential
(
te
.
Linear
(
1
,
1
,
name
=
"linear1"
),
te
.
Linear
(
1
,
1
,
name
=
"linear2"
)
).
cuda
()
NUM_ITERS
=
18000
elif
layer
==
"transformer"
:
model
=
torch
.
nn
.
Sequential
(
te
.
TransformerLayer
(
1
,
1
,
1
,
name
=
"transformer1"
),
te
.
TransformerLayer
(
1
,
1
,
1
,
name
=
"transformer2"
),
).
cuda
()
NUM_ITERS
=
2000
x
=
torch
.
randn
(
1
,
1
,
1
).
cuda
()
y
=
model
(
x
)
y
.
sum
().
backward
()
debug_api
.
step
()
torch
.
cuda
.
synchronize
()
time_start
=
time
.
time
()
for
i
in
range
(
NUM_ITERS
):
y
=
model
(
x
)
y
.
sum
().
backward
()
if
debug_tools_initialized
:
debug_api
.
step
()
torch
.
cuda
.
synchronize
()
time_end
=
time
.
time
()
finally
:
if
debug_tools_initialized
:
debug_api
.
end_debug
()
return
time_end
-
time_start
@
pytest
.
mark
.
parametrize
(
"layer"
,
[
"linear"
,
"transformer"
])
def
test_cpu_overhead
(
layer
,
configs_dir
,
feature_dirs
):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.
with_debug_tools
=
_run_cpu_overhead
(
True
,
layer
,
configs_dir
,
feature_dirs
)
without_debug_tools
=
_run_cpu_overhead
(
False
,
layer
,
configs_dir
,
feature_dirs
)
print
(
f
"with_debug_tools:
{
with_debug_tools
}
s"
)
print
(
f
"without_debug_tools:
{
without_debug_tools
}
s"
)
assert
with_debug_tools
<
without_debug_tools
*
1.25
# 25% overhead margin
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
87e3e56e
...
@@ -519,6 +519,7 @@ def _train(opts):
...
@@ -519,6 +519,7 @@ def _train(opts):
if
opts
.
use_cuda_graphs
:
if
opts
.
use_cuda_graphs
:
del
test_graph
del
test_graph
torch
.
cuda
.
synchronize
()
te
.
module
.
base
.
destroy_ub
()
te
.
module
.
base
.
destroy_ub
()
dist_print
(
"Destroying Userbuffers objects..."
,
debug
=
True
)
dist_print
(
"Destroying Userbuffers objects..."
,
debug
=
True
)
...
...
tests/pytorch/distributed/test_sanity.py
0 → 100644
View file @
87e3e56e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pathlib
import
sys
import
pytest
import
torch
import
transformer_engine
from
transformer_engine.pytorch.attention.dot_product_attention
import
DotProductAttention
from
transformer_engine.pytorch
import
TransformerLayer
,
Linear
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
from
utils
import
ModelConfig
model_configs
=
{
"small"
:
ModelConfig
(
2
,
10
,
2
,
16
),
}
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
])
def
test_current_device
(
model
,
module
):
"""Test cases where current device is different from tensor device"""
num_devices
=
torch
.
cuda
.
device_count
()
assert
num_devices
>
1
,
"This test requires more than one GPU!"
tensor_device
=
num_devices
-
1
dtype
=
torch
.
bfloat16
config
=
model_configs
[
model
]
args
=
[]
kwargs
=
{}
bwd_args
=
[]
if
module
==
"TransformerLayer"
:
model
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_heads
,
params_dtype
=
dtype
,
attn_input_format
=
"thd"
,
self_attn_mask_type
=
"padding"
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
args
=
[
torch
.
randn
(
(
num_tokens
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
)
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
if
module
==
"DotProductAttention"
:
model
=
DotProductAttention
(
config
.
num_heads
,
config
.
head_dim_qk
,
qkv_format
=
"thd"
,
attn_mask_type
=
"padding"
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
args
=
[
torch
.
randn
(
num_tokens
,
config
.
num_heads
,
config
.
head_dim_qk
,
dtype
=
dtype
,
device
=
tensor_device
,
requires_grad
=
True
,
)
for
_
in
range
(
3
)
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
tensor_device
)]
elif
module
==
"Linear"
:
model
=
Linear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
params_dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
args
=
[
torch
.
randn
(
(
config
.
max_seqlen_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
)
]
current_device_before
=
torch
.
cuda
.
current_device
()
out
=
model
(
*
args
,
**
kwargs
)
if
module
==
"DotProductAttention"
:
out
.
backward
(
*
bwd_args
)
else
:
loss
=
out
.
sum
()
loss
.
backward
()
current_device_after
=
torch
.
cuda
.
current_device
()
tensor_device_out
=
out
.
get_device
()
tensor_device_grad
=
args
[
0
].
grad
.
get_device
()
assert
(
current_device_after
==
current_device_before
),
"The current device should not have changed!"
assert
(
tensor_device_out
==
tensor_device
),
"The output tensor should be the same as the input tensors!"
assert
(
tensor_device_grad
==
tensor_device
),
"The gradient tensor should be the same as the input tensors!"
tests/pytorch/test_cpu_offloading.py
View file @
87e3e56e
...
@@ -10,22 +10,24 @@ import torch
...
@@ -10,22 +10,24 @@ import torch
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
utils
import
ModelConfig
,
get_available_attention_backends
# Check if FP8 is supported
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_recipes
=
[
fp8_recipes
=
[
None
]
None
,
# non-fp8
if
fp8_available
:
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
recipe
.
Float8CurrentScaling
(),
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
recipe
.
DelayedScaling
(),
]
SIZE
=
512
model_config
=
{
NUM_HEADS
=
8
"small"
:
ModelConfig
(
8
,
512
,
8
,
64
,
num_layers
=
5
,
eps
=
0.1
),
NUM_LAYERS
=
5
}
EPSILON
=
0.1
SIZE
=
model_config
[
"small"
].
hidden_size
NUM_HEADS
=
model_config
[
"small"
].
num_heads
NUM_LAYERS
=
model_config
[
"small"
].
num_layers
EPSILON
=
model_config
[
"small"
].
eps
# Flash attention saves some internal tensor for the backward pass
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
# that cannot be offloaded to CPU.
...
@@ -124,11 +126,17 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
...
@@ -124,11 +126,17 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
model_cls
=
model_types
[
model_key
]
model_cls
=
model_types
[
model_key
]
models_list
=
[
model_cls
()
for
_
in
range
(
NUM_LAYERS
)]
models_list
=
[
model_cls
()
for
_
in
range
(
NUM_LAYERS
)]
if
fp8_recipe
and
not
fp8_available
:
if
model_key
in
[
"multihead_attention"
,
"transformer_layer"
]:
pytest
.
skip
(
reason_for_no_fp8
)
available_backends
,
*
_
=
get_available_attention_backends
(
if
fp8_recipe
is
not
None
:
model_config
[
"small"
],
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
qkv_dtype
=
torch
.
bfloat16
,
pytest
.
skip
(
reason_for_no_mxfp8
)
qkv_layout
=
"sbhd_sbhd_sbhd"
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
pytest
.
skip
(
"Fused attention backend not available."
)
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
without_offloading
=
_measure_memory_between_forward_and_backward
(
without_offloading
=
_measure_memory_between_forward_and_backward
(
models_list
,
fp8_recipe
,
False
models_list
,
fp8_recipe
,
False
...
...
tests/pytorch/test_cuda_graphs.py
View file @
87e3e56e
...
@@ -2,9 +2,7 @@
...
@@ -2,9 +2,7 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
from
dataclasses
import
dataclass
from
typing
import
Iterable
,
List
,
Union
import
itertools
from
typing
import
Iterable
,
List
,
Tuple
,
Union
import
pytest
import
pytest
import
torch
import
torch
...
@@ -23,46 +21,32 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
...
@@ -23,46 +21,32 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine.pytorch.ops
as
te_ops
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
from
utils
import
ModelConfig
,
reset_rng_states
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
import
os
import
os
from
functools
import
cache
from
functools
import
cache
# Check if FP8 is supported.
# Check if FP8 is supported.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Reset RNG states.
reset_rng_states
()
# Record initial RNG state.
model_configs
=
{
seed
=
1234
"small"
:
ModelConfig
(
32
,
2
,
2
,
32
),
torch
.
manual_seed
(
seed
)
}
torch
.
cuda
.
manual_seed
(
seed
)
_cpu_rng_state
=
torch
.
get_rng_state
()
fp8_recipes
=
[]
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
if
mxfp8_available
:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
fp8_block_scaling_available
:
@
dataclass
fp8_recipes
.
append
(
recipe
.
Float8BlockScaling
())
class
ModelConfig
:
if
fp8_available
:
"""Data tensor dimensions within Transformer model"""
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
sequence_length
:
int
batch_size
:
int
hidden_size
:
int
num_heads
:
int
kv_channels
:
int
model_configs
=
{
"small"
:
ModelConfig
(
2
,
32
,
64
,
2
,
32
)}
fp8_recipes
=
[
recipe
.
DelayedScaling
(),
recipe
.
MXFP8BlockScaling
(),
recipe
.
Float8CurrentScaling
(),
recipe
.
Float8BlockScaling
(),
]
# Supported data types
# Supported data types
dtypes
:
List
[
torch
.
dtype
]
=
[
torch
.
float32
,
torch
.
float16
]
dtypes
:
List
[
torch
.
dtype
]
=
[
torch
.
float32
,
torch
.
float16
]
...
@@ -70,12 +54,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
...
@@ -70,12 +54,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes
.
append
(
torch
.
bfloat16
)
dtypes
.
append
(
torch
.
bfloat16
)
def
reset_rng_states
()
->
None
:
"""Revert to initial RNG state."""
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_global_fp8_state
():
def
reset_global_fp8_state
():
yield
yield
...
@@ -119,7 +97,7 @@ def generate_data(
...
@@ -119,7 +97,7 @@ def generate_data(
"""Generate synthetic data."""
"""Generate synthetic data."""
gen_func
=
torch
.
ones
if
warmup
else
torch
.
randn
gen_func
=
torch
.
ones
if
warmup
else
torch
.
randn
return
gen_func
(
return
gen_func
(
model_config
.
seq
u
en
ce_length
,
model_config
.
max_
seq
l
en
_q
,
model_config
.
batch_size
,
model_config
.
batch_size
,
model_config
.
hidden_size
,
model_config
.
hidden_size
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -157,10 +135,12 @@ class _Sequential(torch.nn.Sequential):
...
@@ -157,10 +135,12 @@ class _Sequential(torch.nn.Sequential):
# Supported modules
# Supported modules
_test_cuda_graphs_modules
:
List
[
str
]
=
[
_test_cuda_graphs_modules
:
List
[
str
]
=
[
# Put linear first to test the case where the cuda context might not be set in
# creating TMA descriptor for MXFP8 quantization.
"linear"
,
"transformer"
,
"transformer"
,
"layernorm_mlp"
,
"layernorm_mlp"
,
"layernorm_linear"
,
"layernorm_linear"
,
"linear"
,
"mha"
,
"mha"
,
"linear_op"
,
"linear_op"
,
]
]
...
@@ -310,35 +290,27 @@ def _test_cuda_graphs(
...
@@ -310,35 +290,27 @@ def _test_cuda_graphs(
@
pytest
.
mark
.
parametrize
(
"module"
,
_test_cuda_graphs_modules
)
@
pytest
.
mark
.
parametrize
(
"module"
,
_test_cuda_graphs_modules
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
@
pytest
.
mark
.
parametrize
(
"fp8"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_params"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_params"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
+
[
None
]
)
def
test_make_graphed_callables
(
def
test_make_graphed_callables
(
*
,
*
,
module
:
str
,
module
:
str
,
model_config
:
str
=
"small"
,
model_config
:
str
=
"small"
,
num_layers
:
int
=
3
,
num_layers
:
int
=
3
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
fp8
:
bool
,
fp8_params
:
bool
,
fp8_params
:
bool
,
fp8_recipe
:
recipe
.
Recipe
,
fp8_recipe
:
recipe
.
Recipe
,
fp8_weight_caching
:
bool
=
False
,
fp8_weight_caching
:
bool
=
False
,
)
->
None
:
)
->
None
:
# Skip invalid configurations.
fp8
=
fp8_recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_params
and
not
fp8
:
if
fp8_params
and
not
fp8
:
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
if
fp8_weight_caching
and
not
fp8
:
if
fp8_weight_caching
and
not
fp8
:
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
module
==
"linear_op"
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
module
==
"linear_op"
:
pytest
.
skip
(
"Module not yet supported for float8_block_scaling with CUDA graphs"
)
pytest
.
skip
(
"Module not yet supported for float8_block_scaling with CUDA graphs"
)
# Run model with different CUDA graph settings.
# Run model with different CUDA graph settings.
model_config
=
model_configs
[
model_config
]
model_config
=
model_configs
[
model_config
]
kwargs
=
dict
(
kwargs
=
dict
(
...
@@ -351,9 +323,11 @@ def test_make_graphed_callables(
...
@@ -351,9 +323,11 @@ def test_make_graphed_callables(
fp8_weight_caching
=
fp8_weight_caching
,
fp8_weight_caching
=
fp8_weight_caching
,
fp8_recipe
=
fp8_recipe
,
fp8_recipe
=
fp8_recipe
,
)
)
outputs
=
_test_cuda_graphs
(
graph_mode
=
"none"
,
**
kwargs
)
# Put graphed callables first to test the case where the cuda context might not be set in
# creating TMA descriptor for MXFP8 quantization.
graph_outputs_mode1
=
_test_cuda_graphs
(
graph_mode
=
"full"
,
**
kwargs
)
graph_outputs_mode1
=
_test_cuda_graphs
(
graph_mode
=
"full"
,
**
kwargs
)
graph_outputs_mode2
=
_test_cuda_graphs
(
graph_mode
=
"individual"
,
**
kwargs
)
graph_outputs_mode2
=
_test_cuda_graphs
(
graph_mode
=
"individual"
,
**
kwargs
)
outputs
=
_test_cuda_graphs
(
graph_mode
=
"none"
,
**
kwargs
)
# Check that results match.
# Check that results match.
assert_all_equal
(
outputs
,
graph_outputs_mode1
)
assert_all_equal
(
outputs
,
graph_outputs_mode1
)
...
@@ -369,7 +343,6 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
...
@@ -369,7 +343,6 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
]
]
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"module"
,
"module"
,
_test_make_graphed_callables_with_fp8_weight_caching_modules
,
_test_make_graphed_callables_with_fp8_weight_caching_modules
,
...
@@ -385,7 +358,6 @@ def test_make_graphed_callables_with_fp8_weight_caching(
...
@@ -385,7 +358,6 @@ def test_make_graphed_callables_with_fp8_weight_caching(
test_make_graphed_callables
(
test_make_graphed_callables
(
module
=
module
,
module
=
module
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
fp8
=
True
,
fp8_params
=
fp8_params
,
fp8_params
=
fp8_params
,
fp8_recipe
=
fp8_recipe
,
fp8_recipe
=
fp8_recipe
,
fp8_weight_caching
=
True
,
fp8_weight_caching
=
True
,
...
@@ -401,7 +373,7 @@ def generate_data_for_dot_product_attention(
...
@@ -401,7 +373,7 @@ def generate_data_for_dot_product_attention(
gen_func
=
torch
.
ones
if
warmup
else
torch
.
randn
gen_func
=
torch
.
ones
if
warmup
else
torch
.
randn
return
[
return
[
gen_func
(
gen_func
(
model_config
.
seq
u
en
ce_length
,
model_config
.
max_
seq
l
en
_q
,
model_config
.
batch_size
,
model_config
.
batch_size
,
model_config
.
num_heads
,
model_config
.
num_heads
,
model_config
.
kv_channels
,
model_config
.
kv_channels
,
...
@@ -495,8 +467,8 @@ def _test_cuda_graphs_with_kwargs(
...
@@ -495,8 +467,8 @@ def _test_cuda_graphs_with_kwargs(
(
(
model_config
.
batch_size
,
model_config
.
batch_size
,
1
,
1
,
model_config
.
seq
u
en
ce_length
,
model_config
.
max_
seq
l
en
_q
,
model_config
.
sequence_length
,
model_config
.
max_seqlen_kv
,
),
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -522,8 +494,8 @@ def _test_cuda_graphs_with_kwargs(
...
@@ -522,8 +494,8 @@ def _test_cuda_graphs_with_kwargs(
(
(
model_config
.
batch_size
,
model_config
.
batch_size
,
1
,
1
,
model_config
.
seq
u
en
ce_length
,
model_config
.
max_
seq
l
en
_q
,
model_config
.
sequence_length
,
model_config
.
max_seqlen_kv
,
),
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
"cuda"
,
device
=
"cuda"
,
...
...
tests/pytorch/test_float8blockwisetensor.py
View file @
87e3e56e
...
@@ -223,7 +223,7 @@ class TestFloat8BlockwiseTensor:
...
@@ -223,7 +223,7 @@ class TestFloat8BlockwiseTensor:
rowwise
=
True
,
rowwise
=
True
,
columnwise
=
dq_columnwise
,
columnwise
=
dq_columnwise
,
block_scaling_dim
=
block_scaling_dim
,
block_scaling_dim
=
block_scaling_dim
,
all_gather_usage
=
True
,
all_gather_usage
=
(
block_scaling_dim
==
1
)
,
)
)
self
.
_test_quantize_dequantize
(
self
.
_test_quantize_dequantize
(
quantizer
=
quantizer
,
quantizer
=
quantizer
,
...
...
tests/pytorch/test_fused_optimizer.py
View file @
87e3e56e
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
from
itertools
import
product
import
copy
import
copy
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
...
@@ -112,13 +111,6 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -112,13 +111,6 @@ class TestFusedAdam(TestFusedOptimizer):
def
test_bfloat16
(
self
):
def
test_bfloat16
(
self
):
self
.
gen_single_type_test
(
param_type
=
torch
.
bfloat16
,
skip_assert
=
True
)
self
.
gen_single_type_test
(
param_type
=
torch
.
bfloat16
,
skip_assert
=
True
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"more than 1 GPU required"
)
def
test_multi_device
(
self
):
devices
=
(
"cuda:0"
,
"cuda:1"
)
for
current_dev
,
tensor_dev
in
product
(
devices
,
devices
):
with
torch
.
cuda
.
device
(
current_dev
):
self
.
gen_single_type_test
(
param_type
=
torch
.
float
,
device
=
tensor_dev
)
def
test_multi_params
(
self
):
def
test_multi_params
(
self
):
sizes
=
[[
4096
,
1024
],
[
4096
],
[
4096
,
2048
],
[
32320
,
1024
],
[
1
]]
sizes
=
[[
4096
,
1024
],
[
4096
],
[
4096
,
2048
],
[
32320
,
1024
],
[
1
]]
...
@@ -530,13 +522,6 @@ class TestFusedSGD(TestFusedOptimizer):
...
@@ -530,13 +522,6 @@ class TestFusedSGD(TestFusedOptimizer):
def
test_half
(
self
):
def
test_half
(
self
):
self
.
gen_single_type_test
(
param_type
=
torch
.
float16
)
self
.
gen_single_type_test
(
param_type
=
torch
.
float16
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"more than 1 GPU required"
)
def
test_multi_device
(
self
):
devices
=
(
"cuda:0"
,
"cuda:1"
)
for
current_dev
,
tensor_dev
in
product
(
devices
,
devices
):
with
torch
.
cuda
.
device
(
current_dev
):
self
.
gen_single_type_test
(
param_type
=
torch
.
float
,
device
=
tensor_dev
)
class
Model
(
torch
.
nn
.
Module
):
class
Model
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
...
tests/pytorch/test_fused_router.py
View file @
87e3e56e
...
@@ -2,8 +2,7 @@
...
@@ -2,8 +2,7 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
import
torch
import
torch
import
math
from
typing
import
Optional
from
typing
import
Optional
,
Dict
from
transformer_engine.pytorch.router
import
(
from
transformer_engine.pytorch.router
import
(
fused_topk_with_score_function
,
fused_topk_with_score_function
,
fused_compute_score_for_moe_aux_loss
,
fused_compute_score_for_moe_aux_loss
,
...
@@ -149,11 +148,21 @@ def run_comparison(
...
@@ -149,11 +148,21 @@ def run_comparison(
# Set some parameters
# Set some parameters
if
score_function
==
"sigmoid"
:
if
score_function
==
"sigmoid"
:
# Construct the special logits to avoid inf in the sigmoid function
# Construct the special logits to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
offset
=
torch
.
arange
(
-
num_tokens
//
2
,
num_tokens
//
2
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
logits
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
logits
=
(
torch
.
arange
(
-
num_experts
//
2
,
num_experts
//
2
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
)
logits
=
logits
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
logits
=
logits
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
else
:
else
:
logits
=
torch
.
arange
(
num_tokens
*
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-4
logits
=
(
torch
.
arange
(
-
num_tokens
*
num_experts
//
2
,
num_tokens
*
num_experts
//
2
,
device
=
"cuda"
,
dtype
=
dtype
,
)
*
1e-4
)
logits
=
logits
.
view
(
num_tokens
,
num_experts
)
logits
=
logits
.
view
(
num_tokens
,
num_experts
)
logits
.
requires_grad
=
True
logits
.
requires_grad
=
True
if
enable_bias
and
score_function
==
"sigmoid"
:
if
enable_bias
and
score_function
==
"sigmoid"
:
...
@@ -282,11 +291,21 @@ def test_topk_softmax(
...
@@ -282,11 +291,21 @@ def test_topk_softmax(
def
test_fused_scores_for_aux_loss
(
dtype
,
num_tokens
,
num_experts
,
topk
,
score_function
):
def
test_fused_scores_for_aux_loss
(
dtype
,
num_tokens
,
num_experts
,
topk
,
score_function
):
if
score_function
==
"sigmoid"
:
if
score_function
==
"sigmoid"
:
# Construct the special logits to avoid inf in the sigmoid function
# Construct the special logits to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
offset
=
torch
.
arange
(
-
num_tokens
//
2
,
num_tokens
//
2
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
logits
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
logits
=
(
torch
.
arange
(
-
num_experts
//
2
,
num_experts
//
2
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
)
logits
=
logits
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
logits
=
logits
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
else
:
else
:
logits
=
torch
.
arange
(
num_tokens
*
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-4
logits
=
(
torch
.
arange
(
-
num_tokens
*
num_experts
//
2
,
num_tokens
*
num_experts
//
2
,
device
=
"cuda"
,
dtype
=
dtype
,
)
*
1e-4
)
logits
=
logits
.
view
(
num_tokens
,
num_experts
)
logits
=
logits
.
view
(
num_tokens
,
num_experts
)
logits
.
requires_grad
=
True
logits
.
requires_grad
=
True
...
@@ -322,8 +341,8 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
...
@@ -322,8 +341,8 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
])
def
test_fused_moe_aux_loss
(
dtype
,
num_tokens
,
num_experts
,
topk
):
def
test_fused_moe_aux_loss
(
dtype
,
num_tokens
,
num_experts
,
topk
):
# Construct the special probs to avoid inf in the sigmoid function
# Construct the special probs to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
offset
=
torch
.
arange
(
-
num_tokens
//
2
,
num_tokens
//
2
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
probs
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
probs
=
torch
.
arange
(
-
num_experts
//
2
,
num_experts
//
2
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
probs
=
probs
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
probs
=
probs
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
probs
=
probs
.
view
(
num_tokens
,
num_experts
)
probs
=
probs
.
view
(
num_tokens
,
num_experts
)
probs
.
requires_grad
=
True
probs
.
requires_grad
=
True
...
@@ -380,15 +399,12 @@ def profile_topk_softmax(
...
@@ -380,15 +399,12 @@ def profile_topk_softmax(
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_fused_scores_for_aux_loss
(
test_topk_softmax
(
dtype
=
torch
.
float32
,
num_tokens
=
2
,
num_experts
=
32
,
topk
=
8
,
score_function
=
"softmax"
dtype
=
torch
.
float32
,
num_tokens
=
1024
,
num_experts
=
128
,
topk
=
4
,
use_pre_softmax
=
False
,
group_topk
=
None
,
scaling_factor
=
None
,
)
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
256
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
256
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
256
,
topk
=
4
)
tests/pytorch/test_fusible_ops.py
View file @
87e3e56e
...
@@ -21,10 +21,12 @@ import transformer_engine.pytorch as te
...
@@ -21,10 +21,12 @@ import transformer_engine.pytorch as te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops.fused
import
(
from
transformer_engine.pytorch.ops.fused
import
(
Backward
Bias
Activation
,
BackwardActivation
Bias
,
BackwardLinearAdd
,
BackwardLinearAdd
,
BackwardLinearScale
,
ForwardLinearBiasActivation
,
ForwardLinearBiasActivation
,
ForwardLinearBiasAdd
,
ForwardLinearBiasAdd
,
ForwardLinearScaleAdd
,
)
)
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
...
@@ -39,7 +41,7 @@ import transformer_engine_torch as tex
...
@@ -39,7 +41,7 @@ import transformer_engine_torch as tex
# Import utility functions
# Import utility functions
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
))
sys
.
path
.
append
(
str
(
_current_file
.
parent
))
from
utils
import
dtype_tols
,
make_recipe
from
utils
import
dtype_tols
,
make_recipe
,
reset_rng_states
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
import
os
import
os
...
@@ -271,16 +273,72 @@ class TestSequentialContainer:
...
@@ -271,16 +273,72 @@ class TestSequentialContainer:
model
(
torch
.
zeros
(
1
))
model
(
torch
.
zeros
(
1
))
assert
len
(
model
.
_module_groups
)
==
6
assert
len
(
model
.
_module_groups
)
==
6
def
test_extra_tensors
(
self
,
size
:
int
=
16
)
->
None
:
"""Check that extra inputs are distributed properly between module groups
and that extra outputs are properly collected"""
# Construct sequential container
bias
=
te_ops
.
Bias
(
size
=
size
,
device
=
"cpu"
)
with
torch
.
no_grad
():
bias
.
bias
.
copy_
(
torch
.
rand
((
size
,)))
model
=
te_ops
.
Sequential
(
# | Inputs | Outputs
torch
.
nn
.
Identity
(),
# | x1 | x1
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
# | x1 | x1 [x1]
bias
,
# | x1 | h1 (= x1 + b)
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
# | h1 | h1 [h1]
te_ops
.
AddExtraInput
(
in_place
=
True
),
# | h1 [x2] | x2 (= x2 + h1)
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
# | x2 | x2 [x2]
torch
.
nn
.
Identity
(),
# | x2 | x2
bias
,
# | x2 | h2 (= x2 + b)
te_ops
.
AddExtraInput
(
in_place
=
True
),
# | h2 [x3] | x3 (= x3 + h2)
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
# | x3 | x3 [x3]
te_ops
.
AddExtraInput
(
in_place
=
True
),
# | x3 [x4] | x4 (= x4 + x3)
torch
.
nn
.
Identity
(),
# | x4 | x4
te_ops
.
Identity
(),
# | x4 | x4
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
# | x4 | x4 [x4]
te_ops
.
Identity
(),
# | x4 | x4
)
# Create input tensors
x1
=
torch
.
rand
((
size
,))
x2
=
torch
.
rand
((
size
,))
x3
=
torch
.
rand
((
size
,))
x4
=
torch
.
rand
((
size
,))
# Save original input tensor values
x1_orig
=
x1
.
clone
()
x2_orig
=
x2
.
clone
()
x3_orig
=
x3
.
clone
()
x4_orig
=
x4
.
clone
()
# Run forward
ys
=
model
(
x1
,
x2
,
x3
,
x4
)
# Check whether outputs match (x4, x1, h1, x2, x3, x4)
assert
len
(
ys
)
==
6
assert
ys
[
0
].
data_ptr
()
==
x4
.
data_ptr
()
assert
ys
[
1
].
data_ptr
()
==
x1
.
data_ptr
()
assert
ys
[
2
].
data_ptr
()
not
in
[
x
.
data_ptr
()
for
x
in
(
x1
,
x2
,
x3
,
x4
)]
assert
ys
[
3
].
data_ptr
()
==
x2
.
data_ptr
()
assert
ys
[
4
].
data_ptr
()
==
x3
.
data_ptr
()
assert
ys
[
5
].
data_ptr
()
==
x4
.
data_ptr
()
# Check whether tensors have correct values
b
=
bias
.
bias
h1
=
ys
[
2
]
torch
.
testing
.
assert_close
(
x1
,
x1_orig
)
torch
.
testing
.
assert_close
(
h1
,
x1_orig
+
b
)
torch
.
testing
.
assert_close
(
x2
,
x2_orig
+
h1
)
torch
.
testing
.
assert_close
(
x3
,
x3_orig
+
x2
+
b
)
torch
.
testing
.
assert_close
(
x4
,
x4_orig
+
x3
)
class
TestFuser
:
class
TestFuser
:
"""Tests for operation fusion infrastructure"""
"""Tests for operation fusion infrastructure"""
@
staticmethod
@
staticmethod
def
setup_class
(
cls
)
->
None
:
def
setup_class
(
cls
)
->
None
:
# Configure RNG
reset_rng_states
()
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_scale_update
(
def
test_fp8_scale_update
(
...
@@ -494,10 +552,7 @@ class TestBasicOps:
...
@@ -494,10 +552,7 @@ class TestBasicOps:
@
staticmethod
@
staticmethod
def
setup_class
(
cls
)
->
None
:
def
setup_class
(
cls
)
->
None
:
# Configure RNG
reset_rng_states
()
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
...
@@ -795,10 +850,9 @@ class TestBasicOps:
...
@@ -795,10 +850,9 @@ class TestBasicOps:
pytest
.
skip
(
"FP8 output is only supported with FP8 GEMMs"
)
pytest
.
skip
(
"FP8 output is only supported with FP8 GEMMs"
)
if
quantized_grad_input
and
not
quantized_compute
:
if
quantized_grad_input
and
not
quantized_compute
:
pytest
.
skip
(
"FP8 grad input is only supported with FP8 GEMMs"
)
pytest
.
skip
(
"FP8 grad input is only supported with FP8 GEMMs"
)
if
quantization
==
"mxfp8"
and
quantized_output
:
if
quantization
not
in
(
None
,
"fp8"
):
pytest
.
skip
(
"MXFP8 output is not supported with MXFP8 GEMMs"
)
if
quantized_output
or
quantized_grad_input
:
if
quantization
==
"mxfp8"
and
quantized_grad_input
:
pytest
.
skip
(
"Recipe does not support quantized GEMM output"
)
pytest
.
skip
(
"MXFP8 grad input is not supported with MXFP8 GEMMs"
)
if
(
IS_HIP_EXTENSION
and
not
use_hipblaslt
()
and
if
(
IS_HIP_EXTENSION
and
not
use_hipblaslt
()
and
accumulate_into_main_grad
and
dtype
!=
torch
.
float32
and
not
quantized_compute
):
accumulate_into_main_grad
and
dtype
!=
torch
.
float32
and
not
quantized_compute
):
pytest
.
skip
(
"Parameters combination is not supported by ROCBLAS"
)
pytest
.
skip
(
"Parameters combination is not supported by ROCBLAS"
)
...
@@ -1353,18 +1407,17 @@ class TestBasicOps:
...
@@ -1353,18 +1407,17 @@ class TestBasicOps:
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
# L2Norm backward pass requires slightly looser atol for bfloat16
if
dtype
==
torch
.
bfloat16
:
tols
[
"atol"
]
=
2e-3
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"in_place"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_add_
in_place
(
def
test_add_
extra_input
(
self
,
self
,
*
,
*
,
in_shape
:
Iterable
[
int
]
=
(
32
,
32
),
in_shape
:
Iterable
[
int
]
=
(
32
,
32
),
in_place
:
bool
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
quantization
:
Optional
[
str
],
quantization
:
Optional
[
str
],
...
@@ -1410,7 +1463,7 @@ class TestBasicOps:
...
@@ -1410,7 +1463,7 @@ class TestBasicOps:
dx2_ref
=
dy_ref
dx2_ref
=
dy_ref
# Implementation with fusible operation
# Implementation with fusible operation
op
=
te_ops
.
Add
InPlace
(
)
op
=
te_ops
.
Add
ExtraInput
(
in_place
=
in_place
)
y_test
=
op
(
x1_test
,
x2_test
)
y_test
=
op
(
x1_test
,
x2_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
@@ -1425,6 +1478,7 @@ class TestBasicOps:
...
@@ -1425,6 +1478,7 @@ class TestBasicOps:
torch
.
testing
.
assert_close
(
dx1_test
,
dx1_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
dx1_test
,
dx1_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
dx2_test
,
dx2_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
dx2_test
,
dx2_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
parametrize
(
"in_place"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
...
@@ -1432,6 +1486,7 @@ class TestBasicOps:
...
@@ -1432,6 +1486,7 @@ class TestBasicOps:
self
,
self
,
*
,
*
,
in_shape
:
Iterable
[
int
]
=
(
32
,
32
),
in_shape
:
Iterable
[
int
]
=
(
32
,
32
),
in_place
:
bool
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
quantization
:
Optional
[
str
],
quantization
:
Optional
[
str
],
...
@@ -1477,7 +1532,7 @@ class TestBasicOps:
...
@@ -1477,7 +1532,7 @@ class TestBasicOps:
(
y1_ref
*
dy1_ref
+
y2_ref
*
dy2_ref
).
sum
().
backward
()
(
y1_ref
*
dy1_ref
+
y2_ref
*
dy2_ref
).
sum
().
backward
()
# Implementation with fusible operation
# Implementation with fusible operation
op
=
te_ops
.
MakeExtraOutput
()
op
=
te_ops
.
MakeExtraOutput
(
in_place
=
in_place
)
y1_test
,
y2_test
=
op
(
x_test
)
y1_test
,
y2_test
=
op
(
x_test
)
(
y1_test
*
dy1_test
+
y2_test
*
dy2_test
).
sum
().
backward
()
(
y1_test
*
dy1_test
+
y2_test
*
dy2_test
).
sum
().
backward
()
...
@@ -1645,16 +1700,107 @@ class TestBasicOps:
...
@@ -1645,16 +1700,107 @@ class TestBasicOps:
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
1
,
0
,
-
2.5
,
3.5
))
@
pytest
.
mark
.
parametrize
(
"shape"
,
((),
(
1
,
13
),
(
4
,
4
,
2
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
_devices
)
def
test_constant_scale
(
self
,
*
,
scale
:
float
,
shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
scale
*
x_ref
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
op
=
te_ops
.
ConstantScale
(
scale
)
y_test
=
op
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check results
tols
=
dtype_tols
(
dtype
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"prob"
,
(
0.1
,
0.5
,
0.75
))
@
pytest
.
mark
.
parametrize
(
"is_training"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"shape"
,
((
101
,),
(
2
,
4
,
16
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
def
test_dropout
(
self
,
*
,
prob
:
float
,
is_training
:
bool
,
shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
):
# Random data
x_ref
=
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
device
)
+
0.5
x_test
=
x_ref
.
clone
().
requires_grad_
()
dy_ref
=
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
device
)
+
0.5
dy_test
=
dy_ref
.
clone
()
# Apply dropout
op
=
te_ops
.
Dropout
(
prob
)
if
is_training
:
op
.
train
()
else
:
op
.
eval
()
y
=
op
(
x_test
)
y
.
backward
(
dy_test
)
# Check values
if
is_training
:
mask
=
((
y
!=
0
)
/
(
1
-
prob
)).
to
(
dtype
=
dtype
)
torch
.
testing
.
assert_close
(
y
,
x_ref
*
mask
)
torch
.
testing
.
assert_close
(
x_test
.
grad
,
dy_ref
*
mask
)
else
:
torch
.
testing
.
assert_close
(
y
,
x_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
x_test
.
grad
,
dy_ref
,
rtol
=
0
,
atol
=
0
)
# Hypothesis testing for number of zeros
# Note: A Bernoulli random variable with probability p has
# mean p and standard deviation sqrt(p*(1-p)). By the central
# limit theorem, the mean of n iid Bernoulli variables
# converges to a normal random variable with mean p and
# standard deviation sqrt(p*(1-p)/n). If the observed mean is
# below the 0.5th or above the 99.5th percentiles, then the
# p-value is less than 1% and we assume that the dropout
# distribution is incorrect.
if
is_training
:
prob_observed
=
1
-
torch
.
count_nonzero
(
y
).
item
()
/
y
.
numel
()
z_score
=
(
prob_observed
-
prob
)
/
math
.
sqrt
(
prob
*
(
1
-
prob
)
/
y
.
numel
())
assert
abs
(
z_score
)
<
2.5758
,
"Number of zeros is outside 99% confidence interval"
class
TestFusedOps
:
class
TestFusedOps
:
"""Tests for fused operations"""
"""Tests for fused operations"""
@
staticmethod
@
staticmethod
def
setup_class
(
cls
)
->
None
:
def
setup_class
(
cls
)
->
None
:
# Configure RNG
reset_rng_states
()
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"weight_shape"
,
((
32
,
64
),
(
3
,
5
)))
@
pytest
.
mark
.
parametrize
(
"weight_shape"
,
((
32
,
64
),
(
3
,
5
)))
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
-
1
,),
(
1
,
7
,
-
1
),
(
8
,
2
,
10
,
-
1
)))
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
-
1
,),
(
1
,
7
,
-
1
),
(
8
,
2
,
10
,
-
1
)))
...
@@ -1841,7 +1987,7 @@ class TestFusedOps:
...
@@ -1841,7 +1987,7 @@ class TestFusedOps:
device
=
device
,
device
=
device
,
dtype
=
dtype
,
dtype
=
dtype
,
),
),
te_ops
.
Add
InPlace
(
),
te_ops
.
Add
ExtraInput
(
in_place
=
True
),
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
model
[
0
].
weight
.
copy_
(
w_test
)
...
@@ -1878,11 +2024,114 @@ class TestFusedOps:
...
@@ -1878,11 +2024,114 @@ class TestFusedOps:
db_test
=
model
[
0
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
db_test
=
model
[
0
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
1
,
0
,
-
2.5
,
3.5
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_forward_linear_scale_add
(
self
,
*
,
scale
:
float
,
weight_shape
:
tuple
[
int
,
int
]
=
(
32
,
32
),
in_shape
:
Iterable
[
int
]
=
(
32
,
-
1
),
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantized_weight
:
bool
=
False
,
)
->
None
:
"""Forward GEMM + scale + add"""
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
# Random data
x1_ref
,
x1_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
x2_ref
,
x2_test
=
make_reference_and_test_tensors
(
out_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
x1_ref
,
w_ref
)
*
scale
+
x2_ref
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
out_features
,
bias
=
False
,
device
=
device
,
dtype
=
dtype
,
),
te_ops
.
ConstantScale
(
scale
),
te_ops
.
AddExtraInput
(
in_place
=
True
),
te_ops
.
Quantize
(),
)
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
del
w_test
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
y_test
=
model
(
x1_test
,
x2_test
)
y_test
.
backward
(
dy_test
)
# Check that forward operations have been fused
forward_ops
=
model
.
_module_groups
[
0
].
_forward_ops
assert
len
(
forward_ops
)
==
2
assert
isinstance
(
forward_ops
[
0
][
0
],
ForwardLinearScaleAdd
)
assert
isinstance
(
forward_ops
[
1
][
0
],
te_ops
.
Quantize
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx1_test
=
x1_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx2_test
=
x2_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
model
[
0
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx1_test
,
x1_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dx2_test
,
x2_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
(
"relu"
,
"gelu"
))
@
pytest
.
mark
.
parametrize
(
"activation"
,
(
"relu"
,
"gelu"
))
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
32
,
32
),
(
32
,
1
,
32
),
(
8
,
2
,
2
,
32
)))
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
32
,
32
),
(
32
,
1
,
32
),
(
8
,
2
,
2
,
32
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_backward_
bias_
activation
(
def
test_backward_activation
_bias
(
self
,
self
,
*
,
*
,
activation
:
str
,
activation
:
str
,
...
@@ -1891,7 +2140,7 @@ class TestFusedOps:
...
@@ -1891,7 +2140,7 @@ class TestFusedOps:
device
:
torch
.
device
=
"cuda"
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantization
:
Optional
[
str
],
)
->
None
:
)
->
None
:
"""Backward d
bias + dact
+ quantize"""
"""Backward d
act + dbias
+ quantize"""
# Tensor dimensions
# Tensor dimensions
in_shape
=
list
(
out_shape
)
in_shape
=
list
(
out_shape
)
...
@@ -1948,9 +2197,9 @@ class TestFusedOps:
...
@@ -1948,9 +2197,9 @@ class TestFusedOps:
# Check that backward operations have been fused
# Check that backward operations have been fused
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
if
with_quantization
and
quantization
in
[
"fp8_delayed_scaling"
,
"mxfp8"
]
:
if
with_quantization
:
assert
len
(
backward_ops
)
==
2
assert
len
(
backward_ops
)
==
2
assert
isinstance
(
backward_ops
[
0
][
0
],
Backward
Bias
Activation
)
assert
isinstance
(
backward_ops
[
0
][
0
],
BackwardActivation
Bias
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Quantize
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Quantize
)
else
:
else
:
assert
len
(
backward_ops
)
==
3
assert
len
(
backward_ops
)
==
3
...
@@ -1963,6 +2212,7 @@ class TestFusedOps:
...
@@ -1963,6 +2212,7 @@ class TestFusedOps:
if
with_quantization
:
if
with_quantization
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
db_test
=
model
[
1
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
db_test
=
model
[
1
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -2033,7 +2283,7 @@ class TestFusedOps:
...
@@ -2033,7 +2283,7 @@ class TestFusedOps:
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
):
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
):
model
=
te_ops
.
Sequential
(
model
=
te_ops
.
Sequential
(
te_ops
.
MakeExtraOutput
(),
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
te_ops
.
Linear
(
te_ops
.
Linear
(
in_features
,
in_features
,
out_features
,
out_features
,
...
@@ -2071,16 +2321,106 @@ class TestFusedOps:
...
@@ -2071,16 +2321,106 @@ class TestFusedOps:
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
1
,
0
,
-
2.5
,
3.5
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_backward_linear_scale
(
self
,
*
,
scale
:
float
,
weight_shape
:
tuple
[
int
,
int
]
=
(
32
,
32
),
in_shape
:
Iterable
[
int
]
=
(
32
,
-
1
),
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantized_weight
:
bool
=
False
,
)
->
None
:
"""Backward dgrad GEMM + scale"""
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
x_ref
,
w_ref
)
*
scale
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
):
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
out_features
,
bias
=
False
,
device
=
device
,
dtype
=
dtype
,
),
te_ops
.
ConstantScale
(
scale
),
)
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
del
w_test
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
y_test
=
model
(
x_test
)
(
y_test
*
dy_test
).
sum
().
backward
()
# Check that backward operations have been fused
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
assert
len
(
backward_ops
)
==
1
assert
isinstance
(
backward_ops
[
0
][
0
],
BackwardLinearScale
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
model
[
0
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
class
TestCheckpointing
:
class
TestCheckpointing
:
"""Tests for checkpointing"""
"""Tests for checkpointing"""
@
staticmethod
@
staticmethod
def
setup_class
(
cls
)
->
None
:
def
setup_class
(
cls
)
->
None
:
# Configure RNG
reset_rng_states
()
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
...
@@ -2192,11 +2532,9 @@ class TestSequentialModules:
...
@@ -2192,11 +2532,9 @@ class TestSequentialModules:
@
staticmethod
@
staticmethod
def
setup_class
(
cls
)
->
None
:
def
setup_class
(
cls
)
->
None
:
# Configure RNG
reset_rng_states
()
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"requires_grad"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"normalization"
,
(
"LayerNorm"
,
"RMSNorm"
))
@
pytest
.
mark
.
parametrize
(
"normalization"
,
(
"LayerNorm"
,
"RMSNorm"
))
@
pytest
.
mark
.
parametrize
(
"quantized_compute"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantized_compute"
,
(
False
,
True
))
...
@@ -2206,6 +2544,7 @@ class TestSequentialModules:
...
@@ -2206,6 +2544,7 @@ class TestSequentialModules:
def
test_layernorm_mlp
(
def
test_layernorm_mlp
(
self
,
self
,
*
,
*
,
requires_grad
:
bool
,
bias
:
bool
,
bias
:
bool
,
normalization
:
str
,
normalization
:
str
,
quantized_compute
:
bool
,
quantized_compute
:
bool
,
...
@@ -2246,6 +2585,7 @@ class TestSequentialModules:
...
@@ -2246,6 +2585,7 @@ class TestSequentialModules:
quantization
=
quantization
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_dtype
=
dtype
,
test_device
=
device
,
test_device
=
device
,
requires_grad
=
requires_grad
,
)
)
_
,
dy_test
=
make_reference_and_test_tensors
(
_
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
in_shape
,
...
...
tests/pytorch/test_hf_integration.py
View file @
87e3e56e
...
@@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig
...
@@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
from
transformer_engine.pytorch.transformer
import
TransformerLayer
from
transformer_engine.pytorch.transformer
import
TransformerLayer
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
class
SimpleTEModel
(
PreTrainedModel
):
class
SimpleTEModel
(
PreTrainedModel
):
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment