"git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "e013bab5674e8d35d1998a050e1fa239ac9a747d"
Commit 3cf59da2 authored by rusty1s's avatar rusty1s
Browse files

add to sum, REDUCE to template

parent 7aa701b1
...@@ -122,11 +122,11 @@ def timing(dataset): ...@@ -122,11 +122,11 @@ def timing(dataset):
avg_row_len = row.size(0) / dim_size avg_row_len = row.size(0) / dim_size
def sca_row(x): def sca_row(x):
op = getattr(torch_scatter, f'scatter_{args.reduce}') op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
return op(x, row, dim=0, dim_size=dim_size) return op(x, row, dim=0, dim_size=dim_size)
def sca_col(x): def sca_col(x):
op = getattr(torch_scatter, f'scatter_{args.reduce}') op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
return op(x, row_perm, dim=0, dim_size=dim_size) return op(x, row_perm, dim=0, dim_size=dim_size)
def seg_coo(x): def seg_coo(x):
...@@ -136,10 +136,10 @@ def timing(dataset): ...@@ -136,10 +136,10 @@ def timing(dataset):
return segment_csr(x, rowptr, reduce=args.reduce) return segment_csr(x, rowptr, reduce=args.reduce)
def dense1(x): def dense1(x):
return getattr(torch, args.dense_reduce)(x, dim=-2) return getattr(torch, args.reduce)(x, dim=-2)
def dense2(x): def dense2(x):
return getattr(torch, args.dense_reduce)(x, dim=-1) return getattr(torch, args.reduce)(x, dim=-1)
t1, t2, t3, t4, t5, t6 = [], [], [], [], [], [] t1, t2, t3, t4, t5, t6 = [], [], [], [], [], []
...@@ -204,15 +204,12 @@ def timing(dataset): ...@@ -204,15 +204,12 @@ def timing(dataset):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument('--reduce', type=str, required=True,
'--reduce', choices=['sum', 'mean', 'min', 'max'])
type=str,
required=True,
choices=['add', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true') parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args() args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce
iters = 1 if args.device == 'cpu' else 20 iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512] sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes sizes = sizes[:3] if args.device == 'cpu' else sizes
......
...@@ -7,28 +7,36 @@ ...@@ -7,28 +7,36 @@
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
enum ReductionType { ADD, MEAN, MIN, MAX }; enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ #define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \ [&] { \
ReductionType REDUCE = ADD; \ switch (reduce2REDUCE.at(reduce)) { \
if (reduce == "add") { \ case SUM: { \
REDUCE = ADD; \ const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (reduce == "mean") { \ } \
REDUCE = MEAN; \ case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (reduce == "min") { \ } \
REDUCE = MIN; \ case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (reduce == "max") { \ } \
REDUCE = MAX; \ case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
} \
}() }()
template <typename scalar_t> struct Reducer { template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init(ReductionType REDUCE) { static inline scalar_t init() {
if (REDUCE == MIN) { if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max(); return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) { } else if (REDUCE == MAX) {
...@@ -38,18 +46,9 @@ template <typename scalar_t> struct Reducer { ...@@ -38,18 +46,9 @@ template <typename scalar_t> struct Reducer {
} }
} }
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val) { static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
}
}
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) { int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) { if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val; *val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) || } else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) { (REDUCE == MAX && new_val > *val)) {
...@@ -58,9 +57,9 @@ template <typename scalar_t> struct Reducer { ...@@ -58,9 +57,9 @@ template <typename scalar_t> struct Reducer {
} }
} }
static inline void write(ReductionType REDUCE, scalar_t *address, scalar_t val, static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) { int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == ADD) { if (REDUCE == SUM) {
*address = val; *address = val;
} else if (REDUCE == MEAN) { } else if (REDUCE == MEAN) {
*address = val / (count > 0 ? count : (scalar_t)1); *address = val / (count > 0 ? count : (scalar_t)1);
...@@ -111,7 +110,7 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, ...@@ -111,7 +110,7 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
at::optional<at::Tensor> arg_out = at::nullopt; at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options()); arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
...@@ -137,16 +136,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, ...@@ -137,16 +136,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
offset = (n / (indptr.size(-1) - 1)) * E * K; offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
vals[k] = Reducer<scalar_t>::init(REDUCE); vals[k] = Reducer<scalar_t, REDUCE>::init();
} }
for (int64_t e = row_start; e < row_end; e++) { for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
Reducer<scalar_t>::update(REDUCE, Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e); &vals[k], src_data[offset + e * K + k], &args[k], e);
} }
} }
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k], Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k], arg_out_data + n * K + k, args[k],
row_end - row_start); row_end - row_start);
} }
...@@ -183,7 +182,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -183,7 +182,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
at::optional<at::Tensor> arg_out = at::nullopt; at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), index.options()); arg_out = at::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
...@@ -215,13 +214,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -215,13 +214,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
for (int e_2 = 0; e_2 < E_2; e_2++) { for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
Reducer<scalar_t>::update(REDUCE, Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2); &vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
} }
if (e_2 == E_2 - 1) { if (e_2 == E_2 - 1) {
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
Reducer<scalar_t>::write(REDUCE, Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, vals[k], out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k], arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start); e_2 + 1 - row_start);
...@@ -232,7 +231,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -232,7 +231,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
if (idx != next_idx) { if (idx != next_idx) {
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
Reducer<scalar_t>::write(REDUCE, Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, vals[k], out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k], arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start); e_2 + 1 - row_start);
......
...@@ -11,23 +11,32 @@ ...@@ -11,23 +11,32 @@
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS #define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff #define FULL_MASK 0xffffffff
enum ReductionType { ADD, MEAN, MIN, MAX }; enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ #define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \ [&] { \
if (reduce == "add") { \ switch (reduce2REDUCE.at(reduce)) { \
const ReductionType REDUCE = ADD; \ case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (reduce == "mean") { \ } \
case MEAN: { \
const ReductionType REDUCE = MEAN; \ const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (reduce == "min") { \ } \
case MIN: { \
const ReductionType REDUCE = MIN; \ const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (reduce == "max") { \ } \
case MAX: { \
const ReductionType REDUCE = MAX; \ const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
} \
}() }()
template <typename scalar_t, ReductionType REDUCE> struct Reducer { template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...@@ -43,7 +52,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -43,7 +52,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ void update(scalar_t *val, static inline __host__ __device__ void update(scalar_t *val,
scalar_t new_val) { scalar_t new_val) {
if (REDUCE == ADD || REDUCE == MEAN) { if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val; *val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) || } else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) { (REDUCE == MAX && new_val > *val)) {
...@@ -53,7 +62,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -53,7 +62,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val, static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) { int64_t *arg, int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) { if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val; *val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) || } else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) { (REDUCE == MAX && new_val > *val)) {
...@@ -65,7 +74,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -65,7 +74,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ void write(scalar_t *address, scalar_t val, static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t *arg_address,
int64_t arg, int count) { int64_t arg, int count) {
if (REDUCE == ADD) { if (REDUCE == SUM) {
*address = val; *address = val;
} else if (REDUCE == MEAN) { } else if (REDUCE == MEAN) {
*address = val / (scalar_t)max(count, 1); *address = val / (scalar_t)max(count, 1);
...@@ -80,7 +89,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -80,7 +89,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
static inline __device__ void atomic_write(scalar_t *address, scalar_t val) { static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
if (REDUCE == ADD || REDUCE == MEAN) { if (REDUCE == SUM || REDUCE == MEAN) {
atomAdd(address, val); atomAdd(address, val);
} else if (REDUCE == MIN && val < *address) { } else if (REDUCE == MIN && val < *address) {
atomMin(address, val); atomMin(address, val);
...@@ -204,7 +213,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -204,7 +213,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> arg_out = at::nullopt; at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options()); arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
...@@ -396,7 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -396,7 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::optional<at::Tensor> arg_out = at::nullopt; at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), index.options()); arg_out = at::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
...@@ -455,14 +464,14 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -455,14 +464,14 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
}); });
}); });
if (reduce == "mean") { if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec(); auto sizes = index.sizes().vec();
sizes[reduce_dim] = out.size(reduce_dim); sizes[reduce_dim] = out.size(reduce_dim);
auto count = at::zeros(sizes, out.options()); auto count = at::zeros(sizes, out.options());
AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] { AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] {
auto count_data = count.DATA_PTR<scalar_t>(); auto count_data = count.DATA_PTR<scalar_t>();
segment_coo_kernel<scalar_t, ADD, false> segment_coo_kernel<scalar_t, SUM, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info, <<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, E, N); count_data, E, N);
}); });
......
...@@ -7,15 +7,15 @@ from torch_scatter import segment_coo, segment_csr ...@@ -7,15 +7,15 @@ from torch_scatter import segment_coo, segment_csr
from .utils import tensor, dtypes, devices from .utils import tensor, dtypes, devices
reductions = ['add', 'mean', 'min', 'max'] reductions = ['sum', 'mean', 'min', 'max']
grad_reductions = ['add', 'mean'] grad_reductions = ['sum', 'mean']
tests = [ tests = [
{ {
'src': [1, 2, 3, 4, 5, 6], 'src': [1, 2, 3, 4, 5, 6],
'index': [0, 0, 1, 1, 1, 3], 'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6], 'indptr': [0, 2, 5, 5, 6],
'add': [3, 12, 0, 6], 'sum': [3, 12, 0, 6],
'mean': [1.5, 4, 0, 6], 'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6], 'min': [1, 3, 0, 6],
'arg_min': [0, 2, 6, 5], 'arg_min': [0, 2, 6, 5],
...@@ -26,7 +26,7 @@ tests = [ ...@@ -26,7 +26,7 @@ tests = [
'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], 'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
'index': [0, 0, 1, 1, 1, 3], 'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6], 'indptr': [0, 2, 5, 5, 6],
'add': [[4, 6], [21, 24], [0, 0], [11, 12]], 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]], 'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
'min': [[1, 2], [5, 6], [0, 0], [11, 12]], 'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]], 'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]],
...@@ -37,7 +37,7 @@ tests = [ ...@@ -37,7 +37,7 @@ tests = [
'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]], 'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], 'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], 'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
'add': [[4, 21, 0, 11], [12, 18, 12, 0]], 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]], 'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
'min': [[1, 5, 0, 11], [2, 8, 12, 0]], 'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]], 'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]],
...@@ -48,7 +48,7 @@ tests = [ ...@@ -48,7 +48,7 @@ tests = [
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]], 'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
'index': [[0, 0, 1], [0, 2, 2]], 'index': [[0, 0, 1], [0, 2, 2]],
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]], 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], 'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], 'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]], 'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]],
...@@ -59,7 +59,7 @@ tests = [ ...@@ -59,7 +59,7 @@ tests = [
'src': [[1, 3], [2, 4]], 'src': [[1, 3], [2, 4]],
'index': [[0, 0], [0, 0]], 'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]], 'indptr': [[0, 2], [0, 2]],
'add': [[4], [6]], 'sum': [[4], [6]],
'mean': [[2], [3]], 'mean': [[2], [3]],
'min': [[1], [2]], 'min': [[1], [2]],
'arg_min': [[0], [0]], 'arg_min': [[0], [0]],
...@@ -70,7 +70,7 @@ tests = [ ...@@ -70,7 +70,7 @@ tests = [
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], 'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
'index': [[0, 0], [0, 0]], 'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]], 'indptr': [[0, 2], [0, 2]],
'add': [[[4, 4]], [[6, 6]]], 'sum': [[[4, 4]], [[6, 6]]],
'mean': [[[2, 2]], [[3, 3]]], 'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]], 'min': [[[1, 1]], [[2, 2]]],
'arg_min': [[[0, 0]], [[0, 0]]], 'arg_min': [[[0, 0]], [[0, 0]]],
...@@ -134,7 +134,7 @@ def test_segment_out(test, reduce, dtype, device): ...@@ -134,7 +134,7 @@ def test_segment_out(test, reduce, dtype, device):
segment_coo(src, index, out, reduce=reduce) segment_coo(src, index, out, reduce=reduce)
if reduce == 'add': if reduce == 'sum':
expected = expected - 2 expected = expected - 2
elif reduce == 'mean': elif reduce == 'mean':
expected = out # We can not really test this here. expected = out # We can not really test this here.
......
...@@ -31,7 +31,7 @@ class GatherCOO(torch.autograd.Function): ...@@ -31,7 +31,7 @@ class GatherCOO(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_src, _ = seg(grad_out.is_cuda).segment_coo( grad_src, _ = seg(grad_out.is_cuda).segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'add') grad_out, index, grad_out.new_zeros(src_size), 'sum')
return grad_src, None, None return grad_src, None, None
...@@ -53,7 +53,7 @@ class GatherCSR(torch.autograd.Function): ...@@ -53,7 +53,7 @@ class GatherCSR(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_src, _ = seg(grad_out.is_cuda).segment_csr( grad_src, _ = seg(grad_out.is_cuda).segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'add') grad_out, indptr, grad_out.new_empty(src_size), 'sum')
return grad_src, None, None return grad_src, None, None
......
...@@ -18,7 +18,7 @@ def gat(is_cuda): ...@@ -18,7 +18,7 @@ def gat(is_cuda):
class SegmentCOO(torch.autograd.Function): class SegmentCOO(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, src, index, out, dim_size, reduce): def forward(ctx, src, index, out, dim_size, reduce):
assert reduce in ['add', 'mean', 'min', 'max'] assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if out is not None: if out is not None:
ctx.mark_dirty(out) ctx.mark_dirty(out)
ctx.reduce = reduce ctx.reduce = reduce
...@@ -55,7 +55,7 @@ class SegmentCOO(torch.autograd.Function): ...@@ -55,7 +55,7 @@ class SegmentCOO(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if ctx.reduce == 'add': if ctx.reduce == 'sum' or ctx.reduce == 'add':
grad_src = gat(grad_out.is_cuda).gather_coo( grad_src = gat(grad_out.is_cuda).gather_coo(
grad_out, index, grad_out.new_empty(src_size)) grad_out, index, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean': elif ctx.reduce == 'mean':
...@@ -68,7 +68,7 @@ class SegmentCOO(torch.autograd.Function): ...@@ -68,7 +68,7 @@ class SegmentCOO(torch.autograd.Function):
size[-1] = grad_out.size(index.dim() - 1) size[-1] = grad_out.size(index.dim() - 1)
count = segment_cpu.segment_coo( count = segment_cpu.segment_coo(
torch.ones_like(index, dtype=grad_out.dtype), index, torch.ones_like(index, dtype=grad_out.dtype), index,
grad_out.new_zeros(size), 'add')[0].clamp_(min=1) grad_out.new_zeros(size), 'sum')[0].clamp_(min=1)
count = gat(grad_out.is_cuda).gather_coo( count = gat(grad_out.is_cuda).gather_coo(
count, index, count.new_empty(src_size[:index.dim()])) count, index, count.new_empty(src_size[:index.dim()]))
...@@ -88,7 +88,7 @@ class SegmentCOO(torch.autograd.Function): ...@@ -88,7 +88,7 @@ class SegmentCOO(torch.autograd.Function):
class SegmentCSR(torch.autograd.Function): class SegmentCSR(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, src, indptr, out, reduce): def forward(ctx, src, indptr, out, reduce):
assert reduce in ['add', 'mean', 'min', 'max'] assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if out is not None: if out is not None:
ctx.mark_dirty(out) ctx.mark_dirty(out)
...@@ -105,7 +105,7 @@ class SegmentCSR(torch.autograd.Function): ...@@ -105,7 +105,7 @@ class SegmentCSR(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if ctx.reduce == 'add': if ctx.reduce == 'sum' or ctx.reduce == 'add':
grad_src = gat(grad_out.is_cuda).gather_csr( grad_src = gat(grad_out.is_cuda).gather_csr(
grad_out, indptr, grad_out.new_empty(src_size)) grad_out, indptr, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean': elif ctx.reduce == 'mean':
...@@ -129,7 +129,7 @@ class SegmentCSR(torch.autograd.Function): ...@@ -129,7 +129,7 @@ class SegmentCSR(torch.autograd.Function):
return grad_src, None, None, None return grad_src, None, None, None
def segment_coo(src, index, out=None, dim_size=None, reduce="add"): def segment_coo(src, index, out=None, dim_size=None, reduce="sum"):
r""" r"""
| |
...@@ -158,7 +158,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"): ...@@ -158,7 +158,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
:math:`y - 1` in ascending order. :math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`. not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="add"`, the operation For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes computes
.. math:: .. math::
...@@ -196,9 +196,9 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"): ...@@ -196,9 +196,9 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
If :attr:`dim_size` is not given, a minimal sized output tensor If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned. according to :obj:`index.max() + 1` is returned.
(default: :obj:`None`) (default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"add"`, reduce (string, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"add"`) (default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)* :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
...@@ -210,7 +210,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"): ...@@ -210,7 +210,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
index = torch.tensor([0, 0, 1, 1, 1, 2]) index = torch.tensor([0, 0, 1, 1, 1, 2])
index = index.view(1, -1) # Broadcasting in the first and last dim. index = index.view(1, -1) # Broadcasting in the first and last dim.
out = segment_coo(src, index, reduce="add") out = segment_coo(src, index, reduce="sum")
print(out.size()) print(out.size())
...@@ -221,7 +221,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"): ...@@ -221,7 +221,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
return SegmentCOO.apply(src, index, out, dim_size, reduce) return SegmentCOO.apply(src, index, out, dim_size, reduce)
def segment_csr(src, indptr, out=None, reduce="add"): def segment_csr(src, indptr, out=None, reduce="sum"):
r""" r"""
Reduces all values from the :attr:`src` tensor into :attr:`out` within the Reduces all values from the :attr:`src` tensor into :attr:`out` within the
ranges specified in the :attr:`indptr` tensor along the last dimension of ranges specified in the :attr:`indptr` tensor along the last dimension of
...@@ -242,7 +242,7 @@ def segment_csr(src, indptr, out=None, reduce="add"): ...@@ -242,7 +242,7 @@ def segment_csr(src, indptr, out=None, reduce="add"):
:math:`x_m` in ascending order. :math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do The :attr:`indptr` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`. not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="add"`, the operation For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes computes
.. math:: .. math::
...@@ -267,9 +267,9 @@ def segment_csr(src, indptr, out=None, reduce="add"): ...@@ -267,9 +267,9 @@ def segment_csr(src, indptr, out=None, reduce="add"):
The number of dimensions of :attr:`index` needs to be less than or The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`. equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`) out (Tensor, optional): The destination tensor. (default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"add"`, reduce (string, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"add"`) (default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)* :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
...@@ -281,7 +281,7 @@ def segment_csr(src, indptr, out=None, reduce="add"): ...@@ -281,7 +281,7 @@ def segment_csr(src, indptr, out=None, reduce="add"):
indptr = torch.tensor([0, 2, 5, 6]) indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1) # Broadcasting in the first and last dim. indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
out = segment_csr(src, indptr, reduce="add") out = segment_csr(src, indptr, reduce="sum")
print(out.size()) print(out.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment