export.py 918 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Export utilities for TransformerEngine"""
from contextlib import contextmanager

_IN_ONNX_EXPORT_MODE = False

@contextmanager
def onnx_export(
    enabled: bool = False,
) -> None:
    """
    Context manager for exporting to ONNX.

    .. code-block:: python

        with onnx_export(enabled=True):
            torch.onnx.export(model)

    ----------
    enabled: bool, default = `False`
             whether or not to enable export
    """

    global _IN_ONNX_EXPORT_MODE
    onnx_export_state = (_IN_ONNX_EXPORT_MODE)
    try:
        _IN_ONNX_EXPORT_MODE = enabled
        yield
    finally:
        _IN_ONNX_EXPORT_MODE = onnx_export_state

def is_in_onnx_export_mode() -> bool:
    """Returns True if onnx export mode is enabled, False otherwise."""
    return _IN_ONNX_EXPORT_MODE