import argparse import time import os import sys import math import random from tqdm import tqdm import numpy as np import torch from torch.nn import ModuleList, Linear, Conv1d, MaxPool1d, Embedding, BCEWithLogitsLoss import torch.nn.functional as F import dgl from dgl.nn import GraphConv, SortPooling from dgl.sampling import global_uniform_negative_sampling from dgl.dataloading import Sampler, DataLoader from ogb.linkproppred import DglLinkPropPredDataset, Evaluator from scipy.sparse.csgraph import shortest_path class Logger(object): def __init__(self, runs, info=None): self.info = info self.results = [[] for _ in range(runs)] def add_result(self, run, result): # result is in the format of (val_score, test_score) assert len(result) == 2 assert run >= 0 and run < len(self.results) self.results[run].append(result) def print_statistics(self, run=None, f=sys.stdout): if run is not None: result = 100 * torch.tensor(self.results[run]) argmax = result[:, 0].argmax().item() print(f'Run {run + 1:02d}:', file=f) print(f'Highest Valid: {result[:, 0].max():.2f}', file=f) print(f'Highest Eval Point: {argmax + 1}', file=f) print(f' Final Test: {result[argmax, 1]:.2f}', file=f) else: result = 100 * torch.tensor(self.results) best_results = [] for r in result: valid = r[:, 0].max().item() test = r[r[:, 0].argmax(), 1].item() best_results.append((valid, test)) best_result = torch.tensor(best_results) print(f'All runs:', file=f) r = best_result[:, 0] print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}', file=f) r = best_result[:, 1] print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}', file=f) class SealSampler(Sampler): def __init__(self, g, num_hops=1, sample_ratio=1., directed=False, prefetch_node_feats=None, prefetch_edge_feats=None): super().__init__() self.g = g self.num_hops = num_hops self.sample_ratio = sample_ratio self.directed = directed self.prefetch_node_feats = prefetch_node_feats self.prefetch_edge_feats = prefetch_edge_feats def _double_radius_node_labeling(self, adj): N = adj.shape[0] adj_wo_src = adj[range(1, N), :][:, range(1, N)] idx = list(range(1)) + list(range(2, N)) adj_wo_dst = adj[idx, :][:, idx] dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=0) dist2src = np.insert(dist2src, 1, 0, axis=0) dist2src = torch.from_numpy(dist2src) dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=0) dist2dst = np.insert(dist2dst, 0, 0, axis=0) dist2dst = torch.from_numpy(dist2dst) dist = dist2src + dist2dst dist_over_2, dist_mod_2 = torch.div(dist, 2, rounding_mode='floor'), dist % 2 z = 1 + torch.min(dist2src, dist2dst) z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1) z[0: 2] = 1. # shortest path may include inf values z[torch.isnan(z)] = 0. return z.to(torch.long) def sample(self, aug_g, seed_edges): g = self.g subgraphs = [] # construct k-hop enclosing graph for each link for eid in seed_edges: src, dst = map(int, aug_g.find_edges(eid)) # construct the enclosing graph visited, nodes, fringe = [np.unique([src, dst]) for _ in range(3)] for _ in range(self.num_hops): if not self.directed: _, fringe = g.out_edges(fringe) else: _, out_neighbors = g.out_edges(fringe) in_neighbors, _ = g.in_edges(fringe) fringe = np.union1d(in_neighbors, out_neighbors) fringe = np.setdiff1d(fringe, visited) visited = np.union1d(visited, fringe) if self.sample_ratio < 1.: fringe = np.random.choice(fringe, int(self.sample_ratio * len(fringe)), replace=False) if len(fringe) == 0: break nodes = np.union1d(nodes, fringe) subg = g.subgraph(nodes, store_ids=True) # remove edges to predict edges_to_remove = [ subg.edge_ids(s, t) for s, t in [(0, 1), (1, 0)] if subg.has_edges_between(s, t)] subg.remove_edges(edges_to_remove) # add double radius node labeling subg.ndata['z'] = self._double_radius_node_labeling(subg.adj(scipy_fmt='csr')) subg_aug = subg.add_self_loop() if 'weight' in subg.edata: subg_aug.edata['weight'][subg.num_edges():] = torch.ones( subg_aug.num_edges() - subg.num_edges()) subgraphs.append(subg_aug) subgraphs = dgl.batch(subgraphs) dgl.set_src_lazy_features(subg_aug, self.prefetch_node_feats) dgl.set_edge_lazy_features(subg_aug, self.prefetch_edge_feats) return subgraphs, aug_g.edata['y'][seed_edges] # An end-to-end deep learning architecture for graph classification, AAAI-18. class DGCNN(torch.nn.Module): def __init__(self, hidden_channels, num_layers, k, GNN=GraphConv, feature_dim=0): super(DGCNN, self).__init__() self.feature_dim = feature_dim self.k = k self.sort_pool = SortPooling(k=k) self.max_z = 1000 self.z_embedding = Embedding(self.max_z, hidden_channels) self.convs = ModuleList() initial_channels = hidden_channels + self.feature_dim self.convs.append(GNN(initial_channels, hidden_channels)) for _ in range(0, num_layers-1): self.convs.append(GNN(hidden_channels, hidden_channels)) self.convs.append(GNN(hidden_channels, 1)) conv1d_channels = [16, 32] total_latent_dim = hidden_channels * num_layers + 1 conv1d_kws = [total_latent_dim, 5] self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0]) self.maxpool1d = MaxPool1d(2, 2) self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1) dense_dim = int((self.k - 2) / 2 + 1) dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1] self.lin1 = Linear(dense_dim, 128) self.lin2 = Linear(128, 1) def forward(self, g, z, x=None, edge_weight=None): z_emb = self.z_embedding(z) if z_emb.ndim == 3: # in case z has multiple integer labels z_emb = z_emb.sum(dim=1) if x is not None: x = torch.cat([z_emb, x.to(torch.float)], 1) else: x = z_emb xs = [x] for conv in self.convs: xs += [torch.tanh(conv(g, xs[-1], edge_weight=edge_weight))] x = torch.cat(xs[1:], dim=-1) # global pooling x = self.sort_pool(g, x) x = x.unsqueeze(1) # [num_graphs, 1, k * hidden] x = F.relu(self.conv1(x)) x = self.maxpool1d(x) x = F.relu(self.conv2(x)) x = x.view(x.size(0), -1) # [num_graphs, dense_dim] # MLP. x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return x def get_pos_neg_edges(split, split_edge, g, percent=100): pos_edge = split_edge[split]['edge'] if split == 'train': neg_edge = torch.stack(global_uniform_negative_sampling( g, num_samples=pos_edge.size(0), exclude_self_loops=True ), dim=1) else: neg_edge = split_edge[split]['edge_neg'] # sampling according to the percent param np.random.seed(123) # pos sampling num_pos = pos_edge.size(0) perm = np.random.permutation(num_pos) perm = perm[:int(percent / 100 * num_pos)] pos_edge = pos_edge[perm] # neg sampling if neg_edge.dim() > 2: # [Np, Nn, 2] neg_edge = neg_edge[perm].view(-1, 2) else: np.random.seed(123) num_neg = neg_edge.size(0) perm = np.random.permutation(num_neg) perm = perm[:int(percent / 100 * num_neg)] neg_edge = neg_edge[perm] return pos_edge, neg_edge # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2]) def train(): model.train() loss_fnt = BCEWithLogitsLoss() total_loss = 0 total = 0 pbar = tqdm(train_loader, ncols=70) for gs, y in pbar: optimizer.zero_grad() logits = model(gs, gs.ndata['z'], gs.ndata.get('feat', None), edge_weight=gs.edata.get('weight', None)) loss = loss_fnt(logits.view(-1), y.to(torch.float)) loss.backward() optimizer.step() total_loss += loss.item() * gs.batch_size total += gs.batch_size return total_loss / total @torch.no_grad() def test(): model.eval() y_pred, y_true = [], [] for gs, y in tqdm(val_loader, ncols=70): logits = model(gs, gs.ndata['z'], gs.ndata.get('feat', None), edge_weight=gs.edata.get('weight', None)) y_pred.append(logits.view(-1).cpu()) y_true.append(y.view(-1).cpu().to(torch.float)) val_pred, val_true = torch.cat(y_pred), torch.cat(y_true) pos_val_pred = val_pred[val_true==1] neg_val_pred = val_pred[val_true==0] y_pred, y_true = [], [] for gs, y in tqdm(test_loader, ncols=70): logits = model(gs, gs.ndata['z'], gs.ndata.get('feat', None), edge_weight=gs.edata.get('weight', None)) y_pred.append(logits.view(-1).cpu()) y_true.append(y.view(-1).cpu().to(torch.float)) test_pred, test_true = torch.cat(y_pred), torch.cat(y_true) pos_test_pred = test_pred[test_true==1] neg_test_pred = test_pred[test_true==0] if args.eval_metric == 'hits': results = evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred) elif args.eval_metric == 'mrr': results = evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred) return results def evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred): results = {} for K in [20, 50, 100]: evaluator.K = K valid_hits = evaluator.eval({ 'y_pred_pos': pos_val_pred, 'y_pred_neg': neg_val_pred, })[f'hits@{K}'] test_hits = evaluator.eval({ 'y_pred_pos': pos_test_pred, 'y_pred_neg': neg_test_pred, })[f'hits@{K}'] results[f'Hits@{K}'] = (valid_hits, test_hits) return results def evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred): print(pos_val_pred.size(), neg_val_pred.size(), pos_test_pred.size(), neg_test_pred.size()) neg_val_pred = neg_val_pred.view(pos_val_pred.shape[0], -1) neg_test_pred = neg_test_pred.view(pos_test_pred.shape[0], -1) results = {} valid_mrr = evaluator.eval({ 'y_pred_pos': pos_val_pred, 'y_pred_neg': neg_val_pred, })['mrr_list'].mean().item() test_mrr = evaluator.eval({ 'y_pred_pos': pos_test_pred, 'y_pred_neg': neg_test_pred, })['mrr_list'].mean().item() results['MRR'] = (valid_mrr, test_mrr) return results if __name__ == '__main__': # Data settings parser = argparse.ArgumentParser(description='OGBL (SEAL)') parser.add_argument('--dataset', type=str, default='ogbl-collab') # GNN settings parser.add_argument('--sortpool_k', type=float, default=0.6) parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=32) parser.add_argument('--batch_size', type=int, default=32) # Subgraph extraction settings parser.add_argument('--ratio_per_hop', type=float, default=1.0) parser.add_argument('--use_feature', action='store_true', help="whether to use raw node features as GNN input") parser.add_argument('--use_edge_weight', action='store_true', help="whether to consider edge weight in GNN") # Training settings parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--runs', type=int, default=10) parser.add_argument('--train_percent', type=float, default=100) parser.add_argument('--val_percent', type=float, default=100) parser.add_argument('--test_percent', type=float, default=100) parser.add_argument('--num_workers', type=int, default=8, help="number of workers for dynamic dataloaders") # Testing settings parser.add_argument('--use_valedges_as_input', action='store_true') parser.add_argument('--eval_steps', type=int, default=1) args = parser.parse_args() data_appendix = '_rph{}'.format(''.join(str(args.ratio_per_hop).split('.'))) if args.use_valedges_as_input: data_appendix += '_uvai' args.res_dir = os.path.join('results/{}_{}'.format(args.dataset, time.strftime("%Y%m%d%H%M%S"))) print('Results will be saved in ' + args.res_dir) if not os.path.exists(args.res_dir): os.makedirs(args.res_dir) log_file = os.path.join(args.res_dir, 'log.txt') # Save command line input. cmd_input = 'python ' + ' '.join(sys.argv) + '\n' with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'a') as f: f.write(cmd_input) print('Command line input: ' + cmd_input + ' is saved.') with open(log_file, 'a') as f: f.write('\n' + cmd_input) dataset = DglLinkPropPredDataset(name=args.dataset) split_edge = dataset.get_edge_split() graph = dataset[0] # re-format the data of citation2 if args.dataset == 'ogbl-citation2': for k in ['train', 'valid', 'test']: src = split_edge[k]['source_node'] tgt = split_edge[k]['target_node'] split_edge[k]['edge'] = torch.stack([src, tgt], dim=1) if k != 'train': tgt_neg = split_edge[k]['target_node_neg'] split_edge[k]['edge_neg'] = torch.stack([ src[:, None].repeat(1, tgt_neg.size(1)), tgt_neg ], dim=-1) # [Ns, Nt, 2] # reconstruct the graph for ogbl-collab data for validation edge augmentation and coalesce if args.dataset == 'ogbl-collab': if args.use_valedges_as_input: val_edges = split_edge['valid']['edge'] row, col = val_edges.t() # float edata for to_simple transform graph.edata.pop('year') graph.edata['weight'] = graph.edata['weight'].to(torch.float) val_weights = torch.ones(size=(val_edges.size(0), 1)) graph.add_edges(torch.cat([row, col]), torch.cat([col, row]), {'weight': val_weights}) graph = graph.to_simple(copy_edata=True, aggregator='sum') if not args.use_edge_weight and 'weight' in graph.edata: graph.edata.pop('weight') if not args.use_feature and 'feat' in graph.ndata: graph.ndata.pop('feat') if args.dataset.startswith('ogbl-citation'): args.eval_metric = 'mrr' directed = True else: args.eval_metric = 'hits' directed = False evaluator = Evaluator(name=args.dataset) if args.eval_metric == 'hits': loggers = { 'Hits@20': Logger(args.runs, args), 'Hits@50': Logger(args.runs, args), 'Hits@100': Logger(args.runs, args), } elif args.eval_metric == 'mrr': loggers = { 'MRR': Logger(args.runs, args), } device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = dataset.root + '_seal{}'.format(data_appendix) loaders = [] prefetch_node_feats = ['feat'] if 'feat' in graph.ndata else None prefetch_edge_feats = ['weight'] if 'weight' in graph.edata else None train_edge, train_edge_neg = get_pos_neg_edges('train', split_edge, graph, args.train_percent) val_edge, val_edge_neg = get_pos_neg_edges('valid', split_edge, graph, args.val_percent) test_edge, test_edge_neg = get_pos_neg_edges('test', split_edge, graph, args.test_percent) # create an augmented graph for sampling aug_g = dgl.graph(graph.edges()) aug_g.edata['y'] = torch.ones(aug_g.num_edges()) aug_edges = torch.cat([val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg]) aug_labels = torch.cat([ torch.ones(len(val_edge) + len(test_edge)), torch.zeros(len(train_edge_neg) + len(val_edge_neg) + len(test_edge_neg)) ]) aug_g.add_edges(aug_edges[:, 0], aug_edges[:, 1], {'y': aug_labels}) # eids for sampling split_len = [graph.num_edges()] + \ list(map(len, [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg])) train_eids = torch.cat([ graph.edge_ids(train_edge[:, 0], train_edge[:, 1]), torch.arange(sum(split_len[:3]), sum(split_len[:4])) ]) val_eids = torch.cat([ torch.arange(sum(split_len[:1]), sum(split_len[:2])), torch.arange(sum(split_len[:4]), sum(split_len[:5])) ]) test_eids = torch.cat([ torch.arange(sum(split_len[:2]), sum(split_len[:3])), torch.arange(sum(split_len[:5]), sum(split_len[:6])) ]) sampler = SealSampler(graph, 1, args.ratio_per_hop, directed, prefetch_node_feats, prefetch_edge_feats) # force to be dynamic for consistent dataloading for split, shuffle, eids in zip( ['train', 'valid', 'test'], [True, False, False], [train_eids, val_eids, test_eids] ): data_loader = DataLoader(aug_g, eids, sampler, shuffle=shuffle, device=device, batch_size=args.batch_size, num_workers=args.num_workers) loaders.append(data_loader) train_loader, val_loader, test_loader = loaders # convert sortpool_k from percentile to number. num_nodes = [] for subgs, _ in train_loader: subgs = dgl.unbatch(subgs) if len(num_nodes) > 1000: break for subg in subgs: num_nodes.append(subg.num_nodes()) num_nodes = sorted(num_nodes) k = num_nodes[int(math.ceil(args.sortpool_k * len(num_nodes))) - 1] k = max(k, 10) for run in range(args.runs): model = DGCNN(args.hidden_channels, args.num_layers, k, feature_dim=graph.ndata['feat'].size(1) if args.use_feature else 0).to(device) parameters = list(model.parameters()) optimizer = torch.optim.Adam(params=parameters, lr=args.lr) total_params = sum(p.numel() for param in parameters for p in param) print(f'Total number of parameters is {total_params}') print(f'SortPooling k is set to {k}') with open(log_file, 'a') as f: print(f'Total number of parameters is {total_params}', file=f) print(f'SortPooling k is set to {k}', file=f) start_epoch = 1 # Training starts for epoch in range(start_epoch, start_epoch + args.epochs): loss = train() if epoch % args.eval_steps == 0: results = test() for key, result in results.items(): loggers[key].add_result(run, result) model_name = os.path.join( args.res_dir, 'run{}_model_checkpoint{}.pth'.format(run+1, epoch)) optimizer_name = os.path.join( args.res_dir, 'run{}_optimizer_checkpoint{}.pth'.format(run+1, epoch)) torch.save(model.state_dict(), model_name) torch.save(optimizer.state_dict(), optimizer_name) for key, result in results.items(): valid_res, test_res = result to_print = (f'Run: {run + 1:02d}, Epoch: {epoch:02d}, ' + f'Loss: {loss:.4f}, Valid: {100 * valid_res:.2f}%, ' + f'Test: {100 * test_res:.2f}%') print(key) print(to_print) with open(log_file, 'a') as f: print(key, file=f) print(to_print, file=f) for key in loggers.keys(): print(key) loggers[key].print_statistics(run) with open(log_file, 'a') as f: print(key, file=f) loggers[key].print_statistics(run, f=f) for key in loggers.keys(): print(key) loggers[key].print_statistics() with open(log_file, 'a') as f: print(key, file=f) loggers[key].print_statistics(f=f) print(f'Total number of parameters is {total_params}') print(f'Results are saved in {args.res_dir}')