Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
b18a3383
Unverified
Commit
b18a3383
authored
Jul 07, 2020
by
Yuanhao Zhu
Committed by
GitHub
Jul 07, 2020
Browse files
fix macOS compile (#386)
parent
e2ee171a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
63 deletions
+64
-63
README.rst
README.rst
+6
-0
mmcv/ops/csrc/pytorch/nms.cpp
mmcv/ops/csrc/pytorch/nms.cpp
+1
-1
setup.py
setup.py
+57
-62
No files found.
README.rst
View file @
b18a3383
...
@@ -54,6 +54,12 @@ or install from source
...
@@ -54,6 +54,12 @@ or install from source
cd mmcv
cd mmcv
pip install -e .
pip install -e .
If you are on macOS, replace the last command with
.. code::
CC=lang CXX=clang++ CFLAGS='-stdlib=libc++' pip install -e .
Note: If you would like to use :code:`opencv-python-headless` instead of :code:`opencv-python`,
Note: If you would like to use :code:`opencv-python-headless` instead of :code:`opencv-python`,
e.g., in a minimum container environment or servers without GUI,
e.g., in a minimum container environment or servers without GUI,
you can first install it before installing MMCV to skip the installation of :code:`opencv-python`.
you can first install it before installing MMCV to skip the installation of :code:`opencv-python`.
mmcv/ops/csrc/pytorch/nms.cpp
View file @
b18a3383
...
@@ -102,7 +102,7 @@ Tensor softnms_cpu(Tensor boxes, Tensor scores, Tensor dets,
...
@@ -102,7 +102,7 @@ Tensor softnms_cpu(Tensor boxes, Tensor scores, Tensor dets,
int64_t
pos
=
0
;
int64_t
pos
=
0
;
Tensor
inds_t
=
at
::
arange
(
nboxes
,
boxes
.
options
().
dtype
(
at
::
kLong
));
Tensor
inds_t
=
at
::
arange
(
nboxes
,
boxes
.
options
().
dtype
(
at
::
kLong
));
auto
inds
=
inds_t
.
data_ptr
<
long
>
();
auto
inds
=
inds_t
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
nboxes
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
nboxes
;
i
++
)
{
auto
max_score
=
sc
[
i
];
auto
max_score
=
sc
[
i
];
...
...
setup.py
View file @
b18a3383
import
glob
import
glob
import
os
import
os
import
platform
import
re
import
re
import
setuptools
import
setuptools
from
pkg_resources
import
DistributionNotFound
,
get_distribution
from
pkg_resources
import
DistributionNotFound
,
get_distribution
...
@@ -10,7 +9,19 @@ dist.Distribution().fetch_build_eggs(['Cython', 'numpy>=1.11.1'])
...
@@ -10,7 +9,19 @@ dist.Distribution().fetch_build_eggs(['Cython', 'numpy>=1.11.1'])
import
numpy
# NOQA: E402 # isort:skip
import
numpy
# NOQA: E402 # isort:skip
from
Cython.Build
import
cythonize
# NOQA: E402 # isort:skip
from
Cython.Build
import
cythonize
# NOQA: E402 # isort:skip
from
Cython.Distutils
import
build_ext
as
build_cmd
# NOQA: E402 # isort:skip
EXT_TYPE
=
''
try
:
import
torch
if
torch
.
__version__
==
'parrots'
:
from
parrots.utils.build_extension
import
BuildExtension
EXT_TYPE
=
'parrots'
else
:
from
torch.utils.cpp_extension
import
BuildExtension
EXT_TYPE
=
'pytorch'
except
ModuleNotFoundError
:
from
Cython.Distutils
import
build_ext
as
BuildExtension
print
(
'Skip building ext ops due to the absence of torch.'
)
def
choose_requirement
(
primary
,
secondary
):
def
choose_requirement
(
primary
,
secondary
):
...
@@ -125,13 +136,6 @@ for main, secondary in CHOOSE_INSTALL_REQUIRES:
...
@@ -125,13 +136,6 @@ for main, secondary in CHOOSE_INSTALL_REQUIRES:
def
get_extensions
():
def
get_extensions
():
extensions
=
[]
extensions
=
[]
if
platform
.
system
()
==
'Darwin'
:
extra_compile_args
=
[
'-stdlib=libc++'
]
extra_link_args
=
[
'-stdlib=libc++'
]
else
:
extra_compile_args
=
[]
extra_link_args
=
[]
ext_flow
=
setuptools
.
Extension
(
ext_flow
=
setuptools
.
Extension
(
name
=
'mmcv._flow_warp_ext'
,
name
=
'mmcv._flow_warp_ext'
,
sources
=
[
sources
=
[
...
@@ -139,63 +143,54 @@ def get_extensions():
...
@@ -139,63 +143,54 @@ def get_extensions():
'./mmcv/video/optflow_warp/flow_warp_module.pyx'
'./mmcv/video/optflow_warp/flow_warp_module.pyx'
],
],
include_dirs
=
[
numpy
.
get_include
()],
include_dirs
=
[
numpy
.
get_include
()],
language
=
'c++'
,
language
=
'c++'
)
extra_compile_args
=
extra_compile_args
,
extra_link_args
=
extra_link_args
)
extensions
.
extend
(
cythonize
(
ext_flow
))
extensions
.
extend
(
cythonize
(
ext_flow
))
try
:
if
EXT_TYPE
==
'parrots'
:
import
torch
ext_name
=
'mmcv._ext'
from
parrots.utils.build_extension
import
Extension
define_macros
=
[(
'MMCV_USE_PARROTS'
,
None
)]
op_files
=
glob
.
glob
(
'./mmcv/ops/csrc/parrots/*'
)
include_path
=
os
.
path
.
abspath
(
'./mmcv/ops/csrc'
)
cuda_args
=
os
.
getenv
(
'MMCV_CUDA_ARGS'
)
ext_ops
=
Extension
(
name
=
ext_name
,
sources
=
op_files
,
include_dirs
=
[
include_path
],
define_macros
=
define_macros
,
extra_compile_args
=
{
'nvcc'
:
[
cuda_args
]
if
cuda_args
else
[],
'cxx'
:
[],
},
cuda
=
True
)
extensions
.
append
(
ext_ops
)
elif
EXT_TYPE
==
'pytorch'
:
ext_name
=
'mmcv._ext'
ext_name
=
'mmcv._ext'
if
torch
.
__version__
==
'parrots'
:
from
torch.utils.cpp_extension
import
(
CUDAExtension
,
CppExtension
)
from
parrots.utils.build_extension
import
BuildExtension
,
Extension
# prevent ninja from using too many resources
define_macros
=
[(
'MMCV_USE_PARROTS'
,
None
)]
os
.
environ
.
setdefault
(
'MAX_JOBS'
,
'4'
)
op_files
=
glob
.
glob
(
'./mmcv/ops/csrc/parrots/*'
)
define_macros
=
[]
include_path
=
os
.
path
.
abspath
(
'./mmcv/ops/csrc'
)
extra_compile_args
=
{
'cxx'
:
[]}
if
torch
.
cuda
.
is_available
()
or
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
define_macros
+=
[(
'MMCV_WITH_CUDA'
,
None
)]
cuda_args
=
os
.
getenv
(
'MMCV_CUDA_ARGS'
)
cuda_args
=
os
.
getenv
(
'MMCV_CUDA_ARGS'
)
ext_ops
=
Extension
(
extra_compile_args
[
'nvcc'
]
=
[
cuda_args
]
if
cuda_args
else
[]
name
=
ext_name
,
op_files
=
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/*'
)
sources
=
op_files
,
extension
=
CUDAExtension
include_dirs
=
[
include_path
],
define_macros
=
define_macros
,
extra_compile_args
=
{
'nvcc'
:
[
cuda_args
]
if
cuda_args
else
[],
'cxx'
:
[],
},
cuda
=
True
)
extensions
.
append
(
ext_ops
)
else
:
else
:
from
torch.utils.cpp_extension
import
(
BuildExtension
,
print
(
f
'Compiling
{
ext_name
}
without CUDA'
)
CUDAExtension
,
CppExtension
)
op_files
=
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/*.cpp'
)
# prevent ninja from using too many resources
extension
=
CppExtension
os
.
environ
.
setdefault
(
'MAX_JOBS'
,
'4'
)
define_macros
=
[]
include_path
=
os
.
path
.
abspath
(
'./mmcv/ops/csrc'
)
extra_compile_args
=
{
'cxx'
:
[]}
ext_ops
=
extension
(
name
=
ext_name
,
if
(
torch
.
cuda
.
is_available
()
sources
=
op_files
,
or
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
):
include_dirs
=
[
include_path
],
define_macros
+=
[(
'MMCV_WITH_CUDA'
,
None
)]
define_macros
=
define_macros
,
cuda_args
=
os
.
getenv
(
'MMCV_CUDA_ARGS'
)
extra_compile_args
=
extra_compile_args
)
extra_compile_args
[
'nvcc'
]
=
[
cuda_args
]
if
cuda_args
else
[]
extensions
.
append
(
ext_ops
)
op_files
=
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/*'
)
extension
=
CUDAExtension
else
:
print
(
f
'Compiling
{
ext_name
}
without CUDA'
)
op_files
=
glob
.
glob
(
'./mmcv/ops/csrc/pytorch/*.cpp'
)
extension
=
CppExtension
include_path
=
os
.
path
.
abspath
(
'./mmcv/ops/csrc'
)
ext_ops
=
extension
(
name
=
ext_name
,
sources
=
op_files
,
include_dirs
=
[
include_path
],
define_macros
=
define_macros
,
extra_compile_args
=
extra_compile_args
)
extensions
.
append
(
ext_ops
)
global
build_cmd
build_cmd
=
BuildExtension
except
ModuleNotFoundError
:
print
(
'Skip building ext ops due to the absence of torch.'
)
return
extensions
return
extensions
...
@@ -224,5 +219,5 @@ setup(
...
@@ -224,5 +219,5 @@ setup(
tests_require
=
[
'pytest'
],
tests_require
=
[
'pytest'
],
install_requires
=
install_requires
,
install_requires
=
install_requires
,
ext_modules
=
get_extensions
(),
ext_modules
=
get_extensions
(),
cmdclass
=
{
'build_ext'
:
b
uild
_cmd
},
cmdclass
=
{
'build_ext'
:
B
uild
Extension
},
zip_safe
=
False
)
zip_safe
=
False
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment