Unverified Commit 5bef48df authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Performance] Optimize the use of alternative streams in dataloader (#4177)

* fix using alternative streams

* use a alternative stream for subgraph transferring

* fix StreamContext when stream is None
parent 5640b129
......@@ -25,15 +25,20 @@ class StreamContext(object):
Parameters
----------
cuda_stream : torch.cuda.Stream
cuda_stream : torch.cuda.Stream. This manager is a no-op if it's ``None``.
target stream will be set.
"""
if cuda_stream is None:
self.curr_cuda_stream = None
else:
self.ctx = to_dgl_context(cuda_stream.device)
self.curr_cuda_stream = cuda_stream.cuda_stream
def __enter__(self):
""" get previous stream and set target stream as current.
"""
if self.curr_cuda_stream is None:
return
self.prev_cuda_stream = DGLStreamHandle()
check_call(_LIB.DGLGetStream(
self.ctx.device_type, self.ctx.device_id, ctypes.byref(self.prev_cuda_stream)))
......@@ -43,6 +48,8 @@ class StreamContext(object):
def __exit__(self, exc_type, exc_value, exc_traceback):
""" restore previous stream when exiting.
"""
if self.curr_cuda_stream is None:
return
check_call(_LIB.DGLSetStream(
self.ctx.device_type, self.ctx.device_id, self.prev_cuda_stream))
......@@ -52,7 +59,7 @@ def stream(cuda_stream):
Parameters
----------
stream : torch.cuda.Stream
stream : torch.cuda.Stream. This manager is a no-op if it's ``None``.
target stream will be set.
"""
return StreamContext(cuda_stream)
""" CUDA wrappers """
from . import nccl
from .._ffi.streams import stream
......@@ -28,6 +28,7 @@ from .base import BlockSampler, as_edge_prediction_sampler
from .. import backend as F
from ..distributed import DistGraph
from ..multiprocessing import call_once_and_share
from ..cuda import stream as dgl_stream
PYTHON_EXIT_STATUS = False
def _set_python_exit_flag():
......@@ -314,10 +315,14 @@ def _prefetch(batch, dataloader, stream):
#
# Once the futures are fetched, this function waits for them to complete by
# calling its wait() method.
with torch.cuda.stream(stream):
with torch.cuda.stream(stream), dgl_stream(stream):
# fetch node/edge features
feats = recursive_apply(batch, _prefetch_for, dataloader)
feats = recursive_apply(feats, _await_or_return)
return feats
# transfer input nodes/seed nodes/sampled subgraph
batch = recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True))
stream_event = stream.record_event() if stream is not None else None
return batch, feats, stream_event
def _assign_for(item, feat):
......@@ -344,18 +349,11 @@ def _put_if_event_not_set(queue, result, event):
continue
def _prefetcher_entry(
dataloader_it, dataloader, queue, num_threads, use_alternate_streams,
done_event):
dataloader_it, dataloader, queue, num_threads, stream, done_event):
# PyTorch will set the number of threads to 1 which slows down pin_memory() calls
# in main process if a prefetching thread is created.
if num_threads is not None:
torch.set_num_threads(num_threads)
if use_alternate_streams:
stream = (
torch.cuda.Stream(device=dataloader.device)
if dataloader.device.type == 'cuda' else None)
else:
stream = None
try:
while not done_event.is_set():
......@@ -364,17 +362,8 @@ def _prefetcher_entry(
except StopIteration:
break
batch = recursive_apply(batch, restore_parent_storage_columns, dataloader.graph)
feats = _prefetch(batch, dataloader, stream)
_put_if_event_not_set(queue, (
# batch will be already in pinned memory as per the behavior of
# PyTorch DataLoader.
recursive_apply(
batch, lambda x: x.to(dataloader.device, non_blocking=True)),
feats,
stream.record_event() if stream is not None else None,
None),
done_event)
batch, feats, stream_event = _prefetch(batch, dataloader, stream)
_put_if_event_not_set(queue, (batch, feats, stream_event, None), done_event)
_put_if_event_not_set(queue, (None, None, None, None), done_event)
except: # pylint: disable=bare-except
_put_if_event_not_set(
......@@ -432,24 +421,26 @@ def restore_parent_storage_columns(item, g):
class _PrefetchingIter(object):
def __init__(self, dataloader, dataloader_it, use_thread=False, use_alternate_streams=True,
num_threads=None):
def __init__(self, dataloader, dataloader_it, num_threads=None):
self.queue = Queue(1)
self.dataloader_it = dataloader_it
self.dataloader = dataloader
self.graph_sampler = self.dataloader.graph_sampler
self.pin_prefetcher = self.dataloader.pin_prefetcher
self.num_threads = num_threads
self.use_thread = use_thread
self.use_alternate_streams = use_alternate_streams
self.use_thread = dataloader.use_prefetch_thread
self.use_alternate_streams = dataloader.use_alternate_streams
self.device = self.dataloader.device
if self.use_alternate_streams and self.device.type == 'cuda':
self.stream = torch.cuda.Stream(device=self.device)
else:
self.stream = None
self._shutting_down = False
if use_thread:
if self.use_thread:
self._done_event = threading.Event()
thread = threading.Thread(
target=_prefetcher_entry,
args=(dataloader_it, dataloader, self.queue, num_threads,
use_alternate_streams, self._done_event),
self.stream, self._done_event),
daemon=True)
thread.start()
self.thread = thread
......@@ -485,14 +476,7 @@ class _PrefetchingIter(object):
def _next_non_threaded(self):
batch = next(self.dataloader_it)
batch = recursive_apply(batch, restore_parent_storage_columns, self.dataloader.graph)
device = self.dataloader.device
if self.use_alternate_streams:
stream = torch.cuda.Stream(device=device) if device.type == 'cuda' else None
else:
stream = None
feats = _prefetch(batch, self.dataloader, stream)
batch = recursive_apply(batch, lambda x: x.to(device, non_blocking=True))
stream_event = stream.record_event() if stream is not None else None
batch, feats, stream_event = _prefetch(batch, self.dataloader, self.stream)
return batch, feats, stream_event
def _next_threaded(self):
......@@ -880,9 +864,7 @@ class DataLoader(torch.utils.data.DataLoader):
# When using multiprocessing PyTorch sometimes set the number of PyTorch threads to 1
# when spawning new Python threads. This drastically slows down pinning features.
num_threads = torch.get_num_threads() if self.num_workers > 0 else None
return _PrefetchingIter(
self, super().__iter__(), use_thread=self.use_prefetch_thread,
use_alternate_streams=self.use_alternate_streams, num_threads=num_threads)
return _PrefetchingIter(self, super().__iter__(), num_threads=num_threads)
def worker_init_function(self, worker_id):
"""Worker init default function.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment