Unverified Commit f931c6ba authored by Serge Panev's avatar Serge Panev Committed by GitHub
Browse files

[Examples] Add pure gpu mode in the GraphSAGE node classification and link prediction (#3856)


Signed-off-by: default avatarSerge Panev <spanev@nvidia.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 5fcd7f29
import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -11,6 +12,11 @@ import tqdm ...@@ -11,6 +12,11 @@ import tqdm
# (This is a long-standing issue) # (This is a long-standing issue)
from ogb.linkproppred import DglLinkPropPredDataset from ogb.linkproppred import DglLinkPropPredDataset
parser = argparse.ArgumentParser()
parser.add_argument('--pure-gpu', action='store_true',
help='Perform both sampling and training on GPU.')
args = parser.parse_args()
device = 'cuda' device = 'cuda'
def to_bidirected_with_reverse_mapping(g): def to_bidirected_with_reverse_mapping(g):
...@@ -69,8 +75,8 @@ class SAGE(nn.Module): ...@@ -69,8 +75,8 @@ class SAGE(nn.Module):
def inference(self, g, device, batch_size, num_workers, buffer_device=None): def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official # The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching. # example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat'] feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers) batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers)
...@@ -78,14 +84,17 @@ class SAGE(nn.Module): ...@@ -78,14 +84,17 @@ class SAGE(nn.Module):
buffer_device = device buffer_device = device
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device) y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device,
pin_memory=args.pure_gpu)
feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata['h'] x = feat[input_nodes]
h = layer(blocks[0], x) h = layer(blocks[0], x)
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = F.relu(h) h = F.relu(h)
y[output_nodes] = h.to(buffer_device) y[output_nodes] = h.to(buffer_device)
g.ndata['h'] = y feat = y
return y return y
...@@ -118,6 +127,7 @@ def evaluate(model, edge_split, device, num_workers): ...@@ -118,6 +127,7 @@ def evaluate(model, edge_split, device, num_workers):
dataset = DglLinkPropPredDataset('ogbl-citation2') dataset = DglLinkPropPredDataset('ogbl-citation2')
graph = dataset[0] graph = dataset[0]
graph, reverse_eids = to_bidirected_with_reverse_mapping(graph) graph, reverse_eids = to_bidirected_with_reverse_mapping(graph)
graph = graph.to('cuda' if args.pure_gpu else 'cpu')
reverse_eids = reverse_eids.to(device) reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(graph.num_edges()).to(device) seed_edges = torch.arange(graph.num_edges()).to(device)
edge_split = dataset.get_edge_split() edge_split = dataset.get_edge_split()
...@@ -132,7 +142,7 @@ sampler = dgl.dataloading.as_edge_prediction_sampler( ...@@ -132,7 +142,7 @@ sampler = dgl.dataloading.as_edge_prediction_sampler(
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
graph, seed_edges, sampler, graph, seed_edges, sampler,
device=device, batch_size=512, shuffle=True, device=device, batch_size=512, shuffle=True,
drop_last=False, num_workers=0, use_uva=True) drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
durations = [] durations = []
for epoch in range(10): for epoch in range(10):
...@@ -159,6 +169,6 @@ for epoch in range(10): ...@@ -159,6 +169,6 @@ for epoch in range(10):
break break
if epoch % 10 == 0: if epoch % 10 == 0:
model.eval() model.eval()
valid_mrr, test_mrr = evaluate(model, edge_split, device, 12) valid_mrr, test_mrr = evaluate(model, edge_split, device, 0 if args.pure_gpu else 12)
print('Validation MRR:', valid_mrr.item(), 'Test MRR:', test_mrr.item()) print('Validation MRR:', valid_mrr.item(), 'Test MRR:', test_mrr.item())
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
...@@ -10,6 +10,11 @@ from ogb.nodeproppred import DglNodePropPredDataset ...@@ -10,6 +10,11 @@ from ogb.nodeproppred import DglNodePropPredDataset
import tqdm import tqdm
import argparse import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--pure-gpu', action='store_true',
help='Perform both sampling and training on GPU.')
args = parser.parse_args()
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
super().__init__() super().__init__()
...@@ -31,29 +36,32 @@ class SAGE(nn.Module): ...@@ -31,29 +36,32 @@ class SAGE(nn.Module):
return h return h
def inference(self, g, device, batch_size, num_workers, buffer_device=None): def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official feat = g.ndata['feat']
# example is that the intermediate results can also benefit from prefetching. sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, batch_size=batch_size, shuffle=False, drop_last=False,
persistent_workers=(num_workers > 0)) num_workers=num_workers)
if buffer_device is None: if buffer_device is None:
buffer_device = device buffer_device = device
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = torch.zeros( y = torch.empty(
g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes, g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device=buffer_device) device=buffer_device, pin_memory=True)
feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata['h'] # use an explicitly contiguous slice
x = feat[input_nodes]
h = layer(blocks[0], x) h = layer(blocks[0], x)
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = F.relu(h) h = F.relu(h)
h = self.dropout(h) h = self.dropout(h)
y[output_nodes] = h.to(buffer_device) # be design, our output nodes are contiguous so we can take
g.ndata['h'] = y # advantage of that here
y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
feat = y
return y return y
dataset = DglNodePropPredDataset('ogbn-products') dataset = DglNodePropPredDataset('ogbn-products')
...@@ -65,6 +73,9 @@ train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_i ...@@ -65,6 +73,9 @@ train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_i
device = 'cuda' device = 'cuda'
train_idx = train_idx.to(device) train_idx = train_idx.to(device)
valid_idx = valid_idx.to(device) valid_idx = valid_idx.to(device)
test_idx = test_idx.to(device)
graph = graph.to('cuda' if args.pure_gpu else 'cpu')
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device) model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
...@@ -73,10 +84,10 @@ sampler = dgl.dataloading.NeighborSampler( ...@@ -73,10 +84,10 @@ sampler = dgl.dataloading.NeighborSampler(
[15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label']) [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True, graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=True) drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
valid_dataloader = dgl.dataloading.NodeDataLoader( valid_dataloader = dgl.dataloading.NodeDataLoader(
graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True, graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=True) drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
durations = [] durations = []
for _ in range(10): for _ in range(10):
...@@ -114,8 +125,8 @@ print(np.mean(durations[4:]), np.std(durations[4:])) ...@@ -114,8 +125,8 @@ print(np.mean(durations[4:]), np.std(durations[4:]))
# Test accuracy and offline inference of all nodes # Test accuracy and offline inference of all nodes
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
pred = model.inference(graph, device, 4096, 12, graph.device) pred = model.inference(graph, device, 4096, 0, 'cpu')
pred = pred[test_idx] pred = pred[test_idx].to(device)
label = graph.ndata['label'][test_idx] label = graph.ndata['label'][test_idx]
acc = MF.accuracy(pred, label) acc = MF.accuracy(pred, label)
print('Test acc:', acc.item()) print('Test acc:', acc.item())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment