import argparse import itertools import torch as th import torch.nn as nn import torch.nn.functional as F from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from tqdm import tqdm import dgl import dgl.nn as dglnn from dgl import AddReverse, Compose, ToSimple from dgl.nn import HeteroEmbedding def prepare_data(args): dataset = DglNodePropPredDataset(name="ogbn-mag") split_idx = dataset.get_idx_split() # graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks) g, labels = dataset[0] labels = labels["paper"].flatten() transform = Compose([ToSimple(), AddReverse()]) g = transform(g) print("Loaded graph: {}".format(g)) logger = Logger(args.runs) # train sampler sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20]) train_loader = dgl.dataloading.DataLoader( g, split_idx["train"], sampler, batch_size=1024, shuffle=True, num_workers=0, ) return g, labels, dataset.num_classes, split_idx, logger, train_loader def extract_embed(node_embed, input_nodes): emb = node_embed( {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != "paper"} ) return emb def rel_graph_embed(graph, embed_size): node_num = {} for ntype in graph.ntypes: if ntype == "paper": continue node_num[ntype] = graph.num_nodes(ntype) embeds = HeteroEmbedding(node_num, embed_size) return embeds class RelGraphConvLayer(nn.Module): def __init__( self, in_feat, out_feat, ntypes, rel_names, activation=None, dropout=0.0 ): super(RelGraphConvLayer, self).__init__() self.in_feat = in_feat self.out_feat = out_feat self.ntypes = ntypes self.rel_names = rel_names self.activation = activation self.conv = dglnn.HeteroGraphConv( { rel: dglnn.GraphConv( in_feat, out_feat, norm="right", weight=False, bias=False ) for rel in rel_names } ) self.weight = nn.ModuleDict( { rel_name: nn.Linear(in_feat, out_feat, bias=False) for rel_name in self.rel_names } ) # weight for self loop self.loop_weights = nn.ModuleDict( { ntype: nn.Linear(in_feat, out_feat, bias=True) for ntype in self.ntypes } ) self.dropout = nn.Dropout(dropout) self.reset_parameters() def reset_parameters(self): for layer in self.weight.values(): layer.reset_parameters() for layer in self.loop_weights.values(): layer.reset_parameters() def forward(self, g, inputs): """ Parameters ---------- g : DGLHeteroGraph Input graph. inputs : dict[str, torch.Tensor] Node feature for each node type. Returns ------- dict[str, torch.Tensor] New node features for each node type. """ g = g.local_var() wdict = { rel_name: {"weight": self.weight[rel_name].weight.T} for rel_name in self.rel_names } inputs_dst = { k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items() } hs = self.conv(g, inputs, mod_kwargs=wdict) def _apply(ntype, h): h = h + self.loop_weights[ntype](inputs_dst[ntype]) if self.activation: h = self.activation(h) return self.dropout(h) return {ntype: _apply(ntype, h) for ntype, h in hs.items()} class EntityClassify(nn.Module): def __init__(self, g, in_dim, out_dim): super(EntityClassify, self).__init__() self.in_dim = in_dim self.h_dim = 64 self.out_dim = out_dim self.rel_names = list(set(g.etypes)) self.rel_names.sort() self.dropout = 0.5 self.layers = nn.ModuleList() # i2h self.layers.append( RelGraphConvLayer( self.in_dim, self.h_dim, g.ntypes, self.rel_names, activation=F.relu, dropout=self.dropout, ) ) # h2o self.layers.append( RelGraphConvLayer( self.h_dim, self.out_dim, g.ntypes, self.rel_names, activation=None, ) ) def reset_parameters(self): for layer in self.layers: layer.reset_parameters() def forward(self, h, blocks): for layer, block in zip(self.layers, blocks): h = layer(block, h) return h class Logger(object): r""" This class was taken directly from the PyG implementation and can be found here: https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/logger.py This was done to ensure that performance was measured in precisely the same way """ def __init__(self, runs): self.results = [[] for _ in range(runs)] def add_result(self, run, result): assert len(result) == 3 assert run >= 0 and run < len(self.results) self.results[run].append(result) def print_statistics(self, run=None): if run is not None: result = 100 * th.tensor(self.results[run]) argmax = result[:, 1].argmax().item() print(f"Run {run + 1:02d}:") print(f"Highest Train: {result[:, 0].max():.2f}") print(f"Highest Valid: {result[:, 1].max():.2f}") print(f" Final Train: {result[argmax, 0]:.2f}") print(f" Final Test: {result[argmax, 2]:.2f}") else: result = 100 * th.tensor(self.results) best_results = [] for r in result: train1 = r[:, 0].max().item() valid = r[:, 1].max().item() train2 = r[r[:, 1].argmax(), 0].item() test = r[r[:, 1].argmax(), 2].item() best_results.append((train1, valid, train2, test)) best_result = th.tensor(best_results) print(f"All runs:") r = best_result[:, 0] print(f"Highest Train: {r.mean():.2f} ± {r.std():.2f}") r = best_result[:, 1] print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}") r = best_result[:, 2] print(f" Final Train: {r.mean():.2f} ± {r.std():.2f}") r = best_result[:, 3] print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}") def train( g, model, node_embed, optimizer, train_loader, split_idx, labels, logger, device, run, ): print("start training...") category = "paper" for epoch in range(3): num_train = split_idx["train"][category].shape[0] pbar = tqdm(total=num_train) pbar.set_description(f"Epoch {epoch:02d}") model.train() total_loss = 0 for input_nodes, seeds, blocks in train_loader: blocks = [blk.to(device) for blk in blocks] seeds = seeds[ category ] # we only predict the nodes with type "category" batch_size = seeds.shape[0] emb = extract_embed(node_embed, input_nodes) # Add the batch's raw "paper" features emb.update( {"paper": g.ndata["feat"]["paper"][input_nodes["paper"]]} ) emb = {k: e.to(device) for k, e in emb.items()} lbl = labels[seeds].to(device) optimizer.zero_grad() logits = model(emb, blocks)[category] y_hat = logits.log_softmax(dim=-1) loss = F.nll_loss(y_hat, lbl) loss.backward() optimizer.step() total_loss += loss.item() * batch_size pbar.update(batch_size) pbar.close() loss = total_loss / num_train result = test(g, model, node_embed, labels, device, split_idx) logger.add_result(run, result) train_acc, valid_acc, test_acc = result print( f"Run: {run + 1:02d}, " f"Epoch: {epoch +1 :02d}, " f"Loss: {loss:.4f}, " f"Train: {100 * train_acc:.2f}%, " f"Valid: {100 * valid_acc:.2f}%, " f"Test: {100 * test_acc:.2f}%" ) return logger @th.no_grad() def test(g, model, node_embed, y_true, device, split_idx): model.eval() category = "paper" evaluator = Evaluator(name="ogbn-mag") # 2 GNN layers sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2) loader = dgl.dataloading.DataLoader( g, {"paper": th.arange(g.num_nodes("paper"))}, sampler, batch_size=16384, shuffle=False, num_workers=0, ) pbar = tqdm(total=y_true.size(0)) pbar.set_description(f"Inference") y_hats = list() for input_nodes, seeds, blocks in loader: blocks = [blk.to(device) for blk in blocks] seeds = seeds[ category ] # we only predict the nodes with type "category" batch_size = seeds.shape[0] emb = extract_embed(node_embed, input_nodes) # Get the batch's raw "paper" features emb.update({"paper": g.ndata["feat"]["paper"][input_nodes["paper"]]}) emb = {k: e.to(device) for k, e in emb.items()} logits = model(emb, blocks)[category] y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True) y_hats.append(y_hat.cpu()) pbar.update(batch_size) pbar.close() y_pred = th.cat(y_hats, dim=0) y_true = th.unsqueeze(y_true, 1) train_acc = evaluator.eval( { "y_true": y_true[split_idx["train"]["paper"]], "y_pred": y_pred[split_idx["train"]["paper"]], } )["acc"] valid_acc = evaluator.eval( { "y_true": y_true[split_idx["valid"]["paper"]], "y_pred": y_pred[split_idx["valid"]["paper"]], } )["acc"] test_acc = evaluator.eval( { "y_true": y_true[split_idx["test"]["paper"]], "y_pred": y_pred[split_idx["test"]["paper"]], } )["acc"] return train_acc, valid_acc, test_acc def main(args): device = f"cuda:0" if th.cuda.is_available() else "cpu" g, labels, num_classes, split_idx, logger, train_loader = prepare_data(args) embed_layer = rel_graph_embed(g, 128) model = EntityClassify(g, 128, num_classes).to(device) print( f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}" ) print( f"Number of model parameters: {sum(p.numel() for p in model.parameters())}" ) for run in range(args.runs): embed_layer.reset_parameters() model.reset_parameters() # optimizer all_params = itertools.chain( model.parameters(), embed_layer.parameters() ) optimizer = th.optim.Adam(all_params, lr=0.01) logger = train( g, model, embed_layer, optimizer, train_loader, split_idx, labels, logger, device, run, ) logger.print_statistics(run) print("Final performance: ") logger.print_statistics() if __name__ == "__main__": parser = argparse.ArgumentParser(description="RGCN") parser.add_argument("--runs", type=int, default=10) args = parser.parse_args() main(args)