import glob import os import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchmetrics.functional as MF import tqdm from ogb.nodeproppred import DglNodePropPredDataset from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from torchmetrics import Accuracy import dgl import dgl.nn.pytorch as dglnn class SAGE(LightningModule): def __init__(self, in_feats, n_hidden, n_classes): super().__init__() self.save_hyperparameters() self.layers = nn.ModuleList() self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean")) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean")) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean")) self.dropout = nn.Dropout(0.5) self.n_hidden = n_hidden self.n_classes = n_classes self.train_acc = Accuracy() self.val_acc = Accuracy() def forward(self, blocks, x): h = x for l, (layer, block) in enumerate(zip(self.layers, blocks)): h = layer(block, h) if l != len(self.layers) - 1: h = F.relu(h) h = self.dropout(h) return h def inference(self, g, device, batch_size, num_workers, buffer_device=None): # The difference between this inference function and the one in the official # example is that the intermediate results can also benefit from prefetching. g.ndata["h"] = g.ndata["feat"] sampler = dgl.dataloading.MultiLayerFullNeighborSampler( 1, prefetch_node_feats=["h"] ) dataloader = dgl.dataloading.DataLoader( g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, persistent_workers=(num_workers > 0), ) if buffer_device is None: buffer_device = device for l, layer in enumerate(self.layers): y = torch.zeros( g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes, device=buffer_device, ) for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): x = blocks[0].srcdata["h"] h = layer(blocks[0], x) if l != len(self.layers) - 1: h = F.relu(h) h = self.dropout(h) y[output_nodes] = h.to(buffer_device) g.ndata["h"] = y return y def training_step(self, batch, batch_idx): input_nodes, output_nodes, blocks = batch x = blocks[0].srcdata["feat"] y = blocks[-1].dstdata["label"] y_hat = self(blocks, x) loss = F.cross_entropy(y_hat, y) self.train_acc(torch.argmax(y_hat, 1), y) self.log( "train_acc", self.train_acc, prog_bar=True, on_step=True, on_epoch=False, ) return loss def validation_step(self, batch, batch_idx): input_nodes, output_nodes, blocks = batch x = blocks[0].srcdata["feat"] y = blocks[-1].dstdata["label"] y_hat = self(blocks, x) self.val_acc(torch.argmax(y_hat, 1), y) self.log( "val_acc", self.val_acc, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True, ) def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=0.001, weight_decay=5e-4 ) return optimizer class DataModule(LightningDataModule): def __init__( self, graph, train_idx, val_idx, fanouts, batch_size, n_classes ): super().__init__() sampler = dgl.dataloading.NeighborSampler( fanouts, prefetch_node_feats=["feat"], prefetch_labels=["label"] ) self.g = graph self.train_idx, self.val_idx = train_idx, val_idx self.sampler = sampler self.batch_size = batch_size self.in_feats = graph.ndata["feat"].shape[1] self.n_classes = n_classes def train_dataloader(self): return dgl.dataloading.DataLoader( self.g, self.train_idx.to("cuda"), self.sampler, device="cuda", batch_size=self.batch_size, shuffle=True, drop_last=False, # For CPU sampling, set num_workers to nonzero and use_uva=False # Set use_ddp to False for single GPU. num_workers=0, use_uva=True, use_ddp=True, ) def val_dataloader(self): return dgl.dataloading.DataLoader( self.g, self.val_idx.to("cuda"), self.sampler, device="cuda", batch_size=self.batch_size, shuffle=True, drop_last=False, num_workers=0, use_uva=True, ) if __name__ == "__main__": dataset = DglNodePropPredDataset("ogbn-products") graph, labels = dataset[0] graph.ndata["label"] = labels.squeeze() graph.create_formats_() split_idx = dataset.get_idx_split() train_idx, val_idx, test_idx = ( split_idx["train"], split_idx["valid"], split_idx["test"], ) datamodule = DataModule( graph, train_idx, val_idx, [15, 10, 5], 1024, dataset.num_classes ) model = SAGE(datamodule.in_feats, 256, datamodule.n_classes) # Train checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1) # Use this for single GPU # trainer = Trainer(gpus=[0], max_epochs=10, callbacks=[checkpoint_callback]) trainer = Trainer( gpus=[0, 1, 2, 3], max_epochs=10, callbacks=[checkpoint_callback], strategy="ddp_spawn", ) trainer.fit(model, datamodule=datamodule) # Test dirs = glob.glob("./lightning_logs/*") version = max([int(os.path.split(x)[-1].split("_")[-1]) for x in dirs]) logdir = "./lightning_logs/version_%d" % version print("Evaluating model in", logdir) ckpt = glob.glob(os.path.join(logdir, "checkpoints", "*"))[0] model = SAGE.load_from_checkpoint( checkpoint_path=ckpt, hparams_file=os.path.join(logdir, "hparams.yaml") ).to("cuda") with torch.no_grad(): pred = model.inference(graph, "cuda", 4096, 12, graph.device) pred = pred[test_idx] label = graph.ndata["label"][test_idx] acc = MF.accuracy(pred, label) print("Test accuracy:", acc)