Unverified Commit 7e7f0920 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Disable dynamo for Fused Attention (#558)



* Disable dynamo for Fused Attention
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent efd4b62a
......@@ -10,11 +10,17 @@ import torch
import transformer_engine.pytorch as te
# Model names for test_torch_dynamo
_model_names = ["Linear", "LayerNorm", "LayerNormLinear", "LayerNormMLP"]
_model_factory = {
"Linear": [(lambda: te.Linear(16, 16)), [16, 16]],
"LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]],
"LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]],
"LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]],
"TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]],
}
@pytest.mark.skipif(torch.__version__ < "2", reason="torch.compile not available")
@pytest.mark.parametrize("model_name", _model_names)
@pytest.mark.parametrize("model_name", list(_model_factory.keys()))
def test_torch_dynamo(model_name: str):
"""Test compatibility with Torch Dynamo
......@@ -40,21 +46,9 @@ def test_torch_dynamo(model_name: str):
)
# Construct model and input tensors
model = None
inputs = []
if model_name == "Linear":
model = te.Linear(16, 16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNorm":
model = te.LayerNorm(16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNormLinear":
model = te.LayerNormLinear(16,16)
inputs = [make_tensor([16,16])]
elif model_name == "LayerNormMLP":
model = te.LayerNormMLP(16,16)
inputs = [make_tensor([16,16])]
assert model is not None, f"could not construct {model_name}"
model_builder, input_builder = _model_factory[model_name]
model = model_builder()
inputs = [make_tensor(input_builder)]
# Optimize model with TorchDynamo
torch.compile(model)
......
......@@ -51,7 +51,7 @@ from transformer_engine.pytorch.distributed import (
checkpoint,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6")
......@@ -1714,6 +1714,7 @@ class FusedAttention(torch.nn.Module):
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
@no_torch_dynamo()
def forward(
self,
query_layer: torch.Tensor,
......@@ -2053,6 +2054,7 @@ class DotProductAttention(torch.nn.Module):
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
@no_torch_dynamo(recursive=False)
def forward(
self,
query_layer: torch.Tensor,
......
......@@ -14,10 +14,10 @@ if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda func: func
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
no_torch_dynamo = torch._dynamo.disable
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive)
def set_jit_fusion_options() -> None:
......
......@@ -156,7 +156,7 @@ class LayerNorm(torch.nn.Module):
init.zeros_(self.weight)
init.zeros_(self.bias)
@no_torch_dynamo
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# Set the activation type for AMP.
......
......@@ -913,7 +913,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return fp8_weight_tensors
@no_torch_dynamo
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
......
......@@ -1313,7 +1313,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return fp8_weight_tensors
@no_torch_dynamo
@no_torch_dynamo()
def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
......
......@@ -791,7 +791,7 @@ class Linear(TransformerEngineBaseModule):
return fp8_weight_tensors
@no_torch_dynamo
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
......
......@@ -158,7 +158,7 @@ class RMSNorm(torch.nn.Module):
init.zeros_(self.weight)
@no_torch_dynamo
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""RMSNorm FWD"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment