Commit 8fd1c9c0 authored by rusty1s's avatar rusty1s
Browse files

cpu implementation

parent ceb73a8c
...@@ -6,8 +6,93 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, ...@@ -6,8 +6,93 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::vector<int64_t>, std::vector<int64_t>> std::vector<int64_t>, std::vector<int64_t>>
padded_index_cpu(torch::Tensor rowptr, torch::Tensor col, padded_index_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor rowcount, torch::Tensor binptr) { torch::Tensor rowcount, torch::Tensor binptr) {
std::vector<int64_t> bla = {1}; CHECK_CPU(rowptr);
return std::make_tuple(col, col, col, col, bla, bla); CHECK_CPU(col);
CHECK_CPU(rowcount);
CHECK_CPU(binptr);
CHECK_INPUT(rowptr.numel() == rowcount.numel() + 1);
ptrdiff_t B = binptr.numel() - 1;
ptrdiff_t N = rowcount.numel();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto rowcount_data = rowcount.data_ptr<int64_t>();
auto binptr_data = binptr.data_ptr<int64_t>();
auto bin = torch::empty(N, col.options());
auto bin_data = bin.data_ptr<int64_t>();
auto idx = torch::empty(N, col.options());
auto idx_data = idx.data_ptr<int64_t>();
std::vector<int64_t> node_sizes(B), edge_sizes(B), max_degs(B),
node_offsets(B + 1), edge_offsets(B + 1);
int64_t deg, bin_idx = -1;
for (ptrdiff_t n = 0; n < N; n++) {
deg = rowcount_data[n];
for (ptrdiff_t b = 1; b <= B; b++) {
if (deg < binptr_data[b]) {
bin_idx = b - 1;
break;
}
}
if (bin_idx == -1) {
bin_idx = B - 1;
}
bin_data[n] = bin_idx;
idx_data[n] = node_sizes[bin_idx];
node_sizes[bin_idx] += 1;
max_degs[bin_idx] = std::max(max_degs[bin_idx], deg);
}
for (ptrdiff_t b = 0; b < B; b++) {
edge_sizes[b] = node_sizes[b] * max_degs[b];
node_offsets[b + 1] = node_offsets[b] + node_sizes[b];
edge_offsets[b + 1] = edge_offsets[b] + edge_sizes[b];
}
auto node_perm = torch::empty(N, col.options());
auto node_perm_data = node_perm.data_ptr<int64_t>();
auto E = edge_offsets[B];
auto row_perm = torch::empty(E, col.options());
auto row_perm_data = row_perm.data_ptr<int64_t>();
auto col_perm = torch::empty(E, col.options());
auto col_perm_data = col_perm.data_ptr<int64_t>();
auto edge_mask = torch::empty(E, col.options().dtype(torch::kBool));
auto edge_mask_data = edge_mask.data_ptr<bool>();
int64_t row_start = rowptr_data[0], row_end, edge_offset, offset;
for (ptrdiff_t n = 0; n < N; n++) {
bin_idx = bin_data[n];
offset = idx_data[n];
node_perm_data[node_offsets[bin_idx] + offset] = n;
row_end = rowptr_data[n + 1];
edge_offset = edge_offsets[bin_idx] + offset * max_degs[bin_idx];
for (ptrdiff_t e = 0; e < row_end - row_start; e++) {
row_perm_data[edge_offset + e] = n;
col_perm_data[edge_offset + e] = col_data[row_start + e];
edge_mask_data[edge_offset + e] = false;
}
for (ptrdiff_t e = row_end - row_start; e < max_degs[bin_data[n]]; e++) {
row_perm_data[edge_offset + e] = -1;
col_perm_data[edge_offset + e] = -1;
edge_mask_data[edge_offset + e] = true;
}
row_start = row_end;
}
return std::make_tuple(node_perm, row_perm, col_perm, edge_mask, node_sizes,
edge_sizes);
} }
torch::Tensor padded_index_select_cpu(torch::Tensor src, torch::Tensor index, torch::Tensor padded_index_select_cpu(torch::Tensor src, torch::Tensor index,
......
...@@ -136,8 +136,8 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor col, ...@@ -136,8 +136,8 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
size_t B = binptr.numel() - 1; size_t B = binptr.numel() - 1;
size_t N = rowcount.numel(); size_t N = rowcount.numel();
auto bin = torch::empty(N, rowptr.options()); auto bin = torch::empty(N, col.options());
auto idx = torch::empty(N, rowptr.options()); auto idx = torch::empty(N, col.options());
auto d_info = torch::zeros(5 * B + 2, col.options().dtype(torch::kInt)); auto d_info = torch::zeros(5 * B + 2, col.options().dtype(torch::kInt));
auto d_node_size = d_info.narrow(0, 0, B); auto d_node_size = d_info.narrow(0, 0, B);
...@@ -156,7 +156,7 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor col, ...@@ -156,7 +156,7 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
d_edge_size.data_ptr<int>(), d_node_offset.data_ptr<int>(), d_edge_size.data_ptr<int>(), d_node_offset.data_ptr<int>(),
d_edge_offset.data_ptr<int>(), B); d_edge_offset.data_ptr<int>(), B);
auto node_perm = torch::empty(N, rowptr.options()); auto node_perm = torch::empty(N, col.options());
node_perm_kernel<<<std::min(BLOCKS(N), mpc * 8), THREADS, 0, stream>>>( node_perm_kernel<<<std::min(BLOCKS(N), mpc * 8), THREADS, 0, stream>>>(
bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(), bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment