Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
784c0ed8
Unverified
Commit
784c0ed8
authored
Jun 11, 2019
by
Thomas Wolf
Committed by
GitHub
Jun 11, 2019
Browse files
Merge pull request #668 from jeonsworld/patch-2
apply Whole Word Masking technique
parents
ee0308f7
a3a604ce
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
23 deletions
+59
-23
examples/lm_finetuning/pregenerate_training_data.py
examples/lm_finetuning/pregenerate_training_data.py
+59
-23
No files found.
examples/lm_finetuning/pregenerate_training_data.py
View file @
784c0ed8
...
...
@@ -4,11 +4,11 @@ from tqdm import tqdm, trange
from
tempfile
import
TemporaryDirectory
import
shelve
from
random
import
random
,
randrange
,
randint
,
shuffle
,
choice
,
sample
from
random
import
random
,
randrange
,
randint
,
shuffle
,
choice
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
import
numpy
as
np
import
json
import
collections
class
DocumentDatabase
:
def
__init__
(
self
,
reduce_memory
=
False
):
...
...
@@ -98,22 +98,53 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
else
:
trunc_tokens
.
pop
()
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
[
"index"
,
"label"
])
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_list
):
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
whole_word_mask
,
vocab_list
):
"""Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
with several refactors to clean it up and remove a lot of unnecessary variables."""
cand_indices
=
[]
for
(
i
,
token
)
in
enumerate
(
tokens
):
if
token
==
"[CLS]"
or
token
==
"[SEP]"
:
continue
cand_indices
.
append
(
i
)
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
whole_word_mask
and
len
(
cand_indices
)
>=
1
and
token
.
startswith
(
"##"
)):
cand_indices
[
-
1
].
append
(
i
)
else
:
cand_indices
.
append
([
i
])
num_to_mask
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
shuffle
(
cand_indices
)
mask_indices
=
sorted
(
sample
(
cand_indices
,
num_to_mask
))
masked_token_labels
=
[]
for
index
in
mask_indices
:
masked_lms
=
[]
covered_indexes
=
set
()
for
index_set
in
cand_indices
:
if
len
(
masked_lms
)
>=
num_to_mask
:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_mask
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
random
()
<
0.8
:
masked_token
=
"[MASK]"
...
...
@@ -124,16 +155,20 @@ def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq
# 10% of the time, replace with random word
else
:
masked_token
=
choice
(
vocab_list
)
masked_token_labels
.
append
(
tokens
[
index
])
# Once we've saved the true label for that token, we can overwrite it with the masked version
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
tokens
[
index
]
=
masked_token
assert
len
(
masked_lms
)
<=
num_to_mask
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
mask_indices
=
[
p
.
index
for
p
in
masked_lms
]
masked_token_labels
=
[
p
.
label
for
p
in
masked_lms
]
return
tokens
,
mask_indices
,
masked_token_labels
def
create_instances_from_document
(
doc_database
,
doc_idx
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_list
):
masked_lm_prob
,
max_predictions_per_seq
,
whole_word_mask
,
vocab_list
):
"""This code is mostly a duplicate of the equivalent function from Google BERT's repo.
However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function.
Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence
...
...
@@ -213,7 +248,7 @@ def create_instances_from_document(
segment_ids
=
[
0
for
_
in
range
(
len
(
tokens_a
)
+
2
)]
+
[
1
for
_
in
range
(
len
(
tokens_b
)
+
1
)]
tokens
,
masked_lm_positions
,
masked_lm_labels
=
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_list
)
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
whole_word_mask
,
vocab_list
)
instance
=
{
"tokens"
:
tokens
,
...
...
@@ -237,7 +272,8 @@ def main():
choices
=
[
"bert-base-uncased"
,
"bert-large-uncased"
,
"bert-base-cased"
,
"bert-base-multilingual"
,
"bert-base-chinese"
])
parser
.
add_argument
(
"--do_lower_case"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--do_whole_word_mask"
,
action
=
"store_true"
,
help
=
"Whether to use whole word masking rather than per-WordPiece masking."
)
parser
.
add_argument
(
"--reduce_memory"
,
action
=
"store_true"
,
help
=
"Reduce memory usage for large datasets by keeping data on disc rather than in memory"
)
...
...
@@ -284,7 +320,7 @@ def main():
doc_instances
=
create_instances_from_document
(
docs
,
doc_idx
,
max_seq_length
=
args
.
max_seq_len
,
short_seq_prob
=
args
.
short_seq_prob
,
masked_lm_prob
=
args
.
masked_lm_prob
,
max_predictions_per_seq
=
args
.
max_predictions_per_seq
,
vocab_list
=
vocab_list
)
whole_word_mask
=
args
.
do_whole_word_mask
,
vocab_list
=
vocab_list
)
doc_instances
=
[
json
.
dumps
(
instance
)
for
instance
in
doc_instances
]
for
instance
in
doc_instances
:
epoch_file
.
write
(
instance
+
'
\n
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment