Unverified Commit f85da5f5 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Turn off `use_alternate_stream` when TensorAdaptor is not enabled (#4987)

* turn off use_alternate_stream when tensoradaptor is not enabled

* add docstring
parent a2defae2
......@@ -148,3 +148,8 @@ def load_tensor_adapter(backend, version):
if not tensor_adapter_loaded:
logger = logging.getLogger("dgl-core")
logger.debug("Memory optimization with PyTorch is not enabled.")
def is_tensor_adaptor_enabled() -> bool:
"""Check whether TensorAdaptor is enabled."""
return tensor_adapter_loaded
......@@ -19,6 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
from ..base import NID, EID, dgl_warning, DGLError
from ..batch import batch as batch_graphs
from .._ffi.base import is_tensor_adaptor_enabled
from ..heterograph import DGLGraph
from ..utils import (
recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads, get_num_threads,
......@@ -821,8 +822,15 @@ class DataLoader(torch.utils.data.DataLoader):
# Check use_alternate_streams
if use_alternate_streams is None:
use_alternate_streams = (
self.device.type == 'cuda' and self.graph.device.type == 'cpu' and
not use_uva)
self.device.type == "cuda"
and self.graph.device.type == "cpu"
and not use_uva
and is_tensor_adaptor_enabled()
)
elif use_alternate_streams and not is_tensor_adaptor_enabled():
dgl_warning("use_alternate_streams is turned off because "
"TensorAdaptor is not available.")
use_alternate_streams = False
if (torch.is_tensor(indices) or (
isinstance(indices, Mapping) and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment