"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "236ffa0f63561bd27ec4e0aadd565211621f8fc2"
Unverified Commit 23a5e674 authored by Peiqi Yin's avatar Peiqi Yin Committed by GitHub
Browse files

[fix] device cannot be Nonetype object. (#3822)



* fix device = none to cpu.

* fix singleton-comparison

* modify device is none

* fix use uva

* add docs

* add edge dataloader

* add doc

* fix trailing-whitespace

* modify default
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 43b9d40e
...@@ -589,7 +589,9 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -589,7 +589,9 @@ class DataLoader(torch.utils.data.DataLoader):
The device of the generated MFGs in each iteration, which should be a The device of the generated MFGs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``). PyTorch device object (e.g., ``torch.device``).
By default this value is the same as the device of :attr:`g`. By default this value is None. If :attr:`use_uva` is True, MFGs and graphs will
generated in torch.cuda.current_device(), otherwise generated in the same device
of :attr:`g`.
use_ddp : boolean, optional use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each If True, tells the DataLoader to split the training set for each
participating process appropriately using participating process appropriately using
...@@ -686,7 +688,7 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -686,7 +688,7 @@ class DataLoader(torch.utils.data.DataLoader):
- Otherwise, both the sampling and subgraph construction will take place on the CPU. - Otherwise, both the sampling and subgraph construction will take place on the CPU.
""" """
def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False, def __init__(self, graph, indices, graph_sampler, device=None, use_ddp=False,
ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, ddp_seed=0, batch_size=1, drop_last=False, shuffle=False,
use_prefetch_thread=None, use_alternate_streams=None, use_prefetch_thread=None, use_alternate_streams=None,
pin_prefetcher=None, use_uva=False, **kwargs): pin_prefetcher=None, use_uva=False, **kwargs):
...@@ -719,11 +721,18 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -719,11 +721,18 @@ class DataLoader(torch.utils.data.DataLoader):
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
# ignore when it fails to convert to torch Tensors. # ignore when it fails to convert to torch Tensors.
pass pass
if indices_device is None: if indices_device is None:
if not hasattr(indices, 'device'): if not hasattr(indices, 'device'):
raise AttributeError('Custom indices dataset requires a \"device\" \ raise AttributeError('Custom indices dataset requires a \"device\" \
attribute indicating where the indices is.') attribute indicating where the indices is.')
indices_device = indices.device indices_device = indices.device
if device is None:
if use_uva:
device = torch.cuda.current_device()
else:
device = self.graph.device
self.device = _get_device(device) self.device = _get_device(device)
# Sanity check - we only check for DGLGraphs. # Sanity check - we only check for DGLGraphs.
...@@ -868,12 +877,18 @@ class EdgeDataLoader(DataLoader): ...@@ -868,12 +877,18 @@ class EdgeDataLoader(DataLoader):
g, train_eid, sampler, g, train_eid, sampler,
batch_size=1024, shuffle=True, drop_last=False, num_workers=4) batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
""" """
def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False, def __init__(self, graph, indices, graph_sampler, device=None, use_ddp=False,
ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, ddp_seed=0, batch_size=1, drop_last=False, shuffle=False,
use_prefetch_thread=False, use_alternate_streams=True, use_prefetch_thread=False, use_alternate_streams=True,
pin_prefetcher=False, pin_prefetcher=False,
exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None,
use_uva=False, **kwargs): use_uva=False, **kwargs):
if device is None:
if use_uva:
device = torch.cuda.current_device()
else:
device = self.graph.device
device = _get_device(device) device = _get_device(device)
if isinstance(graph_sampler, BlockSampler): if isinstance(graph_sampler, BlockSampler):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment