Unverified Commit e0aa8389 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the CLUSTER_GAT example. (#6059)

parent d20db1ec
...@@ -112,17 +112,18 @@ class GAT(nn.Module): ...@@ -112,17 +112,18 @@ class GAT(nn.Module):
num_workers=args.num_workers, num_workers=args.num_workers,
) )
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): with dataloader.enable_cpu_affinity():
block = blocks[0].int().to(device) for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
h = x[input_nodes].to(device) block = blocks[0].int().to(device)
if l < self.n_layers - 1: h = x[input_nodes].to(device)
h = layer(block, h).flatten(1) if l < self.n_layers - 1:
else: h = layer(block, h).flatten(1)
h = layer(block, h) else:
h = h.mean(1) h = layer(block, h)
h = h.log_softmax(dim=-1) h = h.mean(1)
h = h.log_softmax(dim=-1)
y[output_nodes] = h.cpu()
y[output_nodes] = h.cpu()
x = y x = y
return y return y
...@@ -279,7 +280,9 @@ def run(args, device, data, nfeat): ...@@ -279,7 +280,9 @@ def run(args, device, data, nfeat):
best_eval_acc, best_test_acc best_eval_acc, best_test_acc
) )
) )
print("Avg epoch time: {}".format(avg / (epoch - 4)))
if epoch >= 5:
print("Avg epoch time: {}".format(avg / (epoch - 4)))
return best_test_acc.to(th.device("cpu")) return best_test_acc.to(th.device("cpu"))
...@@ -291,22 +294,22 @@ if __name__ == "__main__": ...@@ -291,22 +294,22 @@ if __name__ == "__main__":
default=0, default=0,
help="GPU device ID. Use -1 for CPU training", help="GPU device ID. Use -1 for CPU training",
) )
argparser.add_argument("--num-epochs", type=int, default=20) argparser.add_argument("--num_epochs", type=int, default=20)
argparser.add_argument("--num-hidden", type=int, default=128) argparser.add_argument("--num_hidden", type=int, default=128)
argparser.add_argument("--num-layers", type=int, default=3) argparser.add_argument("--num_layers", type=int, default=3)
argparser.add_argument("--num-heads", type=int, default=8) argparser.add_argument("--num_heads", type=int, default=8)
argparser.add_argument("--batch-size", type=int, default=32) argparser.add_argument("--batch_size", type=int, default=32)
argparser.add_argument("--val-batch-size", type=int, default=2000) argparser.add_argument("--val_batch_size", type=int, default=2000)
argparser.add_argument("--log-every", type=int, default=20) argparser.add_argument("--log_every", type=int, default=20)
argparser.add_argument("--eval-every", type=int, default=1) argparser.add_argument("--eval_every", type=int, default=1)
argparser.add_argument("--lr", type=float, default=0.001) argparser.add_argument("--lr", type=float, default=0.001)
argparser.add_argument("--dropout", type=float, default=0.5) argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument("--save-pred", type=str, default="") argparser.add_argument("--save_pred", type=str, default="")
argparser.add_argument("--wd", type=float, default=0) argparser.add_argument("--wd", type=float, default=0)
argparser.add_argument("--num_partitions", type=int, default=15000) argparser.add_argument("--num_partitions", type=int, default=15000)
argparser.add_argument("--num-workers", type=int, default=0) argparser.add_argument("--num_workers", type=int, default=4)
argparser.add_argument( argparser.add_argument(
"--data-cpu", "--data_cpu",
action="store_true", action="store_true",
help="By default the script puts all node features and labels " help="By default the script puts all node features and labels "
"on GPU when using it to save time for data copy. This may " "on GPU when using it to save time for data copy. This may "
...@@ -352,7 +355,7 @@ if __name__ == "__main__": ...@@ -352,7 +355,7 @@ if __name__ == "__main__":
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
pin_memory=True, pin_memory=True,
num_workers=4, num_workers=args.num_workers,
collate_fn=partial(subgraph_collate_fn, graph), collate_fn=partial(subgraph_collate_fn, graph),
) )
...@@ -375,6 +378,5 @@ if __name__ == "__main__": ...@@ -375,6 +378,5 @@ if __name__ == "__main__":
nfeat = graph.ndata.pop("feat").to(device) nfeat = graph.ndata.pop("feat").to(device)
for i in range(10): for i in range(10):
test_accs.append(run(args, device, data, nfeat)) test_accs.append(run(args, device, data, nfeat))
print(
"Average test accuracy:", np.mean(test_accs), "±", np.std(test_accs) print("Average test accuracy:", np.mean(test_accs), "±", np.std(test_accs))
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment