Unverified Commit af79e540 authored by Nicolas Castet's avatar Nicolas Castet Committed by GitHub
Browse files

[Bugfix] Fix Column.record_stream(...) for unmaterialized tensors (#5240)



* Fix Column.record_stream(...) for unmaterialized tensors

* lint

* Record streams for indices on non-materialized columns

* add docstring

* fix for cpu index

* Record also stream on storage

* Always call self.storage.record_stream when storage is on GPU

* fix lint

---------
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent d49a3019
......@@ -5,7 +5,7 @@ from collections import namedtuple
from collections.abc import MutableMapping
from . import backend as F
from .base import DGLError, dgl_warning
from .base import dgl_warning, DGLError
from .init import zero_initializer
from .storages import TensorStorage
from .utils import gather_pinned_tensor_rows, pin_memory_inplace
......@@ -41,6 +41,17 @@ class _LazyIndex(object):
flat_index = F.gather_row(flat_index, index)
return flat_index
def record_stream(self, stream):
"""Record stream for index.
Parameters
----------
stream : torch.cuda.Stream.
"""
for index in self._indices:
if F.context(index) != F.cpu():
index.record_stream(stream)
class LazyFeature(object):
"""Placeholder for feature prefetching.
......@@ -548,7 +559,13 @@ class Column(TensorStorage):
"""
if F.get_preferred_backend() != "pytorch":
raise DGLError("record_stream only supports the PyTorch backend.")
self.data.record_stream(stream)
if self.index is not None and (
isinstance(self.index, _LazyIndex)
or F.context(self.index) != F.cpu()
):
self.index.record_stream(stream)
if F.context(self.storage) != F.cpu():
self.storage.record_stream(stream)
class Frame(MutableMapping):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment