Unverified Commit 0172aeb3 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Temporal neighbor sampling adjustments (part2) (#226)

* temporal neighbor sampling adjustments (part2)

* fix
parent caf7ddde
...@@ -115,11 +115,11 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row, ...@@ -115,11 +115,11 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
} }
inline bool satisfy_time(const c10::Dict<node_t, torch::Tensor> &node_time_dict, inline bool satisfy_time(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const node_t &src_node_type, const int64_t &dst_time, const node_t &src_node_type, int64_t dst_time,
const int64_t &src_node) { int64_t src_node) {
try { // Check whether src -> dst obeys the time constraint: try { // Check whether src -> dst obeys the time constraint:
auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>(); const torch::Tensor &src_node_time = node_time_dict.at(src_node_type);
return dst_time < src_time[src_node]; return src_node_time.data_ptr<int64_t>()[src_node] <= dst_time;
} catch (int err) { // If no time is given, fall back to normal sampling: } catch (int err) { // If no time is given, fall back to normal sampling:
return true; return true;
} }
...@@ -143,14 +143,6 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -143,14 +143,6 @@ hetero_sample(const vector<node_t> &node_types,
to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k; to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;
// Initialize some data structures for the sampling process: // Initialize some data structures for the sampling process:
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
rows_dict[rel_type];
cols_dict[rel_type];
edges_dict[rel_type];
}
unordered_map<node_t, vector<int64_t>> samples_dict; unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict; unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
unordered_map<node_t, vector<int64_t>> root_time_dict; unordered_map<node_t, vector<int64_t>> root_time_dict;
...@@ -160,14 +152,23 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -160,14 +152,23 @@ hetero_sample(const vector<node_t> &node_types,
root_time_dict[node_type]; root_time_dict[node_type];
} }
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
rows_dict[rel_type];
cols_dict[rel_type];
edges_dict[rel_type];
}
// Add the input nodes to the output nodes: // Add the input nodes to the output nodes:
for (const auto &kv : input_node_dict) { for (const auto &kv : input_node_dict) {
const auto &node_type = kv.key(); const auto &node_type = kv.key();
const torch::Tensor &input_node = kv.value(); const torch::Tensor &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>(); const auto *input_node_data = input_node.data_ptr<int64_t>();
int64_t *node_time_data; int64_t *node_time_data;
if (temporal) { if (temporal) {
torch::Tensor node_time = node_time_dict.at(node_type); const torch::Tensor &node_time = node_time_dict.at(node_type);
node_time_data = node_time.data_ptr<int64_t>(); node_time_data = node_time.data_ptr<int64_t>();
} }
...@@ -198,29 +199,27 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -198,29 +199,27 @@ hetero_sample(const vector<node_t> &node_types,
auto &src_samples = samples_dict.at(src_node_type); auto &src_samples = samples_dict.at(src_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type); auto &to_local_src_node = to_local_node_dict.at(src_node_type);
const auto *colptr_data = const torch::Tensor &colptr = colptr_dict.at(rel_type);
((torch::Tensor)colptr_dict.at(rel_type)).data_ptr<int64_t>(); const auto *colptr_data = colptr.data_ptr<int64_t>();
const auto *row_data = const torch::Tensor &row = row_dict.at(rel_type);
((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>(); const auto *row_data = row.data_ptr<int64_t>();
auto &rows = rows_dict.at(rel_type); auto &rows = rows_dict.at(rel_type);
auto &cols = cols_dict.at(rel_type); auto &cols = cols_dict.at(rel_type);
auto &edges = edges_dict.at(rel_type); auto &edges = edges_dict.at(rel_type);
const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
if (begin == end)
continue;
// For temporal sampling, sampled nodes cannot have a timestamp greater // For temporal sampling, sampled nodes cannot have a timestamp greater
// than the timestamp of the root nodes. // than the timestamp of the root nodes:
const auto &dst_root_time = root_time_dict.at(dst_node_type); const auto &dst_root_time = root_time_dict.at(dst_node_type);
auto &src_root_time = root_time_dict.at(src_node_type); auto &src_root_time = root_time_dict.at(src_node_type);
const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
for (int64_t i = begin; i < end; i++) { for (int64_t i = begin; i < end; i++) {
const auto &w = dst_samples[i]; const auto &w = dst_samples[i];
const auto &dst_time = dst_root_time[i]; int64_t dst_time = 0;
if (temporal)
dst_time = dst_root_time[i];
const auto &col_start = colptr_data[w]; const auto &col_start = colptr_data[w];
const auto &col_end = colptr_data[w + 1]; const auto &col_end = colptr_data[w + 1];
const auto col_count = col_end - col_start; const auto col_count = col_end - col_start;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment