Unverified Commit 6143af21 authored by Dong Wang's avatar Dong Wang Committed by GitHub
Browse files

use batch idx and node id as unique key for dedup in temporal sampling (#267)


Co-authored-by: default avatarDong Wang <d@dongs-mbp.lan>
parent 916ba55b
...@@ -10,6 +10,8 @@ using namespace std; ...@@ -10,6 +10,8 @@ using namespace std;
namespace { namespace {
typedef phmap::flat_hash_map<pair<int64_t, int64_t>, int64_t> temporarl_edge_dict;
template <bool replace, bool directed> template <bool replace, bool directed>
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample(const torch::Tensor &colptr, const torch::Tensor &row, sample(const torch::Tensor &colptr, const torch::Tensor &row,
...@@ -146,11 +148,15 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -146,11 +148,15 @@ hetero_sample(const vector<node_t> &node_types,
// Initialize some data structures for the sampling process: // Initialize some data structures for the sampling process:
phmap::flat_hash_map<node_t, vector<int64_t>> samples_dict; phmap::flat_hash_map<node_t, vector<int64_t>> samples_dict;
phmap::flat_hash_map<node_t, vector<pair<int64_t, int64_t>>> temp_samples_dict;
phmap::flat_hash_map<node_t, phmap::flat_hash_map<int64_t, int64_t>> to_local_node_dict; phmap::flat_hash_map<node_t, phmap::flat_hash_map<int64_t, int64_t>> to_local_node_dict;
phmap::flat_hash_map<node_t, temporarl_edge_dict> temp_to_local_node_dict;
phmap::flat_hash_map<node_t, vector<int64_t>> root_time_dict; phmap::flat_hash_map<node_t, vector<int64_t>> root_time_dict;
for (const auto &node_type : node_types) { for (const auto &node_type : node_types) {
samples_dict[node_type]; samples_dict[node_type];
temp_samples_dict[node_type];
to_local_node_dict[node_type]; to_local_node_dict[node_type];
temp_to_local_node_dict[node_type];
root_time_dict[node_type]; root_time_dict[node_type];
} }
...@@ -175,20 +181,33 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -175,20 +181,33 @@ hetero_sample(const vector<node_t> &node_types,
} }
auto &samples = samples_dict.at(node_type); auto &samples = samples_dict.at(node_type);
auto &temp_samples = temp_samples_dict.at(node_type);
auto &to_local_node = to_local_node_dict.at(node_type); auto &to_local_node = to_local_node_dict.at(node_type);
auto &temp_to_local_node = temp_to_local_node_dict.at(node_type);
auto &root_time = root_time_dict.at(node_type); auto &root_time = root_time_dict.at(node_type);
for (int64_t i = 0; i < input_node.numel(); i++) { for (int64_t i = 0; i < input_node.numel(); i++) {
const auto &v = input_node_data[i]; const auto &v = input_node_data[i];
if (temporal) {
temp_samples.push_back({v, i});
temp_to_local_node.insert({{v, i}, i});
} else {
samples.push_back(v); samples.push_back(v);
to_local_node.insert({v, i}); to_local_node.insert({v, i});
}
if (temporal) if (temporal)
root_time.push_back(node_time_data[v]); root_time.push_back(node_time_data[v]);
} }
} }
phmap::flat_hash_map<node_t, pair<int64_t, int64_t>> slice_dict; phmap::flat_hash_map<node_t, pair<int64_t, int64_t>> slice_dict;
if (temporal) {
for (const auto &kv : temp_samples_dict) {
slice_dict[kv.first] = {0, kv.second.size()};
}
} else {
for (const auto &kv : samples_dict) for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()}; slice_dict[kv.first] = {0, kv.second.size()};
}
vector<rel_t> all_rel_types; vector<rel_t> all_rel_types;
for (const auto &kv : num_neighbors_dict) { for (const auto &kv : num_neighbors_dict) {
...@@ -203,8 +222,11 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -203,8 +222,11 @@ hetero_sample(const vector<node_t> &node_types,
const auto &dst_node_type = get<2>(edge_type); const auto &dst_node_type = get<2>(edge_type);
const auto num_samples = num_neighbors_dict.at(rel_type)[ell]; const auto num_samples = num_neighbors_dict.at(rel_type)[ell];
const auto &dst_samples = samples_dict.at(dst_node_type); const auto &dst_samples = samples_dict.at(dst_node_type);
const auto &temp_dst_samples = temp_samples_dict.at(dst_node_type);
auto &src_samples = samples_dict.at(src_node_type); auto &src_samples = samples_dict.at(src_node_type);
auto &temp_src_samples = temp_samples_dict.at(src_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type); auto &to_local_src_node = to_local_node_dict.at(src_node_type);
auto &temp_to_local_src_node = temp_to_local_node_dict.at(src_node_type);
const torch::Tensor &colptr = colptr_dict.at(rel_type); const torch::Tensor &colptr = colptr_dict.at(rel_type);
const auto *colptr_data = colptr.data_ptr<int64_t>(); const auto *colptr_data = colptr.data_ptr<int64_t>();
...@@ -223,7 +245,8 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -223,7 +245,8 @@ hetero_sample(const vector<node_t> &node_types,
const auto &begin = slice_dict.at(dst_node_type).first; const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second; const auto &end = slice_dict.at(dst_node_type).second;
for (int64_t i = begin; i < end; i++) { for (int64_t i = begin; i < end; i++) {
const auto &w = dst_samples[i]; const auto &w = temporal ? temp_dst_samples[i].first : dst_samples[i];
const int64_t root_w = temporal ? temp_dst_samples[i].second : -1;
int64_t dst_time = 0; int64_t dst_time = 0;
if (temporal) if (temporal)
dst_time = dst_root_time[i]; dst_time = dst_root_time[i];
...@@ -241,15 +264,18 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -241,15 +264,18 @@ hetero_sample(const vector<node_t> &node_types,
if (temporal) { if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue; continue;
// force disjoint of computation tree // force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True // note that the sampling always needs to have directed=True
// for temporal case // for temporal case
// to_local_src_node is not used for temporal / directed case // to_local_src_node is not used for temporal / directed case
const int64_t sample_idx = src_samples.size(); const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
src_samples.push_back(v); if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time); src_root_time.push_back(dst_time);
}
cols.push_back(i); cols.push_back(i);
rows.push_back(sample_idx); rows.push_back(res.first->second);
edges.push_back(offset); edges.push_back(offset);
} else { } else {
const auto res = to_local_src_node.insert({v, src_samples.size()}); const auto res = to_local_src_node.insert({v, src_samples.size()});
...@@ -272,14 +298,17 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -272,14 +298,17 @@ hetero_sample(const vector<node_t> &node_types,
// TODO Infinity loop if no neighbor satisfies time constraint: // TODO Infinity loop if no neighbor satisfies time constraint:
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue; continue;
// force disjoint of computation tree // force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True // note that the sampling always needs to have directed=True
// for temporal case // for temporal case
const int64_t sample_idx = src_samples.size(); const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
src_samples.push_back(v); if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time); src_root_time.push_back(dst_time);
}
cols.push_back(i); cols.push_back(i);
rows.push_back(sample_idx); rows.push_back(res.first->second);
edges.push_back(offset); edges.push_back(offset);
} else { } else {
const auto res = to_local_src_node.insert({v, src_samples.size()}); const auto res = to_local_src_node.insert({v, src_samples.size()});
...@@ -307,14 +336,17 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -307,14 +336,17 @@ hetero_sample(const vector<node_t> &node_types,
if (temporal) { if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v)) if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue; continue;
// force disjoint of computation tree // force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True // note that the sampling always needs to have directed=True
// for temporal case // for temporal case
const int64_t sample_idx = src_samples.size(); const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
src_samples.push_back(v); if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time); src_root_time.push_back(dst_time);
}
cols.push_back(i); cols.push_back(i);
rows.push_back(sample_idx); rows.push_back(res.first->second);
edges.push_back(offset); edges.push_back(offset);
} else { } else {
const auto res = to_local_src_node.insert({v, src_samples.size()}); const auto res = to_local_src_node.insert({v, src_samples.size()});
...@@ -331,11 +363,18 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -331,11 +363,18 @@ hetero_sample(const vector<node_t> &node_types,
} }
} }
for (const auto &kv : samples_dict) { if (temporal) {
slice_dict[kv.first] = {slice_dict.at(kv.first).second, kv.second.size()}; for (const auto &kv : temp_samples_dict) {
slice_dict[kv.first] = {0, kv.second.size()};
}
} else {
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()};
} }
} }
// Temporal sample disable undirected
assert(!(temporal && !directed));
if (!directed) { // Construct the subgraph among the sampled nodes: if (!directed) { // Construct the subgraph among the sampled nodes:
phmap::flat_hash_map<int64_t, int64_t>::iterator iter; phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
for (const auto &kv : colptr_dict) { for (const auto &kv : colptr_dict) {
...@@ -371,6 +410,18 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -371,6 +410,18 @@ hetero_sample(const vector<node_t> &node_types,
} }
} }
// Construct samples dictionary from temporal sample dictionary.
if (temporal) {
for (const auto &kv : temp_samples_dict) {
const auto &node_type = kv.first;
const auto &samples = kv.second;
samples_dict[node_type].reserve(samples.size());
for (const auto &v : samples) {
samples_dict[node_type].push_back(v.first);
}
}
}
return make_tuple(from_vector<node_t, int64_t>(samples_dict), return make_tuple(from_vector<node_t, int64_t>(samples_dict),
from_vector<rel_t, int64_t>(rows_dict), from_vector<rel_t, int64_t>(rows_dict),
from_vector<rel_t, int64_t>(cols_dict), from_vector<rel_t, int64_t>(cols_dict),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment