Unverified Commit 6289b6f9 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Fix] Fix rocm support (#1704)

parent 68a2c0a1
...@@ -274,26 +274,10 @@ def get_extensions(): ...@@ -274,26 +274,10 @@ def get_extensions():
except ImportError: except ImportError:
pass pass
project_dir = 'mmcv/ops/csrc/' if is_rocm_pytorch or torch.cuda.is_available() or os.getenv(
if is_rocm_pytorch: 'FORCE_CUDA', '0') == '1':
from torch.utils.hipify import hipify_python if is_rocm_pytorch:
define_macros += [('HIP_DIFF', None)]
hipify_python.hipify(
project_directory=project_dir,
output_directory=project_dir,
includes='mmcv/ops/csrc/*',
show_detailed=True,
is_pytorch_extension=True,
)
define_macros += [('MMCV_WITH_CUDA', None)]
define_macros += [('HIP_DIFF', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/hip/*') \
+ glob.glob('./mmcv/ops/csrc/pytorch/cpu/hip/*')
extension = CUDAExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/hip'))
elif torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('MMCV_WITH_CUDA', None)] define_macros += [('MMCV_WITH_CUDA', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS') cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else [] extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
......
...@@ -656,6 +656,9 @@ def test_modulated_deform_conv2d(): ...@@ -656,6 +656,9 @@ def test_modulated_deform_conv2d():
pytest.skip('modulated_deform_conv op is not successfully compiled') pytest.skip('modulated_deform_conv op is not successfully compiled')
ort_custom_op_path = get_onnxruntime_op_path() ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')
# modulated deform conv config # modulated deform conv config
in_channels = 3 in_channels = 3
out_channels = 64 out_channels = 64
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment