Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
cac0c44a
Unverified
Commit
cac0c44a
authored
Oct 11, 2023
by
q.yao
Committed by
GitHub
Oct 11, 2023
Browse files
temporarily disable mps ops for torch2.1.0 (#2958)
parent
ea53ed02
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
4 deletions
+9
-4
setup.py
setup.py
+5
-3
tests/test_ops/test_bbox.py
tests/test_ops/test_bbox.py
+4
-1
No files found.
setup.py
View file @
cac0c44a
...
@@ -384,8 +384,10 @@ def get_extensions():
...
@@ -384,8 +384,10 @@ def get_extensions():
extra_compile_args
[
'cxx'
]
+=
[
'-ObjC++'
]
extra_compile_args
[
'cxx'
]
+=
[
'-ObjC++'
]
# src
# src
op_files
=
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/*.cpp'
)
+
\
op_files
=
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/*.cpp'
)
+
\
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/cpu/*.cpp'
)
+
\
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/cpu/*.cpp'
)
glob
.
glob
(
'./mmcv/ops/csrc/common/mps/*.mm'
)
+
\
# TODO: support mps ops on torch>=2.1.0
if
parse_version
(
torch
.
__version__
)
<
parse_version
(
'2.1.0'
):
op_files
+=
glob
.
glob
(
'./mmcv/ops/csrc/common/mps/*.mm'
)
+
\
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/mps/*.mm'
)
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/mps/*.mm'
)
extension
=
CppExtension
extension
=
CppExtension
include_dirs
.
append
(
os
.
path
.
abspath
(
'./mmcv/ops/csrc/common'
))
include_dirs
.
append
(
os
.
path
.
abspath
(
'./mmcv/ops/csrc/common'
))
...
...
tests/test_ops/test_bbox.py
View file @
cac0c44a
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
from
mmengine.utils
import
digit_version
from
mmcv.utils
import
(
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
,
IS_MPS_AVAILABLE
,
from
mmcv.utils
import
(
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
,
IS_MPS_AVAILABLE
,
IS_NPU_AVAILABLE
)
IS_NPU_AVAILABLE
)
...
@@ -56,7 +57,9 @@ class TestBBox:
...
@@ -56,7 +57,9 @@ class TestBBox:
pytest
.
param
(
pytest
.
param
(
'mps'
,
'mps'
,
marks
=
pytest
.
mark
.
skipif
(
marks
=
pytest
.
mark
.
skipif
(
not
IS_MPS_AVAILABLE
,
reason
=
'requires MPS support'
)),
not
IS_MPS_AVAILABLE
or
digit_version
(
torch
.
__version__
)
>=
digit_version
(
'2.1.0'
),
reason
=
'requires MPS support'
)),
pytest
.
param
(
pytest
.
param
(
'npu'
,
'npu'
,
marks
=
pytest
.
mark
.
skipif
(
marks
=
pytest
.
mark
.
skipif
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment