Unverified Commit 0a1499fa authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] Dynamo ONNX export support (#1497)



* some initial code
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* onnx support
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* mxfp8 support
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixed returning layernorm etc
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* formatting
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* license fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tests passing
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* added pip install to test.sh
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/export.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* float8currentscaling quantizer exception
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* added to wheels
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* onnx versions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* installations in tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarroot <root@prenyx0221.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update setup.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* onnxscript version chnage
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@gmail.com>

* Update build.yml
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update pytorch.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarroot <root@prenyx0221.a51.clusters.nvidia.com>
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarroot <root@prenyx0221.a51.clusters.nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@gmail.com>
parent c0c12e20
......@@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -83,7 +83,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops
run: pip install torch pybind11[global] einops onnxscript
- name: 'Checkout'
uses: actions/checkout@v3
with:
......
......@@ -13,12 +13,19 @@ from typing import List
def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions."""
reqs = ["torch>=2.1", "einops"]
"""Install dependencies for TE/PyTorch extensions."""
reqs = ["torch>=2.1", "einops", "onnxscript"]
reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
reqs.extend(
[
"torch>=2.1",
"onnx",
"onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871",
]
)
return reqs
......
......@@ -23,6 +23,8 @@ set -x
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime"
pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
......@@ -38,6 +40,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gem
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
......
This diff is collapsed.
......@@ -53,6 +53,7 @@ from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
try:
......
......@@ -56,6 +56,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log,
)
from transformer_engine.pytorch import export
from transformer_engine.pytorch.export import is_in_onnx_export_mode
# Global vars for flash attn v2 and v3 imports
flash_attn_cuda_bwd = None
......@@ -148,7 +150,14 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
def mask_func(x, y):
return (
export.onnx_attention_mask_func(x, y)
if is_in_onnx_export_mode()
else attention_mask_func(x, y)
)
self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
......
......@@ -17,6 +17,7 @@ from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
AttnTypes,
......@@ -963,6 +964,13 @@ class DotProductAttention(TransformerEngineBaseModule):
inference_params=inference_params,
)
global _attention_backends
if is_in_onnx_export_mode():
# We do not want to call get_attention_backend() in ONNX mode
# and we want to avoid using any global variables like _attention_backends.
use_flash_attention = False
use_fused_attention = False
use_unfused_attention = True
else:
if (
_attention_backends["attention_params"] is None
or attention_params != _attention_backends["attention_params"]
......
......@@ -8,6 +8,7 @@ from typing import Callable, Tuple, Union, Optional
import torch
from torch import nn
import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode
THREADS_PER_WARP = 32
......@@ -19,12 +20,18 @@ _default_causal_mask = {}
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
def _get_mask():
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
_default_causal_mask[matrix_identifiers] = torch.triu(
return torch.triu(
torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset
)
if is_in_onnx_export_mode():
return _get_mask()
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
_default_causal_mask[matrix_identifiers] = _get_mask()
return _default_causal_mask[matrix_identifiers]
......@@ -169,7 +176,11 @@ class FusedScaleMaskSoftmax(nn.Module):
self.attn_mask_type = attn_mask_type
assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled"
if is_in_onnx_export_mode():
return self.forward_torch_softmax(inp, mask, scale)
# We do not want to connect this if with previous if,
# because we want to avoid calling is_kernel_available() in ONNX mode.
if self.is_kernel_available(mask, *inp.size()):
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)
......@@ -245,15 +256,15 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type in ["causal", "causal_bottom_right"]:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
if mask is None:
mask = causal_mask
else:
mask = torch.logical_or(mask, causal_mask)
mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
mask_output = self.mask_func(inp, mask)
probs = torch.nn.Softmax(dim=-1)(mask_output)
probs = torch.nn.functional.softmax(mask_output, dim=-1)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
......
......@@ -44,6 +44,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser
......@@ -1140,9 +1141,7 @@ def get_full_mask(
swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q + window_size[1]
).view(batch_size, 1, 1, 1)
swa_mask = torch.logical_not(
torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
)
swa_mask = torch.logical_not((swa_left <= 0) & ~(swa_right < 0))
if attention_mask is not None:
attention_mask = torch.logical_or(swa_mask, attention_mask)
else:
......@@ -1333,14 +1332,22 @@ def get_full_cu_seqlens(
"""
global _cu_seqlens_cache
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
def _get_cu_seqlens(batch_size, max_seqlen, device):
return torch.arange(
0,
(batch_size + 1) * max_seqlen,
step=max_seqlen,
dtype=torch.int32,
device=device,
)
if is_in_onnx_export_mode():
return _get_cu_seqlens(batch_size, max_seqlen, device)
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = _get_cu_seqlens(
batch_size, max_seqlen, device
)
return _cu_seqlens_cache[(batch_size, max_seqlen)]
......@@ -1616,6 +1623,11 @@ def get_qkv_layout(
def run_iteratively(q, k, v):
# check data pointers
if is_in_onnx_export_mode():
check_ptrs_qkv = False
check_ptrs_qk = False
check_ptrs_kv = False
else:
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
......@@ -1708,7 +1720,10 @@ def get_qkv_layout(
return qkv_layout
if not is_in_onnx_export_mode():
qkv_layout = run_iteratively(q, k, v)
else:
qkv_layout = "not_supported"
if qkv_layout == "not_supported":
# force q,k,v to be contiguous and run get_layout again
q, k, v = [x.contiguous() for x in [q, k, v]]
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Export utilities for TransformerEngine"""
from contextlib import contextmanager
from typing import Generator
import torch
_IN_ONNX_EXPORT_MODE = False
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
@contextmanager
def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
"""
Context manager for exporting to ONNX.
.. code-block:: python
from transformer_engine.pytorch.export import onnx_export, te_translation_table
with onnx_export(enabled=True):
torch.onnx.export(model, dynamo=True, custom_translation_table=te_translation_table)
Parameters
----------
enabled: bool, default = `False`
whether or not to enable export
"""
global _IN_ONNX_EXPORT_MODE
onnx_export_state = _IN_ONNX_EXPORT_MODE
if (TORCH_MAJOR, TORCH_MINOR) < (2, 4):
raise RuntimeError("ONNX export is not supported for PyTorch versions less than 2.4")
try:
_IN_ONNX_EXPORT_MODE = enabled
yield
finally:
_IN_ONNX_EXPORT_MODE = onnx_export_state
def is_in_onnx_export_mode() -> bool:
"""Returns True if onnx export mode is enabled, False otherwise."""
return _IN_ONNX_EXPORT_MODE
def assert_warmed_up(module: torch.nn.Module) -> None:
"""Assert that the model has been warmed up before exporting to ONNX."""
assert hasattr(module, "forwarded_at_least_once"), (
"Model must be warmed up before exporting to ONNX, please run model with the"
" same recipe before exporting."
)
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4 or TORCH_MAJOR > 2:
# pylint: disable=unused-import
from .onnx_extensions import (
torch_onnx_gemm_inf_op,
onnx_quantize_fp8_op,
onnx_dequantize_fp8_op,
onnx_quantize_mxfp8_op,
onnx_dequantize_mxfp8_op,
onnx_layernorm,
onnx_attention_mask_func,
onnx_gemm,
te_translation_table,
)
......@@ -6,10 +6,10 @@
import os
from functools import wraps
from typing import Callable, Optional, Tuple
import torch
from . import torch_version
from .export import is_in_onnx_export_mode
from .utils import gpu_autocast_ctx
# pylint: disable=unnecessary-lambda-assignment
......@@ -46,7 +46,17 @@ if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive)
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: (
f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive)
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
def set_jit_fusion_options() -> None:
......
......@@ -13,6 +13,7 @@ import torch
from .. import cpp_extensions as tex
from ..constants import TE_DType
from ..utils import get_default_init_method
from ..export import is_in_onnx_export_mode
def _get_normalization_func(normalization: str, forward: bool):
......@@ -164,6 +165,8 @@ def noop_cat(
raise ValueError("Attempted to concatenate 0 tensors")
if len(tensors) == 1:
return tensors[0]
if is_in_onnx_export_mode():
return torch.cat(tensors, dim=dim)
return _NoopCatFunc.apply(dim, *tensors)
......
......@@ -989,6 +989,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
self.forwarded_at_least_once = True
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
......
......@@ -68,6 +68,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import (
......@@ -1463,6 +1464,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
......@@ -1486,12 +1489,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
......@@ -1621,6 +1619,72 @@ class LayerNormLinear(TransformerEngineBaseModule):
for name, q in zip(names, original_quantizers)
)
def _get_weight_and_bias_tensors(self):
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
return weight_tensor, bias_tensor
def onnx_forward(
self,
inp: torch.Tensor,
fp8_output: bool,
) -> torch.Tensor:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from ..export import onnx_layernorm, onnx_gemm
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export"
assert_warmed_up(self)
(
input_quantizer,
weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(fp8_output, fp8_grad=False)
inp_dtype = inp.dtype
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
ln_out, ln_out_return = onnx_layernorm(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self.eps,
self.normalization,
self.zero_centered_gamma,
inp_dtype,
self.return_layernorm_output,
input_quantizer,
)
if weight_quantizer is not None:
weight_tensor_quantized = weight_quantizer.onnx_quantize(weight_tensor)
weight_tensor = weight_quantizer.onnx_dequantize(weight_tensor_quantized)
weight_tensor = weight_tensor.to(inp_dtype)
if bias_tensor is not None:
bias_tensor = bias_tensor.to(inp_dtype)
output = onnx_gemm(weight_tensor, ln_out, bias_tensor if self.apply_bias else None)
if output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_layernorm_output and self.return_bias:
return output, bias_tensor.to(inp_dtype), ln_out_return
if self.return_layernorm_output:
return output, ln_out_return
if self.return_bias:
return output, bias_tensor.to(inp_dtype)
return output
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert (
......
......@@ -77,6 +77,7 @@ from ..tensor.quantized_tensor import (
from ..cpp_extensions import (
general_gemm,
)
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.utils import any_feature_enabled
from ...debug.pytorch.debug_state import TEDebugState
......@@ -1721,6 +1722,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
if is_in_onnx_export_mode():
return self.onnx_forward(inp)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
......@@ -1910,6 +1913,89 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer,
)
def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from ..export import onnx_layernorm, onnx_gemm
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export"
assert_warmed_up(self)
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(False)
inp_dtype = inp.dtype
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
# layernorm + fp8 cast
ln_out, ln_out_return = onnx_layernorm(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self.eps,
self.normalization,
self.zero_centered_gamma,
inp_dtype,
self.return_layernorm_output,
fc1_input_quantizer,
)
if fc1_weight_quantizer is not None:
fc1_weight_q = fc1_weight_quantizer.onnx_quantize(fc1_weight)
fc1_weight = fc1_weight_quantizer.onnx_dequantize(fc1_weight_q)
fc1_weight = fc1_weight.to(inp_dtype)
fc1_out = onnx_gemm(fc1_weight, ln_out, fc1_bias)
fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32
activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"relu": torch.nn.functional.relu,
"geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh")
* x.chunk(2, -1)[1],
"qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"srelu": torch.nn.functional.softplus,
}
if self.activation not in activation_map:
raise ValueError(f"Unsupported activation in onnx export: {self.activation}")
act_out = activation_map[self.activation](fc1_out)
if fc2_weight_quantizer is not None:
fc2_weight_q = fc2_weight_quantizer.onnx_quantize(fc2_weight)
fc2_weight = fc2_weight_quantizer.onnx_dequantize(fc2_weight_q)
fc2_weight = fc2_weight.to(inp_dtype)
if fc2_input_quantizer is not None:
act_out_q = fc2_input_quantizer.onnx_quantize(act_out)
act_out = fc2_input_quantizer.onnx_dequantize(act_out_q)
act_out = act_out.to(inp_dtype)
fc2_out = onnx_gemm(fc2_weight, act_out, fc2_bias)
if output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_layernorm_output:
if self.return_bias:
return fc2_out, fc2_bias.to(inp_dtype), ln_out_return
return fc2_out, ln_out_return
if self.return_bias:
return fc2_out, fc2_bias.to(inp_dtype)
return fc2_out
def _get_debug_quantizers(self, fp8_output):
from ...debug.pytorch.debug_quantization import DebugQuantizer
......
......@@ -67,6 +67,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
......@@ -1278,6 +1279,9 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
......@@ -1301,13 +1305,7 @@ class Linear(TransformerEngineBaseModule):
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = None
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
......@@ -1420,6 +1418,95 @@ class Linear(TransformerEngineBaseModule):
for name, q in zip(names, original_quantizers)
)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = None
return weight_tensor, bias_tensor
def onnx_forward(
self,
inp: torch.Tensor,
fp8_output: bool,
) -> torch.Tensor:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from ..export import onnx_gemm
assert_warmed_up(self)
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export."
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
(
input_quantizer,
weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(fp8_output, False)
inp_dtype = inp.dtype
if input_quantizer is not None:
inp_q = input_quantizer.onnx_quantize(inp)
inp = input_quantizer.onnx_dequantize(inp_q)
inp = inp.to(inp_dtype)
if weight_quantizer is not None:
weight_q = weight_quantizer.onnx_quantize(weight_tensor)
weight_tensor = weight_quantizer.onnx_dequantize(weight_q)
if bias_tensor is not None:
bias_tensor = bias_tensor.to(inp_dtype)
weight_tensor = weight_tensor.to(inp_dtype)
if self.apply_bias:
output = onnx_gemm(weight_tensor, inp, bias_tensor)
else:
output = onnx_gemm(weight_tensor, inp, None)
if output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_bias:
return output, bias_tensor
return output
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert (
......@@ -1467,23 +1554,6 @@ class Linear(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
File containing torch.ops extensions and their corresponding ONNX symbolic functions.
Many transformer engine layers rely on custom calls from the transformer_engine_torch module, making ONNX export challenging because:
1. They often accept Python objects (quantizers), which ONNX does not support.
2. They are complex, incorporating fusions and precomputing certain values for backward passes—mechanisms unnecessary for ONNX export.
For these reasons, we introduce onnx_forward methods in each layer that are simpler and
primarily leverage torch operators with known ONNX symbolic functions.
These methods avoid fusions and backward pass precomputations.
The main considerations are quantization—which PyTorch does not natively support, so we need to implement onnx symbolic functions on our own.
Since ONNX does not yet support quantization, operators from TensorRT are employed.
The primary goal of ONNX export is to enable inference compatibility with TensorRT.
"""
from typing import Tuple
import math
import torch
import onnxscript
from onnxscript import opset18 as op
from onnx import defs
import transformer_engine_torch as tex
from .tensor.float8_tensor import Float8Quantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .constants import MXFP8_BLOCK_SCALING_SIZE
from .utils import round_up_to_nearest_multiple
from .export import is_in_onnx_export_mode
trt_opset = onnxscript.values.Opset(
"trt", version=1
) # opset from TensorRT which supports FP8 quantization
# ONNX GEMM for inference
def onnx_gemm(weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""ONNX GEMM used for inference."""
reshaped_inp = inp.reshape(-1, inp.shape[-1])
out = torch_onnx_gemm_inf_op(weight, reshaped_inp, bias)
return out.reshape(inp.shape[:-1] + (-1,))
@torch.library.custom_op("tex::gemm_inf", mutates_args=[])
def torch_onnx_gemm_inf_op(
weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
"""Gemm used for inference -- weight is transposed"""
out = inp @ weight.T
if bias is not None:
out = out + bias
return out
@torch_onnx_gemm_inf_op.register_fake
def _(weight, inp, bias):
"""Fake gemm used for inference."""
out = inp @ weight.T
if bias is not None:
out = out + bias
return out
def onnx_gemm_inf_symbolic(
weight: onnxscript.onnx_types.TensorType,
inp: onnxscript.onnx_types.TensorType,
bias: onnxscript.onnx_types.TensorType,
) -> onnxscript.onnx_types.TensorType:
"""Symbolic gemm used for inference."""
return op.Gemm(inp, weight, bias, transA=0, transB=1)
# ONNX FP8 Quantization
@torch.library.custom_op("tex::fp8_quantize", mutates_args=[])
def onnx_quantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor:
"""Quantize to Float8Tensor used for inference."""
scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
amax_tensor = torch.tensor([1], dtype=torch.float32, device=tensor.device)
quantizer = Float8Quantizer(scale_tensor, amax_tensor, tex.DType.kFloat8E4M3)
return quantizer.quantize(tensor)._data
@onnx_quantize_fp8_op.register_fake
def _(tensor, *_):
"""Fake quantize to Float8Tensor used for inference."""
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device)
def onnx_quantize_fp8_symbolic(
tensor: onnxscript.onnx_types.TensorType,
scale: float,
) -> onnxscript.onnx_types.UINT8:
"""Symbolic quantize used for inference."""
scale_inv = op.Constant(value_float=1 / scale)
return TRT_FP8QuantizeLinear(tensor, scale_inv)
# Define the schema for the custom operator
schema = defs.OpSchema(
name="TRT_FP8QuantizeLinear",
domain="trt",
since_version=1,
doc="TRT FP8 Quantize Linear used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")],
)
TRT_FP8QuantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_FP8QuantizeLinear", op_schema=schema
)
# ONNX FP8 Dequantization
@torch.library.custom_op("tex::fp8_dequantize", mutates_args=[])
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor:
"""Dequantize from Float8Tensor used for inference."""
scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
quantizer = Float8Quantizer(
scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3
)
quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32)
return quantizer_tensor.dequantize()
@onnx_dequantize_fp8_op.register_fake
def _(tensor: torch.Tensor, _) -> torch.Tensor:
"""Fake dequantize from Float8Tensor used for inference."""
return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device)
def onnx_dequantize_fp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, scale: float
) -> onnxscript.onnx_types.TensorType:
"""Symbolic dequantize from Float8Tensor used for inference."""
scale_inv = op.Constant(value_float=1 / scale)
return TRT_FP8DequantizeLinear(tensor, scale_inv)
schema = defs.OpSchema(
name="TRT_FP8DequantizeLinear",
domain="trt",
since_version=1,
doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
)
TRT_FP8DequantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema
)
# ONNX MXFP8 Quantization
@torch.library.custom_op("tex::mxfp8_quantize", mutates_args=[])
def onnx_quantize_mxfp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize to MXFP8Tensor used for inference."""
quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3)
quantized_tensor = quantizer(tensor)
return quantized_tensor._rowwise_data, quantized_tensor._rowwise_scale_inv
@onnx_quantize_mxfp8_op.register_fake
def _(tensor: torch.Tensor):
"""Fake quantize to MXFP8Tensor used for inference."""
mxfp8_scale_shape = [
round_up_to_nearest_multiple(math.prod(tensor.shape[:-1]), 128),
round_up_to_nearest_multiple(tensor.shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
]
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.empty(
mxfp8_scale_shape, dtype=torch.uint8, device=tensor.device
)
def onnx_quantize_mxfp8_symbolic(
tensor: onnxscript.onnx_types.TensorType,
) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]:
"""Symbolic quantize to MXFP8Tensor used for inference."""
tensor_out, scale_inv_out = TRT_MXFP8QuantizeLinear(tensor)
return tensor_out, scale_inv_out
schema = defs.OpSchema(
name="TRT_MXFP8QuantizeLinear",
domain="trt",
since_version=1,
doc="TRT MXFP8 Quantize Linear used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
],
outputs=[
defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor"),
defs.OpSchema.FormalParameter(
"scale_inv", "tensor(uint8)", "Scale factor for quantization"
),
],
)
TRT_MXFP8QuantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8QuantizeLinear", op_schema=schema
)
# ONNX MXFP8 Dequantization
@torch.library.custom_op("tex::mxfp8_dequantize", mutates_args=[])
def onnx_dequantize_mxfp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
"""Dequantize from MXFP8Tensor used for inference."""
quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3)
quantizer_tensor = quantizer.create_tensor_from_data(
tensor, scale_inv, fake_dtype=torch.float32
)
return quantizer_tensor.dequantize()
@onnx_dequantize_mxfp8_op.register_fake
def _(tensor: torch.Tensor, _):
"""Fake dequantize from MXFP8Tensor used for inference."""
return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device)
def onnx_dequantize_mxfp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType
) -> onnxscript.onnx_types.TensorType:
"""Symbolic dequantize from MXFP8Tensor used for inference."""
return TRT_MXFP8DequantizeLinear(tensor, scale_inv)
schema = defs.OpSchema(
name="TRT_MXFP8DequantizeLinear",
domain="trt",
since_version=1,
doc="TRT MXFP8 Dequantize Linear from MXFP8Tensor used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
defs.OpSchema.FormalParameter(
"scale_inv", "tensor(uint8)", "Scale factor for dequantization"
),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
)
TRT_MXFP8DequantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8DequantizeLinear", op_schema=schema
)
# ONNX LayerNorm
@torch.library.custom_op("tex::layernorm", mutates_args=[])
def onnx_layernorm_op(
inp: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float
) -> torch.Tensor:
"""ONNX LayerNorm used for inference."""
model = tex.LayerNorm(inp.shape[1], eps=eps)
model.weight.data = weight
model.bias.data = bias
return model(inp)
@onnx_layernorm_op.register_fake
def _(inp, *_):
"""Fake ONNX LayerNorm used for inference."""
return inp
def onnx_layernorm_symbolic(
inp: onnxscript.onnx_types.TensorType,
weight: onnxscript.onnx_types.TensorType,
bias: onnxscript.onnx_types.TensorType,
eps: float,
) -> onnxscript.onnx_types.TensorType:
"""Symbolic ONNX LayerNorm used for inference."""
return op.LayerNormalization(inp, weight, bias, epsilon=eps)
# onnx layernorm helper function - handles layernorm with quantization
def onnx_layernorm(
inp: torch.Tensor,
layer_norm_weight: torch.Tensor,
layer_norm_bias: torch.Tensor,
eps: float,
normalization: str,
zero_centered_gamma: bool,
output_dtype: torch.dtype,
return_layernorm_output: bool,
input_quantizer,
) -> torch.Tensor:
"""ONNX LayerNorm used for inference."""
ln_weight = layer_norm_weight if not zero_centered_gamma else layer_norm_weight + 1
ln_weight = ln_weight.to(inp.dtype).to(torch.float32)
inp = inp.to(torch.float32)
layer_norm_bias = (
layer_norm_bias.to(output_dtype).to(torch.float32) if layer_norm_bias is not None else None
)
if normalization == "RMSNorm":
ln_out = torch.nn.functional.rms_norm(inp, inp.shape[-1:], ln_weight, eps)
else:
ln_out = torch.nn.functional.layer_norm(
inp, inp.shape[-1:], ln_weight, layer_norm_bias, eps
)
ln_out_return = ln_out
if input_quantizer is not None:
if return_layernorm_output:
# In case of return_layernorm_output, layernorm is not fused with fp8 cast,
# so we cast to input_dtype and then perform cast to fp8 if needed
ln_out = ln_out.to(output_dtype).to(torch.float32)
ln_out_return = ln_out
elif isinstance(input_quantizer, MXFP8Quantizer):
# layernorm + mxfp8 quantizer behaves differently
ln_out = ln_out.to(output_dtype).to(torch.float32)
ln_out_quantized = input_quantizer.onnx_quantize(ln_out)
ln_out = input_quantizer.onnx_dequantize(ln_out_quantized)
ln_out = ln_out.to(output_dtype)
return ln_out, ln_out_return
# utility functions
def onnx_attention_mask_func(
attention_scores: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""Get attention mask without inp"""
assert is_in_onnx_export_mode()
return attention_scores.masked_fill(attention_mask, -10000.0)
# This translation table should be passed to torch.onnx.export function
# using the custom_translation_table=te_translation_table option.
te_translation_table = {
torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic,
torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic,
torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic,
torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic,
torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic,
torch.ops.tex.layernorm.default: onnx_layernorm_symbolic,
}
......@@ -23,6 +23,7 @@ from ...utils import (
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
......@@ -179,6 +180,8 @@ class LayerNorm(BasicOperation):
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if is_in_onnx_export_mode():
return self.op_onnx_forward(input_)
# Check tensor dims
weight = self.weight
......@@ -268,3 +271,13 @@ class LayerNorm(BasicOperation):
grad_weight = dw.view(weight_dims)
grad_bias = db.view(weight_dims)
return grad_input, (grad_weight, grad_bias)
def op_onnx_forward(
self,
input_: torch.Tensor,
) -> torch.Tensor:
"""Every operand in this function has a defined ONNX translation."""
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.layer_norm(
input_, input_.shape[-1:], weight, self.bias, self.eps
)
......@@ -23,6 +23,7 @@ from ...utils import (
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
......@@ -162,6 +163,8 @@ class RMSNorm(BasicOperation):
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if is_in_onnx_export_mode():
return self.op_onnx_forward(input_)
# Check tensor dims
weight = self.weight
......@@ -246,3 +249,11 @@ class RMSNorm(BasicOperation):
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
return grad_input, (grad_weight,)
def op_onnx_forward(
self,
input_: torch.Tensor,
) -> torch.Tensor:
"""Every operand in this function has a defined ONNX translation."""
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.rms_norm(input_, input_.shape[-1:], weight, self.eps)
......@@ -167,6 +167,21 @@ class Float8Quantizer(Quantizer):
quantizer=self,
)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations."""
# Q inputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the input if needed.
if tensor.dtype != torch.float32:
tensor = tensor.to(torch.float32)
data = torch.ops.tex.fp8_quantize(tensor, self.scale.item())
return self.create_tensor_from_data(data, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations."""
out = torch.ops.tex.fp8_dequantize(tensor._data, self.scale.item())
out = out.to(tensor.dtype)
return out
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling
......@@ -328,6 +343,18 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer=self,
)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations."""
raise NotImplementedError(
"Float8CurrentScalingQuantizer does not support ONNX quantization yet."
)
def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations."""
raise NotImplementedError(
"Float8CurrentScalingQuantizer does not support ONNX dequantization yet."
)
def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment