"docs/source/vscode:/vscode.git/clone" did not exist on "652f4c07437df79d19eef812cdeb87f5d2d984d4"
Commit ca48787a authored by Hao Zhang's avatar Hao Zhang Committed by HQ
Browse files

fix tutorial (#506)

* Update 9_gat.py

* Update 1_gcn.py
parent 3b96299a
...@@ -106,7 +106,11 @@ def load_cora_data(): ...@@ -106,7 +106,11 @@ def load_cora_data():
features = th.FloatTensor(data.features) features = th.FloatTensor(data.features)
labels = th.LongTensor(data.labels) labels = th.LongTensor(data.labels)
mask = th.ByteTensor(data.train_mask) mask = th.ByteTensor(data.train_mask)
g = DGLGraph(data.graph) g = data.graph
# add self loop
g.remove_edges_from(g.selfloop_edges())
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
return g, features, labels, mask return g, features, labels, mask
############################################################################### ###############################################################################
......
...@@ -280,7 +280,11 @@ def load_cora_data(): ...@@ -280,7 +280,11 @@ def load_cora_data():
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask) mask = torch.ByteTensor(data.train_mask)
g = DGLGraph(data.graph) g = data.graph
# add self loop
g.remove_edges_from(g.selfloop_edges())
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
return g, features, labels, mask return g, features, labels, mask
############################################################################## ##############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment