import math import os import sys import time from collections import defaultdict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from numpy import random from torch.nn.parameter import Parameter from tqdm.auto import tqdm from utils import * import dgl import dgl.function as fn def get_graph(network_data, vocab): """Build graph, treat all nodes as the same type Parameters ---------- network_data: a dict keys describing the edge types, values representing edges vocab: a dict mapping node IDs to node indices Output ------ DGLGraph a heterogenous graph, with one node type and different edge types """ graphs = [] node_type = "_N" # '_N' can be replaced by an arbitrary name data_dict = dict() num_nodes_dict = {node_type: len(vocab)} for edge_type in network_data: tmp_data = network_data[edge_type] src = [] dst = [] for edge in tmp_data: src.extend([vocab[edge[0]], vocab[edge[1]]]) dst.extend([vocab[edge[1]], vocab[edge[0]]]) data_dict[(node_type, edge_type, node_type)] = (src, dst) graph = dgl.heterograph(data_dict, num_nodes_dict) return graph class NeighborSampler(object): def __init__(self, g, num_fanouts): self.g = g self.num_fanouts = num_fanouts def sample(self, pairs): heads, tails, types = zip(*pairs) seeds, head_invmap = torch.unique( torch.LongTensor(heads), return_inverse=True ) blocks = [] for fanout in reversed(self.num_fanouts): sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout) sampled_block = dgl.to_block(sampled_graph, seeds) seeds = sampled_block.srcdata[dgl.NID] blocks.insert(0, sampled_block) return ( blocks, torch.LongTensor(head_invmap), torch.LongTensor(tails), torch.LongTensor(types), ) class DGLGATNE(nn.Module): def __init__( self, num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a, ): super(DGLGATNE, self).__init__() self.num_nodes = num_nodes self.embedding_size = embedding_size self.embedding_u_size = embedding_u_size self.edge_types = edge_types self.edge_type_count = edge_type_count self.dim_a = dim_a self.node_embeddings = Parameter( torch.FloatTensor(num_nodes, embedding_size) ) self.node_type_embeddings = Parameter( torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size) ) self.trans_weights = Parameter( torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size) ) self.trans_weights_s1 = Parameter( torch.FloatTensor(edge_type_count, embedding_u_size, dim_a) ) self.trans_weights_s2 = Parameter( torch.FloatTensor(edge_type_count, dim_a, 1) ) self.reset_parameters() def reset_parameters(self): self.node_embeddings.data.uniform_(-1.0, 1.0) self.node_type_embeddings.data.uniform_(-1.0, 1.0) self.trans_weights.data.normal_( std=1.0 / math.sqrt(self.embedding_size) ) self.trans_weights_s1.data.normal_( std=1.0 / math.sqrt(self.embedding_size) ) self.trans_weights_s2.data.normal_( std=1.0 / math.sqrt(self.embedding_size) ) # embs: [batch_size, embedding_size] def forward(self, block): input_nodes = block.srcdata[dgl.NID] output_nodes = block.dstdata[dgl.NID] batch_size = block.number_of_dst_nodes() node_embed = self.node_embeddings node_type_embed = [] with block.local_scope(): for i in range(self.edge_type_count): edge_type = self.edge_types[i] block.srcdata[edge_type] = self.node_type_embeddings[ input_nodes, i ] block.dstdata[edge_type] = self.node_type_embeddings[ output_nodes, i ] block.update_all( fn.copy_u(edge_type, "m"), fn.sum("m", edge_type), etype=edge_type, ) node_type_embed.append(block.dstdata[edge_type]) node_type_embed = torch.stack(node_type_embed, 1) tmp_node_type_embed = node_type_embed.unsqueeze(2).view( -1, 1, self.embedding_u_size ) trans_w = ( self.trans_weights.unsqueeze(0) .repeat(batch_size, 1, 1, 1) .view(-1, self.embedding_u_size, self.embedding_size) ) trans_w_s1 = ( self.trans_weights_s1.unsqueeze(0) .repeat(batch_size, 1, 1, 1) .view(-1, self.embedding_u_size, self.dim_a) ) trans_w_s2 = ( self.trans_weights_s2.unsqueeze(0) .repeat(batch_size, 1, 1, 1) .view(-1, self.dim_a, 1) ) attention = ( F.softmax( torch.matmul( torch.tanh( torch.matmul(tmp_node_type_embed, trans_w_s1) ), trans_w_s2, ) .squeeze(2) .view(-1, self.edge_type_count), dim=1, ) .unsqueeze(1) .repeat(1, self.edge_type_count, 1) ) node_type_embed = torch.matmul(attention, node_type_embed).view( -1, 1, self.embedding_u_size ) node_embed = node_embed[output_nodes].unsqueeze(1).repeat( 1, self.edge_type_count, 1 ) + torch.matmul(node_type_embed, trans_w).view( -1, self.edge_type_count, self.embedding_size ) last_node_embed = F.normalize(node_embed, dim=2) return ( last_node_embed # [batch_size, edge_type_count, embedding_size] ) class NSLoss(nn.Module): def __init__(self, num_nodes, num_sampled, embedding_size): super(NSLoss, self).__init__() self.num_nodes = num_nodes self.num_sampled = num_sampled self.embedding_size = embedding_size self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size)) # [ (log(i+2) - log(i+1)) / log(num_nodes + 1)] self.sample_weights = F.normalize( torch.Tensor( [ (math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) for k in range(num_nodes) ] ), dim=0, ) self.reset_parameters() def reset_parameters(self): self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) def forward(self, input, embs, label): n = input.shape[0] log_target = torch.log( torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1)) ) negs = torch.multinomial( self.sample_weights, self.num_sampled * n, replacement=True ).view(n, self.num_sampled) noise = torch.neg(self.weights[negs]) sum_log_sampled = torch.sum( torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1 ).squeeze() loss = log_target + sum_log_sampled return -loss.sum() / n def train_model(network_data): index2word, vocab, type_nodes = generate_vocab(network_data) edge_types = list(network_data.keys()) num_nodes = len(index2word) edge_type_count = len(edge_types) epochs = args.epoch batch_size = args.batch_size embedding_size = args.dimensions embedding_u_size = args.edge_dim u_num = edge_type_count num_sampled = args.negative_samples dim_a = args.att_dim att_head = 1 neighbor_samples = args.neighbor_samples num_workers = args.workers device = torch.device( "cuda" if args.gpu is not None and torch.cuda.is_available() else "cpu" ) g = get_graph(network_data, vocab) all_walks = [] for i in range(edge_type_count): nodes = torch.LongTensor(type_nodes[i] * args.num_walks) traces, types = dgl.sampling.random_walk( g, nodes, metapath=[edge_types[i]] * (neighbor_samples - 1) ) all_walks.append(traces) train_pairs = generate_pairs(all_walks, args.window_size, num_workers) neighbor_sampler = NeighborSampler(g, [neighbor_samples]) train_dataloader = torch.utils.data.DataLoader( train_pairs, batch_size=batch_size, collate_fn=neighbor_sampler.sample, shuffle=True, num_workers=num_workers, pin_memory=True, ) model = DGLGATNE( num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a, ) nsloss = NSLoss(num_nodes, num_sampled, embedding_size) model.to(device) nsloss.to(device) optimizer = torch.optim.Adam( [{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-3, ) best_score = 0 patience = 0 for epoch in range(epochs): model.train() random.shuffle(train_pairs) data_iter = tqdm( train_dataloader, desc="epoch %d" % (epoch), total=(len(train_pairs) + (batch_size - 1)) // batch_size, ) avg_loss = 0.0 for i, (block, head_invmap, tails, block_types) in enumerate(data_iter): optimizer.zero_grad() # embs: [batch_size, edge_type_count, embedding_size] block_types = block_types.to(device) embs = model(block[0].to(device))[head_invmap] embs = embs.gather( 1, block_types.view(-1, 1, 1).expand( embs.shape[0], 1, embs.shape[2] ), )[:, 0] loss = nsloss( block[0].dstdata[dgl.NID][head_invmap].to(device), embs, tails.to(device), ) loss.backward() optimizer.step() avg_loss += loss.item() post_fix = { "epoch": epoch, "iter": i, "avg_loss": avg_loss / (i + 1), "loss": loss.item(), } data_iter.set_postfix(post_fix) model.eval() # {'1': {}, '2': {}} final_model = dict( zip(edge_types, [dict() for _ in range(edge_type_count)]) ) for i in range(num_nodes): train_inputs = ( torch.tensor([i for _ in range(edge_type_count)]) .unsqueeze(1) .to(device) ) # [i, i] train_types = ( torch.tensor(list(range(edge_type_count))) .unsqueeze(1) .to(device) ) # [0, 1] pairs = torch.cat( (train_inputs, train_inputs, train_types), dim=1 ) # (2, 3) ( train_blocks, train_invmap, fake_tails, train_types, ) = neighbor_sampler.sample(pairs) node_emb = model(train_blocks[0].to(device))[train_invmap] node_emb = node_emb.gather( 1, train_types.to(device) .view(-1, 1, 1) .expand(node_emb.shape[0], 1, node_emb.shape[2]), )[:, 0] for j in range(edge_type_count): final_model[edge_types[j]][index2word[i]] = ( node_emb[j].cpu().detach().numpy() ) valid_aucs, valid_f1s, valid_prs = [], [], [] test_aucs, test_f1s, test_prs = [], [], [] for i in range(edge_type_count): if args.eval_type == "all" or edge_types[i] in args.eval_type.split( "," ): tmp_auc, tmp_f1, tmp_pr = evaluate( final_model[edge_types[i]], valid_true_data_by_edge[edge_types[i]], valid_false_data_by_edge[edge_types[i]], num_workers, ) valid_aucs.append(tmp_auc) valid_f1s.append(tmp_f1) valid_prs.append(tmp_pr) tmp_auc, tmp_f1, tmp_pr = evaluate( final_model[edge_types[i]], testing_true_data_by_edge[edge_types[i]], testing_false_data_by_edge[edge_types[i]], num_workers, ) test_aucs.append(tmp_auc) test_f1s.append(tmp_f1) test_prs.append(tmp_pr) print("valid auc:", np.mean(valid_aucs)) print("valid pr:", np.mean(valid_prs)) print("valid f1:", np.mean(valid_f1s)) average_auc = np.mean(test_aucs) average_f1 = np.mean(test_f1s) average_pr = np.mean(test_prs) cur_score = np.mean(valid_aucs) if cur_score > best_score: best_score = cur_score patience = 0 else: patience += 1 if patience > args.patience: print("Early Stopping") break return average_auc, average_f1, average_pr if __name__ == "__main__": args = parse_args() file_name = args.input print(args) training_data_by_type = load_training_data(file_name + "/train.txt") valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data( file_name + "/valid.txt" ) testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data( file_name + "/test.txt" ) start = time.time() average_auc, average_f1, average_pr = train_model(training_data_by_type) end = time.time() print("Overall ROC-AUC:", average_auc) print("Overall PR-AUC", average_pr) print("Overall F1:", average_f1) print("Training Time", end - start)