import argparse import os import random import numpy as np import torch import torch.optim as optim from gnn import GNN from ogb.lsc import DglPCQM4MDataset, PCQM4MEvaluator from torch.optim.lr_scheduler import StepLR from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import dgl reg_criterion = torch.nn.L1Loss() def collate_dgl(samples): graphs, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) labels = torch.stack(labels) return batched_graph, labels def train(model, device, loader, optimizer): model.train() loss_accum = 0 for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")): bg = bg.to(device) x = bg.ndata.pop("feat") edge_attr = bg.edata.pop("feat") labels = labels.to(device) pred = model(bg, x, edge_attr).view( -1, ) optimizer.zero_grad() loss = reg_criterion(pred, labels) loss.backward() optimizer.step() loss_accum += loss.detach().cpu().item() return loss_accum / (step + 1) def eval(model, device, loader, evaluator): model.eval() y_true = [] y_pred = [] for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")): bg = bg.to(device) x = bg.ndata.pop("feat") edge_attr = bg.edata.pop("feat") labels = labels.to(device) with torch.no_grad(): pred = model(bg, x, edge_attr).view( -1, ) y_true.append(labels.view(pred.shape).detach().cpu()) y_pred.append(pred.detach().cpu()) y_true = torch.cat(y_true, dim=0) y_pred = torch.cat(y_pred, dim=0) input_dict = {"y_true": y_true, "y_pred": y_pred} return evaluator.eval(input_dict)["mae"] def test(model, device, loader): model.eval() y_pred = [] for step, (bg, _) in enumerate(tqdm(loader, desc="Iteration")): bg = bg.to(device) x = bg.ndata.pop("feat") edge_attr = bg.edata.pop("feat") with torch.no_grad(): pred = model(bg, x, edge_attr).view( -1, ) y_pred.append(pred.detach().cpu()) y_pred = torch.cat(y_pred, dim=0) return y_pred def main(): # Training settings parser = argparse.ArgumentParser( description="GNN baselines on pcqm4m with DGL" ) parser.add_argument( "--seed", type=int, default=42, help="random seed to use (default: 42)" ) parser.add_argument( "--device", type=int, default=0, help="which gpu to use if any (default: 0)", ) parser.add_argument( "--gnn", type=str, default="gin-virtual", help="GNN to use, which can be from " "[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)", ) parser.add_argument( "--graph_pooling", type=str, default="sum", help="graph pooling strategy mean or sum (default: sum)", ) parser.add_argument( "--drop_ratio", type=float, default=0, help="dropout ratio (default: 0)" ) parser.add_argument( "--num_layers", type=int, default=5, help="number of GNN message passing layers (default: 5)", ) parser.add_argument( "--emb_dim", type=int, default=600, help="dimensionality of hidden units in GNNs (default: 600)", ) parser.add_argument( "--train_subset", action="store_true", help="use 10% of the training set for training", ) parser.add_argument( "--batch_size", type=int, default=256, help="input batch size for training (default: 256)", ) parser.add_argument( "--epochs", type=int, default=100, help="number of epochs to train (default: 100)", ) parser.add_argument( "--num_workers", type=int, default=0, help="number of workers (default: 0)", ) parser.add_argument( "--log_dir", type=str, default="", help="tensorboard log directory. If not specified, " "tensorboard will not be used.", ) parser.add_argument( "--checkpoint_dir", type=str, default="", help="directory to save checkpoint", ) parser.add_argument( "--save_test_dir", type=str, default="", help="directory to save test submission file", ) args = parser.parse_args() print(args) np.random.seed(args.seed) torch.manual_seed(args.seed) random.seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) device = torch.device("cuda:" + str(args.device)) else: device = torch.device("cpu") ### automatic dataloading and splitting dataset = DglPCQM4MDataset(root="dataset/") # split_idx['train'], split_idx['valid'], split_idx['test'] # separately gives a 1D int64 tensor split_idx = dataset.get_idx_split() ### automatic evaluator. evaluator = PCQM4MEvaluator() if args.train_subset: subset_ratio = 0.1 subset_idx = torch.randperm(len(split_idx["train"]))[ : int(subset_ratio * len(split_idx["train"])) ] train_loader = DataLoader( dataset[split_idx["train"][subset_idx]], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_dgl, ) else: train_loader = DataLoader( dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_dgl, ) valid_loader = DataLoader( dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_dgl, ) if args.save_test_dir != "": test_loader = DataLoader( dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_dgl, ) if args.checkpoint_dir != "": os.makedirs(args.checkpoint_dir, exist_ok=True) shared_params = { "num_layers": args.num_layers, "emb_dim": args.emb_dim, "drop_ratio": args.drop_ratio, "graph_pooling": args.graph_pooling, } if args.gnn == "gin": model = GNN(gnn_type="gin", virtual_node=False, **shared_params).to( device ) elif args.gnn == "gin-virtual": model = GNN(gnn_type="gin", virtual_node=True, **shared_params).to( device ) elif args.gnn == "gcn": model = GNN(gnn_type="gcn", virtual_node=False, **shared_params).to( device ) elif args.gnn == "gcn-virtual": model = GNN(gnn_type="gcn", virtual_node=True, **shared_params).to( device ) else: raise ValueError("Invalid GNN type") num_params = sum(p.numel() for p in model.parameters()) print(f"#Params: {num_params}") optimizer = optim.Adam(model.parameters(), lr=0.001) if args.log_dir != "": writer = SummaryWriter(log_dir=args.log_dir) best_valid_mae = 1000 if args.train_subset: scheduler = StepLR(optimizer, step_size=300, gamma=0.25) args.epochs = 1000 else: scheduler = StepLR(optimizer, step_size=30, gamma=0.25) for epoch in range(1, args.epochs + 1): print("=====Epoch {}".format(epoch)) print("Training...") train_mae = train(model, device, train_loader, optimizer) print("Evaluating...") valid_mae = eval(model, device, valid_loader, evaluator) print({"Train": train_mae, "Validation": valid_mae}) if args.log_dir != "": writer.add_scalar("valid/mae", valid_mae, epoch) writer.add_scalar("train/mae", train_mae, epoch) if valid_mae < best_valid_mae: best_valid_mae = valid_mae if args.checkpoint_dir != "": print("Saving checkpoint...") checkpoint = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "best_val_mae": best_valid_mae, "num_params": num_params, } torch.save( checkpoint, os.path.join(args.checkpoint_dir, "checkpoint.pt"), ) if args.save_test_dir != "": print("Predicting on test data...") y_pred = test(model, device, test_loader) print("Saving test submission file...") evaluator.save_test_submission( {"y_pred": y_pred}, args.save_test_dir ) scheduler.step() print(f"Best validation MAE so far: {best_valid_mae}") if args.log_dir != "": writer.close() if __name__ == "__main__": main()