You need to sign in or sign up before continuing.
Unverified Commit c6538d6e authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Disable TorchDynamo optimizations in PyTorch modules (#312)



* Disable TorchDynamo optimizations in PyTorch modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add test for Torch Dynamo
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add torch.dynamo test to qa
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Skip torch.compile test for <v2.0
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 36873ec8
......@@ -10,3 +10,4 @@ pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import pytest
import torch
import transformer_engine.pytorch as te
# Model names for test_torch_dynamo
_model_names = ["Linear", "LayerNorm", "LayerNormLinear", "LayerNormMLP"]
@pytest.mark.skipif(torch.__version__ < "2", reason="torch.compile not available")
@pytest.mark.parametrize("model_name", _model_names)
def test_torch_dynamo(model_name: str):
"""Test compatibility with Torch Dynamo
Construct model, optimize with Torch Dynamo, and perform a single
forward and backward pass.
"""
# Helper function to construct tensor with default options
def make_tensor(
dims: Tuple[int],
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
requires_grad: bool = True,
**kwargs,
):
return torch.zeros(
dims,
dtype=dtype,
device=device,
requires_grad=requires_grad,
**kwargs,
)
# Construct model and input tensors
model = None
inputs = []
if model_name == "Linear":
model = te.Linear(16, 16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNorm":
model = te.LayerNorm(16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNormLinear":
model = te.LayerNormLinear(16,16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNormMLP":
model = te.LayerNormMLP(16,16)
inputs = [make_tensor([16,16])]
assert model is not None, f"could not construct {model_name}"
# Optimize model with TorchDynamo
torch.compile(model)
# Forward and backward pass
out = model(*inputs)
out.backward(torch.zeros_like(out))
......@@ -12,6 +12,13 @@ jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
no_torch_dynamo = torch._dynamo.disable
def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options."""
......
......@@ -14,6 +14,7 @@ import transformer_engine_extensions as tex
from ..cpp_extensions import (
layernorm_fwd_inf,
)
from ..jit import no_torch_dynamo
__all__ = ["LayerNorm"]
......@@ -160,6 +161,7 @@ class LayerNorm(torch.nn.Module):
init.zeros_(self.bias)
@no_torch_dynamo
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# Maintain backward compatibility.
......
......@@ -49,6 +49,7 @@ from ..cpp_extensions import (
cast_from_fp8,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
__all__ = ["LayerNormLinear"]
......@@ -821,6 +822,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return fp8_weight_tensors
@no_torch_dynamo
def forward(
self,
inp: torch.Tensor,
......
......@@ -44,6 +44,7 @@ from ..distributed import (
from .. import cpp_extensions as tex
from ..constants import dist_group_type, TE_DType
from ..jit import no_torch_dynamo
__all__ = ["LayerNormMLP"]
......@@ -1140,6 +1141,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return fp8_weight_tensors
@no_torch_dynamo
def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
......
......@@ -43,6 +43,7 @@ from ..cpp_extensions import (
cast_to_fp8,
)
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
__all__ = ["Linear"]
......@@ -668,6 +669,7 @@ class Linear(TransformerEngineBaseModule):
return fp8_weight_tensors
@no_torch_dynamo
def forward(
self,
inp: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment