Commit 27432c85 authored by xiabo's avatar xiabo
Browse files

dtk2210.1 torch1.8.0

parent b8c09f3b
......@@ -7,7 +7,12 @@
#include "pytorch_cuda_helper.hpp"
#endif
#ifdef HIP_DIFF
#define WARP_SIZE 32
#else
#define WARP_SIZE 64
#endif
#define THREADS_PER_PIXEL 32
#define MAX_SHARED_MEMORY 49152
#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144
......@@ -25,6 +30,7 @@ __device__ inline int Loc2Index(const int n, const int c, const int h,
return index;
}
/* TODO: move this to a common place */
#ifndef HIP_DIFF
template <typename scalar_t>
__device__ inline scalar_t min(scalar_t a, scalar_t b) {
return a < b ? a : b;
......@@ -34,19 +40,28 @@ template <typename scalar_t>
__device__ inline scalar_t max(scalar_t a, scalar_t b) {
return a > b ? a : b;
}
#endif
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
for (int offset = 16; offset > 0; offset /= 2)
#ifdef HIP_DIFF
val += __shfl_down(val, offset);
#else
val += __shfl_down_sync(FULL_MASK, val, offset);
#endif
return val;
}
template <>
__device__ __forceinline__ phalf warpReduceSum(phalf val) {
for (int offset = 16; offset > 0; offset /= 2)
#ifdef HIP_DIFF
__PHALF(val) += __shfl_down(FULL_MASK, val, offset);
#else
__PHALF(val) +=
__shfl_down_sync(FULL_MASK, static_cast<__half>(__PHALF(val)), offset);
#endif
return val;
}
......@@ -302,7 +317,11 @@ __global__ void CARAFEBackward_Mask(const int num_kernels,
output_val += top_diff[top_id] * bottom_data[bottom_id];
}
}
#ifdef HIP_DIFF
__syncthreads();
#else
__syncwarp();
#endif
output_val = warpReduceSum(output_val);
if (lane_id == 0) {
const int mask_id =
......
......@@ -3,12 +3,15 @@
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#include <cuda_runtime_api.h>
int get_cudart_version() { return CUDART_VERSION; }
#endif
#endif
std::string get_compiling_cuda_version() {
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
......@@ -20,6 +23,9 @@ std::string get_compiling_cuda_version() {
};
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else
return std::string("rocm not vailable");
#endif
#else
return std::string("not available");
#endif
......
......@@ -3,6 +3,9 @@ import os
import re
from pkg_resources import DistributionNotFound, get_distribution
from setuptools import find_packages, setup
import subprocess
from typing import Optional, Union
from pathlib import Path
EXT_TYPE = ''
try:
......@@ -30,8 +33,30 @@ def choose_requirement(primary, secondary):
return str(primary)
def get_sha(pytorch_root: Union[str, Path]) -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=pytorch_root).decode('ascii').strip()
except Exception:
return 'Unknown'
def get_version_add(sha: Optional[str] = None) -> str:
mmcv_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(mmcv_root, "mmcv"), "version.py")
if sha != 'Unknown':
if sha is None:
sha = get_sha(mmcv_root)
version = 'git' + sha[:7]
if os.getenv('MMCV_BUILD_VERSION'):
version_dtk = os.getenv('MMCV_BUILD_VERSION', "")
version += "." + version_dtk
with open(add_version_path, encoding="utf-8",mode="a") as file:
file.write("__version__=__version__+'+{}'\n".format(version))
file.close()
def get_version():
get_version_add()
version_file = 'mmcv/version.py'
with open(version_file, 'r', encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
......@@ -220,7 +245,19 @@ def get_extensions():
define_macros = []
extra_compile_args = {'cxx': []}
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
is_rocm_pytorch = False
try:
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
except ImportError:
pass
# if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
if is_rocm_pytorch or torch.cuda.is_available() or os.getenv(
'FORCE_CUDA', '0') == '1':
if is_rocm_pytorch:
define_macros += [('HIP_DIFF', None)]
define_macros += [('MMCV_WITH_CUDA', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment