Unverified Commit 0172aeb3 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Temporal neighbor sampling adjustments (part2) (#226)

* temporal neighbor sampling adjustments (part2)

* fix
parent caf7ddde
......@@ -115,11 +115,11 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
}
inline bool satisfy_time(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const node_t &src_node_type, const int64_t &dst_time,
const int64_t &src_node) {
const node_t &src_node_type, int64_t dst_time,
int64_t src_node) {
try { // Check whether src -> dst obeys the time constraint:
auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
return dst_time < src_time[src_node];
const torch::Tensor &src_node_time = node_time_dict.at(src_node_type);
return src_node_time.data_ptr<int64_t>()[src_node] <= dst_time;
} catch (int err) { // If no time is given, fall back to normal sampling:
return true;
}
......@@ -143,14 +143,6 @@ hetero_sample(const vector<node_t> &node_types,
to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;
// Initialize some data structures for the sampling process:
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
rows_dict[rel_type];
cols_dict[rel_type];
edges_dict[rel_type];
}
unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
unordered_map<node_t, vector<int64_t>> root_time_dict;
......@@ -160,14 +152,23 @@ hetero_sample(const vector<node_t> &node_types,
root_time_dict[node_type];
}
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
rows_dict[rel_type];
cols_dict[rel_type];
edges_dict[rel_type];
}
// Add the input nodes to the output nodes:
for (const auto &kv : input_node_dict) {
const auto &node_type = kv.key();
const torch::Tensor &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>();
int64_t *node_time_data;
if (temporal) {
torch::Tensor node_time = node_time_dict.at(node_type);
const torch::Tensor &node_time = node_time_dict.at(node_type);
node_time_data = node_time.data_ptr<int64_t>();
}
......@@ -198,29 +199,27 @@ hetero_sample(const vector<node_t> &node_types,
auto &src_samples = samples_dict.at(src_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
const auto *colptr_data =
((torch::Tensor)colptr_dict.at(rel_type)).data_ptr<int64_t>();
const auto *row_data =
((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
const torch::Tensor &colptr = colptr_dict.at(rel_type);
const auto *colptr_data = colptr.data_ptr<int64_t>();
const torch::Tensor &row = row_dict.at(rel_type);
const auto *row_data = row.data_ptr<int64_t>();
auto &rows = rows_dict.at(rel_type);
auto &cols = cols_dict.at(rel_type);
auto &edges = edges_dict.at(rel_type);
const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
if (begin == end)
continue;
// For temporal sampling, sampled nodes cannot have a timestamp greater
// than the timestamp of the root nodes.
// than the timestamp of the root nodes:
const auto &dst_root_time = root_time_dict.at(dst_node_type);
auto &src_root_time = root_time_dict.at(src_node_type);
const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
for (int64_t i = begin; i < end; i++) {
const auto &w = dst_samples[i];
const auto &dst_time = dst_root_time[i];
int64_t dst_time = 0;
if (temporal)
dst_time = dst_root_time[i];
const auto &col_start = colptr_data[w];
const auto &col_end = colptr_data[w + 1];
const auto col_count = col_end - col_start;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment