Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
44638b93
Unverified
Commit
44638b93
authored
Mar 07, 2022
by
Quan (Andy) Gan
Committed by
GitHub
Mar 07, 2022
Browse files
fix ddp dataloader in heterogeneous cases (#3801)
parent
bb6cec23
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
7 deletions
+9
-7
examples/pytorch/graphsage/multi_gpu_node_classification.py
examples/pytorch/graphsage/multi_gpu_node_classification.py
+2
-2
python/dgl/dataloading/dataloader.py
python/dgl/dataloading/dataloader.py
+7
-5
No files found.
examples/pytorch/graphsage/multi_gpu_node_classification.py
View file @
44638b93
...
@@ -36,7 +36,7 @@ class SAGE(nn.Module):
...
@@ -36,7 +36,7 @@ class SAGE(nn.Module):
# example is that the intermediate results can also benefit from prefetching.
# example is that the intermediate results can also benefit from prefetching.
g
.
ndata
[
'h'
]
=
g
.
ndata
[
'feat'
]
g
.
ndata
[
'h'
]
=
g
.
ndata
[
'feat'
]
sampler
=
dgl
.
dataloading
.
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
'h'
])
sampler
=
dgl
.
dataloading
.
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
'h'
])
dataloader
=
dgl
.
dataloading
.
Node
DataLoader
(
dataloader
=
dgl
.
dataloading
.
DataLoader
(
g
,
torch
.
arange
(
g
.
num_nodes
()).
to
(
g
.
device
),
sampler
,
device
=
device
,
g
,
torch
.
arange
(
g
.
num_nodes
()).
to
(
g
.
device
),
sampler
,
device
=
device
,
batch_size
=
1000
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
num_workers
,
batch_size
=
1000
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
num_workers
,
persistent_workers
=
(
num_workers
>
0
))
persistent_workers
=
(
num_workers
>
0
))
...
@@ -77,7 +77,7 @@ def train(rank, world_size, graph, num_classes, split_idx):
...
@@ -77,7 +77,7 @@ def train(rank, world_size, graph, num_classes, split_idx):
graph
,
train_idx
,
sampler
,
graph
,
train_idx
,
sampler
,
device
=
'cuda'
,
batch_size
=
1000
,
shuffle
=
True
,
drop_last
=
False
,
device
=
'cuda'
,
batch_size
=
1000
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_ddp
=
True
,
use_uva
=
True
)
num_workers
=
0
,
use_ddp
=
True
,
use_uva
=
True
)
valid_dataloader
=
dgl
.
dataloading
.
Node
DataLoader
(
valid_dataloader
=
dgl
.
dataloading
.
DataLoader
(
graph
,
valid_idx
,
sampler
,
device
=
'cuda'
,
batch_size
=
1024
,
shuffle
=
True
,
graph
,
valid_idx
,
sampler
,
device
=
'cuda'
,
batch_size
=
1024
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
True
)
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
True
)
...
...
python/dgl/dataloading/dataloader.py
View file @
44638b93
...
@@ -169,8 +169,10 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
...
@@ -169,8 +169,10 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
def
__init__
(
self
,
indices
,
batch_size
,
drop_last
,
ddp_seed
):
def
__init__
(
self
,
indices
,
batch_size
,
drop_last
,
ddp_seed
):
if
isinstance
(
indices
,
Mapping
):
if
isinstance
(
indices
,
Mapping
):
self
.
_mapping_keys
=
list
(
indices
.
keys
())
self
.
_mapping_keys
=
list
(
indices
.
keys
())
len_indices
=
sum
(
len
(
v
)
for
v
in
indices
.
values
())
else
:
else
:
self
.
_mapping_keys
=
None
self
.
_mapping_keys
=
None
len_indices
=
len
(
indices
)
self
.
rank
=
dist
.
get_rank
()
self
.
rank
=
dist
.
get_rank
()
self
.
num_replicas
=
dist
.
get_world_size
()
self
.
num_replicas
=
dist
.
get_world_size
()
...
@@ -179,17 +181,17 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
...
@@ -179,17 +181,17 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
drop_last
=
drop_last
self
.
drop_last
=
drop_last
if
self
.
drop_last
and
len
(
indices
)
%
self
.
num_replicas
!=
0
:
if
self
.
drop_last
and
len
_
indices
%
self
.
num_replicas
!=
0
:
self
.
num_samples
=
math
.
ceil
((
len
(
indices
)
-
self
.
num_replicas
)
/
self
.
num_replicas
)
self
.
num_samples
=
math
.
ceil
((
len
_
indices
-
self
.
num_replicas
)
/
self
.
num_replicas
)
else
:
else
:
self
.
num_samples
=
math
.
ceil
(
len
(
indices
)
/
self
.
num_replicas
)
self
.
num_samples
=
math
.
ceil
(
len
_
indices
/
self
.
num_replicas
)
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
# If drop_last is True, we create a shared memory array larger than the number
# If drop_last is True, we create a shared memory array larger than the number
# of indices since we will need to pad it after shuffling to make it evenly
# of indices since we will need to pad it after shuffling to make it evenly
# divisible before every epoch. If drop_last is False, we create an array
# divisible before every epoch. If drop_last is False, we create an array
# with the same size as the indices so we can trim it later.
# with the same size as the indices so we can trim it later.
self
.
shared_mem_size
=
self
.
total_size
if
not
self
.
drop_last
else
len
(
indices
)
self
.
shared_mem_size
=
self
.
total_size
if
not
self
.
drop_last
else
len
_
indices
self
.
num_indices
=
len
(
indices
)
self
.
num_indices
=
len
_
indices
if
isinstance
(
indices
,
Mapping
):
if
isinstance
(
indices
,
Mapping
):
self
.
_device
=
next
(
iter
(
indices
.
values
())).
device
self
.
_device
=
next
(
iter
(
indices
.
values
())).
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment