jit.py 401 Bytes
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import torch

from megatron.core.utils import is_torch_min_version

jit_fuser = torch.jit.script
# nvFuser is deprecated in PyTorch JIT starting from 2.2

try:
    if is_torch_min_version("2.2.0a0"):
        jit_fuser = torch.compile
except ImportError:

    def noop_decorator(func):
        return func

    jit_fuser = noop_decorator