Commit 631df924 authored by rusty1s's avatar rusty1s
Browse files

clean up

parent 56de8a6b
#include "degree_padding_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void sizes_kernel(const int64_t *__restrict__ sorted_rowcount,
const int64_t *__restrict__ binptr,
int64_t *__restrict__ size,
int64_t *__restrict__ length,
const int64_t num_bins, const int64_t numel) {
for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < numel - 1; thread_idx += gridDim.x * blockDim.x) {
int64_t deg1 = sorted_rowcount[thread_idx];
int64_t deg2 = sorted_rowcount[thread_idx + 1];
if (deg1 != deg2) {
for (int64_t b = 1; b <= num_bins; b++) {
if (deg1 < __ldg(binptr + b) && deg2 >= __ldg(binptr + b)) {
size[b] = thread_idx + 1;
length[b - 1] = deg1;
}
}
}
if (thread_idx + 1 == numel - 1) {
size[num_bins] = numel;
length[num_bins - 1] = deg2;
}
}
}
std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
bin_assignment_cuda(torch::Tensor rowcount, torch::Tensor binptr) {
CHECK_CUDA(rowcount);
CHECK_CUDA(binptr);
CHECK_INPUT(rowcount.dim() == 1);
CHECK_INPUT(binptr.dim() == 1);
cudaSetDevice(rowcount.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
torch::Tensor sorted_rowcount, perm;
std::tie(sorted_rowcount, perm) = rowcount.sort();
auto size = torch::zeros({binptr.numel()}, binptr.options());
auto length = torch::zeros({binptr.numel() - 1}, binptr.options());
sizes_kernel<<<std::min(BLOCKS(rowcount.numel() - 1), mpc * 8), THREADS, 0,
stream>>>(sorted_rowcount.data_ptr<int64_t>(),
binptr.data_ptr<int64_t>(), size.data_ptr<int64_t>(),
length.data_ptr<int64_t>(), length.numel(),
rowcount.numel());
size = size.cpu();
size = size.narrow(0, 1, length.numel()) - size.narrow(0, 0, length.numel());
auto sizes = at::IntArrayRef(size.data_ptr<int64_t>(), size.numel());
length = length.cpu();
int64_t *length_data = length.data_ptr<int64_t>();
std::vector<int64_t> lengths(length.numel());
std::copy(length_data, length_data + length.numel(), lengths.begin());
return std::make_tuple(perm.split_with_sizes(sizes), lengths);
}
__global__ void padded_mask_select_kernel(
const int64_t *__restrict__ rowptr, const int64_t *__restrict__ col,
const int64_t *__restrict__ index, int64_t *__restrict__ out_idx,
bool *__restrict__ mask, const int64_t length, const int64_t numel) {
int64_t lane_idx, row_idx, row_start, row_end, col_idx;
for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < numel; thread_idx += gridDim.x * blockDim.x) {
lane_idx = thread_idx % length;
row_idx = index[thread_idx / length];
row_start = rowptr[row_idx];
row_end = rowptr[row_idx + 1];
col_idx = -1;
if (lane_idx < row_end - row_start) {
col_idx = col[row_start + lane_idx];
}
out_idx[thread_idx] = col_idx;
mask[thread_idx] = col_idx == -1;
}
}
template <typename scalar_t>
__global__ void
padded_index_select_kernel(const scalar_t *__restrict__ src,
const int64_t *__restrict__ index,
scalar_t *__restrict__ out, scalar_t fill_value,
const int64_t dim, const int64_t numel) {
int64_t index_idx, dim_idx, col;
for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < numel; thread_idx += gridDim.x * blockDim.x) {
index_idx = thread_idx / dim;
dim_idx = thread_idx % dim;
col = __ldg(index + index_idx);
if (col >= 0) {
fill_value = src[col * dim + dim_idx];
}
out[thread_idx] = fill_value;
}
}
std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor index, int64_t length,
torch::Tensor fill_value) {
CHECK_CUDA(src);
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(index);
CHECK_INPUT(src.dim() == 2);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(index.dim() == 1);
CHECK_INPUT(fill_value.numel() == 1);
cudaSetDevice(src.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
auto out_idx = torch::empty({index.size(0), length}, index.options());
auto out = torch::empty({index.size(0), length, src.size(-1)}, src.options());
auto mask = torch::empty({index.size(0), length, 1},
src.options().dtype(torch::kBool));
padded_mask_select_kernel<<<
std::min((out_idx.numel() + THREADS - 1) / THREADS, mpc * 8), THREADS, 0,
stream>>>(rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
index.data_ptr<int64_t>(), out_idx.data_ptr<int64_t>(),
mask.data_ptr<bool>(), length, out_idx.numel());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] {
scalar_t *fill;
if (fill_value.is_cuda()) {
fill = (scalar_t *)malloc(sizeof(scalar_t));
cudaMemcpy(fill, fill_value.data_ptr<scalar_t>(), sizeof(scalar_t),
cudaMemcpyDeviceToHost);
} else {
fill = fill_value.data_ptr<scalar_t>();
}
padded_index_select_kernel<scalar_t>
<<<std::min((out.numel() + THREADS - 1) / THREADS, mpc * 8), THREADS, 0,
stream>>>(src.data_ptr<scalar_t>(), out_idx.data_ptr<int64_t>(),
out.data_ptr<scalar_t>(), fill[0], src.size(-1),
out.numel());
});
return std::make_tuple(out, mask);
}
#pragma once
#include <torch/extension.h>
std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
bin_assignment_cuda(torch::Tensor rowcount, torch::Tensor binptr);
std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor index, int64_t length,
torch::Tensor fill_value);
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include "cuda/degree_padding_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__degree_padding(void) { return NULL; }
#endif
std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
bin_assignment(torch::Tensor rowcount, torch::Tensor binptr) {
if (rowcount.device().is_cuda()) {
#ifdef WITH_CUDA
return bin_assignment_cuda(rowcount, binptr);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
AT_ERROR("Not implemented yet");
}
}
std::tuple<torch::Tensor, torch::Tensor>
padded_index_select(torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
torch::Tensor index, int64_t length,
torch::Tensor fill_value) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return padded_index_select_cuda(src, rowptr, col, index, length,
fill_value);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
AT_ERROR("Not implemented yet");
}
}
// std::tuple<torch::Tensor, torch::Tensor>
// padded_index_select2(torch::Tensor src, torch::Tensor rowptr, torch::Tensor
// col,
// torch::Tensor bin, torch::Tensor index,
// std::vector<int64_t> node_counts,
// std::vector<int64_t> lengths, torch::Tensor fill_value)
// {
// if (src.device().is_cuda()) {
// #ifdef WITH_CUDA
// return padded_index_select_cuda2(src, rowptr, col, bin, index,
// node_counts,
// lengths, fill_value);
// #else
// AT_ERROR("Not compiled with CUDA support");
// #endif
// } else {
// AT_ERROR("Not implemented yet");
// }
// }
static auto registry =
torch::RegisterOperators()
.op("torch_sparse::bin_assignment", &bin_assignment)
.op("torch_sparse::padded_index_select", &padded_index_select);
// .op("torch_sparse::padded_index_select2", &padded_index_select2);
......@@ -80,7 +80,7 @@ tests_require = ['pytest', 'pytest-cov']
setup(
name='torch_sparse',
version='0.6.1',
version='0.6.2',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_sparse',
......
import pytest
import torch
from torch_sparse import SparseTensor
from torch_geometric.datasets import Planetoid
devices = [torch.device('cuda')]
@pytest.mark.parametrize('device', devices)
def test_bin_assignment(device):
rowcount = torch.tensor([2, 3, 6, 4, 5, 7, 8, 1], device=device)
bin_strategy = torch.tensor([[1, 4], [5, 8]], device=device)
perms = torch.ops.torch_sparse.bin_assignment(rowcount, bin_strategy)
print()
print(perms)
dataset = Planetoid('/tmp/Planetoid', name='PubMed')
row, col = dataset[0].edge_index
adj = SparseTensor(row=row, col=col)
rowcount = adj.storage.rowcount().to(device)
# bin_strategy = torch.tensor([[1, 7], [8, 12]], device=device)
bin_strategy = torch.tensor([[1, 4], [5, 13], [14, 22]], device=device)
bin_count = [4, 13, 22]
# src = torch.tensor([
# [1, 1],
# [2, 2],
# [3, 3],
# [4, 4],
# [5, 5],
# [6, 6],
# [7, 7],
# [8, 8],
# ], dtype=torch.float, device=device)
# rowptr = torch.tensor([0, 2, 5, 8, 10], device=device)
# col = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 1], device=device)
# index = torch.tensor([1, 2, 3], device=device)
# out, mask = torch.ops.torch_sparse.padded_index_select(
# src, rowptr, col, index, 4)
# print(out)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for i in range(102):
if i == 2:
start.record()
perms = torch.ops.torch_sparse.bin_assignment(rowcount, bin_strategy)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
print('-------------')
x = torch.randn(dataset[0].num_nodes, 512).to(device)
col = col.to(device)
for i in range(102):
if i == 2:
start.record()
x = x.index_select(0, col)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
x = torch.randn(dataset[0].num_nodes, 512).to(device)
rowptr = adj.storage.rowptr().to(device)
col = col.to(device)
for i in range(102):
if i == 2:
start.record()
for perm, count in zip(perms, bin_count):
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perm,
count, torch.tensor(0.))
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
print('-----------')
for i in range(102):
if i == 2:
start.record()
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[0],
bin_count[0],
torch.tensor(0.))
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for i in range(102):
if i == 2:
start.record()
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[1],
bin_count[1],
torch.tensor(0.))
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for i in range(102):
if i == 2:
start.record()
torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[2],
bin_count[2],
torch.tensor(0.))
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import degree
from .utils import grad_dtypes, tensor
devices = [torch.device('cuda')]
@pytest.mark.parametrize('device', devices)
def test_padded_index_select(device):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_padded_index_select(dtype, device):
row = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 3])
col = torch.tensor([0, 1, 2, 3, 0, 2, 3, 1, 3, 2])
idx = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
adj = SparseTensor(row=row, col=col).to(device)
rowptr, col, _ = adj.csr()
rowcount = adj.storage.rowcount()
binptr = torch.tensor([0, 3, 5], device=device)
data = torch.ops.torch_sparse.padded_index(adj.storage.rowptr(),
adj.storage.col(),
adj.storage.rowcount(), binptr)
node_perm, row_perm, col_perm, mask, size, length = data
data = torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr)
node_perm, row_perm, col_perm, mask, node_size, edge_size = data
print('node perm', node_perm)
print('row perm', row_perm)
print('col perm', col_perm)
print('mask', mask)
print('size', size)
print('length', length)
assert node_perm.tolist() == [2, 3, 0, 1]
assert row_perm.tolist() == [2, 2, 3, -1, 0, 0, 0, 0, 1, 1, 1, -1]
assert col_perm.tolist() == [1, 3, 2, -1, 0, 1, 2, 3, 0, 2, 3, -1]
assert mask.long().tolist() == [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]
assert node_size == [2, 2]
assert edge_size == [4, 8]
x = torch.tensor([[0], [1], [2], [3]], dtype=torch.float, device=device)
x.requires_grad_()
out = torch.ops.torch_sparse.padded_index_select(x, col_perm,
torch.tensor(0.))
print(out)
x = tensor([0, 1, 2, 3], dtype, device).view(-1, 1).requires_grad_()
fill_value = torch.tensor(0., dtype=dtype)
out = torch.ops.torch_sparse.padded_index_select(x, col_perm, fill_value)
assert out.flatten().tolist() == [1, 3, 2, 0, 0, 1, 2, 3, 0, 2, 3, 0]
grad_out = torch.tensor(
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11]],
dtype=torch.float, device=device)
out.backward(grad_out)
print(x.grad)
grad_out = tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype, device)
out.backward(grad_out.view(-1, 1))
assert x.grad.flatten().tolist() == [12, 5, 17, 18]
@pytest.mark.parametrize('device', devices)
def test_padded_index_select_runtime(device):
return
from torch_geometric.datasets import Planetoid
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
dataset = Planetoid('/tmp/Planetoid', name='PubMed')
data = dataset[0]
......@@ -51,26 +56,6 @@ def test_padded_index_select(device):
rowptr = adj.storage.rowptr().to(device)
binptr = torch.tensor([0, 4, 11, 30, 50, 80, 120, 140, 2000]).to(device)
# deg = degree(row, dtype=torch.long)
# bins = torch.bincount(deg)
# print(bins.size())
# print(bins[:200])
# for i in range(110):
# if i == 10:
# start.record()
# perms, lengths = torch.ops.torch_sparse.bin_assignment(
# rowcount, binptr)
# end.record()
# torch.cuda.synchronize()
# print('bin assignment', start.elapsed_time(end))
# idx, mask, size, length, offset = torch.ops.torch_sparse.padded_index(
# rowptr, rowcount, binptr)
# print(size)
# print(length)
# print(offset)
# print(mask[:10])
# print(idx[:10])
x = torch.randn(adj.size(0), 512).to(device)
data = torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr)
......@@ -100,39 +85,6 @@ def test_padded_index_select(device):
torch.cuda.synchronize()
print('padded index select', start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
torch.repeat_interleave(rowcount, rowcount)
end.record()
torch.cuda.synchronize()
print('repeat', start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
rowcount.cumsum(0)
end.record()
torch.cuda.synchronize()
print('cumsum', start.elapsed_time(end))
rowcount2 = rowcount.unsqueeze(1).repeat(1, 5).contiguous()
for i in range(110):
if i == 10:
start.record()
rowcount2.cumsum(0)
end.record()
torch.cuda.synchronize()
print('cumsum', start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
rowcount.sort()
end.record()
torch.cuda.synchronize()
print('sort', start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
......@@ -140,56 +92,3 @@ def test_padded_index_select(device):
end.record()
torch.cuda.synchronize()
print('index_select', start.elapsed_time(end))
return
for i in range(110):
if i == 10:
start.record()
for perm, length in zip(perms, lengths):
torch.ops.torch_sparse.padded_index_select(x, rowptr, col,
perm, length,
torch.tensor(0.))
end.record()
torch.cuda.synchronize()
print('padded_index_select', start.elapsed_time(end))
for perm, length in zip(perms, lengths):
out, mask = torch.ops.torch_sparse.padded_index_select(
x, rowptr, col, perm, length, torch.tensor(0.))
print(out.size(), mask.size(), out.numel(), (out != 0).sum().item())
lengths = bin_strategy[:, 1].view(-1).tolist()
for dim in [32, 64, 128, 256, 512, 1024]:
print(f'--- Dim: {dim} ---')
x = torch.randn(adj.size(0), dim).to(device)
for i in range(110):
if i == 10:
start.record()
perms = torch.ops.torch_sparse.bin_assignment(
rowcount, bin_strategy)
print(perms)
return
for perm, length in zip(perms, lengths):
out1, _ = torch.ops.torch_sparse.padded_index_select(
x, rowptr, col, perm, length, torch.tensor(0.))
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
out2 = x.index_select(0, row)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
out3 = x.index_select(0, col)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
......@@ -3,7 +3,7 @@ import os.path as osp
import torch
__version__ = '0.6.1'
__version__ = '0.6.2'
expected_torch_version = (1, 4)
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment