Commit 014c4bae authored by rusty1s's avatar rusty1s
Browse files

hetero neighbor sampling

parent 9532032e
...@@ -102,7 +102,7 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict, ...@@ -102,7 +102,7 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const int64_t num_hops) { const int64_t num_hops) {
// Create a mapping to convert single string relations to edge type triplets: // Create a mapping to convert single string relations to edge type triplets:
std::unordered_map<rel_t, edge_t> to_edge_type; unordered_map<rel_t, edge_t> to_edge_type;
for (const auto &kv : colptr_dict) { for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key(); const auto &rel_type = kv.key();
to_edge_type[rel_type] = split(rel_type); to_edge_type[rel_type] = split(rel_type);
......
#include "neighbor_sample_cpu.h"
#include "utils.h"
using namespace std;
namespace {
template <bool replace, bool directed>
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample(const torch::Tensor &colptr, const torch::Tensor &row,
const torch::Tensor &input_node, const vector<int64_t> num_neighbors) {
// Initialize some data structures for the sampling process:
vector<int64_t> samples;
unordered_map<int64_t, int64_t> to_local_node;
auto *colptr_data = colptr.data_ptr<int64_t>();
auto *row_data = row.data_ptr<int64_t>();
auto *input_node_data = input_node.data_ptr<int64_t>();
for (int64_t i = 0; i < input_node.numel(); i++) {
const auto &v = input_node_data[i];
samples.push_back(v);
to_local_node.insert({v, i});
}
vector<int64_t> rows, cols, edges;
int64_t begin = 0, end = samples.size();
for (int64_t ell = 0; ell < (int64_t)num_neighbors.size(); ell++) {
const auto &num_samples = num_neighbors[ell];
for (int64_t i = begin; i < end; i++) {
const auto &w = samples[i];
const auto &col_start = colptr_data[w];
const auto &col_end = colptr_data[w + 1];
const auto col_count = col_end - col_start;
if (col_count == 0)
continue;
if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = col_start + rand() % col_count;
const int64_t &v = row_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
} else if (num_samples >= col_count) {
for (int64_t offset = col_start; offset < col_end; offset++) {
const int64_t &v = row_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
} else {
unordered_set<int64_t> rnd_indices;
for (int64_t j = col_count - num_samples; j < col_count; j++) {
int64_t rnd = rand() % j;
if (!rnd_indices.insert(rnd).second) {
rnd = j;
rnd_indices.insert(j);
}
const int64_t offset = col_start + rnd;
const int64_t &v = row_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
}
}
begin = end, end = samples.size();
}
if (!directed) {
unordered_map<int64_t, int64_t>::iterator iter;
for (int64_t i = 0; i < (int64_t)samples.size(); i++) {
const auto &w = samples[i];
const auto &col_start = colptr_data[w];
const auto &col_end = colptr_data[w + 1];
for (int64_t offset = col_start; offset < col_end; offset++) {
const auto &v = row_data[offset];
iter = to_local_node.find(v);
if (iter != to_local_node.end()) {
rows.push_back(iter->second);
cols.push_back(i);
edges.push_back(offset);
}
}
}
}
return make_tuple(from_vector<int64_t>(samples), from_vector<int64_t>(rows),
from_vector<int64_t>(cols), from_vector<int64_t>(edges));
}
template <bool replace, bool directed>
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops) {
// Create a mapping to convert single string relations to edge type triplets:
unordered_map<rel_t, edge_t> to_edge_type;
for (const auto &k : edge_types)
to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;
// Initialize some data structures for the sampling process:
unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
for (const auto &k : node_types) {
samples_dict[k];
to_local_node_dict[k];
}
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
rows_dict[rel_type];
cols_dict[rel_type];
edges_dict[rel_type];
}
// Add the input nodes to the output nodes:
for (const auto &kv : input_node_dict) {
const auto &node_type = kv.key();
const auto &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>();
auto &samples = samples_dict.at(node_type);
auto &to_local_node = to_local_node_dict.at(node_type);
for (int64_t i = 0; i < input_node.numel(); i++) {
const auto &v = input_node_data[i];
samples.push_back(v);
to_local_node.insert({v, i});
}
}
unordered_map<node_t, pair<int64_t, int64_t>> slice_dict;
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()};
for (int64_t ell = 0; ell < num_hops; ell++) {
for (const auto &kv : num_neighbors_dict) {
const auto &rel_type = kv.key();
const auto &edge_type = to_edge_type[rel_type];
const auto &src_node_type = get<0>(edge_type);
const auto &dst_node_type = get<2>(edge_type);
const auto &num_samples = kv.value()[ell];
const auto &dst_samples = samples_dict.at(dst_node_type);
auto &src_samples = samples_dict.at(src_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
const auto *colptr_data = colptr_dict.at(rel_type).data_ptr<int64_t>();
const auto *row_data = row_dict.at(rel_type).data_ptr<int64_t>();
auto &rows = rows_dict.at(rel_type);
auto &cols = cols_dict.at(rel_type);
auto &edges = edges_dict.at(rel_type);
const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
for (int64_t i = begin; i < end; i++) {
const auto &w = dst_samples[i];
const auto &col_start = colptr_data[w];
const auto &col_end = colptr_data[w + 1];
const auto col_count = col_end - col_start;
if (col_count == 0)
continue;
if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = col_start + rand() % col_count;
const int64_t &v = row_data[offset];
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
} else if (num_samples >= col_count) {
for (int64_t offset = col_start; offset < col_end; offset++) {
const int64_t &v = row_data[offset];
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
} else {
unordered_set<int64_t> rnd_indices;
for (int64_t j = col_count - num_samples; j < col_count; j++) {
int64_t rnd = rand() % j;
if (!rnd_indices.insert(rnd).second) {
rnd = j;
rnd_indices.insert(j);
}
const int64_t offset = col_start + rnd;
const int64_t &v = row_data[offset];
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
}
}
}
for (const auto &kv : samples_dict) {
slice_dict[kv.first] = {slice_dict.at(kv.first).second, kv.second.size()};
}
}
if (!directed) { // Construct the subgraph among the sampled nodes:
unordered_map<int64_t, int64_t>::iterator iter;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
const auto &edge_type = to_edge_type[rel_type];
const auto &src_node_type = get<0>(edge_type);
const auto &dst_node_type = get<2>(edge_type);
const auto &dst_samples = samples_dict.at(dst_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
const auto *colptr_data = kv.value().data_ptr<int64_t>();
const auto *row_data = row_dict.at(rel_type).data_ptr<int64_t>();
auto &rows = rows_dict.at(rel_type);
auto &cols = cols_dict.at(rel_type);
auto &edges = edges_dict.at(rel_type);
for (int64_t i = 0; i < (int64_t)dst_samples.size(); i++) {
const auto &w = dst_samples[i];
const auto &col_start = colptr_data[w];
const auto &col_end = colptr_data[w + 1];
for (int64_t offset = col_start; offset < col_end; offset++) {
const auto &v = row_data[offset];
iter = to_local_src_node.find(v);
if (iter != to_local_src_node.end()) {
rows.push_back(iter->second);
cols.push_back(i);
edges.push_back(offset);
}
}
}
}
}
return make_tuple(from_vector<node_t, int64_t>(samples_dict),
from_vector<rel_t, int64_t>(rows_dict),
from_vector<rel_t, int64_t>(cols_dict),
from_vector<rel_t, int64_t>(edges_dict));
}
} // namespace
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
const torch::Tensor &input_node,
const vector<int64_t> num_neighbors, const bool replace,
const bool directed) {
if (replace && directed) {
return sample<true, true>(colptr, row, input_node, num_neighbors);
} else if (replace && !directed) {
return sample<true, false>(colptr, row, input_node, num_neighbors);
} else if (!replace && directed) {
return sample<false, true>(colptr, row, input_node, num_neighbors);
} else {
return sample<false, false>(colptr, row, input_node, num_neighbors);
}
}
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_sample_cpu(
const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const bool replace, const bool directed) {
if (replace && directed) {
return hetero_sample<true, true>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
} else if (replace && !directed) {
return hetero_sample<true, false>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
} else if (!replace && directed) {
return hetero_sample<false, true>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
} else {
return hetero_sample<false, false>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
}
}
#pragma once
#include <torch/extension.h>
typedef std::string node_t;
typedef std::tuple<std::string, std::string, std::string> edge_t;
typedef std::string rel_t;
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
const torch::Tensor &input_node,
const std::vector<int64_t> num_neighbors,
const bool replace, const bool directed);
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_sample_cpu(
const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const bool replace, const bool directed);
...@@ -25,10 +25,23 @@ inline torch::Tensor from_vector(const std::vector<scalar_t> &vec, ...@@ -25,10 +25,23 @@ inline torch::Tensor from_vector(const std::vector<scalar_t> &vec,
return inplace ? out : out.clone(); return inplace ? out : out.clone();
} }
template <typename key_t, typename scalar_t>
inline c10::Dict<key_t, torch::Tensor>
from_vector(const std::unordered_map<key_t, std::vector<scalar_t>> &vec_dict,
bool inplace = false) {
c10::Dict<key_t, torch::Tensor> out_dict;
for (const auto &kv : vec_dict)
out_dict.insert(kv.first, from_vector<scalar_t>(kv.second, inplace));
return out_dict;
}
inline torch::Tensor inline torch::Tensor
choice(int64_t population, int64_t num_samples, bool replace = false, choice(int64_t population, int64_t num_samples, bool replace = false,
torch::optional<torch::Tensor> weight = torch::nullopt) { torch::optional<torch::Tensor> weight = torch::nullopt) {
if (population == 0 || num_samples == 0)
return torch::empty({0}, at::kLong);
if (!replace && num_samples >= population) if (!replace && num_samples >= population)
return torch::arange(population, at::kLong); return torch::arange(population, at::kLong);
...@@ -47,18 +60,53 @@ choice(int64_t population, int64_t num_samples, bool replace = false, ...@@ -47,18 +60,53 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
// Sample without replacement via Robert Floyd algorithm: // Sample without replacement via Robert Floyd algorithm:
// https://www.nowherenearithaca.com/2013/05/ // https://www.nowherenearithaca.com/2013/05/
// robert-floyds-tiny-and-beautiful.html // robert-floyds-tiny-and-beautiful.html
std::unordered_set<int64_t> values;
for (int64_t i = population - num_samples; i < population; i++) {
if (!values.insert(rand() % i).second)
values.insert(i);
}
const auto out = torch::empty(num_samples, at::kLong); const auto out = torch::empty(num_samples, at::kLong);
auto *out_data = out.data_ptr<int64_t>(); auto *out_data = out.data_ptr<int64_t>();
int64_t i = 0; std::unordered_set<int64_t> samples;
for (const auto &value : values) { for (int64_t i = population - num_samples; i < population; i++) {
out_data[i] = value; int64_t sample = rand() % i;
i++; if (!samples.insert(sample).second) {
sample = i;
samples.insert(sample);
}
out_data[i - population + num_samples] = sample;
} }
return out; return out;
} }
} }
template <bool replace>
inline void
uniform_choice(const int64_t population, const int64_t num_samples,
const int64_t *idx_data, std::vector<int64_t> *samples,
std::unordered_map<int64_t, int64_t> *to_local_node) {
if (population == 0 || num_samples == 0)
return;
if (replace) {
for (int64_t i = 0; i < num_samples; i++) {
const int64_t &v = idx_data[rand() % population];
if (to_local_node->insert({v, samples->size()}).second)
samples->push_back(v);
}
} else if (num_samples >= population) {
for (int64_t i = 0; i < population; i++) {
const int64_t &v = idx_data[i];
if (to_local_node->insert({v, samples->size()}).second)
samples->push_back(v);
}
} else {
std::unordered_set<int64_t> indices;
for (int64_t i = population - num_samples; i < population; i++) {
int64_t j = rand() % i;
if (!indices.insert(j).second) {
j = i;
indices.insert(j);
}
const int64_t &v = idx_data[j];
if (to_local_node->insert({v, samples->size()}).second)
samples->push_back(v);
}
}
}
...@@ -11,6 +11,7 @@ PyMODINIT_FUNC PyInit__hgt_sample_cpu(void) { return NULL; } ...@@ -11,6 +11,7 @@ PyMODINIT_FUNC PyInit__hgt_sample_cpu(void) { return NULL; }
#endif #endif
#endif #endif
// Returns 'output_node_dict', 'row_dict', 'col_dict', 'output_edge_dict'
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>, std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>> c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hgt_sample(const c10::Dict<std::string, torch::Tensor> &colptr_dict, hgt_sample(const c10::Dict<std::string, torch::Tensor> &colptr_dict,
......
#include <Python.h>
#include <torch/script.h>
#include "cpu/neighbor_sample_cpu.h"
#ifdef _WIN32
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__neighbor_sample_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__neighbor_sample_cpu(void) { return NULL; }
#endif
#endif
// Returns 'output_node', 'row', 'col', 'output_edge'
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row,
const torch::Tensor &input_node,
const std::vector<int64_t> num_neighbors, const bool replace,
const bool directed) {
return neighbor_sample_cpu(colptr, row, input_node, num_neighbors, replace,
directed);
}
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_sample(
const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const bool replace, const bool directed) {
return hetero_neighbor_sample_cpu(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, replace, directed);
}
static auto registry =
torch::RegisterOperators()
.op("torch_sparse::neighbor_sample", &neighbor_sample)
.op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample);
...@@ -9,7 +9,8 @@ suffix = 'cuda' if torch.cuda.is_available() else 'cpu' ...@@ -9,7 +9,8 @@ suffix = 'cuda' if torch.cuda.is_available() else 'cpu'
for library in [ for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw', '_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
'_saint', '_sample', '_ego_sample', '_hgt_sample', '_relabel' '_saint', '_sample', '_ego_sample', '_hgt_sample', '_neighbor_sample',
'_relabel'
]: ]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
f'{library}_{suffix}', [osp.dirname(__file__)]).origin) f'{library}_{suffix}', [osp.dirname(__file__)]).origin)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment