Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
44638b93
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "abe30c3501db01b201e7c955185ad57fcb06f123"
Unverified
Commit
44638b93
authored
Mar 07, 2022
by
Quan (Andy) Gan
Committed by
GitHub
Mar 07, 2022
Browse files
fix ddp dataloader in heterogeneous cases (#3801)
parent
bb6cec23
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
7 deletions
+9
-7
examples/pytorch/graphsage/multi_gpu_node_classification.py
examples/pytorch/graphsage/multi_gpu_node_classification.py
+2
-2
python/dgl/dataloading/dataloader.py
python/dgl/dataloading/dataloader.py
+7
-5
No files found.
examples/pytorch/graphsage/multi_gpu_node_classification.py
View file @
44638b93
...
@@ -36,7 +36,7 @@ class SAGE(nn.Module):
...
@@ -36,7 +36,7 @@ class SAGE(nn.Module):
# example is that the intermediate results can also benefit from prefetching.
# example is that the intermediate results can also benefit from prefetching.
g
.
ndata
[
'h'
]
=
g
.
ndata
[
'feat'
]
g
.
ndata
[
'h'
]
=
g
.
ndata
[
'feat'
]
sampler
=
dgl
.
dataloading
.
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
'h'
])
sampler
=
dgl
.
dataloading
.
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
'h'
])
dataloader
=
dgl
.
dataloading
.
Node
DataLoader
(
dataloader
=
dgl
.
dataloading
.
DataLoader
(
g
,
torch
.
arange
(
g
.
num_nodes
()).
to
(
g
.
device
),
sampler
,
device
=
device
,
g
,
torch
.
arange
(
g
.
num_nodes
()).
to
(
g
.
device
),
sampler
,
device
=
device
,
batch_size
=
1000
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
num_workers
,
batch_size
=
1000
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
num_workers
,
persistent_workers
=
(
num_workers
>
0
))
persistent_workers
=
(
num_workers
>
0
))
...
@@ -77,7 +77,7 @@ def train(rank, world_size, graph, num_classes, split_idx):
...
@@ -77,7 +77,7 @@ def train(rank, world_size, graph, num_classes, split_idx):
graph
,
train_idx
,
sampler
,
graph
,
train_idx
,
sampler
,
device
=
'cuda'
,
batch_size
=
1000
,
shuffle
=
True
,
drop_last
=
False
,
device
=
'cuda'
,
batch_size
=
1000
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_ddp
=
True
,
use_uva
=
True
)
num_workers
=
0
,
use_ddp
=
True
,
use_uva
=
True
)
valid_dataloader
=
dgl
.
dataloading
.
Node
DataLoader
(
valid_dataloader
=
dgl
.
dataloading
.
DataLoader
(
graph
,
valid_idx
,
sampler
,
device
=
'cuda'
,
batch_size
=
1024
,
shuffle
=
True
,
graph
,
valid_idx
,
sampler
,
device
=
'cuda'
,
batch_size
=
1024
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
True
)
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
True
)
...
...
python/dgl/dataloading/dataloader.py
View file @
44638b93
...
@@ -169,8 +169,10 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
...
@@ -169,8 +169,10 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
def
__init__
(
self
,
indices
,
batch_size
,
drop_last
,
ddp_seed
):
def
__init__
(
self
,
indices
,
batch_size
,
drop_last
,
ddp_seed
):
if
isinstance
(
indices
,
Mapping
):
if
isinstance
(
indices
,
Mapping
):
self
.
_mapping_keys
=
list
(
indices
.
keys
())
self
.
_mapping_keys
=
list
(
indices
.
keys
())
len_indices
=
sum
(
len
(
v
)
for
v
in
indices
.
values
())
else
:
else
:
self
.
_mapping_keys
=
None
self
.
_mapping_keys
=
None
len_indices
=
len
(
indices
)
self
.
rank
=
dist
.
get_rank
()
self
.
rank
=
dist
.
get_rank
()
self
.
num_replicas
=
dist
.
get_world_size
()
self
.
num_replicas
=
dist
.
get_world_size
()
...
@@ -179,17 +181,17 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
...
@@ -179,17 +181,17 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
drop_last
=
drop_last
self
.
drop_last
=
drop_last
if
self
.
drop_last
and
len
(
indices
)
%
self
.
num_replicas
!=
0
:
if
self
.
drop_last
and
len
_
indices
%
self
.
num_replicas
!=
0
:
self
.
num_samples
=
math
.
ceil
((
len
(
indices
)
-
self
.
num_replicas
)
/
self
.
num_replicas
)
self
.
num_samples
=
math
.
ceil
((
len
_
indices
-
self
.
num_replicas
)
/
self
.
num_replicas
)
else
:
else
:
self
.
num_samples
=
math
.
ceil
(
len
(
indices
)
/
self
.
num_replicas
)
self
.
num_samples
=
math
.
ceil
(
len
_
indices
/
self
.
num_replicas
)
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
# If drop_last is True, we create a shared memory array larger than the number
# If drop_last is True, we create a shared memory array larger than the number
# of indices since we will need to pad it after shuffling to make it evenly
# of indices since we will need to pad it after shuffling to make it evenly
# divisible before every epoch. If drop_last is False, we create an array
# divisible before every epoch. If drop_last is False, we create an array
# with the same size as the indices so we can trim it later.
# with the same size as the indices so we can trim it later.
self
.
shared_mem_size
=
self
.
total_size
if
not
self
.
drop_last
else
len
(
indices
)
self
.
shared_mem_size
=
self
.
total_size
if
not
self
.
drop_last
else
len
_
indices
self
.
num_indices
=
len
(
indices
)
self
.
num_indices
=
len
_
indices
if
isinstance
(
indices
,
Mapping
):
if
isinstance
(
indices
,
Mapping
):
self
.
_device
=
next
(
iter
(
indices
.
values
())).
device
self
.
_device
=
next
(
iter
(
indices
.
values
())).
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment