Commit efbbce74 authored by rusty1s's avatar rusty1s
Browse files

working example

parent 2bea1c3c
......@@ -7,84 +7,72 @@
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void bin_kernel(const int64_t *rowcount, const int64_t *bin_strategy,
int64_t *bin, int64_t *one_hot, int64_t num_bins,
int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
auto count = rowcount[thread_idx];
int64_t b = -1;
for (int64_t i = 0; i < num_bins; i++) {
if (count >= __ldg(bin_strategy + 2 * i) &&
count <= __ldg(bin_strategy + 2 * i + 1)) {
b = i;
break;
}
}
__global__ void sizes_kernel(const int64_t *__restrict__ sorted_rowcount,
const int64_t *__restrict__ binptr,
int64_t *__restrict__ size,
int64_t *__restrict__ length,
const int64_t num_bins, const int64_t numel) {
for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < numel - 1; thread_idx += gridDim.x * blockDim.x) {
bin[thread_idx] = b;
if (b >= 0) {
one_hot[b * numel + thread_idx] = 1;
}
}
}
int64_t deg1 = sorted_rowcount[thread_idx];
int64_t deg2 = sorted_rowcount[thread_idx + 1];
__global__ void index_kernel(const int64_t *bin, const int64_t *cumsum,
const int64_t *nodes_per_bin, int64_t *index,
int64_t num_bins, int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
auto b = bin[thread_idx];
if (b >= 0) {
auto idx = cumsum[b * numel + thread_idx] - 1;
for (int64_t i = 0; i < b; i++) {
idx += __ldg(nodes_per_bin + i);
if (deg1 != deg2) {
for (int64_t b = 1; b <= num_bins; b++) {
if (deg1 < __ldg(binptr + b) && deg2 >= __ldg(binptr + b)) {
size[b] = thread_idx + 1;
length[b - 1] = deg1;
}
}
index[idx] = thread_idx;
}
if (thread_idx + 1 == numel - 1) {
size[num_bins] = numel;
length[num_bins - 1] = deg2;
}
}
}
std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount,
torch::Tensor bin_strategy) {
std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
bin_assignment_cuda(torch::Tensor rowcount, torch::Tensor binptr) {
CHECK_CUDA(rowcount);
CHECK_CUDA(bin_strategy);
CHECK_CUDA(binptr);
CHECK_INPUT(rowcount.dim() == 1);
CHECK_INPUT(bin_strategy.dim() == 2 && bin_strategy.size(1) == 2);
CHECK_INPUT(binptr.dim() == 1);
cudaSetDevice(rowcount.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
int64_t num_bins = bin_strategy.size(0);
auto bin = torch::empty({rowcount.numel()}, rowcount.options());
auto one_hot = torch::zeros({num_bins, rowcount.numel()}, rowcount.options());
torch::Tensor sorted_rowcount, perm;
std::tie(sorted_rowcount, perm) = rowcount.sort();
auto stream = at::cuda::getCurrentCUDAStream();
bin_kernel<<<BLOCKS(rowcount.numel()), THREADS, 0, stream>>>(
rowcount.data_ptr<int64_t>(), bin_strategy.data_ptr<int64_t>(),
bin.data_ptr<int64_t>(), one_hot.data_ptr<int64_t>(), num_bins,
rowcount.numel());
auto size = torch::zeros({binptr.numel()}, binptr.options());
auto length = torch::zeros({binptr.numel() - 1}, binptr.options());
auto cumsum = one_hot.cumsum(1);
auto d_nodes_per_bin = cumsum.select(1, rowcount.numel() - 1).contiguous();
auto h_nodes_per_bin = d_nodes_per_bin.cpu();
sizes_kernel<<<std::min(BLOCKS(rowcount.numel() - 1), mpc * 8), THREADS, 0,
stream>>>(sorted_rowcount.data_ptr<int64_t>(),
binptr.data_ptr<int64_t>(), size.data_ptr<int64_t>(),
length.data_ptr<int64_t>(), length.numel(),
rowcount.numel());
auto h_size = h_nodes_per_bin.sum().data_ptr<int64_t>()[0];
auto index = torch::empty({h_size}, rowcount.options());
size = size.cpu();
size = size.narrow(0, 1, length.numel()) - size.narrow(0, 0, length.numel());
auto sizes = at::IntArrayRef(size.data_ptr<int64_t>(), size.numel());
index_kernel<<<BLOCKS(bin.numel()), THREADS, 0, stream>>>(
bin.data_ptr<int64_t>(), cumsum.data_ptr<int64_t>(),
d_nodes_per_bin.data_ptr<int64_t>(), index.data_ptr<int64_t>(), num_bins,
rowcount.numel());
length = length.cpu();
int64_t *length_data = length.data_ptr<int64_t>();
std::vector<int64_t> lengths(length.numel());
std::copy(length_data, length_data + length.numel(), lengths.begin());
auto sizes = at::IntArrayRef(h_nodes_per_bin.data_ptr<int64_t>(), num_bins);
return index.split_with_sizes(sizes);
return std::make_tuple(perm.split_with_sizes(sizes), lengths);
}
__global__ void padded_mask_select_kernel(const int64_t *rowptr,
const int64_t *col,
const int64_t *index,
int64_t *out_idx, bool *mask,
int64_t length, int64_t numel) {
__global__ void padded_mask_select_kernel(
const int64_t *__restrict__ rowptr, const int64_t *__restrict__ col,
const int64_t *__restrict__ index, int64_t *__restrict__ out_idx,
bool *__restrict__ mask, const int64_t length, const int64_t numel) {
int64_t lane_idx, row_idx, row_start, row_end, col_idx;
for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -104,10 +92,11 @@ __global__ void padded_mask_select_kernel(const int64_t *rowptr,
}
template <typename scalar_t>
__global__ void padded_index_select_kernel(const scalar_t *src,
const int64_t *index, scalar_t *out,
scalar_t fill_value, int64_t dim,
int64_t numel) {
__global__ void
padded_index_select_kernel(const scalar_t *__restrict__ src,
const int64_t *__restrict__ index,
scalar_t *__restrict__ out, scalar_t fill_value,
const int64_t dim, const int64_t numel) {
int64_t index_idx, dim_idx, col;
for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -136,22 +125,22 @@ padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(index.dim() == 1);
CHECK_INPUT(fill_value.numel() == 1);
cudaSetDevice(src.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
auto out_idx = torch::empty({index.size(0), length}, index.options());
auto out = torch::empty({index.size(0), length, src.size(-1)}, src.options());
auto mask = torch::empty({index.size(0), length, 1},
src.options().dtype(torch::kBool));
auto stream = at::cuda::getCurrentCUDAStream();
int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
padded_mask_select_kernel<<<
std::min((out_idx.numel() + THREADS - 1) / THREADS, mpc * 8), THREADS, 0,
stream>>>(rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
index.data_ptr<int64_t>(), out_idx.data_ptr<int64_t>(),
mask.data_ptr<bool>(), length, out_idx.numel());
auto out = torch::empty({index.size(0), length, src.size(-1)}, src.options());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] {
scalar_t *fill;
if (fill_value.is_cuda()) {
......
......@@ -2,9 +2,15 @@
#include <torch/extension.h>
std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount,
torch::Tensor bin_strategy);
std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
bin_assignment_cuda(torch::Tensor rowcount, torch::Tensor binptr);
std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor index, int64_t length,
torch::Tensor fill_value);
// std::tuple<torch::Tensor, torch::Tensor> padded_index_select_cuda2(
// torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
// torch::Tensor bin, torch::Tensor index, std::vector<int64_t> node_counts,
// std::vector<int64_t> lengths, torch::Tensor fill_value);
......@@ -9,11 +9,11 @@
PyMODINIT_FUNC PyInit__degree_padding(void) { return NULL; }
#endif
std::vector<torch::Tensor> bin_assignment(torch::Tensor rowcount,
torch::Tensor bin_strategy) {
std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
bin_assignment(torch::Tensor rowcount, torch::Tensor binptr) {
if (rowcount.device().is_cuda()) {
#ifdef WITH_CUDA
return bin_assignment_cuda(rowcount, bin_strategy);
return bin_assignment_cuda(rowcount, binptr);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
......@@ -38,7 +38,28 @@ padded_index_select(torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
}
}
// std::tuple<torch::Tensor, torch::Tensor>
// padded_index_select2(torch::Tensor src, torch::Tensor rowptr, torch::Tensor
// col,
// torch::Tensor bin, torch::Tensor index,
// std::vector<int64_t> node_counts,
// std::vector<int64_t> lengths, torch::Tensor fill_value)
// {
// if (src.device().is_cuda()) {
// #ifdef WITH_CUDA
// return padded_index_select_cuda2(src, rowptr, col, bin, index,
// node_counts,
// lengths, fill_value);
// #else
// AT_ERROR("Not compiled with CUDA support");
// #endif
// } else {
// AT_ERROR("Not implemented yet");
// }
// }
static auto registry =
torch::RegisterOperators()
.op("torch_sparse::bin_assignment", &bin_assignment)
.op("torch_sparse::padded_index_select", &padded_index_select);
// .op("torch_sparse::padded_index_select2", &padded_index_select2);
......@@ -9,34 +9,71 @@ devices = [torch.device('cuda')]
@pytest.mark.parametrize('device', devices)
def test_padded_index_select(device):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
dataset = Planetoid('/tmp/Planetoid', name='PubMed')
data = dataset[0]
row, col = data.edge_index.to(device)
row = torch.arange(data.num_nodes).view(-1, 1).repeat(1, 4).view(-1)
col = torch.randint(0, data.num_nodes, (row.size(0), ))
row, col = row.to(device), col.to(device)
adj = SparseTensor(row=row, col=col)
rowcount = adj.storage.rowcount().to(device)
rowptr = adj.storage.rowptr().to(device)
bin_strategy = torch.tensor([[1, 4]]).to(device)
# bin_strategy = torch.tensor([[1, 5], [6, 12], [13, 19], [20, 30]],
# device=device)
perms = torch.ops.torch_sparse.bin_assignment(rowcount, bin_strategy)
lengths = bin_strategy[:, 1].view(-1).tolist()
print(lengths)
bin_strategy = torch.tensor([[1, 4], [4, 11], [11, 30]]).to(device)
binptr = torch.tensor([0, 4, 11, 30, 50, 80, 120, 140, 2000]).to(device)
deg = degree(row, dtype=torch.long)
print(deg.size(), deg.min(), deg.float().mean(), deg.max())
bins = torch.bincount(deg)
print(bins)
nonzero = bins.nonzero().flatten()
print(nonzero)
print(bins[nonzero])
print(bins.size())
print(bins[:200])
for i in range(110):
if i == 10:
start.record()
perms, lengths = torch.ops.torch_sparse.bin_assignment(
rowcount, binptr)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
return
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for i in range(110):
if i == 10:
start.record()
rowcount.sort()
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
x = torch.randn(data.num_nodes, 128).to(device)
for i in range(110):
if i == 10:
start.record()
x.index_select(0, col)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
for perm, length in zip(perms, lengths):
torch.ops.torch_sparse.padded_index_select(x, rowptr, col,
perm, length,
torch.tensor(0.))
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for perm, length in zip(perms, lengths):
out, mask = torch.ops.torch_sparse.padded_index_select(
x, rowptr, col, perm, length, torch.tensor(0.))
print(out.size(), mask.size(), out.numel(), (out != 0).sum().item())
return
lengths = bin_strategy[:, 1].view(-1).tolist()
for dim in [32, 64, 128, 256, 512, 1024]:
print(f'--- Dim: {dim} ---')
......@@ -45,6 +82,10 @@ def test_padded_index_select(device):
for i in range(110):
if i == 10:
start.record()
perms = torch.ops.torch_sparse.bin_assignment(
rowcount, bin_strategy)
print(perms)
return
for perm, length in zip(perms, lengths):
out1, _ = torch.ops.torch_sparse.padded_index_select(
x, rowptr, col, perm, length, torch.tensor(0.))
......@@ -67,4 +108,3 @@ def test_padded_index_select(device):
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
print(torch.allclose(out1.view(-1, dim), out3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment