"git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "cc75d0e80c24a06e2b9881081f06acca413e8c5e"
Unverified Commit 6143af21 authored by Dong Wang's avatar Dong Wang Committed by GitHub
Browse files

use batch idx and node id as unique key for dedup in temporal sampling (#267)


Co-authored-by: default avatarDong Wang <d@dongs-mbp.lan>
parent 916ba55b
......@@ -10,6 +10,8 @@ using namespace std;
namespace {
typedef phmap::flat_hash_map<pair<int64_t, int64_t>, int64_t> temporarl_edge_dict;
template <bool replace, bool directed>
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample(const torch::Tensor &colptr, const torch::Tensor &row,
......@@ -146,11 +148,15 @@ hetero_sample(const vector<node_t> &node_types,
// Initialize some data structures for the sampling process:
phmap::flat_hash_map<node_t, vector<int64_t>> samples_dict;
phmap::flat_hash_map<node_t, vector<pair<int64_t, int64_t>>> temp_samples_dict;
phmap::flat_hash_map<node_t, phmap::flat_hash_map<int64_t, int64_t>> to_local_node_dict;
phmap::flat_hash_map<node_t, temporarl_edge_dict> temp_to_local_node_dict;
phmap::flat_hash_map<node_t, vector<int64_t>> root_time_dict;
for (const auto &node_type : node_types) {
samples_dict[node_type];
temp_samples_dict[node_type];
to_local_node_dict[node_type];
temp_to_local_node_dict[node_type];
root_time_dict[node_type];
}
......@@ -175,20 +181,33 @@ hetero_sample(const vector<node_t> &node_types,
}
auto &samples = samples_dict.at(node_type);
auto &temp_samples = temp_samples_dict.at(node_type);
auto &to_local_node = to_local_node_dict.at(node_type);
auto &temp_to_local_node = temp_to_local_node_dict.at(node_type);
auto &root_time = root_time_dict.at(node_type);
for (int64_t i = 0; i < input_node.numel(); i++) {
const auto &v = input_node_data[i];
samples.push_back(v);
to_local_node.insert({v, i});
if (temporal) {
temp_samples.push_back({v, i});
temp_to_local_node.insert({{v, i}, i});
} else {
samples.push_back(v);
to_local_node.insert({v, i});
}
if (temporal)
root_time.push_back(node_time_data[v]);
}
}
phmap::flat_hash_map<node_t, pair<int64_t, int64_t>> slice_dict;
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()};
if (temporal) {
for (const auto &kv : temp_samples_dict) {
slice_dict[kv.first] = {0, kv.second.size()};
}
} else {
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()};
}
vector<rel_t> all_rel_types;
for (const auto &kv : num_neighbors_dict) {
......@@ -203,8 +222,11 @@ hetero_sample(const vector<node_t> &node_types,
const auto &dst_node_type = get<2>(edge_type);
const auto num_samples = num_neighbors_dict.at(rel_type)[ell];
const auto &dst_samples = samples_dict.at(dst_node_type);
const auto &temp_dst_samples = temp_samples_dict.at(dst_node_type);
auto &src_samples = samples_dict.at(src_node_type);
auto &temp_src_samples = temp_samples_dict.at(src_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
auto &temp_to_local_src_node = temp_to_local_node_dict.at(src_node_type);
const torch::Tensor &colptr = colptr_dict.at(rel_type);
const auto *colptr_data = colptr.data_ptr<int64_t>();
......@@ -223,7 +245,8 @@ hetero_sample(const vector<node_t> &node_types,
const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
for (int64_t i = begin; i < end; i++) {
const auto &w = dst_samples[i];
const auto &w = temporal ? temp_dst_samples[i].first : dst_samples[i];
const int64_t root_w = temporal ? temp_dst_samples[i].second : -1;
int64_t dst_time = 0;
if (temporal)
dst_time = dst_root_time[i];
......@@ -241,15 +264,18 @@ hetero_sample(const vector<node_t> &node_types,
if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
// force disjoint of computation tree
// force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True
// for temporal case
// to_local_src_node is not used for temporal / directed case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time);
}
cols.push_back(i);
rows.push_back(sample_idx);
rows.push_back(res.first->second);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
......@@ -272,14 +298,17 @@ hetero_sample(const vector<node_t> &node_types,
// TODO Infinity loop if no neighbor satisfies time constraint:
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
// force disjoint of computation tree
// force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True
// for temporal case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time);
}
cols.push_back(i);
rows.push_back(sample_idx);
rows.push_back(res.first->second);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
......@@ -307,14 +336,17 @@ hetero_sample(const vector<node_t> &node_types,
if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
// force disjoint of computation tree
// force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True
// for temporal case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time);
}
cols.push_back(i);
rows.push_back(sample_idx);
rows.push_back(res.first->second);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
......@@ -331,11 +363,18 @@ hetero_sample(const vector<node_t> &node_types,
}
}
for (const auto &kv : samples_dict) {
slice_dict[kv.first] = {slice_dict.at(kv.first).second, kv.second.size()};
if (temporal) {
for (const auto &kv : temp_samples_dict) {
slice_dict[kv.first] = {0, kv.second.size()};
}
} else {
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()};
}
}
// Temporal sample disable undirected
assert(!(temporal && !directed));
if (!directed) { // Construct the subgraph among the sampled nodes:
phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
for (const auto &kv : colptr_dict) {
......@@ -371,6 +410,18 @@ hetero_sample(const vector<node_t> &node_types,
}
}
// Construct samples dictionary from temporal sample dictionary.
if (temporal) {
for (const auto &kv : temp_samples_dict) {
const auto &node_type = kv.first;
const auto &samples = kv.second;
samples_dict[node_type].reserve(samples.size());
for (const auto &v : samples) {
samples_dict[node_type].push_back(v.first);
}
}
}
return make_tuple(from_vector<node_t, int64_t>(samples_dict),
from_vector<rel_t, int64_t>(rows_dict),
from_vector<rel_t, int64_t>(cols_dict),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment