Unverified Commit b612cdeb authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix TE ops API compatibility with PyTorch versions < 2.4.3 (#1494)



* Fix te sequential for older pytorch versions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIxes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fceff07a
......@@ -7,19 +7,9 @@ import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import torch
from packaging.version import Version as PkgVersion
def get_torch_version():
"""Get PyTorch version from __version__"""
from transformer_engine.pytorch.utils import torch_version
def get_torch_version_str():
import torch
return str(torch.__version__)
return PkgVersion(get_torch_version_str())
import torch
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -44,7 +34,7 @@ def _run_test(fp_init, sharding_dims):
@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not get_torch_version() >= PkgVersion("2.4"), reason="Requires PyTorch 2.4.0+")
@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims):
......
......@@ -16,6 +16,7 @@ from ..utils import (
canonicalize_device,
canonicalize_dtype,
devices_match,
torch_version,
)
......@@ -98,8 +99,13 @@ def maybe_autocast_dtype(
default_dtype: Optional[torch.dtype] = None,
) -> torch.dtype:
"""Get autocast dtype if enabled"""
if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type)
if torch_version() >= (2, 4, 3):
if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type)
else:
if torch.is_autocast_enabled():
return torch.get_autocast_gpu_dtype()
return canonicalize_dtype(default_dtype)
......
......@@ -8,6 +8,7 @@ import functools
import math
import os
from typing import Any, Callable, List, Optional, Tuple
from packaging.version import Version as PkgVersion
import torch
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -386,3 +387,9 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:
# Pop NVTX range
torch.cuda.nvtx.range_pop()
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment