Commit 3639bfab authored by rusty1s's avatar rusty1s
Browse files

initial commit

parent ae928282
#include "degree_padding_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 256
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void bin_kernel(const int64_t *rowcount, const int64_t *bin_strategy,
int64_t *bin, int64_t *one_hot, int64_t num_bins,
int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
auto count = rowcount[thread_idx];
int64_t b = -1;
for (int64_t i = 0; i < num_bins; i++) {
if (count >= __ldg(bin_strategy + 2 * i) &&
count <= __ldg(bin_strategy + 2 * i + 1)) {
b = i;
break;
}
}
bin[thread_idx] = b;
if (b >= 0) {
one_hot[b * numel + thread_idx] = 1;
}
}
}
__global__ void index_kernel(const int64_t *bin, const int64_t *cumsum,
const int64_t *nodes_per_bin, int64_t *index,
int64_t num_bins, int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
auto b = bin[thread_idx];
if (b >= 0) {
auto idx = cumsum[b * numel + thread_idx] - 1;
for (int64_t i = 0; i < b; i++) {
idx += __ldg(nodes_per_bin + i);
}
index[idx] = thread_idx;
}
}
}
std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount,
torch::Tensor bin_strategy) {
CHECK_CUDA(rowcount);
CHECK_CUDA(bin_strategy);
CHECK_INPUT(rowcount.dim() == 1);
CHECK_INPUT(bin_strategy.dim() == 2 && bin_strategy.size(1) == 2);
cudaSetDevice(rowcount.get_device());
int64_t num_bins = bin_strategy.size(0);
auto bin = torch::empty({rowcount.numel()}, rowcount.options());
auto one_hot = torch::zeros({num_bins, rowcount.numel()}, rowcount.options());
auto stream = at::cuda::getCurrentCUDAStream();
bin_kernel<<<BLOCKS(rowcount.numel()), THREADS, 0, stream>>>(
rowcount.data_ptr<int64_t>(), bin_strategy.data_ptr<int64_t>(),
bin.data_ptr<int64_t>(), one_hot.data_ptr<int64_t>(), num_bins,
rowcount.numel());
auto cumsum = one_hot.cumsum(1);
auto d_nodes_per_bin = cumsum.select(1, rowcount.numel() - 1).contiguous();
auto h_nodes_per_bin = d_nodes_per_bin.cpu();
auto h_size = h_nodes_per_bin.sum().data_ptr<int64_t>()[0];
auto index = torch::empty({h_size}, rowcount.options());
index_kernel<<<BLOCKS(bin.numel()), THREADS, 0, stream>>>(
bin.data_ptr<int64_t>(), cumsum.data_ptr<int64_t>(),
d_nodes_per_bin.data_ptr<int64_t>(), index.data_ptr<int64_t>(), num_bins,
rowcount.numel());
auto sizes = at::IntArrayRef(h_nodes_per_bin.data_ptr<int64_t>(), num_bins);
return index.split_with_sizes(sizes);
}
template <typename scalar_t, int64_t TB>
__global__ void
padded_index_select_kernel(const scalar_t *src, const int64_t *rowptr,
const int64_t *col, const int64_t *index,
scalar_t *out, bool *mask, int64_t length,
int64_t dim, int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
auto dim_idx = thread_idx % dim;
auto lane_idx = (thread_idx / dim) % TB;
auto index_idx = thread_idx / (TB * dim);
if (thread_idx < numel) {
auto row_idx = __ldg(index + index_idx);
auto row_start = __ldg(rowptr + row_idx);
auto row_end = __ldg(rowptr + row_idx + 1);
for (int64_t c = lane_idx; c < row_end - row_start; c += TB) {
auto x = src[__ldg(col + row_start + c) * dim + dim_idx];
out[index_idx * dim * length + c * dim + dim_idx] = x;
// mask[index_idx * dim * length + c * dim + dim_idx] = true;
}
}
}
#define TB 4
std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor index,
int64_t length) {
CHECK_CUDA(src);
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(index);
CHECK_INPUT(src.dim() == 2);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(index.dim() == 1);
cudaSetDevice(src.get_device());
auto out = torch::zeros({index.size(0), length, src.size(-1)}, src.options());
auto mask =
torch::zeros({index.size(0), length}, src.options().dtype(torch::kBool));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] {
padded_index_select_kernel<scalar_t, TB>
<<<BLOCKS(index.numel() * src.size(-1) * TB), THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), rowptr.data_ptr<int64_t>(),
col.data_ptr<int64_t>(), index.data_ptr<int64_t>(),
out.data_ptr<scalar_t>(), mask.data_ptr<bool>(), length,
src.size(-1), index.numel() * src.size(-1) * TB);
});
return std::make_tuple(out, mask);
}
#pragma once
#include <torch/extension.h>
std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount,
torch::Tensor bin_strategy);
std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor index,
int64_t length);
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include "cuda/degree_padding_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__degree_padding(void) { return NULL; }
#endif
std::vector<torch::Tensor> bin_assignment(torch::Tensor rowcount,
torch::Tensor bin_strategy) {
if (rowcount.device().is_cuda()) {
#ifdef WITH_CUDA
return bin_assignment_cuda(rowcount, bin_strategy);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
AT_ERROR("Not implemented yet");
}
}
std::tuple<torch::Tensor, torch::Tensor>
padded_index_select(torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
torch::Tensor index, int64_t length) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return padded_index_select_cuda(src, rowptr, col, index, length);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
AT_ERROR("Not implemented yet");
}
}
static auto registry =
torch::RegisterOperators()
.op("torch_sparse::bin_assignment", &bin_assignment)
.op("torch_sparse::padded_index_select", &padded_index_select);
import pytest
import torch
from torch_sparse import SparseTensor
from torch_geometric.datasets import Planetoid
devices = [torch.device('cuda')]
@pytest.mark.parametrize('device', devices)
def test_bin_assignment(device):
rowcount = torch.tensor([2, 3, 6, 4, 5, 7, 8, 1], device=device)
bin_strategy = torch.tensor([[1, 4], [5, 8]], device=device)
perms = torch.ops.torch_sparse.bin_assignment(rowcount, bin_strategy)
print()
print(perms)
dataset = Planetoid('/tmp/Planetoid', name='PubMed')
row, col = dataset[0].edge_index
adj = SparseTensor(row=row, col=col)
rowcount = adj.storage.rowcount().to(device)
# bin_strategy = torch.tensor([[1, 7], [8, 12]], device=device)
bin_strategy = torch.tensor([[1, 4], [5, 13], [14, 22]], device=device)
bin_count = [4, 13, 22]
# src = torch.tensor([
# [1, 1],
# [2, 2],
# [3, 3],
# [4, 4],
# [5, 5],
# [6, 6],
# [7, 7],
# [8, 8],
# ], dtype=torch.float, device=device)
# rowptr = torch.tensor([0, 2, 5, 8, 10], device=device)
# col = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 1], device=device)
# index = torch.tensor([1, 2, 3], device=device)
# out, mask = torch.ops.torch_sparse.padded_index_select(
# src, rowptr, col, index, 4)
# print(out)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for i in range(102):
if i == 2:
start.record()
perms = torch.ops.torch_sparse.bin_assignment(rowcount, bin_strategy)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
print('-------------')
x = torch.randn(dataset[0].num_nodes, 512).to(device)
col = col.to(device)
for i in range(102):
if i == 2:
start.record()
x = x.index_select(0, col)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
x = torch.randn(dataset[0].num_nodes, 512).to(device)
rowptr = adj.storage.rowptr().to(device)
col = col.to(device)
for i in range(102):
if i == 2:
start.record()
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[0],
bin_count[0])
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for i in range(102):
if i == 2:
start.record()
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[1],
bin_count[1])
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for i in range(102):
if i == 2:
start.record()
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[2],
bin_count[2])
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
...@@ -9,7 +9,7 @@ expected_torch_version = (1, 4) ...@@ -9,7 +9,7 @@ expected_torch_version = (1, 4)
try: try:
for library in [ for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis',
'_rw', '_saint' '_rw', '_saint', '_degree_padding'
]: ]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment