Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
6289b6f9
Unverified
Commit
6289b6f9
authored
Mar 21, 2022
by
q.yao
Committed by
GitHub
Mar 21, 2022
Browse files
[Fix] Fix rocm support (#1704)
parent
68a2c0a1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
20 deletions
+7
-20
setup.py
setup.py
+4
-20
tests/test_ops/test_onnx.py
tests/test_ops/test_onnx.py
+3
-0
No files found.
setup.py
View file @
6289b6f9
...
@@ -274,26 +274,10 @@ def get_extensions():
...
@@ -274,26 +274,10 @@ def get_extensions():
except
ImportError
:
except
ImportError
:
pass
pass
project_dir
=
'mmcv/ops/csrc/'
if
is_rocm_pytorch
or
torch
.
cuda
.
is_available
()
or
os
.
getenv
(
if
is_rocm_pytorch
:
'FORCE_CUDA'
,
'0'
)
==
'1'
:
from
torch.utils.hipify
import
hipify_python
if
is_rocm_pytorch
:
define_macros
+=
[(
'HIP_DIFF'
,
None
)]
hipify_python
.
hipify
(
project_directory
=
project_dir
,
output_directory
=
project_dir
,
includes
=
'mmcv/ops/csrc/*'
,
show_detailed
=
True
,
is_pytorch_extension
=
True
,
)
define_macros
+=
[(
'MMCV_WITH_CUDA'
,
None
)]
define_macros
+=
[(
'HIP_DIFF'
,
None
)]
cuda_args
=
os
.
getenv
(
'MMCV_CUDA_ARGS'
)
extra_compile_args
[
'nvcc'
]
=
[
cuda_args
]
if
cuda_args
else
[]
op_files
=
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/hip/*'
)
\
+
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/cpu/hip/*'
)
extension
=
CUDAExtension
include_dirs
.
append
(
os
.
path
.
abspath
(
'./mmcv/ops/csrc/common/hip'
))
elif
torch
.
cuda
.
is_available
()
or
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
define_macros
+=
[(
'MMCV_WITH_CUDA'
,
None
)]
define_macros
+=
[(
'MMCV_WITH_CUDA'
,
None
)]
cuda_args
=
os
.
getenv
(
'MMCV_CUDA_ARGS'
)
cuda_args
=
os
.
getenv
(
'MMCV_CUDA_ARGS'
)
extra_compile_args
[
'nvcc'
]
=
[
cuda_args
]
if
cuda_args
else
[]
extra_compile_args
[
'nvcc'
]
=
[
cuda_args
]
if
cuda_args
else
[]
...
...
tests/test_ops/test_onnx.py
View file @
6289b6f9
...
@@ -656,6 +656,9 @@ def test_modulated_deform_conv2d():
...
@@ -656,6 +656,9 @@ def test_modulated_deform_conv2d():
pytest
.
skip
(
'modulated_deform_conv op is not successfully compiled'
)
pytest
.
skip
(
'modulated_deform_conv op is not successfully compiled'
)
ort_custom_op_path
=
get_onnxruntime_op_path
()
ort_custom_op_path
=
get_onnxruntime_op_path
()
if
not
os
.
path
.
exists
(
ort_custom_op_path
):
pytest
.
skip
(
'custom ops for onnxruntime are not compiled.'
)
# modulated deform conv config
# modulated deform conv config
in_channels
=
3
in_channels
=
3
out_channels
=
64
out_channels
=
64
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment