Unverified Commit e7389d7c authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] fix the order of the seed nodes in NodeFlow (#626)

* fix.

* add comments.
parent 8bf97719
......@@ -276,13 +276,17 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
size_t out_node_idx = 0;
for (int layer_id = num_hops - 1; layer_id >= 0; layer_id--) {
// We sort the vertices in a layer so that we don't need to sort the neighbor Ids
// after remap to a subgraph.
// after remap to a subgraph. However, we don't need to sort the first layer
// because we want the order of the nodes in the first layer is the same as
// the input seed nodes.
if (layer_id > 0) {
std::sort(sub_vers->begin() + layer_offsets[layer_id],
sub_vers->begin() + layer_offsets[layer_id + 1],
[](const std::pair<dgl_id_t, dgl_id_t> &a1,
const std::pair<dgl_id_t, dgl_id_t> &a2) {
return a1.first < a2.first;
});
}
// Save the sampled vertices and its layer Id.
for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1]; i++) {
......@@ -305,11 +309,15 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
layer_off_data[1] = layer_offsets[num_hops] - layer_offsets[num_hops - 1];
int out_layer_idx = 1;
for (int layer_id = num_hops - 2; layer_id >= 0; layer_id--) {
// Because we don't sort the vertices in the first layer above, we can't sort
// the neighbor positions of the vertices in the first layer either.
if (layer_id > 0) {
std::sort(neigh_pos->begin() + layer_offsets[layer_id],
neigh_pos->begin() + layer_offsets[layer_id + 1],
[](const neighbor_info &a1, const neighbor_info &a2) {
return a1.id < a2.id;
});
}
for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1]; i++) {
dgl_id_t dst_id = sub_vers->at(i).first;
......
......@@ -39,11 +39,12 @@ def test_self_loop():
assert F.array_equal(in_deg, deg)
def create_mini_batch(g, num_hops, add_self_loop=False):
seed_ids = np.array([0, 1, 2, 3])
seed_ids = np.array([1, 2, 0, 3])
sampler = NeighborSampler(g, batch_size=4, expand_factor=g.number_of_nodes(),
num_hops=num_hops, seed_nodes=seed_ids, add_self_loop=add_self_loop)
nfs = list(sampler)
assert len(nfs) == 1
assert np.array_equal(F.asnumpy(nfs[0].layer_parent_nid(-1)), seed_ids)
return nfs[0]
def check_basic(g, nf):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment