Unverified Commit d670cddb authored by Rex Ying's avatar Rex Ying Committed by GitHub
Browse files

Temporal sample disable undirected (#247)



* disable undirected for temporal sampling

* disjoint sampling for temporal

* fix repeated node index

* compile fix

* Update csrc/cpu/neighbor_sample_cpu.cpp
Co-authored-by: default avatarZecheng Zhang <zecheng@kumo.ai>

* Update csrc/cpu/neighbor_sample_cpu.cpp
Co-authored-by: default avatarZecheng Zhang <zecheng@kumo.ai>

* Update csrc/cpu/neighbor_sample_cpu.cpp
Co-authored-by: default avatarZecheng Zhang <zecheng@kumo.ai>

* comments on directed to be true

* add directed in API

* comments

* minor function signature fix

* Update csrc/cpu/neighbor_sample_cpu.cpp

* Update csrc/neighbor_sample.cpp
Co-authored-by: default avatarZecheng Zhang <zecheng@kumo.ai>
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>
parent 7dbc51cd
......@@ -238,17 +238,24 @@ hetero_sample(const vector<node_t> &node_types,
if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
}
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) {
// force disjoint of computation tree
// note that the sampling always needs to have directed=True
// for temporal case
// to_local_src_node is not used for temporal / directed case
src_samples.push_back(v);
if (temporal)
src_root_time.push_back(dst_time);
}
if (directed) {
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(res.first->second);
rows.push_back(src_samples.size() - 1);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
}
} else if (replace) {
......@@ -261,17 +268,23 @@ hetero_sample(const vector<node_t> &node_types,
// TODO Infinity loop if no neighbor satisfies time constraint:
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
}
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) {
// force disjoint of computation tree
// note that the sampling always needs to have directed=True
// for temporal case
src_samples.push_back(v);
if (temporal)
src_root_time.push_back(dst_time);
}
if (directed) {
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(res.first->second);
rows.push_back(src_samples.size() - 1);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
num_neighbors += 1;
}
......@@ -289,17 +302,23 @@ hetero_sample(const vector<node_t> &node_types,
if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
}
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) {
// force disjoint of computation tree
// note that the sampling always needs to have directed=True
// for temporal case
src_samples.push_back(v);
if (temporal)
src_root_time.push_back(dst_time);
}
if (directed) {
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(res.first->second);
rows.push_back(src_samples.size() - 1);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
}
}
......@@ -412,21 +431,19 @@ hetero_temporal_neighbor_sample_cpu(
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed) {
if (replace && directed) {
AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling")
if (replace) {
// We assume that directed = True for temporal sampling
// The current implementation uses disjoint computation trees
// to tackle the case of the same node sampled having different
// root time constraint.
// In future, we could extend to directed = False case,
// allowing additional edges within each computation tree.
return hetero_sample<true, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else if (replace && !directed) {
return hetero_sample<true, false, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else if (!replace && directed) {
return hetero_sample<false, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else {
return hetero_sample<false, false, true>(
return hetero_sample<false, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment