Commit 73a911c0 authored by Jeremiah Harmsen's avatar Jeremiah Harmsen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 321119040
parent ad166183
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import itertools
import random import random
from absl import app from absl import app
...@@ -48,6 +49,12 @@ flags.DEFINE_bool( ...@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask", False, "do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.") "Whether to use whole word masking rather than per-WordPiece masking.")
flags.DEFINE_integer(
"max_ngram_size", None,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
flags.DEFINE_bool( flags.DEFINE_bool(
"gzip_compress", False, "gzip_compress", False,
"Whether to use `GZIP` compress option to get compressed TFRecord files.") "Whether to use `GZIP` compress option to get compressed TFRecord files.")
...@@ -192,7 +199,8 @@ def create_training_instances(input_files, ...@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob, masked_lm_prob,
max_predictions_per_seq, max_predictions_per_seq,
rng, rng,
do_whole_word_mask=False): do_whole_word_mask=False,
max_ngram_size=None):
"""Create `TrainingInstance`s from raw text.""" """Create `TrainingInstance`s from raw text."""
all_documents = [[]] all_documents = [[]]
...@@ -229,7 +237,7 @@ def create_training_instances(input_files, ...@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document( create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)) do_whole_word_mask, max_ngram_size))
rng.shuffle(instances) rng.shuffle(instances)
return instances return instances
...@@ -238,7 +246,8 @@ def create_training_instances(input_files, ...@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def create_instances_from_document( def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask=False): do_whole_word_mask=False,
max_ngram_size=None):
"""Creates `TrainingInstance`s for a single document.""" """Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index] document = all_documents[document_index]
...@@ -337,7 +346,7 @@ def create_instances_from_document( ...@@ -337,7 +346,7 @@ def create_instances_from_document(
(tokens, masked_lm_positions, (tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions( masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng, tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask) do_whole_word_mask, max_ngram_size)
instance = TrainingInstance( instance = TrainingInstance(
tokens=tokens, tokens=tokens,
segment_ids=segment_ids, segment_ids=segment_ids,
...@@ -355,72 +364,238 @@ def create_instances_from_document( ...@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"]) ["index", "label"])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram = collections.namedtuple("_Gram", ["begin", "end"])
def _window(iterable, size):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i = iter(iterable)
window = []
try:
for e in range(0, size):
window.append(next(i))
yield window
except StopIteration:
# handle the case where iterable's length is less than the window size.
return
for e in i:
window = window[1:] + [e]
yield window
def _contiguous(sorted_grams):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for a, b in _window(sorted_grams, 2):
if a.end != b.begin:
return False
return True
def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if not grams:
return None
grams = sorted(grams)
num_tokens = grams[-1].end
# Ensure our grams are valid (i.e., they don't overlap).
for a, b in _window(grams, 2):
if a.end > b.begin:
raise ValueError("overlapping grams: {}".format(grams))
# Build map from n-gram length to list of n-grams.
ngrams = {i: [] for i in range(1, max_ngram_size+1)}
for gram_size in range(1, max_ngram_size+1):
for g in _window(grams, gram_size):
if _contiguous(g):
# Add an n-gram which spans these one-grams.
ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
# Shuffle each list of n-grams.
for v in ngrams.values():
rng.shuffle(v)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights = list(
itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
output_ngrams = []
# Keep a bitmask of which tokens have been masked.
masked_tokens = [False] * num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while (sum(masked_tokens) < max_masked_tokens and
sum(len(s) for s in ngrams.values())):
# Pick an n-gram size based on our weights.
sz = random.choices(range(1, max_ngram_size+1),
cum_weights=cummulative_weights)[0]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if sum(masked_tokens) + sz > max_masked_tokens:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams[sz].clear()
continue
def create_masked_lm_predictions(tokens, masked_lm_prob, # All of the n-grams of this size have been used.
max_predictions_per_seq, vocab_words, rng, if not ngrams[sz]:
do_whole_word_mask): continue
"""Creates the predictions for the masked LM objective."""
# Choose a random n-gram of the given size.
gram = ngrams[sz].pop()
num_gram_tokens = gram.end-gram.begin
# Check if this would add too many tokens.
if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
continue
# Check if any of the tokens in this gram have already been masked.
if sum(masked_tokens[gram.begin:gram.end]):
continue
cand_indexes = [] # Found a usable n-gram! Mark its tokens as masked and add it to return.
for (i, token) in enumerate(tokens): masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
if token == "[CLS]" or token == "[SEP]": output_ngrams.append(gram)
return output_ngrams
def _wordpieces_to_grams(tokens):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams = []
gram_start_pos = None
for i, token in enumerate(tokens):
if gram_start_pos is not None and token.startswith("##"):
continue continue
# Whole Word Masking means that if we mask all of the wordpieces if gram_start_pos is not None:
# corresponding to an original word. When a word has been split into grams.append(_Gram(gram_start_pos, i))
# WordPieces, the first token does not have any marker and any subsequence if token not in ["[CLS]", "[SEP]"]:
# tokens are prefixed with ##. So whenever we see the ## token, we gram_start_pos = i
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else: else:
cand_indexes.append([i]) gram_start_pos = None
if gram_start_pos is not None:
grams.append(_Gram(gram_start_pos, len(tokens)))
return grams
rng.shuffle(cand_indexes)
output_tokens = list(tokens) def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask,
max_ngram_size=None):
"""Creates the predictions for the masked LM objective."""
if do_whole_word_mask:
grams = _wordpieces_to_grams(tokens)
else:
# Here we consider each token to be a word to allow for sub-word masking.
if max_ngram_size:
raise ValueError("cannot use ngram masking without whole word masking")
grams = [_Gram(i, i+1) for i in range(0, len(tokens))
if tokens[i] not in ["[CLS]", "[SEP]"]]
num_to_predict = min(max_predictions_per_seq, num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob)))) max(1, int(round(len(tokens) * masked_lm_prob))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
num_to_predict, rng)
masked_lms = [] masked_lms = []
covered_indexes = set() output_tokens = list(tokens)
for index_set in cand_indexes: for gram in masked_grams:
if len(masked_lms) >= num_to_predict: # 80% of the time, replace all n-gram tokens with [MASK]
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8: if rng.random() < 0.8:
masked_token = "[MASK]" replacement_action = lambda idx: "[MASK]"
else: else:
# 10% of the time, keep original # 10% of the time, keep all the original n-gram tokens.
if rng.random() < 0.5: if rng.random() < 0.5:
masked_token = tokens[index] replacement_action = lambda idx: tokens[idx]
# 10% of the time, replace with random word # 10% of the time, replace each n-gram token with a random word.
else: else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] replacement_action = lambda idx: rng.choice(vocab_words)
output_tokens[index] = masked_token for idx in range(gram.begin, gram.end):
output_tokens[idx] = replacement_action(idx)
masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lms = sorted(masked_lms, key=lambda x: x.index)
...@@ -467,7 +642,7 @@ def main(_): ...@@ -467,7 +642,7 @@ def main(_):
instances = create_training_instances( instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng, FLAGS.do_whole_word_mask) rng, FLAGS.do_whole_word_mask, FLAGS.max_ngram_size)
output_files = FLAGS.output_file.split(",") output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***") logging.info("*** Writing to output files ***")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment