import argparse import numpy as np import torch as th import torch.nn as nn import warnings warnings.filterwarnings('ignore') from dataset import process_dataset from model import MVGRL, LogReg parser = argparse.ArgumentParser(description='mvgrl') parser.add_argument('--dataname', type=str, default='cora', help='Name of dataset.') parser.add_argument('--gpu', type=int, default=0, help='GPU index. Default: -1, using cpu.') parser.add_argument('--epochs', type=int, default=500, help='Training epochs.') parser.add_argument('--patience', type=int, default=20, help='Patient epochs to wait before early stopping.') parser.add_argument('--lr1', type=float, default=0.001, help='Learning rate of mvgrl.') parser.add_argument('--lr2', type=float, default=0.01, help='Learning rate of linear evaluator.') parser.add_argument('--wd1', type=float, default=0., help='Weight decay of mvgrl.') parser.add_argument('--wd2', type=float, default=0., help='Weight decay of linear evaluator.') parser.add_argument('--epsilon', type=float, default=0.01, help='Edge mask threshold of diffusion graph.') parser.add_argument("--hid_dim", type=int, default=512, help='Hidden layer dim.') args = parser.parse_args() # check cuda if args.gpu != -1 and th.cuda.is_available(): args.device = 'cuda:{}'.format(args.gpu) else: args.device = 'cpu' if __name__ == '__main__': print(args) # Step 1: Prepare data =================================================================== # graph, diff_graph, feat, label, train_idx, val_idx, test_idx, edge_weight = process_dataset(args.dataname, args.epsilon) n_feat = feat.shape[1] n_classes = np.unique(label).shape[0] graph = graph.to(args.device) diff_graph = diff_graph.to(args.device) feat = feat.to(args.device) edge_weight = th.tensor(edge_weight).float().to(args.device) train_idx = train_idx.to(args.device) val_idx = val_idx.to(args.device) test_idx = test_idx.to(args.device) n_node = graph.number_of_nodes() lbl1 = th.ones(n_node * 2) lbl2 = th.zeros(n_node * 2) lbl = th.cat((lbl1, lbl2)) # Step 2: Create model =================================================================== # model = MVGRL(n_feat, args.hid_dim) model = model.to(args.device) lbl = lbl.to(args.device) # Step 3: Create training components ===================================================== # optimizer = th.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1) loss_fn = nn.BCEWithLogitsLoss() # Step 4: Training epochs ================================================================ # best = float('inf') cnt_wait = 0 for epoch in range(args.epochs): model.train() optimizer.zero_grad() shuf_idx = np.random.permutation(n_node) shuf_feat = feat[shuf_idx, :] shuf_feat = shuf_feat.to(args.device) out = model(graph, diff_graph, feat, shuf_feat, edge_weight) loss = loss_fn(out, lbl) loss.backward() optimizer.step() print('Epoch: {0}, Loss: {1:0.4f}'.format(epoch, loss.item())) if loss < best: best = loss cnt_wait = 0 th.save(model.state_dict(), 'model.pkl') else: cnt_wait += 1 if cnt_wait == args.patience: print('Early stopping') break model.load_state_dict(th.load('model.pkl')) embeds = model.get_embedding(graph, diff_graph, feat, edge_weight) train_embs = embeds[train_idx] test_embs = embeds[test_idx] label = label.to(args.device) train_labels = label[train_idx] test_labels = label[test_idx] accs = [] # Step 5: Linear evaluation ========================================================== # for _ in range(5): model = LogReg(args.hid_dim, n_classes) opt = th.optim.Adam(model.parameters(), lr=args.lr2, weight_decay=args.wd2) model = model.to(args.device) loss_fn = nn.CrossEntropyLoss() for epoch in range(300): model.train() opt.zero_grad() logits = model(train_embs) loss = loss_fn(logits, train_labels) loss.backward() opt.step() model.eval() logits = model(test_embs) preds = th.argmax(logits, dim=1) acc = th.sum(preds == test_labels).float() / test_labels.shape[0] accs.append(acc * 100) accs = th.stack(accs) print(accs.mean().item(), accs.std().item())