export.py 2.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Export utilities for TransformerEngine"""

from contextlib import contextmanager
from typing import Generator
import torch


_IN_ONNX_EXPORT_MODE = False
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])


@contextmanager
def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
    """
    Context manager for exporting to ONNX.

    .. code-block:: python

        from transformer_engine.pytorch.export import onnx_export, te_translation_table

        with onnx_export(enabled=True):
            torch.onnx.export(model, dynamo=True, custom_translation_table=te_translation_table)

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
31
    enabled : bool, default = False
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
             whether or not to enable export
    """

    global _IN_ONNX_EXPORT_MODE
    onnx_export_state = _IN_ONNX_EXPORT_MODE
    if (TORCH_MAJOR, TORCH_MINOR) < (2, 4):
        raise RuntimeError("ONNX export is not supported for PyTorch versions less than 2.4")
    try:
        _IN_ONNX_EXPORT_MODE = enabled
        yield
    finally:
        _IN_ONNX_EXPORT_MODE = onnx_export_state


def is_in_onnx_export_mode() -> bool:
    """Returns True if onnx export mode is enabled, False otherwise."""
    return _IN_ONNX_EXPORT_MODE


def assert_warmed_up(module: torch.nn.Module) -> None:
    """Assert that the model has been warmed up before exporting to ONNX."""
    assert hasattr(module, "forwarded_at_least_once"), (
        "Model must be warmed up before exporting to ONNX, please run model with the"
        " same recipe before exporting."
    )


if TORCH_MAJOR == 2 and TORCH_MINOR >= 4 or TORCH_MAJOR > 2:
    # pylint: disable=unused-import
    from .onnx_extensions import (
        torch_onnx_gemm_inf_op,
        onnx_quantize_fp8_op,
        onnx_dequantize_fp8_op,
        onnx_quantize_mxfp8_op,
        onnx_dequantize_mxfp8_op,
        onnx_layernorm,
        onnx_attention_mask_func,
        onnx_gemm,
        te_translation_table,
    )