Commit 2bea1c3c authored by rusty1s's avatar rusty1s
Browse files

degree padding super fast

parent 3639bfab
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "utils.cuh" #include "utils.cuh"
#define THREADS 256 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void bin_kernel(const int64_t *rowcount, const int64_t *bin_strategy, __global__ void bin_kernel(const int64_t *rowcount, const int64_t *bin_strategy,
...@@ -80,37 +80,53 @@ std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount, ...@@ -80,37 +80,53 @@ std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount,
return index.split_with_sizes(sizes); return index.split_with_sizes(sizes);
} }
template <typename scalar_t, int64_t TB> __global__ void padded_mask_select_kernel(const int64_t *rowptr,
__global__ void const int64_t *col,
padded_index_select_kernel(const scalar_t *src, const int64_t *rowptr, const int64_t *index,
const int64_t *col, const int64_t *index, int64_t *out_idx, bool *mask,
scalar_t *out, bool *mask, int64_t length, int64_t length, int64_t numel) {
int64_t dim, int64_t numel) {
int64_t lane_idx, row_idx, row_start, row_end, col_idx;
for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < numel; thread_idx += gridDim.x * blockDim.x) {
lane_idx = thread_idx % length;
row_idx = index[thread_idx / length];
row_start = rowptr[row_idx];
row_end = rowptr[row_idx + 1];
col_idx = -1;
if (lane_idx < row_end - row_start) {
col_idx = col[row_start + lane_idx];
}
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x; out_idx[thread_idx] = col_idx;
auto dim_idx = thread_idx % dim; mask[thread_idx] = col_idx == -1;
auto lane_idx = (thread_idx / dim) % TB; }
auto index_idx = thread_idx / (TB * dim); }
if (thread_idx < numel) { template <typename scalar_t>
auto row_idx = __ldg(index + index_idx); __global__ void padded_index_select_kernel(const scalar_t *src,
auto row_start = __ldg(rowptr + row_idx); const int64_t *index, scalar_t *out,
auto row_end = __ldg(rowptr + row_idx + 1); scalar_t fill_value, int64_t dim,
int64_t numel) {
for (int64_t c = lane_idx; c < row_end - row_start; c += TB) {
auto x = src[__ldg(col + row_start + c) * dim + dim_idx]; int64_t index_idx, dim_idx, col;
out[index_idx * dim * length + c * dim + dim_idx] = x; for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
// mask[index_idx * dim * length + c * dim + dim_idx] = true; thread_idx < numel; thread_idx += gridDim.x * blockDim.x) {
index_idx = thread_idx / dim;
dim_idx = thread_idx % dim;
col = __ldg(index + index_idx);
if (col >= 0) {
fill_value = src[col * dim + dim_idx];
} }
out[thread_idx] = fill_value;
} }
} }
#define TB 4
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr, padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor index, torch::Tensor col, torch::Tensor index, int64_t length,
int64_t length) { torch::Tensor fill_value) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(rowptr); CHECK_CUDA(rowptr);
CHECK_CUDA(col); CHECK_CUDA(col);
...@@ -119,20 +135,38 @@ padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr, ...@@ -119,20 +135,38 @@ padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
CHECK_INPUT(rowptr.dim() == 1); CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1); CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(index.dim() == 1); CHECK_INPUT(index.dim() == 1);
CHECK_INPUT(fill_value.numel() == 1);
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
auto out = torch::zeros({index.size(0), length, src.size(-1)}, src.options()); auto out_idx = torch::empty({index.size(0), length}, index.options());
auto mask = auto mask = torch::empty({index.size(0), length, 1},
torch::zeros({index.size(0), length}, src.options().dtype(torch::kBool)); src.options().dtype(torch::kBool));
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
padded_mask_select_kernel<<<
std::min((out_idx.numel() + THREADS - 1) / THREADS, mpc * 8), THREADS, 0,
stream>>>(rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
index.data_ptr<int64_t>(), out_idx.data_ptr<int64_t>(),
mask.data_ptr<bool>(), length, out_idx.numel());
auto out = torch::empty({index.size(0), length, src.size(-1)}, src.options());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] {
padded_index_select_kernel<scalar_t, TB> scalar_t *fill;
<<<BLOCKS(index.numel() * src.size(-1) * TB), THREADS, 0, stream>>>( if (fill_value.is_cuda()) {
src.data_ptr<scalar_t>(), rowptr.data_ptr<int64_t>(), fill = (scalar_t *)malloc(sizeof(scalar_t));
col.data_ptr<int64_t>(), index.data_ptr<int64_t>(), cudaMemcpy(fill, fill_value.data_ptr<scalar_t>(), sizeof(scalar_t),
out.data_ptr<scalar_t>(), mask.data_ptr<bool>(), length, cudaMemcpyDeviceToHost);
src.size(-1), index.numel() * src.size(-1) * TB); } else {
fill = fill_value.data_ptr<scalar_t>();
}
padded_index_select_kernel<scalar_t>
<<<std::min((out.numel() + THREADS - 1) / THREADS, mpc * 8), THREADS, 0,
stream>>>(src.data_ptr<scalar_t>(), out_idx.data_ptr<int64_t>(),
out.data_ptr<scalar_t>(), fill[0], src.size(-1),
out.numel());
}); });
return std::make_tuple(out, mask); return std::make_tuple(out, mask);
......
...@@ -6,5 +6,5 @@ std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount, ...@@ -6,5 +6,5 @@ std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount,
torch::Tensor bin_strategy); torch::Tensor bin_strategy);
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr, padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor index, torch::Tensor col, torch::Tensor index, int64_t length,
int64_t length); torch::Tensor fill_value);
...@@ -24,10 +24,12 @@ std::vector<torch::Tensor> bin_assignment(torch::Tensor rowcount, ...@@ -24,10 +24,12 @@ std::vector<torch::Tensor> bin_assignment(torch::Tensor rowcount,
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
padded_index_select(torch::Tensor src, torch::Tensor rowptr, torch::Tensor col, padded_index_select(torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
torch::Tensor index, int64_t length) { torch::Tensor index, int64_t length,
torch::Tensor fill_value) {
if (src.device().is_cuda()) { if (src.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return padded_index_select_cuda(src, rowptr, col, index, length); return padded_index_select_cuda(src, rowptr, col, index, length,
fill_value);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
......
...@@ -68,11 +68,24 @@ def test_bin_assignment(device): ...@@ -68,11 +68,24 @@ def test_bin_assignment(device):
x = torch.randn(dataset[0].num_nodes, 512).to(device) x = torch.randn(dataset[0].num_nodes, 512).to(device)
rowptr = adj.storage.rowptr().to(device) rowptr = adj.storage.rowptr().to(device)
col = col.to(device) col = col.to(device)
for i in range(102):
if i == 2:
start.record()
for perm, count in zip(perms, bin_count):
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perm,
count, torch.tensor(0.))
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
print('-----------')
for i in range(102): for i in range(102):
if i == 2: if i == 2:
start.record() start.record()
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[0], torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[0],
bin_count[0]) bin_count[0],
torch.tensor(0.))
end.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(start.elapsed_time(end)) print(start.elapsed_time(end))
...@@ -80,7 +93,8 @@ def test_bin_assignment(device): ...@@ -80,7 +93,8 @@ def test_bin_assignment(device):
if i == 2: if i == 2:
start.record() start.record()
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[1], torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[1],
bin_count[1]) bin_count[1],
torch.tensor(0.))
end.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(start.elapsed_time(end)) print(start.elapsed_time(end))
...@@ -88,7 +102,8 @@ def test_bin_assignment(device): ...@@ -88,7 +102,8 @@ def test_bin_assignment(device):
if i == 2: if i == 2:
start.record() start.record()
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[2], torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[2],
bin_count[2]) bin_count[2],
torch.tensor(0.))
end.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(start.elapsed_time(end)) print(start.elapsed_time(end))
import pytest
import torch
from torch_sparse import SparseTensor
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import degree
devices = [torch.device('cuda')]
@pytest.mark.parametrize('device', devices)
def test_padded_index_select(device):
dataset = Planetoid('/tmp/Planetoid', name='PubMed')
data = dataset[0]
row, col = data.edge_index.to(device)
row = torch.arange(data.num_nodes).view(-1, 1).repeat(1, 4).view(-1)
col = torch.randint(0, data.num_nodes, (row.size(0), ))
row, col = row.to(device), col.to(device)
adj = SparseTensor(row=row, col=col)
rowcount = adj.storage.rowcount().to(device)
rowptr = adj.storage.rowptr().to(device)
bin_strategy = torch.tensor([[1, 4]]).to(device)
# bin_strategy = torch.tensor([[1, 5], [6, 12], [13, 19], [20, 30]],
# device=device)
perms = torch.ops.torch_sparse.bin_assignment(rowcount, bin_strategy)
lengths = bin_strategy[:, 1].view(-1).tolist()
print(lengths)
deg = degree(row, dtype=torch.long)
print(deg.size(), deg.min(), deg.float().mean(), deg.max())
bins = torch.bincount(deg)
print(bins)
nonzero = bins.nonzero().flatten()
print(nonzero)
print(bins[nonzero])
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for dim in [32, 64, 128, 256, 512, 1024]:
print(f'--- Dim: {dim} ---')
x = torch.randn(adj.size(0), dim).to(device)
for i in range(110):
if i == 10:
start.record()
for perm, length in zip(perms, lengths):
out1, _ = torch.ops.torch_sparse.padded_index_select(
x, rowptr, col, perm, length, torch.tensor(0.))
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
out2 = x.index_select(0, row)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
out3 = x.index_select(0, col)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
print(torch.allclose(out1.view(-1, dim), out3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment