Unverified Commit 0832cd2c authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Use torch.compile for version 2.0 and higher (#255)



* Use torch.compile for version 2.0 and higher
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove unused import
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* use torch.__version__
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Use NVFuser for dropout fusions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix onnx tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6280dc7a
......@@ -9,4 +9,4 @@ set -e
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
......@@ -14,6 +14,7 @@ from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type
from .utils import get_device_compute_capability
from .jit import jit_fuser
_FP8_ENABLED = False
_FP8_CALIBRATION = False
......@@ -368,7 +369,7 @@ def update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
return amax_history
@torch.jit.script
@jit_fuser
def _default_get_amax(
amax_history: torch.Tensor,
amax_compute_algo: str,
......@@ -383,7 +384,7 @@ def _default_get_amax(
return amax_history, amax
@torch.jit.script
@jit_fuser
def _default_sf_compute(
amax: torch.Tensor,
scale: torch.Tensor,
......@@ -400,7 +401,7 @@ def _default_sf_compute(
return sf
@torch.jit.script
@jit_fuser
def _compute_scaling_factor_inverse(
scale: torch.Tensor,
scale_inv: torch.Tensor,
......@@ -413,7 +414,7 @@ def _compute_scaling_factor_inverse(
return torch.where(non_weight_mask, 1.0 / scale, scale_inv)
@torch.jit.script
@jit_fuser
def fused_amax_and_scale_update(
amax_history: torch.Tensor,
scale: torch.Tensor,
......
......@@ -3,9 +3,15 @@
# See LICENSE for license information.
"""NVFuser functions and JIT utilities"""
import os
from typing import Callable, Optional, Tuple
import torch
jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile
def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options."""
......@@ -29,14 +35,14 @@ def set_jit_fusion_options() -> None:
torch._C._jit_override_can_fuse_on_gpu(True)
@torch.jit.script
@jit_fuser
def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Bias-GeLU fused"""
x = inp + bias
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@torch.jit.script
@jit_fuser
def gelu_fused_(inp: torch.Tensor) -> torch.Tensor:
"""
GeLU fused, this is copy of bias_gelu_fused cause jit fusion doesn't allow conditioning.
......@@ -48,7 +54,7 @@ def gelu_fused_(inp: torch.Tensor) -> torch.Tensor:
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
@jit_fuser
def bgrad_dgelu_fused_(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -64,7 +70,7 @@ def bgrad_dgelu_fused_(
return bgrad, dgelu
@torch.jit.script
@jit_fuser
def dgelu_fused_(
grad_output: torch.Tensor, inp: torch.Tensor
) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment