Unverified Commit 0a1499fa authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] Dynamo ONNX export support (#1497)



* some initial code
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* onnx support
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* mxfp8 support
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixed returning layernorm etc
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* formatting
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* license fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tests passing
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* added pip install to test.sh
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/export.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* float8currentscaling quantizer exception
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* added to wheels
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* onnx versions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* installations in tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarroot <root@prenyx0221.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update setup.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* onnxscript version chnage
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@gmail.com>

* Update build.yml
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update pytorch.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarroot <root@prenyx0221.a51.clusters.nvidia.com>
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarroot <root@prenyx0221.a51.clusters.nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@gmail.com>
parent c0c12e20
......@@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -83,7 +83,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops
run: pip install torch pybind11[global] einops onnxscript
- name: 'Checkout'
uses: actions/checkout@v3
with:
......
......@@ -13,12 +13,19 @@ from typing import List
def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions."""
reqs = ["torch>=2.1", "einops"]
"""Install dependencies for TE/PyTorch extensions."""
reqs = ["torch>=2.1", "einops", "onnxscript"]
reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
reqs.extend(
[
"torch>=2.1",
"onnx",
"onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871",
]
)
return reqs
......
......@@ -23,6 +23,8 @@ set -x
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime"
pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
......@@ -38,6 +40,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gem
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains tests for exporting TransformerEngine models to ONNX.
The purpose of these tests is validation that TE models are converted to their correct ONNX
representation. Toward this end, each test captures the output of a TE module forward pass,
converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and
validate the output against TE's output.
Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented
using custom ORT operations.
To run many repetitive tests use pytest-loop:
$ python3 -m pip install pytest-loop
$ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm
For reproducibility use: torch.manual_seed(0)
"""
import os
import tempfile
import pytest
import warnings
import numpy as np
import onnxruntime as ort
import torch
import random
from torch import nn as nn
from typing import Optional, Union, Tuple, List
from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method
# Global test configuration knobs.
# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance).
SAVE_TEST_IO = bool(int(os.getenv("NVTE_ONNX_EXPORT_SAVE_TEST_IO", "0")))
if SAVE_TEST_IO:
from polygraphy.json import save_json
from polygraphy.comparator import RunResults
# The directory where generated ONNX test models are stored.
NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR")
NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
tempfile.gettempdir(), "./gen_onnx_models"
)
# The directory where this file is stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
skip_MXFP8 = pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
fp8_recipes = [
None,
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
]
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
@onnx_op(
op_type="trt::TRT_FP8QuantizeLinear",
domain="trt",
inputs=[
PyCustomOpDef.dt_float,
PyCustomOpDef.dt_float,
],
outputs=[PyCustomOpDef.dt_uint8],
)
def trt_fp8_quantize(t, scale):
"""FP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
return q(x)._data.cpu().numpy()
@onnx_op(
op_type="trt::TRT_FP8DequantizeLinear",
domain="trt",
inputs=[
PyCustomOpDef.dt_uint8,
PyCustomOpDef.dt_float,
],
outputs=[PyCustomOpDef.dt_float],
)
def trt_fp8_dequantize(t, scale):
"""FP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
quantizer_tensor = q.create_tensor_from_data(x, fake_dtype=torch.float32)
return quantizer_tensor.dequantize().cpu().numpy()
@onnx_op(
op_type="trt::TRT_MXFP8QuantizeLinear",
domain="trt",
inputs=[
PyCustomOpDef.dt_float,
],
outputs=[PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_uint8],
)
def trt_mxfp8_quantize(t):
"""MXFP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.mxfp8_tensor.MXFP8Quantizer(tex.DType.kFloat8E4M3)
return q(x)._rowwise_data.cpu().numpy(), q(x)._rowwise_scale_inv.cpu().numpy()
@onnx_op(
op_type="trt::TRT_MXFP8DequantizeLinear",
domain="trt",
inputs=[
PyCustomOpDef.dt_uint8,
PyCustomOpDef.dt_uint8,
],
outputs=[PyCustomOpDef.dt_float],
)
def trt_mxfp8_dequantize(t, scale_inv):
"""MXFP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
scale_inv_tensor = torch.from_numpy(scale_inv).cuda()
q = te.tensor.mxfp8_tensor.MXFP8Quantizer(tex.DType.kFloat8E4M3)
quantizer_tensor = q.create_tensor_from_data(x, scale_inv_tensor, fake_dtype=torch.float32)
return quantizer_tensor.dequantize().cpu().numpy()
@pytest.fixture()
def seed_default_rng():
"""Reseed the PRNG for test reproducibility"""
torch.manual_seed(1234)
@pytest.fixture()
def set_max_seq_len(max_seq_len=128):
"""Set the maximum sequence length that can be used for attention masking"""
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}"
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def do_export(
model: torch.nn.Module,
inp: torch.Tensor,
fname: str,
fp8_recipe: recipe.Recipe,
input_names: List[str] = None,
output_names: List[str] = None,
dynamic_shapes: List[str] = None,
):
"""Export to ONNX"""
input_names = input_names or ["input"]
output_names = output_names or ["output"]
with torch.inference_mode(), te.fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe
), warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*")
model.cuda().eval()
os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True)
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,)
assert len(inps) == len(input_names)
inds_to_del = [i for i in range(len(inps)) if inps[i] is None]
input_names = [input_names[i] for i in range(len(inps)) if i not in inds_to_del]
model(*inps) # warm-up run
with te.export.onnx_export(True):
model(*inps)
with te.export.onnx_export(True):
torch.onnx.export(
model,
inps,
fname,
dynamo=True,
custom_translation_table=te_translation_table,
verbose=True,
dynamic_shapes=dynamic_shapes,
input_names=input_names,
output_names=output_names,
optimize=inps[0].dtype
!= torch.bfloat16, # optimizer does not work with bfloat16 yet - will need to change that after onnxscript supports bfloat16
)
def to_numpy(tensor):
if isinstance(tensor, torch.Tensor):
if tensor.dtype == torch.bfloat16:
tensor = tensor.type(torch.float32)
tensor = tensor.detach().cpu().numpy()
return tensor
def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
"""Initialize the FP8 quantization scales in module"""
module.init_fp8_metadata(num_gemms)
for quantizer in module.quantizers["scaling_fwd"]:
quantizer.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
def te_infer(
model: torch.nn.Module,
inps: Union[Tuple[torch.Tensor], torch.Tensor],
is_fp8: bool,
fp8_recipe: recipe.Recipe,
):
"""Transformer Engine forward propagation."""
with torch.inference_mode(), te.fp8_autocast(
enabled=is_fp8, fp8_recipe=fp8_recipe
), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple):
te_outputs = (te_outputs,)
return te_outputs
def compare_outputs(
onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
):
"""Compare ORT and TE outputs."""
assert len(onnx_outputs) == len(te_outputs)
# Compare ORT and PyTorch outputs.
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
# np.isclose: abs(a - b) <= (atol + rtol * abs(b))
te_output = to_numpy(te_output)
onnx_output = to_numpy(onnx_output)
ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol)
mismatches = ac.nonzero()
mismatched_ids = [loc for loc in zip(*mismatches)]
if mismatched_ids:
# Log some information in case of error.
print("*" * 100)
nb_errors = len(mismatched_ids)
nb_vals = min(nb_errors, max_errors_printed)
print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})")
print(f"Showing first {nb_vals} errors (ONNX -- TE):")
abs_err = np.abs(onnx_output - te_output)
errors = abs_err[mismatches]
for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc]
print(
f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >"
f" {atol + rtol * abs(ref)}"
)
print(f"Max error: {np.max(errors)}")
if nb_errors > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")
def serialize_inputs_outputs(
fname: str,
inputs: Union[Tuple[torch.Tensor], torch.Tensor],
te_outputs: List[torch.Tensor],
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
):
if not SAVE_TEST_IO:
return
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
input_names = input_names or ["input"]
output_names = output_names or ["output"]
inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
named_inputs = zip(input_names, inputs)
input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}]
json_fname = fname[: -len(".onnx")] + "_inputs.json"
save_json(input_data, json_fname, description="custom input data")
json_fname = fname[: -len(".onnx")] + "_output.json"
named_outputs = zip(output_names, te_outputs)
output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None}
custom_outputs = RunResults()
custom_outputs.add([output_data], runner_name="custom_runner")
custom_outputs.save(json_fname)
def validate_result(
fname: str,
inps: Union[Tuple[torch.Tensor], torch.Tensor],
model: torch.nn.Module,
atol: float = 1.0e-8, # np.isclose default atol
rtol: float = 1.0e-5, # np.isclose default rtol
max_errors_printed: int = 10,
is_fp8: bool = False,
allow_cnt_errors: int = 0,
input_names: List[str] = None,
output_names: List[str] = None,
te_outputs: List[torch.Tensor] = None,
):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close.
The purpose of the output comparison is to validate that TE models are converted to
their correct ONNX representation by testing that TE and ORT outputs match within some
small threshold (allowing for finite precision errors).
Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring,
a very small number (0-3) of outliers. This is fine to do because these outliers are due to
small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX
representation (the tests assume both ORT or TE kernels are correct).
Argument `te_outputs` can be used to provide pre-computed TE outputs.
"""
def create_ort_session(fname: str, is_fp8: bool):
def load_custom_ops(session_opts: ort.SessionOptions):
"""For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension."""
session_opts.register_custom_ops_library(get_library_path())
print("registered custom FP8 Q/DQ ops!")
"""Create an ONNX Runtime session for validation."""
kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]}
if is_fp8:
sess_options = ort.SessionOptions()
load_custom_ops(sess_options)
kwargs["sess_options"] = sess_options
s = ort.InferenceSession(fname, **kwargs)
return s
def create_ort_input_dict(session, inputs):
inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
input_names = [x.name for x in session.get_inputs()]
inps = [to_numpy(x) for x in inputs if x is not None]
inp_dict = dict(zip(input_names, inps))
return inp_dict
input_names = input_names or ["input"]
output_names = output_names or ["output"]
# Run ORT session and TE model.
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
if not te_outputs:
te_outputs = te_infer(model, inps, is_fp8)
ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
compare_outputs(
onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
)
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def dtype2str(dtype: torch.dtype, fake_bf16_io=False):
if fake_bf16_io:
assert dtype == torch.bfloat16
return "_fake_bf16"
return {
torch.float32: "_fp32",
torch.float16: "_fp16",
torch.bfloat16: "_bf16",
}[dtype]
def as_te_type(dtype: torch.dtype):
return {
torch.float32: tex.DType.kFloat32,
torch.float16: tex.DType.kFloat16,
torch.bfloat16: tex.DType.kBFloat16,
}[dtype]
def get_attn_mask_str(use_mask, attn_mask_type):
# See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
if attn_mask_type is None:
return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_arbitrary-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = (
"_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
)
return attn_mask_str
"""
Test cases begin here.
"""
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, False),
(torch.float16, True),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(torch.bfloat16, True),
],
)
def test_export_linear(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
precision: torch.dtype,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
# Set dimensions (these are arbitrary).
batch_size = 4
in_features = 64
out_features = 64
hidden_size = 64
class Test_Linear(nn.Module):
def __init__(self, in_features, out_features, use_bias, return_bias, precision):
super().__init__()
self.linear = te.Linear(
in_features,
out_features,
bias=use_bias,
return_bias=return_bias,
params_dtype=precision,
)
def forward(self, inp):
ret = self.linear(inp)
return ret
inp = torch.randn(batch_size, hidden_size, in_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if fp8_recipe is not None else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to(
device="cuda"
)
# dynamic shape
bs = torch.export.Dim("bs", min=2, max=1256)
do_export(
model,
inp,
fname,
fp8_recipe,
dynamic_shapes={"inp": {0: bs}},
)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16,):
return
if fp8_recipe is None:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
else:
validate_result(
fname, inp, model, atol=1e-2, is_fp8=fp8_recipe is not None, te_outputs=te_outputs
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize(
"precision",
[
torch.float32,
torch.float16,
torch.bfloat16,
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
precision: torch.dtype,
zero_centered_gamma: bool,
normalization: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Set dimensions (these are arbitrary).
batch_size = 4
in_features = 64
out_features = 256
hidden_size = 256
inp = torch.ones(batch_size, in_features, out_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if fp8_recipe is not None else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx"
with torch.no_grad():
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm
model = layernorm_cls(
hidden_size,
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
).to(device="cuda")
# dynamic shape
bs = torch.export.Dim("bs", min=2, max=1256)
do_export(model, inp, fname, fp8_recipe, dynamic_shapes={"input": {0: bs}})
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16,):
return
if fp8_recipe is None:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
elif precision != torch.bfloat16:
validate_result(
fname,
inp,
model,
atol=1e-3,
is_fp8=fp8_recipe is not None,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_linear(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype,
zero_centered_gamma: bool,
normalization: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
hidden_size = 256
inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if fp8_recipe is not None else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with torch.no_grad():
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
model = te.LayerNormLinear(
hidden_size,
3 * hidden_size,
bias=use_bias,
return_bias=return_bias,
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
).to(device="cuda")
if fp8_recipe is not None:
set_layer_scale(model, scale_factor, num_gemms=2)
do_export(model, inp, fname, fp8_recipe)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16,):
return
if fp8_recipe is None:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
elif precision != torch.bfloat16:
validate_result(
fname,
inp,
model,
atol=1e-3,
is_fp8=fp8_recipe is not None,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_mlp(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype,
zero_centered_gamma: bool,
activation: str,
normalization: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
hidden_size = 256
ffn_hidden_size = 256
inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if fp8_recipe is not None else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx"
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
model = te.LayerNormMLP(
hidden_size,
ffn_hidden_size,
bias=use_bias,
return_bias=return_bias,
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
activation=activation,
normalization=normalization,
).to(device="cuda")
if fp8_recipe is not None:
set_layer_scale(model, scale_factor, num_gemms=2)
do_export(model, inp, fname, fp8_recipe)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16,):
return
atol = (
2e-2 if fp8_recipe is not None else (5e-1 if activation == "swiglu" else 1e-3)
) # TODO(pgadzinski) - check 2e-2
validate_result(
fname, inp, model, atol=atol, is_fp8=fp8_recipe is not None, te_outputs=te_outputs
)
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type",
[
(torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
],
)
def test_export_core_attention(
seed_default_rng,
set_max_seq_len,
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
):
# Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
qkv_format = "sbhd"
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
input_names = ["query", "key", "value", "attention_mask"]
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask)
mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"
model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
).to(device="cuda")
do_export(model, inp, fname, input_names=input_names, fp8_recipe=None)
te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16,):
return
validate_result(
fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs
)
test_configs_multihead_attention = [
# "use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [
# "input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True),
(False, "self", True),
(True, "self", False),
(False, "self", False),
(True, "cross", True),
(False, "cross", True),
(True, "cross", False),
(False, "cross", False),
]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
"input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type
)
def test_export_multihead_attention(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str,
precision: torch.dtype,
return_layernorm_output: bool,
input_layernorm: bool,
attention_type: str,
fuse_qkv_params: bool,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
hidden_size = 256
sequence_length = 128
batch_size = 4
num_attention_heads = 32
kv_channels = 8
attention_dropout = 0.1
layernorm_epsilon = 1e-5
init_method = output_layer_init_method = get_default_init_method()
attention_args = (
hidden_size,
num_attention_heads,
kv_channels,
attention_dropout,
layernorm_epsilon,
init_method,
output_layer_init_method,
)
hidden_states_context = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(
batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
encoder_output = None
if attention_type == "cross":
encoder_output = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
fp8_str = "_fp8" if fp8_recipe is not None else ""
dtype_str = dtype2str(precision)
attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention"
fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else ""
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
input_ln_str = "_input-ln" if input_layernorm else ""
fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"
model = te.MultiheadAttention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
return_bias=True,
).to(device="cuda")
inp_context = (hidden_states_context, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
output_names = ["attention_output", "attention_bias"]
seq = torch.export.Dim("seq", min=2, max=1256)
bs = torch.export.Dim("bs", min=2, max=1256)
do_export(
model,
inp_context,
fname,
fp8_recipe,
input_names=input_names,
output_names=output_names,
dynamic_shapes={
"hidden_states": {0: seq, 1: bs},
"attention_mask": {2: seq, 0: bs} if use_mask else None,
"encoder_output": {0: seq, 1: bs} if attention_type == "cross" else None,
},
)
te_outputs = te_infer(model, inp_context, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(
fname, inp_context, te_outputs, input_names=input_names, output_names=output_names
)
if precision in (torch.bfloat16,):
return
if fp8_recipe is None:
validate_result(
fname,
inp_context,
model,
atol=1e-3,
input_names=input_names,
output_names=output_names,
te_outputs=te_outputs,
)
else:
validate_result(
fname,
inp_context,
model,
atol=1e-2,
is_fp8=fp8_recipe is not None,
input_names=input_names,
output_names=output_names,
allow_cnt_errors=3,
te_outputs=te_outputs,
)
# In GPT generative phase (inference) the input sequence is smaller than the maximum
# allowed sequence length and we want to test this condition.
# Pretend that we're in generative phase when it makes sense (causal mask and self-attention).
is_generative_phase = attn_mask_type == "causal" and attention_type == "self"
if is_generative_phase:
seq_len_offset = 8
hidden_states_generative = torch.randn(
sequence_length - seq_len_offset,
batch_size,
hidden_size,
dtype=precision,
device="cuda",
)
inp_generative = (hidden_states_generative, attention_mask, encoder_output)
if fp8_recipe is None:
validate_result(
fname,
inp_generative,
model,
atol=1e-3,
input_names=input_names,
output_names=output_names,
)
else:
validate_result(
fname,
inp_generative,
model,
atol=1e-2,
is_fp8=fp8_recipe is not None,
input_names=input_names,
output_names=output_names,
allow_cnt_errors=3,
)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
def test_export_transformer_layer(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str,
output_layernorm: bool,
precision: torch.dtype,
fuse_qkv_params: bool,
zero_centered_gamma: bool,
activation: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Layer configuration
hidden_size = 64
sequence_length = 128
batch_size = 1
ffn_hidden_size = 256
num_attention_heads = 4
input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
input_names = ["input", "attention_mask"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(
batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (input_tensor, attention_mask)
fp8_str = "_fp8" if fp8_recipe is not None else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
high_prec_str = dtype2str(precision)
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}_{activation}.onnx"
model = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma,
activation=activation,
).to(device="cuda")
do_export(model, inp, fname, fp8_recipe, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(
fname,
inp,
te_outputs,
input_names=input_names,
)
if precision in (torch.bfloat16,):
return
atol = 5e-1 if fp8_recipe is not None else (5e-1 if activation == "swiglu" else 5e-3)
validate_result(
fname,
inp,
model,
atol=atol,
is_fp8=fp8_recipe is not None,
input_names=input_names,
te_outputs=te_outputs,
)
@skip_FP8
@skip_MXFP8
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [True])
def test_export_gpt_generation(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
precision: torch.dtype,
zero_centered_gamma: bool,
):
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Layer configuration
hidden_size = 64
sequence_length = 128
batch_size = 4
ffn_hidden_size = 256
num_attention_heads = 4
attention_mask = None
use_mask = True
attn_mask_type = "causal"
fuse_qkv_params = True
output_layernorm = False
fp8_str = "_fp8" if fp8_recipe is not None else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
high_prec_str = dtype2str(precision)
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
fname = f"te.transformer_layer_generative{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx"
model = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma,
).to(device="cuda")
# "Context phase": use full input sequence length
input_names = ["input"]
output_names = ["output"]
input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
inp = (input_tensor,)
# dynamic shape
seq = torch.export.Dim("seq", min=2, max=1256)
bs = torch.export.Dim("bs", min=2, max=1256)
do_export(
model,
inp,
fname,
fp8_recipe,
dynamic_shapes={"hidden_states": {0: seq, 1: bs}},
)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(
fname, inp, te_outputs, input_names=input_names, output_names=output_names
)
if precision not in (torch.bfloat16,):
validate_result(
fname,
inp,
model,
atol=1e-2,
is_fp8=fp8_recipe is not None,
input_names=input_names,
te_outputs=te_outputs,
)
# "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8 and for MXFP8 we need to pad to mult of 32.
sequence_length = 1 if fp8_recipe is None else 32
input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
inp = (input_tensor, attention_mask)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision not in (torch.bfloat16,):
validate_result(
fname,
inp,
model,
atol=1e-2,
is_fp8=fp8_recipe is not None,
input_names=input_names,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("enabled", [True, False])
def test_export_ctx_manager(enabled):
assert is_in_onnx_export_mode() == False
with te.onnx_export(enabled):
assert is_in_onnx_export_mode() == enabled
assert is_in_onnx_export_mode() == False
......@@ -53,6 +53,7 @@ from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
try:
......
......@@ -56,6 +56,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log,
)
from transformer_engine.pytorch import export
from transformer_engine.pytorch.export import is_in_onnx_export_mode
# Global vars for flash attn v2 and v3 imports
flash_attn_cuda_bwd = None
......@@ -148,7 +150,14 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
def mask_func(x, y):
return (
export.onnx_attention_mask_func(x, y)
if is_in_onnx_export_mode()
else attention_mask_func(x, y)
)
self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
......
......@@ -17,6 +17,7 @@ from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
AttnTypes,
......@@ -963,6 +964,13 @@ class DotProductAttention(TransformerEngineBaseModule):
inference_params=inference_params,
)
global _attention_backends
if is_in_onnx_export_mode():
# We do not want to call get_attention_backend() in ONNX mode
# and we want to avoid using any global variables like _attention_backends.
use_flash_attention = False
use_fused_attention = False
use_unfused_attention = True
else:
if (
_attention_backends["attention_params"] is None
or attention_params != _attention_backends["attention_params"]
......
......@@ -8,6 +8,7 @@ from typing import Callable, Tuple, Union, Optional
import torch
from torch import nn
import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode
THREADS_PER_WARP = 32
......@@ -19,12 +20,18 @@ _default_causal_mask = {}
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
def _get_mask():
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
_default_causal_mask[matrix_identifiers] = torch.triu(
return torch.triu(
torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset
)
if is_in_onnx_export_mode():
return _get_mask()
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
_default_causal_mask[matrix_identifiers] = _get_mask()
return _default_causal_mask[matrix_identifiers]
......@@ -169,7 +176,11 @@ class FusedScaleMaskSoftmax(nn.Module):
self.attn_mask_type = attn_mask_type
assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled"
if is_in_onnx_export_mode():
return self.forward_torch_softmax(inp, mask, scale)
# We do not want to connect this if with previous if,
# because we want to avoid calling is_kernel_available() in ONNX mode.
if self.is_kernel_available(mask, *inp.size()):
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)
......@@ -245,15 +256,15 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type in ["causal", "causal_bottom_right"]:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
if mask is None:
mask = causal_mask
else:
mask = torch.logical_or(mask, causal_mask)
mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
mask_output = self.mask_func(inp, mask)
probs = torch.nn.Softmax(dim=-1)(mask_output)
probs = torch.nn.functional.softmax(mask_output, dim=-1)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
......
......@@ -44,6 +44,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser
......@@ -1140,9 +1141,7 @@ def get_full_mask(
swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q + window_size[1]
).view(batch_size, 1, 1, 1)
swa_mask = torch.logical_not(
torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
)
swa_mask = torch.logical_not((swa_left <= 0) & ~(swa_right < 0))
if attention_mask is not None:
attention_mask = torch.logical_or(swa_mask, attention_mask)
else:
......@@ -1333,14 +1332,22 @@ def get_full_cu_seqlens(
"""
global _cu_seqlens_cache
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
def _get_cu_seqlens(batch_size, max_seqlen, device):
return torch.arange(
0,
(batch_size + 1) * max_seqlen,
step=max_seqlen,
dtype=torch.int32,
device=device,
)
if is_in_onnx_export_mode():
return _get_cu_seqlens(batch_size, max_seqlen, device)
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = _get_cu_seqlens(
batch_size, max_seqlen, device
)
return _cu_seqlens_cache[(batch_size, max_seqlen)]
......@@ -1616,6 +1623,11 @@ def get_qkv_layout(
def run_iteratively(q, k, v):
# check data pointers
if is_in_onnx_export_mode():
check_ptrs_qkv = False
check_ptrs_qk = False
check_ptrs_kv = False
else:
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
......@@ -1708,7 +1720,10 @@ def get_qkv_layout(
return qkv_layout
if not is_in_onnx_export_mode():
qkv_layout = run_iteratively(q, k, v)
else:
qkv_layout = "not_supported"
if qkv_layout == "not_supported":
# force q,k,v to be contiguous and run get_layout again
q, k, v = [x.contiguous() for x in [q, k, v]]
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Export utilities for TransformerEngine"""
from contextlib import contextmanager
from typing import Generator
import torch
_IN_ONNX_EXPORT_MODE = False
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
@contextmanager
def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
"""
Context manager for exporting to ONNX.
.. code-block:: python
from transformer_engine.pytorch.export import onnx_export, te_translation_table
with onnx_export(enabled=True):
torch.onnx.export(model, dynamo=True, custom_translation_table=te_translation_table)
Parameters
----------
enabled: bool, default = `False`
whether or not to enable export
"""
global _IN_ONNX_EXPORT_MODE
onnx_export_state = _IN_ONNX_EXPORT_MODE
if (TORCH_MAJOR, TORCH_MINOR) < (2, 4):
raise RuntimeError("ONNX export is not supported for PyTorch versions less than 2.4")
try:
_IN_ONNX_EXPORT_MODE = enabled
yield
finally:
_IN_ONNX_EXPORT_MODE = onnx_export_state
def is_in_onnx_export_mode() -> bool:
"""Returns True if onnx export mode is enabled, False otherwise."""
return _IN_ONNX_EXPORT_MODE
def assert_warmed_up(module: torch.nn.Module) -> None:
"""Assert that the model has been warmed up before exporting to ONNX."""
assert hasattr(module, "forwarded_at_least_once"), (
"Model must be warmed up before exporting to ONNX, please run model with the"
" same recipe before exporting."
)
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4 or TORCH_MAJOR > 2:
# pylint: disable=unused-import
from .onnx_extensions import (
torch_onnx_gemm_inf_op,
onnx_quantize_fp8_op,
onnx_dequantize_fp8_op,
onnx_quantize_mxfp8_op,
onnx_dequantize_mxfp8_op,
onnx_layernorm,
onnx_attention_mask_func,
onnx_gemm,
te_translation_table,
)
......@@ -6,10 +6,10 @@
import os
from functools import wraps
from typing import Callable, Optional, Tuple
import torch
from . import torch_version
from .export import is_in_onnx_export_mode
from .utils import gpu_autocast_ctx
# pylint: disable=unnecessary-lambda-assignment
......@@ -46,7 +46,17 @@ if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive)
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: (
f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive)
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
def set_jit_fusion_options() -> None:
......
......@@ -13,6 +13,7 @@ import torch
from .. import cpp_extensions as tex
from ..constants import TE_DType
from ..utils import get_default_init_method
from ..export import is_in_onnx_export_mode
def _get_normalization_func(normalization: str, forward: bool):
......@@ -164,6 +165,8 @@ def noop_cat(
raise ValueError("Attempted to concatenate 0 tensors")
if len(tensors) == 1:
return tensors[0]
if is_in_onnx_export_mode():
return torch.cat(tensors, dim=dim)
return _NoopCatFunc.apply(dim, *tensors)
......
......@@ -989,6 +989,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
self.forwarded_at_least_once = True
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
......
......@@ -68,6 +68,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import (
......@@ -1463,6 +1464,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
......@@ -1486,12 +1489,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
......@@ -1621,6 +1619,72 @@ class LayerNormLinear(TransformerEngineBaseModule):
for name, q in zip(names, original_quantizers)
)
def _get_weight_and_bias_tensors(self):
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
return weight_tensor, bias_tensor
def onnx_forward(
self,
inp: torch.Tensor,
fp8_output: bool,
) -> torch.Tensor:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from ..export import onnx_layernorm, onnx_gemm
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export"
assert_warmed_up(self)
(
input_quantizer,
weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(fp8_output, fp8_grad=False)
inp_dtype = inp.dtype
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
ln_out, ln_out_return = onnx_layernorm(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self.eps,
self.normalization,
self.zero_centered_gamma,
inp_dtype,
self.return_layernorm_output,
input_quantizer,
)
if weight_quantizer is not None:
weight_tensor_quantized = weight_quantizer.onnx_quantize(weight_tensor)
weight_tensor = weight_quantizer.onnx_dequantize(weight_tensor_quantized)
weight_tensor = weight_tensor.to(inp_dtype)
if bias_tensor is not None:
bias_tensor = bias_tensor.to(inp_dtype)
output = onnx_gemm(weight_tensor, ln_out, bias_tensor if self.apply_bias else None)
if output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_layernorm_output and self.return_bias:
return output, bias_tensor.to(inp_dtype), ln_out_return
if self.return_layernorm_output:
return output, ln_out_return
if self.return_bias:
return output, bias_tensor.to(inp_dtype)
return output
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert (
......
......@@ -77,6 +77,7 @@ from ..tensor.quantized_tensor import (
from ..cpp_extensions import (
general_gemm,
)
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.utils import any_feature_enabled
from ...debug.pytorch.debug_state import TEDebugState
......@@ -1721,6 +1722,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
if is_in_onnx_export_mode():
return self.onnx_forward(inp)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
......@@ -1910,6 +1913,89 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer,
)
def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from ..export import onnx_layernorm, onnx_gemm
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export"
assert_warmed_up(self)
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(False)
inp_dtype = inp.dtype
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
# layernorm + fp8 cast
ln_out, ln_out_return = onnx_layernorm(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self.eps,
self.normalization,
self.zero_centered_gamma,
inp_dtype,
self.return_layernorm_output,
fc1_input_quantizer,
)
if fc1_weight_quantizer is not None:
fc1_weight_q = fc1_weight_quantizer.onnx_quantize(fc1_weight)
fc1_weight = fc1_weight_quantizer.onnx_dequantize(fc1_weight_q)
fc1_weight = fc1_weight.to(inp_dtype)
fc1_out = onnx_gemm(fc1_weight, ln_out, fc1_bias)
fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32
activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"relu": torch.nn.functional.relu,
"geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh")
* x.chunk(2, -1)[1],
"qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"srelu": torch.nn.functional.softplus,
}
if self.activation not in activation_map:
raise ValueError(f"Unsupported activation in onnx export: {self.activation}")
act_out = activation_map[self.activation](fc1_out)
if fc2_weight_quantizer is not None:
fc2_weight_q = fc2_weight_quantizer.onnx_quantize(fc2_weight)
fc2_weight = fc2_weight_quantizer.onnx_dequantize(fc2_weight_q)
fc2_weight = fc2_weight.to(inp_dtype)
if fc2_input_quantizer is not None:
act_out_q = fc2_input_quantizer.onnx_quantize(act_out)
act_out = fc2_input_quantizer.onnx_dequantize(act_out_q)
act_out = act_out.to(inp_dtype)
fc2_out = onnx_gemm(fc2_weight, act_out, fc2_bias)
if output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_layernorm_output:
if self.return_bias:
return fc2_out, fc2_bias.to(inp_dtype), ln_out_return
return fc2_out, ln_out_return
if self.return_bias:
return fc2_out, fc2_bias.to(inp_dtype)
return fc2_out
def _get_debug_quantizers(self, fp8_output):
from ...debug.pytorch.debug_quantization import DebugQuantizer
......
......@@ -67,6 +67,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
......@@ -1278,6 +1279,9 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
......@@ -1301,13 +1305,7 @@ class Linear(TransformerEngineBaseModule):
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = None
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
......@@ -1420,6 +1418,95 @@ class Linear(TransformerEngineBaseModule):
for name, q in zip(names, original_quantizers)
)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = None
return weight_tensor, bias_tensor
def onnx_forward(
self,
inp: torch.Tensor,
fp8_output: bool,
) -> torch.Tensor:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from ..export import onnx_gemm
assert_warmed_up(self)
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export."
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
(
input_quantizer,
weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(fp8_output, False)
inp_dtype = inp.dtype
if input_quantizer is not None:
inp_q = input_quantizer.onnx_quantize(inp)
inp = input_quantizer.onnx_dequantize(inp_q)
inp = inp.to(inp_dtype)
if weight_quantizer is not None:
weight_q = weight_quantizer.onnx_quantize(weight_tensor)
weight_tensor = weight_quantizer.onnx_dequantize(weight_q)
if bias_tensor is not None:
bias_tensor = bias_tensor.to(inp_dtype)
weight_tensor = weight_tensor.to(inp_dtype)
if self.apply_bias:
output = onnx_gemm(weight_tensor, inp, bias_tensor)
else:
output = onnx_gemm(weight_tensor, inp, None)
if output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_bias:
return output, bias_tensor
return output
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert (
......@@ -1467,23 +1554,6 @@ class Linear(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
File containing torch.ops extensions and their corresponding ONNX symbolic functions.
Many transformer engine layers rely on custom calls from the transformer_engine_torch module, making ONNX export challenging because:
1. They often accept Python objects (quantizers), which ONNX does not support.
2. They are complex, incorporating fusions and precomputing certain values for backward passes—mechanisms unnecessary for ONNX export.
For these reasons, we introduce onnx_forward methods in each layer that are simpler and
primarily leverage torch operators with known ONNX symbolic functions.
These methods avoid fusions and backward pass precomputations.
The main considerations are quantization—which PyTorch does not natively support, so we need to implement onnx symbolic functions on our own.
Since ONNX does not yet support quantization, operators from TensorRT are employed.
The primary goal of ONNX export is to enable inference compatibility with TensorRT.
"""
from typing import Tuple
import math
import torch
import onnxscript
from onnxscript import opset18 as op
from onnx import defs
import transformer_engine_torch as tex
from .tensor.float8_tensor import Float8Quantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .constants import MXFP8_BLOCK_SCALING_SIZE
from .utils import round_up_to_nearest_multiple
from .export import is_in_onnx_export_mode
trt_opset = onnxscript.values.Opset(
"trt", version=1
) # opset from TensorRT which supports FP8 quantization
# ONNX GEMM for inference
def onnx_gemm(weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""ONNX GEMM used for inference."""
reshaped_inp = inp.reshape(-1, inp.shape[-1])
out = torch_onnx_gemm_inf_op(weight, reshaped_inp, bias)
return out.reshape(inp.shape[:-1] + (-1,))
@torch.library.custom_op("tex::gemm_inf", mutates_args=[])
def torch_onnx_gemm_inf_op(
weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
"""Gemm used for inference -- weight is transposed"""
out = inp @ weight.T
if bias is not None:
out = out + bias
return out
@torch_onnx_gemm_inf_op.register_fake
def _(weight, inp, bias):
"""Fake gemm used for inference."""
out = inp @ weight.T
if bias is not None:
out = out + bias
return out
def onnx_gemm_inf_symbolic(
weight: onnxscript.onnx_types.TensorType,
inp: onnxscript.onnx_types.TensorType,
bias: onnxscript.onnx_types.TensorType,
) -> onnxscript.onnx_types.TensorType:
"""Symbolic gemm used for inference."""
return op.Gemm(inp, weight, bias, transA=0, transB=1)
# ONNX FP8 Quantization
@torch.library.custom_op("tex::fp8_quantize", mutates_args=[])
def onnx_quantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor:
"""Quantize to Float8Tensor used for inference."""
scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
amax_tensor = torch.tensor([1], dtype=torch.float32, device=tensor.device)
quantizer = Float8Quantizer(scale_tensor, amax_tensor, tex.DType.kFloat8E4M3)
return quantizer.quantize(tensor)._data
@onnx_quantize_fp8_op.register_fake
def _(tensor, *_):
"""Fake quantize to Float8Tensor used for inference."""
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device)
def onnx_quantize_fp8_symbolic(
tensor: onnxscript.onnx_types.TensorType,
scale: float,
) -> onnxscript.onnx_types.UINT8:
"""Symbolic quantize used for inference."""
scale_inv = op.Constant(value_float=1 / scale)
return TRT_FP8QuantizeLinear(tensor, scale_inv)
# Define the schema for the custom operator
schema = defs.OpSchema(
name="TRT_FP8QuantizeLinear",
domain="trt",
since_version=1,
doc="TRT FP8 Quantize Linear used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")],
)
TRT_FP8QuantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_FP8QuantizeLinear", op_schema=schema
)
# ONNX FP8 Dequantization
@torch.library.custom_op("tex::fp8_dequantize", mutates_args=[])
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor:
"""Dequantize from Float8Tensor used for inference."""
scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
quantizer = Float8Quantizer(
scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3
)
quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32)
return quantizer_tensor.dequantize()
@onnx_dequantize_fp8_op.register_fake
def _(tensor: torch.Tensor, _) -> torch.Tensor:
"""Fake dequantize from Float8Tensor used for inference."""
return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device)
def onnx_dequantize_fp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, scale: float
) -> onnxscript.onnx_types.TensorType:
"""Symbolic dequantize from Float8Tensor used for inference."""
scale_inv = op.Constant(value_float=1 / scale)
return TRT_FP8DequantizeLinear(tensor, scale_inv)
schema = defs.OpSchema(
name="TRT_FP8DequantizeLinear",
domain="trt",
since_version=1,
doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
)
TRT_FP8DequantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema
)
# ONNX MXFP8 Quantization
@torch.library.custom_op("tex::mxfp8_quantize", mutates_args=[])
def onnx_quantize_mxfp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize to MXFP8Tensor used for inference."""
quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3)
quantized_tensor = quantizer(tensor)
return quantized_tensor._rowwise_data, quantized_tensor._rowwise_scale_inv
@onnx_quantize_mxfp8_op.register_fake
def _(tensor: torch.Tensor):
"""Fake quantize to MXFP8Tensor used for inference."""
mxfp8_scale_shape = [
round_up_to_nearest_multiple(math.prod(tensor.shape[:-1]), 128),
round_up_to_nearest_multiple(tensor.shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
]
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.empty(
mxfp8_scale_shape, dtype=torch.uint8, device=tensor.device
)
def onnx_quantize_mxfp8_symbolic(
tensor: onnxscript.onnx_types.TensorType,
) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]:
"""Symbolic quantize to MXFP8Tensor used for inference."""
tensor_out, scale_inv_out = TRT_MXFP8QuantizeLinear(tensor)
return tensor_out, scale_inv_out
schema = defs.OpSchema(
name="TRT_MXFP8QuantizeLinear",
domain="trt",
since_version=1,
doc="TRT MXFP8 Quantize Linear used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
],
outputs=[
defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor"),
defs.OpSchema.FormalParameter(
"scale_inv", "tensor(uint8)", "Scale factor for quantization"
),
],
)
TRT_MXFP8QuantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8QuantizeLinear", op_schema=schema
)
# ONNX MXFP8 Dequantization
@torch.library.custom_op("tex::mxfp8_dequantize", mutates_args=[])
def onnx_dequantize_mxfp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
"""Dequantize from MXFP8Tensor used for inference."""
quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3)
quantizer_tensor = quantizer.create_tensor_from_data(
tensor, scale_inv, fake_dtype=torch.float32
)
return quantizer_tensor.dequantize()
@onnx_dequantize_mxfp8_op.register_fake
def _(tensor: torch.Tensor, _):
"""Fake dequantize from MXFP8Tensor used for inference."""
return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device)
def onnx_dequantize_mxfp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType
) -> onnxscript.onnx_types.TensorType:
"""Symbolic dequantize from MXFP8Tensor used for inference."""
return TRT_MXFP8DequantizeLinear(tensor, scale_inv)
schema = defs.OpSchema(
name="TRT_MXFP8DequantizeLinear",
domain="trt",
since_version=1,
doc="TRT MXFP8 Dequantize Linear from MXFP8Tensor used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
defs.OpSchema.FormalParameter(
"scale_inv", "tensor(uint8)", "Scale factor for dequantization"
),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
)
TRT_MXFP8DequantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8DequantizeLinear", op_schema=schema
)
# ONNX LayerNorm
@torch.library.custom_op("tex::layernorm", mutates_args=[])
def onnx_layernorm_op(
inp: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float
) -> torch.Tensor:
"""ONNX LayerNorm used for inference."""
model = tex.LayerNorm(inp.shape[1], eps=eps)
model.weight.data = weight
model.bias.data = bias
return model(inp)
@onnx_layernorm_op.register_fake
def _(inp, *_):
"""Fake ONNX LayerNorm used for inference."""
return inp
def onnx_layernorm_symbolic(
inp: onnxscript.onnx_types.TensorType,
weight: onnxscript.onnx_types.TensorType,
bias: onnxscript.onnx_types.TensorType,
eps: float,
) -> onnxscript.onnx_types.TensorType:
"""Symbolic ONNX LayerNorm used for inference."""
return op.LayerNormalization(inp, weight, bias, epsilon=eps)
# onnx layernorm helper function - handles layernorm with quantization
def onnx_layernorm(
inp: torch.Tensor,
layer_norm_weight: torch.Tensor,
layer_norm_bias: torch.Tensor,
eps: float,
normalization: str,
zero_centered_gamma: bool,
output_dtype: torch.dtype,
return_layernorm_output: bool,
input_quantizer,
) -> torch.Tensor:
"""ONNX LayerNorm used for inference."""
ln_weight = layer_norm_weight if not zero_centered_gamma else layer_norm_weight + 1
ln_weight = ln_weight.to(inp.dtype).to(torch.float32)
inp = inp.to(torch.float32)
layer_norm_bias = (
layer_norm_bias.to(output_dtype).to(torch.float32) if layer_norm_bias is not None else None
)
if normalization == "RMSNorm":
ln_out = torch.nn.functional.rms_norm(inp, inp.shape[-1:], ln_weight, eps)
else:
ln_out = torch.nn.functional.layer_norm(
inp, inp.shape[-1:], ln_weight, layer_norm_bias, eps
)
ln_out_return = ln_out
if input_quantizer is not None:
if return_layernorm_output:
# In case of return_layernorm_output, layernorm is not fused with fp8 cast,
# so we cast to input_dtype and then perform cast to fp8 if needed
ln_out = ln_out.to(output_dtype).to(torch.float32)
ln_out_return = ln_out
elif isinstance(input_quantizer, MXFP8Quantizer):
# layernorm + mxfp8 quantizer behaves differently
ln_out = ln_out.to(output_dtype).to(torch.float32)
ln_out_quantized = input_quantizer.onnx_quantize(ln_out)
ln_out = input_quantizer.onnx_dequantize(ln_out_quantized)
ln_out = ln_out.to(output_dtype)
return ln_out, ln_out_return
# utility functions
def onnx_attention_mask_func(
attention_scores: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""Get attention mask without inp"""
assert is_in_onnx_export_mode()
return attention_scores.masked_fill(attention_mask, -10000.0)
# This translation table should be passed to torch.onnx.export function
# using the custom_translation_table=te_translation_table option.
te_translation_table = {
torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic,
torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic,
torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic,
torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic,
torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic,
torch.ops.tex.layernorm.default: onnx_layernorm_symbolic,
}
......@@ -23,6 +23,7 @@ from ...utils import (
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
......@@ -179,6 +180,8 @@ class LayerNorm(BasicOperation):
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if is_in_onnx_export_mode():
return self.op_onnx_forward(input_)
# Check tensor dims
weight = self.weight
......@@ -268,3 +271,13 @@ class LayerNorm(BasicOperation):
grad_weight = dw.view(weight_dims)
grad_bias = db.view(weight_dims)
return grad_input, (grad_weight, grad_bias)
def op_onnx_forward(
self,
input_: torch.Tensor,
) -> torch.Tensor:
"""Every operand in this function has a defined ONNX translation."""
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.layer_norm(
input_, input_.shape[-1:], weight, self.bias, self.eps
)
......@@ -23,6 +23,7 @@ from ...utils import (
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
......@@ -162,6 +163,8 @@ class RMSNorm(BasicOperation):
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if is_in_onnx_export_mode():
return self.op_onnx_forward(input_)
# Check tensor dims
weight = self.weight
......@@ -246,3 +249,11 @@ class RMSNorm(BasicOperation):
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
return grad_input, (grad_weight,)
def op_onnx_forward(
self,
input_: torch.Tensor,
) -> torch.Tensor:
"""Every operand in this function has a defined ONNX translation."""
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.rms_norm(input_, input_.shape[-1:], weight, self.eps)
......@@ -167,6 +167,21 @@ class Float8Quantizer(Quantizer):
quantizer=self,
)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations."""
# Q inputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the input if needed.
if tensor.dtype != torch.float32:
tensor = tensor.to(torch.float32)
data = torch.ops.tex.fp8_quantize(tensor, self.scale.item())
return self.create_tensor_from_data(data, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations."""
out = torch.ops.tex.fp8_dequantize(tensor._data, self.scale.item())
out = out.to(tensor.dtype)
return out
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling
......@@ -328,6 +343,18 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer=self,
)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations."""
raise NotImplementedError(
"Float8CurrentScalingQuantizer does not support ONNX quantization yet."
)
def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations."""
raise NotImplementedError(
"Float8CurrentScalingQuantizer does not support ONNX dequantization yet."
)
def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment