Commit b5aa7bc0 authored by rusty1s's avatar rusty1s
Browse files

version 2

parent efbbce74
#include "padding_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define FULL_MASK 0xffffffff
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void bin_kernel(const int64_t *__restrict__ rowcount,
const int64_t *__restrict__ binptr,
int64_t *__restrict__ bin, int64_t *__restrict__ idx,
int *__restrict__ size, int *__restrict__ length,
const size_t B, const size_t N) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < N; thread_idx += gridDim.x * blockDim.x) {
int bin_idx = -1, deg = rowcount[thread_idx];
for (ptrdiff_t b = 1; b <= B; b++) {
if (deg < __ldg(binptr + b)) {
bin_idx = b - 1;
break;
}
}
if (bin_idx == -1)
bin_idx = B - 1;
int old = atomicAdd(size + bin_idx, 1);
atomicMax(length + bin_idx, deg);
bin[thread_idx] = bin_idx;
idx[thread_idx] = old;
}
}
__global__ void offset_kernel(const int *__restrict__ size,
const int *__restrict__ length,
int *__restrict__ offset, const size_t B) {
int bin_idx = threadIdx.x / 32;
int lane_idx = threadIdx.x % 32;
if (bin_idx <= B) {
int tmp = 0;
for (int i = lane_idx; i < bin_idx; i += 32) {
tmp += size[i] * length[i];
}
for (int i = 32 / 2; i > 0; i /= 2) {
tmp += __shfl_down_sync(FULL_MASK, tmp, i);
}
if (lane_idx == 0)
offset[bin_idx] = tmp;
}
}
template <int TB>
__global__ void padded_index_kernel(
const int64_t *__restrict__ rowptr, const int64_t *__restrict__ rowcount,
const int64_t *__restrict__ bin, const int64_t *__restrict__ idx,
int64_t *__restrict__ out, bool *__restrict__ mask,
const int *__restrict__ length, const int *__restrict__ offset,
const size_t B, const size_t N) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < TB * N; thread_idx += gridDim.x * blockDim.x) {
int row_idx = thread_idx / TB;
int lane_idx = thread_idx % TB;
int64_t bin_idx = bin[row_idx];
int len = __ldg(length + bin_idx);
int off = __ldg(offset + bin_idx) + len * idx[row_idx];
int64_t row_start = rowptr[row_idx], deg = rowcount[row_idx];
int64_t tmp;
for (int i = lane_idx; i < len; i += TB) {
tmp = -1;
if (i < deg)
tmp = row_start + i;
out[off + i] = tmp;
mask[off + i] = tmp == -1;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor rowcount,
torch::Tensor binptr) {
cudaSetDevice(rowcount.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
size_t B = binptr.numel() - 1;
size_t N = rowcount.numel();
auto bin = torch::empty(N, rowptr.options());
auto idx = torch::empty(N, rowptr.options());
auto tmp = torch::zeros(B + B + B + 1, rowcount.options().dtype(torch::kInt));
auto size = tmp.narrow(0, 0, B);
auto length = tmp.narrow(0, B, B);
auto offset = tmp.narrow(0, 2 * B, B + 1);
bin_kernel<<<std::min(BLOCKS(N), mpc * 8), THREADS, 0, stream>>>(
rowcount.data_ptr<int64_t>(), binptr.data_ptr<int64_t>(),
bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(), size.data_ptr<int>(),
length.data_ptr<int>(), B, N);
offset_kernel<<<BLOCKS(32 * (B + 1)), THREADS, 0, stream>>>(
size.data_ptr<int>(), length.data_ptr<int>(), offset.data_ptr<int>(), B);
auto h_tmp = torch::empty(
{tmp.numel()}, tmp.options().device(torch::kCPU).pinned_memory(true));
cudaMemcpy(h_tmp.data_ptr<int>(), tmp.data_ptr<int>(),
tmp.numel() * sizeof(int), cudaMemcpyDeviceToHost);
auto out = torch::empty({h_tmp.data_ptr<int>()[3 * B]}, rowptr.options());
auto mask = torch::empty({out.numel()}, rowptr.options().dtype(torch::kBool));
padded_index_kernel<8>
<<<std::min(BLOCKS(N * 8), mpc * 8), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), rowcount.data_ptr<int64_t>(),
bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(),
out.data_ptr<int64_t>(), mask.data_ptr<bool>(),
length.data_ptr<int>(), offset.data_ptr<int>(), B, N);
return std::make_tuple(out, mask, h_tmp.narrow(0, 0, B),
h_tmp.narrow(0, B, B), h_tmp.narrow(0, 2 * B, B + 1));
}
template <typename scalar_t>
__global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
const int64_t *__restrict__ col,
const int64_t *__restrict__ index,
scalar_t *__restrict__ out,
const scalar_t fill_value,
const size_t F, const size_t E) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < E * F; thread_idx += gridDim.x * blockDim.x) {
int64_t row_idx = thread_idx / F;
int64_t lane_idx = thread_idx % F;
int64_t index_idx = __ldg(index + row_idx);
scalar_tmp = fill_value;
if (index_idx != -1) {
tmp = src[__ldg(col + index_idx) + lane_idx];
}
out[thread_idx] = tmp;
}
}
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor col,
torch::Tensor index,
torch::Tensor fill_value) {
cudaSetDevice(src.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
size_t E = index.numel();
size_t F = src.size(-1);
auto out = torch::empty(E * F, src.options());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] {
scalar_t *fill;
if (fill_value.is_cuda()) {
fill = (scalar_t *)malloc(sizeof(scalar_t));
cudaMemcpy(fill, fill_value.data_ptr<scalar_t>(), sizeof(scalar_t),
cudaMemcpyDeviceToHost);
} else {
fill = fill_value.data_ptr<scalar_t>();
}
padded_index_select_kernel<scalar_t>
<<<std::min(BLOCKS(E * F), mpc * 8), THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), col.data_ptr<int64_t>(),
index.data_ptr<int64_t>(), out.data_ptr<scalar_t>(), fill[0], F, E);
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor rowcount,
torch::Tensor binptr);
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor col,
torch::Tensor index,
torch::Tensor fill_value);
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include "cuda/padding_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__padding(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
padded_index(torch::Tensor rowptr, torch::Tensor rowcount,
torch::Tensor binptr) {
return padded_index_cuda(rowptr, rowcount, binptr);
}
torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor col,
torch::Tensor index,
torch::Tensor fill_value) {
return padded_index_select(src, col, index, fill_value);
}
static auto registry =
torch::RegisterOperators()
.op("torch_sparse::padded_index", &padded_index)
.op("torch_sparse::padded_index_select", &padded_index_select);
...@@ -25,27 +25,75 @@ def test_padded_index_select(device): ...@@ -25,27 +25,75 @@ def test_padded_index_select(device):
deg = degree(row, dtype=torch.long) deg = degree(row, dtype=torch.long)
bins = torch.bincount(deg) bins = torch.bincount(deg)
print(bins.size()) # print(bins.size())
print(bins[:200]) # print(bins[:200])
# for i in range(110):
# if i == 10:
# start.record()
# perms, lengths = torch.ops.torch_sparse.bin_assignment(
# rowcount, binptr)
# end.record()
# torch.cuda.synchronize()
# print('bin assignment', start.elapsed_time(end))
idx, mask, size, length, offset = torch.ops.torch_sparse.padded_index(
rowptr, rowcount, binptr)
print(size)
print(length)
print(offset)
print(mask[:10])
print(idx[:10])
x = torch.randn(data.num_nodes, 128).to(device)
for i in range(110): for i in range(110):
if i == 10: if i == 10:
start.record() start.record()
perms, lengths = torch.ops.torch_sparse.bin_assignment( torch.ops.torch_sparse.padded_index(rowptr, rowcount, binptr)
rowcount, binptr)
end.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(start.elapsed_time(end)) print('padded index', start.elapsed_time(end))
return
for i in range(110): for i in range(110):
if i == 10: if i == 10:
start.record() start.record()
rowcount.sort() torch.ops.torch_sparse.padded_index_select(x, col, idx,
torch.tensor(0.))
end.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(start.elapsed_time(end)) print('padded index select', start.elapsed_time(end))
x = torch.randn(data.num_nodes, 128).to(device) for i in range(110):
if i == 10:
start.record()
torch.repeat_interleave(rowcount, rowcount)
end.record()
torch.cuda.synchronize()
print('repeat', start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
rowcount.cumsum(0)
end.record()
torch.cuda.synchronize()
print('cumsum', start.elapsed_time(end))
rowcount2 = rowcount.unsqueeze(1).repeat(1, 5).contiguous()
for i in range(110):
if i == 10:
start.record()
rowcount2.cumsum(0)
end.record()
torch.cuda.synchronize()
print('cumsum', start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
rowcount.sort()
end.record()
torch.cuda.synchronize()
print('sort', start.elapsed_time(end))
for i in range(110): for i in range(110):
if i == 10: if i == 10:
...@@ -53,7 +101,8 @@ def test_padded_index_select(device): ...@@ -53,7 +101,8 @@ def test_padded_index_select(device):
x.index_select(0, col) x.index_select(0, col)
end.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(start.elapsed_time(end)) print('index_select', start.elapsed_time(end))
return
for i in range(110): for i in range(110):
if i == 10: if i == 10:
...@@ -64,15 +113,13 @@ def test_padded_index_select(device): ...@@ -64,15 +113,13 @@ def test_padded_index_select(device):
torch.tensor(0.)) torch.tensor(0.))
end.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(start.elapsed_time(end)) print('padded_index_select', start.elapsed_time(end))
for perm, length in zip(perms, lengths): for perm, length in zip(perms, lengths):
out, mask = torch.ops.torch_sparse.padded_index_select( out, mask = torch.ops.torch_sparse.padded_index_select(
x, rowptr, col, perm, length, torch.tensor(0.)) x, rowptr, col, perm, length, torch.tensor(0.))
print(out.size(), mask.size(), out.numel(), (out != 0).sum().item()) print(out.size(), mask.size(), out.numel(), (out != 0).sum().item())
return
lengths = bin_strategy[:, 1].view(-1).tolist() lengths = bin_strategy[:, 1].view(-1).tolist()
for dim in [32, 64, 128, 256, 512, 1024]: for dim in [32, 64, 128, 256, 512, 1024]:
......
...@@ -9,7 +9,7 @@ expected_torch_version = (1, 4) ...@@ -9,7 +9,7 @@ expected_torch_version = (1, 4)
try: try:
for library in [ for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis',
'_rw', '_saint', '_degree_padding' '_rw', '_saint', '_padding'
]: ]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment