Commit 48024c15 authored by rusty1s's avatar rusty1s
Browse files

added cpu segment implementation

parent 5817fb9d
...@@ -11,9 +11,6 @@ import torch_scatter ...@@ -11,9 +11,6 @@ import torch_scatter
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr from torch_scatter import segment_coo, segment_csr
iters = 20
sizes = [1, 16, 32, 64, 128, 256, 512]
short_rows = [ short_rows = [
('DIMACS10', 'citationCiteseer'), ('DIMACS10', 'citationCiteseer'),
('SNAP', 'web-Stanford'), ('SNAP', 'web-Stanford'),
...@@ -216,6 +213,9 @@ if __name__ == '__main__': ...@@ -216,6 +213,9 @@ if __name__ == '__main__':
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args() args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes
for _ in range(10): # Warmup. for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum() torch.randn(100, 100, device=args.device).sum()
......
#pragma once
#include <torch/extension.h>
#include "compat.h"
#define MAX_TENSORINFO_DIMS 25
template <typename scalar_t> struct TensorInfo {
TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS],
int st[MAX_TENSORINFO_DIMS]) {
data = p;
dims = dim;
AT_ASSERT(dims < MAX_TENSORINFO_DIMS);
for (int i = 0; i < dim; ++i) {
sizes[i] = sz[i];
strides[i] = st[i];
}
}
scalar_t *data;
int dims;
int sizes[MAX_TENSORINFO_DIMS];
int strides[MAX_TENSORINFO_DIMS];
};
template <typename scalar_t>
TensorInfo<scalar_t> getTensorInfo(const at::Tensor &tensor) {
int sizes[MAX_TENSORINFO_DIMS];
int strides[MAX_TENSORINFO_DIMS];
int dims = tensor.dim();
for (int i = 0; i < dims; ++i) {
sizes[i] = tensor.size(i);
strides[i] = tensor.stride(i);
}
return TensorInfo<scalar_t>(tensor.DATA_PTR<scalar_t>(), dims, sizes,
strides);
}
template <typename scalar_t> struct IndexToOffset {
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
int offset = 0;
for (int i = info.dims - 1; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
template <typename scalar_t> struct IndexPtrToOffset {
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
#include "index_info.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
enum ReductionType { ADD, MEAN, MIN, MAX };
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
if (reduce == "add") { \
const ReductionType REDUCE = ADD; \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::lowest();
} else {
return (scalar_t)0;
}
}
static inline void update(scalar_t *val, scalar_t new_val) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
}
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == ADD) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (count > 0 ? count : (scalar_t)1);
} else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else {
*address = (scalar_t)0;
}
}
}
};
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
std::string reduce) { std::string reduce) {
...@@ -9,8 +79,74 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, ...@@ -9,8 +79,74 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
CHECK_CPU(indptr); CHECK_CPU(indptr);
if (out_opt.has_value()) if (out_opt.has_value())
CHECK_CPU(out_opt.value()); CHECK_CPU(out_opt.value());
AT_ASSERTM(false, "Not yet implemented");
return std::make_tuple(src, at::nullopt); AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
// Broadcasting `indptr` via `expand`.
auto sizes = indptr.sizes().vec();
for (int i = 0; i < indptr.dim() - 1; i++) {
sizes[i] = src.size(i);
}
indptr = indptr.expand(sizes);
src = src.contiguous();
auto reduce_dim = indptr.dim() - 1;
at::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
"Input mismatch");
} else {
sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
out = at::empty(sizes, src.options());
}
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") {
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(reduce_dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
scalar_t val;
int64_t row_start, row_end, arg;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int n = 0; n < N; n++) {
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int k = 0; k < K; k++) {
val = Reducer<scalar_t, REDUCE>::init();
for (int64_t e = row_start; e < row_end; e++) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + e * K + k], &arg, e);
}
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, val,
arg_out_data + n * K + k, arg,
row_end - row_start);
}
}
});
});
return std::make_tuple(out, arg_out);
} }
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<at::Tensor, at::optional<at::Tensor>>
...@@ -19,8 +155,84 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -19,8 +155,84 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(index); CHECK_CPU(index);
CHECK_CPU(out); CHECK_CPU(out);
AT_ASSERTM(false, "Not yet implemented");
return std::make_tuple(src, at::nullopt); AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
// Broadcasting `index` via `expand`.
auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++) {
sizes[i] = src.size(i);
}
index = index.expand(sizes);
src = src.contiguous();
out = out.contiguous();
auto reduce_dim = index.dim() - 1;
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") {
arg_out = at::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto E_1 = index.numel() / src.size(reduce_dim);
auto E_2 = src.size(reduce_dim);
auto K = src.numel() / index.numel();
auto N = out.size(reduce_dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
scalar_t val;
int64_t idx, next_idx, row_start, arg;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int e_1 = 0; e_1 < E_1; e_1++) {
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
for (int k = 0; k < K; k++) {
idx = index_info.data[offset];
row_start = 0;
val = out_data[e_1 * N * K + k];
for (int e_2 = 0; e_2 < E_2; e_2++) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[e_1 * E_2 * K + e_2 * K + k], &arg, e_2);
if (e_2 == E_2 - 1) {
Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, val,
arg_out_data + e_1 * N * K + idx * K + k, arg,
e_2 + 1 - row_start);
} else {
next_idx = index_info.data[offset + (e_2 + 1) * stride];
if (idx != next_idx) {
Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, val,
arg_out_data + e_1 * N * K + idx * K + k, arg,
e_2 + 1 - row_start);
row_start = e_2 + 1;
val = out_data[e_1 * N * K + next_idx * K + k];
}
idx = next_idx;
}
}
}
}
});
});
return std::make_tuple(out, arg_out);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -178,7 +178,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -178,7 +178,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch"); AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
// Broadcasting across `index` via `expand`. // Broadcasting `indptr` via `expand`.
auto sizes = indptr.sizes().vec(); auto sizes = indptr.sizes().vec();
for (int i = 0; i < indptr.dim() - 1; i++) { for (int i = 0; i < indptr.dim() - 1; i++) {
sizes[i] = src.size(i); sizes[i] = src.size(i);
...@@ -379,7 +379,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -379,7 +379,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
// Broadcasting across `index` via `expand`. // Broadcasting `index` via `expand`.
auto sizes = index.sizes().vec(); auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++) { for (int i = 0; i < index.dim(); i++) {
sizes[i] = src.size(i); sizes[i] = src.size(i);
......
...@@ -10,7 +10,7 @@ from .utils import tensor, dtypes ...@@ -10,7 +10,7 @@ from .utils import tensor, dtypes
reductions = ['add', 'mean', 'min', 'max'] reductions = ['add', 'mean', 'min', 'max']
grad_reductions = ['add', 'mean'] grad_reductions = ['add', 'mean']
devices = [torch.device('cuda')] devices = [torch.device('cpu')]
tests = [ tests = [
{ {
...@@ -82,7 +82,6 @@ tests = [ ...@@ -82,7 +82,6 @@ tests = [
] ]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,dtype,device', @pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices)) product(tests, reductions, dtypes, devices))
def test_forward(test, reduce, dtype, device): def test_forward(test, reduce, dtype, device):
...@@ -119,7 +118,6 @@ def test_backward(test, reduce, device): ...@@ -119,7 +118,6 @@ def test_backward(test, reduce, device):
assert gradcheck(segment_csr, (src, indptr, None, reduce)) is True assert gradcheck(segment_csr, (src, indptr, None, reduce)) is True
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,dtype,device', @pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices)) product(tests, reductions, dtypes, devices))
def test_segment_out(test, reduce, dtype, device): def test_segment_out(test, reduce, dtype, device):
...@@ -153,7 +151,6 @@ def test_segment_out(test, reduce, dtype, device): ...@@ -153,7 +151,6 @@ def test_segment_out(test, reduce, dtype, device):
assert torch.all(out == expected) assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,dtype,device', @pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices)) product(tests, reductions, dtypes, devices))
def test_non_contiguous_segment(test, reduce, dtype, device): def test_non_contiguous_segment(test, reduce, dtype, device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment