Commit 04fe0806 authored by rusty1s's avatar rusty1s
Browse files

gather kernels

parent 1adc8a71
import time
import itertools
import torch
from scipy.io import loadmat
from torch_scatter import gather_coo, gather_csr
from scatter_segment import iters, device, sizes
from scatter_segment import short_rows, long_rows, download, bold
@torch.no_grad()
def correctness(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long)
dim_size = rowptr.size(0) - 1
for size in sizes[1:]:
try:
x = torch.randn((dim_size, size), device=device)
x = x.squeeze(-1) if size == 1 else x
out1 = x.index_select(0, row)
out2 = gather_coo(x, row)
out3 = gather_csr(x, rowptr)
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
except RuntimeError:
torch.cuda.empty_cache()
@torch.no_grad()
def timing(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long)
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
t1, t2, t3, t4 = [], [], [], []
for size in sizes:
try:
x = torch.randn((dim_size, size), device=device)
row_expand = row.view(-1, 1).expand(-1, x.size(-1))
x = x.squeeze(-1) if size == 1 else x
row_expand = row_expand.squeeze(-1) if size == 1 else row_expand
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.index_select(0, row)
del out
torch.cuda.synchronize()
t1.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t1.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.gather(0, row_expand)
del out
torch.cuda.synchronize()
t2.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t2.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = gather_coo(x, row)
del out
torch.cuda.synchronize()
t3.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t3.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = gather_csr(x, rowptr)
del out
torch.cuda.synchronize()
t4.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t4.append(float('inf'))
del x
except RuntimeError:
torch.cuda.empty_cache()
for t in (t1, t2, t3):
t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4])
winner = torch.zeros_like(ts, dtype=torch.bool)
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
winner = winner.tolist()
name = f'{group}/{name}'
print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
print('\t'.join([bold('SELECT ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
print('\t'.join([bold('GAT ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
print('\t'.join([bold('GAT_COO')] +
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
print('\t'.join([bold('GAT_CSR')] +
[bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
print()
if __name__ == '__main__':
for _ in range(10): # Warmup.
torch.randn(100, 100, device=device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
timing(dataset)
...@@ -3,35 +3,33 @@ import os.path as osp ...@@ -3,35 +3,33 @@ import os.path as osp
import itertools import itertools
import wget import wget
from scipy.io import loadmat
import torch import torch
from scipy.io import loadmat
from torch_scatter import scatter_add from torch_scatter import scatter_add, segment_csr, segment_coo
from torch_scatter import segment_csr, segment_coo
iters = 20 iters = 20
device = 'cuda' device = 'cuda'
sizes = [1, 16, 32, 64, 128, 256, 512] sizes = [1, 16, 32, 64, 128, 256, 512]
long_rows = [
('Janna', 'StocF-1465'),
('GHS_psdef', 'ldoor'),
]
short_rows = [ short_rows = [
('DIMACS10', 'citationCiteseer'), ('DIMACS10', 'citationCiteseer'),
('SNAP', 'web-Stanford'), ('SNAP', 'web-Stanford'),
] ]
long_rows = [
('Janna', 'StocF-1465'),
('GHS_psdef', 'ldoor'),
]
url = 'https://sparse.tamu.edu/mat/{}/{}.mat' def download(dataset):
for group, name in itertools.chain(long_rows, short_rows): url = 'https://sparse.tamu.edu/mat/{}/{}.mat'
for group, name in itertools.chain(long_rows, short_rows):
if not osp.exists(f'{name}.mat'): if not osp.exists(f'{name}.mat'):
print(f'Downloading {group}/{name}:') print(f'Downloading {group}/{name}:')
wget.download(url.format(group, name)) wget.download(url.format(group, name))
print('') print('')
for _ in range(10): # Warmup.
torch.randn(100, 100, device=device).sum()
def bold(text, flag=True): def bold(text, flag=True):
return f'\033[1m{text}\033[0m' if flag else text return f'\033[1m{text}\033[0m' if flag else text
...@@ -193,6 +191,10 @@ def timing(dataset): ...@@ -193,6 +191,10 @@ def timing(dataset):
print() print()
for dataset in itertools.chain(short_rows, long_rows): if __name__ == '__main__':
for _ in range(10): # Warmup.
torch.randn(100, 100, device=device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset) correctness(dataset)
timing(dataset) timing(dataset)
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr, at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt); at::optional<at::Tensor> out_opt);
at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
at::optional<at::Tensor> out_opt);
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) { at::optional<at::Tensor> out_opt) {
...@@ -14,6 +16,16 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, ...@@ -14,6 +16,16 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
return gather_csr_cuda(src, indptr, out_opt); return gather_csr_cuda(src, indptr, out_opt);
} }
at::Tensor gather_coo(at::Tensor src, at::Tensor index,
at::optional<at::Tensor> out_opt) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return gather_coo_cuda(src, index, out_opt);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_csr", &gather_csr, "Gather CSR (CUDA)"); m.def("gather_csr", &gather_csr, "Gather CSR (CUDA)");
m.def("gather_coo", &gather_coo, "Gather COO (CUDA)");
} }
...@@ -3,12 +3,60 @@ ...@@ -3,12 +3,60 @@
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include "atomics.cuh"
#include "compat.cuh" #include "compat.cuh"
#include "indptr.cuh"
#define THREADS 256 #define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS #define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
template <typename scalar_t, int TB>
__global__ void
gather_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx % TB;
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = __ldg(src_data + row_idx);
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) {
out_data[offset + out_idx] = val; // "Mostly" coalesced.
}
}
}
template <typename scalar_t>
__global__ void gather_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = src_data[thread_idx]; // Coalesced.
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int out_idx = row_start; out_idx < row_end; out_idx++) {
out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced.
}
}
}
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr, at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) { at::optional<at::Tensor> out_opt) {
...@@ -28,8 +76,8 @@ at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -28,8 +76,8 @@ at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
if (i != gather_dim) if (i != gather_dim)
AT_ASSERTM(src.size(i) == out.size(i)); AT_ASSERTM(src.size(i) == out.size(i));
} else { } else {
int64_t *d_gather_size = indptr.flatten()[-1].DATA_PTR<int64_t>(); auto d_gather_size = indptr.flatten()[-1].DATA_PTR<int64_t>();
int64_t *h_gather_size; auto h_gather_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t), cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
...@@ -38,5 +86,113 @@ at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -38,5 +86,113 @@ at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
out = at::empty(sizes, src.options()); out = at::empty(sizes, src.options());
} }
auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(gather_dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
if (K == 1) {
gather_csr_kernel<scalar_t, 4><<<BLOCKS(1, 4 * N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
} else {
gather_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, K, E);
}
});
return out;
}
template <typename scalar_t>
__global__ void
gather_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N;
scalar_t val = __ldg(src_data + offset + row);
out_data[row_idx] = val;
}
}
template <typename scalar_t>
__global__ void gather_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
if (thread_idx < E * K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K;
scalar_t val = __ldg(src_data + offset + K * row + col_idx);
out_data[thread_idx] = val;
}
}
at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
at::optional<at::Tensor> out_opt) {
AT_ASSERTM(src.dim() >= index.dim());
for (int i = 0; i < index.dim() - 1; i++)
AT_ASSERTM(src.size(i) == index.size(i));
src = src.contiguous();
auto gather_dim = index.dim() - 1;
at::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(out.size(i) == index.size(i));
for (int i = index.dim() + 1; i < src.dim(); i++)
AT_ASSERTM(out.size(i) == src.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim);
out = at::empty(sizes, src.options());
}
auto E = index.numel();
auto K = out.numel() / E;
auto N = src.size(gather_dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
if (K == 1) {
gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, N);
} else {
gather_coo_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, K, N);
}
});
return out; return out;
} }
#pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset {
static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "atomics.cuh" #include "atomics.cuh"
#include "compat.cuh" #include "compat.cuh"
#include "indptr.cuh"
#define THREADS 256 #define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS #define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
...@@ -90,22 +91,6 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -90,22 +91,6 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
}; };
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset {
static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
template <typename scalar_t, ReductionType REDUCE, int TB> template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void __global__ void
segment_csr_kernel(const scalar_t *src_data, segment_csr_kernel(const scalar_t *src_data,
...@@ -313,7 +298,7 @@ __global__ void segment_coo_broadcast_kernel( ...@@ -313,7 +298,7 @@ __global__ void segment_coo_broadcast_kernel(
// read and write is performed in column-major order. The intermediate // read and write is performed in column-major order. The intermediate
// results are written via atomics. // results are written via atomics.
int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB; int row_start = blockIdx.x * (blockDim.y + threadIdx.y) * TB;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x; int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
if (row_start < E && col_idx < K) { if (row_start < E && col_idx < K) {
...@@ -375,7 +360,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -375,7 +360,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
} }
auto E = index.numel(); auto E = index.numel();
auto K = src.numel() / index.numel(); auto K = src.numel() / E;
auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim); auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
......
from itertools import product
import pytest
import torch
from torch_scatter import gather_coo, gather_csr
from .utils import tensor
dtypes = [torch.float]
devices = [torch.device('cuda')]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
src = tensor([1, 2, 3, 4], dtype, device)
src = tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype, device)
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
out = gather_csr(src, indptr)
print('CSR', out)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = gather_coo(src, index)
print('COO', out)
out = src.index_select(0, index)
print('Expected', out)
...@@ -9,6 +9,7 @@ from .min import scatter_min ...@@ -9,6 +9,7 @@ from .min import scatter_min
from .logsumexp import scatter_logsumexp from .logsumexp import scatter_logsumexp
from .segment import segment_coo, segment_csr from .segment import segment_coo, segment_csr
from .gather import gather_coo, gather_csr
import torch_scatter.composite import torch_scatter.composite
...@@ -26,6 +27,8 @@ __all__ = [ ...@@ -26,6 +27,8 @@ __all__ = [
'scatter_logsumexp', 'scatter_logsumexp',
'segment_coo', 'segment_coo',
'segment_csr', 'segment_csr',
'gather_coo',
'gather_csr',
'torch_scatter', 'torch_scatter',
'__version__', '__version__',
] ]
import torch
if torch.cuda.is_available():
from torch_scatter import gather_cuda
def gather_coo(src, index, out=None):
return gather_cuda.gather_coo(src, index, out)
def gather_csr(src, indptr, out=None):
return gather_cuda.gather_csr(src, indptr, out)
...@@ -8,7 +8,6 @@ class SegmentCSR(torch.autograd.Function): ...@@ -8,7 +8,6 @@ class SegmentCSR(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, src, indptr, out, reduce): def forward(ctx, src, indptr, out, reduce):
assert reduce in ['any', 'add', 'mean', 'min', 'max'] assert reduce in ['any', 'add', 'mean', 'min', 'max']
assert indptr.dtype == torch.long
if out is not None: if out is not None:
ctx.mark_dirty(out) ctx.mark_dirty(out)
...@@ -31,21 +30,14 @@ class SegmentCSR(torch.autograd.Function): ...@@ -31,21 +30,14 @@ class SegmentCSR(torch.autograd.Function):
def segment_coo(src, index, out=None, dim_size=None, reduce='add'): def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
assert reduce in ['any', 'add', 'mean', 'min', 'max'] assert reduce in ['any', 'add', 'mean', 'min', 'max']
if out is None: if out is None: # TODO: MOVE TO CPP
dim_size = index.max().item() + 1 if dim_size is None else dim_size dim_size = index.max().item() + 1 if dim_size is None else dim_size
size = list(src.size()) size = list(src.size())
size[index.dim() - 1] = dim_size size[index.dim() - 1] = dim_size
out = src.new_zeros(size) # TODO: DEPENDS ON REDUCE out = src.new_zeros(size) # TODO: DEPENDS ON REDUCE
assert index.dtype == torch.long and src.dtype == out.dtype
out, arg_out = segment_cuda.segment_coo(src, index, out, reduce) out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
return out if arg_out is None else (out, arg_out) return out if arg_out is None else (out, arg_out)
def segment_csr(src, indptr, out=None, reduce='add'): def segment_csr(src, indptr, out=None, reduce='add'):
return SegmentCSR.apply(src, indptr, out, reduce) return SegmentCSR.apply(src, indptr, out, reduce)
# assert reduce in ['add', 'mean', 'min', 'max']
# assert indptr.dtype == torch.long
# out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
# return out if arg_out is None else (out, arg_out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment