cuda_ext.py 3.99 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os

from setuptools.command.build_ext import build_ext
from distutils.unixccompiler import UnixCCompiler

# This is based on
# https://github.com/rmcgibbo/npcuda-example/blob/dd2768d8ccb5688c0f08678dd8f1ad5afe3e4332/cython/setup.py
# published under BSD 2-Clause "Simplified" License


def find_in_path(name, path):
    """Find a file in a search path"""

    # Adapted fom http://code.activestate.com/recipes/52224
    for dir in path.split(os.pathsep):
        binpath = os.path.join(dir, name)
        if os.path.exists(binpath):
            return os.path.abspath(binpath)

    return None


def locate_cuda():
    """Locate the CUDA environment on the system

    Returns a dict with keys 'cuda_root', 'nvcc', 'include', and 'lib64'
    and values giving the absolute path to each directory.

    Starts by looking for the CUDAHOME and CUDA_ROOT env variables.
    If not found, everything is based on finding 'nvcc' in the PATH.
    """

    # First check if any common env variable is in use
    if "CUDAHOME" in os.environ:
        cuda_root = os.environ["CUDAHOME"]
        nvcc = os.path.join(cuda_root, "bin", "nvcc")
    elif "CUDA_ROOT" in os.environ:
        cuda_root = os.environ["CUDA_ROOT"]
        nvcc = os.path.join(cuda_root, "bin", "nvcc")
    else:
        # Otherwise, search the PATH for NVCC
        nvcc = find_in_path("nvcc", os.environ["PATH"])
        if nvcc is not None:
            cuda_root = os.path.dirname(os.path.dirname(nvcc))
        else:
            cuda_root = None

    if cuda_root is None:
        return {
            "cuda_root": "",
            "nvcc": "nvcc",
            "include": [],
            "lib64": [],
            "cflags": [],
        }

    cflags = ["-c", "--compiler-options", "'-fPIC'", "-std=c++11"]

    cm = os.environ.get("CUDA_COMPUTE_CAPABILITY")
    if cm is not None:
        cflags.append("-gencode=arch=compute_{cm},code=compute_{cm}".format(cm=cm))
    else:
        print(
            "Warning: Consider settings the CUDA_COMPUTE_CAPABILITY environment "
            "variable to your GPU's compute capability."
        )

    return {
        "cuda_root": cuda_root,
        "nvcc": nvcc,
        "include": [os.path.join(cuda_root, "include")],
        "lib64": [os.path.join(cuda_root, "lib64")],
        "cflags": cflags,
    }


def customize_compiler_for_nvcc(self):
    if not isinstance(self, UnixCCompiler):
        # Just give up
        default_compile = self._compile

        def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
            postargs = extra_postargs["gcc"]
            return default_compile(obj, src, ext, cc_args, postargs, pp_opts)

        self._compile = _compile
        return

    # Tell the compiler it can process .cu
    self.src_extensions.append(".cu")

    # Save references to the default compiler_so and _compile methods
    default_compiler_so = self.compiler_so
    default_compile = self._compile

    # Now redefine the _compile method. This gets executed for each
    # object but distutils doesn't have the ability to change compilers
    # based on source extension: we add it.
    def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
        if os.path.splitext(src)[1] == ".cu":
            # use the cuda for .cu files
            self.set_executable("compiler_so", cuda_info["nvcc"])
            # use only a subset of the extra_postargs, which are 1-1
            # translated from the extra_compile_args in the Extension class
            postargs = extra_postargs["nvcc"]
        else:
            postargs = extra_postargs["gcc"]

        default_compile(obj, src, ext, cc_args, postargs, pp_opts)
        # Reset the default compiler_so, which we might have changed for cuda
        self.compiler_so = default_compiler_so

    # Inject our redefined _compile method into the class
    self._compile = _compile


# Run the customize_compiler
class custom_build_ext(build_ext):
    def build_extensions(self):
        customize_compiler_for_nvcc(self.compiler)
        super().build_extensions()


cuda_info = locate_cuda()