Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
565fd629
Commit
565fd629
authored
Oct 20, 2025
by
zhaochao
Browse files
[DCU] Fix the bug in test_onnx_export.py under L0
Signed-off-by:
zhaochao
<
zhaochao1@sugon.com
>
parent
68d851d1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
2 deletions
+65
-2
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+65
-2
No files found.
tests/pytorch/test_onnx_export.py
View file @
565fd629
...
...
@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
,
te_translation_table
from
transformer_engine.pytorch.onnx_extensions
import
te_translation_table
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
get_default_init_method
import
tensorrt
as
trt
...
...
@@ -468,16 +470,22 @@ def _test_export_linear(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_export_linear_use_bias
(
seed_default_rng
,
use_bias
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
use_bias
=
use_bias
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
def
test_export_linear_return_bias
(
seed_default_rng
,
return_bias
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
return_bias
=
return_bias
)
...
...
@@ -539,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm
(
normalization
=
normalization
)
...
...
@@ -602,27 +612,39 @@ def _test_export_layernorm_linear(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_linear_return_ln_out
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
return_layernorm_output
=
True
)
def
test_export_layernorm_linear_zero_centered_gamma
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_linear_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
normalization
=
normalization
)
def
test_export_layernorm_linear_no_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
use_bias
=
False
)
def
test_export_layernorm_linear_return_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
return_bias
=
True
)
...
...
@@ -681,32 +703,46 @@ def _test_export_layernorm_mlp(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_mlp
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_mlp_return_layernorm_output
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
return_layernorm_output
=
True
)
def
test_export_layernorm_mlp_return_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
return_bias
=
True
)
def
test_export_layernorm_mlp_no_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
use_bias
=
False
)
def
test_export_layernorm_mlp_zero_centered_gamma
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_mlp_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
normalization
=
normalization
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_layernorm_mlp_activation
(
seed_default_rng
,
activation
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
activation
=
activation
)
...
...
@@ -728,6 +764,8 @@ def test_export_core_attention(
use_mask
:
bool
,
attn_mask_type
:
str
,
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
# Set dimensions (these are arbitrary).
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
=
(
64
,
4
,
1
,
64
)
qkv_size
=
(
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
)
...
...
@@ -929,22 +967,32 @@ def _test_export_multihead_attention(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_multihead_attention_recipe
(
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_multihead_attention_no_mask
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
use_mask
=
False
)
def
test_export_multihead_attention_no_input_layernorm
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
input_layernorm
=
False
)
def
test_export_multihead_attention_cross_attn
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
attention_type
=
"cross"
)
def
test_export_multihead_attention_unfused_qkv_params
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
fuse_qkv_params
=
False
)
...
...
@@ -1020,27 +1068,39 @@ def _test_export_transformer_layer(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_transformer_layer_recipe
(
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_transformer_layer_no_mask
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
use_mask
=
False
)
def
test_export_transformer_layer_output_layernorm
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
output_layernorm
=
True
)
def
test_export_transformer_layer_unfused_qkv_params
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
fuse_qkv_params
=
False
)
def
test_export_transformer_layer_zero_centered_gamma
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_transformer_layer_activation
(
activation
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
activation
=
activation
)
...
...
@@ -1053,7 +1113,8 @@ def test_export_gpt_generation(
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
# Layer configuration
hidden_size
=
64
sequence_length
=
128
...
...
@@ -1144,6 +1205,8 @@ def test_export_ctx_manager(enabled):
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
test_trt_integration
(
fp8_recipe
:
recipe
.
Recipe
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"TRT is not supported for HIP"
)
model
=
te
.
TransformerLayer
(
hidden_size
=
128
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment