Unverified Commit c6538d6e authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Disable TorchDynamo optimizations in PyTorch modules (#312)



* Disable TorchDynamo optimizations in PyTorch modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add test for Torch Dynamo
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add torch.dynamo test to qa
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Skip torch.compile test for <v2.0
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 36873ec8
......@@ -10,3 +10,4 @@ pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import pytest
import torch
import transformer_engine.pytorch as te
# Model names for test_torch_dynamo
_model_names = ["Linear", "LayerNorm", "LayerNormLinear", "LayerNormMLP"]
@pytest.mark.skipif(torch.__version__ < "2", reason="torch.compile not available")
@pytest.mark.parametrize("model_name", _model_names)
def test_torch_dynamo(model_name: str):
"""Test compatibility with Torch Dynamo
Construct model, optimize with Torch Dynamo, and perform a single
forward and backward pass.
"""
# Helper function to construct tensor with default options
def make_tensor(
dims: Tuple[int],
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
requires_grad: bool = True,
**kwargs,
):
return torch.zeros(
dims,
dtype=dtype,
device=device,
requires_grad=requires_grad,
**kwargs,
)
# Construct model and input tensors
model = None
inputs = []
if model_name == "Linear":
model = te.Linear(16, 16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNorm":
model = te.LayerNorm(16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNormLinear":
model = te.LayerNormLinear(16,16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNormMLP":
model = te.LayerNormMLP(16,16)
inputs = [make_tensor([16,16])]
assert model is not None, f"could not construct {model_name}"
# Optimize model with TorchDynamo
torch.compile(model)
# Forward and backward pass
out = model(*inputs)
out.backward(torch.zeros_like(out))
......@@ -12,6 +12,13 @@ jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
no_torch_dynamo = torch._dynamo.disable
def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options."""
......
......@@ -14,6 +14,7 @@ import transformer_engine_extensions as tex
from ..cpp_extensions import (
layernorm_fwd_inf,
)
from ..jit import no_torch_dynamo
__all__ = ["LayerNorm"]
......@@ -160,6 +161,7 @@ class LayerNorm(torch.nn.Module):
init.zeros_(self.bias)
@no_torch_dynamo
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# Maintain backward compatibility.
......
......@@ -49,6 +49,7 @@ from ..cpp_extensions import (
cast_from_fp8,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
__all__ = ["LayerNormLinear"]
......@@ -821,6 +822,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return fp8_weight_tensors
@no_torch_dynamo
def forward(
self,
inp: torch.Tensor,
......
......@@ -44,6 +44,7 @@ from ..distributed import (
from .. import cpp_extensions as tex
from ..constants import dist_group_type, TE_DType
from ..jit import no_torch_dynamo
__all__ = ["LayerNormMLP"]
......@@ -1140,6 +1141,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return fp8_weight_tensors
@no_torch_dynamo
def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
......
......@@ -43,6 +43,7 @@ from ..cpp_extensions import (
cast_to_fp8,
)
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
__all__ = ["Linear"]
......@@ -668,6 +669,7 @@ class Linear(TransformerEngineBaseModule):
return fp8_weight_tensors
@no_torch_dynamo
def forward(
self,
inp: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment