Unverified Commit b76d0ed1 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Refactor] Regolden graphsage example for future guide (#4186)



* Regolden graphsage example to guide others

* update golden

* update

* Update example and propagate to original folder

* Update to remove ^M (windows DOS) character

* update

* Merge file changes and update README

* Minor comment update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent a6bd96aa
...@@ -2,10 +2,9 @@ Inductive Representation Learning on Large Graphs (GraphSAGE) ...@@ -2,10 +2,9 @@ Inductive Representation Learning on Large Graphs (GraphSAGE)
============ ============
- Paper link: [http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf](http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf) - Paper link: [http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf](http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf)
- Author's code repo: [https://github.com/williamleif/graphsage-simple](https://github.com/williamleif/graphsage-simple). Note that the original code is - Author's code repo: [https://github.com/williamleif/graphsage-simple](https://github.com/williamleif/graphsage-simple)
simple reference implementation of GraphSAGE.
Advanced usages, including how to run pure GPU sampling, how to train with PyTorch Lightning, etc., are in the `advanced` directory. For advanced usages, including training with multi-gpu/multi-node, and PyTorch Lightning, etc., more examples can be found in [advanced](https://github.com/dmlc/dgl/tree/master/examples/pytorch/graphsage/advanced) and [dist](https://github.com/dmlc/dgl/tree/master/examples/pytorch/graphsage/dist) directory.
Requirements Requirements
------------ ------------
...@@ -14,7 +13,7 @@ Requirements ...@@ -14,7 +13,7 @@ Requirements
pip install requests torchmetrics pip install requests torchmetrics
``` ```
Results How to run
------- -------
### Full graph training ### Full graph training
...@@ -24,19 +23,27 @@ Run with following (available dataset: "cora", "citeseer", "pubmed") ...@@ -24,19 +23,27 @@ Run with following (available dataset: "cora", "citeseer", "pubmed")
python3 train_full.py --dataset cora --gpu 0 # full graph python3 train_full.py --dataset cora --gpu 0 # full graph
``` ```
Results:
```
* cora: ~0.8330 * cora: ~0.8330
* citeseer: ~0.7110 * citeseer: ~0.7110
* pubmed: ~0.7830 * pubmed: ~0.7830
```
### Minibatch training for node classification ### Minibatch training for node classification
Train w/ mini-batch sampling for node classification on OGB-products: Train w/ mini-batch sampling in mixed mode (CPU+GPU) for node classification on "ogbn-products"
```bash ```bash
python3 node_classification.py python3 node_classification.py
python3 multi_gpu_node_classification.py python3 multi_gpu_node_classification.py
``` ```
Results:
```
Test Accuracy: 0.7632
```
### PyTorch Lightning for node classification ### PyTorch Lightning for node classification
Train w/ mini-batch sampling for node classification with PyTorch Lightning on OGB-products. Train w/ mini-batch sampling for node classification with PyTorch Lightning on OGB-products.
......
...@@ -4,27 +4,22 @@ import torch.nn.functional as F ...@@ -4,27 +4,22 @@ import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
import dgl import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
import time from dgl.data import AsNodePredDataset
import numpy as np from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import tqdm import tqdm
import argparse import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--pure-gpu', action='store_true',
help='Perform both sampling and training on GPU.')
args = parser.parse_args()
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_size, hid_size, out_size):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean'))
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden self.hid_size = hid_size
self.n_classes = n_classes self.out_size = out_size
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
...@@ -35,98 +30,115 @@ class SAGE(nn.Module): ...@@ -35,98 +30,115 @@ class SAGE(nn.Module):
h = self.dropout(h) h = self.dropout(h)
return h return h
def inference(self, g, device, batch_size, num_workers, buffer_device=None): def inference(self, g, device, batch_size):
"""Conduct layer-wise inference to get all the node embeddings."""
feat = g.ndata['feat'] feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat']) sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.DataLoader( dataloader = DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False, batch_size=batch_size, shuffle=False, drop_last=False,
num_workers=num_workers) num_workers=0)
buffer_device = 'cpu'
if buffer_device is None: pin_memory = (buffer_device != device)
buffer_device = device
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = torch.empty( y = torch.empty(
g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes, g.num_nodes(), self.hid_size if l != len(self.layers) - 1 else self.out_size,
device=buffer_device, pin_memory=True) device=buffer_device, pin_memory=pin_memory)
feat = feat.to(device) feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
# use an explicitly contiguous slice
x = feat[input_nodes] x = feat[input_nodes]
h = layer(blocks[0], x) h = layer(blocks[0], x) # len(blocks) = 1
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = F.relu(h) h = F.relu(h)
h = self.dropout(h) h = self.dropout(h)
# be design, our output nodes are contiguous so we can take # by design, our output nodes are contiguous
# advantage of that here
y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device) y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
feat = y feat = y
return y return y
dataset = DglNodePropPredDataset('ogbn-products') def evaluate(model, graph, dataloader):
graph, labels = dataset[0]
graph.ndata['label'] = labels.squeeze()
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
device = 'cuda'
train_idx = train_idx.to(device)
valid_idx = valid_idx.to(device)
test_idx = test_idx.to(device)
graph = graph.to('cuda' if args.pure_gpu else 'cpu')
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
sampler = dgl.dataloading.NeighborSampler(
[15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
durations = []
for _ in range(10):
model.train()
t0 = time.time()
for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label']
y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y)
opt.zero_grad()
loss.backward()
opt.step()
if it % 20 == 0:
acc = MF.accuracy(y_hat, y)
mem = torch.cuda.max_memory_allocated() / 1000000
print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB')
tt = time.time()
print(tt - t0)
durations.append(tt - t0)
model.eval() model.eval()
ys = [] ys = []
y_hats = [] y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader): for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
with torch.no_grad(): with torch.no_grad():
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata['feat']
ys.append(blocks[-1].dstdata['label']) ys.append(blocks[-1].dstdata['label'])
y_hats.append(model(blocks, x)) y_hats.append(model(blocks, x))
acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys)) return MF.accuracy(torch.cat(y_hats), torch.cat(ys))
print('Validation acc:', acc.item())
def layerwise_infer(args, device, graph, nid, model, batch_size):
model.eval()
with torch.no_grad():
pred = model.inference(graph, device, batch_size)
pred = pred[nid]
label = graph.ndata['label'][nid]
return MF.accuracy(pred, label)
def train(args, device, g, dataset, model):
# create sampler & dataloader
train_idx = dataset.train_idx.to(device)
val_idx = dataset.val_idx.to(device)
sampler = NeighborSampler([10, 10, 10], # fanout for layer-0, layer-1 and layer-2
prefetch_node_feats=['feat'],
prefetch_labels=['label'])
use_uva = (args.mode == 'mixed')
train_dataloader = DataLoader(g, train_idx, sampler, device=device,
batch_size=1024, shuffle=True,
drop_last=False, num_workers=0,
use_uva=use_uva)
val_dataloader = DataLoader(g, val_idx, sampler, device=device,
batch_size=1024, shuffle=True,
drop_last=False, num_workers=0,
use_uva=use_uva)
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
for epoch in range(10):
model.train()
total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label']
y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y)
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.item()
acc = evaluate(model, g, val_dataloader)
print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} "
.format(epoch, total_loss / (it+1), acc.item()))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--mode", default='mixed', choices=['cpu', 'mixed', 'puregpu'],
help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training.")
args = parser.parse_args()
if not torch.cuda.is_available():
args.mode = 'cpu'
print(f'Training in {args.mode} mode.')
# load and preprocess dataset
print('Loading data')
dataset = AsNodePredDataset(DglNodePropPredDataset('ogbn-products'))
g = dataset[0]
g = g.to('cuda' if args.mode == 'puregpu' else 'cpu')
device = torch.device('cpu' if args.mode == 'cpu' else 'cuda')
# create GraphSAGE model
in_size = g.ndata['feat'].shape[1]
out_size = dataset.num_classes
model = SAGE(in_size, 256, out_size).to(device)
print(np.mean(durations[4:]), np.std(durations[4:])) # model training
print('Training')
train(args, device, g, dataset, model)
# Test accuracy and offline inference of all nodes # test the model
model.eval() print('Testing')
with torch.no_grad(): acc = layerwise_infer(args, device, g, dataset.test_idx.to(device), model, batch_size=4096)
pred = model.inference(graph, device, 4096, 0, 'cpu') print("Test Accuracy {:.4f}".format(acc.item()))
pred = pred[test_idx].to(device)
label = graph.ndata['label'][test_idx].to(device)
acc = MF.accuracy(pred, label)
print('Test acc:', acc.item())
...@@ -773,8 +773,7 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -773,8 +773,7 @@ class DataLoader(torch.utils.data.DataLoader):
else: else:
if self.graph.device != indices_device: if self.graph.device != indices_device:
raise ValueError( raise ValueError(
'Expect graph and indices to be on the same device. ' 'Expect graph and indices to be on the same device when use_uva=False. ')
'If you wish to use UVA sampling, please set use_uva=True.')
if self.graph.device.type == 'cuda' and num_workers > 0: if self.graph.device.type == 'cuda' and num_workers > 0:
raise ValueError('num_workers must be 0 if graph and indices are on CUDA.') raise ValueError('num_workers must be 0 if graph and indices are on CUDA.')
if self.graph.device.type == 'cpu' and num_workers > 0: if self.graph.device.type == 'cpu' and num_workers > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment