Unverified Commit a4dc2a72 authored by pc's avatar pc Committed by GitHub
Browse files

support device dispatch in parrots (#1588)

parent 0bcbeadb
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -189,13 +189,14 @@ def get_extensions(): ...@@ -189,13 +189,14 @@ def get_extensions():
define_macros = [] define_macros = []
include_dirs = [] include_dirs = []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') +\ op_files = glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') +\
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') +\
glob.glob('./mmcv/ops/csrc/parrots/*.cpp') glob.glob('./mmcv/ops/csrc/parrots/*.cpp')
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
cuda_args = os.getenv('MMCV_CUDA_ARGS') cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args = { extra_compile_args = {
'nvcc': [cuda_args] if cuda_args else [], 'nvcc': [cuda_args, '-std=c++14'] if cuda_args else ['-std=c++14'],
'cxx': [], 'cxx': ['-std=c++14'],
} }
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('MMCV_WITH_CUDA', None)] define_macros += [('MMCV_WITH_CUDA', None)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment