Unverified Commit a6bd96aa authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

Fix example crashes due to DGL API update (#4194)


Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent f7dae453
......@@ -45,7 +45,7 @@ def main(args):
test_mask = g.ndata['test_mask']
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......
......@@ -21,6 +21,7 @@ def evaluate(model, features, labels, mask):
def main(args):
# load and preprocess dataset
data = load_data(args)
g = data[0]
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
......@@ -33,7 +34,7 @@ def main(args):
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
if args.gpu < 0:
cuda = False
......@@ -46,13 +47,10 @@ def main(args):
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()
# graph preprocess
g = data.graph
# add self loop
if args.self_loop:
g.remove_edges_from(nx.selfloop_edges(g))
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)
n_edges = g.number_of_edges()
if args.gpu >= 0:
......
......@@ -62,7 +62,7 @@ def main(args):
test_mask = g.ndata['test_mask']
num_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......
......@@ -48,7 +48,7 @@ def accuracy(logits, labels):
return correct.item() * 1.0 / len(labels)
def evaluate(model, g, features, labels, mask):
def evaluate(g, model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(g, features)
......@@ -82,7 +82,7 @@ def main(args):
test_mask = g.ndata['test_mask']
num_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......@@ -156,7 +156,7 @@ def main(args):
print()
if args.early_stop:
model.load_state_dict(torch.load('es_checkpoint.pt'))
acc = evaluate(model, features, labels, test_mask)
acc = evaluate(g, model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
......
......@@ -140,7 +140,7 @@ def main(args):
test_mask = g.ndata['test_mask']
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......
......@@ -47,7 +47,7 @@ def main(args):
test_mask = g.ndata['test_mask']
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......
......@@ -70,7 +70,7 @@ def main(args):
test_mask = g.ndata['test_mask']
in_feats = features.shape[1]
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......
......@@ -60,7 +60,7 @@ def main(args):
test_mask = g.ndata['test_mask']
num_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......
......@@ -57,7 +57,7 @@ def main(args):
test_mask = g.ndata['test_mask']
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......
......@@ -78,7 +78,7 @@ def main(args):
test_mask = g.ndata['test_mask']
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......
......@@ -51,7 +51,7 @@ def main(args):
test_mask = g.ndata['test_mask']
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......
......@@ -31,7 +31,7 @@ def main(args):
# load and preprocess dataset
args.dataset = "reddit-self-loop"
data = load_data(args)
g = data.graph
g = data[0]
if args.gpu < 0:
cuda = False
else:
......@@ -45,7 +45,7 @@ def main(args):
test_mask = g.ndata['test_mask']
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment