Unverified Commit 79535f30 authored by Rex Ying's avatar Rex Ying Committed by GitHub
Browse files

Temporal sampling (#202)



* temporal sample

* hetero_neighbor_sample temporal

* api

* remove redundant function

* refactor

* remove catch output

* debug compile

* testing env

* debug

* debug

* debug

* debug

* debug

* revert

* node time data should not be const

* Update csrc/cpu/neighbor_sample_cpu.cpp
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>

* Update csrc/cpu/neighbor_sample_cpu.cpp
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>

* Update csrc/cpu/neighbor_sample_cpu.cpp
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>

* Update csrc/cpu/neighbor_sample_cpu.cpp
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>

* temporal template

* compilation fixes
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>
parent eafcfe0a
......@@ -114,16 +114,34 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
from_vector<int64_t>(cols), from_vector<int64_t>(edges));
}
template <bool replace, bool directed>
bool satisfy_time_constraint(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const std::string &src_node_type,
const int64_t &dst_time,
const int64_t &sampled_node) {
// whether src -> dst obeys the time constraint
try {
const auto *src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
return dst_time < src_time[sampled_node];
}
catch (int err) {
// if the node type does not have timestamp, fall back to normal sampling
return true;
}
}
template <bool replace, bool directed, bool temporal>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const vector<node_t> &node_types,
const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops) {
const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops,
const c10::Dict<node_t, torch::Tensor> &node_time_dict) {
//bool temporal = (!node_time_dict.empty());
// Create a mapping to convert single string relations to edge type triplets:
unordered_map<rel_t, edge_t> to_edge_type;
......@@ -131,13 +149,6 @@ hetero_sample(const vector<node_t> &node_types,
to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;
// Initialize some data structures for the sampling process:
unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
for (const auto &node_type : node_types) {
samples_dict[node_type];
to_local_node_dict[node_type];
}
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
......@@ -146,18 +157,40 @@ hetero_sample(const vector<node_t> &node_types,
edges_dict[rel_type];
}
unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
// The timestamp of the center node whose neighborhood that the sampled node
// belongs to. It maps node_type to empty vector in non-temporal sampling.
unordered_map<node_t, vector<int64_t>> root_time_dict;
for (const auto &node_type : node_types) {
samples_dict[node_type];
to_local_node_dict[node_type];
root_time_dict[node_type];
}
// Add the input nodes to the output nodes:
for (const auto &kv : input_node_dict) {
const auto &node_type = kv.key();
const torch::Tensor &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>();
// dummy value. will be reset to root time if is_temporal==true
auto *node_time_data = input_node.data_ptr<int64_t>();
// root_time[i] stores the timestamp of the computation tree root
// of the node samples[i]
if (temporal) {
node_time_data = node_time_dict.at(node_type).data_ptr<int64_t>();
}
auto &samples = samples_dict.at(node_type);
auto &to_local_node = to_local_node_dict.at(node_type);
auto &root_time = root_time_dict.at(node_type);
for (int64_t i = 0; i < input_node.numel(); i++) {
const auto &v = input_node_data[i];
samples.push_back(v);
to_local_node.insert({v, i});
if (temporal) {
root_time.push_back(node_time_data[v]);
}
}
}
......@@ -187,8 +220,17 @@ hetero_sample(const vector<node_t> &node_types,
const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
if (begin == end){
continue;
}
// for temporal sampling, sampled src node cannot have timestamp greater
// than its corresponding dst_root_time
const auto &dst_root_time = root_time_dict.at(dst_node_type);
auto &src_root_time = root_time_dict.at(src_node_type);
for (int64_t i = begin; i < end; i++) {
const auto &w = dst_samples[i];
const auto &dst_time = dst_root_time[i];
const auto &col_start = colptr_data[w];
const auto &col_end = colptr_data[w + 1];
const auto col_count = col_end - col_start;
......@@ -197,11 +239,22 @@ hetero_sample(const vector<node_t> &node_types,
continue;
if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
// select all neighbors
for (int64_t offset = col_start; offset < col_end; offset++) {
const int64_t &v = row_data[offset];
bool time_constraint = true;
if (temporal) {
time_constraint = satisfy_time_constraint(
node_time_dict, src_node_type, dst_time, v);
}
if (!time_constraint)
continue;
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
if (res.second) {
src_samples.push_back(v);
if (temporal)
src_root_time.push_back(dst_time);
}
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
......@@ -209,19 +262,33 @@ hetero_sample(const vector<node_t> &node_types,
}
}
} else if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
// sample with replacement
int64_t num_neighbors = 0;
while (num_neighbors < num_samples) {
const int64_t offset = col_start + uniform_randint(col_count);
const int64_t &v = row_data[offset];
bool time_constraint = true;
if (temporal) {
time_constraint = satisfy_time_constraint(
node_time_dict, src_node_type, dst_time, v);
}
if (!time_constraint)
continue;
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
if (res.second) {
src_samples.push_back(v);
if (temporal)
src_root_time.push_back(dst_time);
}
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
num_neighbors += 1;
}
} else {
// sample without replacement
unordered_set<int64_t> rnd_indices;
for (int64_t j = col_count - num_samples; j < col_count; j++) {
int64_t rnd = uniform_randint(j);
......@@ -231,9 +298,19 @@ hetero_sample(const vector<node_t> &node_types,
}
const int64_t offset = col_start + rnd;
const int64_t &v = row_data[offset];
bool time_constraint = true;
if (temporal) {
time_constraint = satisfy_time_constraint(
node_time_dict, src_node_type, dst_time, v);
}
if (!time_constraint)
continue;
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
if (res.second) {
src_samples.push_back(v);
if (temporal)
src_root_time.push_back(dst_time);
}
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
......@@ -290,6 +367,27 @@ hetero_sample(const vector<node_t> &node_types,
from_vector<rel_t, int64_t>(edges_dict));
}
template <bool replace, bool directed>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample_random(const vector<node_t> &node_types,
const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops) {
c10::Dict<node_t, torch::Tensor> empty_dict;
return hetero_sample<replace, directed, false>(node_types,
edge_types,
colptr_dict,
row_dict,
input_node_dict,
num_neighbors_dict,
num_hops,
empty_dict);
}
} // namespace
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
......@@ -320,20 +418,58 @@ hetero_neighbor_sample_cpu(
const int64_t num_hops, const bool replace, const bool directed) {
if (replace && directed) {
return hetero_sample<true, true>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
return hetero_sample_random<true, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
} else if (replace && !directed) {
return hetero_sample_random<true, false>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
} else if (!replace && directed) {
return hetero_sample_random<false, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
} else {
return hetero_sample_random<false, false>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
}
}
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_temporal_sample_cpu(
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed) {
if (replace && directed) {
return hetero_sample<true, true, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
} else if (replace && !directed) {
return hetero_sample<true, false>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
return hetero_sample<true, false, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
} else if (!replace && directed) {
return hetero_sample<false, true>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
return hetero_sample<false, true, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
} else {
return hetero_sample<false, false>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
return hetero_sample<false, false, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
}
}
......@@ -22,3 +22,15 @@ hetero_neighbor_sample_cpu(
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const bool replace, const bool directed);
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_temporal_sample_cpu(
const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed);
\ No newline at end of file
......@@ -40,7 +40,24 @@ hetero_neighbor_sample(
num_neighbors_dict, num_hops, replace, directed);
}
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_temporal_sample(
const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed) {
return hetero_neighbor_temporal_sample_cpu(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops, replace, directed);
}
static auto registry =
torch::RegisterOperators()
.op("torch_sparse::neighbor_sample", &neighbor_sample)
.op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample);
.op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample)
.op("torch_sparse::hetero_neighbor_temporal_sample", &hetero_neighbor_temporal_sample);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment