Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
87e3e56e
Commit
87e3e56e
authored
Aug 27, 2025
by
yuguo
Browse files
Merge commit '
734bcedd
' of...
Merge commit '
734bcedd
' of
https://github.com/NVIDIA/TransformerEngine
parents
2f11bd2e
734bcedd
Changes
217
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
989 additions
and
1104 deletions
+989
-1104
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+107
-310
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+186
-199
tests/pytorch/test_parallel_cross_entropy.py
tests/pytorch/test_parallel_cross_entropy.py
+0
-1
tests/pytorch/test_qk_norm.py
tests/pytorch/test_qk_norm.py
+182
-35
tests/pytorch/test_recipe.py
tests/pytorch/test_recipe.py
+35
-37
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+94
-461
tests/pytorch/utils.py
tests/pytorch/utils.py
+187
-0
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+7
-0
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+15
-3
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+35
-3
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
...ngine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
+24
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
...engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
+8
-0
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+5
-2
transformer_engine/common/common.h
transformer_engine/common/common.h
+3
-2
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+14
-10
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
+2
-2
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
...ngine/common/fused_router/fused_score_for_moe_aux_loss.cu
+6
-3
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
...ine/common/fused_router/fused_topk_with_score_function.cu
+3
-3
transformer_engine/common/fused_router/utils.h
transformer_engine/common/fused_router/utils.h
+48
-21
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+28
-12
No files found.
tests/pytorch/test_numerics.py
View file @
87e3e56e
...
...
@@ -2,7 +2,6 @@
#
# See LICENSE for license information.
from
collections
import
OrderedDict
import
math
import
os
from
typing
import
Dict
,
List
,
Tuple
,
Optional
...
...
@@ -39,54 +38,39 @@ from transformer_engine.pytorch import (
Fp8Unpadding
,
)
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
,
get_cudnn_version
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Record initial RNG state from script run.
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
# Reset RNG states.
reset_rng_states
()
if
torch_version
()
>=
(
2
,
7
,
0
):
torch
.
_dynamo
.
config
.
recompile_limit
=
16
else
:
torch
.
_dynamo
.
config
.
cache_size_limit
=
16
class
ModelConfig
:
def
__init__
(
self
,
hidden_size
,
eps
,
num_attention_heads
,
embed
,
num_layers
,
seq_len
):
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
num_attention_heads
=
num_attention_heads
self
.
embed
=
embed
self
.
num_layers
=
num_layers
self
.
seq_len
=
seq_len
model_configs
=
{
"small"
:
ModelConfig
(
1
28
,
1
e-5
,
8
,
3
6
,
4
,
128
),
"126m"
:
ModelConfig
(
768
,
1e-5
,
12
,
64
,
12
,
2048
),
"small"
:
ModelConfig
(
1
,
1
28
,
8
,
1
6
,
num_layers
=
4
),
"126m"
:
ModelConfig
(
1
,
2048
,
12
,
64
,
num_layers
=
12
),
}
model_configs_inference
=
{
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m"
:
ModelConfig
(
768
,
1e-5
,
12
,
64
,
12
,
256
),
"126m"
:
ModelConfig
(
1
,
256
,
12
,
64
,
num_layers
=
12
),
}
backends_inference
=
[
"FlashAttention"
,
"UnfusedAttention"
,
"FusedAttention"
]
module_inference
=
[
"TransformerLayer"
,
"MultiheadAttention"
]
...
...
@@ -120,12 +104,27 @@ if NVTE_TEST_NVINSPECT_ENABLED:
feature_dirs
=
os
.
environ
[
"NVTE_TEST_NVINSPECT_FEATURE_DIRS"
],
)
fp8_recipes
=
[
recipe
.
MXFP8BlockScaling
(),
recipe
.
DelayedScaling
(),
recipe
.
Float8CurrentScaling
(),
recipe
.
Float8BlockScaling
(),
]
fp8_recipes
=
[]
if
mxfp8_available
:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
fp8_block_scaling_available
:
fp8_recipes
.
append
(
recipe
.
Float8BlockScaling
())
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
def
is_fused_attn_available
(
config
:
ModelConfig
,
dtype
:
torch
.
dtype
,
qkv_layout
=
"bshd_bshd_bshd"
,
is_training
=
True
):
_
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
)
return
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
in
fused_attn_backends
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
...
...
@@ -177,12 +176,6 @@ def assert_allclose(
raise
AssertionError
(
msg
)
def
reset_rng_states
()
->
None
:
"""revert back to initial RNG state."""
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_global_fp8_state
():
yield
...
...
@@ -535,13 +528,13 @@ def _test_e2e_selective_recompute(
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
params_dtype
=
dtype
,
...
...
@@ -550,13 +543,13 @@ def _test_e2e_selective_recompute(
)
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
seq
_
len
)
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_
seqlen
_q
)
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
te_out
=
block
(
...
...
@@ -582,14 +575,8 @@ def _test_e2e_selective_recompute(
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
...
...
@@ -630,13 +617,13 @@ def _test_e2e_full_recompute(
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
params_dtype
=
dtype
,
...
...
@@ -645,14 +632,14 @@ def _test_e2e_full_recompute(
)
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
use_reentrant
,
)
if
use_reentrant
:
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
seq
_
len
)
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_
seqlen
_q
)
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
if
recompute
:
...
...
@@ -698,14 +685,8 @@ def _test_e2e_full_recompute(
def
test_gpt_full_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
...
...
@@ -761,13 +742,13 @@ def _test_e2e_checkpointing_get_model(config, dtype):
return
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
params_dtype
=
dtype
,
...
...
@@ -779,7 +760,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
reset_rng_states
()
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
...
...
@@ -809,14 +790,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
if
p
.
requires_grad
:
param_grads
.
append
(
p
.
grad
.
clone
())
global
_cpu_rng_state
,
_cuda_rng_state
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
del
block
block
=
_test_e2e_checkpointing_get_model
(
config
,
dtype
)
block
.
load_state_dict
(
torch
.
load
(
path
,
weights_only
=
False
))
reset_rng_states
()
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
...
...
@@ -849,6 +830,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
):
pytest
.
skip
(
"No attention backend available."
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
...
...
@@ -869,13 +852,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
reset_rng_states
()
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
inp_hidden_states
.
retain_grad
()
inp_attn_mask
=
get_causal_attn_mask
(
config
.
seq
_
len
)
inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_
seqlen
_q
)
out
=
block
(
inp_hidden_states
,
attention_mask
=
inp_attn_mask
)
loss
=
out
.
sum
()
...
...
@@ -895,11 +878,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
False
):
pytest
.
skip
(
"No attention backend available."
)
te_gpt
=
TransformerLayer
(
hidden_size
=
config
.
hidden_size
,
ffn_hidden_size
=
4
*
config
.
hidden_size
,
num_attention_heads
=
config
.
num_
attention_
heads
,
num_attention_heads
=
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
attention_dropout
=
0.1
,
hidden_dropout
=
0.1
,
...
...
@@ -914,7 +899,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
TorchGPT
(
config
.
hidden_size
,
config
.
eps
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
parallel_attention_mlp
=
parallel_attention_mlp
,
)
.
to
(
dtype
=
dtype
)
...
...
@@ -975,13 +960,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states
()
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
inp_hidden_states
.
retain_grad
()
inp_attn_mask
=
get_causal_attn_mask
(
config
.
seq
_
len
)
if
mask_type
==
"causal"
else
None
inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_
seqlen
_q
)
if
mask_type
==
"causal"
else
None
forward_kwargs
=
{}
if
te
:
...
...
@@ -1006,10 +991,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
False
):
pytest
.
skip
(
"No attention backend available."
)
te_mha
=
MultiheadAttention
(
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
fuse_qkv_params
=
True
,
params_dtype
=
dtype
,
qkv_weight_interleaved
=
False
,
...
...
@@ -1020,7 +1007,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
torch_mha
=
(
TorchMHA
(
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
)
.
to
(
dtype
=
dtype
)
.
cuda
()
...
...
@@ -1066,7 +1053,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
FP8GlobalStateManager
.
reset
()
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
...
...
@@ -1098,11 +1085,12 @@ def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states
()
mask
=
torch
.
triu
(
torch
.
ones
(
config
.
seq_len
,
config
.
seq_len
,
dtype
=
torch
.
bool
,
device
=
"cuda"
),
diagonal
=
1
torch
.
ones
(
config
.
max_seqlen_q
,
config
.
max_seqlen_kv
,
dtype
=
torch
.
bool
,
device
=
"cuda"
),
diagonal
=
1
,
)
query
,
key
,
value
=
[
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
num_
attention_
heads
,
config
.
embed
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
num_heads
,
config
.
kv_channels
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
...
...
@@ -1131,8 +1119,8 @@ def test_dpa_accuracy(dtype, bs, model):
te_dpa
=
(
DotProductAttention
(
config
.
num_
attention_
heads
,
config
.
embed
,
config
.
num_heads
,
config
.
kv_channels
,
attention_dropout
=
0.0
,
# disable dropout, FU uses rng differently
)
.
to
(
dtype
=
dtype
)
...
...
@@ -1141,7 +1129,7 @@ def test_dpa_accuracy(dtype, bs, model):
torch_dpa
=
(
TorchDotProductAttention
(
config
.
embed
,
config
.
kv_channels
,
0.0
,
# dropout
)
.
to
(
dtype
=
dtype
)
...
...
@@ -1267,8 +1255,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
te_linear_ref
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
)
# Shoul
e
be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
te_outputs
,
te_outputs_ref
)):
# Shoul
d
be bit-wise match
for
_
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
te_outputs
,
te_outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
...
...
@@ -1280,17 +1268,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation
=
True
fp8_model_params
=
False
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
config
=
model_configs
[
model
]
if
config
.
seq
_
len
%
16
!=
0
and
fp8
:
if
config
.
max_
seqlen
_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
...
...
@@ -1653,14 +1636,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
2
]
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
all_activations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
def
test_layernorm_mlp_accuracy_delay_wgrad_compute
(
dtype
,
bs
,
model
,
activation
,
normalization
,
bias
,
fuse_wgrad_accumulation
dtype
,
bs
,
model
,
bias
,
fuse_wgrad_accumulation
):
config
=
model_configs
[
model
]
...
...
@@ -1669,7 +1650,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size
=
4
*
config
.
hidden_size
,
eps
=
config
.
eps
,
bias
=
bias
,
normalization
=
normalization
,
params_dtype
=
dtype
,
device
=
"cuda"
,
delay_wgrad_compute
=
True
,
...
...
@@ -1681,7 +1661,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size
=
4
*
config
.
hidden_size
,
eps
=
config
.
eps
,
bias
=
bias
,
normalization
=
normalization
,
params_dtype
=
dtype
,
device
=
"cuda"
,
delay_wgrad_compute
=
False
,
...
...
@@ -1691,8 +1670,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
# Share params
with
torch
.
no_grad
():
ln_mlp_ref
.
layer_norm_weight
=
Parameter
(
ln_mlp
.
layer_norm_weight
.
clone
())
if
normalization
!=
"RMSNorm"
:
ln_mlp_ref
.
layer_norm_bias
=
Parameter
(
ln_mlp
.
layer_norm_bias
.
clone
())
ln_mlp_ref
.
layer_norm_bias
=
Parameter
(
ln_mlp
.
layer_norm_bias
.
clone
())
ln_mlp_ref
.
fc1_weight
=
Parameter
(
ln_mlp
.
fc1_weight
.
clone
())
ln_mlp_ref
.
fc2_weight
=
Parameter
(
ln_mlp
.
fc2_weight
.
clone
())
if
bias
:
...
...
@@ -1730,7 +1708,7 @@ def _test_grouped_linear_accuracy(
FP8GlobalStateManager
.
reset
()
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
...
...
@@ -1743,14 +1721,14 @@ def _test_grouped_linear_accuracy(
split_size
=
16
if
recipe
.
mxfp8
():
split_size
=
128
m
=
config
.
seq
_
len
//
split_size
m
=
config
.
max_
seqlen
_q
//
split_size
dist
=
torch
.
sort
(
torch
.
randint
(
0
,
m
,
(
num_gemms
-
2
,))).
values
.
tolist
()
dist
.
append
(
dist
[
-
1
])
# Manually add a zero
m_splits
=
torch
.
tensor
(
dist
+
[
m
])
-
torch
.
tensor
([
0
]
+
dist
)
m_splits
=
m_splits
*
split_size
assert
m_splits
.
sum
()
==
config
.
seq
_
len
and
len
(
m_splits
)
==
num_gemms
assert
m_splits
.
sum
()
==
config
.
max_
seqlen
_q
and
len
(
m_splits
)
==
num_gemms
else
:
m_splits
=
torch
.
tensor
([
config
.
seq
_
len
])
m_splits
=
torch
.
tensor
([
config
.
max_
seqlen
_q
])
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
if
isinstance
(
block
,
GroupedLinear
):
...
...
@@ -1806,17 +1784,11 @@ def test_grouped_linear_accuracy(
parallel_mode
=
None
,
):
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
if
config
.
seq
_
len
%
16
!=
0
and
fp8
:
if
config
.
max_
seqlen
_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
...
...
@@ -1908,19 +1880,13 @@ def test_grouped_linear_accuracy_save_original_input(
parallel_mode
=
None
,
):
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
config
=
model_configs
[
model
]
if
config
.
seq
_
len
%
16
!=
0
and
fp8
:
if
config
.
max_
seqlen
_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
...
...
@@ -2074,14 +2040,14 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
FP8GlobalStateManager
.
reset
()
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
*
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
*
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
inp_hidden_states
.
retain_grad
()
m_splits
=
_generate_random_numbers
(
num_gemms
,
config
.
seq
_
len
*
bs
)
m_splits
=
_generate_random_numbers
(
num_gemms
,
config
.
max_
seqlen
_q
*
bs
)
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
if
isinstance
(
block
,
TorchGroupedLinearWithPadding
):
...
...
@@ -2124,17 +2090,11 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params
,
parallel_mode
=
None
,
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
if
config
.
seq
_
len
%
16
!=
0
and
fp8
:
if
config
.
max_
seqlen
_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
...
...
@@ -2199,19 +2159,13 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8_model_params
,
parallel_mode
=
None
,
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
config
=
model_configs
[
model
]
if
config
.
seq
_
len
%
16
!=
0
and
fp8
:
if
config
.
max_
seqlen
_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
...
...
@@ -2268,9 +2222,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
# Placeholders used for graph capture.
static_input
=
torch
.
randn
(
config
.
seq_len
,
bs
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
config
.
max_seqlen_q
,
bs
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
)
static_target
=
torch
.
randn
(
config
.
max_seqlen_q
,
bs
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
)
static_target
=
torch
.
randn
(
config
.
seq_len
,
bs
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
)
real_input
=
torch
.
rand_like
(
static_input
)
real_target
=
torch
.
rand_like
(
static_target
)
...
...
@@ -2334,7 +2290,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
block_args
=
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
)
block_kwargs
=
dict
(
layernorm_epsilon
=
config
.
eps
,
...
...
@@ -2342,7 +2298,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
...
...
@@ -2377,13 +2333,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
params_dtype
=
dtype
,
...
...
@@ -2392,13 +2348,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
)
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
seq
_
len
)
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_
seqlen
_q
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe
):
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
)
...
...
@@ -2418,14 +2374,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
...
...
@@ -2461,13 +2411,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_sbhd
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0
,
attention_dropout
=
0
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
...
...
@@ -2482,13 +2432,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_bshd
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0
,
attention_dropout
=
0
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
...
...
@@ -2500,13 +2450,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_thd
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0
,
attention_dropout
=
0
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
...
...
@@ -2521,15 +2471,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
assert
torch
.
all
(
torch
.
eq
(
p1
,
p2
)
&
torch
.
eq
(
p1
,
p3
)),
f
"
{
n1
}
,
{
n2
}
and
{
n3
}
not identical"
x_sbhd
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
x_bshd
=
x_sbhd
.
transpose
(
0
,
1
).
contiguous
()
x_thd
=
x_bshd
.
reshape
(
bs
*
config
.
seq
_
len
,
config
.
hidden_size
).
contiguous
()
x_thd_cumsum
=
torch
.
arange
(
bs
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
*
config
.
seq
_
len
x_thd
=
x_bshd
.
reshape
(
bs
*
config
.
max_
seqlen
_q
,
config
.
hidden_size
).
contiguous
()
x_thd_cumsum
=
torch
.
arange
(
bs
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
*
config
.
max_
seqlen
_q
# To make sure forward is also identical (just in case some module decides
# to act fancy)
...
...
@@ -2556,167 +2506,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
x_thd
,
cu_seqlens_q
=
x_thd_cumsum
,
cu_seqlens_kv
=
x_thd_cumsum
,
max_seqlen_q
=
config
.
seq
_
len
,
max_seqlen_kv
=
config
.
seq
_
len
,
max_seqlen_q
=
config
.
max_
seqlen
_q
,
max_seqlen_kv
=
config
.
max_
seqlen
_kv
,
)
torch
.
testing
.
assert_close
(
y_bshd
,
y_thd
.
reshape
(
bs
,
config
.
seq_len
,
config
.
hidden_size
).
contiguous
(),
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model_key"
,
model_configs_inference
.
keys
())
@
pytest
.
mark
.
parametrize
(
"use_RoPE"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"input_format"
,
input_formats_inference
)
@
pytest
.
mark
.
parametrize
(
"module"
,
module_inference
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
backends_inference
)
@
pytest
.
mark
.
parametrize
(
"is_paged"
,
[
False
,
True
])
def
test_kv_cache_accuracy
(
dtype
,
bs
,
model_key
,
use_RoPE
,
input_format
,
module
,
backend
,
is_paged
):
reset_rng_states
()
if
backend
in
[
"FusedAttention"
]:
pytest
.
skip
(
"Not support FusedAttention"
)
if
backend
in
[
"FusedAttention"
,
"FlashAttention"
]
and
dtype
==
torch
.
float32
:
pytest
.
skip
(
"FusedAttention and FlashAttention do not support FP32"
)
if
use_RoPE
:
pytest
.
skip
(
"KV cache does not support starting positions for RoPE"
)
if
(
backend
==
"FusedAttention"
and
get_device_compute_capability
()
==
(
8
,
9
)
and
get_cudnn_version
()
<
(
9
,
12
,
0
)
):
pytest
.
skip
(
"Skip KV cache for sm89 and cuDNN < 9.12"
)
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
elif
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
elif
backend
==
"UnfusedAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
config
=
model_configs_inference
[
model_key
]
S
=
config
.
seq_len
B
=
bs
H
=
config
.
num_attention_heads
D
=
config
.
hidden_size
head_size
=
config
.
embed
layer_number
=
1
# Limits the max size of KV-cache
B_max
=
B
S_max
=
S
if
module
==
"TransformerLayer"
:
model
=
TransformerLayer
(
hidden_size
=
D
,
ffn_hidden_size
=
4
*
D
,
num_attention_heads
=
H
,
attn_input_format
=
input_format
,
self_attn_mask_type
=
"causal"
,
enc_dec_attn_mask_type
=
"causal"
,
layer_number
=
layer_number
,
attention_dropout
=
0.0
,
params_dtype
=
dtype
,
device
=
"cuda"
,
).
eval
()
else
:
model
=
(
MultiheadAttention
(
hidden_size
=
D
,
num_attention_heads
=
H
,
qkv_format
=
input_format
,
layer_number
=
layer_number
,
attention_dropout
=
0.0
,
attn_mask_type
=
"causal"
,
params_dtype
=
dtype
,
)
.
cuda
()
.
eval
()
)
inference_params
=
InferenceParams
(
max_batch_size
=
B_max
,
max_sequence_length
=
S_max
,
num_heads_kv
=
H
,
head_dim_k
=
head_size
,
dtype
=
dtype
,
is_paged
=
is_paged
,
total_num_pages
=
int
(
B_max
*
S_max
/
256
),
page_size
=
256
,
)
rotary_freqs
=
torch
.
randn
((
S_max
,
1
,
1
,
head_size
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
input
=
torch
.
randn
((
S
,
B
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
if
input_format
==
"bshd"
:
input
=
input
.
transpose
(
0
,
1
).
contiguous
()
incremental_output
=
torch
.
zeros_like
(
input
)
# Generate output for the entire sequence
full_output
=
model
(
hidden_states
=
input
,
rotary_pos_emb
=
rotary_freqs
if
use_RoPE
else
None
)
# Incrementaly generate outputs using KV-cache
step_dict
=
OrderedDict
(
zip
(
list
(
range
(
B
)),
[
1
]
*
B
))
for
i
in
range
(
S
):
inference_params
.
pre_step
(
step_dict
)
if
input_format
==
"sbhd"
:
incremental_input
=
input
[
i
].
view
(
1
,
B
,
D
)
else
:
incremental_input
=
input
[:,
i
,
:].
view
(
B
,
1
,
D
)
seqlens_q
=
torch
.
ones
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cu_seqlens_q
=
torch
.
zeros
(
B
+
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
cu_seqlens_kv
=
cu_seqlens_q
.
clone
()
mask_type
=
"padding"
kwargs
=
{}
if
module
==
"TransformerLayer"
:
kwargs
[
"self_attn_mask_type"
]
=
mask_type
else
:
kwargs
[
"attn_mask_type"
]
=
mask_type
line_output
=
model
(
hidden_states
=
incremental_input
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_freqs
if
use_RoPE
else
None
,
**
kwargs
,
max_seqlen_q
=
1
,
max_seqlen_kv
=
S
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
y_thd
.
reshape
(
bs
,
config
.
max_seqlen_q
,
config
.
hidden_size
).
contiguous
(),
)
if
input_format
==
"sbhd"
:
incremental_output
[
i
,
:,
:]
=
line_output
.
view
(
B
,
D
)
else
:
incremental_output
[:,
i
,
:]
=
line_output
.
view
(
B
,
D
)
if
module
==
"TransformerLayer"
:
atol
=
{
torch
.
float32
:
5e-3
,
torch
.
half
:
5e-3
,
torch
.
bfloat16
:
5e-2
,
}
else
:
atol
=
{
torch
.
float32
:
1e-3
,
torch
.
half
:
1e-3
,
torch
.
bfloat16
:
1e-2
,
}
# Check if the fully generated output matches the one generated incrementally
assert_allclose
(
full_output
,
incremental_output
,
atol
[
dtype
])
@
pytest
.
mark
.
parametrize
(
"shape"
,
...
...
@@ -2815,9 +2613,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
(
16
,
4096
,
128
,
512
),
],
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
])
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
False
,
True
])
def
test_fp8_grouped_gemm
(
shape
,
fp8_dtype
,
accumulate
):
def
test_fp8_grouped_gemm
(
shape
,
accumulate
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
...
...
tests/pytorch/test_onnx_export.py
View file @
87e3e56e
...
...
@@ -27,7 +27,6 @@ import warnings
import
numpy
as
np
import
onnxruntime
as
ort
import
torch
import
random
from
torch
import
nn
as
nn
from
typing
import
Optional
,
Union
,
Tuple
,
List
from
onnxruntime_extensions
import
PyCustomOpDef
,
get_library_path
,
onnx_op
...
...
@@ -59,14 +58,13 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
skip_FP8
=
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
skip_MXFP8
=
pytest
.
mark
.
skipif
(
not
mxfp8_available
,
reason
=
reason_for_no_mxfp8
)
fp8_recipes
=
[
None
,
recipe
.
DelayedScaling
(),
recipe
.
MXFP8BlockScaling
(),
]
fp8_recipes
=
[]
if
mxfp8_available
:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
None
)
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
...
...
@@ -369,14 +367,6 @@ def validate_result(
)
def
create_meta
(
scale_factor
:
float
,
size
:
int
=
1
):
meta
=
tex
.
FP8TensorMeta
()
meta
.
amax_history
=
torch
.
zeros
(
1
,
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
meta
.
scale_inv
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
/
scale_factor
meta
.
scale
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
scale_factor
return
meta
def
dtype2str
(
dtype
:
torch
.
dtype
,
fake_bf16_io
=
False
):
if
fake_bf16_io
:
assert
dtype
==
torch
.
bfloat16
...
...
@@ -413,36 +403,12 @@ Test cases begin here.
"""
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
# Returning the bias is a TE fusion optimization we don't care about.
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision, use_bias"
,
[
(
torch
.
float32
,
False
),
(
torch
.
float32
,
True
),
(
torch
.
float16
,
False
),
(
torch
.
float16
,
True
),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(
torch
.
bfloat16
,
False
),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(
torch
.
bfloat16
,
True
),
],
)
def
test_export_linear
(
seed_default_rng
,
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
use_bias
:
bool
,
return_bias
:
bool
,
precision
:
torch
.
dtype
,
def
_test_export_linear
(
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
use_bias
:
bool
=
True
,
return_bias
:
bool
=
False
,
precision
:
torch
.
dtype
=
torch
.
float32
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
return_bias
and
not
use_bias
:
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
...
...
@@ -498,32 +464,28 @@ def test_export_linear(
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
,
],
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm
(
seed_default_rng
,
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
precision
:
torch
.
dtype
,
zero_centered_gamma
:
bool
,
normalization
:
str
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
_test_export_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_export_linear_use_bias
(
seed_default_rng
,
use_bias
):
_test_export_linear
(
use_bias
=
use_bias
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
def
test_export_linear_return_bias
(
seed_default_rng
,
return_bias
):
_test_export_linear
(
return_bias
=
return_bias
)
def
_test_export_layernorm
(
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
precision
:
torch
.
dtype
=
torch
.
float32
,
zero_centered_gamma
:
bool
=
False
,
normalization
:
str
=
all_normalizations
[
0
],
):
# Set dimensions (these are arbitrary).
batch_size
=
4
in_features
=
64
...
...
@@ -564,39 +526,31 @@ def test_export_layernorm(
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"return_layernorm_output"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision, use_bias"
,
[
(
torch
.
float32
,
False
),
(
torch
.
float32
,
True
),
(
torch
.
float16
,
True
),
(
torch
.
float16
,
False
),
(
torch
.
bfloat16
,
True
),
(
torch
.
bfloat16
,
False
),
],
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
_test_export_layernorm
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_zero_centered_gamma
(
seed_default_rng
):
_test_export_layernorm
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm_linear
(
seed_default_rng
,
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
use_bias
:
bool
,
return_bias
:
bool
,
return_layernorm_output
:
bool
,
precision
:
torch
.
dtype
,
zero_centered_gamma
:
bool
,
normalization
:
str
,
def
test_export_layernorm_normalization
(
seed_default_rng
,
normalization
):
_test_export_layernorm
(
normalization
=
normalization
)
def
_test_export_layernorm_linear
(
scale_factor
:
float
=
112
,
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
use_bias
:
bool
=
True
,
return_bias
:
bool
=
False
,
return_layernorm_output
:
bool
=
False
,
precision
:
torch
.
dtype
=
torch
.
float32
,
zero_centered_gamma
:
bool
=
False
,
normalization
:
str
=
all_normalizations
[
0
],
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
return_bias
and
not
use_bias
:
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
...
...
@@ -644,41 +598,44 @@ def test_export_layernorm_linear(
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"return_layernorm_output"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision, use_bias"
,
[
(
torch
.
float32
,
False
),
(
torch
.
float32
,
True
),
(
torch
.
float16
,
True
),
(
torch
.
float16
,
False
),
(
torch
.
bfloat16
,
True
),
(
torch
.
bfloat16
,
False
),
],
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm_mlp
(
seed_default_rng
,
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
use_bias
:
bool
,
return_bias
:
bool
,
return_layernorm_output
:
bool
,
precision
:
torch
.
dtype
,
zero_centered_gamma
:
bool
,
activation
:
str
,
normalization
:
str
,
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
_test_export_layernorm_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_linear_return_ln_out
(
seed_default_rng
):
_test_export_layernorm_linear
(
return_layernorm_output
=
True
)
def
test_export_layernorm_linear_zero_centered_gamma
(
seed_default_rng
):
_test_export_layernorm_linear
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_linear_normalization
(
seed_default_rng
,
normalization
):
_test_export_layernorm_linear
(
normalization
=
normalization
)
def
test_export_layernorm_linear_no_bias
(
seed_default_rng
):
_test_export_layernorm_linear
(
use_bias
=
False
)
def
test_export_layernorm_linear_return_bias
(
seed_default_rng
):
_test_export_layernorm_linear
(
return_bias
=
True
)
def
_test_export_layernorm_mlp
(
scale_factor
:
float
=
112
,
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
use_bias
:
bool
=
True
,
return_bias
:
bool
=
False
,
return_layernorm_output
:
bool
=
False
,
precision
:
torch
.
dtype
=
torch
.
float32
,
zero_centered_gamma
:
bool
=
False
,
activation
:
str
=
supported_activations
[
0
],
normalization
:
str
=
all_normalizations
[
0
],
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
return_bias
and
not
use_bias
:
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
...
...
@@ -720,6 +677,38 @@ def test_export_layernorm_mlp(
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_mlp
(
seed_default_rng
,
fp8_recipe
,
precision
):
_test_export_layernorm_mlp
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_mlp_return_layernorm_output
(
seed_default_rng
):
_test_export_layernorm_mlp
(
return_layernorm_output
=
True
)
def
test_export_layernorm_mlp_return_bias
(
seed_default_rng
):
_test_export_layernorm_mlp
(
return_bias
=
True
)
def
test_export_layernorm_mlp_no_bias
(
seed_default_rng
):
_test_export_layernorm_mlp
(
use_bias
=
False
)
def
test_export_layernorm_mlp_zero_centered_gamma
(
seed_default_rng
):
_test_export_layernorm_mlp
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_mlp_normalization
(
seed_default_rng
,
normalization
):
_test_export_layernorm_mlp
(
normalization
=
normalization
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_layernorm_mlp_activation
(
seed_default_rng
,
activation
):
_test_export_layernorm_mlp
(
activation
=
activation
)
@
pytest
.
mark
.
parametrize
(
"precision, use_mask, attn_mask_type"
,
[
...
...
@@ -734,8 +723,6 @@ def test_export_layernorm_mlp(
],
)
def
test_export_core_attention
(
seed_default_rng
,
set_max_seq_len
,
precision
:
torch
.
dtype
,
use_mask
:
bool
,
attn_mask_type
:
str
,
...
...
@@ -777,11 +764,6 @@ def test_export_core_attention(
)
test_configs_multihead_attention
=
[
# "use_mask, attn_mask_type"
(
False
,
"no_mask"
),
# calls ScaledSoftmax
(
True
,
"arbitrary"
),
# calls ScaledMaskedSoftmax
]
test_configs_attention_type
=
[
# "input_layernorm, attention_type, fuse_qkv_params"
(
True
,
"self"
,
True
),
...
...
@@ -795,31 +777,14 @@ test_configs_attention_type = [
]
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"use_mask, attn_mask_type"
,
test_configs_multihead_attention
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"return_layernorm_output"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"input_layernorm, attention_type, fuse_qkv_params"
,
test_configs_attention_type
)
def
test_export_multihead_attention
(
seed_default_rng
,
set_max_seq_len
,
fp8_recipe
:
recipe
.
Recipe
,
use_mask
:
bool
,
attn_mask_type
:
str
,
precision
:
torch
.
dtype
,
return_layernorm_output
:
bool
,
input_layernorm
:
bool
,
attention_type
:
str
,
fuse_qkv_params
:
bool
,
def
_test_export_multihead_attention
(
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
use_mask
:
bool
=
True
,
precision
:
torch
.
dtype
=
torch
.
float32
,
input_layernorm
:
bool
=
True
,
attention_type
:
str
=
"self"
,
fuse_qkv_params
:
bool
=
True
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
hidden_size
=
256
sequence_length
=
128
batch_size
=
4
...
...
@@ -837,6 +802,7 @@ def test_export_multihead_attention(
init_method
,
output_layer_init_method
,
)
attn_mask_type
=
"arbitrary"
if
use_mask
else
"no_mask"
hidden_states_context
=
torch
.
randn
(
sequence_length
,
batch_size
,
hidden_size
,
dtype
=
precision
,
device
=
"cuda"
...
...
@@ -868,7 +834,7 @@ def test_export_multihead_attention(
*
attention_args
,
attn_mask_type
=
attn_mask_type
,
params_dtype
=
precision
,
return_layernorm_output
=
return_layernorm_output
,
return_layernorm_output
=
False
,
input_layernorm
=
input_layernorm
,
attention_type
=
attention_type
,
fuse_qkv_params
=
fuse_qkv_params
,
...
...
@@ -960,30 +926,37 @@ def test_export_multihead_attention(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"use_mask, attn_mask_type"
,
test_configs_multihead_attention
)
@
pytest
.
mark
.
parametrize
(
"output_layernorm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"fuse_qkv_params"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
)
def
test_export_transformer_layer
(
seed_default_rng
,
set_max_seq_len
,
fp8_recipe
:
recipe
.
Recipe
,
use_mask
:
bool
,
attn_mask_type
:
str
,
output_layernorm
:
bool
,
precision
:
torch
.
dtype
,
fuse_qkv_params
:
bool
,
zero_centered_gamma
:
bool
,
activation
:
str
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
def
test_export_multihead_attention_recipe
(
fp8_recipe
,
precision
):
_test_export_multihead_attention
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_multihead_attention_no_mask
():
_test_export_multihead_attention
(
use_mask
=
False
)
def
test_export_multihead_attention_no_input_layernorm
():
_test_export_multihead_attention
(
input_layernorm
=
False
)
def
test_export_multihead_attention_cross_attn
():
_test_export_multihead_attention
(
attention_type
=
"cross"
)
def
test_export_multihead_attention_unfused_qkv_params
():
_test_export_multihead_attention
(
fuse_qkv_params
=
False
)
def
_test_export_transformer_layer
(
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
use_mask
:
bool
=
True
,
attn_mask_type
:
str
=
"arbitrary"
,
output_layernorm
:
bool
=
False
,
precision
:
torch
.
dtype
=
torch
.
float32
,
fuse_qkv_params
:
bool
=
True
,
zero_centered_gamma
:
bool
=
False
,
activation
:
str
=
supported_activations
[
0
],
):
# Layer configuration
hidden_size
=
64
sequence_length
=
128
...
...
@@ -1043,28 +1016,43 @@ def test_export_transformer_layer(
)
@
skip_FP8
@
skip_MXFP8
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_transformer_layer_recipe
(
fp8_recipe
,
precision
):
_test_export_transformer_layer
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_transformer_layer_no_mask
():
_test_export_transformer_layer
(
use_mask
=
False
)
def
test_export_transformer_layer_output_layernorm
():
_test_export_transformer_layer
(
output_layernorm
=
True
)
def
test_export_transformer_layer_unfused_qkv_params
():
_test_export_transformer_layer
(
fuse_qkv_params
=
False
)
def
test_export_transformer_layer_zero_centered_gamma
():
_test_export_transformer_layer
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_transformer_layer_activation
(
activation
):
_test_export_transformer_layer
(
activation
=
activation
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
True
])
def
test_export_gpt_generation
(
seed_default_rng
,
set_max_seq_len
,
fp8_recipe
:
recipe
.
Recipe
,
precision
:
torch
.
dtype
,
zero_centered_gamma
:
bool
,
):
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Layer configuration
hidden_size
=
64
sequence_length
=
128
...
...
@@ -1091,7 +1079,6 @@ def test_export_gpt_generation(
output_layernorm
=
output_layernorm
,
params_dtype
=
precision
,
fuse_qkv_params
=
fuse_qkv_params
,
zero_centered_gamma
=
zero_centered_gamma
,
).
to
(
device
=
"cuda"
)
# "Context phase": use full input sequence length
...
...
tests/pytorch/test_parallel_cross_entropy.py
View file @
87e3e56e
...
...
@@ -3,7 +3,6 @@
# See LICENSE for license information.
import
random
import
pytest
import
torch
from
transformer_engine.pytorch.cross_entropy
import
parallel_cross_entropy
...
...
tests/pytorch/test_qk_norm.py
View file @
87e3e56e
...
...
@@ -8,10 +8,10 @@ import pytest
import
torch
@
pytest
.
mark
.
parametrize
(
"
use_
qk_norm
"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"qk_norm
_type"
,
[
None
,
"L2Normalization"
,
"RMSNorm"
,
"LayerNorm"
])
@
pytest
.
mark
.
parametrize
(
"attention_type"
,
[
"self"
,
"cross"
])
@
pytest
.
mark
.
parametrize
(
"qk_norm_eps"
,
[
1e-6
,
1e-5
])
def
test_qk_norm_functionality
(
use_
qk_norm
,
attention_type
,
qk_norm_eps
)
->
None
:
def
test_qk_norm_functionality
(
qk_norm
_type
,
attention_type
,
qk_norm_eps
)
->
None
:
"""Test QK normalization functionality, module structure, and numerical behavior."""
hidden_size
=
256
num_attention_heads
=
8
...
...
@@ -22,25 +22,59 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
attention_type
=
attention_type
,
use_
qk_norm
=
use_
qk_norm
,
qk_norm
_type
=
qk_norm
_type
,
qk_norm_eps
=
qk_norm_eps
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Check module structure based on use_qk_norm parameter
if
use_qk_norm
:
assert
hasattr
(
mha
,
"qk_norm"
),
"Should have qk_norm module when use_qk_norm=True"
assert
not
hasattr
(
mha
,
"q_l2norm"
),
"Should not have separate q_l2norm module"
assert
not
hasattr
(
mha
,
"k_l2norm"
),
"Should not have separate k_l2norm module"
# Check that the module is L2Norm type
from
transformer_engine.pytorch.ops.basic.l2normalization
import
L2Normalization
assert
isinstance
(
mha
.
qk_norm
,
L2Normalization
),
"qk_norm should be an L2Normalization module"
# Check module structure based on qk_norm_type parameter
if
qk_norm_type
is
not
None
:
assert
mha
.
q_norm
is
not
None
,
"Should have q_norm module when qk_norm_type is not None"
assert
mha
.
k_norm
is
not
None
,
"Should have k_norm module when qk_norm_type is not None"
# Check that the modules are of the correct type
if
qk_norm_type
==
"L2Normalization"
:
from
transformer_engine.pytorch.ops.basic.l2normalization
import
L2Normalization
assert
isinstance
(
mha
.
q_norm
,
L2Normalization
),
"q_norm should be an L2Normalization module"
assert
isinstance
(
mha
.
k_norm
,
L2Normalization
),
"k_norm should be an L2Normalization module"
# For L2 normalization, q_norm and k_norm should be the same instance (parameter-free)
assert
(
mha
.
q_norm
is
mha
.
k_norm
),
"q_norm and k_norm should be the same instance for L2 normalization"
elif
qk_norm_type
==
"RMSNorm"
:
from
transformer_engine.pytorch.module.rmsnorm
import
RMSNorm
assert
isinstance
(
mha
.
q_norm
,
RMSNorm
),
"q_norm should be an RMSNorm module"
assert
isinstance
(
mha
.
k_norm
,
RMSNorm
),
"k_norm should be an RMSNorm module"
# For RMS normalization, q_norm and k_norm should be separate instances
assert
(
mha
.
q_norm
is
not
mha
.
k_norm
),
"q_norm and k_norm should be separate instances for RMS normalization"
elif
qk_norm_type
==
"LayerNorm"
:
from
transformer_engine.pytorch.module.layernorm
import
LayerNorm
assert
isinstance
(
mha
.
q_norm
,
LayerNorm
),
"q_norm should be a LayerNorm module"
assert
isinstance
(
mha
.
k_norm
,
LayerNorm
),
"k_norm should be a LayerNorm module"
# For LayerNorm, q_norm and k_norm should be separate instances
assert
(
mha
.
q_norm
is
not
mha
.
k_norm
),
"q_norm and k_norm should be separate instances for LayerNorm"
else
:
# For extensibility - just ensure they exist
assert
mha
.
q_norm
is
not
None
,
f
"q_norm should exist for qk_norm_type=
{
qk_norm_type
}
"
assert
mha
.
k_norm
is
not
None
,
f
"k_norm should exist for qk_norm_type=
{
qk_norm_type
}
"
else
:
assert
not
hasattr
(
mha
,
"qk_norm"
),
"Should not have qk_norm module when use_qk_norm=False"
assert
mha
.
q_norm
is
None
,
"Should not have q_norm module when qk_norm_type is None"
assert
mha
.
k_norm
is
None
,
"Should not have k_norm module when qk_norm_type is None"
# Create input tensors
batch_size
=
2
# Use a fixed batch size for testing
...
...
@@ -89,17 +123,14 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
assert
not
torch
.
isinf
(
output_with_rope
).
any
(),
"RoPE output contains Inf"
def
test_qk_norm_output_difference
()
->
None
:
@
pytest
.
mark
.
parametrize
(
"qk_norm_type"
,
[
"L2Normalization"
,
"RMSNorm"
,
"LayerNorm"
])
def
test_qk_norm_output_difference
(
qk_norm_type
)
->
None
:
"""Test that QK normalization actually changes the output compared to no normalization."""
hidden_size
=
256
num_attention_heads
=
8
seq_len
=
128
batch_size
=
2
# Use same random seed to ensure identical weight initialization
current_rng_state
=
torch
.
get_rng_state
()
current_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
# Reset to a known seed for reproducible initialization
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
...
...
@@ -108,7 +139,7 @@ def test_qk_norm_output_difference() -> None:
mha_with_norm
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
use_
qk_norm
=
Tru
e
,
qk_norm
_type
=
qk_norm_typ
e
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
...
...
@@ -121,7 +152,7 @@ def test_qk_norm_output_difference() -> None:
mha_no_norm
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
use_
qk_norm
=
Fals
e
,
qk_norm
_type
=
Non
e
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
...
...
@@ -139,10 +170,11 @@ def test_qk_norm_output_difference() -> None:
# Outputs should be different when QK normalization is enabled
assert
not
torch
.
allclose
(
output_with_norm
,
output_no_norm
,
atol
=
1e-6
),
"QK normalization should change the output, but outputs are identical"
),
f
"QK normalization
(
{
qk_norm_type
}
)
should change the output, but outputs are identical"
def
test_qk_norm_with_fused_qkv
()
->
None
:
@
pytest
.
mark
.
parametrize
(
"qk_norm_type"
,
[
"L2Normalization"
,
"RMSNorm"
,
"LayerNorm"
])
def
test_qk_norm_with_fused_qkv
(
qk_norm_type
)
->
None
:
"""Test QK normalization works with fused QKV parameters."""
hidden_size
=
256
num_attention_heads
=
8
...
...
@@ -152,7 +184,7 @@ def test_qk_norm_with_fused_qkv() -> None:
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
fuse_qkv_params
=
True
,
use_
qk_norm
=
Tru
e
,
qk_norm
_type
=
qk_norm_typ
e
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
...
...
@@ -173,7 +205,8 @@ def test_qk_norm_with_fused_qkv() -> None:
),
f
"Output shape mismatch:
{
output
.
shape
}
"
def
test_qk_norm_transformer_layer_output_difference
()
->
None
:
@
pytest
.
mark
.
parametrize
(
"qk_norm_type"
,
[
"L2Normalization"
,
"RMSNorm"
,
"LayerNorm"
])
def
test_qk_norm_transformer_layer_output_difference
(
qk_norm_type
)
->
None
:
"""Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
from
transformer_engine.pytorch
import
TransformerLayer
...
...
@@ -183,10 +216,6 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
seq_len
=
128
batch_size
=
2
# Use same random seed to ensure identical weight initialization
current_rng_state
=
torch
.
get_rng_state
()
current_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
# Reset to a known seed for reproducible initialization
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
...
...
@@ -196,7 +225,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
hidden_size
=
hidden_size
,
ffn_hidden_size
=
ffn_hidden_size
,
num_attention_heads
=
num_attention_heads
,
use_
qk_norm
=
Tru
e
,
qk_norm
_type
=
qk_norm_typ
e
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
...
...
@@ -210,7 +239,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
hidden_size
=
hidden_size
,
ffn_hidden_size
=
ffn_hidden_size
,
num_attention_heads
=
num_attention_heads
,
use_
qk_norm
=
Fals
e
,
qk_norm
_type
=
Non
e
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
...
...
@@ -226,9 +255,10 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
output_no_norm
=
transformer_no_norm
(
hidden_states
)
# Outputs should be different when QK normalization is enabled
assert
not
torch
.
allclose
(
output_with_norm
,
output_no_norm
,
atol
=
1e-6
),
"QK normalization should change the TransformerLayer output, but outputs are identical"
assert
not
torch
.
allclose
(
output_with_norm
,
output_no_norm
,
atol
=
1e-6
),
(
f
"QK normalization (
{
qk_norm_type
}
) should change the TransformerLayer output, but outputs"
" are identical"
)
# Check that outputs have expected shapes and properties
assert
output_with_norm
.
shape
==
(
...
...
@@ -240,3 +270,120 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
assert
not
torch
.
isinf
(
output_with_norm
).
any
(),
"Output with QK norm contains Inf"
assert
not
torch
.
isnan
(
output_no_norm
).
any
(),
"Output without QK norm contains NaN"
assert
not
torch
.
isinf
(
output_no_norm
).
any
(),
"Output without QK norm contains Inf"
@
pytest
.
mark
.
parametrize
(
"qk_norm_type"
,
[
"L2Normalization"
,
"RMSNorm"
,
"LayerNorm"
])
def
test_qk_norm_before_after_rope
(
qk_norm_type
)
->
None
:
"""Test that QK normalization before and after RoPE works without errors."""
hidden_size
=
256
num_attention_heads
=
8
seq_len
=
64
batch_size
=
2
# Create model with QK norm after RoPE (default)
mha_after
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
qk_norm_type
=
qk_norm_type
,
qk_norm_before_rope
=
False
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Create model with QK norm before RoPE
mha_before
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
qk_norm_type
=
qk_norm_type
,
qk_norm_before_rope
=
True
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
hidden_states
=
torch
.
randn
(
seq_len
,
batch_size
,
hidden_size
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# Create RoPE embeddings
head_dim
=
hidden_size
//
num_attention_heads
rotary_dim
=
head_dim
//
2
rotary_pos_emb
=
torch
.
randn
(
seq_len
,
1
,
1
,
rotary_dim
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
with
torch
.
no_grad
():
output_after_rope
=
mha_after
(
hidden_states
,
rotary_pos_emb
=
rotary_pos_emb
)
output_before_rope
=
mha_before
(
hidden_states
,
rotary_pos_emb
=
rotary_pos_emb
)
output_after_no_rope
=
mha_after
(
hidden_states
)
output_before_no_rope
=
mha_before
(
hidden_states
)
# Check output shapes and properties
expected_shape
=
(
seq_len
,
batch_size
,
hidden_size
)
for
output
in
[
output_after_rope
,
output_before_rope
,
output_after_no_rope
,
output_before_no_rope
,
]:
assert
output
.
shape
==
expected_shape
,
f
"Output shape mismatch:
{
output
.
shape
}
"
assert
not
torch
.
isnan
(
output
).
any
(),
"Output contains NaN"
assert
not
torch
.
isinf
(
output
).
any
(),
"Output contains Inf"
assert
output_after_rope
.
shape
==
output_before_rope
.
shape
,
"Outputs should have same shape"
assert
mha_after
.
qk_norm_before_rope
==
False
,
"mha_after should have qk_norm_before_rope=False"
assert
mha_before
.
qk_norm_before_rope
==
True
,
"mha_before should have qk_norm_before_rope=True"
def
test_different_qk_norm_types_produce_different_outputs
()
->
None
:
"""Test that different QK normalization types produce different outputs."""
hidden_size
=
256
num_attention_heads
=
8
seq_len
=
128
batch_size
=
2
# Use same random seed to ensure identical weight initialization
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
# Create model with L2 normalization
mha_l2
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
qk_norm_type
=
"L2Normalization"
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Reset to same seed for identical initialization
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
# Create model with RMS normalization
mha_rms
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
qk_norm_type
=
"RMSNorm"
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Create input tensors
hidden_states
=
torch
.
randn
(
seq_len
,
batch_size
,
hidden_size
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# Compare outputs with identical weights but different QK norm types
with
torch
.
no_grad
():
output_l2
=
mha_l2
(
hidden_states
)
output_rms
=
mha_rms
(
hidden_states
)
# Outputs should be different when using different normalization types
assert
not
torch
.
allclose
(
output_l2
,
output_rms
,
atol
=
1e-6
),
"L2 and RMS normalization should produce different outputs, but outputs are identical"
# Check that outputs have expected shapes and properties
assert
output_l2
.
shape
==
output_rms
.
shape
,
"L2 and RMS outputs should have same shape"
assert
not
torch
.
isnan
(
output_l2
).
any
(),
"L2 output contains NaN"
assert
not
torch
.
isinf
(
output_l2
).
any
(),
"L2 output contains Inf"
assert
not
torch
.
isnan
(
output_rms
).
any
(),
"RMS output contains NaN"
assert
not
torch
.
isinf
(
output_rms
).
any
(),
"RMS output contains Inf"
tests/pytorch/test_recipe.py
View file @
87e3e56e
...
...
@@ -192,12 +192,6 @@ class TestFP8Recipe:
amax_compute_algo
=
amax_compute_algo
,
)
# Get FP8 meta tensors
with
te
.
fp8_autocast
(
fp8_recipe
=
recipe
):
x_fp8_meta
=
op
.
get_quantizer
(
"forward"
,
0
)
w_fp8_meta
=
op
.
get_quantizer
(
"forward"
,
1
)
dy_fp8_meta
=
op
.
get_quantizer
(
"backward"
,
0
)
# Perform training steps
x_history
=
[]
w_history
=
[]
...
...
@@ -229,19 +223,30 @@ class TestFP8Recipe:
y
=
op
(
x
)
y
.
backward
(
dy
)
def
check_amax_history
(
fp8_meta
:
dict
,
ref_amax_history
:
Iterable
[
float
],
)
->
None
:
"""Check that amax history matches expected values"""
if
len
(
ref_amax_history
)
>
amax_history_len
:
ref_amax_history
=
ref_amax_history
[
-
amax_history_len
:]
def
check_metas
(
test_scale
:
float
,
test_amax_history
:
torch
.
Tensor
,
ref_amax_history_list
:
list
[
float
],
stage
:
str
,
):
"""Check that meta tensors match expected values"""
# Compute amax
if
len
(
ref_amax_history_list
)
>
amax_history_len
:
ref_amax_history_list
=
ref_amax_history_list
[
-
(
amax_history_len
+
1
)
:]
ref_amax_history
=
torch
.
tensor
(
ref_amax_history
,
ref_amax_history
_list
,
dtype
=
torch
.
float32
,
device
=
device
,
)
test_amax_history
=
fp8_meta
.
amax_history
[:,
0
]
if
amax_compute_algo
==
"max"
:
ref_amax
=
max
(
ref_amax_history_list
)
elif
amax_compute_algo
==
"most_recent"
:
ref_amax
=
ref_amax_history_list
[
-
1
]
else
:
raise
RuntimeError
(
f
"
{
amax_compute_algo
=
}
is not supported"
)
# Compare amax history
tols
=
dict
(
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
test_amax_history
[
-
(
step
+
1
)
:],
...
...
@@ -249,23 +254,6 @@ class TestFP8Recipe:
**
tols
,
)
def
check_scale
(
quantizer
:
Float8Quantizer
,
ref_amax_history
:
Iterable
[
float
],
stage
:
str
,
):
"""Check that scale and scale reciprocal match expected values"""
# Compute amax
if
len
(
ref_amax_history
)
>
amax_history_len
:
ref_amax_history
=
ref_amax_history
[
-
(
amax_history_len
+
1
)
:]
if
amax_compute_algo
==
"max"
:
ref_amax
=
max
(
ref_amax_history
)
elif
amax_compute_algo
==
"most_recent"
:
ref_amax
=
ref_amax_history
[
-
1
]
else
:
raise
RuntimeError
(
f
"
{
amax_compute_algo
=
}
is not supported"
)
# Compute scale
max_val
=
{
"forward"
:
448.0
,
...
...
@@ -273,16 +261,26 @@ class TestFP8Recipe:
}[
stage
]
ref_scale
=
(
max_val
/
ref_amax
)
/
(
2
**
margin
)
# C
heck values in FP8 meta tensors
# C
ompare scale
torch
.
testing
.
assert_close
(
quantizer
.
scale
.
item
()
,
test_scale
,
ref_scale
,
)
# Get scaling factors
x_test_scale
=
op
.
get_quantizer
(
"forward"
,
0
).
scale
.
item
()
w_test_scale
=
op
.
get_quantizer
(
"forward"
,
1
).
scale
.
item
()
dy_test_scale
=
op
.
get_quantizer
(
"backward"
,
0
).
scale
.
item
()
# Get amax histories
x_test_history
=
op
.
_fp8_metas
[
"forward"
][
forward_key
].
amax_history
[:,
0
]
w_test_history
=
op
.
_fp8_metas
[
"forward"
][
forward_key
].
amax_history
[:,
1
]
dy_test_history
=
op
.
_fp8_metas
[
"backward"
][
backward_key
].
amax_history
[:,
0
]
# Check that results match expected values
check_
scale
(
x_fp8_meta
,
x_history
,
"forward"
)
check_
scale
(
w_fp8_meta
,
w_history
,
"forward"
)
check_scale
(
dy_
fp8_meta
,
dy_history
,
"backward"
)
check_
metas
(
x_test_scale
,
x_test_history
,
x_history
,
"forward"
)
check_
metas
(
w_test_scale
,
w_test_history
,
w_history
,
"forward"
)
check_
metas
(
dy_test_
scale
,
dy_
test_history
,
dy_history
,
"backward"
)
@
pytest
.
mark
.
parametrize
(
"amax_case"
,
[
"zero"
,
"tiny"
,
"normal"
,
"inf"
,
"nan"
])
@
pytest
.
mark
.
parametrize
(
"fused_update"
,
[
True
,
False
],
ids
=
[
"fused"
,
"non-fused"
])
...
...
tests/pytorch/test_sanity.py
View file @
87e3e56e
...
...
@@ -2,9 +2,7 @@
#
# See LICENSE for license information.
from
dataclasses
import
dataclass
from
typing
import
Optional
from
contextlib
import
nullcontext
import
torch
import
pytest
...
...
@@ -18,11 +16,9 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init
,
)
from
transformer_engine.pytorch.utils
import
(
get_device_compute_capability
,
init_method_normal
,
scaled_init_method_normal
,
is_bf16_compatible
,
get_cudnn_version
,
)
from
transformer_engine.pytorch
import
(
LayerNormLinear
,
...
...
@@ -32,7 +28,6 @@ from transformer_engine.pytorch import (
TransformerLayer
,
RMSNorm
,
LayerNorm
,
get_cpu_offload_context
,
)
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
...
...
@@ -47,21 +42,17 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.distributed
import
checkpoint
from
utils
import
dtype_tols
from
utils
import
ModelConfig
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Record initial RNG state from script run.
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
NVTE_TEST_NVINSPECT_ENABLED
=
int
(
os
.
environ
.
get
(
"NVTE_TEST_NVINSPECT_ENABLED"
,
"0"
))
...
...
@@ -79,88 +70,33 @@ if NVTE_TEST_NVINSPECT_ENABLED:
)
def
create_meta
(
scale_factor
:
float
,
size
:
int
=
1
):
meta
=
tex
.
FP8TensorMeta
()
meta
.
amax_history
=
torch
.
zeros
(
1
,
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
meta
.
scale_inv
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
/
scale_factor
meta
.
scale
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
scale_factor
return
meta
if
IS_HIP_EXTENSION
:
from
functools
import
cache
@
cache
def
use_hipblaslt
()
->
bool
:
return
(
os
.
getenv
(
"NVTE_USE_HIPBLASLT"
)
is
not
None
or
os
.
getenv
(
"NVTE_USE_ROCBLAS"
)
is
None
)
def
custom_amax_to_scale
(
amax
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
fp8_max
:
torch
.
Tensor
,
recipe
:
recipe
.
DelayedScaling
,
)
->
torch
.
Tensor
:
"""Custom func to test recipe."""
sf
=
fp8_max
/
amax
sf
=
torch
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
torch
.
where
(
torch
.
isfinite
(
amax
),
sf
,
scale
)
return
sf
def
custom_amax_compute
(
amax_history
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Custom func to test recipe."""
return
torch
.
min
(
amax_history
,
dim
=
0
).
values
def
reset_rng_states
()
->
None
:
"""revert back to initial RNG state."""
global
_cpu_rng_state
,
_cuda_rng_state
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
@
dataclass
class
ModelConfig
:
"""Transformer model configuration"""
num_layers
:
int
seq_len
:
int
batch_size
:
int
hidden_size
:
int
num_attention_heads
:
int
kv_channels
:
Optional
[
int
]
=
None
def
is_fp8_supported
(
self
):
if
self
.
seq_len
*
self
.
batch_size
%
16
:
return
False
if
self
.
hidden_size
%
16
:
return
False
return
True
def
is_fp8_supported
(
config
:
ModelConfig
):
if
(
config
.
max_seqlen_q
*
config
.
batch_size
%
16
or
config
.
max_seqlen_kv
*
config
.
batch_size
%
16
):
return
False
if
config
.
hidden_size
%
16
or
config
.
hidden_size_kv
%
16
:
return
False
return
True
model_configs
=
{
"126m"
:
ModelConfig
(
1
2
,
2048
,
2
,
768
,
12
),
"small"
:
ModelConfig
(
2
,
32
,
2
,
64
,
2
),
"weird"
:
ModelConfig
(
2
,
37
,
3
,
69
,
3
),
"large"
:
ModelConfig
(
1
,
128
,
2
,
5
12
,
4
,
128
),
"126m"
:
ModelConfig
(
2
,
2048
,
1
2
,
64
,
num_layers
=
12
),
"small"
:
ModelConfig
(
2
,
32
,
2
,
32
,
num_layers
=
2
),
"weird"
:
ModelConfig
(
3
,
37
,
3
,
23
,
num_layers
=
2
),
"large"
:
ModelConfig
(
2
,
128
,
4
,
12
8
,
num_layers
=
1
),
}
fp8_recipes
=
[
None
,
# Test non-FP8
recipe
.
MXFP8BlockScaling
(),
# Test default
recipe
.
Float8CurrentScaling
(),
# Test default
recipe
.
Float8BlockScaling
(),
# Test default
recipe
.
DelayedScaling
(),
# Test default
recipe
.
DelayedScaling
(
# Test most_recent algo
amax_history_len
=
16
,
amax_compute_algo
=
"most_recent"
,
),
recipe
.
DelayedScaling
(
# Test custom amax and scale compute algo
fp8_format
=
recipe
.
Format
.
E4M3
,
amax_compute_algo
=
custom_amax_compute
,
scaling_factor_compute_algo
=
custom_amax_to_scale
,
),
]
fp8_recipes
=
[]
if
mxfp8_available
:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
fp8_block_scaling_available
:
fp8_recipes
.
append
(
recipe
.
Float8BlockScaling
())
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
None
)
param_types
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_compatible
():
# bf16 requires sm_80 or higher
...
...
@@ -184,66 +120,9 @@ def reset_global_fp8_state():
FP8GlobalStateManager
.
reset
()
def
_test_sanity_e2e_cuda_graph
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
# Initialize loss function and optimizer.
loss_fn
=
torch
.
nn
.
MSELoss
()
optimizer
=
torch
.
optim
.
SGD
(
block
.
parameters
(),
lr
=
0.1
)
# Placeholders used for capture.
static_input
=
torch
.
randn
(
config
.
seq_len
,
config
.
batch_size
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
,
)
static_target
=
torch
.
randn
(
config
.
seq_len
,
config
.
batch_size
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
)
real_input
=
torch
.
rand_like
(
static_input
)
real_target
=
torch
.
rand_like
(
static_target
)
use_fp8
=
fp8_recipe
is
not
None
if
skip_wgrad
:
_disable_wgrads
(
block
)
# Pre graph capture warmup in a separate stream.
s
=
torch
.
cuda
.
Stream
()
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
3
):
optimizer
.
zero_grad
(
set_to_none
=
True
)
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
,
_graph
=
True
):
out
=
block
(
static_input
)
loss
=
loss_fn
(
out
,
static_target
)
loss
.
backward
()
optimizer
.
step
()
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
# Capture.
g
=
torch
.
cuda
.
CUDAGraph
()
optimizer
.
zero_grad
(
set_to_none
=
True
)
with
torch
.
cuda
.
graph
(
g
):
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
,
_graph
=
True
):
static_output
=
block
(
static_input
)
static_loss
=
loss_fn
(
static_output
,
static_target
)
static_loss
.
backward
()
optimizer
.
step
()
# Fills the graph's input memory with new data to compute on
with
torch
.
no_grad
():
static_input
.
copy_
(
real_input
)
static_target
.
copy_
(
real_target
)
g
.
replay
()
torch
.
cuda
.
synchronize
()
def
_test_sanity_e2e_amp
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
True
,
...
...
@@ -251,7 +130,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
torch
.
randint
(
2
,
(
1
,
1
,
config
.
seq
_
len
,
config
.
seq
_
len
),
(
1
,
1
,
config
.
max_
seqlen
_q
,
config
.
max_
seqlen
_kv
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
)
...
...
@@ -278,14 +157,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
def
_test_sanity_e2e_gradient_accumulation_fusion
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
te_inp_attn_mask
=
torch
.
randint
(
2
,
(
1
,
1
,
config
.
seq
_
len
,
config
.
seq
_
len
),
(
1
,
1
,
config
.
max_
seqlen
_q
,
config
.
max_
seqlen
_kv
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
)
...
...
@@ -316,9 +195,9 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert
len
(
failed_grads
)
==
0
,
f
"Gradient not accumulated for
{
failed_grads
}
."
def
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
cpu_offload
):
def
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
...
...
@@ -327,16 +206,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
if
skip_wgrad
:
_disable_wgrads
(
block
)
if
cpu_offload
:
offload_context
,
sync_function
=
get_cpu_offload_context
(
enabled
=
True
)
else
:
offload_context
=
nullcontext
()
sync_function
=
lambda
x
:
x
use_fp8
=
fp8_recipe
is
not
None
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
)
,
offload_context
:
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
te_out
=
block
(
te_inp_hidden_states
)
te_out
=
sync_function
(
te_out
)
loss
=
te_out
.
sum
()
loss
.
backward
()
torch
.
cuda
.
synchronize
()
...
...
@@ -344,7 +216,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
def
_test_sanity_e2e_bert
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
...
...
@@ -352,7 +224,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_attn_mask
=
torch
.
randint
(
2
,
(
config
.
batch_size
,
1
,
1
,
config
.
seq
_
len
),
(
config
.
batch_size
,
1
,
1
,
config
.
max_
seqlen
_q
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
)
...
...
@@ -370,21 +242,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
def
_test_sanity_e2e_T5
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
te_inp_attn_mask
=
torch
.
randint
(
2
,
(
1
,
1
,
config
.
seq
_
len
,
config
.
seq
_
len
),
(
1
,
1
,
config
.
max_
seqlen
_q
,
config
.
max_
seqlen
_kv
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
)
enc_dec_attn_mask
=
torch
.
randint
(
2
,
(
config
.
batch_size
,
1
,
1
,
config
.
seq
_
len
),
(
config
.
batch_size
,
1
,
1
,
config
.
max_
seqlen
_kv
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
)
...
...
@@ -412,7 +284,7 @@ def _test_sanity_common(
pytest
.
skip
(
"No gradient computation; Skipping to avoid PyTorch RuntimeError."
)
te_inp
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
not
skip_dgrad
,
...
...
@@ -440,7 +312,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
pytest
.
skip
(
"No gradient computation; Skipping to avoid PyTorch RuntimeError."
)
te_inp
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
device
=
"cuda"
,
requires_grad
=
True
,
)
...
...
@@ -495,13 +367,7 @@ def test_sanity_layernorm_linear(
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
...
...
@@ -529,13 +395,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
...
...
@@ -562,16 +422,10 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest
.
skip
(
"Quantized model parameters are not supported in debug mode."
)
config
=
model_configs
[
model
]
ffn_hidden_size
=
4
*
config
.
hidden_size
num_tokens
=
bs
*
config
.
seq
_
len
num_tokens
=
bs
*
config
.
max_
seqlen
_q
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
use_fp8
=
fp8_recipe
is
not
None
...
...
@@ -607,16 +461,10 @@ def test_sanity_grouped_linear(
ffn_hidden_size
=
4
*
config
.
hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
bs
=
bs
*
16
num_tokens
=
bs
*
config
.
seq
_
len
*
(
num_gemms
-
1
)
num_tokens
=
bs
*
config
.
max_
seqlen
_q
*
(
num_gemms
-
1
)
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
use_fp8
=
fp8_recipe
is
not
None
...
...
@@ -628,7 +476,7 @@ def test_sanity_grouped_linear(
inp_hidden_states
=
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
m_splits
=
[
bs
*
config
.
seq
_
len
]
*
num_gemms
m_splits
=
[
bs
*
config
.
max_
seqlen
_q
]
*
num_gemms
if
empty_split
==
"first"
:
m_splits
[
0
]
=
0
elif
empty_split
==
"last"
:
...
...
@@ -666,13 +514,7 @@ def test_sanity_layernorm_mlp(
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
...
...
@@ -697,36 +539,24 @@ def test_sanity_layernorm_mlp(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
all_activations
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"gelu"
,
"swiglu"
]
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"cpu_offload"
,
all_boolean
)
def
test_sanity_gpt
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
bias
,
activation
,
normalization
,
parallel_attention_mlp
,
cpu_offload
,
):
if
cpu_offload
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"CPU offload is not supported in debug mode."
)
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
...
...
@@ -736,7 +566,7 @@ def test_sanity_gpt(
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
...
...
@@ -745,7 +575,6 @@ def test_sanity_gpt(
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
zero_centered_gamma
=
zero_centered_gamma
,
bias
=
bias
,
activation
=
activation
,
normalization
=
normalization
,
...
...
@@ -753,7 +582,7 @@ def test_sanity_gpt(
parallel_attention_mlp
=
parallel_attention_mlp
,
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
cpu_offload
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
)
def
test_sanity_gpt_126m
():
...
...
@@ -770,12 +599,10 @@ def test_sanity_gpt_126m():
fp8_recipe
=
fp8_recipe
,
model
=
"126m"
,
skip_wgrad
=
False
,
zero_centered_gamma
=
True
,
bias
=
True
,
activation
=
"gelu"
,
normalization
=
"LayerNorm"
,
parallel_attention_mlp
=
False
,
cpu_offload
=
False
,
)
...
...
@@ -783,19 +610,14 @@ def test_sanity_gpt_126m():
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_sanity_bert
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
normalization
):
def
test_sanity_bert
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
normalization
):
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
...
...
@@ -805,7 +627,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
...
...
@@ -814,7 +636,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
True
,
output_layernorm
=
True
,
zero_centered_gamma
=
zero_centered_gamma
,
self_attn_mask_type
=
"causal"
,
normalization
=
normalization
,
device
=
"cuda"
,
...
...
@@ -835,7 +656,6 @@ def test_sanity_bert_126m():
fp8_recipe
=
fp8_recipe
,
model
=
"126m"
,
skip_wgrad
=
False
,
zero_centered_gamma
=
False
,
normalization
=
"LayerNorm"
,
)
...
...
@@ -844,19 +664,14 @@ def test_sanity_bert_126m():
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_sanity_T5
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
normalization
):
def
test_sanity_T5
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
normalization
):
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
...
...
@@ -866,7 +681,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
...
...
@@ -876,7 +691,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
layer_type
=
"decoder"
,
zero_centered_gamma
=
zero_centered_gamma
,
normalization
=
normalization
,
device
=
"cuda"
,
)
...
...
@@ -896,7 +710,6 @@ def test_sanity_T5_126m():
fp8_recipe
=
fp8_recipe
,
model
=
"126m"
,
skip_wgrad
=
False
,
zero_centered_gamma
=
False
,
normalization
=
"LayerNorm"
,
)
...
...
@@ -909,13 +722,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
...
...
@@ -925,7 +732,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
...
...
@@ -941,18 +748,11 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
def
test_sanity_drop_path
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
):
def
test_sanity_drop_path
(
dtype
,
fp8_recipe
,
model
):
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
...
...
@@ -962,7 +762,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
...
...
@@ -975,7 +775,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
device
=
"cuda"
,
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
False
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
False
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
...
@@ -986,13 +786,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
...
...
@@ -1002,7 +796,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
...
...
@@ -1015,27 +809,18 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
device
=
"cuda"
,
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
False
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
def
test_sanity_gradient_accumulation_fusion
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
):
def
test_sanity_gradient_accumulation_fusion
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
):
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
...
...
@@ -1045,7 +830,7 @@ def test_sanity_gradient_accumulation_fusion(
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
...
...
@@ -1054,7 +839,6 @@ def test_sanity_gradient_accumulation_fusion(
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
zero_centered_gamma
=
zero_centered_gamma
,
fuse_qkv_params
=
True
,
fuse_wgrad_accumulation
=
True
,
device
=
"cuda"
,
...
...
@@ -1063,56 +847,6 @@ def test_sanity_gradient_accumulation_fusion(
_test_sanity_e2e_gradient_accumulation_fusion
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_gpt_cuda_graph
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
normalization
):
if
IS_HIP_EXTENSION
:
if
not
use_hipblaslt
():
pytest
.
skip
(
"CUDA graph capture not supported with rocBLAS path"
)
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
():
pytest
.
skip
(
"cuda graph not supported for float8_block_scaling recipe"
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_attention_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
kv_channels
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
zero_centered_gamma
=
zero_centered_gamma
,
fuse_qkv_params
=
True
,
normalization
=
normalization
,
device
=
"cuda"
,
)
_test_sanity_e2e_cuda_graph
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
)
def
test_model_multiple_cast
():
a
=
torch
.
zeros
((
16
,
16
),
device
=
"cuda"
)
m
=
Linear
(
16
,
32
)
...
...
@@ -1167,133 +901,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
torch
.
cuda
.
synchronize
()
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_sanity_attention_extra_state
(
model
,
dtype
):
config
=
model_configs
[
model
]
outputs
=
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
True
)
outputs_checkpoint_v1_6
=
_run_attention_extra_state
(
dtype
,
config
,
mimic_v1_6
=
True
,
checkpoint
=
True
)
# Check that results match
tols
=
dtype_tols
(
dtype
)
if
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
tols
.
update
(
dict
(
rtol
=
2e-2
,
atol
=
2e-3
))
for
i
,
(
ref
,
test
)
in
enumerate
(
zip
(
outputs
,
outputs_checkpoint
)):
torch
.
testing
.
assert_close
(
test
,
ref
,
**
tols
,
)
for
i
,
(
ref
,
test
)
in
enumerate
(
zip
(
outputs
,
outputs_checkpoint_v1_6
)):
torch
.
testing
.
assert_close
(
test
,
ref
,
**
tols
,
)
def
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
False
,
mimic_v1_6
=
False
):
steps
=
10
path
=
"checkpoint.pt"
fp8_enabled
=
True
fp8_recipe
=
recipe
.
DelayedScaling
(
margin
=
0
,
fp8_format
=
recipe
.
Format
.
HYBRID
,
amax_history_len
=
1
,
amax_compute_algo
=
"most_recent"
,
fp8_dpa
=
fp8_enabled
,
fp8_mha
=
False
,
)
reset_rng_states
()
hidden_states
=
torch
.
randn
(
(
config
.
seq_len
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
def
get_model
(
dtype
,
config
):
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
with
fp8_model_init
(
enabled
=
fp8_enabled
,
recipe
=
fp8_recipe
):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_attention_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.0
,
attention_dropout
=
0.0
,
fuse_qkv_params
=
True
,
params_dtype
=
dtype
,
device
=
"cuda"
,
)
return
block
block
=
get_model
(
dtype
,
config
)
for
i
in
range
(
steps
//
2
):
with
fp8_autocast
(
enabled
=
fp8_enabled
,
fp8_recipe
=
fp8_recipe
):
output
=
block
(
hidden_states
,
None
)
loss
=
output
.
sum
()
loss
.
backward
()
if
checkpoint
:
sd
=
block
.
state_dict
()
if
mimic_v1_6
:
sd
[
"self_attention.core_attention.fused_attention._extra_state"
]
=
sd
[
"self_attention.core_attention._extra_state"
]
del
sd
[
"self_attention.core_attention._extra_state"
]
torch
.
save
(
sd
,
path
)
param_grads
=
[]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
param_grads
.
append
(
p
.
grad
.
clone
())
_cpu_rng_state_new
=
torch
.
get_rng_state
()
_cuda_rng_state_new
=
torch
.
cuda
.
get_rng_state
()
del
block
block
=
get_model
(
dtype
,
config
)
block
.
load_state_dict
(
torch
.
load
(
path
,
weights_only
=
False
))
torch
.
set_rng_state
(
_cpu_rng_state_new
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state_new
)
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
p
.
grad
=
param_grads
.
pop
(
0
)
assert
not
param_grads
,
"Oops!"
for
i
in
range
((
steps
+
1
)
//
2
):
with
fp8_autocast
(
enabled
=
fp8_enabled
,
fp8_recipe
=
fp8_recipe
):
output
=
block
(
hidden_states
,
None
)
loss
=
output
.
sum
()
loss
.
backward
()
torch
.
cuda
.
synchronize
()
if
os
.
path
.
exists
(
path
):
os
.
remove
(
path
)
outputs
=
[
output
,
hidden_states
.
grad
]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
outputs
.
append
(
p
.
grad
)
return
outputs
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_replace_raw_data_for_float8tensor
():
"""Test the functionality of replace_raw_data"""
...
...
@@ -1389,6 +996,32 @@ def test_sanity_checkpointing_on_callables():
torch
.
testing
.
assert_close
(
grad_checkpoint
,
grad_standard
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_linear_frozen_weights_memory_default_recipe
():
"""Test that memory usage is optimized when weights are frozen for MXFP8."""
dim
=
1024
linear
=
Linear
(
dim
,
dim
,
bias
=
False
)
x
=
torch
.
randn
(
dim
,
dim
,
requires_grad
=
True
,
device
=
"cuda"
)
# Freeze weights
linear
.
weight
.
requires_grad
=
False
# Forward and backward pass with FP8
with
fp8_autocast
():
o
=
linear
(
x
)
g_o
=
torch
.
randn_like
(
o
)
max_memory_before_backward
=
torch
.
cuda
.
max_memory_allocated
()
o
.
backward
(
g_o
)
max_memory_after_backward
=
torch
.
cuda
.
max_memory_allocated
()
memory_diff
=
(
max_memory_after_backward
-
max_memory_before_backward
)
/
1e6
assert
memory_diff
<
5.5
,
(
f
"Memory usage with frozen weights (
{
memory_diff
}
MB) should be less than 5.5MB as the"
" grad_output should be quantized only columnwise."
)
@
pytest
.
mark
.
parametrize
(
"module_name"
,
(
"Linear"
,
"LayerNormLinear"
,
"LayerNormMLP"
,
"GroupedLinear"
,
"ops.Linear"
),
...
...
tests/pytorch/utils.py
View file @
87e3e56e
...
...
@@ -4,12 +4,24 @@
from
__future__
import
annotations
import
logging
import
os
from
contextlib
import
contextmanager
import
pytest
import
torch
import
transformer_engine
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
get_attention_backend
,
AttentionParams
,
AttentionLogging
,
)
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
def
str_to_dtype
(
dtype
:
str
|
torch
.
dtype
)
->
torch
.
dtype
:
...
...
@@ -106,3 +118,178 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if
name
==
"fp8_block_scaling"
:
return
transformer_engine
.
common
.
recipe
.
Float8BlockScaling
()
raise
ValueError
(
f
"Unsupported quantization scheme (
{
name
}
)"
)
# Cached RNG state
_rng_states
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
def
reset_rng_states
()
->
None
:
"""Revert to deterministic RNG state"""
global
_rng_states
if
_rng_states
is
None
:
torch
.
manual_seed
(
1234
)
torch
.
cuda
.
manual_seed
(
1234
)
_rng_states
=
(
torch
.
get_rng_state
(),
torch
.
cuda
.
get_rng_state
())
else
:
cpu_rng_state
,
cuda_rng_state
=
_rng_states
torch
.
set_rng_state
(
cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
class
ModelConfig
:
def
__init__
(
self
,
batch_size
:
int
,
max_seqlen_q
:
int
,
num_heads
:
int
,
head_dim_qk
:
int
,
max_seqlen_kv
:
int
=
None
,
num_gqa_groups
:
int
=
None
,
head_dim_v
:
int
=
None
,
dropout_p
:
float
=
0.0
,
attn_mask_type
:
str
=
"no_mask"
,
attn_bias_type
:
str
=
"no_bias"
,
alibi_type
:
str
=
"none"
,
bias_shape
:
str
=
"1hss"
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
total_requests
:
int
=
None
,
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
eps
:
float
=
1e-5
,
):
self
.
batch_size
=
batch_size
self
.
max_seqlen_q
=
max_seqlen_q
self
.
max_seqlen_kv
=
max_seqlen_q
if
max_seqlen_kv
is
None
else
max_seqlen_kv
self
.
num_heads
=
num_heads
self
.
num_gqa_groups
=
num_heads
if
num_gqa_groups
is
None
else
num_gqa_groups
self
.
head_dim_qk
=
head_dim_qk
self
.
head_dim_v
=
head_dim_qk
if
head_dim_v
is
None
else
head_dim_v
if
self
.
head_dim_qk
==
self
.
head_dim_v
:
self
.
kv_channels
=
self
.
head_dim_qk
else
:
self
.
kv_channels
=
(
self
.
head_dim_qk
,
self
.
head_dim_v
)
self
.
hidden_size
=
self
.
num_heads
*
self
.
head_dim_qk
self
.
hidden_size_kv
=
self
.
num_gqa_groups
*
self
.
head_dim_v
self
.
dropout_p
=
dropout_p
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_bias_type
=
attn_bias_type
self
.
alibi_type
=
alibi_type
self
.
attn_type
=
"self"
if
(
self
.
max_seqlen_q
==
self
.
max_seqlen_kv
)
else
"cross"
self
.
bias_shape
=
bias_shape
self
.
window_size
=
window_size
self
.
total_requests
=
total_requests
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
self
.
eps
=
eps
@
contextmanager
def
logging_context
(
highest_level
=
logging
.
WARNING
):
previous_level
=
logging
.
root
.
manager
.
disable
logging
.
disable
(
highest_level
)
try
:
yield
finally
:
logging
.
disable
(
previous_level
)
def
get_available_attention_backends
(
config
:
ModelConfig
,
qkv_dtype
:
torch
.
dtype
,
qkv_layout
:
str
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
pad_between_seqs
:
bool
=
False
,
context_parallel
:
bool
=
False
,
deterministic
:
bool
=
False
,
fp8
:
bool
=
False
,
fp8_meta
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_training
:
bool
=
True
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
)
->
Tuple
[
List
,
List
]:
"""Check for all available attention backends that support a model configuration"""
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
alibi_slopes_shape
=
None
if
config
.
attn_bias_type
==
"alibi"
and
config
.
alibi_type
==
"custom"
:
if
config
.
bias_shape
==
"1hss"
:
alibi_slopes_shape
=
[
config
.
num_heads
]
if
config
.
bias_shape
==
"bhss"
:
alibi_slopes_shape
=
[
config
.
batch_size
,
config
.
num_heads
]
core_attention_bias_shape
=
(
config
.
bias_shape
if
config
.
attn_bias_type
==
"post_scale_bias"
else
None
)
core_attention_bias_requires_grad
=
False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if
(
config
.
attn_bias_type
==
"post_scale_bias"
and
config
.
head_dim_qk
<=
128
and
config
.
head_dim_v
<=
128
):
core_attention_bias_requires_grad
=
True
fused_attn_backends
=
[]
available_backends
=
None
flash_attention_backend
=
None
fused_attention_backend
=
None
def
test
():
attention_params
=
AttentionParams
(
qkv_dtype
=
qkv_dtype
,
qkv_layout
=
qkv_layout
,
batch_size
=
config
.
batch_size
,
num_heads
=
config
.
num_heads
,
num_gqa_groups
=
config
.
num_gqa_groups
,
max_seqlen_q
=
config
.
max_seqlen_q
,
max_seqlen_kv
=
config
.
max_seqlen_kv
,
head_dim_qk
=
config
.
head_dim_qk
,
head_dim_v
=
config
.
head_dim_v
,
attn_mask_type
=
config
.
attn_mask_type
,
window_size
=
window_size
,
alibi_slopes_shape
=
alibi_slopes_shape
,
core_attention_bias_type
=
config
.
attn_bias_type
,
core_attention_bias_shape
=
core_attention_bias_shape
,
core_attention_bias_requires_grad
=
core_attention_bias_requires_grad
,
pad_between_seqs
=
pad_between_seqs
,
attention_dropout
=
config
.
dropout_p
,
context_parallel
=
context_parallel
,
deterministic
=
deterministic
,
fp8
=
fp8
,
fp8_meta
=
fp8_meta
,
is_training
=
is_training
,
inference_params
=
inference_params
,
)
(
use_flash_attention
,
use_fused_attention
,
flash_attention_backend
,
fused_attention_backend
,
use_unfused_attention
,
available_backends
,
)
=
get_attention_backend
(
attention_params
)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
_attention_backends
[
"use_fused_attention"
]
=
use_fused_attention
_attention_backends
[
"flash_attention_backend"
]
=
flash_attention_backend
_attention_backends
[
"fused_attention_backend"
]
=
fused_attention_backend
_attention_backends
[
"use_unfused_attention"
]
=
use_unfused_attention
_attention_backends
[
"backend_selection_requires_update"
]
=
False
return
available_backends
,
flash_attention_backend
,
fused_attention_backend
backends
=
{
0
:
"F16_max512_seqlen"
,
1
:
"F16_arbitrary_seqlen"
,
2
:
"FP8"
}
if
AttentionLogging
.
_is_logging_setup
is
False
:
AttentionLogging
.
setup_logging
()
with
logging_context
(
highest_level
=
AttentionLogging
.
_log_level
):
for
i
in
range
(
3
):
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
str
(
i
)
_attention_backends
[
"backend_selection_requires_update"
]
=
True
available_backends
,
flash_attention_backend
,
fused_attention_backend
=
test
()
if
fused_attention_backend
==
FusedAttnBackend
[
backends
[
i
]]:
fused_attn_backends
.
append
(
fused_attention_backend
)
return
available_backends
,
flash_attention_backend
,
fused_attn_backends
transformer_engine/common/CMakeLists.txt
View file @
87e3e56e
...
...
@@ -126,6 +126,7 @@ if(USE_CUDA)
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
...
...
@@ -189,6 +190,7 @@ else()
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
...
...
@@ -347,6 +349,8 @@ if(USE_CUDA)
string_code_transpose_rtc_cast_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/swap_first_dims.cu
string_code_transpose_rtc_swap_first_dims_cu
)
make_string_header_from_file
(
utils.cuh
string_code_utils_cuh
)
else
()
...
...
@@ -358,6 +362,8 @@ else()
string_code_transpose_rtc_cast_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/swap_first_dims.cu
string_code_transpose_rtc_swap_first_dims_cu
)
endif
()
...
...
@@ -385,6 +391,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties
(
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
endif
()
...
...
transformer_engine/common/__init__.py
View file @
87e3e56e
...
...
@@ -246,6 +246,18 @@ def _load_cudnn():
if
found
:
return
handle
# Attempt to locate libcudnn via ldconfig
libs
=
subprocess
.
check_output
(
f
"ldconfig -p | grep 'libcudnn
{
_get_sys_extension
()
}
'"
,
shell
=
True
)
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
for
lib
in
libs
:
if
"libcudnn"
in
lib
and
"=>"
in
lib
:
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
if
sos
:
return
ctypes
.
CDLL
(
sos
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libcudnn
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
...
...
@@ -267,12 +279,12 @@ def _load_nvrtc():
return
handle
# Attempt to locate NVRTC via ldconfig
libs
=
subprocess
.
check_output
(
"ldconfig -p | grep 'libnvrtc'"
,
shell
=
True
)
libs
=
subprocess
.
check_output
(
f
"ldconfig -p | grep 'libnvrtc
{
_get_sys_extension
()
}
'"
,
shell
=
True
)
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
for
lib
in
libs
:
if
"stub"
in
lib
or
"libnvrtc-builtins"
in
lib
:
continue
if
"libnvrtc"
in
lib
and
"=>"
in
lib
:
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
if
sos
:
...
...
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
87e3e56e
...
...
@@ -189,14 +189,26 @@ CommOverlapCore::~CommOverlapCore() {
if
(
_atomic_gemm
)
cudaFree
(
_counter
.
dptr
());
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
cudaStreamDestroy
(
_stream_compute
[
i
]);
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
cudaStreamSynchronize
(
_stream_compute
[
i
]);
cudaStreamDestroy
(
_stream_compute
[
i
]);
}
auto
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
{
NVTE_WARN
(
"Error detected while destroying communicator: "
,
cudaGetErrorString
(
error
));
}
if
(
_comm_created
)
{
try
{
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi
(
_ub_comm
);
destroy_communicator_mpi
(
_ub_comm
);
#else
destroy_communicator
(
_ub_comm
);
destroy_communicator
(
_ub_comm
);
#endif
}
catch
(
const
std
::
exception
&
e
)
{
NVTE_WARN
(
"Error destroying communicator, cleanup may be incomplete:
\n
"
,
e
.
what
());
}
_comm_created
=
false
;
}
}
...
...
@@ -382,6 +394,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
CommOverlapBase
::~
CommOverlapBase
()
{
cudaEventDestroy
(
_start_d2dcopy
);
cudaStreamSynchronize
(
_stream_comm
);
cudaStreamDestroy
(
_stream_comm
);
}
...
...
@@ -704,6 +717,25 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
}
// CommOverlapBase::split_overlap_rs
void
CommOverlapBase
::
bulk_overlap_external_ag
(
cudaStream_t
send_stream
,
cudaStream_t
recv_stream
,
cudaStream_t
stream_main
)
{
int
comm_bytes
=
_ubuf
.
bytes
();
int
comm_bytes_per_rank
=
comm_bytes
/
_tp_size
;
// We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush
userbuffers_send_all
(
_ub_reg
,
0
,
_ub_reg
,
0
,
comm_bytes_per_rank
,
_tp_id
,
_tp_size
,
_ub_comm
,
send_stream
);
userbuffers_recv_all
(
_ub_reg
,
0
,
_ub_reg
,
0
,
comm_bytes_per_rank
,
_tp_id
,
_tp_size
,
_ub_comm
,
recv_stream
);
for
(
auto
stream
:
{
send_stream
,
recv_stream
})
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
stream
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
// We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_stop_comm
,
0
));
}
}
/***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
View file @
87e3e56e
...
...
@@ -2652,6 +2652,30 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
}
}
void
userbuffers_send_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
)
{
for
(
int
j
=
1
;
j
<
tp_size
;
j
++
)
{
int
i
=
(
tp_rank
+
j
)
%
tp_size
;
int
send_offset
=
srcoffset
+
bytes_per_slice
*
tp_rank
;
int
recv_offset
=
dstoffset
+
bytes_per_slice
*
tp_rank
;
userbuffers_send
(
srchandler
,
send_offset
,
dsthandler
,
recv_offset
,
bytes_per_slice
,
comm
,
i
,
stream
);
}
}
void
userbuffers_recv_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
)
{
for
(
int
j
=
tp_size
-
1
;
j
>
0
;
j
--
)
{
int
i
=
(
tp_rank
+
j
)
%
tp_size
;
int
send_offset
=
srcoffset
+
bytes_per_slice
*
i
;
int
recv_offset
=
dstoffset
+
bytes_per_slice
*
i
;
userbuffers_recv
(
srchandler
,
send_offset
,
dsthandler
,
recv_offset
,
bytes_per_slice
,
comm
,
i
,
stream
);
}
}
// producer
static
__global__
void
producer_kernel
(
void
*
atomic_ptr
,
int
chunk_i
)
{
// Decrement atomic val to signal current output tile finish
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
View file @
87e3e56e
...
...
@@ -312,4 +312,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp
void
reduce_bf16
(
void
*
input
,
void
*
output
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
void
userbuffers_send_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
);
void
userbuffers_recv_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
transformer_engine/common/common.cu
View file @
87e3e56e
...
...
@@ -98,6 +98,9 @@ void checkCuDriverContext(CUstream stream) {
#ifdef __HIP_PLATFORM_AMD__
return
;
#else
// Ensure the thread's "current" CUDA context is set.
cuda_driver
::
ensure_context_exists
();
CUcontext
ctx
;
const
CUresult
driver_status
=
cuda_driver
::
call
(
"cuStreamGetCtx"
,
stream
,
&
ctx
);
switch
(
driver_status
)
{
...
...
@@ -167,10 +170,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
void
*
dataPtr
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uint8_t
*>
(
tensor
.
dptr
)
+
(
offset_elems
*
type_num_bits
)
/
8
);
NVTE_CHECK
(
is_aligned_ptr
(
dataPtr
,
TMA_
gmem_alignment
),
NVTE_CHECK
(
is_aligned_ptr
(
dataPtr
,
TMA_
GMEM_ALIGNMENT
),
"Tensor data pointer must be 16B aligned"
);
const
int
TMA_needed_size
=
(
TMA_
gmem_alignment
*
8
)
/
type_num_bits
;
const
int
TMA_needed_size
=
(
TMA_
GMEM_ALIGNMENT
*
8
)
/
type_num_bits
;
NVTE_CHECK
(
globalX
%
TMA_needed_size
==
0
,
"Shape not supported. For "
,
type_num_bits
,
"-bit data type, expected multiple of "
,
TMA_needed_size
,
", got "
,
globalX
);
...
...
transformer_engine/common/common.h
View file @
87e3e56e
...
...
@@ -94,7 +94,7 @@ struct SimpleTensor {
nvte_make_shape
(
this
->
shape
.
data
(),
this
->
shape
.
size
())};
}
in
t
numel
()
const
{
size_
t
numel
()
const
{
size_t
acc
=
1
;
for
(
const
auto
&
dim
:
shape
)
{
acc
*=
dim
;
...
...
@@ -737,7 +737,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr
size_t
scale_tensor_alignment_Y_colwise
=
4
;
// Alignment requirements for the Tensor Memory Accelerator (TMA)
constexpr
int
TMA_gmem_alignment
=
16
;
// global memory address alignment
constexpr
size_t
TMA_GMEM_ALIGNMENT
=
16
;
// global memory address alignment
constexpr
size_t
TMA_SHMEM_ALIGNMENT
=
128
;
// shared memory address alignment
inline
bool
is_aligned_ptr
(
const
void
*
ptr
,
size_t
alignment
)
{
return
reinterpret_cast
<
uintptr_t
>
(
ptr
)
%
alignment
==
0
;
...
...
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
87e3e56e
...
...
@@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_MASK
||
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_MASK
)))
&&
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
&&
!
requires_64bit_ragged_offset
)
{
!
requires_64bit_ragged_offset
&&
// 9.10.0: known bugs with SDPA FP8
(
cudnn_runtime_version
!=
91000
))
{
if
(
cudnn_runtime_version
>=
8900
)
{
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_FP8
;
}
else
{
...
...
@@ -239,20 +241,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(
!
is_training
&&
sm_arch_
>=
100
&&
cudnn_runtime_version
>=
90900
&&
max_seqlen_q
>
1
&&
layout_group
!=
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
)
||
// 9.10: any head_dim + any arch + fprop + paged
// 9.10: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(
!
is_training
&&
cudnn_runtime_version
>=
9100
0
&&
// 9.10
.2
: any head_dim + any arch + fprop + paged
// 9.10
.2
: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10
.2
: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(
!
is_training
&&
cudnn_runtime_version
>=
9100
2
&&
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
||
max_seqlen_q
>
1
||
(
max_seqlen_q
==
1
&&
attn_mask_type
!=
NVTE_Mask_Type
::
NVTE_CAUSAL_MASK
&&
attn_mask_type
!=
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_MASK
)))
||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(
head_dim_qk
==
192
&&
head_dim_v
==
128
&&
is_training
&&
sm_arch_
>=
100
&&
cudnn_runtime_version
>=
91100
))
&&
// 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(
!
(
cudnn_runtime_version
==
91100
&&
is_training
&&
sm_arch_
==
90
&&
head_dim_qk
>=
128
&&
head_dim_v
>=
128
&&
!
(
head_dim_qk
=
=
1
9
2
&&
head_dim_v
=
=
128
)
&&
head_dim_qk
!=
head_dim_v
)))
&&
// 9.11
/9.12
bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(
!
(
(
cudnn_runtime_version
==
91100
||
cudnn_runtime_version
==
91200
)
&&
is_training
&&
sm_arch_
==
90
&&
head_dim_qk
>
=
12
8
&&
head_dim_v
>
=
128
&&
!
(
head_dim_qk
==
192
&&
head_dim_v
==
128
)
&&
head_dim_qk
!=
head_dim_v
)))
&&
// bias type
((
cudnn_runtime_version
<
8906
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
)
||
(
cudnn_runtime_version
>=
8906
&&
...
...
@@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
max_seqlen_q
<=
max_seqlen_kv
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
&&
dropout
==
0.0
))))
&&
// check 64-bit ragged offset support
(
supported_ragged_offset_size
))
{
(
supported_ragged_offset_size
)
&&
// 9.10.0/9.10.1: known bugs with SDPA F16
(
cudnn_runtime_version
!=
91000
)
&&
(
cudnn_runtime_version
!=
91001
))
{
flag_arb
=
true
;
}
if
(((
max_seqlen_q
>
512
)
||
(
max_seqlen_kv
>
512
))
&&
(
flag_arb
==
true
))
{
...
...
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
View file @
87e3e56e
...
...
@@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType
intermediate_result
=
warp_reduce_on_shmem
(
aggregated_probs_per_expert
,
num_cols
,
sum
,
lane_id
);
warp_reduce_on_shmem
(
aggregated_probs_per_expert
,
num_cols
,
ReduceFuncType
::
SUM
,
lane_id
);
__syncwarp
();
if
(
lane_id
==
0
)
{
...
...
@@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType
intermediate_result
=
warp_reduce_on_shmem
(
aggregated_probs_per_expert
,
num_cols
,
sum
,
lane_id
);
warp_reduce_on_shmem
(
aggregated_probs_per_expert
,
num_cols
,
ReduceFuncType
::
SUM
,
lane_id
);
__syncwarp
();
if
(
lane_id
==
0
)
{
...
...
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
View file @
87e3e56e
...
...
@@ -107,7 +107,8 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
if
(
score_function
==
0
)
{
if
(
topk
>
1
)
{
auto
sum_logits
=
warp_reduce_on_shmem
(
local_logits
,
num_experts
,
sum
,
lane_id
);
auto
sum_logits
=
warp_reduce_on_shmem
(
local_logits
,
num_experts
,
ReduceFuncType
::
SUM
,
lane_id
);
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_logits
[
i
]
=
static_cast
<
DataType
>
(
static_cast
<
double
>
(
local_logits
[
i
])
/
(
static_cast
<
double
>
(
sum_logits
)
+
epsilon
));
...
...
@@ -231,13 +232,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int
*/
// Sigmoid Post-processing bwd when topk > 1
if
(
topk
>
1
&&
score_function
==
0
)
{
auto
sum_fwd_input
=
warp_reduce_on_shmem
(
local_act_from_fwd
,
num_experts
,
sum
,
lane_id
);
auto
sum_fwd_input
=
warp_reduce_on_shmem
(
local_act_from_fwd
,
num_experts
,
ReduceFuncType
::
SUM
,
lane_id
);
// Put the result of output * grad to the comp_buf
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_comp_buf
[
i
]
=
local_grad
[
i
]
*
local_act_from_fwd
[
i
];
}
__syncwarp
();
auto
sum_Output_x_Grad
=
warp_reduce_on_shmem
(
local_comp_buf
,
num_experts
,
sum
,
lane_id
);
auto
sum_Output_x_Grad
=
warp_reduce_on_shmem
(
local_comp_buf
,
num_experts
,
ReduceFuncType
::
SUM
,
lane_id
);
// In-place update
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_grad
[
i
]
=
...
...
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
View file @
87e3e56e
...
...
@@ -220,7 +220,7 @@ __global__ void fused_topk_with_score_function_forward_kernel(
// score_function == 0 means sigmoid
if
(
score_function
==
0
)
{
if
(
topk
>
1
)
{
double
sum_scores
=
warp_reduce_on_shmem
(
topk_scores
,
topk
,
sum
,
lane_id
);
double
sum_scores
=
warp_reduce_on_shmem
(
topk_scores
,
topk
,
ReduceFuncType
::
SUM
,
lane_id
);
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
topk_scores
[
i
]
=
static_cast
<
double
>
(
topk_scores
[
i
])
/
(
sum_scores
+
epsilon
);
}
...
...
@@ -362,7 +362,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
/*data ptr = */
local_act_from_fwd
,
/*mask ptr = */
local_routing_map
,
/*data size = */
num_experts
,
/*reduce func = */
sum
,
lane_id
);
/*reduce func = */
ReduceFuncType
::
SUM
,
lane_id
);
// Put the result of output * grad to the comp_buf
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_comp_buf
[
i
]
=
(
local_routing_map
[
i
]
?
static_cast
<
double
>
(
local_grad
[
i
])
*
...
...
@@ -374,7 +374,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
/*data ptr = */
local_comp_buf
,
/*mask ptr = */
local_routing_map
,
/*data size = */
num_experts
,
/*reduce func = */
sum
,
lane_id
);
/*reduce func = */
ReduceFuncType
::
SUM
,
lane_id
);
// In-place update
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
if
(
local_routing_map
[
i
])
{
...
...
transformer_engine/common/fused_router/utils.h
View file @
87e3e56e
...
...
@@ -30,14 +30,28 @@ __device__ inline T sum(T a, T b) {
return
a
+
b
;
}
enum
ReduceFuncType
{
SUM
,
MAX
,
};
template
<
typename
T
>
__device__
inline
T
warp_reduce_on_shmem
(
T
*
data_ptr
,
int
data_size
,
T
(
*
r
educe
_f
unc
)(
T
,
T
)
,
__device__
inline
T
warp_reduce_on_shmem
(
T
*
data_ptr
,
int
data_size
,
R
educe
F
unc
Type
type
,
int
lane_id
)
{
T
(
*
reduce_func
)(
T
,
T
);
double
default_val
=
0
;
if
(
type
==
ReduceFuncType
::
SUM
)
{
reduce_func
=
sum
;
default_val
=
0
;
}
else
if
(
type
==
ReduceFuncType
::
MAX
)
{
reduce_func
=
max
;
default_val
=
-
std
::
numeric_limits
<
double
>::
infinity
();
}
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile
double
val
=
lane_id
<
data_size
?
static_cast
<
double
>
(
data_ptr
[
lane_id
])
:
static_cast
<
double
>
(
0
);
volatile
double
val
=
lane_id
<
data_size
?
static_cast
<
double
>
(
data_ptr
[
lane_id
])
:
default_val
;
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
val
=
reduce_func
(
val
,
data_ptr
[
i
]);
}
...
...
@@ -69,13 +83,22 @@ __device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, i
template
<
typename
T
>
__device__
inline
T
masked_warp_reduce_on_shmem
(
T
*
data_ptr
,
bool
*
mask
,
int
data_size
,
T
(
*
reduce_func
)(
T
,
T
),
int
lane_id
)
{
ReduceFuncType
type
,
int
lane_id
)
{
T
(
*
reduce_func
)(
T
,
T
);
double
default_val
=
0
;
if
(
type
==
ReduceFuncType
::
SUM
)
{
reduce_func
=
sum
;
default_val
=
0
;
}
else
if
(
type
==
ReduceFuncType
::
MAX
)
{
reduce_func
=
max
;
default_val
=
-
std
::
numeric_limits
<
double
>::
infinity
();
}
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile
double
val
=
lane_id
<
data_size
&&
mask
[
lane_id
]
?
static_cast
<
double
>
(
data_ptr
[
lane_id
])
:
static_cast
<
double
>
(
0
);
volatile
double
val
=
lane_id
<
data_size
&&
mask
[
lane_id
]
?
static_cast
<
double
>
(
data_ptr
[
lane_id
])
:
default_val
;
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
if
(
mask
[
i
])
{
val
=
reduce_func
(
val
,
data_ptr
[
i
]);
...
...
@@ -128,7 +151,7 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
float
sum_Output_x_Grad
=
warp_reduce_on_shmem
(
/*data ptr = */
comp_buf
,
/*data size = */
data_size
,
/*reduce func = */
sum
,
lane_id
);
/*reduce func = */
ReduceFuncType
::
SUM
,
lane_id
);
// In-place update
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
if
(
mask
)
{
...
...
@@ -147,14 +170,16 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
template
<
typename
DataType
>
__device__
inline
void
apply_softmax_on_float
(
DataType
*
scores
,
int
data_size
,
int
lane_id
)
{
// 1. compute the max of value
float
max_val
=
static_cast
<
float
>
(
warp_reduce_on_shmem
(
scores
,
data_size
,
max
,
lane_id
));
float
max_val
=
static_cast
<
float
>
(
warp_reduce_on_shmem
(
scores
,
data_size
,
ReduceFuncType
::
MAX
,
lane_id
));
// 2. value -> exp_value
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
scores
[
i
]
=
static_cast
<
float
>
(
exp
(
static_cast
<
float
>
(
scores
[
i
])
-
max_val
));
}
__syncwarp
();
// 3. compute the sum of exp_value
float
sum_val
=
static_cast
<
float
>
(
warp_reduce_on_shmem
(
scores
,
data_size
,
sum
,
lane_id
));
float
sum_val
=
static_cast
<
float
>
(
warp_reduce_on_shmem
(
scores
,
data_size
,
ReduceFuncType
::
SUM
,
lane_id
));
// 4. update the softmax value
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
scores
[
i
]
=
static_cast
<
float
>
(
scores
[
i
])
/
sum_val
;
...
...
@@ -165,19 +190,29 @@ __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, i
template
<
typename
T
>
__device__
inline
void
naive_topk_and_mask
(
T
*
scores
,
int
data_size
,
int
topk
,
int
*
topk_indices
,
T
*
topk_scores
,
int
lane_id
)
{
// Check if the index is masked by the later iteration
auto
is_masked
=
[
&
topk_indices
](
int
k
,
int
index
)
{
if
(
k
==
0
)
return
false
;
for
(
int
i
=
0
;
i
<
k
;
i
++
)
{
if
(
topk_indices
[
i
]
==
index
)
return
true
;
}
return
false
;
};
// Topk Times: Find the max value and its index
// Then mask it, and record the index in the topk_indices
// After looping topk times, the topk_indices will be the topk indices
for
(
int
k
=
0
;
k
<
topk
;
k
++
)
{
// Find the max value and its index
volatile
double
val
=
(
lane_id
<
data_size
)
?
static_cast
<
double
>
(
scores
[
lane_id
])
:
static_cast
<
double
>
(
0
);
volatile
double
val
=
(
lane_id
<
data_size
&&
!
is_masked
(
k
,
lane_id
))
?
static_cast
<
double
>
(
scores
[
lane_id
])
:
-
std
::
numeric_limits
<
double
>::
infinity
();
volatile
int
index
=
(
lane_id
<
data_size
)
?
lane_id
:
0
;
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
volatile
double
cur_val
=
scores
[
i
];
volatile
double
cur_val
=
(
is_masked
(
k
,
i
))
?
-
std
::
numeric_limits
<
double
>::
infinity
()
:
static_cast
<
double
>
(
scores
[
i
]);
if
(
cur_val
>
val
)
{
val
=
cur_val
;
index
=
i
;
...
...
@@ -200,17 +235,9 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
if
(
lane_id
==
0
)
{
topk_indices
[
k
]
=
index
;
topk_scores
[
k
]
=
val
;
scores
[
index
]
=
static_cast
<
double
>
(
-
1.0
)
-
val
;
// make the selected experts using val = - 1 - val
}
__syncwarp
();
}
// Reset the scores to the original value
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
scores
[
topk_indices
[
i
]]
=
static_cast
<
double
>
(
-
1.0
)
-
static_cast
<
double
>
(
scores
[
topk_indices
[
i
]]);
}
}
// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
87e3e56e
...
...
@@ -253,8 +253,9 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
cudaStream_t
stream
)
{
float
alpha
,
float
beta
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
cudaStream_t
stream
)
{
// Tensor dims in row-major order
const
int
A0
=
inputA
->
flat_first_dim
();
const
int
A1
=
inputA
->
flat_last_dim
();
...
...
@@ -310,13 +311,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"fp8 Aux output for gemm + gelu fusion not supported!"
);
}
if
(
is_fp8_dtype
(
outputD
->
data
.
dtype
))
{
NVTE_CHECK
(
!
accumulate
,
"Accumulation mode not supported with FP8 GEMM output!"
);
NVTE_CHECK
(
beta
==
0.0
f
,
"Accumulation mode not supported with FP8 GEMM output!"
);
}
float
one
=
1.0
;
float
zero
=
0.0
;
float
beta
=
(
accumulate
)
?
one
:
zero
;
cublasLtHandle_t
handle
=
cublasHandleManager
::
Instance
().
GetHandle
();
cublasLtMatmulDesc_t
operationDesc
=
nullptr
;
...
...
@@ -601,7 +598,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS
(
cublasLtMatmul
(
handle
,
operationDesc
,
static_cast
<
const
void
*>
(
&
one
),
/* alpha */
static_cast
<
const
void
*>
(
&
alpha
),
/* alpha */
param
.
A
,
/* A */
Adesc
,
param
.
B
,
/* B */
Bdesc
,
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
...
...
@@ -752,8 +749,27 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif //__HIP_PLATFORM_AMD__
1.0
f
,
(
accumulate
)
?
1.0
f
:
0.0
f
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif //__HIP_PLATFORM_AMD__
}
void
nvte_cublas_gemm_scaled
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
float
alpha
,
float
beta
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_gemm_scaled
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
const
Tensor
*
biasTensor
=
convertNVTETensor
(
bias
);
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
alpha
,
beta
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
}
void
nvte_cublas_atomic_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
...
...
@@ -846,8 +862,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
1.0
f
,
(
accumulate
)
?
1.0
f
:
0.0
f
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
#endif //__HIP_PLATFORM_AMD__
}
...
...
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment