Commit d4dd827f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 359773626
parent 52531231
...@@ -766,12 +766,15 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -766,12 +766,15 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def test_preprocessing_for_mlm(self, use_bert): def test_preprocessing_for_mlm(self, use_bert):
"""Combines both SavedModel types and TF.text helpers for MLM.""" """Combines both SavedModel types and TF.text helpers for MLM."""
# Create the preprocessing SavedModel with a [MASK] token. # Create the preprocessing SavedModel with a [MASK] token.
non_special_tokens = ["hello", "world",
"nice", "movie", "great", "actors",
"quick", "fox", "lazy", "dog"]
preprocess = tf.saved_model.load(self._do_export( preprocess = tf.saved_model.load(self._do_export(
["d", "ef", "abc", "xy"], do_lower_case=True, non_special_tokens, do_lower_case=True,
tokenize_with_offsets=use_bert, # TODO(b/149576200): drop this. tokenize_with_offsets=use_bert, # TODO(b/149576200): drop this.
experimental_disable_assert=True, # TODO(b/175369555): drop this. experimental_disable_assert=True, # TODO(b/175369555): drop this.
add_mask_token=True, use_sp_model=not use_bert)) add_mask_token=True, use_sp_model=not use_bert))
vocab_size = 4+5 if use_bert else 4+7 vocab_size = len(non_special_tokens) + (5 if use_bert else 7)
# Create the encoder SavedModel with an .mlm subobject. # Create the encoder SavedModel with an .mlm subobject.
hidden_size = 16 hidden_size = 16
...@@ -811,12 +814,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -811,12 +814,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(mask_id, 4) self.assertEqual(mask_id, 4)
# A batch of 3 segment pairs. # A batch of 3 segment pairs.
raw_segments = [tf.constant(["hello", "nice movie", "quick brown fox"]), raw_segments = [tf.constant(["hello", "nice movie", "quick fox"]),
tf.constant(["world", "great actors", "lazy dog"])] tf.constant(["world", "great actors", "lazy dog"])]
batch_size = 3 batch_size = 3
# Misc hyperparameters. # Misc hyperparameters.
seq_length = 12 seq_length = 10
max_selections_per_seq = 2 max_selections_per_seq = 2
# Tokenize inputs. # Tokenize inputs.
...@@ -836,12 +839,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -836,12 +839,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
input_ids=input_ids, input_ids=input_ids,
item_selector=text.RandomItemSelector( item_selector=text.RandomItemSelector(
max_selections_per_seq, max_selections_per_seq,
selection_rate=0.15, selection_rate=0.5, # Adjusted for the short test examples.
unselectable_ids=[start_of_sequence_id, end_of_segment_id]), unselectable_ids=[start_of_sequence_id, end_of_segment_id]),
mask_values_chooser=text.MaskValuesChooser(vocab_size=vocab_size, mask_values_chooser=text.MaskValuesChooser(
mask_token=mask_id, vocab_size=vocab_size, mask_token=mask_id,
mask_token_rate=0.8, # Always put [MASK] to have a predictable result.
random_token_rate=0.1)) mask_token_rate=1.0, random_token_rate=0.0))
# Pad to fixed-length Transformer encoder inputs. # Pad to fixed-length Transformer encoder inputs.
input_word_ids, _ = text.pad_model_inputs(masked_input_ids, input_word_ids, _ = text.pad_model_inputs(masked_input_ids,
seq_length, seq_length,
...@@ -854,6 +857,22 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -854,6 +857,22 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
masked_lm_positions = tf.cast(masked_lm_positions, tf.int32) masked_lm_positions = tf.cast(masked_lm_positions, tf.int32)
num_predictions = int(tf.shape(masked_lm_positions)[1]) num_predictions = int(tf.shape(masked_lm_positions)[1])
# Test transformer inputs.
self.assertEqual(num_predictions, max_selections_per_seq)
expected_word_ids = np.array([
# [CLS] hello [SEP] world [SEP]
[2, 5, 3, 6, 3, 0, 0, 0, 0, 0],
# [CLS] nice movie [SEP] great actors [SEP]
[2, 7, 8, 3, 9, 10, 3, 0, 0, 0],
# [CLS] brown fox [SEP] lazy dog [SEP]
[2, 11, 12, 3, 13, 14, 3, 0, 0, 0]])
for i in range(batch_size):
for j in range(num_predictions):
k = int(masked_lm_positions[i, j])
if k != 0:
expected_word_ids[i, k] = 4 # [MASK]
self.assertAllEqual(input_word_ids, expected_word_ids)
# Call the MLM head of the Transformer encoder. # Call the MLM head of the Transformer encoder.
mlm_inputs = dict( mlm_inputs = dict(
input_word_ids=input_word_ids, input_word_ids=input_word_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment