Unverified Commit 051db0d7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Switch to torch.compile for dropout for torch v>2.2 (#607)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a1e8f19d
...@@ -12,6 +12,11 @@ jit_fuser = torch.jit.script ...@@ -12,6 +12,11 @@ jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile jit_fuser = torch.compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = torch.compile
# Decorator to disable Torch Dynamo # Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308 # See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func no_torch_dynamo = lambda recursive=True: lambda func: func
...@@ -134,7 +139,7 @@ def get_bias_dropout_add(training: bool) -> Callable: ...@@ -134,7 +139,7 @@ def get_bias_dropout_add(training: bool) -> Callable:
return _bias_dropout_add return _bias_dropout_add
@jit_fuser @dropout_fuser
def bias_dropout_add_fused_train_( def bias_dropout_add_fused_train_(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -151,7 +156,7 @@ def bias_dropout_add_fused_train( ...@@ -151,7 +156,7 @@ def bias_dropout_add_fused_train(
return bias_dropout_add_fused_train_(x, bias, residual, prob) return bias_dropout_add_fused_train_(x, bias, residual, prob)
@jit_fuser @dropout_fuser
def bias_dropout_add_fused_inference_( def bias_dropout_add_fused_inference_(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor: ) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment