Unverified Commit cf00d537 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

[PyTorch] Defer torch compilation steps until first function call (#1599)



* Defer torch compilation steps until first function call
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix function call in smoke test
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>

---------
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent b59d1d8b
...@@ -56,3 +56,10 @@ def test_torch_dynamo(model_name: str): ...@@ -56,3 +56,10 @@ def test_torch_dynamo(model_name: str):
# Forward and backward pass # Forward and backward pass
out = model(*inputs) out = model(*inputs)
out.backward(torch.zeros_like(out)) out.backward(torch.zeros_like(out))
def test_lazy_compile():
"""Smoke test to ensure lazy compilation is working."""
from transformer_engine.pytorch.jit import dgelu_fused_
dgelu_fused_(torch.randn(10, 10), torch.randn(10, 10))
...@@ -4,21 +4,41 @@ ...@@ -4,21 +4,41 @@
"""NVFuser functions and JIT utilities""" """NVFuser functions and JIT utilities"""
import os import os
from functools import wraps
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import torch import torch
# pylint: disable=unnecessary-lambda-assignment # pylint: disable=unnecessary-lambda-assignment
def lazy_compile(func):
"""Lazy compile a function with torch.compile
This decorator defers the compilation of a function until the first call, speeding up the
overall module's import time if these functions are not used.
"""
compiled_func = None
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal compiled_func
if compiled_func is None:
compiled_func = torch.compile(func)
return compiled_func(*args, **kwargs)
return wrapper
jit_fuser = lambda func: func jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile jit_fuser = lazy_compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597 # See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = torch.compile dropout_fuser = lazy_compile
# Decorator to disable Torch Dynamo # Decorator to disable Torch Dynamo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment