setup.py 1.35 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import sys
ROCM_HOME= '/opt/rocm-3.3.0/'
nvcc='/opt/rocm-3.3.0/hip/bin/hipcc'

if sys.version_info < (3,):
    sys.exit('Sorry, Python3 is required for gnmt.')
with open('requirements.txt') as f:
    reqs = f.read()

extra_cuda_compile_args = {
    'cxx': ['-O2', ] + ['-DHCC_ENABLE_ACCELERATOR_PRINTF'],
    #'nvcc': ['--gpu-architecture=sm_70', ]
    'nvcc': ['-O3',] + ['-fno-gpu-rdc',]+['--amdgpu-target=gfx906']
    }

cat_utils = CUDAExtension(
    name='seq2seq.pack_utils._C',
    sources=[
        'seq2seq/csrc/pack_utils.cpp',
        #'seq2seq/csrc/pack_utils_kernel.cu'
        'seq2seq/csrc/pack_utils_kernel.hip'
    ],
    extra_compile_args=extra_cuda_compile_args
)

attn_score = CUDAExtension(
    name='seq2seq.attn_score._C',
    sources=[
        'seq2seq/csrc/attn_score_cuda.cpp',
        #'seq2seq/csrc/attn_score_cuda_kernel.cu',
        'seq2seq/csrc/attn_score_hip_kernel.hip',
    ],
    extra_compile_args=extra_cuda_compile_args
)

setup(
    name='gnmt',
    version='0.7.0',
    description='GNMT',
    install_requires=reqs.strip().split('\n'),
    packages=find_packages(),
    ext_modules=[cat_utils, attn_score],
    #ext_modules=[cat_utils],
    cmdclass={
                'build_ext': BuildExtension
    },
    test_suite='tests',
)