Commit 039a711d authored by Hao Zhang's avatar Hao Zhang Committed by Minjie Wang
Browse files

[Model] fix self-edge bug in GCN and GAT. (#482)

* Update gcn_mp.py

* Update train.py

* Update train.py

* Update train.py

* Update gat_batch.py

* Update gat_batch.py

* Update gcn_mp.py
parent fa887f69
......@@ -178,10 +178,11 @@ def main(args):
test_mask = test_mask.as_in_context(ctx)
val_mask = val_mask.as_in_context(ctx)
# create graph
g = DGLGraph(data.graph)
g = data.graph
# add self-loop
g.remove_edges_from(g.selfloop_edges())
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
# create model
model = GAT(g,
args.num_layers,
......
......@@ -52,9 +52,11 @@ def main(args):
test_mask = test_mask.as_in_context(ctx)
# create GCN model
g = DGLGraph(data.graph)
g = data.graph
if args.self_loop:
g.add_edges(g.nodes(), g.nodes())
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)
# normalization
degs = g.in_degrees().astype('float32')
norm = mx.nd.power(degs, -0.5)
......
......@@ -65,11 +65,12 @@ def main(args):
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()
# create DGL graph
g = DGLGraph(data.graph)
n_edges = g.number_of_edges()
g = data.graph
# add self loop
g.remove_edges_from(g.selfloop_edges())
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
n_edges = g.number_of_edges()
# create model
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
model = GAT(g,
......
......@@ -147,10 +147,12 @@ def main(args):
test_mask = test_mask.cuda()
# graph preprocess and calculate normalization factor
g = DGLGraph(data.graph)
n_edges = g.number_of_edges()
g = data.graph
g.remove_edges_from(g.selfloop_edges())
g = DGLGraph(g)
# add self loop
g.add_edges(g.nodes(), g.nodes())
n_edges = g.number_of_edges()
# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
......
......@@ -54,11 +54,13 @@ def main(args):
test_mask = test_mask.cuda()
# graph preprocess and calculate normalization factor
g = DGLGraph(data.graph)
n_edges = g.number_of_edges()
g = data.graph
# add self loop
if args.self_loop:
g.add_edges(g.nodes(), g.nodes())
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)
n_edges = g.number_of_edges()
# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment