Unverified Commit 51f19fdc authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Add test for TRT integration + fix for mxfp8 export (#2083)



* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5b4d89c3
......@@ -23,8 +23,6 @@ set -x
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime"
pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
......@@ -40,7 +38,6 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
pip3 install onnxruntime==1.20.1
pip3 install onnxruntime_extensions==0.13.0
: ${TE_PATH:=/opt/transformerengine}
python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py
......@@ -36,6 +36,7 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method
import tensorrt as trt
# Global test configuration knobs.
......@@ -113,7 +114,7 @@ def trt_fp8_dequantize(t, scale):
@onnx_op(
op_type="trt::TRT_MXFP8QuantizeLinear",
op_type="trt::TRT_MXFP8DynamicQuantize",
domain="trt",
inputs=[
PyCustomOpDef.dt_float,
......@@ -1139,3 +1140,59 @@ def test_export_ctx_manager(enabled):
with te.onnx_export(enabled):
assert is_in_onnx_export_mode() == enabled
assert is_in_onnx_export_mode() == False
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_trt_integration(fp8_recipe: recipe.Recipe):
model = te.TransformerLayer(
hidden_size=128,
ffn_hidden_size=128,
num_attention_heads=4,
).eval()
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
out_ref = model(*inps)
onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx")
os.close(onnx_fd)
try:
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
with te.onnx_export(enabled=True):
torch.onnx.export(
model,
inps,
onnx_path,
output_names=["output"],
dynamo=True,
custom_translation_table=te_translation_table,
)
os.system(f"trtexec --onnx={onnx_path} --saveEngine={onnx_path}.engine")
# Run TRT engine
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open(onnx_path + ".engine", "rb") as f:
engine_data = f.read()
engine = runtime.deserialize_cuda_engine(engine_data)
context = engine.create_execution_context()
context.set_tensor_address(engine.get_tensor_name(0), inps[0].data_ptr())
stream = torch.cuda.Stream()
out = torch.zeros_like(out_ref)
context.set_tensor_address("output", out.data_ptr())
context.execute_async_v3(stream_handle=stream.cuda_stream)
stream.synchronize()
# Compare TRT and TE outputs
atol = 5e-2 if fp8_recipe is not None else 1e-4
rtol = 5e-2 if fp8_recipe is not None else 1e-4
torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol)
finally:
try:
os.remove(onnx_path)
except FileNotFoundError:
pass
......@@ -194,12 +194,12 @@ def onnx_quantize_mxfp8_symbolic(
tensor: onnxscript.onnx_types.TensorType,
) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]:
"""Symbolic quantize to MXFP8Tensor used for inference."""
tensor_out, scale_inv_out = TRT_MXFP8QuantizeLinear(tensor)
tensor_out, scale_inv_out = TRT_MXFP8DynamicQuantize(tensor)
return tensor_out, scale_inv_out
schema = defs.OpSchema(
name="TRT_MXFP8QuantizeLinear",
name="TRT_MXFP8DynamicQuantize",
domain="trt",
since_version=1,
doc="TRT MXFP8 Quantize Linear used for inference.",
......@@ -214,8 +214,8 @@ schema = defs.OpSchema(
],
)
TRT_MXFP8QuantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8QuantizeLinear", op_schema=schema
TRT_MXFP8DynamicQuantize = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8DynamicQuantize", op_schema=schema
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment