Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
d5cd815f
Commit
d5cd815f
authored
Nov 03, 2025
by
zhaochao
Browse files
[DCU] Fix the bug in test_onnx_export.py under L0
Signed-off-by:
zhaochao
<
zhaochao1@sugon.com
>
parent
ef65dd33
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
15 deletions
+70
-15
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+70
-15
No files found.
tests/pytorch/test_onnx_export.py
View file @
d5cd815f
...
...
@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
,
te_translation_table
from
transformer_engine.pytorch.onnx_extensions
import
te_translation_table
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
get_default_init_method
import
tensorrt
as
trt
...
...
@@ -65,7 +67,6 @@ if mxfp8_available:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
None
)
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
...
...
@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
],
outputs
=
[
PyCustomOpDef
.
dt_uint8
],
)
def
trt_fp8_quantize
(
t
,
scale
_inv
):
def
trt_fp8_quantize
(
t
,
scale
):
"""FP8 quantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
_inv
).
cuda
(),
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
...
...
@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv):
],
outputs
=
[
PyCustomOpDef
.
dt_float
],
)
def
trt_fp8_dequantize
(
t
,
scale
_inv
):
def
trt_fp8_dequantize
(
t
,
scale
):
"""FP8 dequantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
_inv
).
cuda
(),
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
...
...
@@ -469,16 +470,22 @@ def _test_export_linear(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_export_linear_use_bias
(
seed_default_rng
,
use_bias
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
use_bias
=
use_bias
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
def
test_export_linear_return_bias
(
seed_default_rng
,
return_bias
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_linear
(
return_bias
=
return_bias
)
...
...
@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm
(
normalization
=
normalization
)
...
...
@@ -594,9 +603,7 @@ def _test_export_layernorm_linear(
fname
,
inp
,
model
,
# For current scaling we use Float8Quantizer in tests + amax computed by hand,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol
=
1e-3
if
fp8_recipe
.
__class__
is
not
recipe
.
Float8CurrentScaling
else
2e-2
,
atol
=
1e-3
,
is_fp8
=
fp8_recipe
is
not
None
,
te_outputs
=
te_outputs
,
)
...
...
@@ -605,27 +612,39 @@ def _test_export_layernorm_linear(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_linear_return_ln_out
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
return_layernorm_output
=
True
)
def
test_export_layernorm_linear_zero_centered_gamma
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_linear_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
normalization
=
normalization
)
def
test_export_layernorm_linear_no_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
use_bias
=
False
)
def
test_export_layernorm_linear_return_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_linear
(
return_bias
=
True
)
...
...
@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_mlp
(
seed_default_rng
,
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_mlp_return_layernorm_output
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
return_layernorm_output
=
True
)
def
test_export_layernorm_mlp_return_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
return_bias
=
True
)
def
test_export_layernorm_mlp_no_bias
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
use_bias
=
False
)
def
test_export_layernorm_mlp_zero_centered_gamma
(
seed_default_rng
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_mlp_normalization
(
seed_default_rng
,
normalization
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
normalization
=
normalization
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_layernorm_mlp_activation
(
seed_default_rng
,
activation
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_layernorm_mlp
(
activation
=
activation
)
...
...
@@ -731,6 +764,8 @@ def test_export_core_attention(
use_mask
:
bool
,
attn_mask_type
:
str
,
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
# Set dimensions (these are arbitrary).
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
=
(
64
,
4
,
1
,
64
)
qkv_size
=
(
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
)
...
...
@@ -932,22 +967,32 @@ def _test_export_multihead_attention(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_multihead_attention_recipe
(
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_multihead_attention_no_mask
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
use_mask
=
False
)
def
test_export_multihead_attention_no_input_layernorm
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
input_layernorm
=
False
)
def
test_export_multihead_attention_cross_attn
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
attention_type
=
"cross"
)
def
test_export_multihead_attention_unfused_qkv_params
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_multihead_attention
(
fuse_qkv_params
=
False
)
...
...
@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_transformer_layer_recipe
(
fp8_recipe
,
precision
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_transformer_layer_no_mask
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
use_mask
=
False
)
def
test_export_transformer_layer_output_layernorm
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
output_layernorm
=
True
)
def
test_export_transformer_layer_unfused_qkv_params
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
fuse_qkv_params
=
False
)
def
test_export_transformer_layer_zero_centered_gamma
():
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_transformer_layer_activation
(
activation
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
_test_export_transformer_layer
(
activation
=
activation
)
...
...
@@ -1056,7 +1113,8 @@ def test_export_gpt_generation(
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
# Layer configuration
hidden_size
=
64
sequence_length
=
128
...
...
@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled):
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
test_trt_integration
(
fp8_recipe
:
recipe
.
Recipe
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"TRT is not supported for HIP"
)
model
=
te
.
TransformerLayer
(
hidden_size
=
128
,
ffn_hidden_size
=
128
,
num_attention_heads
=
4
,
).
eval
()
if
type
(
fp8_recipe
)
==
recipe
.
Float8CurrentScaling
:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model
=
te
.
LayerNormMLP
(
128
,
128
)
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment