Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f91a97d0
Unverified
Commit
f91a97d0
authored
Jun 09, 2021
by
Quan (Andy) Gan
Committed by
GitHub
Jun 09, 2021
Browse files
[Dataloader] Fix compatibility of DistributedSampler for older PyTorch versions (#2997)
* fix compatibility * fix * lint
parent
a7fe461c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
21 deletions
+20
-21
python/dgl/dataloading/pytorch/__init__.py
python/dgl/dataloading/pytorch/__init__.py
+20
-21
No files found.
python/dgl/dataloading/pytorch/__init__.py
View file @
f91a97d0
"""DGL PyTorch DataLoaders"""
"""DGL PyTorch DataLoaders"""
import
inspect
import
inspect
import
math
import
math
from
distutils.version
import
LooseVersion
import
torch
as
th
import
torch
as
th
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
...
@@ -12,6 +13,22 @@ from ...ndarray import NDArray as DGLNDArray
...
@@ -12,6 +13,22 @@ from ...ndarray import NDArray as DGLNDArray
from
...
import
backend
as
F
from
...
import
backend
as
F
from
...base
import
DGLError
from
...base
import
DGLError
PYTORCH_VER
=
LooseVersion
(
th
.
__version__
)
PYTORCH_16
=
PYTORCH_VER
>=
LooseVersion
(
"1.6.0"
)
PYTORCH_17
=
PYTORCH_VER
>=
LooseVersion
(
"1.7.0"
)
def
_create_dist_sampler
(
dataset
,
dataloader_kwargs
,
ddp_seed
):
# Note: will change the content of dataloader_kwargs
dist_sampler_kwargs
=
{
'shuffle'
:
dataloader_kwargs
[
'shuffle'
]}
dataloader_kwargs
[
'shuffle'
]
=
False
if
PYTORCH_16
:
dist_sampler_kwargs
[
'seed'
]
=
ddp_seed
if
PYTORCH_17
:
dist_sampler_kwargs
[
'drop_last'
]
=
dataloader_kwargs
[
'drop_last'
]
dataloader_kwargs
[
'drop_last'
]
=
False
return
DistributedSampler
(
dataset
,
**
dist_sampler_kwargs
)
class
_ScalarDataBatcherIter
:
class
_ScalarDataBatcherIter
:
def
__init__
(
self
,
dataset
,
batch_size
,
drop_last
):
def
__init__
(
self
,
dataset
,
batch_size
,
drop_last
):
self
.
dataset
=
dataset
self
.
dataset
=
dataset
...
@@ -449,13 +466,7 @@ class NodeDataLoader:
...
@@ -449,13 +466,7 @@ class NodeDataLoader:
self
.
use_ddp
=
use_ddp
self
.
use_ddp
=
use_ddp
self
.
use_scalar_batcher
=
use_scalar_batcher
self
.
use_scalar_batcher
=
use_scalar_batcher
if
use_ddp
and
not
use_scalar_batcher
:
if
use_ddp
and
not
use_scalar_batcher
:
self
.
dist_sampler
=
DistributedSampler
(
self
.
dist_sampler
=
_create_dist_sampler
(
dataset
,
dataloader_kwargs
,
ddp_seed
)
dataset
,
shuffle
=
dataloader_kwargs
[
'shuffle'
],
drop_last
=
dataloader_kwargs
[
'drop_last'
],
seed
=
ddp_seed
)
dataloader_kwargs
[
'shuffle'
]
=
False
dataloader_kwargs
[
'drop_last'
]
=
False
dataloader_kwargs
[
'sampler'
]
=
self
.
dist_sampler
dataloader_kwargs
[
'sampler'
]
=
self
.
dist_sampler
self
.
dataloader
=
DataLoader
(
self
.
dataloader
=
DataLoader
(
...
@@ -724,13 +735,7 @@ class EdgeDataLoader:
...
@@ -724,13 +735,7 @@ class EdgeDataLoader:
self
.
use_ddp
=
use_ddp
self
.
use_ddp
=
use_ddp
if
use_ddp
:
if
use_ddp
:
self
.
dist_sampler
=
DistributedSampler
(
self
.
dist_sampler
=
_create_dist_sampler
(
dataset
,
dataloader_kwargs
,
ddp_seed
)
dataset
,
shuffle
=
dataloader_kwargs
[
'shuffle'
],
drop_last
=
dataloader_kwargs
[
'drop_last'
],
seed
=
ddp_seed
)
dataloader_kwargs
[
'shuffle'
]
=
False
dataloader_kwargs
[
'drop_last'
]
=
False
dataloader_kwargs
[
'sampler'
]
=
self
.
dist_sampler
dataloader_kwargs
[
'sampler'
]
=
self
.
dist_sampler
self
.
dataloader
=
DataLoader
(
self
.
dataloader
=
DataLoader
(
...
@@ -835,13 +840,7 @@ class GraphDataLoader:
...
@@ -835,13 +840,7 @@ class GraphDataLoader:
self
.
use_ddp
=
use_ddp
self
.
use_ddp
=
use_ddp
if
use_ddp
:
if
use_ddp
:
self
.
dist_sampler
=
DistributedSampler
(
self
.
dist_sampler
=
_create_dist_sampler
(
dataset
,
dataloader_kwargs
,
ddp_seed
)
dataset
,
shuffle
=
dataloader_kwargs
[
'shuffle'
],
drop_last
=
dataloader_kwargs
[
'drop_last'
],
seed
=
ddp_seed
)
dataloader_kwargs
[
'shuffle'
]
=
False
dataloader_kwargs
[
'drop_last'
]
=
False
dataloader_kwargs
[
'sampler'
]
=
self
.
dist_sampler
dataloader_kwargs
[
'sampler'
]
=
self
.
dist_sampler
self
.
dataloader
=
DataLoader
(
dataset
=
dataset
,
self
.
dataloader
=
DataLoader
(
dataset
=
dataset
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment