nvrtc.py 9.47 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from ctypes import byref, c_char, c_char_p, c_int, c_size_t, c_void_p, POINTER
from enum import IntEnum
from numba.core import config
from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError,
                                      NvrtcSupportError)

import functools
import os
import threading
import warnings

# Opaque handle for compilation unit
nvrtc_program = c_void_p

# Result code
nvrtc_result = c_int


class NvrtcResult(IntEnum):
    NVRTC_SUCCESS = 0
    NVRTC_ERROR_OUT_OF_MEMORY = 1
    NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2
    NVRTC_ERROR_INVALID_INPUT = 3
    NVRTC_ERROR_INVALID_PROGRAM = 4
    NVRTC_ERROR_INVALID_OPTION = 5
    NVRTC_ERROR_COMPILATION = 6
    NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7
    NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8
    NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9
    NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10
    NVRTC_ERROR_INTERNAL_ERROR = 11


_nvrtc_lock = threading.Lock()


class NvrtcProgram:
    """
    A class for managing the lifetime of nvrtcProgram instances. Instances of
    the class own an nvrtcProgram; when an instance is deleted, the underlying
    nvrtcProgram is destroyed using the appropriate NVRTC API.
    """
    def __init__(self, nvrtc, handle):
        self._nvrtc = nvrtc
        self._handle = handle

    @property
    def handle(self):
        return self._handle

    def __del__(self):
        if self._handle:
            self._nvrtc.destroy_program(self)


class NVRTC:
    """
    Provides a Pythonic interface to the NVRTC APIs, abstracting away the C API
    calls.

    The sole instance of this class is a process-wide singleton, similar to the
    NVVM interface. Initialization is protected by a lock and uses the standard
    (for Numba) open_cudalib function to load the NVRTC library.
    """
    _PROTOTYPES = {
        # nvrtcResult nvrtcVersion(int *major, int *minor)
        'nvrtcVersion': (nvrtc_result, POINTER(c_int), POINTER(c_int)),
        # nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
        #                                const char *src,
        #                                const char *name,
        #                                int numHeaders,
        #                                const char * const *headers,
        #                                const char * const *includeNames)
        'nvrtcCreateProgram': (nvrtc_result, nvrtc_program, c_char_p, c_char_p,
                               c_int, POINTER(c_char_p), POINTER(c_char_p)),
        # nvrtcResult nvrtcDestroyProgram(nvrtcProgram *prog);
        'nvrtcDestroyProgram': (nvrtc_result, POINTER(nvrtc_program)),
        # nvrtcResult nvrtcCompileProgram(nvrtcProgram prog,
        #                                 int numOptions,
        #                                 const char * const *options)
        'nvrtcCompileProgram': (nvrtc_result, nvrtc_program, c_int,
                                POINTER(c_char_p)),
        # nvrtcResult nvrtcGetPTXSize(nvrtcProgram prog, size_t *ptxSizeRet);
        'nvrtcGetPTXSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
        # nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx);
        'nvrtcGetPTX': (nvrtc_result, nvrtc_program, c_char_p),
        # nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog,
        #                               size_t *cubinSizeRet);
        'nvrtcGetCUBINSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
        # nvrtcResult nvrtcGetCUBIN(nvrtcProgram prog, char *cubin);
        'nvrtcGetCUBIN': (nvrtc_result, nvrtc_program, c_char_p),
        # nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog,
        #                                    size_t *logSizeRet);
        'nvrtcGetProgramLogSize': (nvrtc_result, nvrtc_program,
                                   POINTER(c_size_t)),
        # nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
        'nvrtcGetProgramLog': (nvrtc_result, nvrtc_program, c_char_p),
    }

    # Singleton reference
    __INSTANCE = None

    def __new__(cls):
        with _nvrtc_lock:
            if cls.__INSTANCE is None:
                from numba.cuda.cudadrv.libs import open_cudalib
                cls.__INSTANCE = inst = object.__new__(cls)
                try:
                    lib = open_cudalib('nvrtc')
                except OSError as e:
                    cls.__INSTANCE = None
                    raise NvrtcSupportError("NVRTC cannot be loaded") from e

                # Find & populate functions
                for name, proto in inst._PROTOTYPES.items():
                    func = getattr(lib, name)
                    func.restype = proto[0]
                    func.argtypes = proto[1:]

                    @functools.wraps(func)
                    def checked_call(*args, func=func, name=name):
                        error = func(*args)
                        if error == NvrtcResult.NVRTC_ERROR_COMPILATION:
                            raise NvrtcCompilationError()
                        elif error != NvrtcResult.NVRTC_SUCCESS:
                            try:
                                error_name = NvrtcResult(error).name
                            except ValueError:
                                error_name = ('Unknown nvrtc_result '
                                              f'(error code: {error})')
                            msg = f'Failed to call {name}: {error_name}'
                            raise NvrtcError(msg)

                    setattr(inst, name, checked_call)

        return cls.__INSTANCE

    def get_version(self):
        """
        Get the NVRTC version as a tuple (major, minor).
        """
        major = c_int()
        minor = c_int()
        self.nvrtcVersion(byref(major), byref(minor))
        return major.value, minor.value

    def create_program(self, src, name):
        """
        Create an NVRTC program with managed lifetime.
        """
        if isinstance(src, str):
            src = src.encode()
        if isinstance(name, str):
            name = name.encode()

        handle = nvrtc_program()

        # The final three arguments are for passing the contents of headers -
        # this is not supported, so there are 0 headers and the header names
        # and contents are null.
        self.nvrtcCreateProgram(byref(handle), src, name, 0, None, None)
        return NvrtcProgram(self, handle)

    def compile_program(self, program, options):
        """
        Compile an NVRTC program. Compilation may fail due to a user error in
        the source; this function returns ``True`` if there is a compilation
        error and ``False`` on success.
        """
        # We hold a list of encoded options to ensure they can't be collected
        # prior to the call to nvrtcCompileProgram
        encoded_options = [opt.encode() for opt in options]
        option_pointers = [c_char_p(opt) for opt in encoded_options]
        c_options_type = (c_char_p * len(options))
        c_options = c_options_type(*option_pointers)
        try:
            self.nvrtcCompileProgram(program.handle, len(options), c_options)
            return False
        except NvrtcCompilationError:
            return True

    def destroy_program(self, program):
        """
        Destroy an NVRTC program.
        """
        self.nvrtcDestroyProgram(byref(program.handle))

    def get_compile_log(self, program):
        """
        Get the compile log as a Python string.
        """
        log_size = c_size_t()
        self.nvrtcGetProgramLogSize(program.handle, byref(log_size))

        log = (c_char * log_size.value)()
        self.nvrtcGetProgramLog(program.handle, log)

        return log.value.decode()

    def get_ptx(self, program):
        """
        Get the compiled PTX as a Python string.
        """
        ptx_size = c_size_t()
        self.nvrtcGetPTXSize(program.handle, byref(ptx_size))

        ptx = (c_char * ptx_size.value)()
        self.nvrtcGetPTX(program.handle, ptx)

        return ptx.value.decode()


def compile(src, name, cc):
    """
    Compile a CUDA C/C++ source to PTX for a given compute capability.

    :param src: The source code to compile
    :type src: str
    :param name: The filename of the source (for information only)
    :type name: str
    :param cc: A tuple ``(major, minor)`` of the compute capability
    :type cc: tuple
    :return: The compiled PTX and compilation log
    :rtype: tuple
    """
    nvrtc = NVRTC()
    program = nvrtc.create_program(src, name)

    # Compilation options:
    # - Compile for the current device's compute capability.
    # - The CUDA include path is added.
    # - Relocatable Device Code (rdc) is needed to prevent device functions
    #   being optimized away.
    major, minor = cc
    arch = f'--gpu-architecture=compute_{major}{minor}'
    include = f'-I{config.CUDA_INCLUDE_PATH}'

    cudadrv_path = os.path.dirname(os.path.abspath(__file__))
    numba_cuda_path = os.path.dirname(cudadrv_path)
    numba_include = f'-I{numba_cuda_path}'
    options = [arch, include, numba_include, '-rdc', 'true']

    # Compile the program
    compile_error = nvrtc.compile_program(program, options)

    # Get log from compilation
    log = nvrtc.get_compile_log(program)

    # If the compile failed, provide the log in an exception
    if compile_error:
        msg = (f'NVRTC Compilation failure whilst compiling {name}:\n\n{log}')
        raise NvrtcError(msg)

    # Otherwise, if there's any content in the log, present it as a warning
    if log:
        msg = (f"NVRTC log messages whilst compiling {name}:\n\n{log}")
        warnings.warn(msg)

    ptx = nvrtc.get_ptx(program)
    return ptx, log