Commit 0be33ffa authored by rusty1s's avatar rusty1s
Browse files

potential windows fix

parent feca30d1
......@@ -40,8 +40,8 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
template <typename scalar_t> struct Reducer {
static inline scalar_t init(ReductionType REDUCE) {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
......@@ -52,8 +52,8 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
return (scalar_t)0;
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
static inline void update(ReductionType REDUCE, scalar_t *val,
scalar_t new_val, int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
......@@ -67,8 +67,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
static inline void write(ReductionType REDUCE, scalar_t *address,
scalar_t val, int64_t *arg_address, int64_t arg,
int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
......
......@@ -61,22 +61,22 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
int64_t i, idx;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
out.fill_(Reducer<scalar_t>::init(REDUCE));
for (auto b = 0; b < B; b++) {
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++) {
i = b * E * K + e * K + k;
idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)];
Reducer<scalar_t, REDUCE>::update(
out_data + b * N * K + idx * K + k, src_data[i],
Reducer<scalar_t>::update(
REDUCE, out_data + b * N * K + idx * K + k, src_data[i],
arg_out_data + b * N * K + idx * K + k, e);
}
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE), (scalar_t)0);
});
});
......
......@@ -72,7 +72,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
int64_t idx, next_idx, row_start;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
out.fill_(Reducer<scalar_t>::init(REDUCE));
if (REDUCE == MEAN)
count_data = arg_out.value().data_ptr<scalar_t>();
......@@ -87,13 +87,13 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[b * E * K + e * K + k], &args[k], e);
Reducer<scalar_t>::update(
REDUCE, &vals[k], src_data[b * E * K + e * K + k], &args[k], e);
if (e == E - 1) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
Reducer<scalar_t>::write(
REDUCE, out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
if (REDUCE == MEAN)
......@@ -104,8 +104,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
if (idx != next_idx) {
for (auto k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
Reducer<scalar_t>::write(
REDUCE, out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
......@@ -121,7 +121,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE), (scalar_t)0);
if (REDUCE == MEAN)
arg_out.value().clamp_(1);
......
......@@ -68,17 +68,17 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
vals[k] = Reducer<scalar_t>::init(REDUCE);
for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e);
Reducer<scalar_t>::update(
REDUCE, &vals[k], src_data[offset + e * K + k], &args[k], e);
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment