Unverified Commit cf00d537 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

[PyTorch] Defer torch compilation steps until first function call (#1599)



* Defer torch compilation steps until first function call
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix function call in smoke test
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>

---------
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent b59d1d8b
......@@ -56,3 +56,10 @@ def test_torch_dynamo(model_name: str):
# Forward and backward pass
out = model(*inputs)
out.backward(torch.zeros_like(out))
def test_lazy_compile():
"""Smoke test to ensure lazy compilation is working."""
from transformer_engine.pytorch.jit import dgelu_fused_
dgelu_fused_(torch.randn(10, 10), torch.randn(10, 10))
......@@ -4,21 +4,41 @@
"""NVFuser functions and JIT utilities"""
import os
from functools import wraps
from typing import Callable, Optional, Tuple
import torch
# pylint: disable=unnecessary-lambda-assignment
def lazy_compile(func):
"""Lazy compile a function with torch.compile
This decorator defers the compilation of a function until the first call, speeding up the
overall module's import time if these functions are not used.
"""
compiled_func = None
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal compiled_func
if compiled_func is None:
compiled_func = torch.compile(func)
return compiled_func(*args, **kwargs)
return wrapper
jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile
jit_fuser = lazy_compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = torch.compile
dropout_fuser = lazy_compile
# Decorator to disable Torch Dynamo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment