Commit ebeb4509 authored by rusty1s's avatar rusty1s
Browse files

update with root_n_id

parent 54b0a095
......@@ -8,9 +8,9 @@ inline torch::Tensor vec2tensor(std::vector<int64_t> vec) {
return torch::from_blob(vec.data(), {(int64_t)vec.size()}, at::kLong).clone();
}
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`, `root_n_id`
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
torch::Tensor, torch::Tensor>
ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor idx, int64_t depth,
int64_t num_neighbors, bool replace) {
......@@ -19,12 +19,13 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
std::vector<torch::Tensor> out_cols(idx.numel());
std::vector<torch::Tensor> out_n_ids(idx.numel());
std::vector<torch::Tensor> out_e_ids(idx.numel());
auto out_root_n_id = torch::empty({idx.numel()}, at::kLong);
out_rowptrs[0] = torch::zeros({1}, at::kLong);
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto idx_data = idx.data_ptr<int64_t>();
auto out_root_n_id_data = out_root_n_id.data_ptr<int64_t>();
at::parallel_for(0, idx.numel(), 1, [&](int64_t begin, int64_t end) {
int64_t row_start, row_end, row_count, vec_start, vec_end, v, w;
......@@ -82,6 +83,8 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
i++;
}
out_root_n_id_data[g] = n_id_map[idx_data[g]];
std::vector<int64_t> rowptrs, cols, e_ids;
for (int64_t v : n_ids) {
row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
......@@ -114,11 +117,12 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
out_rowptrs[g + 1].add_(edge_cumsum);
out_cols[g].add_(node_cumsum);
out_ptr_data[g] = node_cumsum;
out_root_n_id_data[g] += node_cumsum;
}
node_cumsum += out_n_ids[idx.numel() - 1].numel();
out_ptr_data[idx.numel()] = node_cumsum;
return std::make_tuple(torch::cat(out_rowptrs, 0), torch::cat(out_cols, 0),
torch::cat(out_n_ids, 0), torch::cat(out_e_ids, 0),
out_ptr);
out_ptr, out_root_n_id);
}
......@@ -3,7 +3,7 @@
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
torch::Tensor, torch::Tensor>
ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor idx, int64_t depth,
int64_t num_neighbors, bool replace);
......@@ -11,9 +11,9 @@ PyMODINIT_FUNC PyInit__ego_sample_cpu(void) { return NULL; }
#endif
#endif
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`, `root_n_id`
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
torch::Tensor, torch::Tensor>
ego_k_hop_sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t depth, int64_t num_neighbors, bool replace) {
if (rowptr.device().is_cuda()) {
......
......@@ -8,7 +8,15 @@ def test_ego_k_hop_sample_adj():
col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])
_ = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
idx = torch.tensor([2])
nid = torch.tensor([0, 1])
fn = torch.ops.torch_sparse.ego_k_hop_sample_adj
fn(rowptr, col, idx, 1, 3, False)
out = fn(rowptr, col, nid, 1, 3, False)
rowptr, col, nid, eid, ptr, root_n_id = out
assert nid.tolist() == [0, 1, 2, 3, 0, 1, 2]
assert rowptr.tolist() == [0, 3, 5, 7, 8, 10, 12, 14]
# row [0, 0, 0, 1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6]
assert col.tolist() == [1, 2, 3, 0, 2, 0, 1, 0, 5, 6, 4, 6, 4, 5]
assert eid.tolist() == [0, 1, 2, 3, 4, 5, 6, 9, 0, 1, 3, 4, 5, 6]
assert ptr.tolist() == [0, 4, 7]
assert root_n_id.tolist() == [0, 5]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment