Commit 1adc8a71 authored by rusty1s's avatar rusty1s
Browse files

segment any

parent 3c89ebc2
...@@ -48,11 +48,11 @@ def correctness(dataset): ...@@ -48,11 +48,11 @@ def correctness(dataset):
for size in sizes: for size in sizes:
try: try:
x = torch.randn((row.size(0), size), device=device) x = torch.randn((row.size(0), size), device=device)
x = x.unsqueeze(-1) if size == 1 else x x = x.squeeze(-1) if size == 1 else x
out1 = scatter_add(x, row, dim=0, dim_size=dim_size) out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
out2 = segment_coo(x, row, dim_size=dim_size) out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
out3 = segment_csr(x, rowptr) out3 = segment_csr(x, rowptr, reduce='add')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
...@@ -74,7 +74,7 @@ def timing(dataset): ...@@ -74,7 +74,7 @@ def timing(dataset):
for size in sizes: for size in sizes:
try: try:
x = torch.randn((row.size(0), size), device=device) x = torch.randn((row.size(0), size), device=device)
x = x.unsqueeze(-1) if size == 1 else x x = x.squeeze(-1) if size == 1 else x
try: try:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -104,7 +104,7 @@ def timing(dataset): ...@@ -104,7 +104,7 @@ def timing(dataset):
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.perf_counter() t = time.perf_counter()
for _ in range(iters): for _ in range(iters):
out = segment_coo(x, row, dim_size=dim_size) out = segment_coo(x, row, dim_size=dim_size, reduce='any')
del out del out
torch.cuda.synchronize() torch.cuda.synchronize()
t3.append(time.perf_counter() - t) t3.append(time.perf_counter() - t)
...@@ -116,7 +116,7 @@ def timing(dataset): ...@@ -116,7 +116,7 @@ def timing(dataset):
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.perf_counter() t = time.perf_counter()
for _ in range(iters): for _ in range(iters):
out = segment_csr(x, rowptr) out = segment_csr(x, rowptr, reduce='any')
del out del out
torch.cuda.synchronize() torch.cuda.synchronize()
t4.append(time.perf_counter() - t) t4.append(time.perf_counter() - t)
...@@ -134,7 +134,7 @@ def timing(dataset): ...@@ -134,7 +134,7 @@ def timing(dataset):
try: try:
x = torch.randn((dim_size, int(avg_row_len + 1), size), x = torch.randn((dim_size, int(avg_row_len + 1), size),
device=device) device=device)
x = x.unsqueeze(-1) if size == 1 else x x = x.squeeze(-1) if size == 1 else x
try: try:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -149,7 +149,7 @@ def timing(dataset): ...@@ -149,7 +149,7 @@ def timing(dataset):
t5.append(float('inf')) t5.append(float('inf'))
x = x.view(dim_size, size, int(avg_row_len + 1)) x = x.view(dim_size, size, int(avg_row_len + 1))
x = x.unsqueeze(-2) if size == 1 else x x = x.squeeze(-2) if size == 1 else x
try: try:
torch.cuda.synchronize() torch.cuda.synchronize()
......
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt);
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return gather_csr_cuda(src, indptr, out_opt);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_csr", &gather_csr, "Gather CSR (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "atomics.cuh"
#include "compat.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) {
AT_ASSERTM(src.dim() >= indptr.dim());
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i));
src = src.contiguous();
auto gather_dim = indptr.dim() - 1;
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1);
at::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != gather_dim)
AT_ASSERTM(src.size(i) == out.size(i));
} else {
int64_t *d_gather_size = indptr.flatten()[-1].DATA_PTR<int64_t>();
int64_t *h_gather_size;
cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto sizes = src.sizes().vec();
sizes[gather_dim] = *h_gather_size;
out = at::empty(sizes, src.options());
}
return out;
}
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#define FULL_MASK 0xffffffff #define FULL_MASK 0xffffffff
enum ReductionType { ADD, MEAN, MIN, MAX }; enum ReductionType { ADD, MEAN, MIN, MAX };
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ #define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \ [&] { \
if (reduce == "add") { \ if (reduce == "add") { \
...@@ -42,12 +41,12 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -42,12 +41,12 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val, static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) { int64_t *arg, int64_t new_arg) {
if ((REDUCE == MIN && new_val < *val) || if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) { (REDUCE == MAX && new_val > *val)) {
*val = new_val; *val = new_val;
*arg = new_arg; *arg = new_arg;
} else {
*val = *val + new_val;
} }
} }
...@@ -220,6 +219,22 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -220,6 +219,22 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
if (reduce == "any") {
auto index = indptr.narrow(reduce_dim, 0, indptr.size(reduce_dim) - 1);
auto index2 = indptr.narrow(reduce_dim, 1, indptr.size(reduce_dim) - 1);
auto mask = (index2 - index) == 0;
for (int i = reduce_dim + 1; i < src.dim(); i++) {
index = index.unsqueeze(-1);
mask = mask.unsqueeze(-1);
}
at::gather_out(out, src, reduce_dim, index.expand(out.sizes()));
out.masked_fill_(mask.expand(out.sizes()), 0);
return std::make_tuple(out, arg_out);
}
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1)); auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N; auto K = out.numel() / N;
auto E = src.size(reduce_dim); auto E = src.size(reduce_dim);
...@@ -351,6 +366,14 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -351,6 +366,14 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
if (reduce == "any") {
for (int i = reduce_dim + 1; i < src.dim(); i++) {
index = index.unsqueeze(-1);
}
out.scatter_(reduce_dim, index.expand(src.sizes()), src);
return std::make_tuple(out, arg_out);
}
auto E = index.numel(); auto E = index.numel();
auto K = src.numel() / index.numel(); auto K = src.numel() / index.numel();
auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim); auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
......
...@@ -16,16 +16,16 @@ def test_forward(dtype, device): ...@@ -16,16 +16,16 @@ def test_forward(dtype, device):
src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype, src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
device) device)
src = tensor([1, 2, 3, 4, 5, 6], dtype, device) # src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
src.requires_grad_() # src.requires_grad_()
indptr = tensor([0, 2, 5, 5, 6], torch.long, device) indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
out = segment_csr(src, indptr, reduce='max') out = segment_csr(src, indptr, reduce='any')
out = out[0] if isinstance(out, tuple) else out
print('CSR', out) print('CSR', out)
# out = out[0] if isinstance(out, tuple) else out
out.backward(torch.randn_like(out)) # out.backward(torch.randn_like(out))
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device) index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_coo(src, index, reduce='add') out = segment_coo(src, index, reduce='any')
print('COO', out) print('COO', out)
...@@ -7,7 +7,7 @@ if torch.cuda.is_available(): ...@@ -7,7 +7,7 @@ if torch.cuda.is_available():
class SegmentCSR(torch.autograd.Function): class SegmentCSR(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, src, indptr, out, reduce): def forward(ctx, src, indptr, out, reduce):
assert reduce in ['add', 'mean', 'min', 'max'] assert reduce in ['any', 'add', 'mean', 'min', 'max']
assert indptr.dtype == torch.long assert indptr.dtype == torch.long
if out is not None: if out is not None:
...@@ -30,12 +30,12 @@ class SegmentCSR(torch.autograd.Function): ...@@ -30,12 +30,12 @@ class SegmentCSR(torch.autograd.Function):
def segment_coo(src, index, out=None, dim_size=None, reduce='add'): def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
assert reduce in ['add', 'mean', 'min', 'max'] assert reduce in ['any', 'add', 'mean', 'min', 'max']
if out is None: if out is None:
dim_size = index.max().item() + 1 if dim_size is None else dim_size dim_size = index.max().item() + 1 if dim_size is None else dim_size
size = list(src.size()) size = list(src.size())
size[index.dim() - 1] = dim_size size[index.dim() - 1] = dim_size
out = src.new_zeros(size) # TODO: DEPENDENT ON REDUCE out = src.new_zeros(size) # TODO: DEPENDS ON REDUCE
assert index.dtype == torch.long and src.dtype == out.dtype assert index.dtype == torch.long and src.dtype == out.dtype
out, arg_out = segment_cuda.segment_coo(src, index, out, reduce) out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
return out if arg_out is None else (out, arg_out) return out if arg_out is None else (out, arg_out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment