Commit 89e1c2cb authored by Stefan Ivanov's avatar Stefan Ivanov
Browse files

Revert "potential windows fix"

This reverts commit 0be33ffa.
parent c4fdd99f
...@@ -40,8 +40,8 @@ const std::map<std::string, ReductionType> reduce2REDUCE = { ...@@ -40,8 +40,8 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
} \ } \
}() }()
template <typename scalar_t> struct Reducer { template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init(ReductionType REDUCE) { static inline scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV) if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1; return (scalar_t)1;
else if (REDUCE == MIN) else if (REDUCE == MIN)
...@@ -52,8 +52,8 @@ template <typename scalar_t> struct Reducer { ...@@ -52,8 +52,8 @@ template <typename scalar_t> struct Reducer {
return (scalar_t)0; return (scalar_t)0;
} }
static inline void update(ReductionType REDUCE, scalar_t *val, static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
scalar_t new_val, int64_t *arg, int64_t new_arg) { int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN) if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val; *val = *val + new_val;
else if (REDUCE == MUL) else if (REDUCE == MUL)
...@@ -67,9 +67,8 @@ template <typename scalar_t> struct Reducer { ...@@ -67,9 +67,8 @@ template <typename scalar_t> struct Reducer {
} }
} }
static inline void write(ReductionType REDUCE, scalar_t *address, static inline void write(scalar_t *address, scalar_t val,
scalar_t val, int64_t *arg_address, int64_t arg, int64_t *arg_address, int64_t arg, int count) {
int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val; *address = val;
else if (REDUCE == MEAN) else if (REDUCE == MEAN)
......
...@@ -61,22 +61,22 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -61,22 +61,22 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
int64_t i, idx; int64_t i, idx;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value()) if (!optional_out.has_value())
out.fill_(Reducer<scalar_t>::init(REDUCE)); out.fill_(Reducer<scalar_t, REDUCE>::init());
for (auto b = 0; b < B; b++) { for (auto b = 0; b < B; b++) {
for (auto e = 0; e < E; e++) { for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++) { for (auto k = 0; k < K; k++) {
i = b * E * K + e * K + k; i = b * E * K + e * K + k;
idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)]; idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)];
Reducer<scalar_t>::update( Reducer<scalar_t, REDUCE>::update(
REDUCE, out_data + b * N * K + idx * K + k, src_data[i], out_data + b * N * K + idx * K + k, src_data[i],
arg_out_data + b * N * K + idx * K + k, e); arg_out_data + b * N * K + idx * K + k, e);
} }
} }
} }
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE), (scalar_t)0); out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
}); });
}); });
......
...@@ -72,7 +72,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -72,7 +72,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
int64_t idx, next_idx, row_start; int64_t idx, next_idx, row_start;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value()) if (!optional_out.has_value())
out.fill_(Reducer<scalar_t>::init(REDUCE)); out.fill_(Reducer<scalar_t, REDUCE>::init());
if (REDUCE == MEAN) if (REDUCE == MEAN)
count_data = arg_out.value().data_ptr<scalar_t>(); count_data = arg_out.value().data_ptr<scalar_t>();
...@@ -87,13 +87,13 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -87,13 +87,13 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
for (auto e = 0; e < E; e++) { for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
Reducer<scalar_t>::update( Reducer<scalar_t, REDUCE>::update(
REDUCE, &vals[k], src_data[b * E * K + e * K + k], &args[k], e); &vals[k], src_data[b * E * K + e * K + k], &args[k], e);
if (e == E - 1) { if (e == E - 1) {
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
Reducer<scalar_t>::write( Reducer<scalar_t, REDUCE>::write(
REDUCE, out_data + b * N * K + idx * K + k, vals[k], out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k], arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start); e + 1 - row_start);
if (REDUCE == MEAN) if (REDUCE == MEAN)
...@@ -104,8 +104,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -104,8 +104,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
if (idx != next_idx) { if (idx != next_idx) {
for (auto k = 0; k < K; k++) { for (auto k = 0; k < K; k++) {
Reducer<scalar_t>::write( Reducer<scalar_t, REDUCE>::write(
REDUCE, out_data + b * N * K + idx * K + k, vals[k], out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k], arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start); e + 1 - row_start);
...@@ -121,7 +121,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -121,7 +121,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
} }
} }
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE), (scalar_t)0); out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MEAN) if (REDUCE == MEAN)
arg_out.value().clamp_(1); arg_out.value().clamp_(1);
......
...@@ -68,15 +68,15 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -68,15 +68,15 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
offset = (n / (indptr.size(-1) - 1)) * E * K; offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t>::init(REDUCE); vals[k] = Reducer<scalar_t, REDUCE>::init();
for (auto e = row_start; e < row_end; e++) for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
Reducer<scalar_t>::update( Reducer<scalar_t, REDUCE>::update(
REDUCE, &vals[k], src_data[offset + e * K + k], &args[k], e); &vals[k], src_data[offset + e * K + k], &args[k], e);
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k], Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k], arg_out_data + n * K + k, args[k],
row_end - row_start); row_end - row_start);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment