Unverified Commit c01f9bae authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #105 from rusty1s/traceable

[WIP] tracebale functions
parents 2520670a 02a47c46
......@@ -3,10 +3,5 @@ source=torch_scatter
[report]
exclude_lines =
pragma: no cover
cuda
forward
backward
apply
torch.jit.script
raise
min_value
max_value
......@@ -39,6 +39,7 @@ install:
- pip install codecov
- pip install sphinx
- pip install sphinx_rtd_theme
- pip install sphinx-autodoc-typehints
script:
- python -c "import torch; print(torch.__version__)"
- pycodestyle .
......
Copyright (c) 2019 Matthias Fey <matthias.fey@tu-dortmund.de>
Copyright (c) 2020 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
......
......@@ -22,34 +22,27 @@
**[Documentation](https://pytorch-scatter.readthedocs.io)**
This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor.
The package consists of the following operations:
This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
* [**Scatter Add**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html)
* [**Scatter Sub**](https://pytorch-scatter.readthedocs.io/en/latest/functions/sub.html)
* [**Scatter Mul**](https://pytorch-scatter.readthedocs.io/en/latest/functions/mul.html)
* [**Scatter Div**](https://pytorch-scatter.readthedocs.io/en/latest/functions/div.html)
* [**Scatter Mean**](https://pytorch-scatter.readthedocs.io/en/latest/functions/mean.html)
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)
The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`:
In addition, we provide composite functions which make use of `scatter_*` operations under the hood:
* [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment.html) based on arbitrary indices
* [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices
* [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers
* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
In addition, we provide the following **composite functions** which make use of `scatter_*` operations under the hood: :`scatter_std`, `scatter_logsumexp`, `scatter_softmax` and `scatter_log_softmax`.
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
## Installation
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
Ensure that at least PyTorch 1.3.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
```
$ python -c "import torch; print(torch.__version__)"
>>> 1.1.0
>>> 1.3.0
$ echo $PATH
>>> /usr/local/cuda/bin:...
......@@ -81,17 +74,17 @@ from torch_scatter import scatter_max
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index, fill_value=0)
out, argmax = scatter_max(src, index, dim=-1)
```
```
print(out)
tensor([[ 0, 0, 4, 3, 2, 0],
[ 2, 4, 3, 0, 0, 0]])
tensor([[0, 0, 4, 3, 2, 0],
[2, 4, 3, 0, 0, 0]])
print(argmax)
tensor([[-1, -1, 3, 4, 0, 1]
[ 1, 4, 3, -1, -1, -1]])
tensor([[5, 5, 3, 4, 0, 1]
[1, 4, 3, 5, 5, 5]])
```
## Running tests
......
......@@ -7,9 +7,7 @@ import wget
import torch
from scipy.io import loadmat
import torch_scatter
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
from torch_scatter import scatter, segment_coo, segment_csr
short_rows = [
('DIMACS10', 'citationCiteseer'),
......@@ -47,34 +45,30 @@ def correctness(dataset):
x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
out3 = segment_csr(x, rowptr, reduce='add')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
out3 = segment_csr(x, rowptr, reduce='mean')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_().mul_(-1)
out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1))
out2, _ = segment_coo(x, row, reduce='min')
out3, _ = segment_csr(x, rowptr, reduce='min')
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min')
out2 = segment_coo(x, row, reduce='min')
out3 = segment_csr(x, rowptr, reduce='min')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_()
out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1))
out2, _ = segment_coo(x, row, reduce='max')
out3, _ = segment_csr(x, rowptr, reduce='max')
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max')
out2 = segment_coo(x, row, reduce='max')
out3 = segment_csr(x, rowptr, reduce='max')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
......@@ -117,17 +111,15 @@ def timing(dataset):
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
row_perm = row[torch.randperm(row.size(0))]
row2 = row[torch.randperm(row.size(0))]
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
def sca_row(x):
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
return op(x, row, dim=0, dim_size=dim_size)
return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)
def sca_col(x):
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
return op(x, row_perm, dim=0, dim_size=dim_size)
return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)
def seg_coo(x):
return segment_coo(x, row, reduce=args.reduce)
......@@ -205,11 +197,10 @@ def timing(dataset):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True,
choices=['sum', 'mean', 'min', 'max'])
choices=['sum', 'add', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#pragma once
#include <torch/extension.h>
#include "compat.h"
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \
auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \
\
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
\
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} \
continue; \
} \
\
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
\
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \
} \
}()
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
TENSOR4, DIM, CODE) \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
TYPE4 *TENSOR4##_data = TENSOR4.DATA_PTR<TYPE4>(); \
auto TENSOR4##_size = TENSOR4.size(DIM); \
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \
auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \
\
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
\
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} \
continue; \
} \
\
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
TENSOR4##_data += TENSOR4.stride(cur_dim); \
\
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \
} \
}()
#include <torch/script.h>
#include "compat.h"
#include "index_info.h"
#include <vector>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (out_opt.has_value())
CHECK_CPU(out_opt.value());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = indptr.dim() - 1;
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
"Input mismatch");
torch::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != gather_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
} else {
auto sizes = src.sizes().vec();
sizes[gather_dim] = *indptr.flatten()[-1].DATA_PTR<int64_t>();
out = torch::empty(sizes, src.options());
}
auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(gather_dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
for (int n = 0; n < N; n++) {
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
for (int k = 0; k < K; k++) {
vals[k] = src_data[n * K + k];
}
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) {
out_data[offset + e * K + k] = vals[k];
}
}
}
});
return out;
}
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> out_opt) {
CHECK_CPU(src);
CHECK_CPU(index);
if (out_opt.has_value())
CHECK_CPU(out_opt.value());
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim() - 1; i++)
AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = index.dim() - 1;
torch::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch");
for (int i = index.dim() + 1; i < src.dim(); i++)
AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch");
} else {
auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim);
out = torch::empty(sizes, src.options());
}
auto E_1 = index.numel() / out.size(gather_dim);
auto E_2 = index.size(gather_dim);
auto K = out.numel() / index.numel();
auto N = src.size(gather_dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t idx, next_idx;
for (int e_1 = 0; e_1 < E_1; e_1++) {
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
idx = index_info.data[offset];
for (int k = 0; k < K; k++) {
vals[k] = src_data[e_1 * N * K + idx * K + k];
}
for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) {
out_data[e_1 * E_2 * K + e_2 * K + k] = vals[k];
}
if (e_2 < E_2 - 1) {
next_idx = index_info.data[offset + (e_2 + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) {
idx = next_idx;
for (int k = 0; k < K; k++) {
vals[k] = src_data[e_1 * N * K + idx * K + k];
}
}
}
}
}
});
return out;
}
static auto registry =
torch::RegisterOperators("torch_scatter_cpu::gather_csr", &gather_csr)
.op("torch_scatter_cpu::gather_coo", &gather_coo);
#include <torch/script.h>
#include "dim_apply.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
out_data[idx * out_stride] *= src_data[i * src_stride];
}
});
});
}
void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
out_data[idx * out_stride] /= src_data[i * src_stride];
}
});
});
}
void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
{
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
if (src_data[i * src_stride] >= out_data[idx * out_stride]) {
out_data[idx * out_stride] = src_data[i * src_stride];
arg_data[idx * arg_stride] = i;
}
}
});
});
}
void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
CHECK_CPU(arg);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
{
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
if (src_data[i * src_stride] <= out_data[idx * out_stride]) {
out_data[idx * out_stride] = src_data[i * src_stride];
arg_data[idx * arg_stride] = i;
}
}
});
});
}
static auto registry =
torch::RegisterOperators("torch_scatter_cpu::scatter_mul", &scatter_mul)
.op("torch_scatter_cpu::scatter_div", &scatter_div)
.op("torch_scatter_cpu::scatter_max", &scatter_max)
.op("torch_scatter_cpu::scatter_min", &scatter_min);
#include <torch/script.h>
#include "compat.h"
#include "index_info.h"
#include <vector>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::lowest();
} else {
return (scalar_t)0;
}
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (count > 0 ? count : (scalar_t)1);
} else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else {
*address = (scalar_t)0;
}
}
}
};
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (out_opt.has_value())
CHECK_CPU(out_opt.value());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
// Broadcasting `indptr` via `expand`.
auto sizes = indptr.sizes().vec();
for (int i = 0; i < indptr.dim() - 1; i++) {
sizes[i] = src.size(i);
}
indptr = indptr.expand(sizes);
src = src.contiguous();
auto reduce_dim = indptr.dim() - 1;
torch::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
"Input mismatch");
} else {
sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(reduce_dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int n = 0; n < N; n++) {
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int k = 0; k < K; k++) {
vals[k] = Reducer<scalar_t, REDUCE>::init();
}
for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e);
}
}
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
}
});
});
return std::make_tuple(out, arg_out);
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out,
std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
// Broadcasting `index` via `expand`.
auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++) {
sizes[i] = src.size(i);
}
index = index.expand(sizes);
src = src.contiguous();
out = out.contiguous();
auto reduce_dim = index.dim() - 1;
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto E_1 = index.numel() / src.size(reduce_dim);
auto E_2 = src.size(reduce_dim);
auto K = src.numel() / index.numel();
auto N = out.size(reduce_dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t idx, next_idx, row_start;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int e_1 = 0; e_1 < E_1; e_1++) {
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
idx = index_info.data[offset];
for (int k = 0; k < K; k++) {
vals[k] = out_data[e_1 * N * K + k];
}
row_start = 0;
for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
}
if (e_2 == E_2 - 1) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start);
}
} else {
next_idx = index_info.data[offset + (e_2 + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start);
vals[k] = out_data[e_1 * N * K + next_idx * K + k];
}
row_start = e_2 + 1;
}
idx = next_idx;
}
}
}
});
});
return std::make_tuple(out, arg_out);
}
static auto registry =
torch::RegisterOperators("torch_scatter_cpu::segment_csr", &segment_csr)
.op("torch_scatter_cpu::segment_coo", &segment_coo);
......@@ -2,8 +2,6 @@
#include <torch/extension.h>
#include "compat.h"
#define MAX_TENSORINFO_DIMS 25
template <typename scalar_t> struct TensorInfo {
......@@ -36,7 +34,7 @@ TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) {
strides[i] = tensor.stride(i);
}
return TensorInfo<scalar_t>(tensor.DATA_PTR<scalar_t>(), dims, sizes,
return TensorInfo<scalar_t>(tensor.data_ptr<scalar_t>(), dims, sizes,
strides);
}
......
#pragma once
#include <limits>
#include <map>
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
};
#include "scatter_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() == index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) >= index.size(i));
if (dim < 0)
dim = src.dim() + dim;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto B = 1;
for (auto i = 0; i < dim; i++)
B *= src.size(i);
auto E = src.size(dim);
auto K = src.numel() / (B * E);
auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
int64_t i, idx;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
for (auto b = 0; b < B; b++) {
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++) {
i = b * E * K + e * K + k;
idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)];
Reducer<scalar_t, REDUCE>::update(
out_data + b * N * K + idx * K + k, src_data[i],
arg_out_data + b * N * K + idx * K + k, e);
}
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
});
});
return std::make_tuple(out, arg_out);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
#include "segment_coo_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= index.dim());
auto sizes = index.sizes().vec();
for (auto i = 0; i < index.dim(); i++)
sizes[i] = src.size(i);
index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
torch::optional<torch::Tensor> count = torch::nullopt;
if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
count = torch::zeros(sizes, out.options());
}
auto B = index.numel() / src.size(dim);
auto E = src.size(dim);
auto K = src.numel() / index.numel();
auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t *count_data = nullptr;
std::vector<scalar_t> vals(K);
int64_t idx, next_idx, row_start;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
if (REDUCE == MEAN)
count_data = count.value().data_ptr<scalar_t>();
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
idx = index_info.data[offset];
for (auto k = 0; k < K; k++)
vals[k] = out_data[b * N * K + k];
row_start = 0;
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[b * E * K + e * K + k], &args[k], e);
if (e == E - 1) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
if (REDUCE == MEAN)
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
} else {
next_idx = index_info.data[offset + (e + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) {
for (auto k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
vals[k] = out_data[b * N * K + next_idx * K + k];
}
if (REDUCE == MEAN)
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
row_start = e + 1;
}
idx = next_idx;
}
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MEAN)
arg_out = count;
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) == index.size(i));
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < src.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = index.size(dim);
out = torch::empty(sizes, src.options());
}
auto B = index.numel() / out.size(dim);
auto E = index.size(dim);
auto K = out.numel() / index.numel();
auto N = src.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t idx, next_idx;
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
idx = index_info.data[offset];
for (auto k = 0; k < K; k++)
vals[k] = src_data[b * N * K + idx * K + k];
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
out_data[b * E * K + e * K + k] = vals[k];
if (e < E - 1) {
next_idx = index_info.data[offset + (e + 1) * stride];
CHECK_INPUT(idx <= next_idx);
if (idx != next_idx) {
idx = next_idx;
for (auto k = 0; k < K; k++)
vals[k] = src_data[b * N * K + idx * K + k];
}
}
}
}
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
#include "segment_csr_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
} else {
sizes = src.sizes().vec();
sizes[dim] = indptr.size(dim) - 1;
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (auto n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e);
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
}
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
for (auto n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
for (auto k = 0; k < K; k++)
vals[k] = src_data[n * K + k];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++)
out_data[offset + e * K + k] = vals[k];
}
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce);
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment