import argparse import random import warnings import numpy as np import torch as th import torch.nn as nn import dgl warnings.filterwarnings("ignore") from dataset import process_dataset, process_dataset_appnp from model import MVGRL, LogReg parser = argparse.ArgumentParser(description="mvgrl") parser.add_argument( "--dataname", type=str, default="cora", help="Name of dataset." ) parser.add_argument( "--gpu", type=int, default=-1, help="GPU index. Default: -1, using cpu." ) parser.add_argument("--epochs", type=int, default=500, help="Training epochs.") parser.add_argument( "--patience", type=int, default=20, help="Patient epochs to wait before early stopping.", ) parser.add_argument( "--lr1", type=float, default=0.001, help="Learning rate of mvgrl." ) parser.add_argument( "--lr2", type=float, default=0.01, help="Learning rate of linear evaluator." ) parser.add_argument( "--wd1", type=float, default=0.0, help="Weight decay of mvgrl." ) parser.add_argument( "--wd2", type=float, default=0.0, help="Weight decay of linear evaluator." ) parser.add_argument( "--epsilon", type=float, default=0.01, help="Edge mask threshold of diffusion graph.", ) parser.add_argument( "--hid_dim", type=int, default=512, help="Hidden layer dim." ) parser.add_argument( "--sample_size", type=int, default=2000, help="Subgraph size." ) args = parser.parse_args() # check cuda if args.gpu != -1 and th.cuda.is_available(): args.device = "cuda:{}".format(args.gpu) else: args.device = "cpu" if __name__ == "__main__": print(args) # Step 1: Prepare data =================================================================== # if args.dataname == "pubmed": ( graph, diff_graph, feat, label, train_idx, val_idx, test_idx, edge_weight, ) = process_dataset_appnp(args.epsilon) else: ( graph, diff_graph, feat, label, train_idx, val_idx, test_idx, edge_weight, ) = process_dataset(args.dataname, args.epsilon) edge_weight = th.tensor(edge_weight).float() graph.ndata["feat"] = feat diff_graph.edata["edge_weight"] = edge_weight n_feat = feat.shape[1] n_classes = np.unique(label).shape[0] edge_weight = th.tensor(edge_weight).float() train_idx = train_idx.to(args.device) val_idx = val_idx.to(args.device) test_idx = test_idx.to(args.device) n_node = graph.number_of_nodes() sample_size = args.sample_size lbl1 = th.ones(sample_size * 2) lbl2 = th.zeros(sample_size * 2) lbl = th.cat((lbl1, lbl2)) lbl = lbl.to(args.device) # Step 2: Create model =================================================================== # model = MVGRL(n_feat, args.hid_dim) model = model.to(args.device) # Step 3: Create training components ===================================================== # optimizer = th.optim.Adam( model.parameters(), lr=args.lr1, weight_decay=args.wd1 ) loss_fn = nn.BCEWithLogitsLoss() node_list = list(range(n_node)) # Step 4: Training epochs ================================================================ # best = float("inf") cnt_wait = 0 for epoch in range(args.epochs): model.train() optimizer.zero_grad() sample_idx = random.sample(node_list, sample_size) g = dgl.node_subgraph(graph, sample_idx) dg = dgl.node_subgraph(diff_graph, sample_idx) f = g.ndata.pop("feat") ew = dg.edata.pop("edge_weight") shuf_idx = np.random.permutation(sample_size) sf = f[shuf_idx, :] g = g.to(args.device) dg = dg.to(args.device) f = f.to(args.device) ew = ew.to(args.device) sf = sf.to(args.device) out = model(g, dg, f, sf, ew) loss = loss_fn(out, lbl) loss.backward() optimizer.step() print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, loss.item())) if loss < best: best = loss cnt_wait = 0 th.save(model.state_dict(), "model.pkl") else: cnt_wait += 1 if cnt_wait == args.patience: print("Early stopping") break model.load_state_dict(th.load("model.pkl")) graph = graph.to(args.device) diff_graph = diff_graph.to(args.device) feat = feat.to(args.device) edge_weight = edge_weight.to(args.device) embeds = model.get_embedding(graph, diff_graph, feat, edge_weight) train_embs = embeds[train_idx] test_embs = embeds[test_idx] label = label.to(args.device) train_labels = label[train_idx] test_labels = label[test_idx] accs = [] # Step 5: Linear evaluation ========================================================== # for _ in range(5): model = LogReg(args.hid_dim, n_classes) opt = th.optim.Adam( model.parameters(), lr=args.lr2, weight_decay=args.wd2 ) model = model.to(args.device) loss_fn = nn.CrossEntropyLoss() for epoch in range(300): model.train() opt.zero_grad() logits = model(train_embs) loss = loss_fn(logits, train_labels) loss.backward() opt.step() model.eval() logits = model(test_embs) preds = th.argmax(logits, dim=1) acc = th.sum(preds == test_labels).float() / test_labels.shape[0] accs.append(acc * 100) accs = th.stack(accs) print(accs.mean().item(), accs.std().item())