Commit d5cd815f authored by zhaochao's avatar zhaochao
Browse files

[DCU] Fix the bug in test_onnx_export.py under L0


Signed-off-by: default avatarzhaochao <zhaochao1@sugon.com>
parent ef65dd33
...@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op ...@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table from transformer_engine.pytorch.onnx_extensions import te_translation_table
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.utils import get_default_init_method
import tensorrt as trt import tensorrt as trt
...@@ -65,7 +67,6 @@ if mxfp8_available: ...@@ -65,7 +67,6 @@ if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling()) fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_available: if fp8_available:
fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(None) fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
...@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"] ...@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
], ],
outputs=[PyCustomOpDef.dt_uint8], outputs=[PyCustomOpDef.dt_uint8],
) )
def trt_fp8_quantize(t, scale_inv): def trt_fp8_quantize(t, scale):
"""FP8 quantization extension for ONNX Runtime.""" """FP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda() x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer( q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale_inv).cuda(), scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(), amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
...@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv): ...@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv):
], ],
outputs=[PyCustomOpDef.dt_float], outputs=[PyCustomOpDef.dt_float],
) )
def trt_fp8_dequantize(t, scale_inv): def trt_fp8_dequantize(t, scale):
"""FP8 dequantization extension for ONNX Runtime.""" """FP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda() x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer( q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale_inv).cuda(), scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(), amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
...@@ -469,16 +470,22 @@ def _test_export_linear( ...@@ -469,16 +470,22 @@ def _test_export_linear(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision): def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(fp8_recipe=fp8_recipe, precision=precision) _test_export_linear(fp8_recipe=fp8_recipe, precision=precision)
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_export_linear_use_bias(seed_default_rng, use_bias): def test_export_linear_use_bias(seed_default_rng, use_bias):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(use_bias=use_bias) _test_export_linear(use_bias=use_bias)
@pytest.mark.parametrize("return_bias", [True, False]) @pytest.mark.parametrize("return_bias", [True, False])
def test_export_linear_return_bias(seed_default_rng, return_bias): def test_export_linear_return_bias(seed_default_rng, return_bias):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(return_bias=return_bias) _test_export_linear(return_bias=return_bias)
...@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng): ...@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_normalization(seed_default_rng, normalization): def test_export_layernorm_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm(normalization=normalization) _test_export_layernorm(normalization=normalization)
...@@ -594,9 +603,7 @@ def _test_export_layernorm_linear( ...@@ -594,9 +603,7 @@ def _test_export_layernorm_linear(
fname, fname,
inp, inp,
model, model,
# For current scaling we use Float8Quantizer in tests + amax computed by hand, atol=1e-3,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2,
is_fp8=fp8_recipe is not None, is_fp8=fp8_recipe is not None,
te_outputs=te_outputs, te_outputs=te_outputs,
) )
...@@ -605,27 +612,39 @@ def _test_export_layernorm_linear( ...@@ -605,27 +612,39 @@ def _test_export_layernorm_linear(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision): def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision) _test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_linear_return_ln_out(seed_default_rng): def test_export_layernorm_linear_return_ln_out(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(return_layernorm_output=True) _test_export_layernorm_linear(return_layernorm_output=True)
def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng): def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(zero_centered_gamma=True) _test_export_layernorm_linear(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:]) @pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_linear_normalization(seed_default_rng, normalization): def test_export_layernorm_linear_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(normalization=normalization) _test_export_layernorm_linear(normalization=normalization)
def test_export_layernorm_linear_no_bias(seed_default_rng): def test_export_layernorm_linear_no_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(use_bias=False) _test_export_layernorm_linear(use_bias=False)
def test_export_layernorm_linear_return_bias(seed_default_rng): def test_export_layernorm_linear_return_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(return_bias=True) _test_export_layernorm_linear(return_bias=True)
...@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp( ...@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision): def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision) _test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng): def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(return_layernorm_output=True) _test_export_layernorm_mlp(return_layernorm_output=True)
def test_export_layernorm_mlp_return_bias(seed_default_rng): def test_export_layernorm_mlp_return_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(return_bias=True) _test_export_layernorm_mlp(return_bias=True)
def test_export_layernorm_mlp_no_bias(seed_default_rng): def test_export_layernorm_mlp_no_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(use_bias=False) _test_export_layernorm_mlp(use_bias=False)
def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng): def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(zero_centered_gamma=True) _test_export_layernorm_mlp(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:]) @pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_mlp_normalization(seed_default_rng, normalization): def test_export_layernorm_mlp_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(normalization=normalization) _test_export_layernorm_mlp(normalization=normalization)
@pytest.mark.parametrize("activation", supported_activations[1:]) @pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_layernorm_mlp_activation(seed_default_rng, activation): def test_export_layernorm_mlp_activation(seed_default_rng, activation):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(activation=activation) _test_export_layernorm_mlp(activation=activation)
...@@ -731,6 +764,8 @@ def test_export_core_attention( ...@@ -731,6 +764,8 @@ def test_export_core_attention(
use_mask: bool, use_mask: bool,
attn_mask_type: str, attn_mask_type: str,
): ):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels) qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
...@@ -932,22 +967,32 @@ def _test_export_multihead_attention( ...@@ -932,22 +967,32 @@ def _test_export_multihead_attention(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_multihead_attention_recipe(fp8_recipe, precision): def test_export_multihead_attention_recipe(fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision) _test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision)
def test_export_multihead_attention_no_mask(): def test_export_multihead_attention_no_mask():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(use_mask=False) _test_export_multihead_attention(use_mask=False)
def test_export_multihead_attention_no_input_layernorm(): def test_export_multihead_attention_no_input_layernorm():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(input_layernorm=False) _test_export_multihead_attention(input_layernorm=False)
def test_export_multihead_attention_cross_attn(): def test_export_multihead_attention_cross_attn():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(attention_type="cross") _test_export_multihead_attention(attention_type="cross")
def test_export_multihead_attention_unfused_qkv_params(): def test_export_multihead_attention_unfused_qkv_params():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(fuse_qkv_params=False) _test_export_multihead_attention(fuse_qkv_params=False)
...@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer( ...@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_transformer_layer_recipe(fp8_recipe, precision): def test_export_transformer_layer_recipe(fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision) _test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision)
def test_export_transformer_layer_no_mask(): def test_export_transformer_layer_no_mask():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(use_mask=False) _test_export_transformer_layer(use_mask=False)
def test_export_transformer_layer_output_layernorm(): def test_export_transformer_layer_output_layernorm():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(output_layernorm=True) _test_export_transformer_layer(output_layernorm=True)
def test_export_transformer_layer_unfused_qkv_params(): def test_export_transformer_layer_unfused_qkv_params():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(fuse_qkv_params=False) _test_export_transformer_layer(fuse_qkv_params=False)
def test_export_transformer_layer_zero_centered_gamma(): def test_export_transformer_layer_zero_centered_gamma():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(zero_centered_gamma=True) _test_export_transformer_layer(zero_centered_gamma=True)
@pytest.mark.parametrize("activation", supported_activations[1:]) @pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_transformer_layer_activation(activation): def test_export_transformer_layer_activation(activation):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(activation=activation) _test_export_transformer_layer(activation=activation)
...@@ -1056,7 +1113,8 @@ def test_export_gpt_generation( ...@@ -1056,7 +1113,8 @@ def test_export_gpt_generation(
"""Test that the ONNX model can correctly handle inputs with different shapes and that """Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths. the attention mask is adjusted on-the-fly to different sequence lengths.
""" """
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
# Layer configuration # Layer configuration
hidden_size = 64 hidden_size = 64
sequence_length = 128 sequence_length = 128
...@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled): ...@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled):
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_trt_integration(fp8_recipe: recipe.Recipe): def test_trt_integration(fp8_recipe: recipe.Recipe):
if IS_HIP_EXTENSION:
pytest.skip("TRT is not supported for HIP")
model = te.TransformerLayer( model = te.TransformerLayer(
hidden_size=128, hidden_size=128,
ffn_hidden_size=128, ffn_hidden_size=128,
num_attention_heads=4, num_attention_heads=4,
).eval() ).eval()
if type(fp8_recipe) == recipe.Float8CurrentScaling:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model = te.LayerNormMLP(128, 128)
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment