#!/usr/bin/env python # -*- coding: utf-8 -*- import argparse import os import random import sys import time import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F import torch.optim as optim from matplotlib.ticker import AutoMinorLocator, MultipleLocator from models import GAT from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from torch import nn import dgl import dgl.function as fn from dgl.dataloading import ( DataLoader, MultiLayerFullNeighborSampler, MultiLayerNeighborSampler, ) device = None dataset = "ogbn-proteins" n_node_feats, n_edge_feats, n_classes = 0, 8, 112 def seed(seed=0): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False dgl.random.seed(seed) def load_data(dataset): data = DglNodePropPredDataset(name=dataset) evaluator = Evaluator(name=dataset) splitted_idx = data.get_idx_split() train_idx, val_idx, test_idx = ( splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"], ) graph, labels = data[0] graph.ndata["labels"] = labels return graph, labels, train_idx, val_idx, test_idx, evaluator def preprocess(graph, labels, train_idx): global n_node_feats # The sum of the weights of adjacent edges is used as node features. graph.update_all( fn.copy_e("feat", "feat_copy"), fn.sum("feat_copy", "feat") ) n_node_feats = graph.ndata["feat"].shape[-1] # Only the labels in the training set are used as features, while others are filled with zeros. graph.ndata["train_labels_onehot"] = torch.zeros( graph.number_of_nodes(), n_classes ) graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1 graph.ndata["deg"] = graph.out_degrees().float().clamp(min=1) graph.create_formats_() return graph, labels def gen_model(args): if args.use_labels: n_node_feats_ = n_node_feats + n_classes else: n_node_feats_ = n_node_feats model = GAT( n_node_feats_, n_edge_feats, n_classes, n_layers=args.n_layers, n_heads=args.n_heads, n_hidden=args.n_hidden, edge_emb=16, activation=F.relu, dropout=args.dropout, input_drop=args.input_drop, attn_drop=args.attn_drop, edge_drop=args.edge_drop, use_attn_dst=not args.no_attn_dst, ) return model def add_labels(graph, idx): feat = graph.srcdata["feat"] train_labels_onehot = torch.zeros([feat.shape[0], n_classes], device=device) train_labels_onehot[idx] = graph.srcdata["train_labels_onehot"][idx] graph.srcdata["feat"] = torch.cat([feat, train_labels_onehot], dim=-1) def train( args, model, dataloader, _labels, _train_idx, criterion, optimizer, _evaluator, ): model.train() loss_sum, total = 0, 0 for input_nodes, output_nodes, subgraphs in dataloader: subgraphs = [b.to(device) for b in subgraphs] new_train_idx = torch.arange(len(output_nodes), device=device) if args.use_labels: train_labels_idx = torch.arange( len(output_nodes), len(input_nodes), device=device ) train_pred_idx = new_train_idx add_labels(subgraphs[0], train_labels_idx) else: train_pred_idx = new_train_idx pred = model(subgraphs) loss = criterion( pred[train_pred_idx], subgraphs[-1].dstdata["labels"][train_pred_idx].float(), ) optimizer.zero_grad() loss.backward() optimizer.step() count = len(train_pred_idx) loss_sum += loss.item() * count total += count # torch.cuda.empty_cache() return loss_sum / total @torch.no_grad() def evaluate( args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator, ): model.eval() preds = torch.zeros(labels.shape).to(device) # Due to the memory capacity constraints, we use sampling for inference and calculate the average of the predictions 'eval_times' times. eval_times = 1 for _ in range(eval_times): for input_nodes, output_nodes, subgraphs in dataloader: subgraphs = [b.to(device) for b in subgraphs] new_train_idx = list(range(len(input_nodes))) if args.use_labels: add_labels(subgraphs[0], new_train_idx) pred = model(subgraphs) preds[output_nodes] += pred # torch.cuda.empty_cache() preds /= eval_times train_loss = criterion(preds[train_idx], labels[train_idx].float()).item() val_loss = criterion(preds[val_idx], labels[val_idx].float()).item() test_loss = criterion(preds[test_idx], labels[test_idx].float()).item() return ( evaluator(preds[train_idx], labels[train_idx]), evaluator(preds[val_idx], labels[val_idx]), evaluator(preds[test_idx], labels[test_idx]), train_loss, val_loss, test_loss, preds, ) def run( args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running ): evaluator_wrapper = lambda pred, labels: evaluator.eval( {"y_pred": pred, "y_true": labels} )["rocauc"] train_batch_size = (len(train_idx) + 9) // 10 # batch_size = len(train_idx) train_sampler = MultiLayerNeighborSampler( [32 for _ in range(args.n_layers)] ) # sampler = MultiLayerFullNeighborSampler(args.n_layers) train_dataloader = DataLoader( graph.cpu(), train_idx.cpu(), train_sampler, batch_size=train_batch_size, num_workers=10, ) eval_sampler = MultiLayerNeighborSampler( [100 for _ in range(args.n_layers)] ) # sampler = MultiLayerFullNeighborSampler(args.n_layers) eval_dataloader = DataLoader( graph.cpu(), torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]), eval_sampler, batch_size=65536, num_workers=10, ) criterion = nn.BCEWithLogitsLoss() model = gen_model(args).to(device) optimizer = optim.AdamW( model.parameters(), lr=args.lr, weight_decay=args.wd ) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="max", factor=0.75, patience=50, verbose=True ) total_time = 0 val_score, best_val_score, final_test_score = 0, 0, 0 train_scores, val_scores, test_scores = [], [], [] losses, train_losses, val_losses, test_losses = [], [], [], [] final_pred = None for epoch in range(1, args.n_epochs + 1): tic = time.time() loss = train( args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper, ) toc = time.time() total_time += toc - tic if ( epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0 ): ( train_score, val_score, test_score, train_loss, val_loss, test_loss, pred, ) = evaluate( args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper, ) if val_score > best_val_score: best_val_score = val_score final_test_score = test_score final_pred = pred if epoch % args.log_every == 0: print( f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}s" ) print( f"Loss: {loss:.4f}\n" f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n" f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}" ) for l, e in zip( [ train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses, ], [ train_score, val_score, test_score, loss, train_loss, val_loss, test_loss, ], ): l.append(e) lr_scheduler.step(val_score) print("*" * 50) print( f"Best val score: {best_val_score}, Final test score: {final_test_score}" ) print("*" * 50) if args.plot: fig = plt.figure(figsize=(24, 24)) ax = fig.gca() ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_yticks(np.linspace(0, 1.0, 101)) ax.tick_params(labeltop=True, labelright=True) for y, label in zip( [train_scores, val_scores, test_scores], ["train score", "val score", "test score"], ): plt.plot( range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1, ) ax.xaxis.set_major_locator(MultipleLocator(100)) ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.yaxis.set_major_locator(MultipleLocator(0.01)) ax.yaxis.set_minor_locator(AutoMinorLocator(2)) plt.grid(which="major", color="red", linestyle="dotted") plt.grid(which="minor", color="orange", linestyle="dotted") plt.legend() plt.tight_layout() plt.savefig(f"gat_score_{n_running}.png") fig = plt.figure(figsize=(24, 24)) ax = fig.gca() ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.tick_params(labeltop=True, labelright=True) for y, label in zip( [losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"], ): plt.plot( range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1, ) ax.xaxis.set_major_locator(MultipleLocator(100)) ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.yaxis.set_major_locator(MultipleLocator(0.1)) ax.yaxis.set_minor_locator(AutoMinorLocator(5)) plt.grid(which="major", color="red", linestyle="dotted") plt.grid(which="minor", color="orange", linestyle="dotted") plt.legend() plt.tight_layout() plt.savefig(f"gat_loss_{n_running}.png") if args.save_pred: os.makedirs("./output", exist_ok=True) torch.save(F.softmax(final_pred, dim=1), f"./output/{n_running}.pt") return best_val_score, final_test_score def count_parameters(args): model = gen_model(args) return sum( [np.prod(p.size()) for p in model.parameters() if p.requires_grad] ) def main(): global device argparser = argparse.ArgumentParser( "GAT implementation on ogbn-proteins", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) argparser.add_argument( "--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.", ) argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID") argparser.add_argument("--seed", type=int, default=0, help="random seed") argparser.add_argument( "--n-runs", type=int, default=10, help="running times" ) argparser.add_argument( "--n-epochs", type=int, default=1200, help="number of epochs" ) argparser.add_argument( "--use-labels", action="store_true", help="Use labels in the training set as input features.", ) argparser.add_argument( "--no-attn-dst", action="store_true", help="Don't use attn_dst." ) argparser.add_argument( "--n-heads", type=int, default=6, help="number of heads" ) argparser.add_argument( "--lr", type=float, default=0.01, help="learning rate" ) argparser.add_argument( "--n-layers", type=int, default=6, help="number of layers" ) argparser.add_argument( "--n-hidden", type=int, default=80, help="number of hidden units" ) argparser.add_argument( "--dropout", type=float, default=0.25, help="dropout rate" ) argparser.add_argument( "--input-drop", type=float, default=0.1, help="input drop rate" ) argparser.add_argument( "--attn-drop", type=float, default=0.0, help="attention dropout rate" ) argparser.add_argument( "--edge-drop", type=float, default=0.1, help="edge drop rate" ) argparser.add_argument("--wd", type=float, default=0, help="weight decay") argparser.add_argument( "--eval-every", type=int, default=5, help="evaluate every EVAL_EVERY epochs", ) argparser.add_argument( "--log-every", type=int, default=5, help="log every LOG_EVERY epochs" ) argparser.add_argument( "--plot", action="store_true", help="plot learning curves" ) argparser.add_argument( "--save-pred", action="store_true", help="save final predictions" ) args = argparser.parse_args() if args.cpu: device = torch.device("cpu") else: device = torch.device(f"cuda:{args.gpu}") # load data & preprocess print("Loading data") graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset) print("Preprocessing") graph, labels = preprocess(graph, labels, train_idx) labels, train_idx, val_idx, test_idx = map( lambda x: x.to(device), (labels, train_idx, val_idx, test_idx) ) # run val_scores, test_scores = [], [] for i in range(args.n_runs): print("Running", i) seed(args.seed + i) val_score, test_score = run( args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1 ) val_scores.append(val_score) test_scores.append(test_score) print(" ".join(sys.argv)) print(args) print(f"Runned {args.n_runs} times") print("Val scores:", val_scores) print("Test scores:", test_scores) print(f"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}") print(f"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}") print(f"Number of params: {count_parameters(args)}") if __name__ == "__main__": main() # Namespace(attn_drop=0.0, cpu=False, dropout=0.25, edge_drop=0.1, eval_every=5, gpu=6, input_drop=0.1, log_every=5, lr=0.01, n_epochs=1200, n_heads=6, n_hidden=80, n_layers=6, n_runs=10, no_attn_dst=False, plot=True, save_pred=False, seed=0, use_labels=False, wd=0) # Runned 10 times # Val scores: [0.927741031859485, 0.9272113161947824, 0.9271363901359605, 0.9275579074100136, 0.9264291968462317, 0.9275278541203443, 0.9286381790529751, 0.9288245051991526, 0.9269289529175155, 0.9278177920224489] # Test scores: [0.8754403567694566, 0.8749781870941457, 0.8735933245353141, 0.8759835445000637, 0.8745950242855286, 0.8742530369108132, 0.8784892022402326, 0.873345314887444, 0.8724393129004984, 0.874077975765639] # Average val score: 0.927581312575891 ± 0.0006953509986591492 # Average test score: 0.8747195279889135 ± 0.001593598488797452 # Number of params: 2475232 # Namespace(attn_drop=0.0, cpu=False, dropout=0.25, edge_drop=0.1, eval_every=5, gpu=7, input_drop=0.1, log_every=5, lr=0.01, n_epochs=1200, n_heads=6, n_hidden=80, n_layers=6, n_runs=10, no_attn_dst=False, plot=True, save_pred=False, seed=0, use_labels=True, wd=0) # Runned 10 times # Val scores: [0.9293776332568928, 0.9281066322254939, 0.9286775378440911, 0.9270252685136046, 0.9267937838323375, 0.9277731792338011, 0.9285615428437761, 0.9270819730221879, 0.9276822010553241, 0.9287115722177839] # Test scores: [0.8761623033485811, 0.8773002619440896, 0.8756680817047869, 0.8751873860287073, 0.875781797307807, 0.8764533839446703, 0.8771202308989311, 0.8765888651476396, 0.8773581283481205, 0.8777751912293709] # Average val score: 0.9279791324045293 ± 0.0008115348697502517 # Average test score: 0.8765395629902706 ± 0.0008016806017700173 # Number of params: 2484192