Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4c147814
Unverified
Commit
4c147814
authored
May 14, 2022
by
Quan (Andy) Gan
Committed by
GitHub
May 14, 2022
Browse files
[Optimization] Memory consumption optimization on index shuffling in dataloader (#3980)
* fix * revert * Update dataloader.py
parent
65e6b04d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
11 deletions
+15
-11
python/dgl/dataloading/dataloader.py
python/dgl/dataloading/dataloader.py
+15
-11
No files found.
python/dgl/dataloading/dataloader.py
View file @
4c147814
...
@@ -11,6 +11,7 @@ import atexit
...
@@ -11,6 +11,7 @@ import atexit
import
os
import
os
import
psutil
import
psutil
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
...
@@ -134,15 +135,13 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
...
@@ -134,15 +135,13 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
# Use a shared memory array to permute indices for shuffling. This is to make sure that
# Use a shared memory array to permute indices for shuffling. This is to make sure that
# the worker processes can see it when persistent_workers=True, where self._indices
# the worker processes can see it when persistent_workers=True, where self._indices
# would not be duplicated every epoch.
# would not be duplicated every epoch.
self
.
_indices
=
torch
.
empty
(
self
.
_id_tensor
.
shape
[
0
],
dtype
=
torch
.
int64
).
share_memory_
()
self
.
_indices
=
torch
.
arange
(
self
.
_id_tensor
.
shape
[
0
],
dtype
=
torch
.
int64
).
share_memory_
()
self
.
_indices
[:]
=
torch
.
arange
(
self
.
_id_tensor
.
shape
[
0
])
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
drop_last
=
drop_last
self
.
drop_last
=
drop_last
def
shuffle
(
self
):
def
shuffle
(
self
):
"""Shuffle the dataset."""
"""Shuffle the dataset."""
# TODO: may need an in-place shuffle kernel
np
.
random
.
shuffle
(
self
.
_indices
.
numpy
())
self
.
_indices
[:]
=
self
.
_indices
[
torch
.
randperm
(
self
.
_indices
.
shape
[
0
])]
def
__iter__
(
self
):
def
__iter__
(
self
):
indices
=
_divide_by_worker
(
self
.
_indices
,
self
.
batch_size
,
self
.
drop_last
)
indices
=
_divide_by_worker
(
self
.
_indices
,
self
.
batch_size
,
self
.
drop_last
)
...
@@ -203,16 +202,20 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
...
@@ -203,16 +202,20 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
def
_create_shared_indices
(
self
):
def
_create_shared_indices
(
self
):
indices
=
torch
.
empty
(
self
.
shared_mem_size
,
dtype
=
torch
.
int64
)
indices
=
torch
.
empty
(
self
.
shared_mem_size
,
dtype
=
torch
.
int64
)
num_ids
=
self
.
_id_tensor
.
shape
[
0
]
num_ids
=
self
.
_id_tensor
.
shape
[
0
]
indices
[:
num_ids
]
=
torch
.
arange
(
num_ids
)
torch
.
arange
(
num_ids
,
out
=
indices
[:
num_ids
]
)
indices
[
num_ids
:]
=
torch
.
arange
(
self
.
shared_mem_size
-
num_ids
)
torch
.
arange
(
self
.
shared_mem_size
-
num_ids
,
out
=
indices
[
num_ids
:]
)
return
indices
return
indices
def
shuffle
(
self
):
def
shuffle
(
self
):
"""Shuffles the dataset."""
"""Shuffles the dataset."""
# Only rank 0 does the actual shuffling. The other ranks wait for it.
# Only rank 0 does the actual shuffling. The other ranks wait for it.
if
self
.
rank
==
0
:
if
self
.
rank
==
0
:
self
.
_indices
[:
self
.
num_indices
]
=
self
.
_indices
[
if
self
.
_device
==
torch
.
device
(
'cpu'
):
torch
.
randperm
(
self
.
num_indices
,
device
=
self
.
_device
)]
np
.
random
.
shuffle
(
self
.
_indices
[:
self
.
num_indices
].
numpy
())
else
:
self
.
_indices
[:
self
.
num_indices
]
=
self
.
_indices
[
torch
.
randperm
(
self
.
num_indices
,
device
=
self
.
_device
)]
if
not
self
.
drop_last
:
if
not
self
.
drop_last
:
# pad extra
# pad extra
self
.
_indices
[
self
.
num_indices
:]
=
\
self
.
_indices
[
self
.
num_indices
:]
=
\
...
@@ -514,9 +517,10 @@ class CollateWrapper(object):
...
@@ -514,9 +517,10 @@ class CollateWrapper(object):
self
.
device
=
device
self
.
device
=
device
def
__call__
(
self
,
items
):
def
__call__
(
self
,
items
):
if
self
.
use_uva
or
(
self
.
g
.
device
!=
torch
.
device
(
'cpu'
)):
graph_device
=
getattr
(
self
.
g
,
'device'
,
None
)
# Only copy the indices to the given device if in UVA mode or the graph is not on
if
self
.
use_uva
or
(
graph_device
!=
torch
.
device
(
'cpu'
)):
# CPU.
# Only copy the indices to the given device if in UVA mode or the graph
# is not on CPU.
items
=
recursive_apply
(
items
,
lambda
x
:
x
.
to
(
self
.
device
))
items
=
recursive_apply
(
items
,
lambda
x
:
x
.
to
(
self
.
device
))
batch
=
self
.
sample_func
(
self
.
g
,
items
)
batch
=
self
.
sample_func
(
self
.
g
,
items
)
return
recursive_apply
(
batch
,
remove_parent_storage_columns
,
self
.
g
)
return
recursive_apply
(
batch
,
remove_parent_storage_columns
,
self
.
g
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment