Unverified Commit 7e7f0920 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Disable dynamo for Fused Attention (#558)



* Disable dynamo for Fused Attention
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent efd4b62a
...@@ -10,11 +10,17 @@ import torch ...@@ -10,11 +10,17 @@ import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
# Model names for test_torch_dynamo # Model names for test_torch_dynamo
_model_names = ["Linear", "LayerNorm", "LayerNormLinear", "LayerNormMLP"] _model_factory = {
"Linear": [(lambda: te.Linear(16, 16)), [16, 16]],
"LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]],
"LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]],
"LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]],
"TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]],
}
@pytest.mark.skipif(torch.__version__ < "2", reason="torch.compile not available") @pytest.mark.skipif(torch.__version__ < "2", reason="torch.compile not available")
@pytest.mark.parametrize("model_name", _model_names) @pytest.mark.parametrize("model_name", list(_model_factory.keys()))
def test_torch_dynamo(model_name: str): def test_torch_dynamo(model_name: str):
"""Test compatibility with Torch Dynamo """Test compatibility with Torch Dynamo
...@@ -40,21 +46,9 @@ def test_torch_dynamo(model_name: str): ...@@ -40,21 +46,9 @@ def test_torch_dynamo(model_name: str):
) )
# Construct model and input tensors # Construct model and input tensors
model = None model_builder, input_builder = _model_factory[model_name]
inputs = [] model = model_builder()
if model_name == "Linear": inputs = [make_tensor(input_builder)]
model = te.Linear(16, 16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNorm":
model = te.LayerNorm(16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNormLinear":
model = te.LayerNormLinear(16,16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNormMLP":
model = te.LayerNormMLP(16,16)
inputs = [make_tensor([16,16])]
assert model is not None, f"could not construct {model_name}"
# Optimize model with TorchDynamo # Optimize model with TorchDynamo
torch.compile(model) torch.compile(model)
......
...@@ -51,7 +51,7 @@ from transformer_engine.pytorch.distributed import ( ...@@ -51,7 +51,7 @@ from transformer_engine.pytorch.distributed import (
checkpoint, checkpoint,
) )
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6") _flash_attn_version_required = packaging.version.Version("1.0.6")
...@@ -1714,6 +1714,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1714,6 +1714,7 @@ class FusedAttention(torch.nn.Module):
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
@no_torch_dynamo()
def forward( def forward(
self, self,
query_layer: torch.Tensor, query_layer: torch.Tensor,
...@@ -2053,6 +2054,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2053,6 +2054,7 @@ class DotProductAttention(torch.nn.Module):
self.cp_global_ranks = cp_global_ranks self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream self.cp_stream = cp_stream
@no_torch_dynamo(recursive=False)
def forward( def forward(
self, self,
query_layer: torch.Tensor, query_layer: torch.Tensor,
......
...@@ -14,10 +14,10 @@ if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): ...@@ -14,10 +14,10 @@ if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
# Decorator to disable Torch Dynamo # Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308 # See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda func: func no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2": if torch.__version__ >= "2":
import torch._dynamo import torch._dynamo
no_torch_dynamo = torch._dynamo.disable no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive)
def set_jit_fusion_options() -> None: def set_jit_fusion_options() -> None:
......
...@@ -156,7 +156,7 @@ class LayerNorm(torch.nn.Module): ...@@ -156,7 +156,7 @@ class LayerNorm(torch.nn.Module):
init.zeros_(self.weight) init.zeros_(self.weight)
init.zeros_(self.bias) init.zeros_(self.bias)
@no_torch_dynamo @no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD""" """LayerNorm FWD"""
# Set the activation type for AMP. # Set the activation type for AMP.
......
...@@ -913,7 +913,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -913,7 +913,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return fp8_weight_tensors return fp8_weight_tensors
@no_torch_dynamo @no_torch_dynamo()
def forward( def forward(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
......
...@@ -1313,7 +1313,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1313,7 +1313,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return fp8_weight_tensors return fp8_weight_tensors
@no_torch_dynamo @no_torch_dynamo()
def forward( def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
......
...@@ -791,7 +791,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -791,7 +791,7 @@ class Linear(TransformerEngineBaseModule):
return fp8_weight_tensors return fp8_weight_tensors
@no_torch_dynamo @no_torch_dynamo()
def forward( def forward(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
......
...@@ -158,7 +158,7 @@ class RMSNorm(torch.nn.Module): ...@@ -158,7 +158,7 @@ class RMSNorm(torch.nn.Module):
init.zeros_(self.weight) init.zeros_(self.weight)
@no_torch_dynamo @no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""RMSNorm FWD""" """RMSNorm FWD"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment