Unverified Commit d9549fba authored by zhuyuanhao's avatar zhuyuanhao Committed by GitHub
Browse files

fix cpp header error (#371)

* 1. use macro USE_PARROTS control header include
2. add clang-format google style in pre-commit

* use MMCV_ macros
parent 2c6fc5fd
#ifndef SYNC_BN_KERNEL_CUH
#define SYNC_BN_KERNEL_CUH
#ifndef SYNCBN_CUDA_KERNEL_CUH
#define SYNCBN_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void sync_bn_forward_mean_cuda_kernel(const T *input, float *mean,
......@@ -321,4 +327,4 @@ __global__ void sync_bn_backward_data_cuda_kernel(
}
}
#endif // SYNC_BN_KERNEL_CUH
#endif // SYNCBN_CUDA_KERNEL_CUH
......@@ -150,22 +150,23 @@ def get_extensions():
try:
import torch
cuda_args = [
'-gencode=arch=compute_52,code=sm_52',
'-gencode=arch=compute_60,code=sm_60',
'-gencode=arch=compute_61,code=sm_61',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_70,code=compute_70'
]
ext_name = 'mmcv._ext'
if torch.__version__ == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension
cuda_args = [
'-gencode=arch=compute_60,code=sm_60',
'-gencode=arch=compute_61,code=sm_61',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_70,code=compute_70'
]
define_macros = [('MMCV_USE_PARROTS', None)]
op_files = glob.glob('./mmcv/ops/csrc/parrots/*')
include_path = os.path.abspath('./mmcv/ops/csrc')
ext_ops = Extension(
name=ext_name,
sources=op_files,
include_dirs=[include_path],
define_macros=define_macros,
extra_compile_args={
'nvcc': cuda_args,
'cxx': [],
......@@ -177,12 +178,19 @@ def get_extensions():
CUDAExtension, CppExtension)
# prevent ninja from using too many resources
os.environ.setdefault('MAX_JOBS', '4')
cuda_args = [
'-gencode=arch=compute_52,code=sm_52',
'-gencode=arch=compute_60,code=sm_60',
'-gencode=arch=compute_61,code=sm_61',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_70,code=compute_70'
]
define_macros = []
extra_compile_args = {'cxx': []}
if (torch.cuda.is_available()
or os.getenv('FORCE_CUDA', '0') == '1'):
define_macros += [('WITH_CUDA', None)]
define_macros += [('MMCV_WITH_CUDA', None)]
extra_compile_args['nvcc'] = cuda_args
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*')
extension = CUDAExtension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment