build_cffi.py 994 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# This file contains the cffi-extension call to build the custom
# kernel used by amp.
# For mysterious reasons, it needs to live at the top-level directory.
# TODO: remove this when we move to cpp-extension.


import os
import torch
from torch.utils.ffi import create_extension

abs_path = os.path.dirname(os.path.realpath(__file__))

sources = ['apex/amp/src/scale_cuda.c']
headers = ['apex/amp/src/scale_cuda.h']
defines = [('WITH_CUDA', None)]
with_cuda = True

extra_objects = [os.path.join(abs_path, 'build/scale_kernel.o')]

# When running `python build_cffi.py` directly, set package=False. But
# if it's used with `cffi_modules` in setup.py, then set package=True.
package = (__name__ != '__main__')

extension = create_extension(
    'apex.amp._C.scale_lib',
    package=package,
    headers=headers,
    sources=sources,
    define_macros=defines,
    relative_to=__file__,
    with_cuda=with_cuda,
    extra_objects=extra_objects
)

if __name__ == '__main__':
    extension.build()