jax.py 2.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
6
"""JAX related extensions."""
import os
7
8
9
10
from pathlib import Path

import setuptools

11
from .utils import cuda_path, all_files_in_dir
12
13
from typing import List

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

def xla_path() -> str:
    """XLA root path lookup.
    Throws FileNotFoundError if XLA source is not found."""

    try:
        from jax.extend import ffi
    except ImportError:
        if os.getenv("XLA_HOME"):
            xla_home = Path(os.getenv("XLA_HOME"))
        else:
            xla_home = "/opt/xla"
    else:
        xla_home = ffi.include_dir()

    if not os.path.isdir(xla_home):
        raise FileNotFoundError("Could not find xla source.")
    return xla_home
32

33
34
35
36
37
38
39
40
41

def setup_jax_extension(
    csrc_source_files,
    csrc_header_files,
    common_header_files,
) -> setuptools.Extension:
    """Setup PyBind11 extension for JAX support"""
    # Source files
    csrc_source_files = Path(csrc_source_files)
42
    extensions_dir = csrc_source_files / "extensions"
43
    sources = all_files_in_dir(extensions_dir, ".cpp")
44
45
46

    # Header files
    cuda_home, _ = cuda_path()
47
    xla_home = xla_path()
48
49
50
51
52
53
    include_dirs = [
        cuda_home / "include",
        common_header_files,
        common_header_files / "common",
        common_header_files / "common" / "include",
        csrc_header_files,
54
        xla_home,
55
56
57
    ]

    # Compile flags
58
    cxx_flags = ["-O3"]
59
60
61
62

    # Define TE/JAX as a Pybind11Extension
    from pybind11.setup_helpers import Pybind11Extension

63
64
    class Pybind11CPPExtension(Pybind11Extension):
        """Modified Pybind11Extension to allow custom CXX flags."""
65
66
67

        def _add_cflags(self, flags: List[str]) -> None:
            if isinstance(self.extra_compile_args, dict):
68
                cxx_flags = self.extra_compile_args.pop("cxx", [])
69
                cxx_flags += flags
70
                self.extra_compile_args["cxx"] = cxx_flags
71
72
73
            else:
                self.extra_compile_args[:0] = flags

74
    return Pybind11CPPExtension(
75
76
77
        "transformer_engine_jax",
        sources=[str(path) for path in sources],
        include_dirs=[str(path) for path in include_dirs],
78
        extra_compile_args={"cxx": cxx_flags},
79
    )