Commit 4a1164b8 authored by rusty1s's avatar rusty1s
Browse files

template reducer

parent ea9d68bd
...@@ -14,34 +14,34 @@ const std::map<std::string, ReductionType> reduce2REDUCE = { ...@@ -14,34 +14,34 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
[&] { \ [&] { \
switch (reduce2REDUCE.at(reduce)) { \ switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \ case SUM: { \
const ReductionType REDUCE = SUM; \ static constexpr ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
case MEAN: { \ case MEAN: { \
const ReductionType REDUCE = MEAN; \ static constexpr ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
case MUL: { \ case MUL: { \
const ReductionType REDUCE = MUL; \ static constexpr ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
case DIV: { \ case DIV: { \
const ReductionType REDUCE = DIV; \ static constexpr ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
case MIN: { \ case MIN: { \
const ReductionType REDUCE = MIN; \ static constexpr ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
case MAX: { \ case MAX: { \
const ReductionType REDUCE = MAX; \ static constexpr ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
} \ } \
}() }()
template <typename scalar_t> struct Reducer { template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init(ReductionType REDUCE) { static inline scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV) if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1; return (scalar_t)1;
else if (REDUCE == MIN) else if (REDUCE == MIN)
...@@ -52,8 +52,8 @@ template <typename scalar_t> struct Reducer { ...@@ -52,8 +52,8 @@ template <typename scalar_t> struct Reducer {
return (scalar_t)0; return (scalar_t)0;
} }
static inline void update(ReductionType REDUCE, scalar_t *val, static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
scalar_t new_val, int64_t *arg, int64_t new_arg) { int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN) if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val; *val = *val + new_val;
else if (REDUCE == MUL) else if (REDUCE == MUL)
...@@ -67,9 +67,8 @@ template <typename scalar_t> struct Reducer { ...@@ -67,9 +67,8 @@ template <typename scalar_t> struct Reducer {
} }
} }
static inline void write(ReductionType REDUCE, scalar_t *address, static inline void write(scalar_t *address, scalar_t val,
scalar_t val, int64_t *arg_address, int64_t arg, int64_t *arg_address, int64_t arg, int count) {
int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val; *address = val;
else if (REDUCE == MEAN) else if (REDUCE == MEAN)
......
...@@ -63,7 +63,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -63,7 +63,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
row_start = rowptr_data[m], row_end = rowptr_data[m + 1]; row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t>::init(REDUCE); vals[k] = Reducer<scalar_t, REDUCE>::init();
auto offset = b * N * K; auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) { for (auto e = row_start; e < row_end; e++) {
...@@ -72,20 +72,19 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -72,20 +72,19 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
val = value_data[e]; val = value_data[e];
for (auto k = 0; k < K; k++) { for (auto k = 0; k < K; k++) {
if (HAS_VALUE) if (HAS_VALUE)
Reducer<scalar_t>::update(REDUCE, &vals[k], Reducer<scalar_t, REDUCE>::update(
val * mat_data[offset + c * K + k], &vals[k], val * mat_data[offset + c * K + k], &args[k],
&args[k], e); e);
else else
Reducer<scalar_t>::update(REDUCE, &vals[k], Reducer<scalar_t, REDUCE>::update(
mat_data[offset + c * K + k], &vals[k], mat_data[offset + c * K + k], &args[k], e);
&args[k], e);
} }
} }
offset = b * M * K + m * K; offset = b * M * K + m * K;
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
Reducer<scalar_t>::write(REDUCE, out_data + offset + k, vals[k], Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
arg_out_data + offset + k, args[k], arg_out_data + offset + k,
row_end - row_start); args[k], row_end - row_start);
} }
} }
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment