Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
1edc9e13
Commit
1edc9e13
authored
Nov 26, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.8' into release_v2.8
parents
5e7dd67e
3a040217
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2149 additions
and
2093 deletions
+2149
-2093
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-0
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+17
-16
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+43
-19
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+40
-7
tests/pytorch/distributed/test_numerics.py
tests/pytorch/distributed/test_numerics.py
+13
-1
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+9
-3
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+0
-1
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+4
-2
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+55
-4
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+70
-15
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+57
-9
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+0
-9
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+1832
-2005
transformer_engine/common/transpose/transpose.cu
transformer_engine/common/transpose/transpose.cu
+7
-1
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+1
-1
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
1edc9e13
...
@@ -49,6 +49,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_tes
...
@@ -49,6 +49,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_tes
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_attention.xml
$TE_PATH
/tests/pytorch/attention/test_attention.py
||
test_fail
"test_attention.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_attention.xml
$TE_PATH
/tests/pytorch/attention/test_attention.py
||
test_fail
"test_attention.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/attention/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/attention/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
mkdir
-p
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint
&&
python
$TE_PATH
/tests/pytorch/test_checkpoint.py
--save-checkpoint
all
--checkpoint-dir
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint/
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH
=
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_checkpoint.xml
$TE_PATH
/tests/pytorch/test_checkpoint.py
||
test_fail
"test_checkpoint.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH
=
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_checkpoint.xml
$TE_PATH
/tests/pytorch/test_checkpoint.py
||
test_fail
"test_checkpoint.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_router.xml
$TE_PATH
/tests/pytorch/test_fused_router.py
||
test_fail
"test_fused_router.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_router.xml
$TE_PATH
/tests/pytorch/test_fused_router.py
||
test_fail
"test_fused_router.py"
...
...
tests/pytorch/attention/test_attention.py
View file @
1edc9e13
...
@@ -340,22 +340,23 @@ def test_dpa_softmax(dtype, model_configs, model):
...
@@ -340,22 +340,23 @@ def test_dpa_softmax(dtype, model_configs, model):
model_configs_mla
=
{
model_configs_mla
=
{
# test: ModelConfig(b, sq, hq, dqk)
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
"mla_1_0"
:
ModelConfig
(
8
,
128
,
16
,
64
,
head_dim_v
=
128
),
# # test: ModelConfig(b, sq, hq, dqk)
"mla_1_1"
:
ModelConfig
(
4
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_2"
:
ModelConfig
(
4
,
128
,
16
,
192
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_1"
:
ModelConfig
(
# "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
# "mla_2_1": ModelConfig(
),
# 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
"mla_2_2"
:
ModelConfig
(
# ), # cross, 1
1
,
2048
,
24
,
192
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
128
# "mla_2_2": ModelConfig(
),
# 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
"mla_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
head_dim_v
=
64
),
# ), # cross, 1
"mla_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_2"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# inference
}
}
...
...
tests/pytorch/debug/run_distributed.py
View file @
1edc9e13
...
@@ -17,6 +17,7 @@ import transformer_engine_torch as tex
...
@@ -17,6 +17,7 @@ import transformer_engine_torch as tex
import
nvdlfw_inspect.api
as
debug_api
import
nvdlfw_inspect.api
as
debug_api
from
transformer_engine.debug
import
set_weight_tensor_tp_group_reduce
from
transformer_engine.debug
import
set_weight_tensor_tp_group_reduce
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
test_numerics
import
(
from
test_numerics
import
(
_emulate_linear
,
_emulate_linear
,
...
@@ -47,7 +48,6 @@ TEST_NR = 0
...
@@ -47,7 +48,6 @@ TEST_NR = 0
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
def
_get_tensors
(
parallel_mode
,
weight_seed
=
SEED
,
data_seed
=
SEED
,
tp_size
=
None
,
tp_rank
=
None
):
def
_get_tensors
(
parallel_mode
,
weight_seed
=
SEED
,
data_seed
=
SEED
,
tp_size
=
None
,
tp_rank
=
None
):
if
tp_size
is
None
:
if
tp_size
is
None
:
tp_size
=
WORLD_SIZE
tp_size
=
WORLD_SIZE
...
@@ -72,6 +72,16 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None,
...
@@ -72,6 +72,16 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None,
def
_init_model
(
weight
,
parallel_mode
=
None
,
tp_group
=
None
,
name
=
"linear"
):
def
_init_model
(
weight
,
parallel_mode
=
None
,
tp_group
=
None
,
name
=
"linear"
):
if
IS_HIP_EXTENSION
:
model
=
transformer_engine
.
pytorch
.
Linear
(
IN_SIZE
,
OUT_SIZE
,
name
=
name
,
bias
=
False
,
parallel_mode
=
parallel_mode
,
tp_group
=
(
tp_group
or
NCCL_WORLD
if
parallel_mode
else
None
),
)
else
:
model
=
transformer_engine
.
pytorch
.
Linear
(
model
=
transformer_engine
.
pytorch
.
Linear
(
IN_SIZE
,
IN_SIZE
,
OUT_SIZE
,
OUT_SIZE
,
...
@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
...
@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
)
)
set_weight_tensor_tp_group_reduce
(
True
)
# reset
set_weight_tensor_tp_group_reduce
(
True
)
# reset
@
run_debug_test
@
run_debug_test
def
sanity_test_log_quantized_stats
(
parallel_mode
,
gather_weight
,
**
kwargs
):
def
sanity_test_log_quantized_stats
(
parallel_mode
,
gather_weight
,
**
kwargs
):
from
test_log
import
LOG_QUANTIZED_CONFIG
from
test_log
import
LOG_QUANTIZED_CONFIG
...
@@ -580,6 +589,9 @@ def test_fake_quant_fp8(
...
@@ -580,6 +589,9 @@ def test_fake_quant_fp8(
"dgrad_fp8"
:
not
(
dgrad_weight
or
dgrad_grad
),
"dgrad_fp8"
:
not
(
dgrad_weight
or
dgrad_grad
),
"wgrad_fp8"
:
not
(
wgrad_grad
or
wgrad_input
),
"wgrad_fp8"
:
not
(
wgrad_grad
or
wgrad_input
),
}
}
if
IS_HIP_EXTENSION
:
if
fp8_kwargs
[
"fprop_fp8"
]
or
fp8_kwargs
[
"dgrad_fp8"
]
or
fp8_kwargs
[
"wgrad_fp8"
]:
return
# Output type 32 (FP32) does not support int8 simulation.
if
WORLD_RANK
==
0
:
if
WORLD_RANK
==
0
:
fake_quant_fp8_create_config
(
fake_quant_fp8_create_config
(
fprop_inp
,
fprop_inp
,
...
@@ -667,6 +679,10 @@ if __name__ == "__main__":
...
@@ -667,6 +679,10 @@ if __name__ == "__main__":
random
.
seed
(
SEED
)
random
.
seed
(
SEED
)
_init_distributed
()
_init_distributed
()
if
IS_HIP_EXTENSION
:
# Output type 32 (FP32) does not support int8 simulation.
pass
else
:
test_log_expert_parallel
()
test_log_expert_parallel
()
for
parallel_mode
in
[
"column"
,
"row"
]:
for
parallel_mode
in
[
"column"
,
"row"
]:
for
gather_weight
in
[
True
,
False
]:
for
gather_weight
in
[
True
,
False
]:
...
@@ -676,6 +692,11 @@ if __name__ == "__main__":
...
@@ -676,6 +692,11 @@ if __name__ == "__main__":
for
parallel_mode
in
[
"row"
,
"column"
]:
for
parallel_mode
in
[
"row"
,
"column"
]:
test_disable_fp8_layer
(
parallel_mode
)
test_disable_fp8_layer
(
parallel_mode
)
if
IS_HIP_EXTENSION
:
# Output type 32 (FP32) does not support int8 simulation.
pass
else
:
# test_disable_fp8_gemms
# test_disable_fp8_gemms
_run_test_with_combinations
(
_run_test_with_combinations
(
test_disable_fp8_gemms
,
all_boolean
,
num_repeat
=
3
,
extra_args
=
[
"column"
,
"row"
]
test_disable_fp8_gemms
,
all_boolean
,
num_repeat
=
3
,
extra_args
=
[
"column"
,
"row"
]
...
@@ -690,7 +711,10 @@ if __name__ == "__main__":
...
@@ -690,7 +711,10 @@ if __name__ == "__main__":
extra_args
=
[
"column"
,
"row"
],
extra_args
=
[
"column"
,
"row"
],
sample_size
=
20
,
sample_size
=
20
,
)
)
if
IS_HIP_EXTENSION
:
# Output type 32 (FP32) does not support int8 simulation.
pass
else
:
_run_test_with_combinations
(
_run_test_with_combinations
(
test_per_tensor_scaling
,
test_per_tensor_scaling
,
all_boolean
,
all_boolean
,
...
...
tests/pytorch/distributed/run_numerics.py
View file @
1edc9e13
...
@@ -733,7 +733,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
...
@@ -733,7 +733,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
def
test_linear
():
def
test_linear
():
"""Run linear layer tests with various configurations."""
"""Run linear layer tests with various configurations."""
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{},
{
"bias"
:
False
},
{
"bias"
:
False
},
{
"init_method"
:
_constant
},
{
"init_method"
:
_constant
},
...
@@ -743,7 +743,15 @@ def test_linear():
...
@@ -743,7 +743,15 @@ def test_linear():
{
"delay_wgrad_compute"
:
True
},
{
"delay_wgrad_compute"
:
True
},
{
"save_original_input"
:
True
},
{
"save_original_input"
:
True
},
]
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
continue
continue
...
@@ -913,7 +921,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
...
@@ -913,7 +921,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
def
test_layernorm_linear
():
def
test_layernorm_linear
():
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{},
{
"bias"
:
False
},
{
"bias"
:
False
},
{
"init_method"
:
_constant
},
{
"init_method"
:
_constant
},
...
@@ -924,7 +932,15 @@ def test_layernorm_linear():
...
@@ -924,7 +932,15 @@ def test_layernorm_linear():
{
"return_layernorm_output"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
{
"delay_wgrad_compute"
:
True
},
]
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
for
parallel_mode
in
[
"column"
]:
for
parallel_mode
in
[
"column"
]:
for
sequence_parallel
in
[
False
,
True
]:
for
sequence_parallel
in
[
False
,
True
]:
...
@@ -1019,7 +1035,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
...
@@ -1019,7 +1035,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
def
test_layernorm_mlp
():
def
test_layernorm_mlp
():
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{},
{
"init_method"
:
_constant
},
{
"init_method"
:
_constant
},
{
"output_layer_init_method"
:
_constant
},
{
"output_layer_init_method"
:
_constant
},
...
@@ -1033,7 +1049,15 @@ def test_layernorm_mlp():
...
@@ -1033,7 +1049,15 @@ def test_layernorm_mlp():
{
"return_layernorm_output"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
{
"delay_wgrad_compute"
:
True
},
]
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
for
set_parallel_mode
in
[
True
]:
for
set_parallel_mode
in
[
True
]:
for
sequence_parallel
in
[
False
,
True
]:
for
sequence_parallel
in
[
False
,
True
]:
...
@@ -1108,7 +1132,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
...
@@ -1108,7 +1132,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
def
test_transformer_layer
():
def
test_transformer_layer
():
kwargs_list
=
[
base_
kwargs_list
=
[
{},
{},
{
"num_gqa_groups"
:
4
},
{
"num_gqa_groups"
:
4
},
{
"init_method"
:
_constant
},
{
"init_method"
:
_constant
},
...
@@ -1128,6 +1152,15 @@ def test_transformer_layer():
...
@@ -1128,6 +1152,15 @@ def test_transformer_layer():
{
"fuse_qkv_params"
:
True
},
{
"fuse_qkv_params"
:
True
},
{
"activation"
:
"relu"
},
{
"activation"
:
"relu"
},
]
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if
IS_HIP_EXTENSION
and
QUANTIZATION
==
"fp8_block_scaling"
:
kwargs_list
=
[
kwargs
for
kwargs
in
base_kwargs_list
if
kwargs
.
get
(
"bias"
,
True
)
is
False
]
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
for
sequence_parallel
in
[
False
,
True
]:
for
sequence_parallel
in
[
False
,
True
]:
...
...
tests/pytorch/distributed/test_numerics.py
View file @
1edc9e13
...
@@ -9,7 +9,8 @@ from pathlib import Path
...
@@ -9,7 +9,8 @@ from pathlib import Path
import
pytest
import
pytest
import
torch
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine
as
te
"""
"""
Distributed numerics tests
Distributed numerics tests
...
@@ -66,4 +67,15 @@ def test_distributed(quantization):
...
@@ -66,4 +67,15 @@ def test_distributed(quantization):
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
pytest
.
skip
(
reason_for_no_nvfp4
)
if
IS_HIP_EXTENSION
and
quantization
==
"fp8_block_scaling"
:
import
importlib
ori_int8_sim_fp8
=
os
.
environ
.
get
(
"NVTE_INT8_SIM_FP8"
,
"None"
)
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"1"
importlib
.
reload
(
te
.
pytorch
.
fp8
)
_run_test
(
quantization
)
_run_test
(
quantization
)
if
IS_HIP_EXTENSION
and
quantization
==
"fp8_block_scaling"
:
if
ori_int8_sim_fp8
is
None
or
ori_int8_sim_fp8
==
"None"
:
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"0"
else
:
del
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
importlib
.
reload
(
te
.
pytorch
.
fp8
)
tests/pytorch/test_cuda_graphs.py
View file @
1edc9e13
...
@@ -28,9 +28,9 @@ if IS_HIP_EXTENSION:
...
@@ -28,9 +28,9 @@ if IS_HIP_EXTENSION:
from
functools
import
cache
from
functools
import
cache
# Check if FP8 is supported.
# Check if FP8 is supported.
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Reset RNG states.
# Reset RNG states.
reset_rng_states
()
reset_rng_states
()
...
@@ -367,6 +367,12 @@ def test_make_graphed_callables(
...
@@ -367,6 +367,12 @@ def test_make_graphed_callables(
)
)
if
fp8_params
:
if
fp8_params
:
pytest
.
skip
(
"NVFP4 params not supported"
)
pytest
.
skip
(
"NVFP4 params not supported"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Run model with different CUDA graph settings.
# Run model with different CUDA graph settings.
model_config
=
model_configs
[
model_config
]
model_config
=
model_configs
[
model_config
]
...
...
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
1edc9e13
...
@@ -6,7 +6,6 @@ import pytest
...
@@ -6,7 +6,6 @@ import pytest
import
torch
import
torch
import
transformer_engine
as
te
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
...
...
tests/pytorch/test_fusible_ops.py
View file @
1edc9e13
...
@@ -2111,7 +2111,8 @@ class TestFusedOps:
...
@@ -2111,7 +2111,8 @@ class TestFusedOps:
quantized_weight
:
bool
=
False
,
quantized_weight
:
bool
=
False
,
)
->
None
:
)
->
None
:
"""Forward GEMM + scale + add"""
"""Forward GEMM + scale + add"""
if
IS_HIP_EXTENSION
and
scale
!=
1
:
pytest
.
skip
(
"alpha must be 1.0 for hip"
)
# Make input and weight shapes consistent
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
...
@@ -2496,7 +2497,8 @@ class TestFusedOps:
...
@@ -2496,7 +2497,8 @@ class TestFusedOps:
quantized_weight
:
bool
=
False
,
quantized_weight
:
bool
=
False
,
)
->
None
:
)
->
None
:
"""Backward dgrad GEMM + scale"""
"""Backward dgrad GEMM + scale"""
if
IS_HIP_EXTENSION
and
scale
!=
1
:
pytest
.
skip
(
"alpha must be 1.0 for hip"
)
# Make input and weight shapes consistent
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
...
...
tests/pytorch/test_numerics.py
View file @
1edc9e13
...
@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
...
@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
...
@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute(
...
@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute(
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -714,8 +721,15 @@ def _test_e2e_full_recompute(
...
@@ -714,8 +721,15 @@ def _test_e2e_full_recompute(
def
test_gpt_full_activation_recompute
(
def
test_gpt_full_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
...
@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation
=
True
fuse_wgrad_accumulation
=
True
fp8_model_params
=
False
fp8_model_params
=
False
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy(
...
@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy(
use_cutlass
=
False
,
use_cutlass
=
False
,
):
):
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
...
@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy(
...
@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy(
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"1"
outputs_ref
=
_test_grouped_linear_accuracy
(
outputs_ref
=
_test_grouped_linear_accuracy
(
sequential_linear
,
sequential_linear
,
num_gemms
,
num_gemms
,
...
@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy(
...
@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy(
fuse_wgrad_accumulation
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
delay_wgrad_compute
,
)
)
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"0"
for
o
,
o_ref
in
zip
(
outputs
,
outputs_ref
):
for
o
,
o_ref
in
zip
(
outputs
,
outputs_ref
):
if
use_cutlass
:
if
use_cutlass
:
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
...
@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input(
...
@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy(
...
@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params
,
fp8_model_params
,
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
...
@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
...
@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
NVTE_TEST_NVINSPECT_ENABLED
:
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
...
tests/pytorch/test_onnx_export.py
View file @
1edc9e13
...
@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
...
@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
,
te_translation_table
from
transformer_engine.pytorch.onnx_extensions
import
te_translation_table
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
get_default_init_method
from
transformer_engine.pytorch.utils
import
get_default_init_method
import
tensorrt
as
trt
import
tensorrt
as
trt
...
@@ -65,7 +67,6 @@ if mxfp8_available:
...
@@ -65,7 +67,6 @@ if mxfp8_available:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
fp8_available
:
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
None
)
fp8_recipes
.
append
(
None
)
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
...
@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
...
@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
],
],
outputs
=
[
PyCustomOpDef
.
dt_uint8
],
outputs
=
[
PyCustomOpDef
.
dt_uint8
],
)
)
def
trt_fp8_quantize
(
t
,
scale
_inv
):
def
trt_fp8_quantize
(
t
,
scale
):
"""FP8 quantization extension for ONNX Runtime."""
"""FP8 quantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
_inv
).
cuda
(),
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
)
...
@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv):
...
@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv):
],
],
outputs
=
[
PyCustomOpDef
.
dt_float
],
outputs
=
[
PyCustomOpDef
.
dt_float
],
)
)
def
trt_fp8_dequantize
(
t
,
scale
_inv
):
def
trt_fp8_dequantize
(
t
,
scale
):
"""FP8 dequantization extension for ONNX Runtime."""
"""FP8 dequantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
_inv
).
cuda
(),
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
)
...
@@ -469,16 +470,22 @@ def _test_export_linear(
...
@@ -469,16 +470,22 @@ def _test_export_linear(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
def
test_export_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
_test_export_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_export_linear_use_bias
(
seed_default_rng
,
use_bias
):
def
test_export_linear_use_bias
(
seed_default_rng
,
use_bias
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
use_bias
=
use_bias
)
_test_export_linear
(
use_bias
=
use_bias
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
def
test_export_linear_return_bias
(
seed_default_rng
,
return_bias
):
def
test_export_linear_return_bias
(
seed_default_rng
,
return_bias
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
return_bias
=
return_bias
)
_test_export_linear
(
return_bias
=
return_bias
)
...
@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
...
@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm_normalization
(
seed_default_rng
,
normalization
):
def
test_export_layernorm_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm
(
normalization
=
normalization
)
_test_export_layernorm
(
normalization
=
normalization
)
...
@@ -594,9 +603,7 @@ def _test_export_layernorm_linear(
...
@@ -594,9 +603,7 @@ def _test_export_layernorm_linear(
fname
,
fname
,
inp
,
inp
,
model
,
model
,
# For current scaling we use Float8Quantizer in tests + amax computed by hand,
atol
=
1e-3
,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol
=
1e-3
if
fp8_recipe
.
__class__
is
not
recipe
.
Float8CurrentScaling
else
2e-2
,
is_fp8
=
fp8_recipe
is
not
None
,
is_fp8
=
fp8_recipe
is
not
None
,
te_outputs
=
te_outputs
,
te_outputs
=
te_outputs
,
)
)
...
@@ -605,27 +612,39 @@ def _test_export_layernorm_linear(
...
@@ -605,27 +612,39 @@ def _test_export_layernorm_linear(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
def
test_export_layernorm_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
_test_export_layernorm_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_linear_return_ln_out
(
seed_default_rng
):
def
test_export_layernorm_linear_return_ln_out
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
return_layernorm_output
=
True
)
_test_export_layernorm_linear
(
return_layernorm_output
=
True
)
def
test_export_layernorm_linear_zero_centered_gamma
(
seed_default_rng
):
def
test_export_layernorm_linear_zero_centered_gamma
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
zero_centered_gamma
=
True
)
_test_export_layernorm_linear
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_linear_normalization
(
seed_default_rng
,
normalization
):
def
test_export_layernorm_linear_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
normalization
=
normalization
)
_test_export_layernorm_linear
(
normalization
=
normalization
)
def
test_export_layernorm_linear_no_bias
(
seed_default_rng
):
def
test_export_layernorm_linear_no_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
use_bias
=
False
)
_test_export_layernorm_linear
(
use_bias
=
False
)
def
test_export_layernorm_linear_return_bias
(
seed_default_rng
):
def
test_export_layernorm_linear_return_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
return_bias
=
True
)
_test_export_layernorm_linear
(
return_bias
=
True
)
...
@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp(
...
@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_mlp
(
seed_default_rng
,
fp8_recipe
,
precision
):
def
test_export_layernorm_mlp
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
_test_export_layernorm_mlp
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_mlp_return_layernorm_output
(
seed_default_rng
):
def
test_export_layernorm_mlp_return_layernorm_output
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
return_layernorm_output
=
True
)
_test_export_layernorm_mlp
(
return_layernorm_output
=
True
)
def
test_export_layernorm_mlp_return_bias
(
seed_default_rng
):
def
test_export_layernorm_mlp_return_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
return_bias
=
True
)
_test_export_layernorm_mlp
(
return_bias
=
True
)
def
test_export_layernorm_mlp_no_bias
(
seed_default_rng
):
def
test_export_layernorm_mlp_no_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
use_bias
=
False
)
_test_export_layernorm_mlp
(
use_bias
=
False
)
def
test_export_layernorm_mlp_zero_centered_gamma
(
seed_default_rng
):
def
test_export_layernorm_mlp_zero_centered_gamma
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
zero_centered_gamma
=
True
)
_test_export_layernorm_mlp
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_mlp_normalization
(
seed_default_rng
,
normalization
):
def
test_export_layernorm_mlp_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
normalization
=
normalization
)
_test_export_layernorm_mlp
(
normalization
=
normalization
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_layernorm_mlp_activation
(
seed_default_rng
,
activation
):
def
test_export_layernorm_mlp_activation
(
seed_default_rng
,
activation
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
activation
=
activation
)
_test_export_layernorm_mlp
(
activation
=
activation
)
...
@@ -731,6 +764,8 @@ def test_export_core_attention(
...
@@ -731,6 +764,8 @@ def test_export_core_attention(
use_mask
:
bool
,
use_mask
:
bool
,
attn_mask_type
:
str
,
attn_mask_type
:
str
,
):
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
# Set dimensions (these are arbitrary).
# Set dimensions (these are arbitrary).
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
=
(
64
,
4
,
1
,
64
)
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
=
(
64
,
4
,
1
,
64
)
qkv_size
=
(
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
)
qkv_size
=
(
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
)
...
@@ -932,22 +967,32 @@ def _test_export_multihead_attention(
...
@@ -932,22 +967,32 @@ def _test_export_multihead_attention(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_multihead_attention_recipe
(
fp8_recipe
,
precision
):
def
test_export_multihead_attention_recipe
(
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
_test_export_multihead_attention
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_multihead_attention_no_mask
():
def
test_export_multihead_attention_no_mask
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
use_mask
=
False
)
_test_export_multihead_attention
(
use_mask
=
False
)
def
test_export_multihead_attention_no_input_layernorm
():
def
test_export_multihead_attention_no_input_layernorm
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
input_layernorm
=
False
)
_test_export_multihead_attention
(
input_layernorm
=
False
)
def
test_export_multihead_attention_cross_attn
():
def
test_export_multihead_attention_cross_attn
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
attention_type
=
"cross"
)
_test_export_multihead_attention
(
attention_type
=
"cross"
)
def
test_export_multihead_attention_unfused_qkv_params
():
def
test_export_multihead_attention_unfused_qkv_params
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
fuse_qkv_params
=
False
)
_test_export_multihead_attention
(
fuse_qkv_params
=
False
)
...
@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer(
...
@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_transformer_layer_recipe
(
fp8_recipe
,
precision
):
def
test_export_transformer_layer_recipe
(
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
_test_export_transformer_layer
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_transformer_layer_no_mask
():
def
test_export_transformer_layer_no_mask
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
use_mask
=
False
)
_test_export_transformer_layer
(
use_mask
=
False
)
def
test_export_transformer_layer_output_layernorm
():
def
test_export_transformer_layer_output_layernorm
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
output_layernorm
=
True
)
_test_export_transformer_layer
(
output_layernorm
=
True
)
def
test_export_transformer_layer_unfused_qkv_params
():
def
test_export_transformer_layer_unfused_qkv_params
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
fuse_qkv_params
=
False
)
_test_export_transformer_layer
(
fuse_qkv_params
=
False
)
def
test_export_transformer_layer_zero_centered_gamma
():
def
test_export_transformer_layer_zero_centered_gamma
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
zero_centered_gamma
=
True
)
_test_export_transformer_layer
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_transformer_layer_activation
(
activation
):
def
test_export_transformer_layer_activation
(
activation
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
activation
=
activation
)
_test_export_transformer_layer
(
activation
=
activation
)
...
@@ -1056,7 +1113,8 @@ def test_export_gpt_generation(
...
@@ -1056,7 +1113,8 @@ def test_export_gpt_generation(
"""Test that the ONNX model can correctly handle inputs with different shapes and that
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
"""
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
# Layer configuration
# Layer configuration
hidden_size
=
64
hidden_size
=
64
sequence_length
=
128
sequence_length
=
128
...
@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled):
...
@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled):
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
test_trt_integration
(
fp8_recipe
:
recipe
.
Recipe
):
def
test_trt_integration
(
fp8_recipe
:
recipe
.
Recipe
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"TRT is not supported for HIP"
)
model
=
te
.
TransformerLayer
(
model
=
te
.
TransformerLayer
(
hidden_size
=
128
,
hidden_size
=
128
,
ffn_hidden_size
=
128
,
ffn_hidden_size
=
128
,
num_attention_heads
=
4
,
num_attention_heads
=
4
,
).
eval
()
).
eval
()
if
type
(
fp8_recipe
)
==
recipe
.
Float8CurrentScaling
:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model
=
te
.
LayerNormMLP
(
128
,
128
)
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
...
...
tests/pytorch/test_sanity.py
View file @
1edc9e13
...
@@ -46,7 +46,7 @@ from utils import ModelConfig
...
@@ -46,7 +46,7 @@ from utils import ModelConfig
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Record initial RNG state from script run.
# Record initial RNG state from script run.
...
@@ -388,7 +388,13 @@ def test_sanity_layernorm_linear(
...
@@ -388,7 +388,13 @@ def test_sanity_layernorm_linear(
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -450,7 +456,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
...
@@ -450,7 +456,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
num_tokens
=
bs
*
config
.
max_seqlen_q
num_tokens
=
bs
*
config
.
max_seqlen_q
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -543,7 +555,13 @@ def test_sanity_layernorm_mlp(
...
@@ -543,7 +555,13 @@ def test_sanity_layernorm_mlp(
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -587,7 +605,13 @@ def test_sanity_gpt(
...
@@ -587,7 +605,13 @@ def test_sanity_gpt(
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -759,7 +783,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -759,7 +783,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -791,7 +821,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
...
@@ -791,7 +821,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -827,7 +863,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -827,7 +863,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
@@ -863,7 +905,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
...
@@ -863,7 +905,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
1edc9e13
...
@@ -1561,15 +1561,6 @@ void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor
...
@@ -1561,15 +1561,6 @@ void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor
NVTE_ERROR
(
"TT layout not allowed."
);
NVTE_ERROR
(
"TT layout not allowed."
);
}
}
hipblasLtHandle_t
handle
=
nullptr
;
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
0
];
NVTE_ERROR
(
"Remove nvte_cublas_batchgemm_tensorwise_int8 for now."
);
NVTE_ERROR
(
"Remove nvte_cublas_batchgemm_tensorwise_int8 for now."
);
}
}
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
1edc9e13
...
@@ -465,106 +465,6 @@ transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t)
...
@@ -465,106 +465,6 @@ transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t)
namespace
{
namespace
{
static
class
HandlePool
{
public:
hipblasLtHandle_t
get
(
int
device_id
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mt
);
if
(
pool
.
empty
())
{
int
device_count
=
0
;
NVTE_CHECK_CUDA
(
hipGetDeviceCount
(
&
device_count
));
pool
.
resize
(
device_count
);
return
nullptr
;
}
if
(
!
pool
[
device_id
].
empty
())
{
hipblasLtHandle_t
h
=
pool
[
device_id
].
front
();
pool
[
device_id
].
pop_front
();
return
h
;
}
return
nullptr
;
}
hipblasLtHandle_t
obtain
(
int
device_id
)
{
hipblasLtHandle_t
h
=
get
(
device_id
);
if
(
h
==
nullptr
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
&
h
));
}
return
h
;
}
void
store
(
const
std
::
vector
<
hipblasLtHandle_t
>&
handles
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mt
);
if
(
pool
.
empty
())
{
std
::
cout
<<
"[ERROR] Attempt to store handles to invalid pool"
<<
std
::
endl
;
}
for
(
unsigned
int
i
=
0
;
i
<
pool
.
size
();
i
++
)
{
if
(
handles
[
i
]
!=
nullptr
)
{
pool
[
i
].
push_front
(
handles
[
i
]);
}
}
}
~
HandlePool
()
{
#if DESTROY_HIPBLASLT_HANDLES_POOL
std
::
lock_guard
<
std
::
mutex
>
lock
(
mt
);
for
(
auto
&
hlist
:
pool
)
{
for
(
auto
&
h
:
hlist
)
{
hipblasLtDestroy
(
h
);
}
}
pool
.
clear
();
#endif
}
inline
size_t
get_size
()
const
{
return
pool
.
size
();
}
private:
std
::
mutex
mt
;
using
Pool
=
std
::
vector
<
std
::
forward_list
<
hipblasLtHandle_t
>>
;
// Order of destructors between thread_local and global is not actually guaranteed
// As a simple w/a make pool storage "leaky"
// Just do not destruct it and do not destroy hipbladLt handles
// Let OS deal with it on application exit
#if DESTROY_HIPBLASLT_HANDLES_POOL
Pool
pool
;
#else
Pool
&
pool
=
*
new
Pool
();
#endif
}
handle_pool
;
thread_local
static
class
HandleCache
{
public:
hipblasLtHandle_t
get
(
int
device_id
)
const
{
return
d
.
empty
()
?
nullptr
:
d
[
device_id
];
}
hipblasLtHandle_t
obtain
(
int
device_id
)
{
hipblasLtHandle_t
h
=
get
(
device_id
);
if
(
h
)
{
return
h
;
}
h
=
handle_pool
.
obtain
(
device_id
);
set
(
device_id
,
h
);
return
h
;
}
void
set
(
int
device_id
,
hipblasLtHandle_t
h
)
{
if
(
d
.
empty
())
{
d
.
resize
(
handle_pool
.
get_size
());
}
d
[
device_id
]
=
h
;
}
~
HandleCache
()
{
if
(
!
d
.
empty
())
{
handle_pool
.
store
(
d
);
}
}
private:
std
::
vector
<
hipblasLtHandle_t
>
d
;
}
cached_handles
;
class
csv_helper
{
class
csv_helper
{
public:
public:
struct
start
{};
struct
start
{};
...
@@ -987,18 +887,12 @@ static inline int getIntEnv(const char* name, int defval, int minval) {
...
@@ -987,18 +887,12 @@ static inline int getIntEnv(const char* name, int defval, int minval) {
}
//namespace
}
//namespace
/* Warning: only call once per device!
static
inline
void
CreateHipBlasLtHandle
(
hipblasLtHandle_t
*
handle
)
{
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
handle
));
* need to create multiple handles corresponding to compute_streams
* to avoid a handle be used by multi-streams concurrently.
*/
static
void
init_hipblaslt_handles
(
hipblasLtHandle_t
*
hipblaslt_handles
)
{
NVTE_CHECK
(
hipblaslt_handles
!=
nullptr
);
for
(
int
i
=
0
;
i
<
compute_num_streams
;
i
++
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
&
hipblaslt_handles
[
i
]));
}
}
}
using
hipBlasLtHandleManager
=
detail
::
HandleManager
<
hipblasLtHandle_t
,
CreateHipBlasLtHandle
>
;
transformer_engine
::
DType
get_transformer_engine_dtype_from_hipblaslt_dtype
(
const
hipDataType
t
)
{
transformer_engine
::
DType
get_transformer_engine_dtype_from_hipblaslt_dtype
(
const
hipDataType
t
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
switch
(
t
)
{
switch
(
t
)
{
...
@@ -1018,8 +912,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1018,8 +912,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
)
{
hipblasLtHandle_t
handle
)
{
void
*
A
=
inputA
->
data
.
dptr
;
void
*
A
=
inputA
->
data
.
dptr
;
void
*
A_scale_inverse
=
inputA
->
scale_inv
.
dptr
;
void
*
A_scale_inverse
=
inputA
->
scale_inv
.
dptr
;
float
*
A_scale_inverse_float
=
(
float
*
)(
inputA
->
scale_inv
.
dptr
);
float
*
A_scale_inverse_float
=
(
float
*
)(
inputA
->
scale_inv
.
dptr
);
...
@@ -1064,12 +957,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1064,12 +957,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
int
device_id
;
int
device_id
;
NVTE_CHECK_CUDA
(
hipGetDevice
(
&
device_id
));
NVTE_CHECK_CUDA
(
hipGetDevice
(
&
device_id
));
if
(
handle
==
nullptr
)
{
hipblasLtHandle_t
handle
=
hipBlasLtHandleManager
::
Instance
().
GetHandle
();
handle
=
cached_handles
.
get
(
device_id
);
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
obtain
(
device_id
);
}
}
hipblasLtMatmulDesc_t
operationDesc
=
nullptr
;
hipblasLtMatmulDesc_t
operationDesc
=
nullptr
;
hipblasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
hipblasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
...
@@ -1352,82 +1240,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1352,82 +1240,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
}
}
struct
HipBlasLtUserArgsDeleter
{
class
userArgsManager
{
void
operator
()(
hipblaslt_ext
::
UserArguments
*
ptr
)
const
noexcept
{
public:
hipFree
(
ptr
);
userArgsManager
()
{}
~
userArgsManager
()
{
// Release all userArgs when the manager is destroyed
for
(
auto
&
device_pair
:
userArgs_map_
)
{
hipFree
(
device_pair
.
second
);
// Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext
::
UserArguments
*
get
(
int
device_id
,
size_t
size
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// Check if the userArgs for this device exists
auto
device_it
=
userArgs_map_
.
find
(
device_id
);
if
(
device_it
!=
userArgs_map_
.
end
())
{
return
device_it
->
second
;
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext
::
UserArguments
*
userArgs
;
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
userArgs
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
// Store the userArgs in the map for this device
userArgs_map_
[
device_id
]
=
userArgs
;
return
userArgs
;
}
}
private:
std
::
unordered_map
<
int
,
hipblaslt_ext
::
UserArguments
*>
userArgs_map_
;
// Map from device_id to hipblasHandle
std
::
mutex
mutex_
;
};
};
class
d_userArgsManager
{
using
HipBlasLtUserArgsPtr
=
std
::
unique_ptr
<
hipblaslt_ext
::
UserArguments
,
HipBlasLtUserArgsDeleter
>
;
public:
d_userArgsManager
()
{}
~
d_userArgsManager
()
{
inline
HipBlasLtUserArgsPtr
make_hipblaslt_user_args_ptr
(
size_t
size
,
bool
host
)
{
// Release all userArgs when the manager is destroyed
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
nullptr
;
for
(
auto
&
device_pair
:
d_userArgs_map_
)
{
if
(
host
)
{
hipFree
(
device_pair
.
second
);
// Only one userArgs per device
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
}
else
{
NVTE_CHECK_CUDA
(
hipMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
}
return
HipBlasLtUserArgsPtr
(
raw_ptr
);
}
// Get a userArgs for the given device (creates if necessary)
inline
hipblaslt_ext
::
UserArguments
*
get_hipblaslt_user_args
(
size_t
size
,
bool
host
)
{
hipblaslt_ext
::
UserArguments
*
get
(
int
device_id
,
size_t
size
)
{
thread_local
static
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>
host_userargs_cache
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
thread_local
static
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>
device_userargs_cache
;
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>&
user_args_cache
=
host
?
host_userargs_cache
:
device_userargs_cache
;
// Check if the userArgs for this device exists
auto
size_it
=
user_args_cache
.
find
(
size
);
auto
device_it
=
d_userArgs_map_
.
find
(
device_id
);
if
(
size_it
!=
user_args_cache
.
end
())
{
if
(
device_it
!=
d_userArgs_map_
.
end
())
{
return
size_it
->
second
.
get
();
return
device_it
->
second
;
}
}
else
// Create a new userArgs for this device if it doesn't exist
{
hipblaslt_ext
::
UserArguments
*
d_userArgs
;
HipBlasLtUserArgsPtr
user_args
=
make_hipblaslt_user_args_ptr
(
size
,
host
);
NVTE_CHECK_CUDA
(
hipMalloc
(
&
d_userArgs
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
user_args
.
get
();
user_args_cache
[
size
]
=
std
::
move
(
user_args
);
// Store the userArgs in the map for this device
return
raw_ptr
;
d_userArgs_map_
[
device_id
]
=
d_userArgs
;
return
d_userArgs
;
}
}
}
private:
std
::
unordered_map
<
int
,
hipblaslt_ext
::
UserArguments
*>
d_userArgs_map_
;
// Map from device_id to hipblasHandle
std
::
mutex
mutex_
;
};
// Define a static userArgs manager
static
userArgsManager
UAManager
;
static
d_userArgsManager
d_UAManager
;
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
...
@@ -1438,23 +1285,13 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1438,23 +1285,13 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
int
device_id
;
hipblaslt_ext
::
UserArguments
*
userArgs
=
get_hipblaslt_user_args
(
m
.
size
(),
true
);
hipGetDevice
(
&
device_id
);
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
get_hipblaslt_user_args
(
m
.
size
(),
false
);
hipblaslt_ext
::
UserArguments
*
userArgs
=
UAManager
.
get
(
device_id
,
m
.
size
());
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
d_UAManager
.
get
(
device_id
,
m
.
size
());
// hipblaslt_ext::UserArguments* userArgs;
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblasLtHandle_t
handle
=
nullptr
;
hipblasLtHandle_t
handle
=
hipBlasLtHandleManager
::
Instance
().
GetHandle
();
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
}
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
[
0
]
->
data
.
dtype
);
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
[
0
]
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
[
0
]
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
[
0
]
->
data
.
dtype
);
...
@@ -1972,20 +1809,10 @@ void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1972,20 +1809,10 @@ void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
// Check compute_stream_offset valid.
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
hipblasLtHandle_t
handle
=
nullptr
;
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
}
hipblaslt_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
hipblaslt_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
workspace
,
workspaceSize
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
grad
,
workspace
,
workspaceSize
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
handle
);
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
return
;
return
;
}
}
...
...
transformer_engine/common/transpose/transpose.cu
View file @
1edc9e13
...
@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
...
@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
// Choose between runtime-compiled or statically-compiled kernel
// Choose between runtime-compiled or statically-compiled kernel
const
bool
aligned
=
(
row_length
%
THREADS_PER_WARP
==
0
&&
num_rows
%
THREADS_PER_WARP
==
0
);
const
bool
aligned
=
(
row_length
%
THREADS_PER_WARP
==
0
&&
num_rows
%
THREADS_PER_WARP
==
0
);
if
(
aligned
&&
rtc
::
is_enabled
())
{
// Runtime-compiled tuned kernel
//TODO:Using RTC may cause kernel crashes. Therefore, set use_rtc to true to avoid using RTC and resolve the kernel crash issue.
#ifdef USE_ROCM
const
bool
use_rtc
=
false
;
#else
const
bool
use_rtc
=
true
;
#endif
if
(
aligned
&&
rtc
::
is_enabled
()
&&
use_rtc
)
{
// Runtime-compiled tuned kernel
// Pick kernel config
// Pick kernel config
std
::
vector
<
KernelConfig
>
kernel_configs
;
std
::
vector
<
KernelConfig
>
kernel_configs
;
kernel_configs
.
reserve
(
16
);
kernel_configs
.
reserve
(
16
);
...
...
transformer_engine/pytorch/module/_common.py
View file @
1edc9e13
...
@@ -55,7 +55,7 @@ def apply_normalization(
...
@@ -55,7 +55,7 @@ def apply_normalization(
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
if
enable_lightop
and
(
ln_bias
is
None
)
and
normalization
==
"RMSNorm"
and
output_quantizer
is
None
and
(
output_dtype
is
torch
.
bfloat16
or
output_dtype
is
torch
.
float16
or
output_dtype
is
torch
.
float32
):
if
enable_lightop
and
(
ln_bias
is
None
)
and
normalization
==
"RMSNorm"
and
output_quantizer
is
None
and
(
output_dtype
is
torch
.
bfloat16
or
output_dtype
is
torch
.
float16
or
output_dtype
is
torch
.
float32
)
and
not
zero_centered_gamma
:
out
,
rsigma
=
rmsnorm_forward
(
inputmat
,
ln_weight
,
ln_out
,
eps
,
True
)
out
,
rsigma
=
rmsnorm_forward
(
inputmat
,
ln_weight
,
ln_out
,
eps
,
True
)
return
out
,
None
,
rsigma
return
out
,
None
,
rsigma
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment