Unverified Commit 0832cd2c authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Use torch.compile for version 2.0 and higher (#255)



* Use torch.compile for version 2.0 and higher
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove unused import
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* use torch.__version__
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Use NVFuser for dropout fusions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix onnx tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6280dc7a
...@@ -9,4 +9,4 @@ set -e ...@@ -9,4 +9,4 @@ set -e
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
...@@ -14,6 +14,7 @@ from transformer_engine.common.recipe import DelayedScaling, Format ...@@ -14,6 +14,7 @@ from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type from .constants import dist_group_type
from .utils import get_device_compute_capability from .utils import get_device_compute_capability
from .jit import jit_fuser
_FP8_ENABLED = False _FP8_ENABLED = False
_FP8_CALIBRATION = False _FP8_CALIBRATION = False
...@@ -368,7 +369,7 @@ def update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: ...@@ -368,7 +369,7 @@ def update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
return amax_history return amax_history
@torch.jit.script @jit_fuser
def _default_get_amax( def _default_get_amax(
amax_history: torch.Tensor, amax_history: torch.Tensor,
amax_compute_algo: str, amax_compute_algo: str,
...@@ -383,7 +384,7 @@ def _default_get_amax( ...@@ -383,7 +384,7 @@ def _default_get_amax(
return amax_history, amax return amax_history, amax
@torch.jit.script @jit_fuser
def _default_sf_compute( def _default_sf_compute(
amax: torch.Tensor, amax: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
...@@ -400,7 +401,7 @@ def _default_sf_compute( ...@@ -400,7 +401,7 @@ def _default_sf_compute(
return sf return sf
@torch.jit.script @jit_fuser
def _compute_scaling_factor_inverse( def _compute_scaling_factor_inverse(
scale: torch.Tensor, scale: torch.Tensor,
scale_inv: torch.Tensor, scale_inv: torch.Tensor,
...@@ -413,7 +414,7 @@ def _compute_scaling_factor_inverse( ...@@ -413,7 +414,7 @@ def _compute_scaling_factor_inverse(
return torch.where(non_weight_mask, 1.0 / scale, scale_inv) return torch.where(non_weight_mask, 1.0 / scale, scale_inv)
@torch.jit.script @jit_fuser
def fused_amax_and_scale_update( def fused_amax_and_scale_update(
amax_history: torch.Tensor, amax_history: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
......
...@@ -3,9 +3,15 @@ ...@@ -3,9 +3,15 @@
# See LICENSE for license information. # See LICENSE for license information.
"""NVFuser functions and JIT utilities""" """NVFuser functions and JIT utilities"""
import os
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import torch import torch
jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile
def set_jit_fusion_options() -> None: def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options.""" """Set PyTorch JIT layer fusion options."""
...@@ -29,14 +35,14 @@ def set_jit_fusion_options() -> None: ...@@ -29,14 +35,14 @@ def set_jit_fusion_options() -> None:
torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_override_can_fuse_on_gpu(True)
@torch.jit.script @jit_fuser
def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Bias-GeLU fused""" """Bias-GeLU fused"""
x = inp + bias x = inp + bias
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@torch.jit.script @jit_fuser
def gelu_fused_(inp: torch.Tensor) -> torch.Tensor: def gelu_fused_(inp: torch.Tensor) -> torch.Tensor:
""" """
GeLU fused, this is copy of bias_gelu_fused cause jit fusion doesn't allow conditioning. GeLU fused, this is copy of bias_gelu_fused cause jit fusion doesn't allow conditioning.
...@@ -48,7 +54,7 @@ def gelu_fused_(inp: torch.Tensor) -> torch.Tensor: ...@@ -48,7 +54,7 @@ def gelu_fused_(inp: torch.Tensor) -> torch.Tensor:
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script @jit_fuser
def bgrad_dgelu_fused_( def bgrad_dgelu_fused_(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -64,7 +70,7 @@ def bgrad_dgelu_fused_( ...@@ -64,7 +70,7 @@ def bgrad_dgelu_fused_(
return bgrad, dgelu return bgrad, dgelu
@torch.jit.script @jit_fuser
def dgelu_fused_( def dgelu_fused_(
grad_output: torch.Tensor, inp: torch.Tensor grad_output: torch.Tensor, inp: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment