Unverified Commit 23a5e674 authored by Peiqi Yin's avatar Peiqi Yin Committed by GitHub
Browse files

[fix] device cannot be Nonetype object. (#3822)



* fix device = none to cpu.

* fix singleton-comparison

* modify device is none

* fix use uva

* add docs

* add edge dataloader

* add doc

* fix trailing-whitespace

* modify default
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 43b9d40e
......@@ -589,7 +589,9 @@ class DataLoader(torch.utils.data.DataLoader):
The device of the generated MFGs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
By default this value is the same as the device of :attr:`g`.
By default this value is None. If :attr:`use_uva` is True, MFGs and graphs will
generated in torch.cuda.current_device(), otherwise generated in the same device
of :attr:`g`.
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
......@@ -686,7 +688,7 @@ class DataLoader(torch.utils.data.DataLoader):
- Otherwise, both the sampling and subgraph construction will take place on the CPU.
"""
def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False,
def __init__(self, graph, indices, graph_sampler, device=None, use_ddp=False,
ddp_seed=0, batch_size=1, drop_last=False, shuffle=False,
use_prefetch_thread=None, use_alternate_streams=None,
pin_prefetcher=None, use_uva=False, **kwargs):
......@@ -719,11 +721,18 @@ class DataLoader(torch.utils.data.DataLoader):
except: # pylint: disable=bare-except
# ignore when it fails to convert to torch Tensors.
pass
if indices_device is None:
if not hasattr(indices, 'device'):
raise AttributeError('Custom indices dataset requires a \"device\" \
attribute indicating where the indices is.')
indices_device = indices.device
if device is None:
if use_uva:
device = torch.cuda.current_device()
else:
device = self.graph.device
self.device = _get_device(device)
# Sanity check - we only check for DGLGraphs.
......@@ -868,12 +877,18 @@ class EdgeDataLoader(DataLoader):
g, train_eid, sampler,
batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
"""
def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False,
def __init__(self, graph, indices, graph_sampler, device=None, use_ddp=False,
ddp_seed=0, batch_size=1, drop_last=False, shuffle=False,
use_prefetch_thread=False, use_alternate_streams=True,
pin_prefetcher=False,
exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None,
use_uva=False, **kwargs):
if device is None:
if use_uva:
device = torch.cuda.current_device()
else:
device = self.graph.device
device = _get_device(device)
if isinstance(graph_sampler, BlockSampler):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment