import argparse import os import random import time import networkx as nx import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data from data_utils import pre_process from model.encoder import DiffPool import dgl import dgl.function as fn from dgl import DGLGraph from dgl.data import tu global_train_time_per_epoch = [] def arg_parse(): """ argument parser """ parser = argparse.ArgumentParser(description="DiffPool arguments") parser.add_argument("--dataset", dest="dataset", help="Input Dataset") parser.add_argument( "--pool_ratio", dest="pool_ratio", type=float, help="pooling ratio" ) parser.add_argument( "--num_pool", dest="num_pool", type=int, help="num_pooling layer" ) parser.add_argument( "--no_link_pred", dest="linkpred", action="store_false", help="switch of link prediction object", ) parser.add_argument("--cuda", dest="cuda", type=int, help="switch cuda") parser.add_argument("--lr", dest="lr", type=float, help="learning rate") parser.add_argument( "--clip", dest="clip", type=float, help="gradient clipping" ) parser.add_argument( "--batch-size", dest="batch_size", type=int, help="batch size" ) parser.add_argument("--epochs", dest="epoch", type=int, help="num-of-epoch") parser.add_argument( "--train-ratio", dest="train_ratio", type=float, help="ratio of trainning dataset split", ) parser.add_argument( "--test-ratio", dest="test_ratio", type=float, help="ratio of testing dataset split", ) parser.add_argument( "--num_workers", dest="n_worker", type=int, help="number of workers when dataloading", ) parser.add_argument( "--gc-per-block", dest="gc_per_block", type=int, help="number of graph conv layer per block", ) parser.add_argument( "--bn", dest="bn", action="store_const", const=True, default=True, help="switch for bn", ) parser.add_argument( "--dropout", dest="dropout", type=float, help="dropout rate" ) parser.add_argument( "--bias", dest="bias", action="store_const", const=True, default=True, help="switch for bias", ) parser.add_argument( "--save_dir", dest="save_dir", help="model saving directory: SAVE_DICT/DATASET", ) parser.add_argument( "--load_epoch", dest="load_epoch", type=int, help="load trained model params from\ SAVE_DICT/DATASET/model-LOAD_EPOCH", ) parser.add_argument( "--data_mode", dest="data_mode", help="data\ preprocessing mode: default, id, degree, or one-hot\ vector of degree number", choices=["default", "id", "deg", "deg_num"], ) parser.set_defaults( dataset="ENZYMES", pool_ratio=0.15, num_pool=1, cuda=1, lr=1e-3, clip=2.0, batch_size=20, epoch=4000, train_ratio=0.7, test_ratio=0.1, n_worker=1, gc_per_block=3, dropout=0.0, method="diffpool", bn=True, bias=True, save_dir="./model_param", load_epoch=-1, data_mode="default", ) return parser.parse_args() def prepare_data(dataset, prog_args, train=False, pre_process=None): """ preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader """ if train: shuffle = True else: shuffle = False if pre_process: pre_process(dataset, prog_args) # dataset.set_fold(fold) return dgl.dataloading.GraphDataLoader( dataset, batch_size=prog_args.batch_size, shuffle=shuffle, num_workers=prog_args.n_worker, ) def graph_classify_task(prog_args): """ perform graph classification task """ dataset = tu.LegacyTUDataset(name=prog_args.dataset) train_size = int(prog_args.train_ratio * len(dataset)) test_size = int(prog_args.test_ratio * len(dataset)) val_size = int(len(dataset) - train_size - test_size) dataset_train, dataset_val, dataset_test = torch.utils.data.random_split( dataset, (train_size, val_size, test_size) ) train_dataloader = prepare_data( dataset_train, prog_args, train=True, pre_process=pre_process ) val_dataloader = prepare_data( dataset_val, prog_args, train=False, pre_process=pre_process ) test_dataloader = prepare_data( dataset_test, prog_args, train=False, pre_process=pre_process ) input_dim, label_dim, max_num_node = dataset.statistics() print("++++++++++STATISTICS ABOUT THE DATASET") print("dataset feature dimension is", input_dim) print("dataset label dimension is", label_dim) print("the max num node is", max_num_node) print("number of graphs is", len(dataset)) # assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size" hidden_dim = 64 # used to be 64 embedding_dim = 64 # calculate assignment dimension: pool_ratio * largest graph's maximum # number of nodes in the dataset assign_dim = int(max_num_node * prog_args.pool_ratio) print("++++++++++MODEL STATISTICS++++++++") print("model hidden dim is", hidden_dim) print("model embedding dim for graph instance embedding", embedding_dim) print("initial batched pool graph dim is", assign_dim) activation = F.relu # initialize model # 'diffpool' : diffpool model = DiffPool( input_dim, hidden_dim, embedding_dim, label_dim, activation, prog_args.gc_per_block, prog_args.dropout, prog_args.num_pool, prog_args.linkpred, prog_args.batch_size, "meanpool", assign_dim, prog_args.pool_ratio, ) if prog_args.load_epoch >= 0 and prog_args.save_dir is not None: model.load_state_dict( torch.load( prog_args.save_dir + "/" + prog_args.dataset + "/model.iter-" + str(prog_args.load_epoch) ) ) print("model init finished") print("MODEL:::::::", prog_args.method) if prog_args.cuda: model = model.cuda() logger = train( train_dataloader, model, prog_args, val_dataset=val_dataloader ) result = evaluate(test_dataloader, model, prog_args, logger) print("test accuracy {:.2f}%".format(result * 100)) def train(dataset, model, prog_args, same_feat=True, val_dataset=None): """ training function """ dir = prog_args.save_dir + "/" + prog_args.dataset if not os.path.exists(dir): os.makedirs(dir) dataloader = dataset optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=0.001 ) early_stopping_logger = {"best_epoch": -1, "val_acc": -1} if prog_args.cuda > 0: torch.cuda.set_device(0) for epoch in range(prog_args.epoch): begin_time = time.time() model.train() accum_correct = 0 total = 0 print("\nEPOCH ###### {} ######".format(epoch)) computation_time = 0.0 for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader): for (key, value) in batch_graph.ndata.items(): batch_graph.ndata[key] = value.float() graph_labels = graph_labels.long() if torch.cuda.is_available(): batch_graph = batch_graph.to(torch.cuda.current_device()) graph_labels = graph_labels.cuda() model.zero_grad() compute_start = time.time() ypred = model(batch_graph) indi = torch.argmax(ypred, dim=1) correct = torch.sum(indi == graph_labels).item() accum_correct += correct total += graph_labels.size()[0] loss = model.loss(ypred, graph_labels) loss.backward() batch_compute_time = time.time() - compute_start computation_time += batch_compute_time nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip) optimizer.step() train_accu = accum_correct / total print( "train accuracy for this epoch {} is {:.2f}%".format( epoch, train_accu * 100 ) ) elapsed_time = time.time() - begin_time print( "loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format( loss.item(), elapsed_time, computation_time ) ) global_train_time_per_epoch.append(elapsed_time) if val_dataset is not None: result = evaluate(val_dataset, model, prog_args) print("validation accuracy {:.2f}%".format(result * 100)) if ( result >= early_stopping_logger["val_acc"] and result <= train_accu ): early_stopping_logger.update(best_epoch=epoch, val_acc=result) if prog_args.save_dir is not None: torch.save( model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset + "/model.iter-" + str(early_stopping_logger["best_epoch"]), ) print( "best epoch is EPOCH {}, val_acc is {:.2f}%".format( early_stopping_logger["best_epoch"], early_stopping_logger["val_acc"] * 100, ) ) torch.cuda.empty_cache() return early_stopping_logger def evaluate(dataloader, model, prog_args, logger=None): """ evaluate function """ if logger is not None and prog_args.save_dir is not None: model.load_state_dict( torch.load( prog_args.save_dir + "/" + prog_args.dataset + "/model.iter-" + str(logger["best_epoch"]) ) ) model.eval() correct_label = 0 with torch.no_grad(): for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader): for (key, value) in batch_graph.ndata.items(): batch_graph.ndata[key] = value.float() graph_labels = graph_labels.long() if torch.cuda.is_available(): batch_graph = batch_graph.to(torch.cuda.current_device()) graph_labels = graph_labels.cuda() ypred = model(batch_graph) indi = torch.argmax(ypred, dim=1) correct = torch.sum(indi == graph_labels) correct_label += correct.item() result = correct_label / (len(dataloader) * prog_args.batch_size) return result def main(): """ main """ prog_args = arg_parse() print(prog_args) graph_classify_task(prog_args) print( "Train time per epoch: {:.4f}".format( sum(global_train_time_per_epoch) / len(global_train_time_per_epoch) ) ) print( "Max memory usage: {:.4f}".format( torch.cuda.max_memory_allocated(0) / (1024 * 1024) ) ) if __name__ == "__main__": main()