"benchmarks/vscode:/vscode.git/clone" did not exist on "4177f729fc65096814aebb6838ccd8ec5c0f04a8"
Unverified Commit c01f9bae authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #105 from rusty1s/traceable

[WIP] tracebale functions
parents 2520670a 02a47c46
...@@ -3,10 +3,5 @@ source=torch_scatter ...@@ -3,10 +3,5 @@ source=torch_scatter
[report] [report]
exclude_lines = exclude_lines =
pragma: no cover pragma: no cover
cuda torch.jit.script
forward
backward
apply
raise raise
min_value
max_value
...@@ -39,6 +39,7 @@ install: ...@@ -39,6 +39,7 @@ install:
- pip install codecov - pip install codecov
- pip install sphinx - pip install sphinx
- pip install sphinx_rtd_theme - pip install sphinx_rtd_theme
- pip install sphinx-autodoc-typehints
script: script:
- python -c "import torch; print(torch.__version__)" - python -c "import torch; print(torch.__version__)"
- pycodestyle . - pycodestyle .
......
Copyright (c) 2019 Matthias Fey <matthias.fey@tu-dortmund.de> Copyright (c) 2020 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
......
...@@ -22,34 +22,27 @@ ...@@ -22,34 +22,27 @@
**[Documentation](https://pytorch-scatter.readthedocs.io)** **[Documentation](https://pytorch-scatter.readthedocs.io)**
This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package. This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
The package consists of the following operations: Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
* [**Scatter Add**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html) The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`:
* [**Scatter Sub**](https://pytorch-scatter.readthedocs.io/en/latest/functions/sub.html)
* [**Scatter Mul**](https://pytorch-scatter.readthedocs.io/en/latest/functions/mul.html)
* [**Scatter Div**](https://pytorch-scatter.readthedocs.io/en/latest/functions/div.html)
* [**Scatter Mean**](https://pytorch-scatter.readthedocs.io/en/latest/functions/mean.html)
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)
In addition, we provide composite functions which make use of `scatter_*` operations under the hood: * [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment.html) based on arbitrary indices
* [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices
* [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers
* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax) In addition, we provide the following **composite functions** which make use of `scatter_*` operations under the hood: :`scatter_std`, `scatter_logsumexp`, `scatter_softmax` and `scatter_log_softmax`.
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations. All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
## Installation ## Installation
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*: Ensure that at least PyTorch 1.3.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
``` ```
$ python -c "import torch; print(torch.__version__)" $ python -c "import torch; print(torch.__version__)"
>>> 1.1.0 >>> 1.3.0
$ echo $PATH $ echo $PATH
>>> /usr/local/cuda/bin:... >>> /usr/local/cuda/bin:...
...@@ -81,17 +74,17 @@ from torch_scatter import scatter_max ...@@ -81,17 +74,17 @@ from torch_scatter import scatter_max
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index, fill_value=0) out, argmax = scatter_max(src, index, dim=-1)
``` ```
``` ```
print(out) print(out)
tensor([[ 0, 0, 4, 3, 2, 0], tensor([[0, 0, 4, 3, 2, 0],
[ 2, 4, 3, 0, 0, 0]]) [2, 4, 3, 0, 0, 0]])
print(argmax) print(argmax)
tensor([[-1, -1, 3, 4, 0, 1] tensor([[5, 5, 3, 4, 0, 1]
[ 1, 4, 3, -1, -1, -1]]) [1, 4, 3, 5, 5, 5]])
``` ```
## Running tests ## Running tests
......
...@@ -7,9 +7,7 @@ import wget ...@@ -7,9 +7,7 @@ import wget
import torch import torch
from scipy.io import loadmat from scipy.io import loadmat
import torch_scatter from torch_scatter import scatter, segment_coo, segment_csr
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
short_rows = [ short_rows = [
('DIMACS10', 'citationCiteseer'), ('DIMACS10', 'citationCiteseer'),
...@@ -47,34 +45,30 @@ def correctness(dataset): ...@@ -47,34 +45,30 @@ def correctness(dataset):
x = torch.randn((row.size(0), size), device=args.device) x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x x = x.squeeze(-1) if size == 1 else x
out1 = scatter_add(x, row, dim=0, dim_size=dim_size) out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='add') out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
out3 = segment_csr(x, rowptr, reduce='add') out3 = segment_csr(x, rowptr, reduce='add')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
out1 = scatter_mean(x, row, dim=0, dim_size=dim_size) out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean') out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
out3 = segment_csr(x, rowptr, reduce='mean') out3 = segment_csr(x, rowptr, reduce='mean')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_().mul_(-1) out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min')
out2 = segment_coo(x, row, reduce='min')
out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1)) out3 = segment_csr(x, rowptr, reduce='min')
out2, _ = segment_coo(x, row, reduce='min')
out3, _ = segment_csr(x, rowptr, reduce='min')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_() out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max')
out2 = segment_coo(x, row, reduce='max')
out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1)) out3 = segment_csr(x, rowptr, reduce='max')
out2, _ = segment_coo(x, row, reduce='max')
out3, _ = segment_csr(x, rowptr, reduce='max')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
...@@ -117,17 +111,15 @@ def timing(dataset): ...@@ -117,17 +111,15 @@ def timing(dataset):
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
row_perm = row[torch.randperm(row.size(0))] row2 = row[torch.randperm(row.size(0))]
dim_size = rowptr.size(0) - 1 dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size avg_row_len = row.size(0) / dim_size
def sca_row(x): def sca_row(x):
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}') return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)
return op(x, row, dim=0, dim_size=dim_size)
def sca_col(x): def sca_col(x):
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}') return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)
return op(x, row_perm, dim=0, dim_size=dim_size)
def seg_coo(x): def seg_coo(x):
return segment_coo(x, row, reduce=args.reduce) return segment_coo(x, row, reduce=args.reduce)
...@@ -205,11 +197,10 @@ def timing(dataset): ...@@ -205,11 +197,10 @@ def timing(dataset):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True, parser.add_argument('--reduce', type=str, required=True,
choices=['sum', 'mean', 'min', 'max']) choices=['sum', 'add', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true') parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args() args = parser.parse_args()
args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce
iters = 1 if args.device == 'cpu' else 20 iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512] sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes sizes = sizes[:3] if args.device == 'cpu' else sizes
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#pragma once
#include <torch/extension.h>
#include "compat.h"
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \
auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \
\
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
\
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} \
continue; \
} \
\
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
\
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \
} \
}()
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
TENSOR4, DIM, CODE) \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
TYPE4 *TENSOR4##_data = TENSOR4.DATA_PTR<TYPE4>(); \
auto TENSOR4##_size = TENSOR4.size(DIM); \
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \
auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \
\
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
\
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} \
continue; \
} \
\
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
TENSOR4##_data += TENSOR4.stride(cur_dim); \
\
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \
} \
}()
#include <torch/script.h>
#include "compat.h"
#include "index_info.h"
#include <vector>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (out_opt.has_value())
CHECK_CPU(out_opt.value());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = indptr.dim() - 1;
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
"Input mismatch");
torch::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != gather_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
} else {
auto sizes = src.sizes().vec();
sizes[gather_dim] = *indptr.flatten()[-1].DATA_PTR<int64_t>();
out = torch::empty(sizes, src.options());
}
auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(gather_dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
for (int n = 0; n < N; n++) {
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
for (int k = 0; k < K; k++) {
vals[k] = src_data[n * K + k];
}
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) {
out_data[offset + e * K + k] = vals[k];
}
}
}
});
return out;
}
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> out_opt) {
CHECK_CPU(src);
CHECK_CPU(index);
if (out_opt.has_value())
CHECK_CPU(out_opt.value());
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim() - 1; i++)
AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = index.dim() - 1;
torch::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch");
for (int i = index.dim() + 1; i < src.dim(); i++)
AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch");
} else {
auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim);
out = torch::empty(sizes, src.options());
}
auto E_1 = index.numel() / out.size(gather_dim);
auto E_2 = index.size(gather_dim);
auto K = out.numel() / index.numel();
auto N = src.size(gather_dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t idx, next_idx;
for (int e_1 = 0; e_1 < E_1; e_1++) {
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
idx = index_info.data[offset];
for (int k = 0; k < K; k++) {
vals[k] = src_data[e_1 * N * K + idx * K + k];
}
for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) {
out_data[e_1 * E_2 * K + e_2 * K + k] = vals[k];
}
if (e_2 < E_2 - 1) {
next_idx = index_info.data[offset + (e_2 + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) {
idx = next_idx;
for (int k = 0; k < K; k++) {
vals[k] = src_data[e_1 * N * K + idx * K + k];
}
}
}
}
}
});
return out;
}
static auto registry =
torch::RegisterOperators("torch_scatter_cpu::gather_csr", &gather_csr)
.op("torch_scatter_cpu::gather_coo", &gather_coo);
#include <torch/script.h>
#include "dim_apply.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
out_data[idx * out_stride] *= src_data[i * src_stride];
}
});
});
}
void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
out_data[idx * out_stride] /= src_data[i * src_stride];
}
});
});
}
void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
{
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
if (src_data[i * src_stride] >= out_data[idx * out_stride]) {
out_data[idx * out_stride] = src_data[i * src_stride];
arg_data[idx * arg_stride] = i;
}
}
});
});
}
void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out,
torch::Tensor arg, int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
CHECK_CPU(arg);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
{
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
if (src_data[i * src_stride] <= out_data[idx * out_stride]) {
out_data[idx * out_stride] = src_data[i * src_stride];
arg_data[idx * arg_stride] = i;
}
}
});
});
}
static auto registry =
torch::RegisterOperators("torch_scatter_cpu::scatter_mul", &scatter_mul)
.op("torch_scatter_cpu::scatter_div", &scatter_div)
.op("torch_scatter_cpu::scatter_max", &scatter_max)
.op("torch_scatter_cpu::scatter_min", &scatter_min);
#include <torch/script.h>
#include "compat.h"
#include "index_info.h"
#include <vector>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::lowest();
} else {
return (scalar_t)0;
}
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (count > 0 ? count : (scalar_t)1);
} else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else {
*address = (scalar_t)0;
}
}
}
};
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> out_opt, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (out_opt.has_value())
CHECK_CPU(out_opt.value());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
// Broadcasting `indptr` via `expand`.
auto sizes = indptr.sizes().vec();
for (int i = 0; i < indptr.dim() - 1; i++) {
sizes[i] = src.size(i);
}
indptr = indptr.expand(sizes);
src = src.contiguous();
auto reduce_dim = indptr.dim() - 1;
torch::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
"Input mismatch");
} else {
sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(reduce_dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int n = 0; n < N; n++) {
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int k = 0; k < K; k++) {
vals[k] = Reducer<scalar_t, REDUCE>::init();
}
for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e);
}
}
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
}
});
});
return std::make_tuple(out, arg_out);
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out,
std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
// Broadcasting `index` via `expand`.
auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++) {
sizes[i] = src.size(i);
}
index = index.expand(sizes);
src = src.contiguous();
out = out.contiguous();
auto reduce_dim = index.dim() - 1;
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto E_1 = index.numel() / src.size(reduce_dim);
auto E_2 = src.size(reduce_dim);
auto K = src.numel() / index.numel();
auto N = out.size(reduce_dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t idx, next_idx, row_start;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int e_1 = 0; e_1 < E_1; e_1++) {
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
idx = index_info.data[offset];
for (int k = 0; k < K; k++) {
vals[k] = out_data[e_1 * N * K + k];
}
row_start = 0;
for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
}
if (e_2 == E_2 - 1) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start);
}
} else {
next_idx = index_info.data[offset + (e_2 + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start);
vals[k] = out_data[e_1 * N * K + next_idx * K + k];
}
row_start = e_2 + 1;
}
idx = next_idx;
}
}
}
});
});
return std::make_tuple(out, arg_out);
}
static auto registry =
torch::RegisterOperators("torch_scatter_cpu::segment_csr", &segment_csr)
.op("torch_scatter_cpu::segment_coo", &segment_coo);
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
#define MAX_TENSORINFO_DIMS 25 #define MAX_TENSORINFO_DIMS 25
template <typename scalar_t> struct TensorInfo { template <typename scalar_t> struct TensorInfo {
...@@ -36,7 +34,7 @@ TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) { ...@@ -36,7 +34,7 @@ TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) {
strides[i] = tensor.stride(i); strides[i] = tensor.stride(i);
} }
return TensorInfo<scalar_t>(tensor.DATA_PTR<scalar_t>(), dims, sizes, return TensorInfo<scalar_t>(tensor.data_ptr<scalar_t>(), dims, sizes,
strides); strides);
} }
......
#pragma once
#include <limits>
#include <map>
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
};
#include "scatter_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() == index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) >= index.size(i));
if (dim < 0)
dim = src.dim() + dim;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto B = 1;
for (auto i = 0; i < dim; i++)
B *= src.size(i);
auto E = src.size(dim);
auto K = src.numel() / (B * E);
auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
int64_t i, idx;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
for (auto b = 0; b < B; b++) {
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++) {
i = b * E * K + e * K + k;
idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)];
Reducer<scalar_t, REDUCE>::update(
out_data + b * N * K + idx * K + k, src_data[i],
arg_out_data + b * N * K + idx * K + k, e);
}
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
});
});
return std::make_tuple(out, arg_out);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
#include "segment_coo_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= index.dim());
auto sizes = index.sizes().vec();
for (auto i = 0; i < index.dim(); i++)
sizes[i] = src.size(i);
index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
torch::optional<torch::Tensor> count = torch::nullopt;
if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
count = torch::zeros(sizes, out.options());
}
auto B = index.numel() / src.size(dim);
auto E = src.size(dim);
auto K = src.numel() / index.numel();
auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t *count_data = nullptr;
std::vector<scalar_t> vals(K);
int64_t idx, next_idx, row_start;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
if (REDUCE == MEAN)
count_data = count.value().data_ptr<scalar_t>();
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
idx = index_info.data[offset];
for (auto k = 0; k < K; k++)
vals[k] = out_data[b * N * K + k];
row_start = 0;
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[b * E * K + e * K + k], &args[k], e);
if (e == E - 1) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
if (REDUCE == MEAN)
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
} else {
next_idx = index_info.data[offset + (e + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) {
for (auto k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
vals[k] = out_data[b * N * K + next_idx * K + k];
}
if (REDUCE == MEAN)
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
row_start = e + 1;
}
idx = next_idx;
}
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MEAN)
arg_out = count;
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) == index.size(i));
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < src.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = index.size(dim);
out = torch::empty(sizes, src.options());
}
auto B = index.numel() / out.size(dim);
auto E = index.size(dim);
auto K = out.numel() / index.numel();
auto N = src.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t idx, next_idx;
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
idx = index_info.data[offset];
for (auto k = 0; k < K; k++)
vals[k] = src_data[b * N * K + idx * K + k];
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
out_data[b * E * K + e * K + k] = vals[k];
if (e < E - 1) {
next_idx = index_info.data[offset + (e + 1) * stride];
CHECK_INPUT(idx <= next_idx);
if (idx != next_idx) {
idx = next_idx;
for (auto k = 0; k < K; k++)
vals[k] = src_data[b * N * K + idx * K + k];
}
}
}
}
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
#include "segment_csr_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
} else {
sizes = src.sizes().vec();
sizes[dim] = indptr.size(dim) - 1;
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (auto n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e);
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
}
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
for (auto n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
for (auto k = 0; k < K; k++)
vals[k] = src_data[n * K + k];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++)
out_data[offset + e * K + k] = vals[k];
}
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce);
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment