Unverified Commit 303c6d16 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Enforce PyTorch version 2.1 and run attention tests with torch.compile (#1516)



* Enforce torch 2.0 and run attn tests with torch.compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* replace torch.compile with jit_fuser
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9588109d
...@@ -23,6 +23,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 ...@@ -23,6 +23,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
exit $FAIL exit $FAIL
...@@ -103,7 +103,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -103,7 +103,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
install_reqs.extend(["torch"]) install_reqs.extend(["torch>=2.1"])
# Blackwell is not supported as of Triton 3.2.0, need custom internal build # Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton") # install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable"]) test_reqs.extend(["numpy", "torchvision", "prettytable"])
......
...@@ -6,8 +6,8 @@ import os ...@@ -6,8 +6,8 @@ import os
import pytest import pytest
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import torch_version
import torch import torch
......
...@@ -7,16 +7,25 @@ ...@@ -7,16 +7,25 @@
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position,wrong-import-order
import logging import logging
import functools
import sys
import importlib import importlib
import importlib.util import importlib.util
import sys
import torch
from importlib.metadata import version from importlib.metadata import version
from packaging.version import Version as PkgVersion
import torch
from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension from transformer_engine.common import _get_sys_extension
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
def _load_library(): def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_torch" module_name = "transformer_engine_torch"
...@@ -60,6 +69,9 @@ def _load_library(): ...@@ -60,6 +69,9 @@ def _load_library():
spec.loader.exec_module(solib) spec.loader.exec_module(solib)
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
_load_library() _load_library()
from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import Linear
......
...@@ -1385,7 +1385,7 @@ def _get_full_cu_seqlens( ...@@ -1385,7 +1385,7 @@ def _get_full_cu_seqlens(
return _cu_seqlens_cache[(batch_size, max_seqlen)] return _cu_seqlens_cache[(batch_size, max_seqlen)]
@torch.compile @jit_fuser
def pack_tensor( def pack_tensor(
indices: torch.Tensor, indices: torch.Tensor,
tensor: torch.Tensor, tensor: torch.Tensor,
...@@ -1409,7 +1409,7 @@ def pack_tensor( ...@@ -1409,7 +1409,7 @@ def pack_tensor(
return packed return packed
@torch.compile @jit_fuser
def pack_2_tensors( def pack_2_tensors(
indices: torch.Tensor, indices: torch.Tensor,
t1: torch.Tensor, t1: torch.Tensor,
...@@ -1423,7 +1423,7 @@ def pack_2_tensors( ...@@ -1423,7 +1423,7 @@ def pack_2_tensors(
return t1_packed, t2_packed return t1_packed, t2_packed
@torch.compile @jit_fuser
def pack_3_tensors( def pack_3_tensors(
indices: torch.Tensor, indices: torch.Tensor,
t1: torch.Tensor, t1: torch.Tensor,
...@@ -1439,7 +1439,7 @@ def pack_3_tensors( ...@@ -1439,7 +1439,7 @@ def pack_3_tensors(
return t1_packed, t2_packed, t3_packed return t1_packed, t2_packed, t3_packed
@torch.compile @jit_fuser
def unpack_tensor( def unpack_tensor(
indices: torch.Tensor, indices: torch.Tensor,
dim0: int, dim0: int,
...@@ -1462,7 +1462,7 @@ def unpack_tensor( ...@@ -1462,7 +1462,7 @@ def unpack_tensor(
return unpacked return unpacked
@torch.compile @jit_fuser
def unpack_2_tensors( def unpack_2_tensors(
indices: torch.Tensor, indices: torch.Tensor,
dim0: int, dim0: int,
...@@ -1477,7 +1477,7 @@ def unpack_2_tensors( ...@@ -1477,7 +1477,7 @@ def unpack_2_tensors(
return t1_unpacked, t2_unpacked return t1_unpacked, t2_unpacked
@torch.compile @jit_fuser
def unpack_3_tensors( def unpack_3_tensors(
indices: torch.Tensor, indices: torch.Tensor,
dim0: int, dim0: int,
...@@ -1645,7 +1645,7 @@ def get_cu_seqlens_on_cp_rank( ...@@ -1645,7 +1645,7 @@ def get_cu_seqlens_on_cp_rank(
return cu_seqlens_on_cp_rank return cu_seqlens_on_cp_rank
@torch.compile @jit_fuser
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
""" """
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
...@@ -1665,7 +1665,7 @@ def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): ...@@ -1665,7 +1665,7 @@ def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
return chunk_ids return chunk_ids
@torch.compile @jit_fuser
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
"""Reorder sequence chunk for A2A communication.""" """Reorder sequence chunk for A2A communication."""
if before_attn: if before_attn:
......
...@@ -10,28 +10,20 @@ import torch ...@@ -10,28 +10,20 @@ import torch
# pylint: disable=unnecessary-lambda-assignment # pylint: disable=unnecessary-lambda-assignment
jit_fuser = torch.jit.script jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile jit_fuser = torch.compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597 # See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = torch.compile dropout_fuser = torch.compile
# Decorator to disable Torch Dynamo # Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308 # See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive)
if torch.__version__ >= "2":
import torch._dynamo
if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
f, recursive=recursive
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
def set_jit_fusion_options() -> None: def set_jit_fusion_options() -> None:
......
...@@ -10,13 +10,13 @@ from typing import Any, Iterable, Optional ...@@ -10,13 +10,13 @@ from typing import Any, Iterable, Optional
import torch import torch
from transformer_engine_torch import FP8TensorMeta from transformer_engine_torch import FP8TensorMeta
from .. import torch_version
from ..fp8 import FP8GlobalStateManager from ..fp8 import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor from ..tensor.float8_tensor import Float8Tensor
from ..utils import ( from ..utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
devices_match, devices_match,
torch_version,
) )
......
...@@ -8,7 +8,6 @@ import functools ...@@ -8,7 +8,6 @@ import functools
import math import math
import os import os
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
from packaging.version import Version as PkgVersion
import torch import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
...@@ -387,9 +386,3 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None: ...@@ -387,9 +386,3 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:
# Pop NVTX range # Pop NVTX range
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment