Unverified Commit 017d9d40 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

[Fix][Readability] Improving the Capsule Network example. (#5985)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 416e2425
......@@ -68,15 +68,16 @@ def squash(s, dim=1):
def init_graph(in_nodes, out_nodes, f_size, device="cpu"):
g = dgl.DGLGraph()
g.set_n_initializer(dgl.frame.zero_initializer)
all_nodes = in_nodes + out_nodes
g.add_nodes(all_nodes)
src, dst = [], []
in_indx = list(range(in_nodes))
out_indx = list(range(in_nodes, in_nodes + out_nodes))
# add edges use edge broadcasting
for u in in_indx:
g.add_edges(u, out_indx)
src += [u] * len(out_indx)
dst += out_indx
g = dgl.graph((src, dst)) # dgl.graph once;
g.set_n_initializer(dgl.frame.zero_initializer)
g = g.to(device)
g.edata["b"] = th.zeros(in_nodes * out_nodes, 1).to(device)
return g
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment