Unverified Commit d7062980 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

fix a bug. (#646)

parent e16e895d
......@@ -441,9 +441,20 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
&tmp_sampled_edge_list,
&time_seed);
}
if (add_self_loop) {
// If we need to add self loop and it doesn't exist in the sampled neighbor list.
if (add_self_loop && std::find(tmp_sampled_src_list.begin(), tmp_sampled_src_list.end(),
dst_id) == tmp_sampled_src_list.end()) {
tmp_sampled_src_list.push_back(dst_id);
tmp_sampled_edge_list.push_back(-1);
const dgl_id_t *src_list = col_list + *(indptr + dst_id);
const dgl_id_t *eid_list = val_list + *(indptr + dst_id);
// TODO(zhengda) this operation has O(N) complexity. It can be pretty slow.
const dgl_id_t *src = std::find(src_list, src_list + ver_len, dst_id);
// If there doesn't exist a self loop in the graph.
// we have to add -1 as the edge id for the self-loop edge.
if (src == src_list + ver_len)
tmp_sampled_edge_list.push_back(-1);
else
tmp_sampled_edge_list.push_back(eid_list[src - src_list]);
}
CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size());
neigh_pos.emplace_back(dst_id, neighbor_list.size(), tmp_sampled_src_list.size());
......
......@@ -10,7 +10,7 @@ import dgl.function as fn
from functools import partial
import itertools
def generate_rand_graph(n, connect_more=False, complete=False):
def generate_rand_graph(n, connect_more=False, complete=False, add_self_loop=False):
if complete:
cord = [(i,j) for i, j in itertools.product(range(n), range(n)) if i != j]
row = [t[0] for t in cord]
......@@ -23,7 +23,13 @@ def generate_rand_graph(n, connect_more=False, complete=False):
if connect_more:
arr[0] = 1
arr[:,0] = 1
g = dgl.DGLGraph(arr, readonly=True)
if add_self_loop:
g = dgl.DGLGraph(arr, readonly=False)
nodes = np.arange(g.number_of_nodes())
g.add_edges(nodes, nodes)
g.readonly()
else:
g = dgl.DGLGraph(arr, readonly=True)
g.ndata['h1'] = F.randn((g.number_of_nodes(), 10))
g.edata['h2'] = F.randn((g.number_of_edges(), 3))
return g
......@@ -39,6 +45,18 @@ def test_self_loop():
deg = F.copy_to(F.ones(in_deg.shape, dtype=F.int64), F.cpu()) * n
assert_array_equal(F.asnumpy(in_deg), F.asnumpy(deg))
g = generate_rand_graph(n, complete=True, add_self_loop=True)
g = dgl.to_simple_graph(g)
nf = create_mini_batch(g, num_hops, add_self_loop=True)
for i in range(nf.num_blocks):
parent_eid = F.asnumpy(nf.block_parent_eid(i))
parent_nid = F.asnumpy(nf.layer_parent_nid(i + 1))
# The loop eid in the parent graph must exist in the block parent eid.
parent_loop_eid = F.asnumpy(g.edge_ids(parent_nid, parent_nid))
assert len(parent_loop_eid) == len(parent_nid)
for eid in parent_loop_eid:
assert eid in parent_eid
def create_mini_batch(g, num_hops, add_self_loop=False):
seed_ids = np.array([1, 2, 0, 3])
sampler = NeighborSampler(g, batch_size=4, expand_factor=g.number_of_nodes(),
......@@ -59,7 +77,6 @@ def check_basic(g, nf):
assert nf.number_of_edges() == num_edges
assert len(nf) == num_nodes
assert nf.is_readonly
assert not nf.is_multigraph
assert np.all(F.asnumpy(nf.has_nodes(list(range(num_nodes)))))
for i in range(num_nodes):
......@@ -131,6 +148,11 @@ def test_basic():
assert nf.num_layers == num_layers + 1
check_basic(g, nf)
g = generate_rand_graph(100, add_self_loop=True)
nf = create_mini_batch(g, num_layers, add_self_loop=True)
assert nf.num_layers == num_layers + 1
check_basic(g, nf)
def check_apply_nodes(create_node_flow, use_negative_block_id):
num_layers = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment