Unverified Commit f91a97d0 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Dataloader] Fix compatibility of DistributedSampler for older PyTorch versions (#2997)

* fix compatibility

* fix

* lint
parent a7fe461c
"""DGL PyTorch DataLoaders""" """DGL PyTorch DataLoaders"""
import inspect import inspect
import math import math
from distutils.version import LooseVersion
import torch as th import torch as th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -12,6 +13,22 @@ from ...ndarray import NDArray as DGLNDArray ...@@ -12,6 +13,22 @@ from ...ndarray import NDArray as DGLNDArray
from ... import backend as F from ... import backend as F
from ...base import DGLError from ...base import DGLError
PYTORCH_VER = LooseVersion(th.__version__)
PYTORCH_16 = PYTORCH_VER >= LooseVersion("1.6.0")
PYTORCH_17 = PYTORCH_VER >= LooseVersion("1.7.0")
def _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed):
# Note: will change the content of dataloader_kwargs
dist_sampler_kwargs = {'shuffle': dataloader_kwargs['shuffle']}
dataloader_kwargs['shuffle'] = False
if PYTORCH_16:
dist_sampler_kwargs['seed'] = ddp_seed
if PYTORCH_17:
dist_sampler_kwargs['drop_last'] = dataloader_kwargs['drop_last']
dataloader_kwargs['drop_last'] = False
return DistributedSampler(dataset, **dist_sampler_kwargs)
class _ScalarDataBatcherIter: class _ScalarDataBatcherIter:
def __init__(self, dataset, batch_size, drop_last): def __init__(self, dataset, batch_size, drop_last):
self.dataset = dataset self.dataset = dataset
...@@ -449,13 +466,7 @@ class NodeDataLoader: ...@@ -449,13 +466,7 @@ class NodeDataLoader:
self.use_ddp = use_ddp self.use_ddp = use_ddp
self.use_scalar_batcher = use_scalar_batcher self.use_scalar_batcher = use_scalar_batcher
if use_ddp and not use_scalar_batcher: if use_ddp and not use_scalar_batcher:
self.dist_sampler = DistributedSampler( self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'],
seed=ddp_seed)
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader( self.dataloader = DataLoader(
...@@ -724,13 +735,7 @@ class EdgeDataLoader: ...@@ -724,13 +735,7 @@ class EdgeDataLoader:
self.use_ddp = use_ddp self.use_ddp = use_ddp
if use_ddp: if use_ddp:
self.dist_sampler = DistributedSampler( self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'],
seed=ddp_seed)
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader( self.dataloader = DataLoader(
...@@ -835,13 +840,7 @@ class GraphDataLoader: ...@@ -835,13 +840,7 @@ class GraphDataLoader:
self.use_ddp = use_ddp self.use_ddp = use_ddp
if use_ddp: if use_ddp:
self.dist_sampler = DistributedSampler( self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'],
seed=ddp_seed)
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader(dataset=dataset, self.dataloader = DataLoader(dataset=dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment