Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c3691006
Commit
c3691006
authored
Apr 08, 2020
by
Mohammad
Browse files
testing new gpt2 dataset
parent
836c6776
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
85 deletions
+30
-85
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+1
-1
megatron/data/new_gpt2_dataset.py
megatron/data/new_gpt2_dataset.py
+29
-84
No files found.
megatron/data/bert_dataset.py
View file @
c3691006
...
...
@@ -24,7 +24,6 @@ from torch.utils.data import Dataset
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.data
import
helpers
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron
import
print_rank_0
...
...
@@ -249,6 +248,7 @@ def get_samples_mapping_(indexed_dataset,
start_time
=
time
.
time
()
print_rank_0
(
' > building sapmles index mapping for {} ...'
.
format
(
name
))
from
megatron.data
import
helpers
samples_mapping
=
helpers
.
build_mapping
(
indexed_dataset
.
doc_idx
,
indexed_dataset
.
sizes
,
...
...
megatron/data/new_gpt2_dataset.py
View file @
c3691006
...
...
@@ -13,26 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2
S
tyle dataset."""
"""GPT2
s
tyle dataset."""
import
os
import
time
import
numpy
as
np
import
torch
from
torch.utils.data
import
Dataset
import
helpers
#from bert_dataset import get_train_valid_test_split_
def
print_rank_0
(
message
):
print
(
message
)
from
megatron
import
print_rank_0
from
megatron
import
mpu
from
megatron.data.bert_dataset
import
get_train_valid_test_split_
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
def
build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
seq_length
,
seed
,
skip_warmup
):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
...
...
@@ -56,7 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def
build_dataset
(
index
,
name
):
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
documents
=
np
.
arange
(
start
=
splits
[
index
],
end
=
splits
[
index
+
1
],
documents
=
np
.
arange
(
start
=
splits
[
index
],
stop
=
splits
[
index
+
1
],
step
=
1
,
dtype
=
np
.
int32
)
dataset
=
GPT2Dataset
(
name
,
data_prefix
,
documents
,
indexed_dataset
,
...
...
@@ -72,7 +70,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
"""Build indexed dataset."""
print_rank_0
(
' > building dataset index ...'
)
start_time
=
time
.
time
()
...
...
@@ -81,25 +79,18 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
skip_warmup
)
print_rank_0
(
' > finished creating indexed dataset in {:4f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' > indexed dataset stats:'
)
print_rank_0
(
' number of documents: {}'
.
format
(
indexed_dataset
.
sizes
.
shape
[
0
]))
return
indexed_dataset
class
GPT2Dataset
(
Dataset
):
class
GPT2Dataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
name
,
data_prefix
,
documents
,
indexed_dataset
,
def
__init__
(
self
,
name
,
data_prefix
,
documents
,
indexed_dataset
,
num_samples
,
seq_length
,
seed
):
self
.
name
=
name
self
.
data_prefix
=
data_prefix
self
.
num_samples
=
num_samples
self
.
seq_length
=
seq_length
self
.
seed
=
seed
self
.
indexed_dataset
=
indexed_dataset
# Checks
...
...
@@ -107,11 +98,9 @@ class GPT2Dataset(Dataset):
assert
np
.
max
(
documents
)
<
indexed_dataset
.
sizes
.
shape
[
0
]
# Build index mappings.
self
.
num_epochs
,
self
.
doc_idx
,
self
.
sample_idx
,
self
.
shuffle_idx
\
=
_build_index_mappings
(
self
.
name
,
self
.
data_prefix
,
documents
,
self
.
indexed_dataset
.
sizes
,
self
.
num_samples
,
self
.
seq_length
,
self
.
seed
)
self
.
doc_idx
,
self
.
sample_idx
,
self
.
shuffle_idx
=
_build_index_mappings
(
self
.
name
,
data_prefix
,
documents
,
self
.
indexed_dataset
.
sizes
,
num_samples
,
seq_length
,
seed
)
def
__len__
(
self
):
...
...
@@ -144,7 +133,7 @@ class GPT2Dataset(Dataset):
length
=
offset_l
+
1
))
sample
=
np
.
concatenate
(
sample_list
)
return
sample
return
{
'text'
:
np
.
array
(
sample
,
dtype
=
np
.
int64
)}
...
...
@@ -156,7 +145,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
num_epochs
=
_num_epochs
(
tokens_per_epoch
,
seq_length
,
num_samples
)
# rng state
np_rng
=
np
.
random
.
RandomState
(
seed
=
seed
)
# Filename of the index mappings.
_filename
=
data_prefix
_filename
+=
'_{}_indexmap'
.
format
(
name
)
...
...
@@ -168,11 +157,11 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
shuffle_idx_filename
=
_filename
+
'_shuffle_idx.npy'
# Build the indexed mapping if not exist.
if
True
:
#
torch.distributed.get_rank() == 0:
if
torch
.
distributed
.
get_rank
()
==
0
:
if
(
not
os
.
path
.
isfile
(
doc_idx_filename
))
or
\
(
not
os
.
path
.
isfile
(
sample_idx_filename
))
or
\
(
not
os
.
path
.
isfile
(
shuffle_idx_filename
)):
print_rank_0
(
' > WARNING: could not find index map files, building '
'the indices on rank 0 ...'
)
# doc-idx.
...
...
@@ -183,7 +172,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# sample-idx.
start_time
=
time
.
time
()
import
helpers
# Use C++ implementation for speed.
from
megatron.data
import
helpers
assert
doc_idx
.
dtype
==
np
.
int32
assert
sizes
.
dtype
==
np
.
int32
sample_idx
=
helpers
.
build_sample_idx
(
sizes
,
doc_idx
,
seq_length
,
num_epochs
,
tokens_per_epoch
)
#sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
...
...
@@ -202,9 +194,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
#
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
#
assert counts[0].item() == torch.distributed.get_world_size(
#
group=mpu.get_data_parallel_group())
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load mappings.
start_time
=
time
.
time
()
...
...
@@ -221,8 +213,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
sample_idx
.
shape
[
0
]))
print_rank_0
(
' total number of epochs: {}'
.
format
(
num_epochs
))
return
num_epochs
,
doc_idx
,
sample_idx
,
shuffle_idx
return
doc_idx
,
sample_idx
,
shuffle_idx
def
_num_tokens
(
documents
,
sizes
):
...
...
@@ -311,10 +304,11 @@ def _build_shuffle_idx(size, np_rng):
if
size
>=
(
np
.
iinfo
(
np
.
uint32
).
max
-
1
):
dtype_
=
np
.
int64
shuffle_idx
=
np
.
arange
(
start
=
0
,
stop
=
size
,
step
=
1
,
dtype
=
dtype_
)
#
np_rng.shuffle(shuffle_idx)
np_rng
.
shuffle
(
shuffle_idx
)
return
shuffle_idx
'''
class IndexedDataset:
...
...
@@ -399,53 +393,4 @@ if __name__ == '__main__':
test(seed, data_prefix, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length)
exit
()
'''
num_docs = 5
min_doc_length = 2
max_doc_length = 10
num_samples = 9
seq_length = 4
seed = 1234
np.random.seed(seed)
indexed_dataset = IndexedDataset(num_docs, min_doc_length,
max_doc_length, seq_length)
print('> indexed dataset:')
for s in indexed_dataset.tokens:
print(' {}'.format(s))
documents = np.array([1,2,3], dtype=np.int32)
dataset = GPT2Dataset('gpt2', documents, indexed_dataset,
num_samples, seq_length, seed)
indexed_dataset.build_tokens_flat(dataset.doc_idx)
print(indexed_dataset.get_sample(6))
print(dataset[6])
'''
'''
myds = MyDataset(ds, num_samples, seq_length)
num_docs = myds._num_docs()
print('> number of document: {}'.format(num_docs))
tokens_per_epoch = myds._num_tokens()
print('> number of tokens: {}'.format(tokens_per_epoch))
num_epochs = myds._num_epochs(tokens_per_epoch)
print('> number of epochs: {}'.format(num_epochs))
doc_idx = myds._build_doc_idx(num_docs, num_epochs)
print('> doc_idx: {}'.format(doc_idx))
ds.build_tokens_flat(doc_idx)
sample_idx =myds._build_sample_idx(num_epochs, tokens_per_epoch, doc_idx)
for s in sample_idx:
print(s)
print(ds.tokens_flat)
print(myds.get_sample(8))
print(ds.get_sample(8))
'''
'''
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment