Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fe207787
Commit
fe207787
authored
Apr 06, 2020
by
Mohammad
Browse files
added new gpt2 dataloder
parent
1eccd462
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
385 additions
and
0 deletions
+385
-0
megatron/data/new_gpt2_dataset.py
megatron/data/new_gpt2_dataset.py
+385
-0
No files found.
megatron/data/new_gpt2_dataset.py
0 → 100644
View file @
fe207787
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 Style dataset."""
import
os
import
time
import
numpy
as
np
import
torch
from
torch.utils.data
import
Dataset
def
print_rank_0
(
message
):
print
(
message
)
class
GPT2Dataset
(
Dataset
):
def
__init__
(
self
,
name
,
data_prefix
,
documents
,
indexed_dataset
,
num_samples
,
seq_length
,
seed
):
self
.
name
=
name
self
.
data_prefix
=
data_prefix
self
.
num_samples
=
num_samples
self
.
seq_length
=
seq_length
self
.
seed
=
seed
self
.
indexed_dataset
=
indexed_dataset
# Checks
assert
np
.
min
(
documents
)
>=
0
assert
np
.
max
(
documents
)
<
indexed_dataset
.
sizes
.
shape
[
0
]
# Build index mappings.
self
.
num_epochs
,
self
.
doc_idx
,
self
.
sample_idx
,
self
.
shuffle_idx
\
=
_build_index_mappings
(
self
.
name
,
self
.
data_prefix
,
documents
,
self
.
indexed_dataset
.
sizes
,
self
.
num_samples
,
self
.
seq_length
,
self
.
seed
)
def
__len__
(
self
):
return
self
.
sample_idx
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
# Get the shuffled index.
idx
=
self
.
shuffle_idx
[
idx
]
# Start and end documents and offsets.
doc_index_f
=
self
.
sample_idx
[
idx
][
0
]
doc_index_l
=
self
.
sample_idx
[
idx
+
1
][
0
]
offset_f
=
self
.
sample_idx
[
idx
][
1
]
offset_l
=
self
.
sample_idx
[
idx
+
1
][
1
]
# If we are within the same document, just extract the chunk.
if
doc_index_f
==
doc_index_l
:
sample
=
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
doc_index_f
],
offset
=
offset_f
,
length
=
offset_l
-
offset_f
+
1
)
else
:
# Otherwise, get the rest of the initial document.
sample_list
=
[
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
doc_index_f
],
offset
=
offset_f
)]
# Loop over all in between documents and add the entire document.
for
i
in
range
(
doc_index_f
+
1
,
doc_index_l
):
sample_list
.
append
(
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
i
]))
# And finally add the relevant portion of last document.
sample_list
.
append
(
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
doc_index_l
],
length
=
offset_l
+
1
))
sample
=
np
.
concatenate
(
sample_list
)
return
sample
def
_build_index_mappings
(
name
,
data_prefix
,
documents
,
sizes
,
num_samples
,
seq_length
,
seed
):
"""doc-idx, sample-idx, and shuffle-idx."""
# Number of tokens in each epoch and number of required epochs.
tokens_per_epoch
=
_num_tokens
(
documents
,
sizes
)
num_epochs
=
_num_epochs
(
tokens_per_epoch
,
seq_length
,
num_samples
)
# rng state
np_rng
=
np
.
random
.
RandomState
(
seed
=
seed
)
# Filename of the index mappings.
_filename
=
data_prefix
_filename
+=
'_{}_indexmap'
.
format
(
name
)
_filename
+=
'_{}ns'
.
format
(
num_samples
)
_filename
+=
'_{}sl'
.
format
(
seq_length
)
_filename
+=
'_{}s'
.
format
(
seed
)
doc_idx_filename
=
_filename
+
'_doc_idx.npy'
sample_idx_filename
=
_filename
+
'_sample_idx.npy'
shuffle_idx_filename
=
_filename
+
'_shuffle_idx.npy'
# Build the indexed mapping if not exist.
if
True
:
#torch.distributed.get_rank() == 0:
if
(
not
os
.
path
.
isfile
(
doc_idx_filename
))
or
\
(
not
os
.
path
.
isfile
(
sample_idx_filename
))
or
\
(
not
os
.
path
.
isfile
(
shuffle_idx_filename
)):
print_rank_0
(
' > WARNING: could not find index map files, building '
'the indices on rank 0 ...'
)
# doc-idx.
start_time
=
time
.
time
()
doc_idx
=
_build_doc_idx
(
documents
,
num_epochs
,
np_rng
)
np
.
save
(
doc_idx_filename
,
doc_idx
,
allow_pickle
=
True
)
print_rank_0
(
' > elasped time to build and save doc-idx mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# sample-idx.
start_time
=
time
.
time
()
sample_idx
=
_build_sample_idx
(
sizes
,
doc_idx
,
seq_length
,
num_epochs
,
tokens_per_epoch
)
np
.
save
(
sample_idx_filename
,
sample_idx
,
allow_pickle
=
True
)
print_rank_0
(
' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# shuffle-idx.
start_time
=
time
.
time
()
shuffle_idx
=
_build_shuffle_idx
(
sample_idx
.
shape
[
0
],
np_rng
)
np
.
save
(
shuffle_idx_filename
,
shuffle_idx
,
allow_pickle
=
True
)
print_rank_0
(
' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
#torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
#assert counts[0].item() == torch.distributed.get_world_size(
# group=mpu.get_data_parallel_group())
# Load mappings.
start_time
=
time
.
time
()
print_rank_0
(
' > loading doc-idx mapping from {}'
.
format
(
doc_idx_filename
))
doc_idx
=
np
.
load
(
doc_idx_filename
,
allow_pickle
=
True
)
print_rank_0
(
' > loading sample-idx mapping from {}'
.
format
(
sample_idx_filename
))
sample_idx
=
np
.
load
(
sample_idx_filename
,
allow_pickle
=
True
)
print_rank_0
(
' > loading shuffle-idx mapping from {}'
.
format
(
shuffle_idx_filename
))
shuffle_idx
=
np
.
load
(
shuffle_idx_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
sample_idx
.
shape
[
0
]))
return
num_epochs
,
doc_idx
,
sample_idx
,
shuffle_idx
def
_num_tokens
(
documents
,
sizes
):
"""Total number of tokens in the dataset."""
return
np
.
sum
(
sizes
[
documents
])
def
_num_epochs
(
tokens_per_epoch
,
seq_length
,
num_samples
):
"""Based on number of samples and sequence lenght, calculate how many
epochs will be needed."""
num_epochs
=
0
total_tokens
=
0
while
True
:
num_epochs
+=
1
total_tokens
+=
tokens_per_epoch
# -1 is because we need to retrieve seq_length + 1 token each time
# but the last token will overlap with the first token of the next
# sample except for the last sample.
if
((
total_tokens
-
1
)
//
seq_length
)
>=
num_samples
:
return
num_epochs
def
_build_doc_idx
(
documents
,
num_epochs
,
np_rng
):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
doc_idx
=
np
.
mgrid
[
0
:
num_epochs
,
0
:
len
(
documents
)][
1
]
doc_idx
[:]
=
documents
doc_idx
=
doc_idx
.
reshape
(
-
1
)
np_rng
.
shuffle
(
doc_idx
)
return
doc_idx
def
_build_sample_idx
(
sizes
,
doc_idx
,
seq_length
,
num_epochs
,
tokens_per_epoch
):
"""Sample index mapping is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains
the index into `doc_idx` and [..., 0] is the
starting offset in that document."""
# Total number of samples. For -1 see comments in `_num_epochs`.
num_samples
=
(
num_epochs
*
tokens_per_epoch
-
1
)
//
seq_length
sample_idx
=
np
.
zeros
([
num_samples
+
1
,
2
],
dtype
=
np
.
int32
)
# Index into sample_idx.
sample_index
=
0
# Index into doc_idx.
doc_idx_index
=
0
# Begining offset for each document.
doc_offset
=
0
# Start with first document and no offset.
sample_idx
[
sample_index
][
0
]
=
doc_idx_index
sample_idx
[
sample_index
][
1
]
=
doc_offset
sample_index
+=
1
while
sample_index
<=
num_samples
:
# Start with a fresh sequence.
remaining_seq_length
=
seq_length
+
1
while
remaining_seq_length
!=
0
:
# Get the document length.
doc_id
=
doc_idx
[
doc_idx_index
]
doc_length
=
sizes
[
doc_id
]
-
doc_offset
# And add it to the current sequence.
remaining_seq_length
-=
doc_length
# If we have more than a full sequence, adjust offset and set
# remaining length to zero so we return from the while loop.
# Note that -1 here is for the same reason we have -1 in
# `_num_epochs` calculations.
if
remaining_seq_length
<=
0
:
doc_offset
+=
(
remaining_seq_length
+
doc_length
-
1
)
remaining_seq_length
=
0
else
:
# Otherwise, start from the begining of the next document.
doc_idx_index
+=
1
doc_offset
=
0
# Record the sequence.
sample_idx
[
sample_index
][
0
]
=
doc_idx_index
sample_idx
[
sample_index
][
1
]
=
doc_offset
sample_index
+=
1
return
sample_idx
def
_build_shuffle_idx
(
size
,
np_rng
):
"""Build the range [0, size) and shuffle."""
dtype_
=
np
.
uint32
if
size
>=
(
np
.
iinfo
(
np
.
uint32
).
max
-
1
):
dtype_
=
np
.
int64
shuffle_idx
=
np
.
arange
(
start
=
0
,
stop
=
size
,
step
=
1
,
dtype
=
dtype_
)
#np_rng.shuffle(shuffle_idx)
return
shuffle_idx
class
IndexedDataset
:
def
__init__
(
self
,
num_docs
,
min_doc_length
,
max_doc_length
,
seq_length
):
self
.
seq_length
=
seq_length
assert
min_doc_length
>
0
self
.
tokens
=
[]
self
.
sizes
=
np
.
zeros
(
num_docs
,
dtype
=
np
.
int32
)
for
i
in
range
(
num_docs
):
size
=
np
.
random
.
randint
(
low
=
min_doc_length
,
high
=
max_doc_length
,
size
=
1
,
dtype
=
np
.
uint32
)[
0
]
tokens_
=
np
.
random
.
randint
(
low
=
1
,
high
=
60000
,
size
=
size
,
dtype
=
np
.
uint32
)
tokens_
[
-
1
]
=
0
self
.
sizes
[
i
]
=
size
self
.
tokens
.
append
(
tokens_
)
self
.
tokens_flat
=
None
def
get
(
self
,
doc_idx
,
offset
=
None
,
length
=
None
):
if
length
is
None
:
if
offset
is
None
:
return
self
.
tokens
[
doc_idx
]
else
:
return
self
.
tokens
[
doc_idx
][
offset
:]
if
offset
is
None
:
return
self
.
tokens
[
doc_idx
][
0
:
length
]
return
self
.
tokens
[
doc_idx
][
offset
:(
offset
+
length
)]
def
get_sample
(
self
,
index
):
start
=
index
*
self
.
seq_length
end
=
start
+
self
.
seq_length
+
1
return
self
.
tokens_flat
[
start
:
end
]
def
build_tokens_flat
(
self
,
doc_idx
):
self
.
tokens_flat
=
np
.
concatenate
([
self
.
tokens
[
i
]
for
i
in
doc_idx
])
def
test
(
seed
,
data_prefix
,
seq_length
,
num_samples
,
num_docs
,
min_doc_length
,
max_doc_length
):
print
(
'testing for seed: {}, seq-length: {}, num-samples: {}, '
'num-docs: {}, min-doc-length: {}, max-doc-length: {}'
.
format
(
seed
,
seq_length
,
num_samples
,
num_docs
,
min_doc_length
,
max_doc_length
))
np
.
random
.
seed
(
seed
)
indexed_dataset
=
IndexedDataset
(
num_docs
,
min_doc_length
,
max_doc_length
,
seq_length
)
indices
=
np
.
random
.
randint
(
indexed_dataset
.
sizes
.
shape
[
0
]
-
2
,
size
=
2
)
documents
=
np
.
arange
(
np
.
min
(
indices
),
np
.
max
(
indices
)
+
1
)
dataset
=
GPT2Dataset
(
'gpt2'
,
data_prefix
,
documents
,
indexed_dataset
,
num_samples
,
seq_length
,
seed
)
print
(
' > number of epochs:'
,
dataset
.
num_epochs
)
indexed_dataset
.
build_tokens_flat
(
dataset
.
doc_idx
)
for
idx
in
range
(
num_samples
):
a
=
dataset
[
idx
]
b
=
indexed_dataset
.
get_sample
(
idx
)
assert
np
.
sum
(
a
-
b
)
==
0
print
(
'passed'
)
if
__name__
==
'__main__'
:
print
(
'gpt2 dataset ...'
)
import
random
data_prefix
=
'junk/'
for
seed
in
range
(
1234
,
1240
):
random
.
seed
(
seed
)
num_docs
=
random
.
randint
(
1
,
999
)
min_doc_length
=
random
.
randint
(
1
,
99
)
max_doc_length
=
random
.
randint
(
100
,
9999
)
num_samples
=
random
.
randint
(
num_docs
,
100
*
num_docs
)
seq_length
=
random
.
randint
(
min_doc_length
,
max_doc_length
)
test
(
seed
,
data_prefix
,
seq_length
,
num_samples
,
num_docs
,
min_doc_length
,
max_doc_length
)
exit
()
'''
num_docs = 5
min_doc_length = 2
max_doc_length = 10
num_samples = 9
seq_length = 4
seed = 1234
np.random.seed(seed)
indexed_dataset = IndexedDataset(num_docs, min_doc_length,
max_doc_length, seq_length)
print('> indexed dataset:')
for s in indexed_dataset.tokens:
print(' {}'.format(s))
documents = np.array([1,2,3], dtype=np.int32)
dataset = GPT2Dataset('gpt2', documents, indexed_dataset,
num_samples, seq_length, seed)
indexed_dataset.build_tokens_flat(dataset.doc_idx)
print(indexed_dataset.get_sample(6))
print(dataset[6])
'''
'''
myds = MyDataset(ds, num_samples, seq_length)
num_docs = myds._num_docs()
print('> number of document: {}'.format(num_docs))
tokens_per_epoch = myds._num_tokens()
print('> number of tokens: {}'.format(tokens_per_epoch))
num_epochs = myds._num_epochs(tokens_per_epoch)
print('> number of epochs: {}'.format(num_epochs))
doc_idx = myds._build_doc_idx(num_docs, num_epochs)
print('> doc_idx: {}'.format(doc_idx))
ds.build_tokens_flat(doc_idx)
sample_idx =myds._build_sample_idx(num_epochs, tokens_per_epoch, doc_idx)
for s in sample_idx:
print(s)
print(ds.tokens_flat)
print(myds.get_sample(8))
print(ds.get_sample(8))
'''
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment