Unverified Commit f91a97d0 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Dataloader] Fix compatibility of DistributedSampler for older PyTorch versions (#2997)

* fix compatibility

* fix

* lint
parent a7fe461c
"""DGL PyTorch DataLoaders"""
import inspect
import math
from distutils.version import LooseVersion
import torch as th
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
......@@ -12,6 +13,22 @@ from ...ndarray import NDArray as DGLNDArray
from ... import backend as F
from ...base import DGLError
PYTORCH_VER = LooseVersion(th.__version__)
PYTORCH_16 = PYTORCH_VER >= LooseVersion("1.6.0")
PYTORCH_17 = PYTORCH_VER >= LooseVersion("1.7.0")
def _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed):
# Note: will change the content of dataloader_kwargs
dist_sampler_kwargs = {'shuffle': dataloader_kwargs['shuffle']}
dataloader_kwargs['shuffle'] = False
if PYTORCH_16:
dist_sampler_kwargs['seed'] = ddp_seed
if PYTORCH_17:
dist_sampler_kwargs['drop_last'] = dataloader_kwargs['drop_last']
dataloader_kwargs['drop_last'] = False
return DistributedSampler(dataset, **dist_sampler_kwargs)
class _ScalarDataBatcherIter:
def __init__(self, dataset, batch_size, drop_last):
self.dataset = dataset
......@@ -449,13 +466,7 @@ class NodeDataLoader:
self.use_ddp = use_ddp
self.use_scalar_batcher = use_scalar_batcher
if use_ddp and not use_scalar_batcher:
self.dist_sampler = DistributedSampler(
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'],
seed=ddp_seed)
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader(
......@@ -724,13 +735,7 @@ class EdgeDataLoader:
self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = DistributedSampler(
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'],
seed=ddp_seed)
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader(
......@@ -835,13 +840,7 @@ class GraphDataLoader:
self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = DistributedSampler(
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'],
seed=ddp_seed)
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader(dataset=dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment