Unverified Commit e053df79 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] enable async transfer in NodeDataLoader for homograph (#3407)

* [Feature] enable async transfer in NodeDataLoader for homograph

* fix lint issues

* fix device choose when creating stream

* fix test on cpu only machine

* fix pin_memory config

* support homo only

* avoid creating stream in each step and sync via event

* fix lint

* enable graph copy on non-default stream

* fix lint

* refine arg description

* fix conflicts
parent 65fdfad6
"""DGL PyTorch DataLoaders"""
import inspect
import math
import threading
import queue
from distutils.version import LooseVersion
import torch as th
from torch.utils.data import DataLoader, IterableDataset
......@@ -13,6 +15,7 @@ from ...ndarray import NDArray as DGLNDArray
from ... import backend as F
from ...base import DGLError
from ...utils import to_dgl_context
from ..._ffi import streams as FS
__all__ = ['NodeDataLoader', 'EdgeDataLoader', 'GraphDataLoader',
# Temporary exposure.
......@@ -330,23 +333,89 @@ def _to_device(data, device):
data = data.to(device)
return data
def _index_select(in_tensor, idx, pin_memory):
idx = idx.to(in_tensor.device)
shape = list(in_tensor.shape)
shape[0] = len(idx)
out_tensor = th.empty(*shape, dtype=in_tensor.dtype, pin_memory=pin_memory)
th.index_select(in_tensor, 0, idx, out=out_tensor)
return out_tensor
def _next(dl_iter, graph, device, load_input, load_output, stream=None):
# input_nodes, ouput_nodes, blocks
input_nodes, output_nodes, blocks = next(dl_iter)
_restore_storages(blocks, graph)
input_data = {}
for tag, data in load_input.items():
sliced = _index_select(data, input_nodes, data.device != device)
input_data[tag] = sliced
output_data = {}
for tag, data in load_output.items():
sliced = _index_select(data, output_nodes, data.device != device)
output_data[tag] = sliced
result_ = (input_nodes, output_nodes, blocks, input_data, output_data)
if stream is not None:
with th.cuda.stream(stream):
with FS.stream(stream):
result = [_to_device(data, device)
for data in result_], result_, stream.record_event()
else:
result = [_to_device(data, device) for data in result_]
return result
def _background_node_dataloader(dl_iter, g, device, results, load_input, load_output):
dev = None
if device.type == 'cuda':
dev = device
elif g.device.type == 'cuda':
dev = g.device
stream = th.cuda.Stream(device=dev)
try:
while True:
results.put(_next(dl_iter, g, device, load_input, load_output, stream))
except StopIteration:
results.put((None, None, None))
class _NodeDataLoaderIter:
def __init__(self, node_dataloader):
self.device = node_dataloader.device
self.node_dataloader = node_dataloader
self.iter_ = iter(node_dataloader.dataloader)
self.async_load = node_dataloader.async_load and (
F.device_type(self.device) == 'cuda')
if self.async_load:
self.results = queue.Queue(1)
threading.Thread(target=_background_node_dataloader, args=(
self.iter_, self.node_dataloader.collator.g, self.device,
self.results, node_dataloader.load_input, node_dataloader.load_output
), daemon=True).start()
# Make this an iterator for PyTorch Lightning compatibility
def __iter__(self):
return self
def __next__(self):
# input_nodes, output_nodes, blocks
result_ = next(self.iter_)
_restore_storages(result_[-1], self.node_dataloader.collator.g)
result = [_to_device(data, self.device) for data in result_]
return result
res = ()
if self.async_load:
res, _, event = self.results.get()
if res is None:
raise StopIteration
event.wait(th.cuda.default_stream())
else:
res = _next(self.iter_, self.node_dataloader.collator.g, self.device,
self.node_dataloader.load_input, self.node_dataloader.load_output)
input_nodes, output_nodes, blocks, input_data, output_data = res
if input_data:
for tag, data in input_data.items():
blocks[0].srcdata[tag] = data
if output_data:
for tag, data in output_data.items():
blocks[-1].dstdata[tag] = data
return input_nodes, output_nodes, blocks
class _EdgeDataLoaderIter:
def __init__(self, edge_dataloader):
......@@ -459,6 +528,19 @@ class NodeDataLoader:
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
load_input : dict[tag, Tensor], optional
The tensors will be sliced according to ``blocks[0].srcdata[dgl.NID]``
and will be attached to ``blocks[0].srcdata``.
load_output : dict[tag, Tensor], optional
The tensors will be sliced according to ``blocks[-1].dstdata[dgl.NID]``
and will be attached to ``blocks[-1].dstdata``.
async_load : boolean, optional
If True, data including graph, sliced tensors will be transferred
between devices asynchronously.This is transparent to end users. This
feature could speed up model train, especially when large data need
to be transferred. As a disadvantage, underlying `to_block` on GPU
becomes disabled and could lead to decreased performance. This is a
trade-off which needs profiling to decide whether to enable it.
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
......@@ -516,7 +598,8 @@ class NodeDataLoader:
"""
collator_arglist = inspect.getfullargspec(NodeCollator).args
def __init__(self, g, nids, graph_sampler, device=None, use_ddp=False, ddp_seed=0, **kwargs):
def __init__(self, g, nids, graph_sampler, device=None, use_ddp=False, ddp_seed=0,
load_input=None, load_output=None, async_load=False, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
......@@ -543,10 +626,21 @@ class NodeDataLoader:
# default to the same device the graph is on
device = th.device(g.device)
# if the sampler supports it, tell it to output to the
# specified device
if not g.is_homogeneous:
if load_input or load_output:
raise DGLError('load_input/load_output not supported for heterograph yet.')
self.load_input = {} if load_input is None else load_input
self.load_output = {} if load_output is None else load_output
self.async_load = async_load
# if the sampler supports it, tell it to output to the specified device.
# But if async_load is enabled, set_output_context should be skipped as
# we'd like to avoid any graph/data transfer graphs across devices in
# sampler. Such transfer will be handled in dataloader.
num_workers = dataloader_kwargs.get('num_workers', 0)
if callable(getattr(graph_sampler, "set_output_context", None)) and num_workers == 0:
if ((not async_load) and
callable(getattr(graph_sampler, "set_output_context", None)) and
num_workers == 0):
graph_sampler.set_output_context(to_dgl_context(device))
self.collator = _NodeCollator(g, nids, graph_sampler, **collator_kwargs)
......
import os
import dgl
import dgl.ops as OPS
import backend as F
import unittest
import torch
......@@ -291,18 +292,33 @@ def _check_device(data):
def test_node_dataloader(sampler_name):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
sampler = {
'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'shadow': dgl.dataloading.ShaDowKHopSampler([3, 3])}[sampler_name]
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(), batch_size=g1.num_nodes())
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
g1.ndata['label'] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
for load_input, load_output in [(None, None), ({'feat': g1.ndata['feat']}, {'label': g1.ndata['label']})]:
for async_load in [False, True]:
for num_workers in [0, 1, 2]:
sampler = {
'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'shadow': dgl.dataloading.ShaDowKHopSampler([3, 3])}[sampler_name]
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(),
load_input=load_input,
load_output=load_output,
async_load=async_load,
batch_size=g1.num_nodes(),
num_workers=num_workers)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
if load_input:
_check_device(blocks[0].srcdata['feat'])
OPS.copy_u_sum(blocks[0], blocks[0].srcdata['feat'])
if load_output:
_check_device(blocks[-1].dstdata['label'])
OPS.copy_u_sum(blocks[-1], blocks[-1].dstdata['label'])
g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
......@@ -319,14 +335,24 @@ def test_node_dataloader(sampler_name):
'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'shadow': dgl.dataloading.ShaDowKHopSampler([{etype: 3 for etype in g2.etypes}] * 2)}[sampler_name]
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size)
assert isinstance(iter(dataloader), Iterator)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
for async_load in [False, True]:
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), async_load=async_load, batch_size=batch_size)
assert isinstance(iter(dataloader), Iterator)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
status = False
try:
dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), load_input={'feat': g1.ndata['feat']}, batch_size=batch_size)
except dgl.DGLError:
status = True
assert status
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'shadow'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment