import torch as th import dgl from dgl.data import GINDataset from dgl.dataloading import GraphDataLoader from model import InfoGraph from evaluate_embedding import evaluate_embedding import argparse def argument(): parser = argparse.ArgumentParser(description='InfoGraph') # data source params parser.add_argument('--dataname', type=str, default='MUTAG', help='Name of dataset.') # training params parser.add_argument('--gpu', type=int, default=-1, help='GPU index, default:-1, using CPU.') parser.add_argument('--epochs', type=int, default=20, help='Training epochs.') parser.add_argument('--batch_size', type=int, default=128, help='Training batch size.') parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.') parser.add_argument('--log_interval', type=int, default=1, help='Interval between two evaluations.') # model params parser.add_argument('--n_layers', type=int, default=3, help='Number of graph convolution layers before each pooling.') parser.add_argument('--hid_dim', type=int, default=32, help='Hidden layer dimensionalities.') args = parser.parse_args() # check cuda if args.gpu != -1 and th.cuda.is_available(): args.device = 'cuda:{}'.format(args.gpu) else: args.device = 'cpu' return args def collate(samples): ''' collate function for building graph dataloader''' graphs, labels = map(list, zip(*samples)) # generate batched graphs and labels batched_graph = dgl.batch(graphs) batched_labels = th.tensor(labels) n_graphs = len(graphs) graph_id = th.arange(n_graphs) graph_id = dgl.broadcast_nodes(batched_graph, graph_id) batched_graph.ndata['graph_id'] = graph_id return batched_graph, batched_labels if __name__ == '__main__': # Step 1: Prepare graph data ===================================== # args = argument() print(args) # load dataset from dgl.data.GINDataset dataset = GINDataset(args.dataname, self_loop = False) # get graphs and labels graphs, labels = map(list, zip(*dataset)) # generate a full-graph with all examples for evaluation wholegraph = dgl.batch(graphs) wholegraph.ndata['attr'] = wholegraph.ndata['attr'].to(th.float32) # create dataloader for batch training dataloader = GraphDataLoader(dataset, batch_size=args.batch_size, collate_fn=collate, drop_last=False, shuffle=True) in_dim = wholegraph.ndata['attr'].shape[1] # Step 2: Create model =================================================================== # model = InfoGraph(in_dim, args.hid_dim, args.n_layers) model = model.to(args.device) # Step 3: Create training components ===================================================== # optimizer = th.optim.Adam(model.parameters(), lr=args.lr) print('===== Before training ======') wholegraph = wholegraph.to(args.device) wholefeat = wholegraph.ndata['attr'] emb = model.get_embedding(wholegraph, wholefeat).cpu() res = evaluate_embedding(emb, labels, args.device) ''' Evaluate the initialized embeddings ''' ''' using logistic regression and SVM(non-linear) ''' print('logreg {:4f}, svc {:4f}'.format(res[0], res[1])) best_logreg = 0 best_logreg_epoch = 0 best_svc = 0 best_svc_epoch = 0 # Step 4: training epochs =============================================================== # for epoch in range(args.epochs): loss_all = 0 model.train() for graph, label in dataloader: graph = graph.to(args.device) feat = graph.ndata['attr'] graph_id = graph.ndata['graph_id'] n_graph = label.shape[0] optimizer.zero_grad() loss = model(graph, feat, graph_id) loss.backward() optimizer.step() loss_all += loss.item() print('Epoch {}, Loss {:.4f}'.format(epoch, loss_all)) if epoch % args.log_interval == 0: # evaluate embeddings model.eval() emb = model.get_embedding(wholegraph, wholefeat).cpu() res = evaluate_embedding(emb, labels, args.device) if res[0] > best_logreg: best_logreg = res[0] best_logreg_epoch = epoch if res[1] > best_svc: best_svc = res[1] best_svc_epoch = epoch print('best logreg {:4f}, epoch {} | best svc: {:4f}, epoch {}'.format(best_logreg, best_logreg_epoch, best_svc, best_svc_epoch)) print('Training End') print('best logreg {:4f} ,best svc {:4f}'.format(best_logreg, best_svc))