jit.py 364 Bytes
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import torch

TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])

jit_fuser = torch.jit.script
# nvFuser is deprecated in PyTorch JIT starting from 2.2
#if (TORCH_MAJOR > 2) or (TORCH_MAJOR == 2 and TORCH_MINOR >= 2):
#    jit_fuser = torch.compile