Commit 7f48de02 authored by xiabo's avatar xiabo
Browse files

适配rocm

parent 0fa8f7af
......@@ -313,6 +313,9 @@ def get_extensions():
define_macros += [('MMCV_WITH_CUDA', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
if is_rocm_pytorch and platform.system() != 'Windows':
extra_compile_args['nvcc'] += \
['--gpu-max-threads-per-block=1024']
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') + \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment