Commit 354ef5e5 authored by rusty1s's avatar rusty1s
Browse files

DONE

parent af2325bb
......@@ -11,8 +11,10 @@
__global__ void bin_kernel(const int64_t *__restrict__ rowcount,
const int64_t *__restrict__ binptr,
int64_t *__restrict__ bin, int64_t *__restrict__ idx,
int *__restrict__ size, int *__restrict__ length,
const size_t B, const size_t N) {
int *__restrict__ node_size,
int *__restrict__ max_deg, const size_t B,
const size_t N) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < N; thread_idx += gridDim.x * blockDim.x) {
......@@ -24,45 +26,71 @@ __global__ void bin_kernel(const int64_t *__restrict__ rowcount,
}
}
if (bin_idx == -1)
if (bin_idx == -1) {
bin_idx = B - 1;
}
int old = atomicAdd(size + bin_idx, 1);
atomicMax(length + bin_idx, deg);
int old = atomicAdd(node_size + bin_idx, 1);
atomicMax(max_deg + bin_idx, deg);
bin[thread_idx] = bin_idx;
idx[thread_idx] = old;
}
}
__global__ void offset_kernel(const int *__restrict__ size,
const int *__restrict__ length,
int *__restrict__ offset, const size_t B) {
__global__ void info_kernel(const int *__restrict__ node_size,
const int *__restrict__ max_deg,
int *__restrict__ edge_size,
int *__restrict__ node_offset,
int *__restrict__ edge_offset, const size_t B) {
int bin_idx = threadIdx.x / 32;
int lane_idx = threadIdx.x % 32;
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int bin_idx = thread_idx / 32;
int lane_idx = thread_idx % 32;
if (bin_idx <= B) { // Computes `node_offset` and `edge_offset`.
int node_tmp = 0;
int edge_tmp = 0;
if (bin_idx <= B) {
int tmp = 0;
for (int i = lane_idx; i < bin_idx; i += 32) {
tmp += size[i] * length[i];
node_tmp += node_size[i];
edge_tmp += node_size[i] * max_deg[i];
}
for (int i = 32 / 2; i > 0; i /= 2) {
tmp += __shfl_down_sync(FULL_MASK, tmp, i);
node_tmp += __shfl_down_sync(FULL_MASK, node_tmp, i);
edge_tmp += __shfl_down_sync(FULL_MASK, edge_tmp, i);
}
if (lane_idx == 0)
offset[bin_idx] = tmp;
if (lane_idx == 0) {
node_offset[bin_idx] = node_tmp;
edge_offset[bin_idx] = edge_tmp;
}
} else if (bin_idx == B + 1) { // Computes `edge_size`.
for (int i = lane_idx; i < B; i += 32) {
edge_size[i] = node_size[i] * max_deg[i];
}
}
}
__global__ void node_perm_kernel(const int64_t *__restrict__ bin,
const int64_t *__restrict__ idx,
const int *__restrict__ node_offset,
int64_t *__restrict__ out, const size_t N) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < N; thread_idx += gridDim.x * blockDim.x) {
out[__ldg(node_offset + bin[thread_idx]) + idx[thread_idx]] = thread_idx;
}
}
template <int TB>
__global__ void padded_index_kernel(
const int64_t *__restrict__ rowptr, const int64_t *__restrict__ rowcount,
const int64_t *__restrict__ bin, const int64_t *__restrict__ idx,
int64_t *__restrict__ out, bool *__restrict__ mask,
const int *__restrict__ length, const int *__restrict__ offset,
const int64_t *__restrict__ rowptr, const int64_t *__restrict__ col,
const int64_t *__restrict__ rowcount, const int64_t *__restrict__ bin,
const int64_t *__restrict__ idx, const int *__restrict__ max_deg,
const int *__restrict__ edge_offset, int64_t *__restrict__ row_perm,
int64_t *__restrict__ col_perm, bool *__restrict__ edge_mask,
const size_t B, const size_t N) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -72,26 +100,33 @@ __global__ void padded_index_kernel(
int lane_idx = thread_idx % TB;
int64_t bin_idx = bin[row_idx];
int len = __ldg(length + bin_idx);
int off = __ldg(offset + bin_idx) + len * idx[row_idx];
int len = __ldg(max_deg + bin_idx);
int off = __ldg(edge_offset + bin_idx) + len * idx[row_idx];
int64_t row_start = rowptr[row_idx], deg = rowcount[row_idx];
int64_t tmp;
int64_t row_tmp, col_tmp;
for (int i = lane_idx; i < len; i += TB) {
tmp = -1;
if (i < deg)
tmp = row_start + i;
out[off + i] = tmp;
mask[off + i] = tmp == -1;
row_tmp = -1, col_tmp = -1;
if (i < deg) {
row_tmp = row_idx;
col_tmp = col[row_start + i];
}
row_perm[off + i] = row_tmp;
col_perm[off + i] = col_tmp;
edge_mask[off + i] = row_tmp == -1;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor rowcount,
torch::Tensor binptr) {
// TODO: Add checks
std::vector<int64_t>, std::vector<int64_t>>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor rowcount, torch::Tensor binptr) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(rowcount);
CHECK_CUDA(binptr);
CHECK_INPUT(rowptr.numel() == rowcount.numel() + 1);
cudaSetDevice(rowcount.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
......@@ -103,45 +138,62 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor rowcount,
auto bin = torch::empty(N, rowptr.options());
auto idx = torch::empty(N, rowptr.options());
auto tmp = torch::zeros(B + B + B + 1, rowcount.options().dtype(torch::kInt));
auto size = tmp.narrow(0, 0, B);
auto length = tmp.narrow(0, B, B);
auto offset = tmp.narrow(0, 2 * B, B + 1);
auto d_info = torch::zeros(5 * B + 2, col.options().dtype(torch::kInt));
auto d_node_size = d_info.narrow(0, 0, B);
auto d_edge_size = d_info.narrow(0, B, B);
auto d_max_deg = d_info.narrow(0, 2 * B, B);
auto d_node_offset = d_info.narrow(0, 3 * B, B + 1);
auto d_edge_offset = d_info.narrow(0, 4 * B + 1, B + 1);
bin_kernel<<<std::min(BLOCKS(N), mpc * 8), THREADS, 0, stream>>>(
rowcount.data_ptr<int64_t>(), binptr.data_ptr<int64_t>(),
bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(), size.data_ptr<int>(),
length.data_ptr<int>(), B, N);
bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(),
d_node_size.data_ptr<int>(), d_max_deg.data_ptr<int>(), B, N);
offset_kernel<<<BLOCKS(32 * (B + 1)), THREADS, 0, stream>>>(
size.data_ptr<int>(), length.data_ptr<int>(), offset.data_ptr<int>(), B);
info_kernel<<<BLOCKS(32 * (B + 2)), THREADS, 0, stream>>>(
d_node_size.data_ptr<int>(), d_max_deg.data_ptr<int>(),
d_edge_size.data_ptr<int>(), d_node_offset.data_ptr<int>(),
d_edge_offset.data_ptr<int>(), B);
auto h_tmp = torch::empty(
{tmp.numel()}, tmp.options().device(torch::kCPU).pinned_memory(true));
cudaMemcpy(h_tmp.data_ptr<int>(), tmp.data_ptr<int>(),
tmp.numel() * sizeof(int), cudaMemcpyDeviceToHost);
auto node_perm = torch::empty(N, rowptr.options());
auto out = torch::empty({h_tmp.data_ptr<int>()[3 * B]}, rowptr.options());
auto mask = torch::empty({out.numel()}, rowptr.options().dtype(torch::kBool));
node_perm_kernel<<<std::min(BLOCKS(N), mpc * 8), THREADS, 0, stream>>>(
bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(),
d_node_offset.data_ptr<int>(), node_perm.data_ptr<int64_t>(), N);
auto h_info = torch::empty(
d_info.numel(), d_info.options().device(torch::kCPU).pinned_memory(true));
cudaMemcpy(h_info.data_ptr<int>(), d_info.data_ptr<int>(),
d_info.numel() * sizeof(int), cudaMemcpyDeviceToHost);
size_t E = h_info.data_ptr<int>()[5 * B + 1];
auto row_perm = torch::empty(E, col.options());
auto col_perm = torch::empty(E, col.options());
auto edge_mask = torch::empty(E, col.options().dtype(torch::kBool));
padded_index_kernel<8>
<<<std::min(BLOCKS(N * 8), mpc * 8), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), rowcount.data_ptr<int64_t>(),
bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(),
out.data_ptr<int64_t>(), mask.data_ptr<bool>(),
length.data_ptr<int>(), offset.data_ptr<int>(), B, N);
return std::make_tuple(out, mask, h_tmp.narrow(0, 0, B),
h_tmp.narrow(0, B, B), h_tmp.narrow(0, 2 * B, B + 1));
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
rowcount.data_ptr<int64_t>(), bin.data_ptr<int64_t>(),
idx.data_ptr<int64_t>(), d_max_deg.data_ptr<int>(),
d_edge_offset.data_ptr<int>(), row_perm.data_ptr<int64_t>(),
col_perm.data_ptr<int64_t>(), edge_mask.data_ptr<bool>(), B, N);
h_info = h_info.to(torch::kLong);
auto h_info_data = h_info.data_ptr<int64_t>();
std::vector<int64_t> node_sizes(h_info_data, h_info_data + B);
std::vector<int64_t> edge_sizes(h_info_data + B, h_info_data + 2 * B);
return std::make_tuple(node_perm, row_perm, col_perm, edge_mask, node_sizes,
edge_sizes);
}
template <typename scalar_t>
__global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
const int64_t *__restrict__ col,
const int64_t *__restrict__ index,
scalar_t *__restrict__ out,
const scalar_t fill_value,
const size_t F, const size_t E) {
const size_t E, const size_t F) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < E * F; thread_idx += gridDim.x * blockDim.x) {
......@@ -152,17 +204,19 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
scalar_t tmp = fill_value;
if (index_idx != -1) {
tmp = src[__ldg(col + index_idx) * F + lane_idx];
tmp = src[index_idx * F + lane_idx];
}
out[thread_idx] = tmp;
}
}
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor col,
torch::Tensor index,
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value) {
// TODO: Add checks
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_INPUT(src.dim() == 2);
CHECK_INPUT(index.dim() == 1);
cudaSetDevice(src.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
......@@ -185,8 +239,8 @@ torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor col,
padded_index_select_kernel<scalar_t>
<<<std::min(BLOCKS(E * F), mpc * 8), THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), col.data_ptr<int64_t>(),
index.data_ptr<int64_t>(), out.data_ptr<scalar_t>(), fill[0], F, E);
src.data_ptr<scalar_t>(), index.data_ptr<int64_t>(),
out.data_ptr<scalar_t>(), fill[0], E, F);
});
return out;
......
......@@ -3,10 +3,9 @@
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor rowcount,
torch::Tensor binptr);
std::vector<int64_t>, std::vector<int64_t>>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor rowcount, torch::Tensor binptr);
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor col,
torch::Tensor index,
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value);
......@@ -10,16 +10,15 @@ PyMODINIT_FUNC PyInit__padding(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
padded_index(torch::Tensor rowptr, torch::Tensor rowcount,
std::vector<int64_t>, std::vector<int64_t>>
padded_index(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
torch::Tensor binptr) {
return padded_index_cuda(rowptr, rowcount, binptr);
return padded_index_cuda(rowptr, col, rowcount, binptr);
}
torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor col,
torch::Tensor index,
torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value) {
return padded_index_select_cuda(src, col, index, fill_value);
return padded_index_select_cuda(src, index, fill_value);
}
static auto registry =
......
......@@ -16,23 +16,24 @@ def test_padded_index_select(device):
col = torch.tensor([0, 1, 2, 3, 0, 2, 3, 1, 3, 2])
idx = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
adj = SparseTensor(row=row, col=col).to(device)
binptr = torch.tensor([0, 3, 5], device=device)
idx, mask, size, length, offset = torch.ops.torch_sparse.padded_index(
adj.storage.rowptr(), adj.storage.rowcount(), binptr)
print(size)
print(length)
print(offset)
data = torch.ops.torch_sparse.padded_index(adj.storage.rowptr(),
adj.storage.col(),
adj.storage.rowcount(), binptr)
node_perm, row_perm, col_perm, mask, size, length = data
print(idx)
print(mask)
print('node perm', node_perm)
print('row perm', row_perm)
print('col perm', col_perm)
print('mask', mask)
print('size', size)
print('length', length)
x = torch.tensor([[0], [1], [2], [3]], dtype=torch.float, device=device)
out = torch.ops.torch_sparse.padded_index_select(x, adj.storage.col(), idx,
torch.tensor(0.))
print(out)
# x = torch.tensor([[0], [1], [2], [3]], dtype=torch.float, device=device)
# out = torch.ops.torch_sparse.padded_index_select(x, adj.storage.col(), idx,
# torch.tensor(0.))
# print(out)
dataset = Planetoid('/tmp/Planetoid', name='PubMed')
data = dataset[0]
......@@ -41,12 +42,10 @@ def test_padded_index_select(device):
adj = SparseTensor(row=row, col=col)
rowcount = adj.storage.rowcount().to(device)
rowptr = adj.storage.rowptr().to(device)
bin_strategy = torch.tensor([[1, 4], [4, 11], [11, 30]]).to(device)
binptr = torch.tensor([0, 4, 11, 30, 50, 80, 120, 140, 2000]).to(device)
deg = degree(row, dtype=torch.long)
bins = torch.bincount(deg)
# deg = degree(row, dtype=torch.long)
# bins = torch.bincount(deg)
# print(bins.size())
# print(bins[:200])
# for i in range(110):
......@@ -57,23 +56,24 @@ def test_padded_index_select(device):
# end.record()
# torch.cuda.synchronize()
# print('bin assignment', start.elapsed_time(end))
idx, mask, size, length, offset = torch.ops.torch_sparse.padded_index(
rowptr, rowcount, binptr)
print(size)
print(length)
print(offset)
print(mask[:10])
print(idx[:10])
x = torch.randn(data.num_nodes, 256).to(device)
# idx, mask, size, length, offset = torch.ops.torch_sparse.padded_index(
# rowptr, rowcount, binptr)
# print(size)
# print(length)
# print(offset)
# print(mask[:10])
# print(idx[:10])
for i in range(110):
if i == 10:
start.record()
torch.ops.torch_sparse.padded_index(rowptr, rowcount, binptr)
torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr)
end.record()
torch.cuda.synchronize()
print('padded index', start.elapsed_time(end))
return
x = torch.randn(data.num_nodes, 512).to(device)
for i in range(110):
if i == 10:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment