"docs/vscode:/vscode.git/clone" did not exist on "e1957c6ebdd4860f832c26ae4de4195d10803723"
Commit cd84568b authored by Koch's avatar Koch
Browse files

fix: fix errors regarding Reducer functionalities in segment.cpp

parent 1eabf7f1
......@@ -11,23 +11,24 @@ enum ReductionType { ADD, MEAN, MIN, MAX };
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
ReductionType REDUCE = ADD; \
if (reduce == "add") { \
const ReductionType REDUCE = ADD; \
REDUCE = ADD; \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
const ReductionType REDUCE = MEAN; \
REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
const ReductionType REDUCE = MIN; \
REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
const ReductionType REDUCE = MAX; \
REDUCE = MAX; \
return __VA_ARGS__(); \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
template <typename scalar_t> struct Reducer {
static inline scalar_t init(ReductionType REDUCE) {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
......@@ -37,7 +38,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static inline void update(scalar_t *val, scalar_t new_val) {
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
......@@ -46,7 +47,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
......@@ -57,7 +58,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static inline void write(scalar_t *address, scalar_t val,
static inline void write(ReductionType REDUCE, scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == ADD) {
*address = val;
......@@ -136,16 +137,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int k = 0; k < K; k++) {
vals[k] = Reducer<scalar_t, REDUCE>::init();
vals[k] = Reducer<scalar_t>::init(REDUCE);
}
for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update(
Reducer<scalar_t>::update(REDUCE,
&vals[k], src_data[offset + e * K + k], &args[k], e);
}
}
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
......@@ -214,13 +215,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update(
Reducer<scalar_t>::update(REDUCE,
&vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
}
if (e_2 == E_2 - 1) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
Reducer<scalar_t>::write(REDUCE,
out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start);
......@@ -231,7 +232,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
if (idx != next_idx) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
Reducer<scalar_t>::write(REDUCE,
out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment