# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import torch from megatron.core.utils import is_torch_min_version jit_fuser = torch.jit.script # nvFuser is deprecated in PyTorch JIT starting from 2.2 if is_torch_min_version("2.2.0a0"): jit_fuser = torch.compile(mode='max-autotune-no-cudagraphs') # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 no_torch_dynamo = lambda recursive=True: lambda func: func if torch.__version__ >= "2": import torch._dynamo if torch.__version__ >= "2.1": no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable( f, recursive=recursive ) else: # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True no_torch_dynamo = lambda recursive=True: torch._dynamo.disable