Unverified Commit b612cdeb authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix TE ops API compatibility with PyTorch versions < 2.4.3 (#1494)



* Fix te sequential for older pytorch versions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIxes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fceff07a
...@@ -7,19 +7,9 @@ import pytest ...@@ -7,19 +7,9 @@ import pytest
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import torch from transformer_engine.pytorch.utils import torch_version
from packaging.version import Version as PkgVersion
def get_torch_version():
"""Get PyTorch version from __version__"""
def get_torch_version_str(): import torch
import torch
return str(torch.__version__)
return PkgVersion(get_torch_version_str())
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -44,7 +34,7 @@ def _run_test(fp_init, sharding_dims): ...@@ -44,7 +34,7 @@ def _run_test(fp_init, sharding_dims):
@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not get_torch_version() >= PkgVersion("2.4"), reason="Requires PyTorch 2.4.0+") @pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True)) @pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims): def test_distributed(fp8_init, sharding_dims):
......
...@@ -16,6 +16,7 @@ from ..utils import ( ...@@ -16,6 +16,7 @@ from ..utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
devices_match, devices_match,
torch_version,
) )
...@@ -98,8 +99,13 @@ def maybe_autocast_dtype( ...@@ -98,8 +99,13 @@ def maybe_autocast_dtype(
default_dtype: Optional[torch.dtype] = None, default_dtype: Optional[torch.dtype] = None,
) -> torch.dtype: ) -> torch.dtype:
"""Get autocast dtype if enabled""" """Get autocast dtype if enabled"""
if torch_version() >= (2, 4, 3):
if torch.is_autocast_enabled(device_type): if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type) return torch.get_autocast_dtype(device_type)
else:
if torch.is_autocast_enabled():
return torch.get_autocast_gpu_dtype()
return canonicalize_dtype(default_dtype) return canonicalize_dtype(default_dtype)
......
...@@ -8,6 +8,7 @@ import functools ...@@ -8,6 +8,7 @@ import functools
import math import math
import os import os
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
from packaging.version import Version as PkgVersion
import torch import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
...@@ -386,3 +387,9 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None: ...@@ -386,3 +387,9 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:
# Pop NVTX range # Pop NVTX range
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment