jit.py 821 Bytes
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import torch

from megatron.core.utils import is_torch_min_version

jit_fuser = torch.jit.script
# nvFuser is deprecated in PyTorch JIT starting from 2.2
if is_torch_min_version("2.2.0a0"):
    jit_fuser = torch.compile
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
    import torch._dynamo

    if torch.__version__ >= "2.1":
        no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
            f, recursive=recursive
        )
    else:
        # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
        no_torch_dynamo = lambda recursive=True: torch._dynamo.disable