jax.py 1.91 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Paddle-paddle related extensions."""
from pathlib import Path

import setuptools
9
from glob import glob
10

11
from .utils import cuda_path, all_files_in_dir
12
13
from typing import List

14
15
from jax.extend import ffi

16
17
18
19
20
21
22
23
24

def setup_jax_extension(
    csrc_source_files,
    csrc_header_files,
    common_header_files,
) -> setuptools.Extension:
    """Setup PyBind11 extension for JAX support"""
    # Source files
    csrc_source_files = Path(csrc_source_files)
25
    extensions_dir = csrc_source_files / "extensions"
26
27
    sources = [
        csrc_source_files / "utils.cu",
Phuong Nguyen's avatar
Phuong Nguyen committed
28
    ] + all_files_in_dir(extensions_dir, ".cpp")
29
30
31

    # Header files
    cuda_home, _ = cuda_path()
32
    jax_ffi_include = ffi.include_dir()
33
34
35
36
37
38
    include_dirs = [
        cuda_home / "include",
        common_header_files,
        common_header_files / "common",
        common_header_files / "common" / "include",
        csrc_header_files,
39
        jax_ffi_include,
40
41
42
    ]

    # Compile flags
43
44
    cxx_flags = ["-O3"]
    nvcc_flags = ["-O3"]
45
46
47
48
49
50
51
52
53

    # Define TE/JAX as a Pybind11Extension
    from pybind11.setup_helpers import Pybind11Extension

    class Pybind11CUDAExtension(Pybind11Extension):
        """Modified Pybind11Extension to allow combined CXX + NVCC compile flags."""

        def _add_cflags(self, flags: List[str]) -> None:
            if isinstance(self.extra_compile_args, dict):
54
                cxx_flags = self.extra_compile_args.pop("cxx", [])
55
                cxx_flags += flags
56
                self.extra_compile_args["cxx"] = cxx_flags
57
58
59
60
61
62
63
            else:
                self.extra_compile_args[:0] = flags

    return Pybind11CUDAExtension(
        "transformer_engine_jax",
        sources=[str(path) for path in sources],
        include_dirs=[str(path) for path in include_dirs],
64
        extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags},
65
    )