Commit e5e3985d authored by rusty1s's avatar rusty1s
Browse files

rowptr implementation + reduce cleanup

parent 143938b7
#include <torch/extension.h>
#include "compat.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at::Tensor rowptr(at::Tensor row, int64_t size) {
CHECK_CPU(row);
AT_ASSERTM(row.dim() == 1, "Row needs to be one-dimensional");
auto out = at::empty(size + 1, row.options());
auto row_data = row.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
int64_t numel = row.numel(), idx = row_data[0], next_idx;
for (int64_t i = 0; i <= idx; i++)
out_data[i] = 0;
for (int64_t i = 0; i < row.size(0) - 1; i++) {
next_idx = row_data[i + 1];
for (int64_t j = idx; j < next_idx; j++)
out_data[j + 1] = i + 1;
idx = next_idx;
}
for (int64_t i = idx + 1; i < size + 1; i++)
out_data[i] = numel;
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rowptr", &rowptr, "Rowptr (CPU)");
}
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor rowptr_cuda(at::Tensor row, int64_t size);
at::Tensor rowptr(at::Tensor row, int64_t size) {
CHECK_CUDA(row);
return rowptr_cuda(row, size);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rowptr", &rowptr, "Rowptr (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "compat.cuh"
#define THREADS 256
__global__ void rowptr_kernel(const int64_t *row_data, int64_t *out_data,
int64_t numel, int64_t size) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx == 0) {
for (int64_t i = 0; i < row_data[0]; i++)
out_data[i] = 0;
} else if (thread_idx == numel) {
for (int64_t i = row_data[numel - 1]; i < size + 1; i++)
out_data[i] = size;
} else {
for (int64_t i = row_data[thread_idx - 1]; i < row_data[thread_idx]; i++)
out_data[i] = thread_idx - 1;
}
}
at::Tensor rowptr_cuda(at::Tensor row, size_t size) {
AT_ASSERTM(row.dim() == 1, "Row needs to be one-dimensional");
auto out = at::empty(size + 1, row.options());
auto row_data = row.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
rowptr_kernel<<<(row.numel() + 2 + THREADS - 1) / THREADS, THREADS, 0,
stream>>>(row_data, out_data, row.numel(), size);
return out;
}
from itertools import product
import pytest
import torch
from torch_sparse import rowptr_cpu
from .utils import tensor, devices
if torch.cuda.is_available():
from torch_sparse import rowptr_cuda
tests = [
{
'row': [0, 0, 1, 1, 1, 2, 2],
'size': 5,
'rowptr': [0, 2, 5, 7, 7, 7],
},
{
'row': [0, 0, 1, 1, 1, 4, 4],
'size': 5,
'rowptr': [0, 2, 5, 5, 5, 7],
},
{
'row': [2, 2, 4, 4],
'size': 7,
'rowptr': [0, 0, 0, 2, 2, 4, 4, 4],
},
]
def rowptr(row, size):
if row.is_cuda:
return rowptr_cuda.rowptr(row, size)
else:
return rowptr_cpu.rowptr(row, size)
@pytest.mark.parametrize('test,device', product(tests, devices))
def test_rowptr(test, device):
row = tensor(test['row'], torch.long, device)
size = test['size']
expected = tensor(test['rowptr'], torch.long, device)
out = rowptr(row, size)
assert torch.all(out == expected)
import torch
dtypes = [torch.float]
dtypes = [torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')]
# if torch.cuda.is_available():
# devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]
if torch.cuda.is_available():
devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]
def tensor(x, dtype, device):
return torch.tensor(x, dtype=dtype, device=device)
return None if x is None else torch.tensor(x, dtype=dtype, device=device)
import torch
import torch_scatter
from torch_scatter import segment_csr
def reduce(src, dim=None, reduce='add', deterministic=False): if dim is None and src.has_value():
op = getattr(torch, 'sum' if reduce == 'add' else reduce)
return op(src.storage.value)
def __reduce__(src, dim=None, reduce='add', deterministic=False):
if dim is None and src.has_value():
func = getattr(torch, 'sum' if reduce == 'add' else reduce)
return func(src.storage.value)
if dim is None and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
......@@ -19,8 +22,8 @@ def reduce(src, dim=None, reduce='add', deterministic=False): if dim is None and
dense_dims = tuple(set([d - 1 for d in dims if d > 1]))
if len(sparse_dims) == 2 and src.has_value():
op = getattr(torch, 'sum' if reduce == 'add' else reduce)
return op(value, dim=(0, ) + dense_dims)
func = getattr(torch, 'sum' if reduce == 'add' else reduce)
return func(value, dim=(0, ) + dense_dims)
if len(sparse_dims) == 2 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
......@@ -28,17 +31,17 @@ def reduce(src, dim=None, reduce='add', deterministic=False): if dim is None and
return torch.tensor(value, device=src.device)
if len(dense_dims) > 0 and len(sparse_dims) == 0:
op = getattr(torch, 'sum' if reduce == 'add' else reduce)
func = getattr(torch, 'sum' if reduce == 'add' else reduce)
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = op(value, dim=dense_dims)
value = func(value, dim=dense_dims)
if isinstance(value, tuple):
return (src.set_value(value[0], layout='csr'),) + value[1:]
return (src.set_value(value[0], layout='csr'), ) + value[1:]
return src.set_value(value, layout='csr')
if len(dense_dims) > 0 and len(sparse_dims) > 0:
op = getattr(torch, 'sum' if reduce == 'add' else reduce)
func = getattr(torch, 'sum' if reduce == 'add' else reduce)
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = op(value, dim=dense_dims)
value = func(value, dim=dense_dims)
value = value[0] if isinstance(value, tuple) else value
if sparse_dims[0] == 0:
......@@ -53,7 +56,23 @@ def reduce(src, dim=None, reduce='add', deterministic=False): if dim is None and
return out
if sparse_dims[0] == 1:
op = getattr(torch_scatter, f'scatter_{reduce}')
out = op(value, col, dim=0, dim_size=src.sparse_size(0))
func = getattr(torch_scatter, f'scatter_{reduce}')
out = func(value, col, dim=0, dim_size=src.sparse_size(0))
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
def sum(src, dim=None, deterministic=False):
return __reduce__(src, dim, reduce='add', deterministic=deterministic)
def mean(src, dim=None, deterministic=False):
return __reduce__(src, dim, reduce='mean', deterministic=deterministic)
def min(src, dim=None, deterministic=False):
return __reduce__(src, dim, reduce='min', deterministic=deterministic)
def max(src, dim=None, deterministic=False):
return __reduce__(src, dim, reduce='max', deterministic=deterministic)
......@@ -3,6 +3,11 @@ import warnings
import torch
from torch_scatter import segment_csr
from torch_sparse import rowptr_cpu
if torch.cuda.is_available():
from torch_sparse import rowptr_cuda
__cache__ = {'enabled': True}
......@@ -196,31 +201,23 @@ class SparseStorage(object):
@cached_property
def rowcount(self):
# TODO
one = torch.ones_like(self.row)
return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0])
rowptr = self.rowptr
return rowptr[1:] - rowptr[:-1]
@cached_property
def rowptr(self):
# TODO
rowcount = self.rowcount
rowptr = rowcount.new_zeros(rowcount.numel() + 1)
torch.cumsum(rowcount, dim=0, out=rowptr[1:])
return rowptr
func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
return func.rowptr(self.row, self.sparse_size(0))
@cached_property
def colcount(self):
# TODO
one = torch.ones_like(self.col)
return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1])
colptr = self.colptr
return colptr[1:] - colptr[:-1]
@cached_property
def colptr(self):
# TODO
colcount = self.colcount
colptr = colcount.new_zeros(colcount.numel() + 1)
torch.cumsum(colcount, dim=0, out=colptr[1:])
return colptr
func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
return func.rowptr(self.col[self.csr2csc], self.sparse_size(1))
@cached_property
def csr2csc(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment