Unverified Commit caf7ddde authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Temporal `neighbor_sample` adjustments (#225)

* version up

* formatting

* fix

* reset

* revert

* temporal neighbor sampling adjustments

* typo
parent 15c56351
......@@ -114,16 +114,13 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
from_vector<int64_t>(cols), from_vector<int64_t>(edges));
}
bool satisfy_time_constraint(
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const node_t &src_node_type, const int64_t &dst_time,
const int64_t &src_node) {
// whether src -> dst obeys the time constraint
try {
inline bool satisfy_time(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const node_t &src_node_type, const int64_t &dst_time,
const int64_t &src_node) {
try { // Check whether src -> dst obeys the time constraint:
auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
return dst_time < src_time[src_node];
} catch (int err) {
// if the node type does not have timestamp, fall back to normal sampling
} catch (int err) { // If no time is given, fall back to normal sampling:
return true;
}
}
......@@ -137,8 +134,9 @@ hetero_sample(const vector<node_t> &node_types,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops,
const c10::Dict<node_t, torch::Tensor> &node_time_dict) {
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops) {
// Create a mapping to convert single string relations to edge type triplets:
unordered_map<rel_t, edge_t> to_edge_type;
for (const auto &k : edge_types)
......@@ -155,8 +153,6 @@ hetero_sample(const vector<node_t> &node_types,
unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
// The timestamp of the center node whose neighborhood that the sampled node
// belongs to. It maps node_type to empty vector in non-temporal sampling.
unordered_map<node_t, vector<int64_t>> root_time_dict;
for (const auto &node_type : node_types) {
samples_dict[node_type];
......@@ -169,10 +165,7 @@ hetero_sample(const vector<node_t> &node_types,
const auto &node_type = kv.key();
const torch::Tensor &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>();
// dummy value. will be reset to root time if is_temporal==true
int64_t *node_time_data;
// root_time[i] stores the timestamp of the computation tree root
// of the node samples[i]
if (temporal) {
torch::Tensor node_time = node_time_dict.at(node_type);
node_time_data = node_time.data_ptr<int64_t>();
......@@ -185,9 +178,8 @@ hetero_sample(const vector<node_t> &node_types,
const auto &v = input_node_data[i];
samples.push_back(v);
to_local_node.insert({v, i});
if (temporal) {
if (temporal)
root_time.push_back(node_time_data[v]);
}
}
}
......@@ -217,11 +209,12 @@ hetero_sample(const vector<node_t> &node_types,
const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
if (begin == end) {
if (begin == end)
continue;
}
// for temporal sampling, sampled src node cannot have timestamp greater
// than its corresponding dst_root_time
// For temporal sampling, sampled nodes cannot have a timestamp greater
// than the timestamp of the root nodes.
const auto &dst_root_time = root_time_dict.at(dst_node_type);
auto &src_root_time = root_time_dict.at(src_node_type);
......@@ -236,16 +229,13 @@ hetero_sample(const vector<node_t> &node_types,
continue;
if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
// select all neighbors
// Select all neighbors:
for (int64_t offset = col_start; offset < col_end; offset++) {
const int64_t &v = row_data[offset];
bool time_constraint = true;
if (temporal) {
time_constraint = satisfy_time_constraint(
node_time_dict, src_node_type, dst_time, v);
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
}
if (!time_constraint)
continue;
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) {
src_samples.push_back(v);
......@@ -259,18 +249,16 @@ hetero_sample(const vector<node_t> &node_types,
}
}
} else if (replace) {
// sample with replacement
// Sample with replacement:
int64_t num_neighbors = 0;
while (num_neighbors < num_samples) {
const int64_t offset = col_start + uniform_randint(col_count);
const int64_t &v = row_data[offset];
bool time_constraint = true;
if (temporal) {
time_constraint = satisfy_time_constraint(
node_time_dict, src_node_type, dst_time, v);
// TODO Infinity loop if no neighbor satisfies time constraint:
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
}
if (!time_constraint)
continue;
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) {
src_samples.push_back(v);
......@@ -285,7 +273,7 @@ hetero_sample(const vector<node_t> &node_types,
num_neighbors += 1;
}
} else {
// sample without replacement
// Sample without replacement:
unordered_set<int64_t> rnd_indices;
for (int64_t j = col_count - num_samples; j < col_count; j++) {
int64_t rnd = uniform_randint(j);
......@@ -295,13 +283,10 @@ hetero_sample(const vector<node_t> &node_types,
}
const int64_t offset = col_start + rnd;
const int64_t &v = row_data[offset];
bool time_constraint = true;
if (temporal) {
time_constraint = satisfy_time_constraint(
node_time_dict, src_node_type, dst_time, v);
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
}
if (!time_constraint)
continue;
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) {
src_samples.push_back(v);
......@@ -364,22 +349,6 @@ hetero_sample(const vector<node_t> &node_types,
from_vector<rel_t, int64_t>(edges_dict));
}
template <bool replace, bool directed>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample_random(
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops) {
c10::Dict<node_t, torch::Tensor> empty_dict;
return hetero_sample<replace, directed, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, empty_dict);
}
} // namespace
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
......@@ -409,28 +378,30 @@ hetero_neighbor_sample_cpu(
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const bool replace, const bool directed) {
c10::Dict<node_t, torch::Tensor> node_time_dict; // Empty dictionary.
if (replace && directed) {
return hetero_sample_random<true, true>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
return hetero_sample<true, true, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else if (replace && !directed) {
return hetero_sample_random<true, false>(
return hetero_sample<true, false, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops);
num_neighbors_dict, node_time_dict, num_hops);
} else if (!replace && directed) {
return hetero_sample_random<false, true>(
return hetero_sample<false, true, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops);
num_neighbors_dict, node_time_dict, num_hops);
} else {
return hetero_sample_random<false, false>(
return hetero_sample<false, false, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops);
num_neighbors_dict, node_time_dict, num_hops);
}
}
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_temporal_sample_cpu(
hetero_temporal_neighbor_sample_cpu(
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
......@@ -442,18 +413,18 @@ hetero_neighbor_temporal_sample_cpu(
if (replace && directed) {
return hetero_sample<true, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
num_neighbors_dict, node_time_dict, num_hops);
} else if (replace && !directed) {
return hetero_sample<true, false, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
num_neighbors_dict, node_time_dict, num_hops);
} else if (!replace && directed) {
return hetero_sample<false, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
num_neighbors_dict, node_time_dict, num_hops);
} else {
return hetero_sample<false, false, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
num_neighbors_dict, node_time_dict, num_hops);
}
}
......@@ -25,12 +25,12 @@ hetero_neighbor_sample_cpu(
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_temporal_sample_cpu(
const std::vector<node_t> &node_types,
hetero_temporal_neighbor_sample_cpu(
const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed);
\ No newline at end of file
const int64_t num_hops, const bool replace, const bool directed);
......@@ -16,7 +16,8 @@ PyMODINIT_FUNC PyInit__neighbor_sample_cpu(void) { return NULL; }
#endif
// Returns 'output_node', 'row', 'col', 'output_edge'
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
SPARSE_API
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row,
const torch::Tensor &input_node,
const std::vector<int64_t> num_neighbors, const bool replace,
......@@ -25,7 +26,8 @@ neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row,
directed);
}
SPARSE_API std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
SPARSE_API
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_sample(
const std::vector<node_t> &node_types,
......@@ -42,7 +44,7 @@ hetero_neighbor_sample(
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_temporal_sample(
hetero_temporal_neighbor_sample(
const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
......@@ -51,7 +53,7 @@ hetero_neighbor_temporal_sample(
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed) {
return hetero_neighbor_temporal_sample_cpu(
return hetero_temporal_neighbor_sample_cpu(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops, replace, directed);
}
......@@ -60,4 +62,5 @@ static auto registry =
torch::RegisterOperators()
.op("torch_sparse::neighbor_sample", &neighbor_sample)
.op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample)
.op("torch_sparse::hetero_neighbor_temporal_sample", &hetero_neighbor_temporal_sample);
.op("torch_sparse::hetero_temporal_neighbor_sample",
&hetero_temporal_neighbor_sample);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment