Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
39181113
Commit
39181113
authored
Dec 11, 2020
by
mshoeybi
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Last epoch should not be globally shuffled
parent
56243e19
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
14 deletions
+74
-14
megatron/data/gpt2_dataset.py
megatron/data/gpt2_dataset.py
+74
-14
No files found.
megatron/data/gpt2_dataset.py
View file @
39181113
...
...
@@ -219,9 +219,47 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
print_rank_0
(
' > WARNING: could not find index map files, building '
'the indices on rank 0 ...'
)
# For the last epoch, decide whether include the entire epoch
# in the global shuffle or not.
# If we need only one epoch, then separating last epoch does
# not mean anything.
if
num_epochs
==
1
:
separate_last_epoch
=
False
print
(
' > only one epoch required, setting '
'separate_last_epoch to False'
,
flush
=
True
)
else
:
# Get the number of samples for the last epoch
num_samples_from_epochs_minus_one
=
(
(
num_epochs
-
1
)
*
tokens_per_epoch
-
1
)
//
seq_length
last_epoch_num_samples
=
num_samples
-
\
num_samples_from_epochs_minus_one
assert
last_epoch_num_samples
>=
0
,
\
'last epoch number of samples should be non-negative.'
num_samples_per_epoch
=
(
tokens_per_epoch
-
1
)
//
seq_length
assert
last_epoch_num_samples
<
(
num_samples_per_epoch
+
1
),
\
'last epoch number of samples exceeded max value.'
# If we have less than 80% of the samples for the last epoch,
# seperate out the epoch and treat it differently.
separate_last_epoch
=
(
last_epoch_num_samples
<
int
(
0.80
*
num_samples_per_epoch
))
if
separate_last_epoch
:
string
=
' > last epoch number of samples ({}) is smaller '
\
'than 80% of number of samples per epoch ({}), '
\
'setting separate_last_epoch to True'
else
:
string
=
' > last epoch number of samples ({}) is larger '
\
'than 80% of number of samples per epoch ({}), '
\
'setting separate_last_epoch to False'
print
(
string
.
format
(
last_epoch_num_samples
,
num_samples_per_epoch
),
flush
=
True
)
# doc-idx.
start_time
=
time
.
time
()
doc_idx
=
_build_doc_idx
(
documents
,
num_epochs
,
np_rng
)
doc_idx
=
_build_doc_idx
(
documents
,
num_epochs
,
np_rng
,
separate_last_epoch
)
np
.
save
(
doc_idx_filename
,
doc_idx
,
allow_pickle
=
True
)
print_rank_0
(
' > elasped time to build and save doc-idx mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
...
...
@@ -245,7 +283,12 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time
=
time
.
time
()
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx
=
_build_shuffle_idx
(
sample_idx
.
shape
[
0
]
-
1
,
np_rng
)
if
separate_last_epoch
:
num_samples_
=
num_samples_from_epochs_minus_one
else
:
num_samples_
=
sample_idx
.
shape
[
0
]
-
1
shuffle_idx
=
_build_shuffle_idx
(
num_samples_
,
sample_idx
.
shape
[
0
]
-
1
,
np_rng
)
np
.
save
(
shuffle_idx_filename
,
shuffle_idx
,
allow_pickle
=
True
)
print_rank_0
(
' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
...
...
@@ -300,15 +343,20 @@ def _num_epochs(tokens_per_epoch, seq_length, num_samples):
return
num_epochs
def
_build_doc_idx
(
documents
,
num_epochs
,
np_rng
):
def
_build_doc_idx
(
documents
,
num_epochs
,
np_rng
,
separate_last_epoch
):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
doc_idx
=
np
.
mgrid
[
0
:
num_epochs
,
0
:
len
(
documents
)][
1
]
doc_idx
[:]
=
documents
doc_idx
=
doc_idx
.
reshape
(
-
1
)
doc_idx
=
doc_idx
.
astype
(
np
.
int32
)
np_rng
.
shuffle
(
doc_idx
)
return
doc_idx
if
not
separate_last_epoch
or
num_epochs
==
1
:
doc_idx
=
np
.
mgrid
[
0
:
num_epochs
,
0
:
len
(
documents
)][
1
]
doc_idx
[:]
=
documents
doc_idx
=
doc_idx
.
reshape
(
-
1
)
doc_idx
=
doc_idx
.
astype
(
np
.
int32
)
np_rng
.
shuffle
(
doc_idx
)
return
doc_idx
doc_idx_first
=
_build_doc_idx
(
documents
,
num_epochs
-
1
,
np_rng
,
False
)
doc_idx_last
=
_build_doc_idx
(
documents
,
1
,
np_rng
,
False
)
return
np
.
concatenate
((
doc_idx_first
,
doc_idx_last
))
def
_build_sample_idx
(
sizes
,
doc_idx
,
seq_length
,
...
...
@@ -360,11 +408,23 @@ def _build_sample_idx(sizes, doc_idx, seq_length,
return
sample_idx
def
_build_shuffle_idx
(
size
,
np_rng
):
def
def
_build_shuffle_idx
(
num_samples
,
total_
size
,
np_rng
):
"""Build the range [0, size) and shuffle."""
print
(
' > building shuffle index with split [0, {}) and [{}, {}) '
'...'
.
format
(
num_samples
,
num_samples
,
total_size
),
flush
=
True
)
dtype_
=
np
.
uint32
if
size
>=
(
np
.
iinfo
(
np
.
uint32
).
max
-
1
):
if
total_
size
>=
(
np
.
iinfo
(
np
.
uint32
).
max
-
1
):
dtype_
=
np
.
int64
shuffle_idx
=
np
.
arange
(
start
=
0
,
stop
=
size
,
step
=
1
,
dtype
=
dtype_
)
np_rng
.
shuffle
(
shuffle_idx
)
return
shuffle_idx
shuffle_idx_first
=
np
.
arange
(
start
=
0
,
stop
=
num_samples
,
step
=
1
,
dtype
=
dtype_
)
np_rng
.
shuffle
(
shuffle_idx_first
)
if
num_samples
==
total_size
:
return
shuffle_idx_first
shuffle_idx_last
=
np
.
arange
(
start
=
num_samples
,
stop
=
total_size
,
step
=
1
,
dtype
=
dtype_
)
np_rng
.
shuffle
(
shuffle_idx_last
)
return
np
.
concatenate
((
shuffle_idx_first
,
shuffle_idx_last
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment