Unverified Commit 82ca7819 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bug] Record stream when using another CUDA stream for data transfer (#4250)

* record stream when using another cuda stream for data transfer

* fix linting

* fix None stream
parent 2f322a94
...@@ -28,7 +28,6 @@ from .base import BlockSampler, as_edge_prediction_sampler ...@@ -28,7 +28,6 @@ from .base import BlockSampler, as_edge_prediction_sampler
from .. import backend as F from .. import backend as F
from ..distributed import DistGraph from ..distributed import DistGraph
from ..multiprocessing import call_once_and_share from ..multiprocessing import call_once_and_share
from ..cuda import stream as dgl_stream
PYTORCH_VER = LooseVersion(torch.__version__) PYTORCH_VER = LooseVersion(torch.__version__)
PYTHON_EXIT_STATUS = False PYTHON_EXIT_STATUS = False
...@@ -305,6 +304,18 @@ def _await_or_return(x): ...@@ -305,6 +304,18 @@ def _await_or_return(x):
else: else:
return x return x
def _record_stream(x, stream):
if stream is None:
return x
if isinstance(x, torch.Tensor):
x.record_stream(stream)
return x
elif isinstance(x, _PrefetchedGraphFeatures):
node_feats = recursive_apply(x.node_feats, _record_stream, stream)
edge_feats = recursive_apply(x.edge_feats, _record_stream, stream)
return _PrefetchedGraphFeatures(node_feats, edge_feats)
else:
return x
def _prefetch(batch, dataloader, stream): def _prefetch(batch, dataloader, stream):
# feats has the same nested structure of batch, except that # feats has the same nested structure of batch, except that
...@@ -316,12 +327,21 @@ def _prefetch(batch, dataloader, stream): ...@@ -316,12 +327,21 @@ def _prefetch(batch, dataloader, stream):
# #
# Once the futures are fetched, this function waits for them to complete by # Once the futures are fetched, this function waits for them to complete by
# calling its wait() method. # calling its wait() method.
with torch.cuda.stream(stream), dgl_stream(stream): if stream is not None:
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(stream)
else:
current_stream = None
with torch.cuda.stream(stream):
# fetch node/edge features # fetch node/edge features
feats = recursive_apply(batch, _prefetch_for, dataloader) feats = recursive_apply(batch, _prefetch_for, dataloader)
feats = recursive_apply(feats, _await_or_return) feats = recursive_apply(feats, _await_or_return)
# transfer input nodes/seed nodes/sampled subgraph feats = recursive_apply(feats, _record_stream, current_stream)
# transfer input nodes/seed nodes
# TODO(Xin): sampled subgraph is transferred in the default stream
# because heterograph doesn't support .record_stream() for now
batch = recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True)) batch = recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True))
batch = recursive_apply(batch, _record_stream, current_stream)
stream_event = stream.record_event() if stream is not None else None stream_event = stream.record_event() if stream is not None else None
return batch, feats, stream_event return batch, feats, stream_event
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment