import argparse import os import time import warnings import torch import torch.nn.functional as F from config import CONFIG from modules import GCNNet from sampler import SAINTEdgeSampler, SAINTNodeSampler, SAINTRandomWalkSampler from torch.utils.data import DataLoader from utils import Logger, calc_f1, evaluate, load_data, save_log_dir def main(args, task): warnings.filterwarnings("ignore") multilabel_data = {"ppi", "yelp", "amazon"} multilabel = args.dataset in multilabel_data # This flag is excluded for too large dataset, like amazon, the graph of which is too large to be directly # shifted to one gpu. So we need to # 1. put the whole graph on cpu, and put the subgraphs on gpu in training phase # 2. put the model on gpu in training phase, and put the model on cpu in validation/testing phase # We need to judge cpu_flag and cuda (below) simultaneously when shift model between cpu and gpu if args.dataset in ["amazon"]: cpu_flag = True else: cpu_flag = False # load and preprocess dataset data = load_data(args, multilabel) g = data.g train_mask = g.ndata["train_mask"] val_mask = g.ndata["val_mask"] test_mask = g.ndata["test_mask"] labels = g.ndata["label"] train_nid = data.train_nid in_feats = g.ndata["feat"].shape[1] n_classes = data.num_classes n_nodes = g.num_nodes() n_edges = g.num_edges() n_train_samples = train_mask.int().sum().item() n_val_samples = val_mask.int().sum().item() n_test_samples = test_mask.int().sum().item() print( """----Data statistics------' #Nodes %d #Edges %d #Classes/Labels (multi binary labels) %d #Train samples %d #Val samples %d #Test samples %d""" % ( n_nodes, n_edges, n_classes, n_train_samples, n_val_samples, n_test_samples, ) ) # load sampler kwargs = { "dn": args.dataset, "g": g, "train_nid": train_nid, "num_workers_sampler": args.num_workers_sampler, "num_subg_sampler": args.num_subg_sampler, "batch_size_sampler": args.batch_size_sampler, "online": args.online, "num_subg": args.num_subg, "full": args.full, } if args.sampler == "node": saint_sampler = SAINTNodeSampler(args.node_budget, **kwargs) elif args.sampler == "edge": saint_sampler = SAINTEdgeSampler(args.edge_budget, **kwargs) elif args.sampler == "rw": saint_sampler = SAINTRandomWalkSampler( args.num_roots, args.length, **kwargs ) else: raise NotImplementedError loader = DataLoader( saint_sampler, collate_fn=saint_sampler.__collate_fn__, batch_size=1, shuffle=True, num_workers=args.num_workers, drop_last=False, ) # set device for dataset tensors if args.gpu < 0: cuda = False else: cuda = True torch.cuda.set_device(args.gpu) val_mask = val_mask.cuda() test_mask = test_mask.cuda() if not cpu_flag: g = g.to("cuda:{}".format(args.gpu)) print("labels shape:", g.ndata["label"].shape) print("features shape:", g.ndata["feat"].shape) model = GCNNet( in_dim=in_feats, hid_dim=args.n_hidden, out_dim=n_classes, arch=args.arch, dropout=args.dropout, batch_norm=not args.no_batch_norm, aggr=args.aggr, ) if cuda: model.cuda() # logger and so on log_dir = save_log_dir(args) logger = Logger(os.path.join(log_dir, "loggings")) logger.write(args) # use optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # set train_nids to cuda tensor if cuda: train_nid = torch.from_numpy(train_nid).cuda() print( "GPU memory allocated before training(MB)", torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024, ) start_time = time.time() best_f1 = -1 for epoch in range(args.n_epochs): for j, subg in enumerate(loader): if cuda: subg = subg.to(torch.cuda.current_device()) model.train() # forward pred = model(subg) batch_labels = subg.ndata["label"] if multilabel: loss = F.binary_cross_entropy_with_logits( pred, batch_labels, reduction="sum", weight=subg.ndata["l_n"].unsqueeze(1), ) else: loss = F.cross_entropy(pred, batch_labels, reduction="none") loss = (subg.ndata["l_n"] * loss).sum() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(model.parameters(), 5) optimizer.step() if j == len(loader) - 1: model.eval() with torch.no_grad(): train_f1_mic, train_f1_mac = calc_f1( batch_labels.cpu().numpy(), pred.cpu().numpy(), multilabel, ) print( f"epoch:{epoch + 1}/{args.n_epochs}, Iteration {j + 1}/" f"{len(loader)}:training loss", loss.item(), ) print( "Train F1-mic {:.4f}, Train F1-mac {:.4f}".format( train_f1_mic, train_f1_mac ) ) # evaluate model.eval() if epoch % args.val_every == 0: if ( cpu_flag and cuda ): # Only when we have shifted model to gpu and we need to shift it back on cpu model = model.to("cpu") val_f1_mic, val_f1_mac = evaluate( model, g, labels, val_mask, multilabel ) print( "Val F1-mic {:.4f}, Val F1-mac {:.4f}".format( val_f1_mic, val_f1_mac ) ) if val_f1_mic > best_f1: best_f1 = val_f1_mic print("new best val f1:", best_f1) torch.save( model.state_dict(), os.path.join(log_dir, "best_model_{}.pkl".format(task)), ) if cpu_flag and cuda: model.cuda() end_time = time.time() print(f"training using time {end_time - start_time}") # test if args.use_val: model.load_state_dict( torch.load(os.path.join(log_dir, "best_model_{}.pkl".format(task))) ) if cpu_flag and cuda: model = model.to("cpu") test_f1_mic, test_f1_mac = evaluate(model, g, labels, test_mask, multilabel) print( "Test F1-mic {:.4f}, Test F1-mac {:.4f}".format( test_f1_mic, test_f1_mac ) ) if __name__ == "__main__": warnings.filterwarnings("ignore") parser = argparse.ArgumentParser(description="GraphSAINT") parser.add_argument( "--task", type=str, default="ppi_n", help="type of tasks" ) parser.add_argument( "--online", dest="online", action="store_true", help="sampling method in training phase", ) parser.add_argument("--gpu", type=int, default=0, help="the gpu index") task = parser.parse_args().task args = argparse.Namespace(**CONFIG[task]) args.online = parser.parse_args().online args.gpu = parser.parse_args().gpu print(args) main(args, task=task)