Unverified Commit 9e46423e authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Bugfix] Fix cluster-gat examples (#4068)


Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent d9c25521
......@@ -51,7 +51,7 @@ class GAT(nn.Module):
attn_drop=dropout,
activation=None,
negative_slope=0.2))
def forward(self, g, x):
h = x
for l, conv in enumerate(self.layers):
......@@ -119,7 +119,8 @@ def evaluate(model, g, nfeat, labels, val_nid, test_nid, batch_size, device):
with th.no_grad():
pred = model.inference(g, nfeat, batch_size, device)
model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid]), pred
labels_cpu = labels.to(th.device('cpu'))
return compute_acc(pred[val_nid], labels_cpu[val_nid]), compute_acc(pred[test_nid], labels_cpu[test_nid]), pred
def model_param_summary(model):
""" Count the model parameters """
......@@ -127,11 +128,10 @@ def model_param_summary(model):
print("Total Params {}".format(cnt))
#### Entry point
def run(args, device, data):
def run(args, device, data, nfeat):
# Unpack data
train_nid, val_nid, test_nid, in_feats, labels, n_classes, g, cluster_iterator = data
labels = labels.to(device)
nfeat = g.ndata.pop('feat').to(device)
# Define model and optimizer
model = GAT(in_feats, args.num_heads, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
......@@ -200,7 +200,7 @@ def run(args, device, data):
best_test_acc = test_acc
print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
print('Avg epoch time: {}'.format(avg / (epoch - 4)))
return best_test_acc
return best_test_acc.to(th.device('cpu'))
if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training")
......@@ -265,6 +265,7 @@ if __name__ == '__main__':
# Run 10 times
test_accs = []
nfeat = graph.ndata.pop('feat').to(device)
for i in range(10):
test_accs.append(run(args, device, data))
test_accs.append(run(args, device, data, nfeat))
print('Average test accuracy:', np.mean(test_accs), '±', np.std(test_accs))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment