import argparse import torch as th import dgl from dgl.dataloading import GraphDataLoader import warnings from dataset import load warnings.filterwarnings('ignore') from utils import linearsvc from model import MVGRL parser = argparse.ArgumentParser(description='mvgrl') parser.add_argument('--dataname', type=str, default='MUTAG', help='Name of dataset.') parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using cpu.') parser.add_argument('--epochs', type=int, default=200, help=' Number of training periods.') parser.add_argument('--patience', type=int, default=20, help='Early stopping steps.') parser.add_argument('--lr', type=float, default=0.001, help='Learning rate of mvgrl.') parser.add_argument('--wd', type=float, default=0., help='Weight decay of mvgrl.') parser.add_argument('--batch_size', type=int, default=64, help='Batch size.') parser.add_argument('--n_layers', type=int, default=4, help='Number of GNN layers.') parser.add_argument("--hid_dim", type=int, default=32, help='Hidden layer dim.') args = parser.parse_args() # check cuda if args.gpu != -1 and th.cuda.is_available(): args.device = 'cuda:{}'.format(args.gpu) else: args.device = 'cpu' def collate(samples): ''' collate function for building the graph dataloader''' graphs, diff_graphs, labels = map(list, zip(*samples)) # generate batched graphs and labels batched_graph = dgl.batch(graphs) batched_labels = th.tensor(labels) batched_diff_graph = dgl.batch(diff_graphs) n_graphs = len(graphs) graph_id = th.arange(n_graphs) graph_id = dgl.broadcast_nodes(batched_graph, graph_id) batched_graph.ndata['graph_id'] = graph_id return batched_graph, batched_diff_graph, batched_labels if __name__ == '__main__': # Step 1: Prepare data =================================================================== # dataset = load(args.dataname) graphs, diff_graphs, labels = map(list, zip(*dataset)) print('Number of graphs:', len(graphs)) # generate a full-graph with all examples for evaluation wholegraph = dgl.batch(graphs) whole_dg = dgl.batch(diff_graphs) # create dataloader for batch training dataloader = GraphDataLoader(dataset, batch_size=args.batch_size, collate_fn=collate, drop_last=False, shuffle=True) in_dim = wholegraph.ndata['feat'].shape[1] # Step 2: Create model =================================================================== # model = MVGRL(in_dim, args.hid_dim, args.n_layers) model = model.to(args.device) # Step 3: Create training components ===================================================== # optimizer = th.optim.Adam(model.parameters(), lr=args.lr) print('===== Before training ======') wholegraph = wholegraph.to(args.device) whole_dg = whole_dg.to(args.device) wholefeat = wholegraph.ndata.pop('feat') whole_weight = whole_dg.edata.pop('edge_weight') embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight) lbls = th.LongTensor(labels) acc_mean, acc_std = linearsvc(embs, lbls) print('accuracy_mean, {:.4f}'.format(acc_mean)) best = float('inf') cnt_wait = 0 # Step 4: Training epochs =============================================================== # for epoch in range(args.epochs): loss_all = 0 model.train() for graph, diff_graph, label in dataloader: graph = graph.to(args.device) diff_graph = diff_graph.to(args.device) feat = graph.ndata['feat'] graph_id = graph.ndata['graph_id'] edge_weight = diff_graph.edata['edge_weight'] n_graph = label.shape[0] optimizer.zero_grad() loss = model(graph, diff_graph, feat, edge_weight, graph_id) loss_all += loss.item() loss.backward() optimizer.step() print('Epoch {}, Loss {:.4f}'.format(epoch, loss_all)) if loss < best: best = loss best_t = epoch cnt_wait = 0 th.save(model.state_dict(), f'{args.dataname}.pkl') else: cnt_wait += 1 if cnt_wait == args.patience: print('Early stopping') break print('Training End') # Step 5: Linear evaluation ========================================================== # model.load_state_dict(th.load(f'{args.dataname}.pkl')) embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight) acc_mean, acc_std = linearsvc(embs, lbls) print('accuracy_mean, {:.4f}'.format(acc_mean))