Commit cd84568b authored by Koch's avatar Koch
Browse files

fix: fix errors regarding Reducer functionalities in segment.cpp

parent 1eabf7f1
...@@ -11,23 +11,24 @@ enum ReductionType { ADD, MEAN, MIN, MAX }; ...@@ -11,23 +11,24 @@ enum ReductionType { ADD, MEAN, MIN, MAX };
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ #define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \ [&] { \
ReductionType REDUCE = ADD; \
if (reduce == "add") { \ if (reduce == "add") { \
const ReductionType REDUCE = ADD; \ REDUCE = ADD; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (reduce == "mean") { \ } else if (reduce == "mean") { \
const ReductionType REDUCE = MEAN; \ REDUCE = MEAN; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (reduce == "min") { \ } else if (reduce == "min") { \
const ReductionType REDUCE = MIN; \ REDUCE = MIN; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (reduce == "max") { \ } else if (reduce == "max") { \
const ReductionType REDUCE = MAX; \ REDUCE = MAX; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
}() }()
template <typename scalar_t, ReductionType REDUCE> struct Reducer { template <typename scalar_t> struct Reducer {
static inline scalar_t init() { static inline scalar_t init(ReductionType REDUCE) {
if (REDUCE == MIN) { if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max(); return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) { } else if (REDUCE == MAX) {
...@@ -37,7 +38,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -37,7 +38,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
} }
static inline void update(scalar_t *val, scalar_t new_val) { static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val) {
if (REDUCE == ADD || REDUCE == MEAN) { if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val; *val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) || } else if ((REDUCE == MIN && new_val < *val) ||
...@@ -46,7 +47,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -46,7 +47,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
} }
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) { int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) { if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val; *val = *val + new_val;
...@@ -57,7 +58,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -57,7 +58,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
} }
static inline void write(scalar_t *address, scalar_t val, static inline void write(ReductionType REDUCE, scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) { int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == ADD) { if (REDUCE == ADD) {
*address = val; *address = val;
...@@ -136,16 +137,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, ...@@ -136,16 +137,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
offset = (n / (indptr.size(-1) - 1)) * E * K; offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
vals[k] = Reducer<scalar_t, REDUCE>::init(); vals[k] = Reducer<scalar_t>::init(REDUCE);
} }
for (int64_t e = row_start; e < row_end; e++) { for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t>::update(REDUCE,
&vals[k], src_data[offset + e * K + k], &args[k], e); &vals[k], src_data[offset + e * K + k], &args[k], e);
} }
} }
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k], Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k], arg_out_data + n * K + k, args[k],
row_end - row_start); row_end - row_start);
} }
...@@ -214,13 +215,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -214,13 +215,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
for (int e_2 = 0; e_2 < E_2; e_2++) { for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t>::update(REDUCE,
&vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2); &vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
} }
if (e_2 == E_2 - 1) { if (e_2 == E_2 - 1) {
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write( Reducer<scalar_t>::write(REDUCE,
out_data + e_1 * N * K + idx * K + k, vals[k], out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k], arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start); e_2 + 1 - row_start);
...@@ -231,7 +232,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -231,7 +232,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
if (idx != next_idx) { if (idx != next_idx) {
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write( Reducer<scalar_t>::write(REDUCE,
out_data + e_1 * N * K + idx * K + k, vals[k], out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k], arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start); e_2 + 1 - row_start);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment