Commit 9466dff7 authored by zhangwenwei's avatar zhangwenwei
Browse files

Refactor setup.py

parent bcb9d628
...@@ -3,6 +3,6 @@ line_length = 79 ...@@ -3,6 +3,6 @@ line_length = 79
multi_line_output = 0 multi_line_output = 0
known_standard_library = setuptools known_standard_library = setuptools
known_first_party = mmdet,mmdet3d known_first_party = mmdet,mmdet3d
known_third_party = Cython,cv2,mmcv,numba,numpy,nuscenes,pycocotools,pyquaternion,scipy,shapely,six,skimage,terminaltables,torch,torchvision known_third_party = cv2,mmcv,numba,numpy,nuscenes,pycocotools,pyquaternion,scipy,shapely,six,skimage,terminaltables,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY
import os import os
import platform
import subprocess import subprocess
import time import time
from setuptools import Extension, find_packages, setup from setuptools import find_packages, setup
import numpy as np import torch
from Cython.Build import cythonize from torch.utils.cpp_extension import (BuildExtension, CppExtension,
from torch.utils.cpp_extension import BuildExtension, CUDAExtension CUDAExtension)
def readme(): def readme():
...@@ -16,10 +15,13 @@ def readme(): ...@@ -16,10 +15,13 @@ def readme():
MAJOR = 0 MAJOR = 0
MINOR = 1 MINOR = 0
PATCH = '' PATCH = 0
SUFFIX = 'rc0' SUFFIX = 'rc0'
SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX) if PATCH != '':
SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX)
else:
SHORT_VERSION = '{}.{}{}'.format(MAJOR, MINOR, SUFFIX)
version_file = 'mmdet3d/version.py' version_file = 'mmdet3d/version.py'
...@@ -84,38 +86,36 @@ def get_version(): ...@@ -84,38 +86,36 @@ def get_version():
return locals()['__version__'] return locals()['__version__']
def make_cuda_ext(name, module, sources, extra_args=[], extra_include_path=[]): def make_cuda_ext(name,
return CUDAExtension( module,
sources,
sources_cuda=[],
extra_args=[],
extra_include_path=[]):
define_macros = []
extra_compile_args = {'cxx': [] + extra_args}
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('WITH_CUDA', None)]
extension = CUDAExtension
extra_compile_args['nvcc'] = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
sources += sources_cuda
else:
print('Compiling {} without CUDA'.format(name))
extension = CppExtension
# raise EnvironmentError('CUDA is required to compile MMDetection!')
return extension(
name='{}.{}'.format(module, name), name='{}.{}'.format(module, name),
define_macros=[('WITH_CUDA', None)],
sources=[os.path.join(*module.split('.'), p) for p in sources], sources=[os.path.join(*module.split('.'), p) for p in sources],
include_dirs=extra_include_path, include_dirs=extra_include_path,
extra_compile_args={ define_macros=define_macros,
'cxx': [] + extra_args,
'nvcc':
extra_args + [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
})
def make_cython_ext(name, module, sources):
extra_compile_args = None
if platform.system() != 'Windows':
extra_compile_args = {
'cxx': ['-Wno-unused-function', '-Wno-write-strings']
}
extension = Extension(
'{}.{}'.format(module, name),
[os.path.join(*module.split('.'), p) for p in sources],
include_dirs=[np.get_include()],
language='c++',
extra_compile_args=extra_compile_args) extra_compile_args=extra_compile_args)
extension, = cythonize(extension)
return extension
def parse_requirements(fname='requirements.txt', with_version=True): def parse_requirements(fname='requirements.txt', with_version=True):
...@@ -210,7 +210,6 @@ if __name__ == '__main__': ...@@ -210,7 +210,6 @@ if __name__ == '__main__':
'License :: OSI Approved :: Apache Software License', 'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent', 'Operating System :: OS Independent',
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment