# -*- coding: utf-8 -*- """ HAN mini-batch training by RandomWalkSampler. note: This demo use RandomWalkSampler to sample neighbors, it's hard to get all neighbors when valid or test, so we sampled twice as many neighbors during val/test than training. """ import argparse import numpy import torch import torch.nn as nn import torch.nn.functional as F from model_hetero import SemanticAttention from sklearn.metrics import f1_score from torch.utils.data import DataLoader from utils import EarlyStopping, set_random_seed import dgl from dgl.nn.pytorch import GATConv from dgl.sampling import RandomWalkNeighborSampler class HANLayer(torch.nn.Module): """ HAN layer. Arguments --------- num_metapath : number of metapath based sub-graph in_size : input feature dimension out_size : output feature dimension layer_num_heads : number of attention heads dropout : Dropout probability Inputs ------ g : DGLHeteroGraph The heterogeneous graph h : tensor Input features Outputs ------- tensor The output feature """ def __init__( self, num_metapath, in_size, out_size, layer_num_heads, dropout ): super(HANLayer, self).__init__() # One GAT layer for each meta path based adjacency matrix self.gat_layers = nn.ModuleList() for i in range(num_metapath): self.gat_layers.append( GATConv( in_size, out_size, layer_num_heads, dropout, dropout, activation=F.elu, allow_zero_in_degree=True, ) ) self.semantic_attention = SemanticAttention( in_size=out_size * layer_num_heads ) self.num_metapath = num_metapath def forward(self, block_list, h_list): semantic_embeddings = [] for i, block in enumerate(block_list): semantic_embeddings.append( self.gat_layers[i](block, h_list[i]).flatten(1) ) semantic_embeddings = torch.stack( semantic_embeddings, dim=1 ) # (N, M, D * K) return self.semantic_attention(semantic_embeddings) # (N, D * K) class HAN(nn.Module): def __init__( self, num_metapath, in_size, hidden_size, out_size, num_heads, dropout ): super(HAN, self).__init__() self.layers = nn.ModuleList() self.layers.append( HANLayer(num_metapath, in_size, hidden_size, num_heads[0], dropout) ) for l in range(1, len(num_heads)): self.layers.append( HANLayer( num_metapath, hidden_size * num_heads[l - 1], hidden_size, num_heads[l], dropout, ) ) self.predict = nn.Linear(hidden_size * num_heads[-1], out_size) def forward(self, g, h): for gnn in self.layers: h = gnn(g, h) return self.predict(h) class HANSampler(object): def __init__(self, g, metapath_list, num_neighbors): self.sampler_list = [] for metapath in metapath_list: # note: random walk may get same route(same edge), which will be removed in the sampled graph. # So the sampled graph's edges may be less than num_random_walks(num_neighbors). self.sampler_list.append( RandomWalkNeighborSampler( G=g, num_traversals=1, termination_prob=0, num_random_walks=num_neighbors, num_neighbors=num_neighbors, metapath=metapath, ) ) def sample_blocks(self, seeds): block_list = [] for sampler in self.sampler_list: frontier = sampler(seeds) # add self loop frontier = dgl.remove_self_loop(frontier) frontier.add_edges(torch.tensor(seeds), torch.tensor(seeds)) block = dgl.to_block(frontier, seeds) block_list.append(block) return seeds, block_list def score(logits, labels): _, indices = torch.max(logits, dim=1) prediction = indices.long().cpu().numpy() labels = labels.cpu().numpy() accuracy = (prediction == labels).sum() / len(prediction) micro_f1 = f1_score(labels, prediction, average="micro") macro_f1 = f1_score(labels, prediction, average="macro") return accuracy, micro_f1, macro_f1 def evaluate( model, g, metapath_list, num_neighbors, features, labels, val_nid, loss_fcn, batch_size, ): model.eval() han_valid_sampler = HANSampler( g, metapath_list, num_neighbors=num_neighbors * 2 ) dataloader = DataLoader( dataset=val_nid, batch_size=batch_size, collate_fn=han_valid_sampler.sample_blocks, shuffle=False, drop_last=False, num_workers=4, ) correct = total = 0 prediction_list = [] labels_list = [] with torch.no_grad(): for step, (seeds, blocks) in enumerate(dataloader): h_list = load_subtensors(blocks, features) blocks = [block.to(args["device"]) for block in blocks] hs = [h.to(args["device"]) for h in h_list] logits = model(blocks, hs) loss = loss_fcn( logits, labels[numpy.asarray(seeds)].to(args["device"]) ) # get each predict label _, indices = torch.max(logits, dim=1) prediction = indices.long().cpu().numpy() labels_batch = labels[numpy.asarray(seeds)].cpu().numpy() prediction_list.append(prediction) labels_list.append(labels_batch) correct += (prediction == labels_batch).sum() total += prediction.shape[0] total_prediction = numpy.concatenate(prediction_list) total_labels = numpy.concatenate(labels_list) micro_f1 = f1_score(total_labels, total_prediction, average="micro") macro_f1 = f1_score(total_labels, total_prediction, average="macro") accuracy = correct / total return loss, accuracy, micro_f1, macro_f1 def load_subtensors(blocks, features): h_list = [] for block in blocks: input_nodes = block.srcdata[dgl.NID] h_list.append(features[input_nodes]) return h_list def main(args): # acm data if args["dataset"] == "ACMRaw": from utils import load_data ( g, features, labels, n_classes, train_nid, val_nid, test_nid, train_mask, val_mask, test_mask, ) = load_data("ACMRaw") metapath_list = [["pa", "ap"], ["pf", "fp"]] else: raise NotImplementedError( "Unsupported dataset {}".format(args["dataset"]) ) # Is it need to set different neighbors numbers for different meta-path based graph? num_neighbors = args["num_neighbors"] han_sampler = HANSampler(g, metapath_list, num_neighbors) # Create PyTorch DataLoader for constructing blocks dataloader = DataLoader( dataset=train_nid, batch_size=args["batch_size"], collate_fn=han_sampler.sample_blocks, shuffle=True, drop_last=False, num_workers=4, ) model = HAN( num_metapath=len(metapath_list), in_size=features.shape[1], hidden_size=args["hidden_units"], out_size=n_classes, num_heads=args["num_heads"], dropout=args["dropout"], ).to(args["device"]) total_params = sum(p.numel() for p in model.parameters()) print("total_params: {:d}".format(total_params)) total_trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad ) print("total trainable params: {:d}".format(total_trainable_params)) stopper = EarlyStopping(patience=args["patience"]) loss_fn = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam( model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"] ) for epoch in range(args["num_epochs"]): model.train() for step, (seeds, blocks) in enumerate(dataloader): h_list = load_subtensors(blocks, features) blocks = [block.to(args["device"]) for block in blocks] hs = [h.to(args["device"]) for h in h_list] logits = model(blocks, hs) loss = loss_fn( logits, labels[numpy.asarray(seeds)].to(args["device"]) ) optimizer.zero_grad() loss.backward() optimizer.step() # print info in each batch train_acc, train_micro_f1, train_macro_f1 = score( logits, labels[numpy.asarray(seeds)] ) print( "Epoch {:d} | loss: {:.4f} | train_acc: {:.4f} | train_micro_f1: {:.4f} | train_macro_f1: {:.4f}".format( epoch + 1, loss, train_acc, train_micro_f1, train_macro_f1 ) ) val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate( model, g, metapath_list, num_neighbors, features, labels, val_nid, loss_fn, args["batch_size"], ) early_stop = stopper.step(val_loss.data.item(), val_acc, model) print( "Epoch {:d} | Val loss {:.4f} | Val Accuracy {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}".format( epoch + 1, val_loss.item(), val_acc, val_micro_f1, val_macro_f1 ) ) if early_stop: break stopper.load_checkpoint(model) test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate( model, g, metapath_list, num_neighbors, features, labels, test_nid, loss_fn, args["batch_size"], ) print( "Test loss {:.4f} | Test Accuracy {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}".format( test_loss.item(), test_acc, test_micro_f1, test_macro_f1 ) ) if __name__ == "__main__": parser = argparse.ArgumentParser("mini-batch HAN") parser.add_argument("-s", "--seed", type=int, default=1, help="Random seed") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_neighbors", type=int, default=20) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--num_heads", type=list, default=[8]) parser.add_argument("--hidden_units", type=int, default=8) parser.add_argument("--dropout", type=float, default=0.6) parser.add_argument("--weight_decay", type=float, default=0.001) parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--patience", type=int, default=10) parser.add_argument("--dataset", type=str, default="ACMRaw") parser.add_argument("--device", type=str, default="cuda:0") args = parser.parse_args().__dict__ # set_random_seed(args['seed']) main(args)