import dgl import argparse import torch as th import torch.optim as optim from torch.nn.functional import softmax from sklearn.metrics import roc_auc_score, recall_score from utils import EarlyStopping from model_sampling import CAREGNN, CARESampler, _l1_dist def evaluate(model, loss_fn, dataloader, device='cpu'): loss = 0 auc = 0 recall = 0 num_blocks = 0 for input_nodes, output_nodes, blocks in dataloader: blocks = [b.to(device) for b in blocks] feature = blocks[0].srcdata['feature'] label = blocks[-1].dstdata['label'] logits_gnn, logits_sim = model(blocks, feature) # compute loss loss += loss_fn(logits_gnn, label).item() + args.sim_weight * loss_fn(logits_sim, label).item() recall += recall_score(label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()) auc += roc_auc_score(label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu()) num_blocks += 1 return recall / num_blocks, auc / num_blocks, loss / num_blocks def main(args): # Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Load dataset dataset = dgl.data.FraudDataset(args.dataset, train_size=0.4) graph = dataset[0] num_classes = dataset.num_classes # check cuda if args.gpu >= 0 and th.cuda.is_available(): device = 'cuda:{}'.format(args.gpu) args.num_workers = 0 else: device = 'cpu' # retrieve labels of ground truth labels = graph.ndata['label'].to(device) # Extract node features feat = graph.ndata['feature'].to(device) layers_feat = feat.expand(args.num_layers, -1, -1) # retrieve masks for train/validation/test train_mask = graph.ndata['train_mask'] val_mask = graph.ndata['val_mask'] test_mask = graph.ndata['test_mask'] train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device) val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device) test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device) # Reinforcement learning module only for positive training nodes rl_idx = th.nonzero(train_mask.to(device) & labels.bool(), as_tuple=False).squeeze(1) graph = graph.to(device) # Step 2: Create model =================================================================== # model = CAREGNN(in_dim=feat.shape[-1], num_classes=num_classes, hid_dim=args.hid_dim, num_layers=args.num_layers, activation=th.tanh, step_size=args.step_size, edges=graph.canonical_etypes) model = model.to(device) # Step 3: Create training components ===================================================== # _, cnt = th.unique(labels, return_counts=True) loss_fn = th.nn.CrossEntropyLoss(weight=1 / cnt) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) if args.early_stop: stopper = EarlyStopping(patience=100) # Step 4: training epochs =============================================================== # for epoch in range(args.max_epoch): # calculate the distance of each edges and sample based on the distance dists = [] p = [] for i in range(args.num_layers): dist = {} graph.ndata['nd'] = th.tanh(model.layers[i].MLP(layers_feat[i])) for etype in graph.canonical_etypes: graph.apply_edges(_l1_dist, etype=etype) dist[etype] = graph.edges[etype].data.pop('ed').detach().cpu() dists.append(dist) p.append(model.layers[i].p) graph.ndata.pop('nd') sampler = CARESampler(p, dists, args.num_layers) # train model.train() tr_loss = 0 tr_recall = 0 tr_auc = 0 tr_blk = 0 train_dataloader = dgl.dataloading.DataLoader( graph, train_idx, sampler, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers) for input_nodes, output_nodes, blocks in train_dataloader: blocks = [b.to(device) for b in blocks] train_feature = blocks[0].srcdata['feature'] train_label = blocks[-1].dstdata['label'] logits_gnn, logits_sim = model(blocks, train_feature) # compute loss blk_loss = loss_fn(logits_gnn, train_label) + args.sim_weight * loss_fn(logits_sim, train_label) tr_loss += blk_loss.item() tr_recall += recall_score(train_label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()) tr_auc += roc_auc_score(train_label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu()) tr_blk += 1 # backward optimizer.zero_grad() blk_loss.backward() optimizer.step() # Reinforcement learning module model.RLModule(graph, epoch, rl_idx, dists) # validation model.eval() val_dataloader = dgl.dataloading.DataLoader( graph, val_idx, sampler, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers) val_recall, val_auc, val_loss = evaluate(model, loss_fn, val_dataloader, device) # Print out performance print("In epoch {}, Train Recall: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; " "Valid Recall: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}". format(epoch, tr_recall / tr_blk, tr_auc / tr_blk, tr_loss / tr_blk, val_recall, val_auc, val_loss)) if args.early_stop: if stopper.step(val_auc, model): break # Test with mini batch after all epoch model.eval() if args.early_stop: model.load_state_dict(th.load('es_checkpoint.pt')) test_dataloader = dgl.dataloading.DataLoader( graph, test_idx, sampler, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers) test_recall, test_auc, test_loss = evaluate(model, loss_fn, test_dataloader, device) print("Test Recall: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}".format(test_recall, test_auc, test_loss)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='GCN-based Anti-Spam Model') parser.add_argument("--dataset", type=str, default="amazon", help="DGL dataset for this model (yelp, or amazon)") parser.add_argument("--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU.") parser.add_argument("--hid_dim", type=int, default=64, help="Hidden layer dimension") parser.add_argument("--num_layers", type=int, default=1, help="Number of layers") parser.add_argument("--batch_size", type=int, default=256, help="Size of mini-batch") parser.add_argument("--max_epoch", type=int, default=30, help="The max number of epochs. Default: 30") parser.add_argument("--lr", type=float, default=0.01, help="Learning rate. Default: 0.01") parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay. Default: 0.001") parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02") parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 0.001") parser.add_argument("--num_workers", type=int, default=4, help="Number of node dataloader") parser.add_argument('--early-stop', action='store_true', default=False, help="indicates whether to use early stop") args = parser.parse_args() th.manual_seed(717) print(args) main(args)