Unverified Commit caf7ddde authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Temporal `neighbor_sample` adjustments (#225)

* version up

* formatting

* fix

* reset

* revert

* temporal neighbor sampling adjustments

* typo
parent 15c56351
...@@ -114,16 +114,13 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row, ...@@ -114,16 +114,13 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
from_vector<int64_t>(cols), from_vector<int64_t>(edges)); from_vector<int64_t>(cols), from_vector<int64_t>(edges));
} }
bool satisfy_time_constraint( inline bool satisfy_time(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const node_t &src_node_type, const int64_t &dst_time, const node_t &src_node_type, const int64_t &dst_time,
const int64_t &src_node) { const int64_t &src_node) {
// whether src -> dst obeys the time constraint try { // Check whether src -> dst obeys the time constraint:
try {
auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>(); auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
return dst_time < src_time[src_node]; return dst_time < src_time[src_node];
} catch (int err) { } catch (int err) { // If no time is given, fall back to normal sampling:
// if the node type does not have timestamp, fall back to normal sampling
return true; return true;
} }
} }
...@@ -137,8 +134,9 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -137,8 +134,9 @@ hetero_sample(const vector<node_t> &node_types,
const c10::Dict<rel_t, torch::Tensor> &row_dict, const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict, const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict, const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict) { const int64_t num_hops) {
// Create a mapping to convert single string relations to edge type triplets: // Create a mapping to convert single string relations to edge type triplets:
unordered_map<rel_t, edge_t> to_edge_type; unordered_map<rel_t, edge_t> to_edge_type;
for (const auto &k : edge_types) for (const auto &k : edge_types)
...@@ -155,8 +153,6 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -155,8 +153,6 @@ hetero_sample(const vector<node_t> &node_types,
unordered_map<node_t, vector<int64_t>> samples_dict; unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict; unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
// The timestamp of the center node whose neighborhood that the sampled node
// belongs to. It maps node_type to empty vector in non-temporal sampling.
unordered_map<node_t, vector<int64_t>> root_time_dict; unordered_map<node_t, vector<int64_t>> root_time_dict;
for (const auto &node_type : node_types) { for (const auto &node_type : node_types) {
samples_dict[node_type]; samples_dict[node_type];
...@@ -169,10 +165,7 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -169,10 +165,7 @@ hetero_sample(const vector<node_t> &node_types,
const auto &node_type = kv.key(); const auto &node_type = kv.key();
const torch::Tensor &input_node = kv.value(); const torch::Tensor &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>(); const auto *input_node_data = input_node.data_ptr<int64_t>();
// dummy value. will be reset to root time if is_temporal==true
int64_t *node_time_data; int64_t *node_time_data;
// root_time[i] stores the timestamp of the computation tree root
// of the node samples[i]
if (temporal) { if (temporal) {
torch::Tensor node_time = node_time_dict.at(node_type); torch::Tensor node_time = node_time_dict.at(node_type);
node_time_data = node_time.data_ptr<int64_t>(); node_time_data = node_time.data_ptr<int64_t>();
...@@ -185,11 +178,10 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -185,11 +178,10 @@ hetero_sample(const vector<node_t> &node_types,
const auto &v = input_node_data[i]; const auto &v = input_node_data[i];
samples.push_back(v); samples.push_back(v);
to_local_node.insert({v, i}); to_local_node.insert({v, i});
if (temporal) { if (temporal)
root_time.push_back(node_time_data[v]); root_time.push_back(node_time_data[v]);
} }
} }
}
unordered_map<node_t, pair<int64_t, int64_t>> slice_dict; unordered_map<node_t, pair<int64_t, int64_t>> slice_dict;
for (const auto &kv : samples_dict) for (const auto &kv : samples_dict)
...@@ -217,11 +209,12 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -217,11 +209,12 @@ hetero_sample(const vector<node_t> &node_types,
const auto &begin = slice_dict.at(dst_node_type).first; const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second; const auto &end = slice_dict.at(dst_node_type).second;
if (begin == end) {
if (begin == end)
continue; continue;
}
// for temporal sampling, sampled src node cannot have timestamp greater // For temporal sampling, sampled nodes cannot have a timestamp greater
// than its corresponding dst_root_time // than the timestamp of the root nodes.
const auto &dst_root_time = root_time_dict.at(dst_node_type); const auto &dst_root_time = root_time_dict.at(dst_node_type);
auto &src_root_time = root_time_dict.at(src_node_type); auto &src_root_time = root_time_dict.at(src_node_type);
...@@ -236,16 +229,13 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -236,16 +229,13 @@ hetero_sample(const vector<node_t> &node_types,
continue; continue;
if ((num_samples < 0) || (!replace && (num_samples >= col_count))) { if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
// select all neighbors // Select all neighbors:
for (int64_t offset = col_start; offset < col_end; offset++) { for (int64_t offset = col_start; offset < col_end; offset++) {
const int64_t &v = row_data[offset]; const int64_t &v = row_data[offset];
bool time_constraint = true;
if (temporal) { if (temporal) {
time_constraint = satisfy_time_constraint( if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
node_time_dict, src_node_type, dst_time, v);
}
if (!time_constraint)
continue; continue;
}
const auto res = to_local_src_node.insert({v, src_samples.size()}); const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) { if (res.second) {
src_samples.push_back(v); src_samples.push_back(v);
...@@ -259,18 +249,16 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -259,18 +249,16 @@ hetero_sample(const vector<node_t> &node_types,
} }
} }
} else if (replace) { } else if (replace) {
// sample with replacement // Sample with replacement:
int64_t num_neighbors = 0; int64_t num_neighbors = 0;
while (num_neighbors < num_samples) { while (num_neighbors < num_samples) {
const int64_t offset = col_start + uniform_randint(col_count); const int64_t offset = col_start + uniform_randint(col_count);
const int64_t &v = row_data[offset]; const int64_t &v = row_data[offset];
bool time_constraint = true;
if (temporal) { if (temporal) {
time_constraint = satisfy_time_constraint( // TODO Infinity loop if no neighbor satisfies time constraint:
node_time_dict, src_node_type, dst_time, v); if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
}
if (!time_constraint)
continue; continue;
}
const auto res = to_local_src_node.insert({v, src_samples.size()}); const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) { if (res.second) {
src_samples.push_back(v); src_samples.push_back(v);
...@@ -285,7 +273,7 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -285,7 +273,7 @@ hetero_sample(const vector<node_t> &node_types,
num_neighbors += 1; num_neighbors += 1;
} }
} else { } else {
// sample without replacement // Sample without replacement:
unordered_set<int64_t> rnd_indices; unordered_set<int64_t> rnd_indices;
for (int64_t j = col_count - num_samples; j < col_count; j++) { for (int64_t j = col_count - num_samples; j < col_count; j++) {
int64_t rnd = uniform_randint(j); int64_t rnd = uniform_randint(j);
...@@ -295,13 +283,10 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -295,13 +283,10 @@ hetero_sample(const vector<node_t> &node_types,
} }
const int64_t offset = col_start + rnd; const int64_t offset = col_start + rnd;
const int64_t &v = row_data[offset]; const int64_t &v = row_data[offset];
bool time_constraint = true;
if (temporal) { if (temporal) {
time_constraint = satisfy_time_constraint( if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
node_time_dict, src_node_type, dst_time, v);
}
if (!time_constraint)
continue; continue;
}
const auto res = to_local_src_node.insert({v, src_samples.size()}); const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) { if (res.second) {
src_samples.push_back(v); src_samples.push_back(v);
...@@ -364,22 +349,6 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -364,22 +349,6 @@ hetero_sample(const vector<node_t> &node_types,
from_vector<rel_t, int64_t>(edges_dict)); from_vector<rel_t, int64_t>(edges_dict));
} }
template <bool replace, bool directed>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample_random(
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops) {
c10::Dict<node_t, torch::Tensor> empty_dict;
return hetero_sample<replace, directed, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, empty_dict);
}
} // namespace } // namespace
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
...@@ -409,28 +378,30 @@ hetero_neighbor_sample_cpu( ...@@ -409,28 +378,30 @@ hetero_neighbor_sample_cpu(
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict, const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const bool replace, const bool directed) { const int64_t num_hops, const bool replace, const bool directed) {
c10::Dict<node_t, torch::Tensor> node_time_dict; // Empty dictionary.
if (replace && directed) { if (replace && directed) {
return hetero_sample_random<true, true>(node_types, edge_types, colptr_dict, return hetero_sample<true, true, false>(
row_dict, input_node_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops); num_neighbors_dict, node_time_dict, num_hops);
} else if (replace && !directed) { } else if (replace && !directed) {
return hetero_sample_random<true, false>( return hetero_sample<true, false, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops); num_neighbors_dict, node_time_dict, num_hops);
} else if (!replace && directed) { } else if (!replace && directed) {
return hetero_sample_random<false, true>( return hetero_sample<false, true, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops); num_neighbors_dict, node_time_dict, num_hops);
} else { } else {
return hetero_sample_random<false, false>( return hetero_sample<false, false, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops); num_neighbors_dict, node_time_dict, num_hops);
} }
} }
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>, tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>> c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_temporal_sample_cpu( hetero_temporal_neighbor_sample_cpu(
const vector<node_t> &node_types, const vector<edge_t> &edge_types, const vector<node_t> &node_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict, const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict, const c10::Dict<rel_t, torch::Tensor> &row_dict,
...@@ -442,18 +413,18 @@ hetero_neighbor_temporal_sample_cpu( ...@@ -442,18 +413,18 @@ hetero_neighbor_temporal_sample_cpu(
if (replace && directed) { if (replace && directed) {
return hetero_sample<true, true, true>( return hetero_sample<true, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict); num_neighbors_dict, node_time_dict, num_hops);
} else if (replace && !directed) { } else if (replace && !directed) {
return hetero_sample<true, false, true>( return hetero_sample<true, false, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict); num_neighbors_dict, node_time_dict, num_hops);
} else if (!replace && directed) { } else if (!replace && directed) {
return hetero_sample<false, true, true>( return hetero_sample<false, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict); num_neighbors_dict, node_time_dict, num_hops);
} else { } else {
return hetero_sample<false, false, true>( return hetero_sample<false, false, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict); num_neighbors_dict, node_time_dict, num_hops);
} }
} }
...@@ -25,7 +25,7 @@ hetero_neighbor_sample_cpu( ...@@ -25,7 +25,7 @@ hetero_neighbor_sample_cpu(
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>, std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>> c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_temporal_sample_cpu( hetero_temporal_neighbor_sample_cpu(
const std::vector<node_t> &node_types, const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types, const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict, const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
......
...@@ -16,7 +16,8 @@ PyMODINIT_FUNC PyInit__neighbor_sample_cpu(void) { return NULL; } ...@@ -16,7 +16,8 @@ PyMODINIT_FUNC PyInit__neighbor_sample_cpu(void) { return NULL; }
#endif #endif
// Returns 'output_node', 'row', 'col', 'output_edge' // Returns 'output_node', 'row', 'col', 'output_edge'
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> SPARSE_API
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row, neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row,
const torch::Tensor &input_node, const torch::Tensor &input_node,
const std::vector<int64_t> num_neighbors, const bool replace, const std::vector<int64_t> num_neighbors, const bool replace,
...@@ -25,7 +26,8 @@ neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row, ...@@ -25,7 +26,8 @@ neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row,
directed); directed);
} }
SPARSE_API std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>, SPARSE_API
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>> c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_sample( hetero_neighbor_sample(
const std::vector<node_t> &node_types, const std::vector<node_t> &node_types,
...@@ -42,7 +44,7 @@ hetero_neighbor_sample( ...@@ -42,7 +44,7 @@ hetero_neighbor_sample(
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>, std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>> c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_temporal_sample( hetero_temporal_neighbor_sample(
const std::vector<node_t> &node_types, const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types, const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict, const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
...@@ -51,7 +53,7 @@ hetero_neighbor_temporal_sample( ...@@ -51,7 +53,7 @@ hetero_neighbor_temporal_sample(
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict, const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict, const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed) { const int64_t num_hops, const bool replace, const bool directed) {
return hetero_neighbor_temporal_sample_cpu( return hetero_temporal_neighbor_sample_cpu(
node_types, edge_types, colptr_dict, row_dict, input_node_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops, replace, directed); num_neighbors_dict, node_time_dict, num_hops, replace, directed);
} }
...@@ -60,4 +62,5 @@ static auto registry = ...@@ -60,4 +62,5 @@ static auto registry =
torch::RegisterOperators() torch::RegisterOperators()
.op("torch_sparse::neighbor_sample", &neighbor_sample) .op("torch_sparse::neighbor_sample", &neighbor_sample)
.op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample) .op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample)
.op("torch_sparse::hetero_neighbor_temporal_sample", &hetero_neighbor_temporal_sample); .op("torch_sparse::hetero_temporal_neighbor_sample",
&hetero_temporal_neighbor_sample);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment