Commit 341f959a authored by rusty1s's avatar rusty1s
Browse files

ego sampler

parent 4d49d44e
#include "ego_sample_cpu.h"
#include <ATen/Parallel.h>
#include "utils.h"
inline torch::Tensor vec2tensor(std::vector<int64_t> vec) {
return torch::from_blob(vec.data(), {(int64_t)vec.size()}, at::kLong).clone();
}
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor idx, int64_t depth,
int64_t num_neighbors, bool replace) {
std::vector<torch::Tensor> out_rowptrs(idx.numel() + 1);
std::vector<torch::Tensor> out_cols(idx.numel());
std::vector<torch::Tensor> out_n_ids(idx.numel());
std::vector<torch::Tensor> out_e_ids(idx.numel());
out_rowptrs[0] = torch::zeros({1}, at::kLong);
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto idx_data = idx.data_ptr<int64_t>();
at::parallel_for(0, idx.numel(), 1, [&](int64_t begin, int64_t end) {
int64_t row_start, row_end, row_count, vec_start, vec_end, v, w;
for (int64_t g = begin; g < end; g++) {
std::set<int64_t> n_id_set;
n_id_set.insert(idx_data[g]);
std::vector<int64_t> n_ids;
n_ids.push_back(idx_data[g]);
vec_start = 0, vec_end = n_ids.size();
for (int64_t d = 0; d < depth; d++) {
for (int64_t i = vec_start; i < vec_end; i++) {
v = n_ids[i];
row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
row_count = row_end - row_start;
if (row_count <= num_neighbors) {
for (int64_t e = row_start; e < row_end; e++) {
w = col_data[e];
n_id_set.insert(w);
n_ids.push_back(w);
}
} else if (replace) {
for (int64_t j = 0; j < num_neighbors; j++) {
w = col_data[row_start + (rand() % row_count)];
n_id_set.insert(w);
n_ids.push_back(w);
}
} else {
std::unordered_set<int64_t> perm;
for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
if (!perm.insert(rand() % j).second) {
perm.insert(j);
}
}
for (int64_t j : perm) {
w = col_data[row_start + j];
n_id_set.insert(w);
n_ids.push_back(w);
}
}
}
vec_start = vec_end;
vec_end = n_ids.size();
}
n_ids.clear();
std::unordered_map<int64_t, int64_t> n_id_map;
std::unordered_map<int64_t, int64_t>::iterator iter;
int64_t i = 0;
for (int64_t v : n_id_set) {
n_ids.push_back(v);
n_id_map[v] = i;
i++;
}
std::vector<int64_t> rowptrs, cols, e_ids;
for (int64_t v : n_ids) {
row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
for (int64_t e = row_start; e < row_end; e++) {
w = col_data[e];
iter = n_id_map.find(w);
if (iter != n_id_map.end()) {
cols.push_back(iter->second);
e_ids.push_back(e);
}
}
rowptrs.push_back(cols.size());
}
out_rowptrs[g + 1] = vec2tensor(rowptrs);
out_cols[g] = vec2tensor(cols);
out_n_ids[g] = vec2tensor(n_ids);
out_e_ids[g] = vec2tensor(e_ids);
}
});
auto out_ptr = torch::empty({idx.numel() + 1}, at::kLong);
auto out_ptr_data = out_ptr.data_ptr<int64_t>();
out_ptr_data[0] = 0;
int64_t node_cumsum = 0, edge_cumsum = 0;
for (int64_t g = 0; g < idx.numel() - 1; g++) {
node_cumsum += out_n_ids[g].numel();
edge_cumsum += out_cols[g].numel();
out_rowptrs[g + 2].add_(edge_cumsum);
out_cols[g + 1].add_(node_cumsum);
out_ptr_data[g + 1] = node_cumsum;
}
node_cumsum += out_n_ids[idx.numel() - 1].numel();
out_ptr_data[idx.numel()] = node_cumsum;
return std::make_tuple(torch::cat(out_rowptrs, 0), torch::cat(out_cols, 0),
torch::cat(out_n_ids, 0), torch::cat(out_e_ids, 0),
out_ptr);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor idx, int64_t depth,
int64_t num_neighbors, bool replace);
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "utils.h" #include "utils.h"
// Returns `rowptr`, `col`, `n_id`, `e_id`, // Returns `rowptr`, `col`, `n_id`, `e_id`
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t num_neighbors, bool replace) { int64_t num_neighbors, bool replace) {
......
#include <Python.h>
#include <torch/script.h>
#include "cpu/ego_sample_cpu.h"
#ifdef _WIN32
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__ego_sample_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__ego_sample_cpu(void) { return NULL; }
#endif
#endif
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
ego_k_hop_sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t depth, int64_t num_neighbors, bool replace) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return ego_k_hop_sample_adj_cpu(rowptr, col, idx, depth, num_neighbors,
replace);
}
}
static auto registry = torch::RegisterOperators().op(
"torch_sparse::ego_k_hop_sample_adj", &ego_k_hop_sample_adj);
import torch
from torch_sparse import SparseTensor
def test_ego_k_hop_sample_adj():
rowptr = torch.tensor([0, 3, 5, 9, 10, 12, 14])
row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5])
col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])
_ = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
idx = torch.tensor([2])
fn = torch.ops.torch_sparse.ego_k_hop_sample_adj
fn(rowptr, col, idx, 1, 3, False)
...@@ -9,7 +9,7 @@ suffix = 'cuda' if torch.cuda.is_available() else 'cpu' ...@@ -9,7 +9,7 @@ suffix = 'cuda' if torch.cuda.is_available() else 'cpu'
for library in [ for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw', '_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
'_saint', '_sample', '_relabel' '_saint', '_sample', '_ego_sample', '_relabel'
]: ]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
f'{library}_{suffix}', [osp.dirname(__file__)]).origin) f'{library}_{suffix}', [osp.dirname(__file__)]).origin)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment