Unverified Commit ad228bb8 authored by Rex Ying's avatar Rex Ying Committed by GitHub
Browse files

Minor refactor for temporal sampling (#257)



* disable undirected for temporal sampling

* disjoint sampling for temporal

* fix repeated node index

* compile fix

* Update csrc/cpu/neighbor_sample_cpu.cpp
Co-authored-by: default avatarZecheng Zhang <zecheng@kumo.ai>

* Update csrc/cpu/neighbor_sample_cpu.cpp
Co-authored-by: default avatarZecheng Zhang <zecheng@kumo.ai>

* Update csrc/cpu/neighbor_sample_cpu.cpp
Co-authored-by: default avatarZecheng Zhang <zecheng@kumo.ai>

* comments on directed to be true

* add directed in API

* comments

* minor function signature fix

* Update csrc/cpu/neighbor_sample_cpu.cpp

* Update csrc/neighbor_sample.cpp

* minor refactor

* minor refactor
Co-authored-by: default avatarZecheng Zhang <zecheng@kumo.ai>
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>
parent d670cddb
......@@ -242,10 +242,11 @@ hetero_sample(const vector<node_t> &node_types,
// note that the sampling always needs to have directed=True
// for temporal case
// to_local_src_node is not used for temporal / directed case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(src_samples.size() - 1);
rows.push_back(sample_idx);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
......@@ -271,10 +272,11 @@ hetero_sample(const vector<node_t> &node_types,
// force disjoint of computation tree
// note that the sampling always needs to have directed=True
// for temporal case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(src_samples.size() - 1);
rows.push_back(sample_idx);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
......@@ -305,10 +307,11 @@ hetero_sample(const vector<node_t> &node_types,
// force disjoint of computation tree
// note that the sampling always needs to have directed=True
// for temporal case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
cols.push_back(i);
rows.push_back(src_samples.size() - 1);
rows.push_back(sample_idx);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
......@@ -431,7 +434,7 @@ hetero_temporal_neighbor_sample_cpu(
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed) {
AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling")
AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling");
if (replace) {
// We assume that directed = True for temporal sampling
// The current implementation uses disjoint computation trees
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment