Unverified Commit b79dae36 authored by caojy1998's avatar caojy1998 Committed by GitHub
Browse files

[Example] Modify the output of gated_gcn example to fit benchmark test (#6405)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-21-37.ap-northeast-1.compute.internal>
parent 32be4a8e
......@@ -15,7 +15,7 @@ How to run
----------
```bash
python train.py
python3 train.py --dataset ogbg-molhiv --num_gpus 0 --num_epochs 50
```
## Summary
......
"""
Gated Graph Convolutional Network module for graph classification tasks
"""
import argparse
import time
import torch
import torch.nn as nn
......@@ -98,12 +100,37 @@ def evaluate(model, device, data_loader, evaluator):
return evaluator.eval({"y_true": y_true, "y_pred": y_pred})["rocauc"]
def main():
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="ogbg-molhiv",
help="Dataset name ('ogbg-molhiv', 'ogbg-molbace', 'ogbg-molmuv').",
)
parser.add_argument(
"--num_epochs",
type=int,
default=200,
help="Number of epochs for train.",
)
parser.add_argument(
"--num_gpus",
type=int,
default=0,
help="Number of GPUs used for train and evaluation.",
)
args = parser.parse_args()
print("Training with DGL built-in GATConv module.")
# Load ogb dataset & evaluator.
dataset = DglGraphPropPredDataset(name="ogbg-molhiv")
evaluator = Evaluator(name="ogbg-molhiv")
dataset = DglGraphPropPredDataset(name=args.dataset)
evaluator = Evaluator(name=args.dataset)
if args.num_gpus > 0 and torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
n_classes = dataset.num_tasks
......@@ -125,18 +152,16 @@ def main():
loss_fn = nn.BCEWithLogitsLoss()
print("---------- Training ----------")
for epoch in range(50):
for epoch in range(args.num_epochs):
# Kick off training.
t0 = time.time()
loss = train(model, device, train_loader, opt, loss_fn)
t1 = time.time()
# Evaluate the prediction.
valid_acc = evaluate(model, device, valid_loader, evaluator)
test_acc = evaluate(model, device, test_loader, evaluator)
val_acc = evaluate(model, device, valid_loader, evaluator)
print(
f"In epoch {epoch}, loss: {loss:.3f}, val acc: {valid_acc:.3f}, test"
f" acc: {test_acc:.3f}"
f"Epoch {epoch:05d} | Loss {loss:.4f} | Accuracy {val_acc:.4f} | "
f"Time {t1 - t0:.4f}"
)
if __name__ == "__main__":
main()
acc = evaluate(model, device, test_loader, evaluator)
print(f"Test accuracy {acc:.4f}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment