Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
03ca11f5
Unverified
Commit
03ca11f5
authored
Jan 22, 2024
by
Rhett Ying
Committed by
GitHub
Jan 22, 2024
Browse files
[GraphBolt] fix random generator for shuffle among all workers (#6982)
parent
351c860a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
1 deletion
+8
-1
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+8
-1
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
03ca11f5
...
...
@@ -115,6 +115,8 @@ class ItemShufflerAndBatcher:
rank : int
The rank of the current replica. Applies only when `distributed` is
True.
rng : np.random.Generator
The random number generator to use for shuffling.
"""
def
__init__
(
...
...
@@ -128,6 +130,7 @@ class ItemShufflerAndBatcher:
drop_uneven_inputs
:
Optional
[
bool
]
=
False
,
world_size
:
Optional
[
int
]
=
1
,
rank
:
Optional
[
int
]
=
0
,
rng
:
Optional
[
np
.
random
.
Generator
]
=
None
,
):
self
.
_item_set
=
item_set
self
.
_shuffle
=
shuffle
...
...
@@ -142,6 +145,7 @@ class ItemShufflerAndBatcher:
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
self
.
_num_replicas
=
world_size
self
.
_rank
=
rank
self
.
_rng
=
rng
def
_collate_batch
(
self
,
buffer
,
indices
,
offsets
=
None
):
"""Collate a batch from the buffer. For internal use only."""
...
...
@@ -216,7 +220,7 @@ class ItemShufflerAndBatcher:
buffer
=
self
.
_item_set
[
start_offset
+
start
:
start_offset
+
end
]
indices
=
torch
.
arange
(
end
-
start
)
if
self
.
_shuffle
:
np
.
random
.
shuffle
(
indices
.
numpy
())
self
.
_rng
.
shuffle
(
indices
.
numpy
())
offsets
=
self
.
_calculate_offsets
(
buffer
)
for
i
in
range
(
0
,
len
(
indices
),
self
.
_batch_size
):
if
output_count
<=
0
:
...
...
@@ -494,6 +498,7 @@ class ItemSampler(IterDataPipe):
self
.
_drop_uneven_inputs
=
False
self
.
_world_size
=
None
self
.
_rank
=
None
self
.
_rng
=
np
.
random
.
default_rng
()
def
_organize_items
(
self
,
data_pipe
)
->
None
:
# Shuffle before batch.
...
...
@@ -529,6 +534,7 @@ class ItemSampler(IterDataPipe):
def
__iter__
(
self
)
->
Iterator
:
if
self
.
_use_indexing
:
seed
=
self
.
_rng
.
integers
(
0
,
np
.
iinfo
(
np
.
int32
).
max
)
data_pipe
=
IterableWrapper
(
ItemShufflerAndBatcher
(
self
.
_item_set
,
...
...
@@ -540,6 +546,7 @@ class ItemSampler(IterDataPipe):
drop_uneven_inputs
=
self
.
_drop_uneven_inputs
,
world_size
=
self
.
_world_size
,
rank
=
self
.
_rank
,
rng
=
np
.
random
.
default_rng
(
seed
),
)
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment