"...test_cli/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "1b71bb9309de2857bc152b94138a2296c7df0e68"
Commit f82bfbac authored by rusty1s's avatar rusty1s
Browse files

added experimental flag

parent e30538b1
......@@ -6,7 +6,7 @@
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset {
static inline __device__ int
static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
......
......@@ -30,7 +30,7 @@ enum ReductionType { ADD, MEAN, MIN, MAX };
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __device__ scalar_t init() {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
......@@ -40,8 +40,8 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static inline __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
......@@ -51,9 +51,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static inline __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg,
int count) {
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == ADD) {
*address = val;
} else if (REDUCE == MEAN) {
......
......@@ -12,7 +12,7 @@ if '--cpu' in argv:
USE_GPU = False
cxx_extra_compile_args = []
nvcc_extra_compile_args = ['-arch=sm_35']
nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr']
if platform.system() != 'Windows':
cxx_extra_compile_args += ['-Wno-unused-variable']
TORCH_MAJOR = int(torch.__version__.split('.')[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment