"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "18200240053b1ef5f7beb0584c01dd6677927e84"
Unverified Commit d8375c10 authored by Peiqi Yin's avatar Peiqi Yin Committed by GitHub
Browse files

[BugFix] involve indices_devices for other dataset object. (#3810)

* involve indices_devices for other dataset object.

* modify raise error when device not found.

* remove empty line

* fix line too long
parent fbbca994
......@@ -707,6 +707,7 @@ class DataLoader(torch.utils.data.DataLoader):
self.indices = indices # For PyTorch-Lightning
num_workers = kwargs.get('num_workers', 0)
indices_device = None
try:
if isinstance(indices, Mapping):
indices = {k: (torch.tensor(v) if not torch.is_tensor(v) else v)
......@@ -718,7 +719,11 @@ class DataLoader(torch.utils.data.DataLoader):
except: # pylint: disable=bare-except
# ignore when it fails to convert to torch Tensors.
pass
if indices_device is None:
if not hasattr(indices, 'device'):
raise AttributeError('Custom indices dataset requires a \"device\" \
attribute indicating where the indices is.')
indices_device = indices.device
self.device = _get_device(device)
# Sanity check - we only check for DGLGraphs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment