Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
1edc9e13
Commit
1edc9e13
authored
Nov 26, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.8' into release_v2.8
parents
5e7dd67e
3a040217
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2149 additions
and
2093 deletions
+2149
-2093
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-0
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+17
-16
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+43
-19
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+40
-7
tests/pytorch/distributed/test_numerics.py
tests/pytorch/distributed/test_numerics.py
+13
-1
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+9
-3
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+0
-1
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+4
-2
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+55
-4
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+70
-15
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+57
-9
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+0
-9
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+1832
-2005
transformer_engine/common/transpose/transpose.cu
transformer_engine/common/transpose/transpose.cu
+7
-1
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+1
-1
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
1edc9e13
...
@@ -49,6 +49,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_tes
...
@@ -49,6 +49,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_tes
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_attention.xml
$TE_PATH
/tests/pytorch/attention/test_attention.py
||
test_fail
"test_attention.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_attention.xml
$TE_PATH
/tests/pytorch/attention/test_attention.py
||
test_fail
"test_attention.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/attention/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/attention/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
mkdir
-p
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint
&&
python
$TE_PATH
/tests/pytorch/test_checkpoint.py
--save-checkpoint
all
--checkpoint-dir
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint/
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH
=
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_checkpoint.xml
$TE_PATH
/tests/pytorch/test_checkpoint.py
||
test_fail
"test_checkpoint.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH
=
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_checkpoint.xml
$TE_PATH
/tests/pytorch/test_checkpoint.py
||
test_fail
"test_checkpoint.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_router.xml
$TE_PATH
/tests/pytorch/test_fused_router.py
||
test_fail
"test_fused_router.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_router.xml
$TE_PATH
/tests/pytorch/test_fused_router.py
||
test_fail
"test_fused_router.py"
...
...
tests/pytorch/attention/test_attention.py
View file @
1edc9e13
...
@@ -340,22 +340,23 @@ def test_dpa_softmax(dtype, model_configs, model):
...
@@ -340,22 +340,23 @@ def test_dpa_softmax(dtype, model_configs, model):
model_configs_mla
=
{
model_configs_mla
=
{
# test: ModelConfig(b, sq, hq, dqk)
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
"mla_1_0"
:
ModelConfig
(
8
,
128
,
16
,
64
,
head_dim_v
=
128
),
# # test: ModelConfig(b, sq, hq, dqk)
"mla_1_1"
:
ModelConfig
(
4
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_2"
:
ModelConfig
(
4
,
128
,
16
,
192
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_1"
:
ModelConfig
(
# "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
# "mla_2_1": ModelConfig(
),
# 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
"mla_2_2"
:
ModelConfig
(
# ), # cross, 1
1
,
2048
,
24
,
192
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
128
# "mla_2_2": ModelConfig(
),
# 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
"mla_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
head_dim_v
=
64
),
# ), # cross, 1
"mla_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_2"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# inference
}
}
...
...
tests/pytorch/debug/run_distributed.py
View file @
1edc9e13
...
@@ -17,6 +17,7 @@ import transformer_engine_torch as tex
...
@@ -17,6 +17,7 @@ import transformer_engine_torch as tex
import
nvdlfw_inspect.api
as
debug_api
import
nvdlfw_inspect.api
as
debug_api
from
transformer_engine.debug
import
set_weight_tensor_tp_group_reduce
from
transformer_engine.debug
import
set_weight_tensor_tp_group_reduce
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
test_numerics
import
(
from
test_numerics
import
(
_emulate_linear
,
_emulate_linear
,
...
@@ -47,7 +48,6 @@ TEST_NR = 0
...
@@ -47,7 +48,6 @@ TEST_NR = 0
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
def
_get_tensors
(
parallel_mode
,
weight_seed
=
SEED
,
data_seed
=
SEED
,
tp_size
=
None
,
tp_rank
=
None
):
def
_get_tensors
(
parallel_mode
,
weight_seed
=
SEED
,
data_seed
=
SEED
,
tp_size
=
None
,
tp_rank
=
None
):
if
tp_size
is
None
:
if
tp_size
is
None
:
tp_size
=
WORLD_SIZE
tp_size
=
WORLD_SIZE
...
@@ -72,13 +72,23 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None,
...
@@ -72,13 +72,23 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None,
def
_init_model
(
weight
,
parallel_mode
=
None
,
tp_group
=
None
,
name
=
"linear"
):
def
_init_model
(
weight
,
parallel_mode
=
None
,
tp_group
=
None
,
name
=
"linear"
):
model
=
transformer_engine
.
pytorch
.
Linear
(
if
IS_HIP_EXTENSION
:
model
=
transformer_engine
.
pytorch
.
Linear
(
IN_SIZE
,
IN_SIZE
,
OUT_SIZE
,
OUT_SIZE
,
name
=
name
,
name
=
name
,
bias
=
False
,
parallel_mode
=
parallel_mode
,
parallel_mode
=
parallel_mode
,
tp_group
=
(
tp_group
or
NCCL_WORLD
if
parallel_mode
else
None
),
tp_group
=
(
tp_group
or
NCCL_WORLD
if
parallel_mode
else
None
),
)
)
else
:
model
=
transformer_engine
.
pytorch
.
Linear
(
IN_SIZE
,
OUT_SIZE
,
name
=
name
,
parallel_mode
=
parallel_mode
,
tp_group
=
(
tp_group
or
NCCL_WORLD
if
parallel_mode
else
None
),
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
weight
)
model
.
weight
.
copy_
(
weight
)
return
model
return
model
...
@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
...
@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
)
)
set_weight_tensor_tp_group_reduce
(
True
)
# reset
set_weight_tensor_tp_group_reduce
(
True
)
# reset
@
run_debug_test
@
run_debug_test
def
sanity_test_log_quantized_stats
(
parallel_mode
,
gather_weight
,
**
kwargs
):
def
sanity_test_log_quantized_stats
(
parallel_mode
,
gather_weight
,
**
kwargs
):
from
test_log
import
LOG_QUANTIZED_CONFIG
from
test_log
import
LOG_QUANTIZED_CONFIG
...
@@ -580,6 +589,9 @@ def test_fake_quant_fp8(
...
@@ -580,6 +589,9 @@ def test_fake_quant_fp8(
"dgrad_fp8"
:
not
(
dgrad_weight
or
dgrad_grad
),
"dgrad_fp8"
:
not
(
dgrad_weight
or
dgrad_grad
),
"wgrad_fp8"
:
not
(
wgrad_grad
or
wgrad_input
),
"wgrad_fp8"
:
not
(
wgrad_grad
or
wgrad_input
),
}
}
if
IS_HIP_EXTENSION
:
if
fp8_kwargs
[
"fprop_fp8"
]
or
fp8_kwargs
[
"dgrad_fp8"
]
or
fp8_kwargs
[
"wgrad_fp8"
]:
return
# Output type 32 (FP32) does not support int8 simulation.
if
WORLD_RANK
==
0
:
if
WORLD_RANK
==
0
:
fake_quant_fp8_create_config
(
fake_quant_fp8_create_config
(
fprop_inp
,
fprop_inp
,
...
@@ -667,30 +679,42 @@ if __name__ == "__main__":
...
@@ -667,30 +679,42 @@ if __name__ == "__main__":
random
.
seed
(
SEED
)
random
.
seed
(
SEED
)
_init_distributed
()
_init_distributed
()
test_log_expert_parallel
()
if
IS_HIP_EXTENSION
:
for
parallel_mode
in
[
"column"
,
"row"
]:
# Output type 32 (FP32) does not support int8 simulation.
for
gather_weight
in
[
True
,
False
]:
pass
test_log_distributed
(
parallel_mode
,
gather_weight
)
else
:
test_log_expert_parallel
()
for
parallel_mode
in
[
"column"
,
"row"
]:
for
gather_weight
in
[
True
,
False
]:
test_log_distributed
(
parallel_mode
,
gather_weight
)
if
fp8_available
:
if
fp8_available
:
for
parallel_mode
in
[
"row"
,
"column"
]:
for
parallel_mode
in
[
"row"
,
"column"
]:
test_disable_fp8_layer
(
parallel_mode
)
test_disable_fp8_layer
(
parallel_mode
)
# test_disable_fp8_gemms
_run_test_with_combinations
(
test_disable_fp8_gemms
,
all_boolean
,
num_repeat
=
3
,
extra_args
=
[
"column"
,
"row"
]
)
# test_fake_quant_fp8
if
IS_HIP_EXTENSION
:
dtype_options
=
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
,
None
]
# Output type 32 (FP32) does not support int8 simulation.
pass
else
:
# test_disable_fp8_gemms
_run_test_with_combinations
(
_run_test_with_combinations
(
test_fake_quant_fp8
,
test_disable_fp8_gemms
,
all_boolean
,
num_repeat
=
3
,
extra_args
=
[
"column"
,
"row"
]
dtype_options
,
num_repeat
=
6
,
extra_args
=
[
"column"
,
"row"
],
sample_size
=
20
,
)
)
# test_fake_quant_fp8
dtype_options
=
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
,
None
]
_run_test_with_combinations
(
test_fake_quant_fp8
,
dtype_options
,
num_repeat
=
6
,
extra_args
=
[
"column"
,
"row"
],
sample_size
=
20
,
)
if
IS_HIP_EXTENSION
:
# Output type 32 (FP32) does not support int8 simulation.
pass
else
:
_run_test_with_combinations
(
_run_test_with_combinations
(
test_per_tensor_scaling
,
test_per_tensor_scaling
,
all_boolean
,
all_boolean
,
...
...
tests/pytorch/distributed/run_numerics.py
View file @
1edc9e13
...
@@ -733,7 +733,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
...
@@ -733,7 +733,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
def
test_linear
():
def
test_linear
():
"""Run linear layer tests with various configurations."""
"""Run linear layer tests with various configurations."""
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{},
{
"bias"
:
False
},
{
"bias"
:
False
},
{
"init_method"
:
_constant
},
{
"init_method"
:
_constant
},
...
@@ -743,7 +743,15 @@ def test_linear():
...
@@ -743,7 +743,15 @@ def test_linear():
{
"delay_wgrad_compute"
:
True
},
{
"delay_wgrad_compute"
:
True
},
{
"save_original_input"
:
True
},
{
"save_original_input"
:
True
},
]
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
continue
continue
...
@@ -913,7 +921,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
...
@@ -913,7 +921,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
def
test_layernorm_linear
():
def
test_layernorm_linear
():
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{},
{
"bias"
:
False
},
{
"bias"
:
False
},
{
"init_method"
:
_constant
},
{
"init_method"
:
_constant
},
...
@@ -924,7 +932,15 @@ def test_layernorm_linear():
...
@@ -924,7 +932,15 @@ def test_layernorm_linear():
{
"return_layernorm_output"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
{
"delay_wgrad_compute"
:
True
},
]
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
for
parallel_mode
in
[
"column"
]:
for
parallel_mode
in
[
"column"
]:
for
sequence_parallel
in
[
False
,
True
]:
for
sequence_parallel
in
[
False
,
True
]:
...
@@ -1019,7 +1035,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
...
@@ -1019,7 +1035,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
def
test_layernorm_mlp
():
def
test_layernorm_mlp
():
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{},
{
"init_method"
:
_constant
},
{
"init_method"
:
_constant
},
{
"output_layer_init_method"
:
_constant
},
{
"output_layer_init_method"
:
_constant
},
...
@@ -1033,7 +1049,15 @@ def test_layernorm_mlp():
...
@@ -1033,7 +1049,15 @@ def test_layernorm_mlp():
{
"return_layernorm_output"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
{
"delay_wgrad_compute"
:
True
},
]
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
for
set_parallel_mode
in
[
True
]:
for
set_parallel_mode
in
[
True
]:
for
sequence_parallel
in
[
False
,
True
]:
for
sequence_parallel
in
[
False
,
True
]:
...
@@ -1108,7 +1132,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
...
@@ -1108,7 +1132,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
def
test_transformer_layer
():
def
test_transformer_layer
():
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{},
{
"num_gqa_groups"
:
4
},
{
"num_gqa_groups"
:
4
},
{
"init_method"
:
_constant
},
{
"init_method"
:
_constant
},
...
@@ -1128,6 +1152,15 @@ def test_transformer_layer():
...
@@ -1128,6 +1152,15 @@ def test_transformer_layer():
{
"fuse_qkv_params"
:
True
},
{
"fuse_qkv_params"
:
True
},
{
"activation"
:
"relu"
},
{
"activation"
:
"relu"
},
]
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
for
sequence_parallel
in
[
False
,
True
]:
for
sequence_parallel
in
[
False
,
True
]:
...
...
tests/pytorch/distributed/test_numerics.py
View file @
1edc9e13
...
@@ -9,7 +9,8 @@ from pathlib import Path
...
@@ -9,7 +9,8 @@ from pathlib import Path
import
pytest
import
pytest
import
torch
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine
as
te
"""
"""
Distributed numerics tests
Distributed numerics tests
...
@@ -66,4 +67,15 @@ def test_distributed(quantization):
...
@@ -66,4 +67,15 @@ def test_distributed(quantization):
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
pytest
.
skip
(
reason_for_no_nvfp4
)
if
IS_HIP_EXTENSION
and
quantization
==
"fp8_block_scaling"
:
import
importlib
ori_int8_sim_fp8
=
os
.
environ
.
get
(
"NVTE_INT8_SIM_FP8"
,
"None"
)
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"1"
importlib
.
reload
(
te
.
pytorch
.
fp8
)
_run_test
(
quantization
)
_run_test
(
quantization
)
if
IS_HIP_EXTENSION
and
quantization
==
"fp8_block_scaling"
:
if
ori_int8_sim_fp8
is
None
or
ori_int8_sim_fp8
==
"None"
:
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"0"
else
:
del
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
importlib
.
reload
(
te
.
pytorch
.
fp8
)
tests/pytorch/test_cuda_graphs.py
View file @
1edc9e13
...
@@ -28,9 +28,9 @@ if IS_HIP_EXTENSION:
...
@@ -28,9 +28,9 @@ if IS_HIP_EXTENSION:
from
functools
import
cache
from
functools
import
cache
# Check if FP8 is supported.
# Check if FP8 is supported.
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Reset RNG states.
# Reset RNG states.
reset_rng_states
()
reset_rng_states
()
...
@@ -367,6 +367,12 @@ def test_make_graphed_callables(
...
@@ -367,6 +367,12 @@ def test_make_graphed_callables(
)
)
if
fp8_params
:
if
fp8_params
:
pytest
.
skip
(
"NVFP4 params not supported"
)
pytest
.
skip
(
"NVFP4 params not supported"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Run model with different CUDA graph settings.
# Run model with different CUDA graph settings.
model_config
=
model_configs
[
model_config
]
model_config
=
model_configs
[
model_config
]
...
...
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
1edc9e13
...
@@ -6,7 +6,6 @@ import pytest
...
@@ -6,7 +6,6 @@ import pytest
import
torch
import
torch
import
transformer_engine
as
te
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
...
...
tests/pytorch/test_fusible_ops.py
View file @
1edc9e13
...
@@ -2111,7 +2111,8 @@ class TestFusedOps:
...
@@ -2111,7 +2111,8 @@ class TestFusedOps:
quantized_weight
:
bool
=
False
,
quantized_weight
:
bool
=
False
,
)
->
None
:
)
->
None
:
"""Forward GEMM + scale + add"""
"""Forward GEMM + scale + add"""
if
IS_HIP_EXTENSION
and
scale
!=
1
:
pytest
.
skip
(
"alpha must be 1.0 for hip"
)
# Make input and weight shapes consistent
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
...
@@ -2496,7 +2497,8 @@ class TestFusedOps:
...
@@ -2496,7 +2497,8 @@ class TestFusedOps:
quantized_weight
:
bool
=
False
,
quantized_weight
:
bool
=
False
,
)
->
None
:
)
->
None
:
"""Backward dgrad GEMM + scale"""
"""Backward dgrad GEMM + scale"""
if
IS_HIP_EXTENSION
and
scale
!=
1
:
pytest
.
skip
(
"alpha must be 1.0 for hip"
)
# Make input and weight shapes consistent
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
...
...
tests/pytorch/test_numerics.py
View file @
1edc9e13
...
@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
...
@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
...
@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute(
...
@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute(
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -714,8 +721,15 @@ def _test_e2e_full_recompute(
...
@@ -714,8 +721,15 @@ def _test_e2e_full_recompute(
def
test_gpt_full_activation_recompute
(
def
test_gpt_full_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
...
@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation
=
True
fuse_wgrad_accumulation
=
True
fp8_model_params
=
False
fp8_model_params
=
False
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy(
...
@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy(
use_cutlass
=
False
,
use_cutlass
=
False
,
):
):
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
...
@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy(
...
@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy(
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"1"
outputs_ref
=
_test_grouped_linear_accuracy
(
outputs_ref
=
_test_grouped_linear_accuracy
(
sequential_linear
,
sequential_linear
,
num_gemms
,
num_gemms
,
...
@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy(
...
@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy(
fuse_wgrad_accumulation
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
delay_wgrad_compute
,
)
)
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"0"
for
o
,
o_ref
in
zip
(
outputs
,
outputs_ref
):
for
o
,
o_ref
in
zip
(
outputs
,
outputs_ref
):
if
use_cutlass
:
if
use_cutlass
:
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
...
@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input(
...
@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy(
...
@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params
,
fp8_model_params
,
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
...
@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
...
@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
NVTE_TEST_NVINSPECT_ENABLED
:
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
...
tests/pytorch/test_onnx_export.py
View file @
1edc9e13
...
@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
...
@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
,
te_translation_table
from
transformer_engine.pytorch.onnx_extensions
import
te_translation_table
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
get_default_init_method
from
transformer_engine.pytorch.utils
import
get_default_init_method
import
tensorrt
as
trt
import
tensorrt
as
trt
...
@@ -65,7 +67,6 @@ if mxfp8_available:
...
@@ -65,7 +67,6 @@ if mxfp8_available:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
fp8_available
:
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
None
)
fp8_recipes
.
append
(
None
)
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
...
@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
...
@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
],
],
outputs
=
[
PyCustomOpDef
.
dt_uint8
],
outputs
=
[
PyCustomOpDef
.
dt_uint8
],
)
)
def
trt_fp8_quantize
(
t
,
scale
_inv
):
def
trt_fp8_quantize
(
t
,
scale
):
"""FP8 quantization extension for ONNX Runtime."""
"""FP8 quantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
_inv
).
cuda
(),
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
)
...
@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv):
...
@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv):
],
],
outputs
=
[
PyCustomOpDef
.
dt_float
],
outputs
=
[
PyCustomOpDef
.
dt_float
],
)
)
def
trt_fp8_dequantize
(
t
,
scale
_inv
):
def
trt_fp8_dequantize
(
t
,
scale
):
"""FP8 dequantization extension for ONNX Runtime."""
"""FP8 dequantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
_inv
).
cuda
(),
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
)
...
@@ -469,16 +470,22 @@ def _test_export_linear(
...
@@ -469,16 +470,22 @@ def _test_export_linear(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
def
test_export_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
_test_export_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_export_linear_use_bias
(
seed_default_rng
,
use_bias
):
def
test_export_linear_use_bias
(
seed_default_rng
,
use_bias
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
use_bias
=
use_bias
)
_test_export_linear
(
use_bias
=
use_bias
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
def
test_export_linear_return_bias
(
seed_default_rng
,
return_bias
):
def
test_export_linear_return_bias
(
seed_default_rng
,
return_bias
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
return_bias
=
return_bias
)
_test_export_linear
(
return_bias
=
return_bias
)
...
@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
...
@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm_normalization
(
seed_default_rng
,
normalization
):
def
test_export_layernorm_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm
(
normalization
=
normalization
)
_test_export_layernorm
(
normalization
=
normalization
)
...
@@ -594,9 +603,7 @@ def _test_export_layernorm_linear(
...
@@ -594,9 +603,7 @@ def _test_export_layernorm_linear(
fname
,
fname
,
inp
,
inp
,
model
,
model
,
# For current scaling we use Float8Quantizer in tests + amax computed by hand,
atol
=
1e-3
,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol
=
1e-3
if
fp8_recipe
.
__class__
is
not
recipe
.
Float8CurrentScaling
else
2e-2
,
is_fp8
=
fp8_recipe
is
not
None
,
is_fp8
=
fp8_recipe
is
not
None
,
te_outputs
=
te_outputs
,
te_outputs
=
te_outputs
,
)
)
...
@@ -605,27 +612,39 @@ def _test_export_layernorm_linear(
...
@@ -605,27 +612,39 @@ def _test_export_layernorm_linear(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
def
test_export_layernorm_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
_test_export_layernorm_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_linear_return_ln_out
(
seed_default_rng
):
def
test_export_layernorm_linear_return_ln_out
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
return_layernorm_output
=
True
)
_test_export_layernorm_linear
(
return_layernorm_output
=
True
)
def
test_export_layernorm_linear_zero_centered_gamma
(
seed_default_rng
):
def
test_export_layernorm_linear_zero_centered_gamma
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
zero_centered_gamma
=
True
)
_test_export_layernorm_linear
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_linear_normalization
(
seed_default_rng
,
normalization
):
def
test_export_layernorm_linear_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
normalization
=
normalization
)
_test_export_layernorm_linear
(
normalization
=
normalization
)
def
test_export_layernorm_linear_no_bias
(
seed_default_rng
):
def
test_export_layernorm_linear_no_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
use_bias
=
False
)
_test_export_layernorm_linear
(
use_bias
=
False
)
def
test_export_layernorm_linear_return_bias
(
seed_default_rng
):
def
test_export_layernorm_linear_return_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
return_bias
=
True
)
_test_export_layernorm_linear
(
return_bias
=
True
)
...
@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp(
...
@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_mlp
(
seed_default_rng
,
fp8_recipe
,
precision
):
def
test_export_layernorm_mlp
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
_test_export_layernorm_mlp
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_mlp_return_layernorm_output
(
seed_default_rng
):
def
test_export_layernorm_mlp_return_layernorm_output
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
return_layernorm_output
=
True
)
_test_export_layernorm_mlp
(
return_layernorm_output
=
True
)
def
test_export_layernorm_mlp_return_bias
(
seed_default_rng
):
def
test_export_layernorm_mlp_return_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
return_bias
=
True
)
_test_export_layernorm_mlp
(
return_bias
=
True
)
def
test_export_layernorm_mlp_no_bias
(
seed_default_rng
):
def
test_export_layernorm_mlp_no_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
use_bias
=
False
)
_test_export_layernorm_mlp
(
use_bias
=
False
)
def
test_export_layernorm_mlp_zero_centered_gamma
(
seed_default_rng
):
def
test_export_layernorm_mlp_zero_centered_gamma
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
zero_centered_gamma
=
True
)
_test_export_layernorm_mlp
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_mlp_normalization
(
seed_default_rng
,
normalization
):
def
test_export_layernorm_mlp_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
normalization
=
normalization
)
_test_export_layernorm_mlp
(
normalization
=
normalization
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_layernorm_mlp_activation
(
seed_default_rng
,
activation
):
def
test_export_layernorm_mlp_activation
(
seed_default_rng
,
activation
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
activation
=
activation
)
_test_export_layernorm_mlp
(
activation
=
activation
)
...
@@ -731,6 +764,8 @@ def test_export_core_attention(
...
@@ -731,6 +764,8 @@ def test_export_core_attention(
use_mask
:
bool
,
use_mask
:
bool
,
attn_mask_type
:
str
,
attn_mask_type
:
str
,
):
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
# Set dimensions (these are arbitrary).
# Set dimensions (these are arbitrary).
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
=
(
64
,
4
,
1
,
64
)
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
=
(
64
,
4
,
1
,
64
)
qkv_size
=
(
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
)
qkv_size
=
(
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
)
...
@@ -932,22 +967,32 @@ def _test_export_multihead_attention(
...
@@ -932,22 +967,32 @@ def _test_export_multihead_attention(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_multihead_attention_recipe
(
fp8_recipe
,
precision
):
def
test_export_multihead_attention_recipe
(
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
_test_export_multihead_attention
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_multihead_attention_no_mask
():
def
test_export_multihead_attention_no_mask
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
use_mask
=
False
)
_test_export_multihead_attention
(
use_mask
=
False
)
def
test_export_multihead_attention_no_input_layernorm
():
def
test_export_multihead_attention_no_input_layernorm
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
input_layernorm
=
False
)
_test_export_multihead_attention
(
input_layernorm
=
False
)
def
test_export_multihead_attention_cross_attn
():
def
test_export_multihead_attention_cross_attn
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
attention_type
=
"cross"
)
_test_export_multihead_attention
(
attention_type
=
"cross"
)
def
test_export_multihead_attention_unfused_qkv_params
():
def
test_export_multihead_attention_unfused_qkv_params
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
fuse_qkv_params
=
False
)
_test_export_multihead_attention
(
fuse_qkv_params
=
False
)
...
@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer(
...
@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_transformer_layer_recipe
(
fp8_recipe
,
precision
):
def
test_export_transformer_layer_recipe
(
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
_test_export_transformer_layer
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_transformer_layer_no_mask
():
def
test_export_transformer_layer_no_mask
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
use_mask
=
False
)
_test_export_transformer_layer
(
use_mask
=
False
)
def
test_export_transformer_layer_output_layernorm
():
def
test_export_transformer_layer_output_layernorm
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
output_layernorm
=
True
)
_test_export_transformer_layer
(
output_layernorm
=
True
)
def
test_export_transformer_layer_unfused_qkv_params
():
def
test_export_transformer_layer_unfused_qkv_params
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
fuse_qkv_params
=
False
)
_test_export_transformer_layer
(
fuse_qkv_params
=
False
)
def
test_export_transformer_layer_zero_centered_gamma
():
def
test_export_transformer_layer_zero_centered_gamma
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
zero_centered_gamma
=
True
)
_test_export_transformer_layer
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_transformer_layer_activation
(
activation
):
def
test_export_transformer_layer_activation
(
activation
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
activation
=
activation
)
_test_export_transformer_layer
(
activation
=
activation
)
...
@@ -1056,7 +1113,8 @@ def test_export_gpt_generation(
...
@@ -1056,7 +1113,8 @@ def test_export_gpt_generation(
"""Test that the ONNX model can correctly handle inputs with different shapes and that
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
"""
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
# Layer configuration
# Layer configuration
hidden_size
=
64
hidden_size
=
64
sequence_length
=
128
sequence_length
=
128
...
@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled):
...
@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled):
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
test_trt_integration
(
fp8_recipe
:
recipe
.
Recipe
):
def
test_trt_integration
(
fp8_recipe
:
recipe
.
Recipe
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"TRT is not supported for HIP"
)
model
=
te
.
TransformerLayer
(
model
=
te
.
TransformerLayer
(
hidden_size
=
128
,
hidden_size
=
128
,
ffn_hidden_size
=
128
,
ffn_hidden_size
=
128
,
num_attention_heads
=
4
,
num_attention_heads
=
4
,
).
eval
()
).
eval
()
if
type
(
fp8_recipe
)
==
recipe
.
Float8CurrentScaling
:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model
=
te
.
LayerNormMLP
(
128
,
128
)
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
...
...
tests/pytorch/test_sanity.py
View file @
1edc9e13
...
@@ -46,7 +46,7 @@ from utils import ModelConfig
...
@@ -46,7 +46,7 @@ from utils import ModelConfig
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Record initial RNG state from script run.
# Record initial RNG state from script run.
...
@@ -388,7 +388,13 @@ def test_sanity_layernorm_linear(
...
@@ -388,7 +388,13 @@ def test_sanity_layernorm_linear(
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -450,7 +456,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
...
@@ -450,7 +456,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
num_tokens
=
bs
*
config
.
max_seqlen_q
num_tokens
=
bs
*
config
.
max_seqlen_q
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -543,7 +555,13 @@ def test_sanity_layernorm_mlp(
...
@@ -543,7 +555,13 @@ def test_sanity_layernorm_mlp(
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -587,7 +605,13 @@ def test_sanity_gpt(
...
@@ -587,7 +605,13 @@ def test_sanity_gpt(
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -759,7 +783,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -759,7 +783,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -791,7 +821,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
...
@@ -791,7 +821,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -827,7 +863,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -827,7 +863,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -863,7 +905,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
...
@@ -863,7 +905,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
1edc9e13
...
@@ -1561,15 +1561,6 @@ void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor
...
@@ -1561,15 +1561,6 @@ void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor
NVTE_ERROR
(
"TT layout not allowed."
);
NVTE_ERROR
(
"TT layout not allowed."
);
}
}
hipblasLtHandle_t
handle
=
nullptr
;
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
0
];
NVTE_ERROR
(
"Remove nvte_cublas_batchgemm_tensorwise_int8 for now."
);
NVTE_ERROR
(
"Remove nvte_cublas_batchgemm_tensorwise_int8 for now."
);
}
}
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
1edc9e13
This diff is collapsed.
Click to expand it.
transformer_engine/common/transpose/transpose.cu
View file @
1edc9e13
...
@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
...
@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
// Choose between runtime-compiled or statically-compiled kernel
// Choose between runtime-compiled or statically-compiled kernel
const
bool
aligned
=
(
row_length
%
THREADS_PER_WARP
==
0
&&
num_rows
%
THREADS_PER_WARP
==
0
);
const
bool
aligned
=
(
row_length
%
THREADS_PER_WARP
==
0
&&
num_rows
%
THREADS_PER_WARP
==
0
);
if
(
aligned
&&
rtc
::
is_enabled
())
{
// Runtime-compiled tuned kernel
//TODO:Using RTC may cause kernel crashes. Therefore, set use_rtc to true to avoid using RTC and resolve the kernel crash issue.
#ifdef USE_ROCM
const
bool
use_rtc
=
false
;
#else
const
bool
use_rtc
=
true
;
#endif
if
(
aligned
&&
rtc
::
is_enabled
()
&&
use_rtc
)
{
// Runtime-compiled tuned kernel
// Pick kernel config
// Pick kernel config
std
::
vector
<
KernelConfig
>
kernel_configs
;
std
::
vector
<
KernelConfig
>
kernel_configs
;
kernel_configs
.
reserve
(
16
);
kernel_configs
.
reserve
(
16
);
...
...
transformer_engine/pytorch/module/_common.py
View file @
1edc9e13
...
@@ -55,7 +55,7 @@ def apply_normalization(
...
@@ -55,7 +55,7 @@ def apply_normalization(
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
if
enable_lightop
and
(
ln_bias
is
None
)
and
normalization
==
"RMSNorm"
and
output_quantizer
is
None
and
(
output_dtype
is
torch
.
bfloat16
or
output_dtype
is
torch
.
float16
or
output_dtype
is
torch
.
float32
):
if
enable_lightop
and
(
ln_bias
is
None
)
and
normalization
==
"RMSNorm"
and
output_quantizer
is
None
and
(
output_dtype
is
torch
.
bfloat16
or
output_dtype
is
torch
.
float16
or
output_dtype
is
torch
.
float32
)
and
not
zero_centered_gamma
:
out
,
rsigma
=
rmsnorm_forward
(
inputmat
,
ln_weight
,
ln_out
,
eps
,
True
)
out
,
rsigma
=
rmsnorm_forward
(
inputmat
,
ln_weight
,
ln_out
,
eps
,
True
)
return
out
,
None
,
rsigma
return
out
,
None
,
rsigma
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment