jax.py 3.21 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
6
"""JAX related extensions."""
import os
7
from pathlib import Path
8
from packaging import version
9
10
11

import setuptools

12
from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled
13
14
from typing import List

15

16
17
def install_requirements() -> List[str]:
    """Install dependencies for TE/JAX extensions."""
18
    return ["jax", "flax>=0.7.1"]
19
20
21


def test_requirements() -> List[str]:
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    """Test dependencies for TE/JAX extensions.

    Triton Package Selection:
        The triton package is selected based on NVTE_USE_PYTORCH_TRITON environment variable:

        Default (NVTE_USE_PYTORCH_TRITON unset or "0"):
            Returns 'triton' - OpenAI's standard package from PyPI.
            Install with: pip install triton

        NVTE_USE_PYTORCH_TRITON=1:
            Returns 'pytorch-triton' - for mixed JAX+PyTorch environments.
            Install with: pip install pytorch-triton --index-url https://download.pytorch.org/whl/cu121

            Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder.
    """
    use_pytorch_triton = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0")))

    triton_package = "pytorch-triton" if use_pytorch_triton else "triton"

    return [
        "numpy",
        triton_package,
    ]
45
46


47
48
49
50
51
def xla_path() -> str:
    """XLA root path lookup.
    Throws FileNotFoundError if XLA source is not found."""

    try:
52
53
54
55
56
57
58
        import jax

        if version.parse(jax.__version__) >= version.parse("0.5.0"):
            from jax import ffi  # pylint: disable=ungrouped-imports
        else:
            from jax.extend import ffi  # pylint: disable=ungrouped-imports

59
60
61
62
63
64
65
66
67
68
69
    except ImportError:
        if os.getenv("XLA_HOME"):
            xla_home = Path(os.getenv("XLA_HOME"))
        else:
            xla_home = "/opt/xla"
    else:
        xla_home = ffi.include_dir()

    if not os.path.isdir(xla_home):
        raise FileNotFoundError("Could not find xla source.")
    return xla_home
70

71
72
73
74
75
76
77
78
79

def setup_jax_extension(
    csrc_source_files,
    csrc_header_files,
    common_header_files,
) -> setuptools.Extension:
    """Setup PyBind11 extension for JAX support"""
    # Source files
    csrc_source_files = Path(csrc_source_files)
80
    extensions_dir = csrc_source_files / "extensions"
81
    sources = all_files_in_dir(extensions_dir, name_extension="cpp")
82
83

    # Header files
84
85
86
87
88
89
90
91
92
93
    include_dirs = get_cuda_include_dirs()
    include_dirs.extend(
        [
            common_header_files,
            common_header_files / "common",
            common_header_files / "common" / "include",
            csrc_header_files,
            xla_path(),
        ]
    )
94
95

    # Compile flags
96
    cxx_flags = ["-O3"]
97
98
99
100
101
    if debug_build_enabled():
        cxx_flags.append("-g")
        cxx_flags.append("-UNDEBUG")
    else:
        cxx_flags.append("-g0")
102
103
104
105

    # Define TE/JAX as a Pybind11Extension
    from pybind11.setup_helpers import Pybind11Extension

106
    return Pybind11Extension(
107
108
109
        "transformer_engine_jax",
        sources=[str(path) for path in sources],
        include_dirs=[str(path) for path in include_dirs],
110
        extra_compile_args=cxx_flags,
Phuong Nguyen's avatar
Phuong Nguyen committed
111
        libraries=["nccl"],
112
    )