Unverified Commit 3f138eba authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Bug fixes in new dataloader (#3727)



* fixes

* fix

* more fixes

* update

* oops

* lint?

* temporarily revert - will fix in another PR

* more fixes

* skipping mxnet test

* address comments

* fix DDP

* fix edge dataloader exclusion problems

* stupid bug

* fix

* use_uvm option

* fix

* fixes

* fixes

* fixes

* fixes

* add evaluation for cluster gcn and ddp

* stupid bug again

* fixes

* move sanity checks to only support DGLGraphs

* pytorch lightning compatibility fixes

* remove

* poke

* more fixes

* fix

* fix

* disable test

* docstrings

* why is it getting a memory leak?

* fix

* update

* updates and temporarily disable forkingpickler

* update

* fix?

* fix?

* oops

* oops

* fix

* lint

* huh

* uh

* update

* fix

* made it memory efficient

* refine exclude interface

* fix tutorial

* fix tutorial

* fix graph duplication in CPU dataloader workers

* lint

* lint

* Revert "lint"

This reverts commit 805484dd553695111b5fb37f2125214a6b7276e9.

* Revert "lint"

This reverts commit 0bce411b2b415c2ab770343949404498436dc8b2.

* Revert "fix graph duplication in CPU dataloader workers"

This reverts commit 9e3a8cf34c175d3093c773f6bb023b155f2bd27f.
Co-authored-by: default avatarxiny <xiny@nvidia.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 7b9afbfa
...@@ -8,8 +8,6 @@ import time ...@@ -8,8 +8,6 @@ import time
import numpy as np import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
USE_WRAPPER = True
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
super().__init__() super().__init__()
...@@ -40,11 +38,6 @@ graph.ndata['test_mask'] = torch.zeros(graph.num_nodes(), dtype=torch.bool).inde ...@@ -40,11 +38,6 @@ graph.ndata['test_mask'] = torch.zeros(graph.num_nodes(), dtype=torch.bool).inde
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
if USE_WRAPPER:
import dglnew
graph.create_formats_()
graph = dglnew.graph.wrapper.DGLGraphStorage(graph)
num_partitions = 1000 num_partitions = 1000
sampler = dgl.dataloading.ClusterGCNSampler( sampler = dgl.dataloading.ClusterGCNSampler(
graph, num_partitions, graph, num_partitions,
...@@ -61,14 +54,13 @@ dataloader = dgl.dataloading.DataLoader( ...@@ -61,14 +54,13 @@ dataloader = dgl.dataloading.DataLoader(
batch_size=100, batch_size=100,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
pin_memory=True, num_workers=0,
num_workers=8, use_uva=True)
persistent_workers=True,
use_prefetch_thread=True) # TBD: could probably remove this argument
durations = [] durations = []
for _ in range(10): for _ in range(10):
t0 = time.time() t0 = time.time()
model.train()
for it, sg in enumerate(dataloader): for it, sg in enumerate(dataloader):
x = sg.ndata['feat'] x = sg.ndata['feat']
y = sg.ndata['label'][:, 0] y = sg.ndata['label'][:, 0]
...@@ -85,4 +77,27 @@ for _ in range(10): ...@@ -85,4 +77,27 @@ for _ in range(10):
tt = time.time() tt = time.time()
print(tt - t0) print(tt - t0)
durations.append(tt - t0) durations.append(tt - t0)
model.eval()
with torch.no_grad():
val_preds, test_preds = [], []
val_labels, test_labels = [], []
for it, sg in enumerate(dataloader):
x = sg.ndata['feat']
y = sg.ndata['label'][:, 0]
m_val = sg.ndata['valid_mask']
m_test = sg.ndata['test_mask']
y_hat = model(sg, x)
val_preds.append(y_hat[m_val])
val_labels.append(y[m_val])
test_preds.append(y_hat[m_test])
test_labels.append(y[m_test])
val_preds = torch.cat(val_preds, 0)
val_labels = torch.cat(val_labels, 0)
test_preds = torch.cat(test_preds, 0)
test_labels = torch.cat(test_labels, 0)
val_acc = MF.accuracy(val_preds, val_labels)
test_acc = MF.accuracy(test_preds, test_labels)
print('Validation acc:', val_acc.item(), 'Test acc:', test_acc.item())
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
...@@ -9,8 +9,9 @@ import dgl.nn as dglnn ...@@ -9,8 +9,9 @@ import dgl.nn as dglnn
import time import time
import numpy as np import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
USE_WRAPPER = False USE_UVA = True
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
...@@ -20,6 +21,8 @@ class SAGE(nn.Module): ...@@ -20,6 +21,8 @@ class SAGE(nn.Module):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden
self.n_classes = n_classes
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
...@@ -30,41 +33,66 @@ class SAGE(nn.Module): ...@@ -30,41 +33,66 @@ class SAGE(nn.Module):
h = self.dropout(h) h = self.dropout(h)
return h return h
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers,
persistent_workers=(num_workers > 0))
if buffer_device is None:
buffer_device = device
def train(rank, world_size, graph, num_classes, split_idx): for l, layer in enumerate(self.layers):
y = torch.zeros(
g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device=buffer_device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata['h']
h = layer(blocks[0], x)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
y[output_nodes] = h.to(buffer_device)
g.ndata['h'] = y
return y
def train(rank, world_size, shared_memory_name, features, num_classes, split_idx):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
dist.init_process_group('nccl', 'tcp://127.0.0.1:12347', world_size=world_size, rank=rank) dist.init_process_group('nccl', 'tcp://127.0.0.1:12347', world_size=world_size, rank=rank)
graph = dgl.hetero_from_shared_memory(shared_memory_name)
feat, labels = features
graph.ndata['feat'] = feat
graph.ndata['label'] = labels
model = SAGE(graph.ndata['feat'].shape[1], 256, num_classes).cuda() model = SAGE(graph.ndata['feat'].shape[1], 256, num_classes).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank) model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
if USE_WRAPPER:
import dglnew if USE_UVA:
graph = dglnew.graph.wrapper.DGLGraphStorage(graph) train_idx = train_idx.to('cuda')
sampler = dgl.dataloading.NeighborSampler( sampler = dgl.dataloading.NeighborSampler(
[5, 5, 5], output_device='cpu', prefetch_node_feats=['feat'], [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
prefetch_labels=['label']) train_dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.NodeDataLoader( graph, train_idx, sampler,
graph, device='cuda', batch_size=1000, shuffle=True, drop_last=False,
train_idx, num_workers=0, use_ddp=True, use_uva=USE_UVA)
sampler, valid_dataloader = dgl.dataloading.NodeDataLoader(
device='cuda', graph, valid_idx, sampler, device='cuda', batch_size=1024, shuffle=True,
batch_size=1000, drop_last=False, num_workers=0, use_uva=USE_UVA)
shuffle=True,
drop_last=False,
pin_memory=True,
num_workers=4,
persistent_workers=True,
use_ddp=True,
use_prefetch_thread=True) # TBD: could probably remove this argument
durations = [] durations = []
for _ in range(10): for _ in range(10):
model.train()
t0 = time.time() t0 = time.time()
for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label'][:, 0] y = blocks[-1].dstdata['label'][:, 0]
y_hat = model(blocks, x) y_hat = model(blocks, x)
...@@ -80,27 +108,38 @@ def train(rank, world_size, graph, num_classes, split_idx): ...@@ -80,27 +108,38 @@ def train(rank, world_size, graph, num_classes, split_idx):
if rank == 0: if rank == 0:
print(tt - t0) print(tt - t0)
durations.append(tt - t0) durations.append(tt - t0)
model.eval()
ys = []
y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader):
with torch.no_grad():
x = blocks[0].srcdata['feat']
ys.append(blocks[-1].dstdata['label'])
y_hats.append(model.module(blocks, x))
acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys))
print('Validation acc:', acc.item())
dist.barrier()
if rank == 0: if rank == 0:
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
model.eval()
with torch.no_grad():
pred = model.module.inference(graph, 'cuda', 1000, 12, graph.device)
acc = MF.accuracy(pred.to(graph.device), graph.ndata['label'])
print('Test acc:', acc.item())
if __name__ == '__main__': if __name__ == '__main__':
dataset = DglNodePropPredDataset('ogbn-products') dataset = DglNodePropPredDataset('ogbn-products')
graph, labels = dataset[0] graph, labels = dataset[0]
graph.ndata['label'] = labels shared_memory_name = 'shm' # can be any string
graph.create_formats_() feat = graph.ndata['feat']
graph = graph.shared_memory(shared_memory_name)
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
num_classes = dataset.num_classes num_classes = dataset.num_classes
n_procs = 4 n_procs = 4
# Tested with mp.spawn and fork. Both worked and got 4s per epoch with 4 GPUs # Tested with mp.spawn and fork. Both worked and got 4s per epoch with 4 GPUs
# and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples. # and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples.
#import torch.multiprocessing as mp import torch.multiprocessing as mp
#mp.spawn(train, args=(n_procs, graph, num_classes, split_idx), nprocs=n_procs) mp.spawn(train, args=(n_procs, shared_memory_name, (feat, labels), num_classes, split_idx), nprocs=n_procs)
import dgl.multiprocessing as mp
procs = []
for i in range(n_procs):
p = mp.Process(target=train, args=(i, n_procs, graph, num_classes, split_idx))
p.start()
procs.append(p)
for p in procs:
p.join()
...@@ -6,20 +6,54 @@ import dgl ...@@ -6,20 +6,54 @@ import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
import time import time
import numpy as np import numpy as np
import tqdm
# OGB must follow DGL if both DGL and PyG are installed. Otherwise DataLoader will hang. # OGB must follow DGL if both DGL and PyG are installed. Otherwise DataLoader will hang.
# (This is a long-standing issue) # (This is a long-standing issue)
from ogb.nodeproppred import DglNodePropPredDataset from ogb.linkproppred import DglLinkPropPredDataset
USE_WRAPPER = True USE_UVA = False
device = 'cuda'
def to_bidirected_with_reverse_mapping(g):
"""Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
is the reverse edge of edge ID ``i``.
Does not work with graphs that have self-loops.
"""
g_simple, mapping = dgl.to_simple(
dgl.add_reverse_edges(g), return_counts='count', writeback_mapping=True)
c = g_simple.edata['count']
num_edges = g.num_edges()
mapping_offset = torch.zeros(g_simple.num_edges() + 1, dtype=g_simple.idtype)
mapping_offset[1:] = c.cumsum(0)
idx = mapping.argsort()
idx_uniq = idx[mapping_offset[:-1]]
reverse_idx = torch.where(idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges)
reverse_mapping = mapping[reverse_idx]
# Correctness check
src1, dst1 = g_simple.edges()
src2, dst2 = g_simple.find_edges(reverse_mapping)
assert torch.equal(src1, dst2)
assert torch.equal(src2, dst1)
return g_simple, reverse_mapping
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden):
super().__init__() super().__init__()
self.n_hidden = n_hidden
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.dropout = nn.Dropout(0.5) self.predictor = nn.Sequential(
nn.Linear(n_hidden, n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, 1))
def predict(self, h_src, h_dst):
return self.predictor(h_src * h_dst)
def forward(self, pair_graph, neg_pair_graph, blocks, x): def forward(self, pair_graph, neg_pair_graph, blocks, x):
h = x h = x
...@@ -27,50 +61,88 @@ class SAGE(nn.Module): ...@@ -27,50 +61,88 @@ class SAGE(nn.Module):
h = layer(block, h) h = layer(block, h)
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = F.relu(h) h = F.relu(h)
h = self.dropout(h) pos_src, pos_dst = pair_graph.edges()
with pair_graph.local_scope(), neg_pair_graph.local_scope(): neg_src, neg_dst = neg_pair_graph.edges()
pair_graph.ndata['h'] = neg_pair_graph.ndata['h'] = h h_pos = self.predict(h[pos_src], h[pos_dst])
pair_graph.apply_edges(dgl.function.u_dot_v('h', 'h', 's')) h_neg = self.predict(h[neg_src], h[neg_dst])
neg_pair_graph.apply_edges(dgl.function.u_dot_v('h', 'h', 's')) return h_pos, h_neg
return pair_graph.edata['s'], neg_pair_graph.edata['s']
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
dataset = DglNodePropPredDataset('ogbn-products') # The difference between this inference function and the one in the official
graph, labels = dataset[0] # example is that the intermediate results can also benefit from prefetching.
graph.ndata['label'] = labels g.ndata['h'] = g.ndata['feat']
split_idx = dataset.get_idx_split() sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers)
if buffer_device is None:
buffer_device = device
for l, layer in enumerate(self.layers):
y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata['h']
h = layer(blocks[0], x)
if l != len(self.layers) - 1:
h = F.relu(h)
y[output_nodes] = h.to(buffer_device)
g.ndata['h'] = y
return y
def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500):
rr = torch.zeros(src.shape[0])
for start in tqdm.trange(0, src.shape[0], batch_size):
end = min(start + batch_size, src.shape[0])
all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
h_src = node_emb[src[start:end]][:, None, :].to(device)
h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
pred = model.predict(h_src, h_dst).squeeze(-1)
relevance = torch.zeros(*pred.shape, dtype=torch.bool)
relevance[:, 0] = True
rr[start:end] = MF.retrieval_reciprocal_rank(pred, relevance)
return rr.mean()
def evaluate(model, edge_split, device, num_workers):
with torch.no_grad():
node_emb = model.inference(graph, device, 4096, num_workers, 'cpu')
results = []
for split in ['valid', 'test']:
src = edge_split[split]['source_node'].to(device)
dst = edge_split[split]['target_node'].to(device)
neg_dst = edge_split[split]['target_node_neg'].to(device)
results.append(compute_mrr(model, node_emb, src, dst, neg_dst, device))
return results
dataset = DglLinkPropPredDataset('ogbl-citation2')
graph = dataset[0]
graph, reverse_eids = to_bidirected_with_reverse_mapping(graph)
seed_edges = torch.arange(graph.num_edges())
edge_split = dataset.get_edge_split()
model = SAGE(graph.ndata['feat'].shape[1], 256).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
num_edges = graph.num_edges() if not USE_UVA:
train_eids = torch.arange(num_edges) graph = graph.to(device)
if USE_WRAPPER: reverse_eids = reverse_eids.to(device)
import dglnew seed_edges = torch.arange(graph.num_edges()).to(device)
graph.create_formats_()
graph = dglnew.graph.wrapper.DGLGraphStorage(graph)
sampler = dgl.dataloading.NeighborSampler( sampler = dgl.dataloading.NeighborSampler([15, 10, 5], prefetch_node_feats=['feat'])
[5, 5, 5], output_device='cpu', prefetch_node_feats=['feat'],
prefetch_labels=['label'])
dataloader = dgl.dataloading.EdgeDataLoader( dataloader = dgl.dataloading.EdgeDataLoader(
graph, graph, seed_edges, sampler,
train_eids, device=device, batch_size=512, shuffle=True,
sampler, drop_last=False, num_workers=0,
device='cuda',
batch_size=1000,
shuffle=True,
drop_last=False,
pin_memory=True,
num_workers=8,
persistent_workers=True,
use_prefetch_thread=True, # TBD: could probably remove this argument
exclude='reverse_id', exclude='reverse_id',
reverse_eids=torch.arange(num_edges) ^ 1, reverse_eids=reverse_eids,
negative_sampler=dgl.dataloading.negative_sampler.Uniform(5)) negative_sampler=dgl.dataloading.negative_sampler.Uniform(1),
use_uva=USE_UVA)
durations = [] durations = []
for _ in range(10): for epoch in range(10):
model.train()
t0 = time.time() t0 = time.time()
for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader): for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader):
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata['feat']
...@@ -83,12 +155,16 @@ for _ in range(10): ...@@ -83,12 +155,16 @@ for _ in range(10):
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
opt.step() opt.step()
if it % 20 == 0: if (it + 1) % 20 == 0:
acc = MF.auroc(score, labels.long())
mem = torch.cuda.max_memory_allocated() / 1000000 mem = torch.cuda.max_memory_allocated() / 1000000
print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') print('Loss', loss.item(), 'GPU Mem', mem, 'MB')
tt = time.time() if (it + 1) == 1000:
print(tt - t0) tt = time.time()
t0 = time.time() print(tt - t0)
durations.append(tt - t0) durations.append(tt - t0)
break
if epoch % 10 == 0:
model.eval()
valid_mrr, test_mrr = evaluate(model, edge_split, device, 12)
print('Validation MRR:', valid_mrr.item(), 'Test MRR:', test_mrr.item())
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
...@@ -7,8 +7,10 @@ import dgl.nn as dglnn ...@@ -7,8 +7,10 @@ import dgl.nn as dglnn
import time import time
import numpy as np import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
import argparse
USE_WRAPPER = True USE_UVA = True # Set to True for UVA sampling
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
...@@ -18,6 +20,8 @@ class SAGE(nn.Module): ...@@ -18,6 +20,8 @@ class SAGE(nn.Module):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden
self.n_classes = n_classes
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
...@@ -28,42 +32,64 @@ class SAGE(nn.Module): ...@@ -28,42 +32,64 @@ class SAGE(nn.Module):
h = self.dropout(h) h = self.dropout(h)
return h return h
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers,
persistent_workers=(num_workers > 0))
if buffer_device is None:
buffer_device = device
for l, layer in enumerate(self.layers):
y = torch.zeros(
g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device=buffer_device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata['h']
h = layer(blocks[0], x)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
y[output_nodes] = h.to(buffer_device)
g.ndata['h'] = y
return y
dataset = DglNodePropPredDataset('ogbn-products') dataset = DglNodePropPredDataset('ogbn-products')
graph, labels = dataset[0] graph, labels = dataset[0]
graph.ndata['label'] = labels graph.ndata['label'] = labels.squeeze()
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() if not USE_UVA:
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) graph = graph.to('cuda')
train_idx = train_idx.to('cuda')
valid_idx = valid_idx.to('cuda')
test_idx = test_idx.to('cuda')
device = 'cuda'
if USE_WRAPPER: model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device)
import dglnew opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
graph.create_formats_()
graph = dglnew.graph.wrapper.DGLGraphStorage(graph)
sampler = dgl.dataloading.NeighborSampler( sampler = dgl.dataloading.NeighborSampler(
[5, 5, 5], output_device='cpu', prefetch_node_feats=['feat'], [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
prefetch_labels=['label']) train_dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.NodeDataLoader( graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
graph, drop_last=False, num_workers=0, use_uva=USE_UVA)
train_idx, valid_dataloader = dgl.dataloading.NodeDataLoader(
sampler, graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
device='cuda', drop_last=False, num_workers=0, use_uva=USE_UVA)
batch_size=1000,
shuffle=True,
drop_last=False,
pin_memory=True,
num_workers=16,
persistent_workers=True,
use_prefetch_thread=True) # TBD: could probably remove this argument
durations = [] durations = []
for _ in range(10): for _ in range(10):
model.train()
t0 = time.time() t0 = time.time()
for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label'][:, 0] y = blocks[-1].dstdata['label']
y_hat = model(blocks, x) y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
opt.zero_grad() opt.zero_grad()
...@@ -76,4 +102,23 @@ for _ in range(10): ...@@ -76,4 +102,23 @@ for _ in range(10):
tt = time.time() tt = time.time()
print(tt - t0) print(tt - t0)
durations.append(tt - t0) durations.append(tt - t0)
model.eval()
ys = []
y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader):
with torch.no_grad():
x = blocks[0].srcdata['feat']
ys.append(blocks[-1].dstdata['label'])
y_hats.append(model(blocks, x))
acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys))
print('Validation acc:', acc.item())
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
# Test accuracy and offline inference of all nodes
model.eval()
with torch.no_grad():
pred = model.inference(graph, device, 4096, 12 if USE_UVA else 0, graph.device)
acc = MF.accuracy(pred.to(graph.device), graph.ndata['label'])
print('Test acc:', acc.item())
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import tqdm import tqdm
USE_WRAPPER = True USE_UVA = False
class HeteroGAT(nn.Module): class HeteroGAT(nn.Module):
def __init__(self, etypes, in_feats, n_hidden, n_classes, n_heads=4): def __init__(self, etypes, in_feats, n_hidden, n_classes, n_heads=4):
...@@ -52,44 +52,62 @@ graph = dgl.AddReverse()(graph) ...@@ -52,44 +52,62 @@ graph = dgl.AddReverse()(graph)
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='rev_writes') graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='rev_writes')
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='has_topic') graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='has_topic')
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='affiliated_with') graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='affiliated_with')
graph.edges['cites'].data['weight'] = torch.ones(graph.num_edges('cites')) # dummy edge weights
model = HeteroGAT(graph.etypes, graph.ndata['feat']['paper'].shape[1], 256, dataset.num_classes).cuda() model = HeteroGAT(graph.etypes, graph.ndata['feat']['paper'].shape[1], 256, dataset.num_classes).cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
if USE_WRAPPER:
import dglnew
graph.create_formats_()
graph = dglnew.graph.wrapper.DGLGraphStorage(graph)
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
sampler = dgl.dataloading.NeighborSampler( if not USE_UVA:
[5, 5, 5], output_device='cpu', graph = graph.to('cuda')
train_idx = recursive_apply(train_idx, lambda x: x.to('cuda'))
valid_idx = recursive_apply(valid_idx, lambda x: x.to('cuda'))
test_idx = recursive_apply(test_idx, lambda x: x.to('cuda'))
train_sampler = dgl.dataloading.NeighborSampler(
[5, 5, 5],
prefetch_node_feats={k: ['feat'] for k in graph.ntypes},
prefetch_labels={'paper': ['label']})
valid_sampler = dgl.dataloading.NeighborSampler(
[10, 10, 10], # Slightly more
prefetch_node_feats={k: ['feat'] for k in graph.ntypes}, prefetch_node_feats={k: ['feat'] for k in graph.ntypes},
prefetch_labels={'paper': ['label']}, prefetch_labels={'paper': ['label']})
prefetch_edge_feats={'cites': ['weight']}) train_dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.NodeDataLoader( graph, train_idx, train_sampler,
graph, device='cuda', batch_size=1000, shuffle=True,
train_idx, drop_last=False, num_workers=0, use_uva=USE_UVA)
sampler, valid_dataloader = dgl.dataloading.NodeDataLoader(
device='cuda', graph, valid_idx, valid_sampler,
batch_size=1000, device='cuda', batch_size=1000, shuffle=False,
shuffle=True, drop_last=False, num_workers=0, use_uva=USE_UVA)
drop_last=False, test_dataloader = dgl.dataloading.NodeDataLoader(
pin_memory=True, graph, test_idx, valid_sampler,
num_workers=8, device='cuda', batch_size=1000, shuffle=False,
persistent_workers=True, drop_last=False, num_workers=0, use_uva=USE_UVA)
use_prefetch_thread=True) # TBD: could probably remove this argument
def evaluate(model, dataloader):
preds = []
labels = []
with torch.no_grad():
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label']['paper'][:, 0]
y_hat = model(blocks, x)
preds.append(y_hat)
labels.append(y)
preds = torch.cat(preds, 0)
labels = torch.cat(labels, 0)
acc = MF.accuracy(preds, labels)
return acc
durations = [] durations = []
for _ in range(10): for _ in range(10):
model.train()
t0 = time.time() t0 = time.time()
for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label']['paper'][:, 0] y = blocks[-1].dstdata['label']['paper'][:, 0]
assert y.min() >= 0 and y.max() < dataset.num_classes
y_hat = model(blocks, x) y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
opt.zero_grad() opt.zero_grad()
...@@ -102,4 +120,9 @@ for _ in range(10): ...@@ -102,4 +120,9 @@ for _ in range(10):
tt = time.time() tt = time.time()
print(tt - t0) print(tt - t0)
durations.append(tt - t0) durations.append(tt - t0)
model.eval()
valid_acc = evaluate(model, valid_dataloader)
test_acc = evaluate(model, test_dataloader)
print('Validation acc:', valid_acc, 'Test acc:', test_acc)
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
...@@ -134,7 +134,7 @@ struct COOMatrix { ...@@ -134,7 +134,7 @@ struct COOMatrix {
* \brief Pin the row, col and data (if not Null) of the matrix. * \brief Pin the row, col and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* kDLCPUPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
...@@ -149,7 +149,7 @@ struct COOMatrix { ...@@ -149,7 +149,7 @@ struct COOMatrix {
/*! /*!
* \brief Unpin the row, col and data (if not Null) of the matrix. * \brief Unpin the row, col and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPUPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
......
...@@ -127,7 +127,7 @@ struct CSRMatrix { ...@@ -127,7 +127,7 @@ struct CSRMatrix {
* \brief Pin the indptr, indices and data (if not Null) of the matrix. * \brief Pin the indptr, indices and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* kDLCPUPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
...@@ -142,7 +142,7 @@ struct CSRMatrix { ...@@ -142,7 +142,7 @@ struct CSRMatrix {
/*! /*!
* \brief Unpin the indptr, indices and data (if not Null) of the matrix. * \brief Unpin the indptr, indices and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPUPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
......
...@@ -43,7 +43,7 @@ ...@@ -43,7 +43,7 @@
*/ */
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \ #define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \
if ((val) == kDLCPU || (val) == kDLCPUPinned) { \ if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \ constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if ((val) == kDLGPU) { \ } else if ((val) == kDLGPU) { \
...@@ -233,7 +233,7 @@ ...@@ -233,7 +233,7 @@
}); });
#define CHECK_VALID_CONTEXT(VAR1, VAR2) \ #define CHECK_VALID_CONTEXT(VAR1, VAR2) \
CHECK(((VAR1)->ctx == (VAR2)->ctx) || ((VAR1)->ctx.device_type == kDLCPUPinned)) \ CHECK(((VAR1)->ctx == (VAR2)->ctx) || (VAR1).IsPinned()) \
<< "Expected " << (#VAR2) << "(" << (VAR2)->ctx << ")" << " to have the same device " \ << "Expected " << (#VAR2) << "(" << (VAR2)->ctx << ")" << " to have the same device " \
<< "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \ << "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \
<< "Or " << (#VAR1) << "(" << (VAR1)->ctx << ")" << " is pinned"; << "Or " << (#VAR1) << "(" << (VAR1)->ctx << ")" << " is pinned";
...@@ -246,7 +246,7 @@ ...@@ -246,7 +246,7 @@
* If csr is pinned, array's context will conduct the actual operation. * If csr is pinned, array's context will conduct the actual operation.
*/ */
#define ATEN_CSR_SWITCH_CUDA_UVA(csr, array, XPU, IdType, op, ...) do { \ #define ATEN_CSR_SWITCH_CUDA_UVA(csr, array, XPU, IdType, op, ...) do { \
CHECK_VALID_CONTEXT(csr.indices, array); \ CHECK_VALID_CONTEXT(csr.indices, array); \
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, op, { \ ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \ ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \
{__VA_ARGS__} \ {__VA_ARGS__} \
...@@ -264,7 +264,7 @@ ...@@ -264,7 +264,7 @@
}); });
// Macro to dispatch according to device context and index type. // Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...) \ #define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH_CUDA((coo).row->ctx.device_type, XPU, op, { \ ATEN_XPU_SWITCH_CUDA((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \ ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \
{__VA_ARGS__} \ {__VA_ARGS__} \
......
...@@ -23,6 +23,7 @@ namespace dgl { ...@@ -23,6 +23,7 @@ namespace dgl {
// Forward declaration // Forward declaration
class BaseHeteroGraph; class BaseHeteroGraph;
class HeteroPickleStates;
typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr; typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr;
struct FlattenedHeteroGraph; struct FlattenedHeteroGraph;
...@@ -436,6 +437,21 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -436,6 +437,21 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0; virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0;
/*!
* \brief Set the COO matrix representation for a given edge type.
*/
virtual void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) = 0;
/*!
* \brief Set the CSR matrix representation for a given edge type.
*/
virtual void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) = 0;
/*!
* \brief Set the CSC matrix representation for a given edge type.
*/
virtual void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) = 0;
/*! /*!
* \brief Extract the induced subgraph by the given vertices. * \brief Extract the induced subgraph by the given vertices.
* *
...@@ -864,6 +880,25 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph); ...@@ -864,6 +880,25 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph);
*/ */
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states); HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);
/*!
* \brief Create heterograph from pickling states pickled by ForkingPickler.
*
* This is different from HeteroUnpickle where
* (1) Backward compatibility is not required,
* (2) All graph formats are pickled instead of only one.
*/
HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);
/*!
* \brief Get the pickling states of the relation graph structure in backend tensors for
* ForkingPickler.
*
* This is different from HeteroPickle where
* (1) Backward compatibility is not required,
* (2) All graph formats are pickled instead of only one.
*/
HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph);
#define FORMAT_HAS_CSC(format) \ #define FORMAT_HAS_CSC(format) \
((format) & CSC_CODE) ((format) & CSC_CODE)
......
...@@ -160,6 +160,13 @@ class DeviceAPI { ...@@ -160,6 +160,13 @@ class DeviceAPI {
*/ */
DGL_DLL virtual void UnpinData(void* ptr); DGL_DLL virtual void UnpinData(void* ptr);
/*!
* \brief Check whether the memory is in pinned memory.
*/
DGL_DLL virtual bool IsPinned(const void* ptr) {
return false;
}
/*! /*!
* \brief Allocate temporal workspace for backend execution. * \brief Allocate temporal workspace for backend execution.
* *
......
...@@ -176,7 +176,7 @@ class NDArray { ...@@ -176,7 +176,7 @@ class NDArray {
* on the underlying DLTensor. * on the underlying DLTensor.
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* kDLCPUPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLGPU: invalid, will throw an error.
*/ */
inline void PinMemory_(); inline void PinMemory_();
...@@ -184,7 +184,7 @@ class NDArray { ...@@ -184,7 +184,7 @@ class NDArray {
* \brief In-place method to unpin the current array by calling UnpinData * \brief In-place method to unpin the current array by calling UnpinData
* on the underlying DLTensor. * on the underlying DLTensor.
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPUPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
*/ */
inline void UnpinMemory_(); inline void UnpinMemory_();
...@@ -299,7 +299,7 @@ class NDArray { ...@@ -299,7 +299,7 @@ class NDArray {
* \note Data of the given array will be pinned inplace. * \note Data of the given array will be pinned inplace.
* Behavior depends on the current context, * Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* kDLCPUPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLGPU: invalid, will throw an error.
*/ */
DGL_DLL static void PinData(DLTensor* tensor); DGL_DLL static void PinData(DLTensor* tensor);
...@@ -309,11 +309,18 @@ class NDArray { ...@@ -309,11 +309,18 @@ class NDArray {
* \param tensor The array to be unpinned. * \param tensor The array to be unpinned.
* \note Data of the given array will be unpinned inplace. * \note Data of the given array will be unpinned inplace.
* Behavior depends on the current context, * Behavior depends on the current context,
* kDLCPUPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
*/ */
DGL_DLL static void UnpinData(DLTensor* tensor); DGL_DLL static void UnpinData(DLTensor* tensor);
/*!
* \brief Function check if the data of a DLTensor is pinned.
* \param tensor The array to be checked.
* \return true if pinned.
*/
DGL_DLL static bool IsDataPinned(DLTensor* tensor);
// internal namespace // internal namespace
struct Internal; struct Internal;
private: private:
...@@ -485,7 +492,7 @@ inline void NDArray::UnpinMemory_() { ...@@ -485,7 +492,7 @@ inline void NDArray::UnpinMemory_() {
inline bool NDArray::IsPinned() const { inline bool NDArray::IsPinned() const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
return data_->dl_tensor.ctx.device_type == kDLCPUPinned; return IsDataPinned(&(data_->dl_tensor));
} }
inline int NDArray::use_count() const { inline int NDArray::use_count() const {
......
...@@ -82,8 +82,6 @@ class NDArrayBase(object): ...@@ -82,8 +82,6 @@ class NDArrayBase(object):
Indicates the alignment requirement when converting to dlpack. Will copy to a Indicates the alignment requirement when converting to dlpack. Will copy to a
new tensor if the alignment requirement is not satisfied. new tensor if the alignment requirement is not satisfied.
0 means no alignment requirement. 0 means no alignment requirement.
Will copy to a new tensor if the array is pinned because some backends,
e.g., pytorch, do not support kDLCPUPinned device type.
Returns Returns
......
...@@ -316,25 +316,15 @@ class NDArrayBase(_NDArrayBase): ...@@ -316,25 +316,15 @@ class NDArrayBase(_NDArrayBase):
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
return target return target
def pin_memory_(self, ctx): def pin_memory_(self):
"""Pin host memory and map into GPU address space (in-place) """Pin host memory and map into GPU address space (in-place)
Parameters
----------
ctx : DGLContext
The target GPU to map the host memory space
""" """
check_call(_LIB.DGLArrayPinData(self.handle, ctx)) check_call(_LIB.DGLArrayPinData(self.handle))
def unpin_memory_(self, ctx): def unpin_memory_(self):
"""Unpin host memory pinned by pin_memory_() """Unpin host memory pinned by pin_memory_()
Parameters
----------
ctx : DGLContext
The target GPU to map the host memory space
""" """
check_call(_LIB.DGLArrayUnpinData(self.handle, ctx)) check_call(_LIB.DGLArrayUnpinData(self.handle))
def free_extension_handle(handle, type_code): def free_extension_handle(handle, type_code):
......
...@@ -330,6 +330,21 @@ def copy_to(input, ctx, **kwargs): ...@@ -330,6 +330,21 @@ def copy_to(input, ctx, **kwargs):
""" """
pass pass
def is_pinned(input):
"""Check whether the tensor is in pinned memory.
Parameters
----------
input : Tensor
The tensor.
Returns
-------
bool
Whether the tensor is in pinned memory.
"""
pass
############################################################################### ###############################################################################
# Tensor functions on feature data # Tensor functions on feature data
# -------------------------------- # --------------------------------
......
...@@ -144,6 +144,9 @@ def asnumpy(input): ...@@ -144,6 +144,9 @@ def asnumpy(input):
def copy_to(input, ctx, **kwargs): def copy_to(input, ctx, **kwargs):
return input.as_in_context(ctx) return input.as_in_context(ctx)
def is_pinned(input):
return input.context == mx.cpu_pinned()
def sum(input, dim, keepdims=False): def sum(input, dim, keepdims=False):
if len(input) == 0: if len(input) == 0:
return nd.array([0.], dtype=input.dtype, ctx=input.context) return nd.array([0.], dtype=input.dtype, ctx=input.context)
......
...@@ -120,6 +120,9 @@ def copy_to(input, ctx, **kwargs): ...@@ -120,6 +120,9 @@ def copy_to(input, ctx, **kwargs):
else: else:
raise RuntimeError('Invalid context', ctx) raise RuntimeError('Invalid context', ctx)
def is_pinned(input):
return input.is_pinned()
def sum(input, dim, keepdims=False): def sum(input, dim, keepdims=False):
return th.sum(input, dim=dim, keepdim=keepdims) return th.sum(input, dim=dim, keepdim=keepdims)
......
...@@ -162,6 +162,8 @@ def copy_to(input, ctx, **kwargs): ...@@ -162,6 +162,8 @@ def copy_to(input, ctx, **kwargs):
new_tensor = tf.identity(input) new_tensor = tf.identity(input)
return new_tensor return new_tensor
def is_pinned(input):
return False # not sure how to do this
def sum(input, dim, keepdims=False): def sum(input, dim, keepdims=False):
if input.dtype == tf.bool: if input.dtype == tf.bool:
......
...@@ -75,7 +75,7 @@ class UnifiedTensor: #UnifiedTensor ...@@ -75,7 +75,7 @@ class UnifiedTensor: #UnifiedTensor
self._array = F.zerocopy_to_dgl_ndarray(self._input) self._array = F.zerocopy_to_dgl_ndarray(self._input)
self._device = device self._device = device
self._array.pin_memory_(utils.to_dgl_context(self._device)) self._array.pin_memory_()
def __len__(self): def __len__(self):
return len(self._array) return len(self._array)
...@@ -105,7 +105,7 @@ class UnifiedTensor: #UnifiedTensor ...@@ -105,7 +105,7 @@ class UnifiedTensor: #UnifiedTensor
def __del__(self): def __del__(self):
if hasattr(self, '_array') and self._array != None: if hasattr(self, '_array') and self._array != None:
self._array.unpin_memory_(utils.to_dgl_context(self._device)) self._array.unpin_memory_()
self._array = None self._array = None
if hasattr(self, '_input'): if hasattr(self, '_input'):
......
...@@ -124,6 +124,8 @@ def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): ...@@ -124,6 +124,8 @@ def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
def _find_exclude_eids(g, exclude_mode, eids, **kwargs): def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
if exclude_mode is None: if exclude_mode is None:
return None return None
elif callable(exclude_mode):
return exclude_mode(eids)
elif F.is_tensor(exclude_mode) or ( elif F.is_tensor(exclude_mode) or (
isinstance(exclude_mode, Mapping) and isinstance(exclude_mode, Mapping) and
all(F.is_tensor(v) for v in exclude_mode.values())): all(F.is_tensor(v) for v in exclude_mode.values())):
...@@ -151,9 +153,6 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes= ...@@ -151,9 +153,6 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=
None (default) None (default)
Does not exclude any edge. Does not exclude any edge.
Tensor or dict[etype, Tensor]
Exclude the given edge IDs.
'self' 'self'
Exclude the given edges themselves but nothing else. Exclude the given edges themselves but nothing else.
...@@ -176,6 +175,10 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes= ...@@ -176,6 +175,10 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=
This mode assumes that the reverse of an edge with ID ``e`` and type ``etype`` This mode assumes that the reverse of an edge with ID ``e`` and type ``etype``
will have ID ``e`` and type ``reverse_etype_map[etype]``. will have ID ``e`` and type ``reverse_etype_map[etype]``.
callable
Any function that takes in a single argument :attr:`seed_edges` and returns
a tensor or dict of tensors.
eids : Tensor or dict[etype, Tensor] eids : Tensor or dict[etype, Tensor]
The edge IDs. The edge IDs.
reverse_eids : Tensor or dict[etype, Tensor] reverse_eids : Tensor or dict[etype, Tensor]
...@@ -191,9 +194,8 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes= ...@@ -191,9 +194,8 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=
seed_edges, seed_edges,
reverse_eid_map=reverse_eids, reverse_eid_map=reverse_eids,
reverse_etype_map=reverse_etypes) reverse_etype_map=reverse_etypes)
if exclude_eids is not None: if exclude_eids is not None and output_device is not None:
exclude_eids = recursive_apply( exclude_eids = recursive_apply(exclude_eids, lambda x: F.copy_to(x, output_device))
exclude_eids, lambda x: x.to(output_device))
return exclude_eids return exclude_eids
...@@ -202,8 +204,8 @@ class EdgeBlockSampler(object): ...@@ -202,8 +204,8 @@ class EdgeBlockSampler(object):
classification and link prediction. classification and link prediction.
""" """
def __init__(self, block_sampler, exclude=None, reverse_eids=None, def __init__(self, block_sampler, exclude=None, reverse_eids=None,
reverse_etypes=None, negative_sampler=None, prefetch_node_feats=None, reverse_etypes=None, negative_sampler=None,
prefetch_labels=None, prefetch_edge_feats=None): prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None,):
self.reverse_eids = reverse_eids self.reverse_eids = reverse_eids
self.reverse_etypes = reverse_etypes self.reverse_etypes = reverse_etypes
self.exclude = exclude self.exclude = exclude
...@@ -249,6 +251,8 @@ class EdgeBlockSampler(object): ...@@ -249,6 +251,8 @@ class EdgeBlockSampler(object):
If :attr:`negative_sampler` is given, also returns another graph containing the If :attr:`negative_sampler` is given, also returns another graph containing the
negative pairs as edges. negative pairs as edges.
""" """
if isinstance(seed_edges, Mapping):
seed_edges = {g.to_canonical_etype(k): v for k, v in seed_edges.items()}
exclude = self.exclude exclude = self.exclude
pair_graph = g.edge_subgraph( pair_graph = g.edge_subgraph(
seed_edges, relabel_nodes=False, output_device=self.output_device) seed_edges, relabel_nodes=False, output_device=self.output_device)
......
...@@ -55,7 +55,7 @@ class ClusterGCNSampler(object): ...@@ -55,7 +55,7 @@ class ClusterGCNSampler(object):
partition_node_ids = np.argsort(partition_ids) partition_node_ids = np.argsort(partition_ids)
partition_size = F.zerocopy_from_numpy(np.bincount(partition_ids, minlength=k)) partition_size = F.zerocopy_from_numpy(np.bincount(partition_ids, minlength=k))
partition_offset = F.zerocopy_from_numpy(np.insert(np.cumsum(partition_size), 0, 0)) partition_offset = F.zerocopy_from_numpy(np.insert(np.cumsum(partition_size), 0, 0))
partition_node_ids = F.zerocopy_from_numpy(partition_ids) partition_node_ids = F.zerocopy_from_numpy(partition_node_ids)
with open(cache_path, 'wb') as f: with open(cache_path, 'wb') as f:
pickle.dump((partition_offset, partition_node_ids), f) pickle.dump((partition_offset, partition_node_ids), f)
self.partition_offset = partition_offset self.partition_offset = partition_offset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment