Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8b1da95a
Commit
8b1da95a
authored
Apr 16, 2020
by
Neel Kant
Browse files
Build simple mlm examples
parent
81c71789
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
182 additions
and
86 deletions
+182
-86
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+39
-0
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+70
-85
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+60
-0
megatron/model/bert_model.py
megatron/model/bert_model.py
+5
-1
megatron/tokenizer/bert_tokenization.py
megatron/tokenizer/bert_tokenization.py
+8
-0
No files found.
megatron/data/dataset_utils.py
View file @
8b1da95a
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
import
collections
import
collections
import
itertools
import
numpy
as
np
import
numpy
as
np
...
@@ -80,6 +82,33 @@ def build_training_sample(sample,
...
@@ -80,6 +82,33 @@ def build_training_sample(sample,
return
train_sample
return
train_sample
def
build_simple_training_sample
(
sample
,
target_seq_length
,
max_seq_length
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
masked_lm_prob
,
np_rng
):
tokens
=
list
(
itertools
.
chain
(
*
sample
))[:
max_seq_length
-
2
]
tokens
,
tokentypes
=
create_single_tokens_and_tokentypes
(
tokens
)
max_predictions_per_seq
=
masked_lm_prob
*
max_seq_length
(
tokens
,
masked_positions
,
masked_labels
,
_
)
=
create_masked_lm_predictions
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
\
=
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
masked_labels
,
pad_id
,
max_seq_length
)
train_sample
=
{
'text'
:
tokens_np
,
'types'
:
tokentypes_np
,
'labels'
:
labels_np
,
'loss_mask'
:
loss_mask_np
,
'padding_mask'
:
padding_mask_np
}
return
train_sample
def
get_a_and_b_segments
(
sample
,
np_rng
):
def
get_a_and_b_segments
(
sample
,
np_rng
):
"""Divide sample into a and b segments."""
"""Divide sample into a and b segments."""
...
@@ -132,6 +161,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
...
@@ -132,6 +161,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
tokens
.
pop
()
tokens
.
pop
()
return
True
return
True
def
create_tokens_and_tokentypes
(
tokens_a
,
tokens_b
,
cls_id
,
sep_id
):
def
create_tokens_and_tokentypes
(
tokens_a
,
tokens_b
,
cls_id
,
sep_id
):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
...
@@ -158,6 +188,15 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
...
@@ -158,6 +188,15 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
return
tokens
,
tokentypes
return
tokens
,
tokentypes
def
create_single_tokens_and_tokentypes
(
_tokens
,
cls_id
,
sep_id
):
tokens
=
[]
tokens
.
append
(
cls_id
)
tokens
.
extend
(
list
(
_tokens
))
tokens
.
append
(
sep_id
)
tokentypes
=
[
0
]
*
len
(
tokens
)
return
tokens
,
tokentypes
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
[
"index"
,
"label"
])
[
"index"
,
"label"
])
...
...
megatron/data/ict_dataset.py
View file @
8b1da95a
import
itertools
import
itertools
import
random
import
random
import
os
import
os
import
sys
import
time
import
time
import
numpy
as
np
import
numpy
as
np
...
@@ -27,14 +26,8 @@ class InverseClozeDataset(Dataset):
...
@@ -27,14 +26,8 @@ class InverseClozeDataset(Dataset):
self
.
short_seq_prob
=
short_seq_prob
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
samples_mapping
=
get_samples_mapping
(
self
.
context_dataset
,
self
.
samples_mapping
=
self
.
get_samples_mapping
(
self
.
titles_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
)
data_prefix
,
num_epochs
,
max_num_samples
,
self
.
max_seq_length
,
self
.
seed
,
self
.
name
)
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_list
=
list
(
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
tokenizer
.
inv_vocab
self
.
vocab_id_to_token_list
=
tokenizer
.
inv_vocab
...
@@ -97,82 +90,74 @@ class InverseClozeDataset(Dataset):
...
@@ -97,82 +90,74 @@ class InverseClozeDataset(Dataset):
token_types
=
[
0
]
*
self
.
max_seq_length
token_types
=
[
0
]
*
self
.
max_seq_length
return
tokens
,
token_types
,
pad_mask
return
tokens
,
token_types
,
pad_mask
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
def
get_samples_mapping
(
context_dataset
,
if
not
num_epochs
:
titles_dataset
,
if
not
max_num_samples
:
data_prefix
,
raise
ValueError
(
"Need to specify either max_num_samples "
num_epochs
,
"or num_epochs"
)
max_num_samples
,
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
max_seq_length
,
seed
,
name
):
if
not
num_epochs
:
if
not
max_num_samples
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
# Filename of the index mapping
if
not
max_num_samples
:
indexmap_filename
=
data_prefix
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
indexmap_filename
+=
'_{}_indexmap'
.
format
(
self
.
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
# Filename of the index mapping
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
indexmap_filename
=
data_prefix
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}_indexmap'
.
format
(
name
)
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}msl'
.
format
(
self
.
max_seq_length
)
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
indexmap_filename
+=
'_{}s'
.
format
(
self
.
seed
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'.npy'
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
# Build the indexed mapping if not exist.
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
if
torch
.
distributed
.
get_rank
()
==
0
and
\
indexmap_filename
+=
'.npy'
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
# Build the indexed mapping if not exist.
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
# Make sure the types match the helpers input types.
print
(
' > WARNING: could not find index map file {}, building '
assert
self
.
context_dataset
.
doc_idx
.
dtype
==
np
.
int64
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
assert
self
.
context_dataset
.
sizes
.
dtype
==
np
.
int32
# Make sure the types match the helpers input types.
# Build samples mapping
assert
context_dataset
.
doc_idx
.
dtype
==
np
.
int64
verbose
=
torch
.
distributed
.
get_rank
()
==
0
assert
context_dataset
.
sizes
.
dtype
==
np
.
int32
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
# Build samples mapping
self
.
name
))
verbose
=
torch
.
distributed
.
get_rank
()
==
0
samples_mapping
=
helpers
.
build_blocks_mapping
(
start_time
=
time
.
time
()
self
.
context_dataset
.
doc_idx
,
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
self
.
context_dataset
.
sizes
,
name
))
self
.
titles_dataset
.
sizes
,
samples_mapping
=
helpers
.
build_blocks_mapping
(
num_epochs
,
context_dataset
.
doc_idx
,
max_num_samples
,
context_dataset
.
sizes
,
self
.
max_seq_length
-
3
,
# account for added tokens
titles_dataset
.
sizes
,
self
.
seed
,
num_epochs
,
verbose
)
max_num_samples
,
print_rank_0
(
' > done building samples index mapping'
)
max_seq_length
-
3
,
# account for added tokens
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
seed
,
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
verbose
)
indexmap_filename
))
print_rank_0
(
' > done building samples index mapping'
)
# Make sure all the ranks have built the mapping
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > elapsed time to build and save samples mapping '
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
indexmap_filename
))
# Make sure all the ranks have built the mapping
start_time
=
time
.
time
()
print_rank_0
(
' > elapsed time to build and save samples mapping '
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
'(seconds): {:4f}
'
.
format
(
print_rank_0
(
' loaded indexed file in {:3.3f} seconds
'
.
format
(
time
.
time
()
-
start_time
))
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
print_rank_0
(
' total number of samples: {}'
.
format
(
# device_index=rank which is not the case for model
samples_mapping
.
shape
[
0
]))
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
return
samples_mapping
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
megatron/data/realm_dataset.py
0 → 100644
View file @
8b1da95a
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron.data.bert_dataset
import
get_samples_mapping_
from
megatron.data.dataset_utils
import
build_simple_training_sample
class
RealmDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
# Params to store.
self
.
name
=
name
self
.
seed
=
seed
self
.
masked_lm_prob
=
masked_lm_prob
self
.
max_seq_length
=
max_seq_length
# Dataset.
self
.
indexed_dataset
=
indexed_dataset
# Build the samples mapping.
self
.
samples_mapping
=
get_samples_mapping_
(
self
.
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
self
.
max_seq_length
,
short_seq_prob
,
self
.
seed
,
self
.
name
)
# Vocab stuff.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_dict
=
tokenizer
.
inv_vocab
self
.
cls_id
=
tokenizer
.
cls
self
.
sep_id
=
tokenizer
.
sep
self
.
mask_id
=
tokenizer
.
mask
self
.
pad_id
=
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
seq_length
=
self
.
samples_mapping
[
idx
]
sample
=
[
self
.
indexed_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
return
build_simple_training_sample
(
sample
,
seq_length
,
self
.
max_seq_length
,
# needed for padding
self
.
vocab_id_list
,
self
.
vocab_id_to_token_dict
,
self
.
cls_id
,
self
.
sep_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
np_rng
)
megatron/model/bert_model.py
View file @
8b1da95a
...
@@ -214,10 +214,14 @@ class BertModel(MegatronModule):
...
@@ -214,10 +214,14 @@ class BertModel(MegatronModule):
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
# REALMBertModel is just BertModel without binary head.
# needs a different kind of dataset though
class
ICTBertModel
(
MegatronModule
):
class
ICTBertModel
(
MegatronModule
):
def
__init__
(
self
,
def
__init__
(
self
,
ict_head_size
,
ict_head_size
,
num_tokentypes
=
0
,
num_tokentypes
=
2
,
parallel_output
=
True
):
parallel_output
=
True
):
super
(
ICTBertModel
,
self
).
__init__
()
super
(
ICTBertModel
,
self
).
__init__
()
bert_args
=
dict
(
bert_args
=
dict
(
...
...
megatron/tokenizer/bert_tokenization.py
View file @
8b1da95a
...
@@ -178,6 +178,14 @@ class FullTokenizer(object):
...
@@ -178,6 +178,14 @@ class FullTokenizer(object):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_ids_to_tokens
(
self
,
ids
):
def
convert_ids_to_tokens
(
self
,
ids
):
train_sample
=
{
'text'
:
tokens_np
,
'types'
:
tokentypes_np
,
'labels'
:
labels_np
,
'is_random'
:
int
(
is_next_random
),
'loss_mask'
:
loss_mask_np
,
'padding_mask'
:
padding_mask_np
,
'truncated'
:
int
(
truncated
)}
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
def
vocab_size
(
self
):
def
vocab_size
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment