"scripts/vscode:/vscode.git/clone" did not exist on "77098aea7b2156b69ee140a3b8a1873e931166e2"
Commit f82bfbac authored by rusty1s's avatar rusty1s
Browse files

added experimental flag

parent e30538b1
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
// We need our own `IndexToOffset` implementation since we do not want to // We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`. // access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset { template <typename scalar_t> struct IndexPtrToOffset {
static inline __device__ int static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) { get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1); int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1]; offset *= info.strides[info.dims - 1];
......
...@@ -30,7 +30,7 @@ enum ReductionType { ADD, MEAN, MIN, MAX }; ...@@ -30,7 +30,7 @@ enum ReductionType { ADD, MEAN, MIN, MAX };
}() }()
template <typename scalar_t, ReductionType REDUCE> struct Reducer { template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __device__ scalar_t init() { static inline __host__ __device__ scalar_t init() {
if (REDUCE == MIN) { if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max(); return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) { } else if (REDUCE == MAX) {
...@@ -40,7 +40,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -40,7 +40,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
} }
static inline __device__ void update(scalar_t *val, scalar_t new_val, static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) { int64_t *arg, int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) { if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val; *val = *val + new_val;
...@@ -51,9 +51,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -51,9 +51,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
} }
static inline __device__ void write(scalar_t *address, scalar_t val, static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int64_t *arg_address,
int count) { int64_t arg, int count) {
if (REDUCE == ADD) { if (REDUCE == ADD) {
*address = val; *address = val;
} else if (REDUCE == MEAN) { } else if (REDUCE == MEAN) {
......
...@@ -12,7 +12,7 @@ if '--cpu' in argv: ...@@ -12,7 +12,7 @@ if '--cpu' in argv:
USE_GPU = False USE_GPU = False
cxx_extra_compile_args = [] cxx_extra_compile_args = []
nvcc_extra_compile_args = ['-arch=sm_35'] nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr']
if platform.system() != 'Windows': if platform.system() != 'Windows':
cxx_extra_compile_args += ['-Wno-unused-variable'] cxx_extra_compile_args += ['-Wno-unused-variable']
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment