Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
1edc9e13
Commit
1edc9e13
authored
Nov 26, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.8' into release_v2.8
parents
5e7dd67e
3a040217
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2149 additions
and
2093 deletions
+2149
-2093
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-0
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+17
-16
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+43
-19
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+40
-7
tests/pytorch/distributed/test_numerics.py
tests/pytorch/distributed/test_numerics.py
+13
-1
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+9
-3
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+0
-1
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+4
-2
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+55
-4
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+70
-15
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+57
-9
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+0
-9
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+1832
-2005
transformer_engine/common/transpose/transpose.cu
transformer_engine/common/transpose/transpose.cu
+7
-1
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+1
-1
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
1edc9e13
...
...
@@ -49,6 +49,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_tes
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_attention.xml
$TE_PATH
/tests/pytorch/attention/test_attention.py
||
test_fail
"test_attention.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/attention/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
mkdir
-p
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint
&&
python
$TE_PATH
/tests/pytorch/test_checkpoint.py
--save-checkpoint
all
--checkpoint-dir
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint/
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH
=
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_checkpoint.xml
$TE_PATH
/tests/pytorch/test_checkpoint.py
||
test_fail
"test_checkpoint.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_router.xml
$TE_PATH
/tests/pytorch/test_fused_router.py
||
test_fail
"test_fused_router.py"
...
...
tests/pytorch/attention/test_attention.py
View file @
1edc9e13
...
...
@@ -340,22 +340,23 @@ def test_dpa_softmax(dtype, model_configs, model):
model_configs_mla
=
{
# test: ModelConfig(b, sq, hq, dqk)
"mla_1_0"
:
ModelConfig
(
8
,
128
,
16
,
64
,
head_dim_v
=
128
),
"mla_1_1"
:
ModelConfig
(
4
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
"mla_1_2"
:
ModelConfig
(
4
,
128
,
16
,
192
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
"mla_2_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
"mla_2_1"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
"mla_2_2"
:
ModelConfig
(
1
,
2048
,
24
,
192
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
128
),
"mla_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
head_dim_v
=
64
),
"mla_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
"mla_3_2"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
"mla_3_3"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# # test: ModelConfig(b, sq, hq, dqk)
# "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
# "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
# "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
# "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
# "mla_2_1": ModelConfig(
# 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
# ), # cross, 1
# "mla_2_2": ModelConfig(
# 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
# ), # cross, 1
# "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
# "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
# "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
# "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# inference
}
...
...
tests/pytorch/debug/run_distributed.py
View file @
1edc9e13
...
...
@@ -17,6 +17,7 @@ import transformer_engine_torch as tex
import
nvdlfw_inspect.api
as
debug_api
from
transformer_engine.debug
import
set_weight_tensor_tp_group_reduce
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
test_numerics
import
(
_emulate_linear
,
...
...
@@ -47,7 +48,6 @@ TEST_NR = 0
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
def
_get_tensors
(
parallel_mode
,
weight_seed
=
SEED
,
data_seed
=
SEED
,
tp_size
=
None
,
tp_rank
=
None
):
if
tp_size
is
None
:
tp_size
=
WORLD_SIZE
...
...
@@ -72,6 +72,16 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None,
def
_init_model
(
weight
,
parallel_mode
=
None
,
tp_group
=
None
,
name
=
"linear"
):
if
IS_HIP_EXTENSION
:
model
=
transformer_engine
.
pytorch
.
Linear
(
IN_SIZE
,
OUT_SIZE
,
name
=
name
,
bias
=
False
,
parallel_mode
=
parallel_mode
,
tp_group
=
(
tp_group
or
NCCL_WORLD
if
parallel_mode
else
None
),
)
else
:
model
=
transformer_engine
.
pytorch
.
Linear
(
IN_SIZE
,
OUT_SIZE
,
...
...
@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
)
set_weight_tensor_tp_group_reduce
(
True
)
# reset
@
run_debug_test
def
sanity_test_log_quantized_stats
(
parallel_mode
,
gather_weight
,
**
kwargs
):
from
test_log
import
LOG_QUANTIZED_CONFIG
...
...
@@ -580,6 +589,9 @@ def test_fake_quant_fp8(
"dgrad_fp8"
:
not
(
dgrad_weight
or
dgrad_grad
),
"wgrad_fp8"
:
not
(
wgrad_grad
or
wgrad_input
),
}
if
IS_HIP_EXTENSION
:
if
fp8_kwargs
[
"fprop_fp8"
]
or
fp8_kwargs
[
"dgrad_fp8"
]
or
fp8_kwargs
[
"wgrad_fp8"
]:
return
# Output type 32 (FP32) does not support int8 simulation.
if
WORLD_RANK
==
0
:
fake_quant_fp8_create_config
(
fprop_inp
,
...
...
@@ -667,6 +679,10 @@ if __name__ == "__main__":
random
.
seed
(
SEED
)
_init_distributed
()
if
IS_HIP_EXTENSION
:
# Output type 32 (FP32) does not support int8 simulation.
pass
else
:
test_log_expert_parallel
()
for
parallel_mode
in
[
"column"
,
"row"
]:
for
gather_weight
in
[
True
,
False
]:
...
...
@@ -676,6 +692,11 @@ if __name__ == "__main__":
for
parallel_mode
in
[
"row"
,
"column"
]:
test_disable_fp8_layer
(
parallel_mode
)
if
IS_HIP_EXTENSION
:
# Output type 32 (FP32) does not support int8 simulation.
pass
else
:
# test_disable_fp8_gemms
_run_test_with_combinations
(
test_disable_fp8_gemms
,
all_boolean
,
num_repeat
=
3
,
extra_args
=
[
"column"
,
"row"
]
...
...
@@ -690,7 +711,10 @@ if __name__ == "__main__":
extra_args
=
[
"column"
,
"row"
],
sample_size
=
20
,
)
if
IS_HIP_EXTENSION
:
# Output type 32 (FP32) does not support int8 simulation.
pass
else
:
_run_test_with_combinations
(
test_per_tensor_scaling
,
all_boolean
,
...
...
tests/pytorch/distributed/run_numerics.py
View file @
1edc9e13
...
...
@@ -733,7 +733,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
def
test_linear
():
"""Run linear layer tests with various configurations."""
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{
"bias"
:
False
},
{
"init_method"
:
_constant
},
...
...
@@ -743,7 +743,15 @@ def test_linear():
{
"delay_wgrad_compute"
:
True
},
{
"save_original_input"
:
True
},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
continue
...
...
@@ -913,7 +921,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
def
test_layernorm_linear
():
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{
"bias"
:
False
},
{
"init_method"
:
_constant
},
...
...
@@ -924,7 +932,15 @@ def test_layernorm_linear():
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
parallel_mode
in
[
"column"
]:
for
sequence_parallel
in
[
False
,
True
]:
...
...
@@ -1019,7 +1035,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
def
test_layernorm_mlp
():
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{
"init_method"
:
_constant
},
{
"output_layer_init_method"
:
_constant
},
...
...
@@ -1033,7 +1049,15 @@ def test_layernorm_mlp():
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
set_parallel_mode
in
[
True
]:
for
sequence_parallel
in
[
False
,
True
]:
...
...
@@ -1108,7 +1132,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
def
test_transformer_layer
():
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{
"num_gqa_groups"
:
4
},
{
"init_method"
:
_constant
},
...
...
@@ -1128,6 +1152,15 @@ def test_transformer_layer():
{
"fuse_qkv_params"
:
True
},
{
"activation"
:
"relu"
},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
sequence_parallel
in
[
False
,
True
]:
...
...
tests/pytorch/distributed/test_numerics.py
View file @
1edc9e13
...
...
@@ -9,7 +9,8 @@ from pathlib import Path
import
pytest
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine
as
te
"""
Distributed numerics tests
...
...
@@ -66,4 +67,15 @@ def test_distributed(quantization):
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
if
IS_HIP_EXTENSION
and
quantization
==
"fp8_block_scaling"
:
import
importlib
ori_int8_sim_fp8
=
os
.
environ
.
get
(
"NVTE_INT8_SIM_FP8"
,
"None"
)
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"1"
importlib
.
reload
(
te
.
pytorch
.
fp8
)
_run_test
(
quantization
)
if
IS_HIP_EXTENSION
and
quantization
==
"fp8_block_scaling"
:
if
ori_int8_sim_fp8
is
None
or
ori_int8_sim_fp8
==
"None"
:
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"0"
else
:
del
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
importlib
.
reload
(
te
.
pytorch
.
fp8
)
tests/pytorch/test_cuda_graphs.py
View file @
1edc9e13
...
...
@@ -28,9 +28,9 @@ if IS_HIP_EXTENSION:
from
functools
import
cache
# Check if FP8 is supported.
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Reset RNG states.
reset_rng_states
()
...
...
@@ -367,6 +367,12 @@ def test_make_graphed_callables(
)
if
fp8_params
:
pytest
.
skip
(
"NVFP4 params not supported"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Run model with different CUDA graph settings.
model_config
=
model_configs
[
model_config
]
...
...
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
1edc9e13
...
...
@@ -6,7 +6,6 @@ import pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
...
...
tests/pytorch/test_fusible_ops.py
View file @
1edc9e13
...
...
@@ -2111,7 +2111,8 @@ class TestFusedOps:
quantized_weight
:
bool
=
False
,
)
->
None
:
"""Forward GEMM + scale + add"""
if
IS_HIP_EXTENSION
and
scale
!=
1
:
pytest
.
skip
(
"alpha must be 1.0 for hip"
)
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
...
...
@@ -2496,7 +2497,8 @@ class TestFusedOps:
quantized_weight
:
bool
=
False
,
)
->
None
:
"""Backward dgrad GEMM + scale"""
if
IS_HIP_EXTENSION
and
scale
!=
1
:
pytest
.
skip
(
"alpha must be 1.0 for hip"
)
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
...
...
tests/pytorch/test_numerics.py
View file @
1edc9e13
...
...
@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
...
...
@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute(
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
config
=
model_configs
[
model
]
...
...
@@ -714,8 +721,15 @@ def _test_e2e_full_recompute(
def
test_gpt_full_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
...
...
@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation
=
True
fp8_model_params
=
False
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
...
@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy(
use_cutlass
=
False
,
):
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
...
...
@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy(
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"1"
outputs_ref
=
_test_grouped_linear_accuracy
(
sequential_linear
,
num_gemms
,
...
...
@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy(
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
)
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"0"
for
o
,
o_ref
in
zip
(
outputs
,
outputs_ref
):
if
use_cutlass
:
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
...
...
@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
...
@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params
,
parallel_mode
=
None
,
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
...
@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
...
@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
...
...
tests/pytorch/test_onnx_export.py
View file @
1edc9e13
...
...
@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
,
te_translation_table
from
transformer_engine.pytorch.onnx_extensions
import
te_translation_table
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
get_default_init_method
import
tensorrt
as
trt
...
...
@@ -65,7 +67,6 @@ if mxfp8_available:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
None
)
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
...
...
@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
],
outputs
=
[
PyCustomOpDef
.
dt_uint8
],
)
def
trt_fp8_quantize
(
t
,
scale
_inv
):
def
trt_fp8_quantize
(
t
,
scale
):
"""FP8 quantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
_inv
).
cuda
(),
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
...
...
@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv):
],
outputs
=
[
PyCustomOpDef
.
dt_float
],
)
def
trt_fp8_dequantize
(
t
,
scale
_inv
):
def
trt_fp8_dequantize
(
t
,
scale
):
"""FP8 dequantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
_inv
).
cuda
(),
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
...
...
@@ -469,16 +470,22 @@ def _test_export_linear(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_export_linear_use_bias
(
seed_default_rng
,
use_bias
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
use_bias
=
use_bias
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
def
test_export_linear_return_bias
(
seed_default_rng
,
return_bias
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
return_bias
=
return_bias
)
...
...
@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm
(
normalization
=
normalization
)
...
...
@@ -594,9 +603,7 @@ def _test_export_layernorm_linear(
fname
,
inp
,
model
,
# For current scaling we use Float8Quantizer in tests + amax computed by hand,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol
=
1e-3
if
fp8_recipe
.
__class__
is
not
recipe
.
Float8CurrentScaling
else
2e-2
,
atol
=
1e-3
,
is_fp8
=
fp8_recipe
is
not
None
,
te_outputs
=
te_outputs
,
)
...
...
@@ -605,27 +612,39 @@ def _test_export_layernorm_linear(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_linear_return_ln_out
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
return_layernorm_output
=
True
)
def
test_export_layernorm_linear_zero_centered_gamma
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_linear_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
normalization
=
normalization
)
def
test_export_layernorm_linear_no_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
use_bias
=
False
)
def
test_export_layernorm_linear_return_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
return_bias
=
True
)
...
...
@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_mlp
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_mlp_return_layernorm_output
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
return_layernorm_output
=
True
)
def
test_export_layernorm_mlp_return_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
return_bias
=
True
)
def
test_export_layernorm_mlp_no_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
use_bias
=
False
)
def
test_export_layernorm_mlp_zero_centered_gamma
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_mlp_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
normalization
=
normalization
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_layernorm_mlp_activation
(
seed_default_rng
,
activation
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
activation
=
activation
)
...
...
@@ -731,6 +764,8 @@ def test_export_core_attention(
use_mask
:
bool
,
attn_mask_type
:
str
,
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
# Set dimensions (these are arbitrary).
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
=
(
64
,
4
,
1
,
64
)
qkv_size
=
(
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
)
...
...
@@ -932,22 +967,32 @@ def _test_export_multihead_attention(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_multihead_attention_recipe
(
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_multihead_attention_no_mask
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
use_mask
=
False
)
def
test_export_multihead_attention_no_input_layernorm
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
input_layernorm
=
False
)
def
test_export_multihead_attention_cross_attn
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
attention_type
=
"cross"
)
def
test_export_multihead_attention_unfused_qkv_params
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
fuse_qkv_params
=
False
)
...
...
@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_transformer_layer_recipe
(
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_transformer_layer_no_mask
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
use_mask
=
False
)
def
test_export_transformer_layer_output_layernorm
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
output_layernorm
=
True
)
def
test_export_transformer_layer_unfused_qkv_params
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
fuse_qkv_params
=
False
)
def
test_export_transformer_layer_zero_centered_gamma
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_transformer_layer_activation
(
activation
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
activation
=
activation
)
...
...
@@ -1056,7 +1113,8 @@ def test_export_gpt_generation(
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
# Layer configuration
hidden_size
=
64
sequence_length
=
128
...
...
@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled):
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
test_trt_integration
(
fp8_recipe
:
recipe
.
Recipe
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"TRT is not supported for HIP"
)
model
=
te
.
TransformerLayer
(
hidden_size
=
128
,
ffn_hidden_size
=
128
,
num_attention_heads
=
4
,
).
eval
()
if
type
(
fp8_recipe
)
==
recipe
.
Float8CurrentScaling
:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model
=
te
.
LayerNormMLP
(
128
,
128
)
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
...
...
tests/pytorch/test_sanity.py
View file @
1edc9e13
...
...
@@ -46,7 +46,7 @@ from utils import ModelConfig
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Record initial RNG state from script run.
...
...
@@ -388,7 +388,13 @@ def test_sanity_layernorm_linear(
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
...
@@ -450,7 +456,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
num_tokens
=
bs
*
config
.
max_seqlen_q
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
...
@@ -543,7 +555,13 @@ def test_sanity_layernorm_mlp(
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
...
@@ -587,7 +605,13 @@ def test_sanity_gpt(
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
...
@@ -759,7 +783,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
...
@@ -791,7 +821,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
...
@@ -827,7 +863,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
...
@@ -863,7 +905,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
1edc9e13
...
...
@@ -1561,15 +1561,6 @@ void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor
NVTE_ERROR
(
"TT layout not allowed."
);
}
hipblasLtHandle_t
handle
=
nullptr
;
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
0
];
NVTE_ERROR
(
"Remove nvte_cublas_batchgemm_tensorwise_int8 for now."
);
}
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
1edc9e13
...
...
@@ -465,106 +465,6 @@ transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t)
namespace
{
static
class
HandlePool
{
public:
hipblasLtHandle_t
get
(
int
device_id
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mt
);
if
(
pool
.
empty
())
{
int
device_count
=
0
;
NVTE_CHECK_CUDA
(
hipGetDeviceCount
(
&
device_count
));
pool
.
resize
(
device_count
);
return
nullptr
;
}
if
(
!
pool
[
device_id
].
empty
())
{
hipblasLtHandle_t
h
=
pool
[
device_id
].
front
();
pool
[
device_id
].
pop_front
();
return
h
;
}
return
nullptr
;
}
hipblasLtHandle_t
obtain
(
int
device_id
)
{
hipblasLtHandle_t
h
=
get
(
device_id
);
if
(
h
==
nullptr
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
&
h
));
}
return
h
;
}
void
store
(
const
std
::
vector
<
hipblasLtHandle_t
>&
handles
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mt
);
if
(
pool
.
empty
())
{
std
::
cout
<<
"[ERROR] Attempt to store handles to invalid pool"
<<
std
::
endl
;
}
for
(
unsigned
int
i
=
0
;
i
<
pool
.
size
();
i
++
)
{
if
(
handles
[
i
]
!=
nullptr
)
{
pool
[
i
].
push_front
(
handles
[
i
]);
}
}
}
~
HandlePool
()
{
#if DESTROY_HIPBLASLT_HANDLES_POOL
std
::
lock_guard
<
std
::
mutex
>
lock
(
mt
);
for
(
auto
&
hlist
:
pool
)
{
for
(
auto
&
h
:
hlist
)
{
hipblasLtDestroy
(
h
);
}
}
pool
.
clear
();
#endif
}
inline
size_t
get_size
()
const
{
return
pool
.
size
();
}
private:
std
::
mutex
mt
;
using
Pool
=
std
::
vector
<
std
::
forward_list
<
hipblasLtHandle_t
>>
;
// Order of destructors between thread_local and global is not actually guaranteed
// As a simple w/a make pool storage "leaky"
// Just do not destruct it and do not destroy hipbladLt handles
// Let OS deal with it on application exit
#if DESTROY_HIPBLASLT_HANDLES_POOL
Pool
pool
;
#else
Pool
&
pool
=
*
new
Pool
();
#endif
}
handle_pool
;
thread_local
static
class
HandleCache
{
public:
hipblasLtHandle_t
get
(
int
device_id
)
const
{
return
d
.
empty
()
?
nullptr
:
d
[
device_id
];
}
hipblasLtHandle_t
obtain
(
int
device_id
)
{
hipblasLtHandle_t
h
=
get
(
device_id
);
if
(
h
)
{
return
h
;
}
h
=
handle_pool
.
obtain
(
device_id
);
set
(
device_id
,
h
);
return
h
;
}
void
set
(
int
device_id
,
hipblasLtHandle_t
h
)
{
if
(
d
.
empty
())
{
d
.
resize
(
handle_pool
.
get_size
());
}
d
[
device_id
]
=
h
;
}
~
HandleCache
()
{
if
(
!
d
.
empty
())
{
handle_pool
.
store
(
d
);
}
}
private:
std
::
vector
<
hipblasLtHandle_t
>
d
;
}
cached_handles
;
class
csv_helper
{
public:
struct
start
{};
...
...
@@ -987,18 +887,12 @@ static inline int getIntEnv(const char* name, int defval, int minval) {
}
//namespace
/* Warning: only call once per device!
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
* need to create multiple handles corresponding to compute_streams
* to avoid a handle be used by multi-streams concurrently.
*/
static
void
init_hipblaslt_handles
(
hipblasLtHandle_t
*
hipblaslt_handles
)
{
NVTE_CHECK
(
hipblaslt_handles
!=
nullptr
);
for
(
int
i
=
0
;
i
<
compute_num_streams
;
i
++
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
&
hipblaslt_handles
[
i
]));
}
static
inline
void
CreateHipBlasLtHandle
(
hipblasLtHandle_t
*
handle
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
handle
));
}
using
hipBlasLtHandleManager
=
detail
::
HandleManager
<
hipblasLtHandle_t
,
CreateHipBlasLtHandle
>
;
transformer_engine
::
DType
get_transformer_engine_dtype_from_hipblaslt_dtype
(
const
hipDataType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
...
...
@@ -1018,8 +912,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
hipblasLtHandle_t
handle
)
{
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
)
{
void
*
A
=
inputA
->
data
.
dptr
;
void
*
A_scale_inverse
=
inputA
->
scale_inv
.
dptr
;
float
*
A_scale_inverse_float
=
(
float
*
)(
inputA
->
scale_inv
.
dptr
);
...
...
@@ -1064,12 +957,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
int
device_id
;
NVTE_CHECK_CUDA
(
hipGetDevice
(
&
device_id
));
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
get
(
device_id
);
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
obtain
(
device_id
);
}
}
hipblasLtHandle_t
handle
=
hipBlasLtHandleManager
::
Instance
().
GetHandle
();
hipblasLtMatmulDesc_t
operationDesc
=
nullptr
;
hipblasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
...
...
@@ -1352,82 +1240,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
}
class
userArgsManager
{
public:
userArgsManager
()
{}
~
userArgsManager
()
{
// Release all userArgs when the manager is destroyed
for
(
auto
&
device_pair
:
userArgs_map_
)
{
hipFree
(
device_pair
.
second
);
// Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext
::
UserArguments
*
get
(
int
device_id
,
size_t
size
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// Check if the userArgs for this device exists
auto
device_it
=
userArgs_map_
.
find
(
device_id
);
if
(
device_it
!=
userArgs_map_
.
end
())
{
return
device_it
->
second
;
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext
::
UserArguments
*
userArgs
;
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
userArgs
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
// Store the userArgs in the map for this device
userArgs_map_
[
device_id
]
=
userArgs
;
return
userArgs
;
struct
HipBlasLtUserArgsDeleter
{
void
operator
()(
hipblaslt_ext
::
UserArguments
*
ptr
)
const
noexcept
{
hipFree
(
ptr
);
}
private:
std
::
unordered_map
<
int
,
hipblaslt_ext
::
UserArguments
*>
userArgs_map_
;
// Map from device_id to hipblasHandle
std
::
mutex
mutex_
;
};
class
d_userArgsManager
{
public:
d_userArgsManager
()
{}
using
HipBlasLtUserArgsPtr
=
std
::
unique_ptr
<
hipblaslt_ext
::
UserArguments
,
HipBlasLtUserArgsDeleter
>
;
~
d_userArgsManager
()
{
// Release all userArgs when the manager is destroyed
for
(
auto
&
device_pair
:
d_userArgs_map_
)
{
hipFree
(
device_pair
.
second
);
// Only one userArgs per device
}
inline
HipBlasLtUserArgsPtr
make_hipblaslt_user_args_ptr
(
size_t
size
,
bool
host
)
{
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
nullptr
;
if
(
host
)
{
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
else
{
NVTE_CHECK_CUDA
(
hipMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
return
HipBlasLtUserArgsPtr
(
raw_ptr
);
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext
::
UserArguments
*
get
(
int
device_id
,
size_t
size
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// Check if the userArgs for this device exists
auto
device_it
=
d_userArgs_map_
.
find
(
device_id
);
if
(
device_it
!=
d_userArgs_map_
.
end
())
{
return
device_it
->
second
;
inline
hipblaslt_ext
::
UserArguments
*
get_hipblaslt_user_args
(
size_t
size
,
bool
host
)
{
thread_local
static
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>
host_userargs_cache
;
thread_local
static
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>
device_userargs_cache
;
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>&
user_args_cache
=
host
?
host_userargs_cache
:
device_userargs_cache
;
auto
size_it
=
user_args_cache
.
find
(
size
);
if
(
size_it
!=
user_args_cache
.
end
())
{
return
size_it
->
second
.
get
();
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext
::
UserArguments
*
d_userArgs
;
NVTE_CHECK_CUDA
(
hipMalloc
(
&
d_userArgs
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
// Store the userArgs in the map for this device
d_userArgs_map_
[
device_id
]
=
d_userArgs
;
return
d_userArgs
;
else
{
HipBlasLtUserArgsPtr
user_args
=
make_hipblaslt_user_args_ptr
(
size
,
host
);
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
user_args
.
get
();
user_args_cache
[
size
]
=
std
::
move
(
user_args
);
return
raw_ptr
;
}
}
private:
std
::
unordered_map
<
int
,
hipblaslt_ext
::
UserArguments
*>
d_userArgs_map_
;
// Map from device_id to hipblasHandle
std
::
mutex
mutex_
;
};
// Define a static userArgs manager
static
userArgsManager
UAManager
;
static
d_userArgsManager
d_UAManager
;
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
...
...
@@ -1438,23 +1285,13 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
int
device_id
;
hipGetDevice
(
&
device_id
);
hipblaslt_ext
::
UserArguments
*
userArgs
=
UAManager
.
get
(
device_id
,
m
.
size
());
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
d_UAManager
.
get
(
device_id
,
m
.
size
());
hipblaslt_ext
::
UserArguments
*
userArgs
=
get_hipblaslt_user_args
(
m
.
size
(),
true
);
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
get_hipblaslt_user_args
(
m
.
size
(),
false
);
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblasLtHandle_t
handle
=
nullptr
;
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
}
hipblasLtHandle_t
handle
=
hipBlasLtHandleManager
::
Instance
().
GetHandle
();
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
[
0
]
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
[
0
]
->
data
.
dtype
);
...
...
@@ -1972,20 +1809,10 @@ void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
hipblasLtHandle_t
handle
=
nullptr
;
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
}
hipblaslt_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
workspace
,
workspaceSize
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
handle
);
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
return
;
}
...
...
transformer_engine/common/transpose/transpose.cu
View file @
1edc9e13
...
...
@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
// Choose between runtime-compiled or statically-compiled kernel
const
bool
aligned
=
(
row_length
%
THREADS_PER_WARP
==
0
&&
num_rows
%
THREADS_PER_WARP
==
0
);
if
(
aligned
&&
rtc
::
is_enabled
())
{
// Runtime-compiled tuned kernel
//TODO:Using RTC may cause kernel crashes. Therefore, set use_rtc to true to avoid using RTC and resolve the kernel crash issue.
#ifdef USE_ROCM
const
bool
use_rtc
=
false
;
#else
const
bool
use_rtc
=
true
;
#endif
if
(
aligned
&&
rtc
::
is_enabled
()
&&
use_rtc
)
{
// Runtime-compiled tuned kernel
// Pick kernel config
std
::
vector
<
KernelConfig
>
kernel_configs
;
kernel_configs
.
reserve
(
16
);
...
...
transformer_engine/pytorch/module/_common.py
View file @
1edc9e13
...
...
@@ -55,7 +55,7 @@ def apply_normalization(
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
if
enable_lightop
and
(
ln_bias
is
None
)
and
normalization
==
"RMSNorm"
and
output_quantizer
is
None
and
(
output_dtype
is
torch
.
bfloat16
or
output_dtype
is
torch
.
float16
or
output_dtype
is
torch
.
float32
):
if
enable_lightop
and
(
ln_bias
is
None
)
and
normalization
==
"RMSNorm"
and
output_quantizer
is
None
and
(
output_dtype
is
torch
.
bfloat16
or
output_dtype
is
torch
.
float16
or
output_dtype
is
torch
.
float32
)
and
not
zero_centered_gamma
:
out
,
rsigma
=
rmsnorm_forward
(
inputmat
,
ln_weight
,
ln_out
,
eps
,
True
)
return
out
,
None
,
rsigma
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment