Commit 4a1164b8 authored by rusty1s's avatar rusty1s
Browse files

template reducer

parent ea9d68bd
......@@ -14,34 +14,34 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
static constexpr ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
static constexpr ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
static constexpr ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
static constexpr ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
static constexpr ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
static constexpr ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t> struct Reducer {
static inline scalar_t init(ReductionType REDUCE) {
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
......@@ -52,8 +52,8 @@ template <typename scalar_t> struct Reducer {
return (scalar_t)0;
}
static inline void update(ReductionType REDUCE, scalar_t *val,
scalar_t new_val, int64_t *arg, int64_t new_arg) {
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
......@@ -67,9 +67,8 @@ template <typename scalar_t> struct Reducer {
}
}
static inline void write(ReductionType REDUCE, scalar_t *address,
scalar_t val, int64_t *arg_address, int64_t arg,
int count) {
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
......
......@@ -63,7 +63,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t>::init(REDUCE);
vals[k] = Reducer<scalar_t, REDUCE>::init();
auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) {
......@@ -72,20 +72,19 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
val = value_data[e];
for (auto k = 0; k < K; k++) {
if (HAS_VALUE)
Reducer<scalar_t>::update(REDUCE, &vals[k],
val * mat_data[offset + c * K + k],
&args[k], e);
Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + c * K + k], &args[k],
e);
else
Reducer<scalar_t>::update(REDUCE, &vals[k],
mat_data[offset + c * K + k],
&args[k], e);
Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + c * K + k], &args[k], e);
}
}
offset = b * M * K + m * K;
for (auto k = 0; k < K; k++)
Reducer<scalar_t>::write(REDUCE, out_data + offset + k, vals[k],
arg_out_data + offset + k, args[k],
row_end - row_start);
Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
arg_out_data + offset + k,
args[k], row_end - row_start);
}
}
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment