Unverified Commit f931c6ba authored by Serge Panev's avatar Serge Panev Committed by GitHub
Browse files

[Examples] Add pure gpu mode in the GraphSAGE node classification and link prediction (#3856)


Signed-off-by: default avatarSerge Panev <spanev@nvidia.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 5fcd7f29
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -11,6 +12,11 @@ import tqdm
# (This is a long-standing issue)
from ogb.linkproppred import DglLinkPropPredDataset
parser = argparse.ArgumentParser()
parser.add_argument('--pure-gpu', action='store_true',
help='Perform both sampling and training on GPU.')
args = parser.parse_args()
device = 'cuda'
def to_bidirected_with_reverse_mapping(g):
......@@ -69,8 +75,8 @@ class SAGE(nn.Module):
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers)
......@@ -78,14 +84,17 @@ class SAGE(nn.Module):
buffer_device = device
for l, layer in enumerate(self.layers):
y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device)
y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device,
pin_memory=args.pure_gpu)
feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata['h']
x = feat[input_nodes]
h = layer(blocks[0], x)
if l != len(self.layers) - 1:
h = F.relu(h)
y[output_nodes] = h.to(buffer_device)
g.ndata['h'] = y
feat = y
return y
......@@ -118,6 +127,7 @@ def evaluate(model, edge_split, device, num_workers):
dataset = DglLinkPropPredDataset('ogbl-citation2')
graph = dataset[0]
graph, reverse_eids = to_bidirected_with_reverse_mapping(graph)
graph = graph.to('cuda' if args.pure_gpu else 'cpu')
reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(graph.num_edges()).to(device)
edge_split = dataset.get_edge_split()
......@@ -132,7 +142,7 @@ sampler = dgl.dataloading.as_edge_prediction_sampler(
dataloader = dgl.dataloading.DataLoader(
graph, seed_edges, sampler,
device=device, batch_size=512, shuffle=True,
drop_last=False, num_workers=0, use_uva=True)
drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
durations = []
for epoch in range(10):
......@@ -159,6 +169,6 @@ for epoch in range(10):
break
if epoch % 10 == 0:
model.eval()
valid_mrr, test_mrr = evaluate(model, edge_split, device, 12)
valid_mrr, test_mrr = evaluate(model, edge_split, device, 0 if args.pure_gpu else 12)
print('Validation MRR:', valid_mrr.item(), 'Test MRR:', test_mrr.item())
print(np.mean(durations[4:]), np.std(durations[4:]))
......@@ -10,6 +10,11 @@ from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--pure-gpu', action='store_true',
help='Perform both sampling and training on GPU.')
args = parser.parse_args()
class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes):
super().__init__()
......@@ -31,29 +36,32 @@ class SAGE(nn.Module):
return h
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers,
persistent_workers=(num_workers > 0))
batch_size=batch_size, shuffle=False, drop_last=False,
num_workers=num_workers)
if buffer_device is None:
buffer_device = device
for l, layer in enumerate(self.layers):
y = torch.zeros(
y = torch.empty(
g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device=buffer_device)
device=buffer_device, pin_memory=True)
feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata['h']
# use an explicitly contiguous slice
x = feat[input_nodes]
h = layer(blocks[0], x)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
y[output_nodes] = h.to(buffer_device)
g.ndata['h'] = y
# be design, our output nodes are contiguous so we can take
# advantage of that here
y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
feat = y
return y
dataset = DglNodePropPredDataset('ogbn-products')
......@@ -65,6 +73,9 @@ train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_i
device = 'cuda'
train_idx = train_idx.to(device)
valid_idx = valid_idx.to(device)
test_idx = test_idx.to(device)
graph = graph.to('cuda' if args.pure_gpu else 'cpu')
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
......@@ -73,10 +84,10 @@ sampler = dgl.dataloading.NeighborSampler(
[15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=True)
drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
valid_dataloader = dgl.dataloading.NodeDataLoader(
graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=True)
drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
durations = []
for _ in range(10):
......@@ -114,8 +125,8 @@ print(np.mean(durations[4:]), np.std(durations[4:]))
# Test accuracy and offline inference of all nodes
model.eval()
with torch.no_grad():
pred = model.inference(graph, device, 4096, 12, graph.device)
pred = pred[test_idx]
pred = model.inference(graph, device, 4096, 0, 'cpu')
pred = pred[test_idx].to(device)
label = graph.ndata['label'][test_idx]
acc = MF.accuracy(pred, label)
print('Test acc:', acc.item())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment