import torch import torch.nn as nn import torch.nn.functional as F import torchmetrics.functional as MF import dgl import dgl.nn as dglnn import time import numpy as np from ogb.nodeproppred import DglNodePropPredDataset class SAGE(nn.Module): def __init__(self, in_feats, n_hidden, n_classes): super().__init__() self.layers = nn.ModuleList() self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.dropout = nn.Dropout(0.5) def forward(self, sg, x): h = x for l, layer in enumerate(self.layers): h = layer(sg, h) if l != len(self.layers) - 1: h = F.relu(h) h = self.dropout(h) return h dataset = dgl.data.AsNodePredDataset(DglNodePropPredDataset('ogbn-products')) graph = dataset[0] # already prepares ndata['label'/'train_mask'/'val_mask'/'test_mask'] model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) num_partitions = 1000 sampler = dgl.dataloading.ClusterGCNSampler( graph, num_partitions, prefetch_ndata=['feat', 'label', 'train_mask', 'val_mask', 'test_mask']) # DataLoader for generic dataloading with a graph, a set of indices (any indices, like # partition IDs here), and a graph sampler. dataloader = dgl.dataloading.DataLoader( graph, torch.arange(num_partitions).to('cuda'), sampler, device='cuda', batch_size=100, shuffle=True, drop_last=False, num_workers=0, use_uva=True) durations = [] for _ in range(10): t0 = time.time() model.train() for it, sg in enumerate(dataloader): x = sg.ndata['feat'] y = sg.ndata['label'] m = sg.ndata['train_mask'].bool() y_hat = model(sg, x) loss = F.cross_entropy(y_hat[m], y[m]) opt.zero_grad() loss.backward() opt.step() if it % 20 == 0: acc = MF.accuracy(y_hat[m], y[m]) mem = torch.cuda.max_memory_allocated() / 1000000 print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') tt = time.time() print(tt - t0) durations.append(tt - t0) model.eval() with torch.no_grad(): val_preds, test_preds = [], [] val_labels, test_labels = [], [] for it, sg in enumerate(dataloader): x = sg.ndata['feat'] y = sg.ndata['label'] m_val = sg.ndata['val_mask'].bool() m_test = sg.ndata['test_mask'].bool() y_hat = model(sg, x) val_preds.append(y_hat[m_val]) val_labels.append(y[m_val]) test_preds.append(y_hat[m_test]) test_labels.append(y[m_test]) val_preds = torch.cat(val_preds, 0) val_labels = torch.cat(val_labels, 0) test_preds = torch.cat(test_preds, 0) test_labels = torch.cat(test_labels, 0) val_acc = MF.accuracy(val_preds, val_labels) test_acc = MF.accuracy(test_preds, test_labels) print('Validation acc:', val_acc.item(), 'Test acc:', test_acc.item()) print(np.mean(durations[4:]), np.std(durations[4:]))