import torch import torch.nn as nn import torch.nn.functional as F import torchmetrics.functional as MF import dgl import dgl.function as fn import dgl.nn as dglnn from dgl.dataloading import NeighborSampler, DataLoader from dgl import apply_each from ogb.nodeproppred import DglNodePropPredDataset import tqdm class HeteroGAT(nn.Module): def __init__(self, etypes, in_size, hid_size, out_size, n_heads=4): super().__init__() self.layers = nn.ModuleList() self.layers.append(dglnn.HeteroGraphConv({ etype: dglnn.GATConv(in_size, hid_size // n_heads, n_heads) for etype in etypes})) self.layers.append(dglnn.HeteroGraphConv({ etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads) for etype in etypes})) self.layers.append(dglnn.HeteroGraphConv({ etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads) for etype in etypes})) self.dropout = nn.Dropout(0.5) self.linear = nn.Linear(hid_size, out_size) # Should be HeteroLinear def forward(self, blocks, x): h = x for l, (layer, block) in enumerate(zip(self.layers, blocks)): h = layer(block, h) # One thing is that h might return tensors with zero rows if the number of dst nodes # of one node type is 0. x.view(x.shape[0], -1) wouldn't work in this case. h = apply_each(h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2])) if l != len(self.layers) - 1: h = apply_each(h, F.relu) h = apply_each(h, self.dropout) return self.linear(h['paper']) def evaluate(model, dataloader, desc): preds = [] labels = [] with torch.no_grad(): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader, desc=desc): x = blocks[0].srcdata['feat'] y = blocks[-1].dstdata['label']['paper'][:, 0] y_hat = model(blocks, x) preds.append(y_hat.cpu()) labels.append(y.cpu()) preds = torch.cat(preds, 0) labels = torch.cat(labels, 0) acc = MF.accuracy(preds, labels) return acc def train(train_loader, val_loader, test_loader, model): # loss function and optimizer loss_fcn = nn.CrossEntropyLoss() opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) # training loop for epoch in range(10): model.train() total_loss = 0 for it, (input_nodes, output_nodes, blocks) in enumerate(tqdm.tqdm(train_dataloader, desc="Train")): x = blocks[0].srcdata['feat'] y = blocks[-1].dstdata['label']['paper'][:, 0] y_hat = model(blocks, x) loss = loss_fcn(y_hat, y) opt.zero_grad() loss.backward() opt.step() total_loss += loss.item() model.eval() val_acc = evaluate(model, val_dataloader, 'Val. ') test_acc = evaluate(model, test_dataloader, 'Test ') print(f'Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}') if __name__ == '__main__': print(f'Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # load and preprocess dataset print('Loading data') dataset = DglNodePropPredDataset('ogbn-mag') graph, labels = dataset[0] graph.ndata['label'] = labels # add reverse edges in "cites" relation, and add reverse edge types for the rest etypes graph = dgl.AddReverse()(graph) # precompute the author, topic, and institution features graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='rev_writes') graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='has_topic') graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='affiliated_with') # find train/val/test indexes split_idx = dataset.get_idx_split() train_idx, val_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] train_idx = apply_each(train_idx, lambda x: x.to(device)) val_idx = apply_each(val_idx, lambda x: x.to(device)) test_idx = apply_each(test_idx, lambda x: x.to(device)) # create RGAT model in_size = graph.ndata['feat']['paper'].shape[1] out_size = dataset.num_classes model = HeteroGAT(graph.etypes, in_size, 256, out_size).to(device) # dataloader + model training + testing train_sampler = NeighborSampler([5, 5, 5], prefetch_node_feats={k: ['feat'] for k in graph.ntypes}, prefetch_labels={'paper': ['label']}) val_sampler = NeighborSampler([10, 10, 10], prefetch_node_feats={k: ['feat'] for k in graph.ntypes}, prefetch_labels={'paper': ['label']}) train_dataloader = DataLoader(graph, train_idx, train_sampler, device=device, batch_size=1000, shuffle=True, drop_last=False, num_workers=0, use_uva=torch.cuda.is_available()) val_dataloader = DataLoader(graph, val_idx, val_sampler, device=device, batch_size=1000, shuffle=False, drop_last=False, num_workers=0, use_uva=torch.cuda.is_available()) test_dataloader = DataLoader(graph, test_idx, val_sampler, device=device, batch_size=1000, shuffle=False, drop_last=False, num_workers=0, use_uva=torch.cuda.is_available()) train(train_dataloader, val_dataloader, test_dataloader, model)