"examples/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "5008af2210cc08a72719d5eb81c8e080dc3de085"
Commit efbbce74 authored by rusty1s's avatar rusty1s
Browse files

working example

parent 2bea1c3c
...@@ -7,84 +7,72 @@ ...@@ -7,84 +7,72 @@
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void bin_kernel(const int64_t *rowcount, const int64_t *bin_strategy, __global__ void sizes_kernel(const int64_t *__restrict__ sorted_rowcount,
int64_t *bin, int64_t *one_hot, int64_t num_bins, const int64_t *__restrict__ binptr,
int64_t numel) { int64_t *__restrict__ size,
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t *__restrict__ length,
if (thread_idx < numel) { const int64_t num_bins, const int64_t numel) {
auto count = rowcount[thread_idx]; for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < numel - 1; thread_idx += gridDim.x * blockDim.x) {
int64_t b = -1;
for (int64_t i = 0; i < num_bins; i++) {
if (count >= __ldg(bin_strategy + 2 * i) &&
count <= __ldg(bin_strategy + 2 * i + 1)) {
b = i;
break;
}
}
bin[thread_idx] = b; int64_t deg1 = sorted_rowcount[thread_idx];
if (b >= 0) { int64_t deg2 = sorted_rowcount[thread_idx + 1];
one_hot[b * numel + thread_idx] = 1;
}
}
}
__global__ void index_kernel(const int64_t *bin, const int64_t *cumsum, if (deg1 != deg2) {
const int64_t *nodes_per_bin, int64_t *index, for (int64_t b = 1; b <= num_bins; b++) {
int64_t num_bins, int64_t numel) { if (deg1 < __ldg(binptr + b) && deg2 >= __ldg(binptr + b)) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x; size[b] = thread_idx + 1;
if (thread_idx < numel) { length[b - 1] = deg1;
auto b = bin[thread_idx]; }
if (b >= 0) {
auto idx = cumsum[b * numel + thread_idx] - 1;
for (int64_t i = 0; i < b; i++) {
idx += __ldg(nodes_per_bin + i);
} }
index[idx] = thread_idx; }
if (thread_idx + 1 == numel - 1) {
size[num_bins] = numel;
length[num_bins - 1] = deg2;
} }
} }
} }
std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount, std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
torch::Tensor bin_strategy) { bin_assignment_cuda(torch::Tensor rowcount, torch::Tensor binptr) {
CHECK_CUDA(rowcount); CHECK_CUDA(rowcount);
CHECK_CUDA(bin_strategy); CHECK_CUDA(binptr);
CHECK_INPUT(rowcount.dim() == 1); CHECK_INPUT(rowcount.dim() == 1);
CHECK_INPUT(bin_strategy.dim() == 2 && bin_strategy.size(1) == 2); CHECK_INPUT(binptr.dim() == 1);
cudaSetDevice(rowcount.get_device()); cudaSetDevice(rowcount.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
int64_t num_bins = bin_strategy.size(0); torch::Tensor sorted_rowcount, perm;
auto bin = torch::empty({rowcount.numel()}, rowcount.options()); std::tie(sorted_rowcount, perm) = rowcount.sort();
auto one_hot = torch::zeros({num_bins, rowcount.numel()}, rowcount.options());
auto stream = at::cuda::getCurrentCUDAStream(); auto size = torch::zeros({binptr.numel()}, binptr.options());
bin_kernel<<<BLOCKS(rowcount.numel()), THREADS, 0, stream>>>( auto length = torch::zeros({binptr.numel() - 1}, binptr.options());
rowcount.data_ptr<int64_t>(), bin_strategy.data_ptr<int64_t>(),
bin.data_ptr<int64_t>(), one_hot.data_ptr<int64_t>(), num_bins,
rowcount.numel());
auto cumsum = one_hot.cumsum(1); sizes_kernel<<<std::min(BLOCKS(rowcount.numel() - 1), mpc * 8), THREADS, 0,
auto d_nodes_per_bin = cumsum.select(1, rowcount.numel() - 1).contiguous(); stream>>>(sorted_rowcount.data_ptr<int64_t>(),
auto h_nodes_per_bin = d_nodes_per_bin.cpu(); binptr.data_ptr<int64_t>(), size.data_ptr<int64_t>(),
length.data_ptr<int64_t>(), length.numel(),
rowcount.numel());
auto h_size = h_nodes_per_bin.sum().data_ptr<int64_t>()[0]; size = size.cpu();
auto index = torch::empty({h_size}, rowcount.options()); size = size.narrow(0, 1, length.numel()) - size.narrow(0, 0, length.numel());
auto sizes = at::IntArrayRef(size.data_ptr<int64_t>(), size.numel());
index_kernel<<<BLOCKS(bin.numel()), THREADS, 0, stream>>>( length = length.cpu();
bin.data_ptr<int64_t>(), cumsum.data_ptr<int64_t>(), int64_t *length_data = length.data_ptr<int64_t>();
d_nodes_per_bin.data_ptr<int64_t>(), index.data_ptr<int64_t>(), num_bins, std::vector<int64_t> lengths(length.numel());
rowcount.numel()); std::copy(length_data, length_data + length.numel(), lengths.begin());
auto sizes = at::IntArrayRef(h_nodes_per_bin.data_ptr<int64_t>(), num_bins); return std::make_tuple(perm.split_with_sizes(sizes), lengths);
return index.split_with_sizes(sizes);
} }
__global__ void padded_mask_select_kernel(const int64_t *rowptr, __global__ void padded_mask_select_kernel(
const int64_t *col, const int64_t *__restrict__ rowptr, const int64_t *__restrict__ col,
const int64_t *index, const int64_t *__restrict__ index, int64_t *__restrict__ out_idx,
int64_t *out_idx, bool *mask, bool *__restrict__ mask, const int64_t length, const int64_t numel) {
int64_t length, int64_t numel) {
int64_t lane_idx, row_idx, row_start, row_end, col_idx; int64_t lane_idx, row_idx, row_start, row_end, col_idx;
for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x; for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
...@@ -104,10 +92,11 @@ __global__ void padded_mask_select_kernel(const int64_t *rowptr, ...@@ -104,10 +92,11 @@ __global__ void padded_mask_select_kernel(const int64_t *rowptr,
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void padded_index_select_kernel(const scalar_t *src, __global__ void
const int64_t *index, scalar_t *out, padded_index_select_kernel(const scalar_t *__restrict__ src,
scalar_t fill_value, int64_t dim, const int64_t *__restrict__ index,
int64_t numel) { scalar_t *__restrict__ out, scalar_t fill_value,
const int64_t dim, const int64_t numel) {
int64_t index_idx, dim_idx, col; int64_t index_idx, dim_idx, col;
for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x; for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
...@@ -136,22 +125,22 @@ padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr, ...@@ -136,22 +125,22 @@ padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
CHECK_INPUT(col.dim() == 1); CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(index.dim() == 1); CHECK_INPUT(index.dim() == 1);
CHECK_INPUT(fill_value.numel() == 1); CHECK_INPUT(fill_value.numel() == 1);
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
auto out_idx = torch::empty({index.size(0), length}, index.options()); auto out_idx = torch::empty({index.size(0), length}, index.options());
auto out = torch::empty({index.size(0), length, src.size(-1)}, src.options());
auto mask = torch::empty({index.size(0), length, 1}, auto mask = torch::empty({index.size(0), length, 1},
src.options().dtype(torch::kBool)); src.options().dtype(torch::kBool));
auto stream = at::cuda::getCurrentCUDAStream();
int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
padded_mask_select_kernel<<< padded_mask_select_kernel<<<
std::min((out_idx.numel() + THREADS - 1) / THREADS, mpc * 8), THREADS, 0, std::min((out_idx.numel() + THREADS - 1) / THREADS, mpc * 8), THREADS, 0,
stream>>>(rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), stream>>>(rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
index.data_ptr<int64_t>(), out_idx.data_ptr<int64_t>(), index.data_ptr<int64_t>(), out_idx.data_ptr<int64_t>(),
mask.data_ptr<bool>(), length, out_idx.numel()); mask.data_ptr<bool>(), length, out_idx.numel());
auto out = torch::empty({index.size(0), length, src.size(-1)}, src.options());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] {
scalar_t *fill; scalar_t *fill;
if (fill_value.is_cuda()) { if (fill_value.is_cuda()) {
......
...@@ -2,9 +2,15 @@ ...@@ -2,9 +2,15 @@
#include <torch/extension.h> #include <torch/extension.h>
std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount, std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
torch::Tensor bin_strategy); bin_assignment_cuda(torch::Tensor rowcount, torch::Tensor binptr);
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr, padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor index, int64_t length, torch::Tensor col, torch::Tensor index, int64_t length,
torch::Tensor fill_value); torch::Tensor fill_value);
// std::tuple<torch::Tensor, torch::Tensor> padded_index_select_cuda2(
// torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
// torch::Tensor bin, torch::Tensor index, std::vector<int64_t> node_counts,
// std::vector<int64_t> lengths, torch::Tensor fill_value);
...@@ -9,11 +9,11 @@ ...@@ -9,11 +9,11 @@
PyMODINIT_FUNC PyInit__degree_padding(void) { return NULL; } PyMODINIT_FUNC PyInit__degree_padding(void) { return NULL; }
#endif #endif
std::vector<torch::Tensor> bin_assignment(torch::Tensor rowcount, std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
torch::Tensor bin_strategy) { bin_assignment(torch::Tensor rowcount, torch::Tensor binptr) {
if (rowcount.device().is_cuda()) { if (rowcount.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return bin_assignment_cuda(rowcount, bin_strategy); return bin_assignment_cuda(rowcount, binptr);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
...@@ -38,7 +38,28 @@ padded_index_select(torch::Tensor src, torch::Tensor rowptr, torch::Tensor col, ...@@ -38,7 +38,28 @@ padded_index_select(torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
} }
} }
// std::tuple<torch::Tensor, torch::Tensor>
// padded_index_select2(torch::Tensor src, torch::Tensor rowptr, torch::Tensor
// col,
// torch::Tensor bin, torch::Tensor index,
// std::vector<int64_t> node_counts,
// std::vector<int64_t> lengths, torch::Tensor fill_value)
// {
// if (src.device().is_cuda()) {
// #ifdef WITH_CUDA
// return padded_index_select_cuda2(src, rowptr, col, bin, index,
// node_counts,
// lengths, fill_value);
// #else
// AT_ERROR("Not compiled with CUDA support");
// #endif
// } else {
// AT_ERROR("Not implemented yet");
// }
// }
static auto registry = static auto registry =
torch::RegisterOperators() torch::RegisterOperators()
.op("torch_sparse::bin_assignment", &bin_assignment) .op("torch_sparse::bin_assignment", &bin_assignment)
.op("torch_sparse::padded_index_select", &padded_index_select); .op("torch_sparse::padded_index_select", &padded_index_select);
// .op("torch_sparse::padded_index_select2", &padded_index_select2);
...@@ -9,34 +9,71 @@ devices = [torch.device('cuda')] ...@@ -9,34 +9,71 @@ devices = [torch.device('cuda')]
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_padded_index_select(device): def test_padded_index_select(device):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
dataset = Planetoid('/tmp/Planetoid', name='PubMed') dataset = Planetoid('/tmp/Planetoid', name='PubMed')
data = dataset[0] data = dataset[0]
row, col = data.edge_index.to(device) row, col = data.edge_index.to(device)
row = torch.arange(data.num_nodes).view(-1, 1).repeat(1, 4).view(-1)
col = torch.randint(0, data.num_nodes, (row.size(0), ))
row, col = row.to(device), col.to(device)
adj = SparseTensor(row=row, col=col) adj = SparseTensor(row=row, col=col)
rowcount = adj.storage.rowcount().to(device) rowcount = adj.storage.rowcount().to(device)
rowptr = adj.storage.rowptr().to(device) rowptr = adj.storage.rowptr().to(device)
bin_strategy = torch.tensor([[1, 4]]).to(device)
# bin_strategy = torch.tensor([[1, 5], [6, 12], [13, 19], [20, 30]], bin_strategy = torch.tensor([[1, 4], [4, 11], [11, 30]]).to(device)
# device=device) binptr = torch.tensor([0, 4, 11, 30, 50, 80, 120, 140, 2000]).to(device)
perms = torch.ops.torch_sparse.bin_assignment(rowcount, bin_strategy)
lengths = bin_strategy[:, 1].view(-1).tolist()
print(lengths)
deg = degree(row, dtype=torch.long) deg = degree(row, dtype=torch.long)
print(deg.size(), deg.min(), deg.float().mean(), deg.max())
bins = torch.bincount(deg) bins = torch.bincount(deg)
print(bins) print(bins.size())
nonzero = bins.nonzero().flatten() print(bins[:200])
print(nonzero) for i in range(110):
print(bins[nonzero]) if i == 10:
start.record()
perms, lengths = torch.ops.torch_sparse.bin_assignment(
rowcount, binptr)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
return
start = torch.cuda.Event(enable_timing=True) for i in range(110):
end = torch.cuda.Event(enable_timing=True) if i == 10:
start.record()
rowcount.sort()
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
x = torch.randn(data.num_nodes, 128).to(device)
for i in range(110):
if i == 10:
start.record()
x.index_select(0, col)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
for perm, length in zip(perms, lengths):
torch.ops.torch_sparse.padded_index_select(x, rowptr, col,
perm, length,
torch.tensor(0.))
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
for perm, length in zip(perms, lengths):
out, mask = torch.ops.torch_sparse.padded_index_select(
x, rowptr, col, perm, length, torch.tensor(0.))
print(out.size(), mask.size(), out.numel(), (out != 0).sum().item())
return
lengths = bin_strategy[:, 1].view(-1).tolist()
for dim in [32, 64, 128, 256, 512, 1024]: for dim in [32, 64, 128, 256, 512, 1024]:
print(f'--- Dim: {dim} ---') print(f'--- Dim: {dim} ---')
...@@ -45,6 +82,10 @@ def test_padded_index_select(device): ...@@ -45,6 +82,10 @@ def test_padded_index_select(device):
for i in range(110): for i in range(110):
if i == 10: if i == 10:
start.record() start.record()
perms = torch.ops.torch_sparse.bin_assignment(
rowcount, bin_strategy)
print(perms)
return
for perm, length in zip(perms, lengths): for perm, length in zip(perms, lengths):
out1, _ = torch.ops.torch_sparse.padded_index_select( out1, _ = torch.ops.torch_sparse.padded_index_select(
x, rowptr, col, perm, length, torch.tensor(0.)) x, rowptr, col, perm, length, torch.tensor(0.))
...@@ -67,4 +108,3 @@ def test_padded_index_select(device): ...@@ -67,4 +108,3 @@ def test_padded_index_select(device):
end.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(start.elapsed_time(end)) print(start.elapsed_time(end))
print(torch.allclose(out1.view(-1, dim), out3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment