Commit 3cf59da2 authored by rusty1s's avatar rusty1s
Browse files

add to sum, REDUCE to template

parent 7aa701b1
......@@ -122,11 +122,11 @@ def timing(dataset):
avg_row_len = row.size(0) / dim_size
def sca_row(x):
op = getattr(torch_scatter, f'scatter_{args.reduce}')
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
return op(x, row, dim=0, dim_size=dim_size)
def sca_col(x):
op = getattr(torch_scatter, f'scatter_{args.reduce}')
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
return op(x, row_perm, dim=0, dim_size=dim_size)
def seg_coo(x):
......@@ -136,10 +136,10 @@ def timing(dataset):
return segment_csr(x, rowptr, reduce=args.reduce)
def dense1(x):
return getattr(torch, args.dense_reduce)(x, dim=-2)
return getattr(torch, args.reduce)(x, dim=-2)
def dense2(x):
return getattr(torch, args.dense_reduce)(x, dim=-1)
return getattr(torch, args.reduce)(x, dim=-1)
t1, t2, t3, t4, t5, t6 = [], [], [], [], [], []
......@@ -204,15 +204,12 @@ def timing(dataset):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--reduce',
type=str,
required=True,
choices=['add', 'mean', 'min', 'max'])
parser.add_argument('--reduce', type=str, required=True,
choices=['sum', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes
......
......@@ -7,28 +7,36 @@
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
enum ReductionType { ADD, MEAN, MIN, MAX };
enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
ReductionType REDUCE = ADD; \
if (reduce == "add") { \
REDUCE = ADD; \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
REDUCE = MEAN; \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
REDUCE = MIN; \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
REDUCE = MAX; \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t> struct Reducer {
static inline scalar_t init(ReductionType REDUCE) {
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
......@@ -38,18 +46,9 @@ template <typename scalar_t> struct Reducer {
}
}
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
}
}
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val, int64_t *arg,
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) {
if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
......@@ -58,9 +57,9 @@ template <typename scalar_t> struct Reducer {
}
}
static inline void write(ReductionType REDUCE, scalar_t *address, scalar_t val,
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == ADD) {
if (REDUCE == SUM) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (count > 0 ? count : (scalar_t)1);
......@@ -111,7 +110,7 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") {
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
......@@ -137,16 +136,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int k = 0; k < K; k++) {
vals[k] = Reducer<scalar_t>::init(REDUCE);
vals[k] = Reducer<scalar_t, REDUCE>::init();
}
for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t>::update(REDUCE,
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e);
}
}
for (int k = 0; k < K; k++) {
Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k],
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
......@@ -183,7 +182,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") {
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
......@@ -215,13 +214,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t>::update(REDUCE,
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
}
if (e_2 == E_2 - 1) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t>::write(REDUCE,
Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start);
......@@ -232,7 +231,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
if (idx != next_idx) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t>::write(REDUCE,
Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start);
......
......@@ -11,23 +11,32 @@
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
enum ReductionType { ADD, MEAN, MIN, MAX };
enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
if (reduce == "add") { \
const ReductionType REDUCE = ADD; \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
......@@ -43,7 +52,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ void update(scalar_t *val,
scalar_t new_val) {
if (REDUCE == ADD || REDUCE == MEAN) {
if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
......@@ -53,7 +62,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) {
if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
......@@ -65,7 +74,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == ADD) {
if (REDUCE == SUM) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (scalar_t)max(count, 1);
......@@ -80,7 +89,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
if (REDUCE == ADD || REDUCE == MEAN) {
if (REDUCE == SUM || REDUCE == MEAN) {
atomAdd(address, val);
} else if (REDUCE == MIN && val < *address) {
atomMin(address, val);
......@@ -204,7 +213,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") {
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
......@@ -396,7 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") {
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
......@@ -455,14 +464,14 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
});
});
if (reduce == "mean") {
if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec();
sizes[reduce_dim] = out.size(reduce_dim);
auto count = at::zeros(sizes, out.options());
AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] {
auto count_data = count.DATA_PTR<scalar_t>();
segment_coo_kernel<scalar_t, ADD, false>
segment_coo_kernel<scalar_t, SUM, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, E, N);
});
......
......@@ -7,15 +7,15 @@ from torch_scatter import segment_coo, segment_csr
from .utils import tensor, dtypes, devices
reductions = ['add', 'mean', 'min', 'max']
grad_reductions = ['add', 'mean']
reductions = ['sum', 'mean', 'min', 'max']
grad_reductions = ['sum', 'mean']
tests = [
{
'src': [1, 2, 3, 4, 5, 6],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'add': [3, 12, 0, 6],
'sum': [3, 12, 0, 6],
'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6],
'arg_min': [0, 2, 6, 5],
......@@ -26,7 +26,7 @@ tests = [
'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]],
......@@ -37,7 +37,7 @@ tests = [
'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]],
......@@ -48,7 +48,7 @@ tests = [
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
'index': [[0, 0, 1], [0, 2, 2]],
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]],
......@@ -59,7 +59,7 @@ tests = [
'src': [[1, 3], [2, 4]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'add': [[4], [6]],
'sum': [[4], [6]],
'mean': [[2], [3]],
'min': [[1], [2]],
'arg_min': [[0], [0]],
......@@ -70,7 +70,7 @@ tests = [
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'add': [[[4, 4]], [[6, 6]]],
'sum': [[[4, 4]], [[6, 6]]],
'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]],
'arg_min': [[[0, 0]], [[0, 0]]],
......@@ -134,7 +134,7 @@ def test_segment_out(test, reduce, dtype, device):
segment_coo(src, index, out, reduce=reduce)
if reduce == 'add':
if reduce == 'sum':
expected = expected - 2
elif reduce == 'mean':
expected = out # We can not really test this here.
......
......@@ -31,7 +31,7 @@ class GatherCOO(torch.autograd.Function):
grad_src = None
if ctx.needs_input_grad[0]:
grad_src, _ = seg(grad_out.is_cuda).segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'add')
grad_out, index, grad_out.new_zeros(src_size), 'sum')
return grad_src, None, None
......@@ -53,7 +53,7 @@ class GatherCSR(torch.autograd.Function):
grad_src = None
if ctx.needs_input_grad[0]:
grad_src, _ = seg(grad_out.is_cuda).segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'add')
grad_out, indptr, grad_out.new_empty(src_size), 'sum')
return grad_src, None, None
......
......@@ -18,7 +18,7 @@ def gat(is_cuda):
class SegmentCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, src, index, out, dim_size, reduce):
assert reduce in ['add', 'mean', 'min', 'max']
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if out is not None:
ctx.mark_dirty(out)
ctx.reduce = reduce
......@@ -55,7 +55,7 @@ class SegmentCOO(torch.autograd.Function):
grad_src = None
if ctx.needs_input_grad[0]:
if ctx.reduce == 'add':
if ctx.reduce == 'sum' or ctx.reduce == 'add':
grad_src = gat(grad_out.is_cuda).gather_coo(
grad_out, index, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
......@@ -68,7 +68,7 @@ class SegmentCOO(torch.autograd.Function):
size[-1] = grad_out.size(index.dim() - 1)
count = segment_cpu.segment_coo(
torch.ones_like(index, dtype=grad_out.dtype), index,
grad_out.new_zeros(size), 'add')[0].clamp_(min=1)
grad_out.new_zeros(size), 'sum')[0].clamp_(min=1)
count = gat(grad_out.is_cuda).gather_coo(
count, index, count.new_empty(src_size[:index.dim()]))
......@@ -88,7 +88,7 @@ class SegmentCOO(torch.autograd.Function):
class SegmentCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, src, indptr, out, reduce):
assert reduce in ['add', 'mean', 'min', 'max']
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if out is not None:
ctx.mark_dirty(out)
......@@ -105,7 +105,7 @@ class SegmentCSR(torch.autograd.Function):
grad_src = None
if ctx.needs_input_grad[0]:
if ctx.reduce == 'add':
if ctx.reduce == 'sum' or ctx.reduce == 'add':
grad_src = gat(grad_out.is_cuda).gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
......@@ -129,7 +129,7 @@ class SegmentCSR(torch.autograd.Function):
return grad_src, None, None, None
def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
def segment_coo(src, index, out=None, dim_size=None, reduce="sum"):
r"""
|
......@@ -158,7 +158,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
:math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="add"`, the operation
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
......@@ -196,9 +196,9 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
(default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"add"`,
reduce (string, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"add"`)
(default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
......@@ -210,7 +210,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
index = torch.tensor([0, 0, 1, 1, 1, 2])
index = index.view(1, -1) # Broadcasting in the first and last dim.
out = segment_coo(src, index, reduce="add")
out = segment_coo(src, index, reduce="sum")
print(out.size())
......@@ -221,7 +221,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
return SegmentCOO.apply(src, index, out, dim_size, reduce)
def segment_csr(src, indptr, out=None, reduce="add"):
def segment_csr(src, indptr, out=None, reduce="sum"):
r"""
Reduces all values from the :attr:`src` tensor into :attr:`out` within the
ranges specified in the :attr:`indptr` tensor along the last dimension of
......@@ -242,7 +242,7 @@ def segment_csr(src, indptr, out=None, reduce="add"):
:math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="add"`, the operation
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
......@@ -267,9 +267,9 @@ def segment_csr(src, indptr, out=None, reduce="add"):
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"add"`,
reduce (string, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"add"`)
(default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
......@@ -281,7 +281,7 @@ def segment_csr(src, indptr, out=None, reduce="add"):
indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
out = segment_csr(src, indptr, reduce="add")
out = segment_csr(src, indptr, reduce="sum")
print(out.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment