Unverified Commit 303c6d16 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Enforce PyTorch version 2.1 and run attention tests with torch.compile (#1516)



* Enforce torch 2.0 and run attn tests with torch.compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* replace torch.compile with jit_fuser
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9588109d
......@@ -23,6 +23,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
exit $FAIL
......@@ -103,7 +103,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch"])
install_reqs.extend(["torch>=2.1"])
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable"])
......
......@@ -6,8 +6,8 @@ import os
import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import torch_version
import torch
......
......@@ -7,16 +7,25 @@
# pylint: disable=wrong-import-position,wrong-import-order
import logging
import functools
import sys
import importlib
import importlib.util
import sys
import torch
from importlib.metadata import version
from packaging.version import Version as PkgVersion
import torch
from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_torch"
......@@ -60,6 +69,9 @@ def _load_library():
spec.loader.exec_module(solib)
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
_load_library()
from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear
......
......@@ -1385,7 +1385,7 @@ def _get_full_cu_seqlens(
return _cu_seqlens_cache[(batch_size, max_seqlen)]
@torch.compile
@jit_fuser
def pack_tensor(
indices: torch.Tensor,
tensor: torch.Tensor,
......@@ -1409,7 +1409,7 @@ def pack_tensor(
return packed
@torch.compile
@jit_fuser
def pack_2_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
......@@ -1423,7 +1423,7 @@ def pack_2_tensors(
return t1_packed, t2_packed
@torch.compile
@jit_fuser
def pack_3_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
......@@ -1439,7 +1439,7 @@ def pack_3_tensors(
return t1_packed, t2_packed, t3_packed
@torch.compile
@jit_fuser
def unpack_tensor(
indices: torch.Tensor,
dim0: int,
......@@ -1462,7 +1462,7 @@ def unpack_tensor(
return unpacked
@torch.compile
@jit_fuser
def unpack_2_tensors(
indices: torch.Tensor,
dim0: int,
......@@ -1477,7 +1477,7 @@ def unpack_2_tensors(
return t1_unpacked, t2_unpacked
@torch.compile
@jit_fuser
def unpack_3_tensors(
indices: torch.Tensor,
dim0: int,
......@@ -1645,7 +1645,7 @@ def get_cu_seqlens_on_cp_rank(
return cu_seqlens_on_cp_rank
@torch.compile
@jit_fuser
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
"""
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
......@@ -1665,7 +1665,7 @@ def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
return chunk_ids
@torch.compile
@jit_fuser
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
"""Reorder sequence chunk for A2A communication."""
if before_attn:
......
......@@ -10,28 +10,20 @@ import torch
# pylint: disable=unnecessary-lambda-assignment
jit_fuser = torch.jit.script
jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = torch.compile
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
f, recursive=recursive
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive)
def set_jit_fusion_options() -> None:
......
......@@ -10,13 +10,13 @@ from typing import Any, Iterable, Optional
import torch
from transformer_engine_torch import FP8TensorMeta
from .. import torch_version
from ..fp8 import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor
from ..utils import (
canonicalize_device,
canonicalize_dtype,
devices_match,
torch_version,
)
......
......@@ -8,7 +8,6 @@ import functools
import math
import os
from typing import Any, Callable, List, Optional, Tuple
from packaging.version import Version as PkgVersion
import torch
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -387,9 +386,3 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:
# Pop NVTX range
torch.cuda.nvtx.range_pop()
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment