from setuptools import setup, find_packages from torch.utils.cpp_extension import BuildExtension, CUDAExtension import sys ROCM_HOME= '/opt/rocm-3.3.0/' nvcc='/opt/rocm-3.3.0/hip/bin/hipcc' if sys.version_info < (3,): sys.exit('Sorry, Python3 is required for gnmt.') with open('requirements.txt') as f: reqs = f.read() extra_cuda_compile_args = { 'cxx': ['-O2', ] + ['-DHCC_ENABLE_ACCELERATOR_PRINTF'], #'nvcc': ['--gpu-architecture=sm_70', ] 'nvcc': ['-O3',] + ['-fno-gpu-rdc',]+['--amdgpu-target=gfx906'] } cat_utils = CUDAExtension( name='seq2seq.pack_utils._C', sources=[ 'seq2seq/csrc/pack_utils.cpp', #'seq2seq/csrc/pack_utils_kernel.cu' 'seq2seq/csrc/pack_utils_kernel.hip' ], extra_compile_args=extra_cuda_compile_args ) attn_score = CUDAExtension( name='seq2seq.attn_score._C', sources=[ 'seq2seq/csrc/attn_score_cuda.cpp', #'seq2seq/csrc/attn_score_cuda_kernel.cu', 'seq2seq/csrc/attn_score_hip_kernel.hip', ], extra_compile_args=extra_cuda_compile_args ) setup( name='gnmt', version='0.7.0', description='GNMT', install_requires=reqs.strip().split('\n'), packages=find_packages(), ext_modules=[cat_utils, attn_score], #ext_modules=[cat_utils], cmdclass={ 'build_ext': BuildExtension }, test_suite='tests', )