Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f6a6811f
Commit
f6a6811f
authored
Nov 20, 2019
by
Mohammad Shoeybi
Browse files
fixed padding issue
parent
9ea9d50f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
115 additions
and
112 deletions
+115
-112
megatron/data/albert_dataset.py
megatron/data/albert_dataset.py
+98
-98
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+13
-10
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-1
pretrain_albert.py
pretrain_albert.py
+2
-2
pretrain_bert.py
pretrain_bert.py
+1
-1
No files found.
megatron/data/albert_dataset.py
View file @
f6a6811f
"""TO BE ADDED """
"""TO BE ADDED """
import
os
import
random
import
random
import
time
import
time
import
os
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
.dataset_utils
import
build_training_sample
from
megatron.data
import
helpers
#from data.mapping import build_training_samples_mapping
from
.
import
helpers
from
megatron.data
import
FullBertTokenizer
from
megatron.data
import
FullBertTokenizer
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
print_rank_0
class
AlbertDataset
(
Dataset
):
class
AlbertDataset
(
Dataset
):
def
__init__
(
self
,
vocab_file
,
data_prefix
,
data_impl
,
skip_warmup
,
def
__init__
(
self
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
vocab_file
,
data_prefix
,
data_impl
,
skip_warmup
,
short_seq_prob
,
seed
):
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
# Params to store.
# Params to store.
self
.
seed
=
seed
self
.
seed
=
seed
...
@@ -32,25 +28,26 @@ class AlbertDataset(Dataset):
...
@@ -32,25 +28,26 @@ class AlbertDataset(Dataset):
self
.
tokenizer
=
FullBertTokenizer
(
vocab_file
,
do_lower_case
=
True
)
self
.
tokenizer
=
FullBertTokenizer
(
vocab_file
,
do_lower_case
=
True
)
# Indexed dataset.
# Indexed dataset.
self
.
indexed_dataset
=
self
.
_get_indexed_dataset
(
data_prefix
,
data_impl
,
self
.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
skip_warmup
)
data_impl
,
skip_warmup
)
# Build the samples mapping.
# Build the samples mapping.
self
.
samples_mapping
=
self
.
_
get_samples_mapping
(
self
.
indexed_dataset
,
self
.
samples_mapping
=
get_samples_mapping
_
(
self
.
indexed_dataset
,
data_prefix
,
data_prefix
,
num_epochs
,
num_epochs
,
max_num_samples
,
max_num_samples
,
self
.
max_seq_length
,
self
.
max_seq_length
,
short_seq_prob
,
short_seq_prob
,
self
.
seed
)
self
.
seed
)
# Vocab stuff.
# Vocab stuff.
self
.
vocab_id_list
=
list
(
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_dict
=
tokenizer
.
inv_vocab
self
.
vocab_id_to_token_dict
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
tokenizer
.
vocab
[
'[CLS]'
]
self
.
cls_id
=
self
.
tokenizer
.
vocab
[
'[CLS]'
]
self
.
sep_id
=
tokenizer
.
vocab
[
'[SEP]'
]
self
.
sep_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
]
self
.
mask_id
=
tokenizer
.
vocab
[
'[MASK]'
]
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
self
.
pad_id
=
tokenizer
.
vocab
[
'[PAD]'
]
self
.
pad_id
=
self
.
tokenizer
.
vocab
[
'[PAD]'
]
exit
()
exit
()
...
@@ -64,6 +61,8 @@ class AlbertDataset(Dataset):
...
@@ -64,6 +61,8 @@ class AlbertDataset(Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
# Note that this rng state should be python and not numpy since
# python randint is inclusive whereas the numpy one is exclusive.
rng
=
random
.
Random
(
self
.
seed
+
idx
)
rng
=
random
.
Random
(
self
.
seed
+
idx
)
start_index
,
end_index
,
seq_length
=
self
.
samples_mapping
[
idx
]
start_index
,
end_index
,
seq_length
=
self
.
samples_mapping
[
idx
]
sample
=
[]
sample
=
[]
...
@@ -82,82 +81,81 @@ class AlbertDataset(Dataset):
...
@@ -82,82 +81,81 @@ class AlbertDataset(Dataset):
def
_get_indexed_dataset
(
self
,
data_prefix
,
data_impl
,
skip_warmup
):
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
print_rank_0
(
"> Reading dataset index ..."
)
print_rank_0
(
"> Reading dataset index ..."
)
indexed_dataset
=
make_indexed_dataset
(
data_prefix
,
indexed_dataset
=
make_indexed_dataset
(
data_prefix
,
data_impl
,
data_impl
,
skip_warmup
)
skip_warmup
)
print_rank_0
(
"> Finished creating indexed dataset in {:4f} "
print_rank_0
(
"> Finished creating indexed dataset in {:4f} "
"seconds"
.
format
(
time
.
time
()
-
start_time
))
"seconds"
.
format
(
time
.
time
()
-
start_time
))
return
indexed_dataset
return
indexed_dataset
def
_get_samples_mapping
(
self
,
def
get_samples_mapping_
(
indexed_dataset
,
indexed_dataset
,
data_prefix
,
data_prefix
,
num_epochs
,
num_epochs
,
max_num_samples
,
max_num_samples
,
max_seq_length
,
max_seq_length
,
short_seq_prob
,
short_seq_prob
,
seed
):
seed
):
if
not
num_epochs
:
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
# Filename of the index mapping
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
indexmap_filename
=
data_prefix
if
not
max_num_samples
:
indexmap_filename
+=
'_indexmap'
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
# Filename of the index mapping
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{:0.2f}ssp'
.
format
(
short_seq_prob
)
indexmap_filename
+=
'_indexmap'
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
indexmap_filename
+=
'.npy'
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
# Build the indexed mapping if not exist.
indexmap_filename
+=
'_{:0.2f}ssp'
.
format
(
short_seq_prob
)
if
torch
.
distributed
.
get_rank
()
==
0
and
\
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
not
os
.
path
.
isfile
(
indexmap_filename
):
indexmap_filename
+=
'.npy'
print
(
'WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Build the indexed mapping if not exist.
# Make sure the types match the helpers input types.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
assert
indexed_dataset
.
doc_idx
.
dtype
==
np
.
int64
not
os
.
path
.
isfile
(
indexmap_filename
):
assert
indexed_dataset
.
sizes
.
dtype
==
np
.
int32
print
(
'WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Build samples mapping
# Make sure the types match the helpers input types.
verbose
=
torch
.
distributed
.
get_rank
()
==
0
assert
indexed_dataset
.
doc_idx
.
dtype
==
np
.
int64
start_time
=
time
.
time
()
assert
indexed_dataset
.
sizes
.
dtype
==
np
.
int32
samples_mapping
=
helpers
.
build_mapping
(
indexed_dataset
.
doc_idx
,
# Build samples mapping
indexed_dataset
.
sizes
,
verbose
=
torch
.
distributed
.
get_rank
()
==
0
num_epochs
,
max_num_samples
,
max_seq_length
-
3
,
# account for added tokens
short_seq_prob
,
seed
,
verbose
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
# Make sure all the ranks have built the mapping
print_rank_0
(
'> elasped time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
torch
.
distributed
.
barrier
()
# Load indexed dataset.
print_rank_0
(
'> loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
samples_mapping
=
helpers
.
build_mapping
(
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
indexed_dataset
.
doc_idx
,
time
.
time
()
-
start_time
))
indexed_dataset
.
sizes
,
print_rank_0
(
' total number of samples: {}'
.
format
(
num_epochs
,
samples_mapping
.
shape
[
0
]))
max_num_samples
,
max_seq_length
-
3
,
# account for added tokens
short_seq_prob
,
seed
,
verbose
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
# Make sure all the ranks have built the mapping
print_rank_0
(
'> elasped time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
torch
.
distributed
.
barrier
()
# Load indexed dataset.
print_rank_0
(
'> loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
return
samples_mapping
'''
'''
...
@@ -274,6 +272,7 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
...
@@ -274,6 +272,7 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
return samples_np
return samples_np
'''
'''
'''
# WILL BE REPLACED WITH JARED'S
# WILL BE REPLACED WITH JARED'S
class JaredDataset(object):
class JaredDataset(object):
...
@@ -395,3 +394,4 @@ if __name__ == '__main__':
...
@@ -395,3 +394,4 @@ if __name__ == '__main__':
max_seq_length=512,
max_seq_length=512,
short_seq_prob=0.1,
short_seq_prob=0.1,
seed=1234)
seed=1234)
'''
megatron/data/dataset_utils.py
View file @
f6a6811f
...
@@ -24,7 +24,9 @@ def build_training_sample(sample,
...
@@ -24,7 +24,9 @@ def build_training_sample(sample,
mask_id: Mask token id.
mask_id: Mask token id.
pad_id: Padding token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
masked_lm_prob: Probability to mask tokens.
rng: Random number genenrator.
rng: Random number genenrator. Note that this rng state should be
python and not numpy since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
"""
# We assume that we have at least two sentences in the sample
# We assume that we have at least two sentences in the sample
...
@@ -36,8 +38,8 @@ def build_training_sample(sample,
...
@@ -36,8 +38,8 @@ def build_training_sample(sample,
# Truncate to `target_sequence_length`.
# Truncate to `target_sequence_length`.
max_num_tokens
=
target_seq_length
max_num_tokens
=
target_seq_length
truncated
=
truncate_segments
(
tokens_a
,
tokens_b
,
len
(
tokens_a
),
len
(
tokens_b
),
truncated
=
truncate_segments
(
tokens_a
,
tokens_b
,
len
(
tokens_a
),
max_num_tokens
,
rng
)
len
(
tokens_b
),
max_num_tokens
,
rng
)
# Build tokens and toketypes.
# Build tokens and toketypes.
tokens
,
tokentypes
=
create_tokens_and_tokentypes
(
tokens_a
,
tokens_b
,
tokens
,
tokentypes
=
create_tokens_and_tokentypes
(
tokens_a
,
tokens_b
,
...
@@ -50,17 +52,17 @@ def build_training_sample(sample,
...
@@ -50,17 +52,17 @@ def build_training_sample(sample,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
rng
)
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
rng
)
# Padding.
# Padding.
tokens_np
,
tokentypes_np
,
labels
,
padding_mask
,
loss_mask
\
tokens_np
,
tokentypes_np
,
labels
_np
,
padding_mask
_np
,
loss_mask
_np
\
=
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
=
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
masked_labels
,
pad_id
,
max_seq_length
)
masked_labels
,
pad_id
,
max_seq_length
)
train_sample
=
{
train_sample
=
{
'text'
:
tokens_np
,
'text'
:
tokens_np
,
'types'
:
tokentypes_np
,
'types'
:
tokentypes_np
,
'labels'
:
labels
,
'labels'
:
labels
_np
,
'is_random'
:
int
(
is_next_random
),
'is_random'
:
int
(
is_next_random
),
'loss_mask'
:
loss_mask
,
'loss_mask'
:
loss_mask
_np
,
'padding_mask'
:
padding_mask
,
'padding_mask'
:
padding_mask
_np
,
'truncated'
:
int
(
truncated
)}
'truncated'
:
int
(
truncated
)}
return
train_sample
return
train_sample
...
@@ -357,7 +359,8 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
...
@@ -357,7 +359,8 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
tokentypes_np
=
np
.
array
(
tokentypes
+
filler
,
dtype
=
np
.
int64
)
tokentypes_np
=
np
.
array
(
tokentypes
+
filler
,
dtype
=
np
.
int64
)
# Padding mask.
# Padding mask.
padding_mask_np
=
np
.
array
([
1
]
*
num_tokens
+
[
0
]
*
padding_length
,
dtype
=
np
.
int64
)
padding_mask_np
=
np
.
array
([
1
]
*
num_tokens
+
[
0
]
*
padding_length
,
dtype
=
np
.
int64
)
# Lables and loss mask.
# Lables and loss mask.
labels
=
[
-
1
]
*
max_seq_length
labels
=
[
-
1
]
*
max_seq_length
...
@@ -372,8 +375,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
...
@@ -372,8 +375,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
return
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
return
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
'''
if __name__ == '__main__':
if __name__ == '__main__':
...
@@ -469,3 +471,4 @@ if __name__ == '__main__':
...
@@ -469,3 +471,4 @@ if __name__ == '__main__':
string += '{:5d}'.format(tokentype)
string += '{:5d}'.format(tokentype)
string += '{:5d}'.format(padding_mask)
string += '{:5d}'.format(padding_mask)
print(string)
print(string)
'''
megatron/model/bert_model.py
View file @
f6a6811f
...
@@ -145,7 +145,7 @@ class BertModel(MegatronModule):
...
@@ -145,7 +145,7 @@ class BertModel(MegatronModule):
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
init_method_std
,
num_layers
),
num_layers
),
residual_connection_post_layernorm
=
Tru
e
)
residual_connection_post_layernorm
=
Fals
e
)
self
.
lm_head
=
BertLMHead
(
self
.
lm_head
=
BertLMHead
(
self
.
language_model
.
embedding
.
word_embeddings
.
weight
.
size
(
0
),
self
.
language_model
.
embedding
.
word_embeddings
.
weight
.
size
(
0
),
...
...
pretrain_albert.py
View file @
f6a6811f
...
@@ -73,7 +73,7 @@ def get_batch(data_iterator, timers):
...
@@ -73,7 +73,7 @@ def get_batch(data_iterator, timers):
sentence_order
=
data_b
[
'is_random'
].
long
()
sentence_order
=
data_b
[
'is_random'
].
long
()
loss_mask
=
data_b
[
'loss_mask'
].
float
()
loss_mask
=
data_b
[
'loss_mask'
].
float
()
lm_labels
=
data_b
[
'labels'
].
long
()
lm_labels
=
data_b
[
'labels'
].
long
()
padding_mask
=
data_b
[
'padding_mask'
].
byte
()
padding_mask
=
data_b
[
'padding_mask'
].
long
()
return
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
return
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
...
@@ -88,7 +88,7 @@ def forward_step(data_iterator, model, args, timers):
...
@@ -88,7 +88,7 @@ def forward_step(data_iterator, model, args, timers):
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
lm_logits
,
sop_logits
=
model
(
tokens
,
1
-
padding_mask
,
tokentype_ids
=
types
)
lm_logits
,
sop_logits
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
)
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
contiguous
().
float
(),
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
contiguous
().
float
(),
sentence_order
.
view
(
-
1
).
contiguous
(),
sentence_order
.
view
(
-
1
).
contiguous
(),
...
...
pretrain_bert.py
View file @
f6a6811f
...
@@ -72,7 +72,7 @@ def get_batch(data_iterator, timers):
...
@@ -72,7 +72,7 @@ def get_batch(data_iterator, timers):
next_sentence
=
data_b
[
'is_random'
].
long
()
next_sentence
=
data_b
[
'is_random'
].
long
()
loss_mask
=
data_b
[
'mask'
].
float
()
loss_mask
=
data_b
[
'mask'
].
float
()
lm_labels
=
data_b
[
'mask_labels'
].
long
()
lm_labels
=
data_b
[
'mask_labels'
].
long
()
padding_mask
=
data_b
[
'pad_mask'
].
byte
()
padding_mask
=
data_b
[
'pad_mask'
].
long
()
return
tokens
,
types
,
next_sentence
,
loss_mask
,
lm_labels
,
padding_mask
return
tokens
,
types
,
next_sentence
,
loss_mask
,
lm_labels
,
padding_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment