Unverified Commit 31f1b30c authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the GGNN example. (#6055)

parent 83437e67
......@@ -104,7 +104,7 @@ def _ns_dataloader(
node_ids.append(s)
if t not in node_ids:
node_ids.append(t)
g = dgl.DGLGraph()
g = dgl.graph([])
g.add_nodes(len(node_ids))
g.ndata["node_id"] = torch.tensor(node_ids, dtype=torch.long)
......@@ -224,7 +224,7 @@ def _gc_dataloader(
node_ids.append(s)
if t not in node_ids:
node_ids.append(t)
g = dgl.DGLGraph()
g = dgl.graph([])
g.add_nodes(len(node_ids))
g.ndata["node_id"] = torch.tensor(node_ids, dtype=torch.long)
......@@ -346,7 +346,7 @@ def _path_finding_dataloader(
node_ids.append(s)
if t not in node_ids:
node_ids.append(t)
g = dgl.DGLGraph()
g = dgl.graph([])
g.add_nodes(len(node_ids))
g.ndata["node_id"] = torch.tensor(node_ids, dtype=torch.long)
......
......@@ -59,7 +59,7 @@ def main(args):
labels = labels.data.numpy().tolist()
dev_preds += preds
dev_labels += labels
acc = np.equal(dev_labels, dev_preds).astype(np.float).tolist()
acc = np.equal(dev_labels, dev_preds).astype(float).tolist()
acc = sum(acc) / len(acc)
print(f"Epoch {epoch}, Dev acc {acc}")
......@@ -81,7 +81,7 @@ def main(args):
labels = labels.data.numpy().tolist()
test_preds += preds
test_labels += labels
acc = np.equal(test_labels, test_preds).astype(np.float).tolist()
acc = np.equal(test_labels, test_preds).astype(float).tolist()
acc = sum(acc) / len(acc)
test_acc_list.append(acc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment