Commit bb124157 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into RTESuperGLUE

parents 2e9bb539 0edeb7f6
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,34 +11,5 @@ ...@@ -11,34 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Util functions to integrate with Keras internals."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend
try:
from tensorflow.python.keras.engine import keras_tensor # pylint: disable=g-import-not-at-top,unused-import
keras_tensor.disable_keras_tensors()
except ImportError:
keras_tensor = None
class NoOpContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
def maybe_enter_backend_graph():
if (keras_tensor is not None) and keras_tensor.keras_tensors_enabled():
return NoOpContextManager()
else:
return backend.get_graph().as_default()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub. r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.
This tool creates preprocessor and encoder SavedModels suitable for uploading This tool creates preprocessor and encoder SavedModels suitable for uploading
...@@ -145,7 +145,7 @@ flags.DEFINE_integer( ...@@ -145,7 +145,7 @@ flags.DEFINE_integer(
"sequence length for the bert_pack_inputs subobject." "sequence length for the bert_pack_inputs subobject."
"Needed for --export_type preprocessing.") "Needed for --export_type preprocessing.")
flags.DEFINE_bool( flags.DEFINE_bool(
"tokenize_with_offsets", False, # Broken by b/149576200. "tokenize_with_offsets", False, # TODO(b/181866850)
"Whether to export a .tokenize_with_offsets subobject for " "Whether to export a .tokenize_with_offsets subobject for "
"--export_type preprocessing.") "--export_type preprocessing.")
flags.DEFINE_multi_string( flags.DEFINE_multi_string(
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Library of components of export_tfhub.py. See docstring there for more.""" """Library of components of export_tfhub.py. See docstring there for more."""
import contextlib import contextlib
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests export_tfhub_lib.""" """Tests export_tfhub_lib."""
import os import os
...@@ -21,6 +21,7 @@ from absl.testing import parameterized ...@@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
import tensorflow_text as text
from sentencepiece import SentencePieceTrainer from sentencepiece import SentencePieceTrainer
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -32,11 +33,11 @@ from official.nlp.tools import export_tfhub_lib ...@@ -32,11 +33,11 @@ from official.nlp.tools import export_tfhub_lib
def _get_bert_config_or_encoder_config(use_bert_config, hidden_size, def _get_bert_config_or_encoder_config(use_bert_config, hidden_size,
num_hidden_layers): num_hidden_layers, vocab_size=100):
"""Returns config args for export_tfhub_lib._create_model().""" """Returns config args for export_tfhub_lib._create_model()."""
if use_bert_config: if use_bert_config:
bert_config = configs.BertConfig( bert_config = configs.BertConfig(
vocab_size=100, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=32, intermediate_size=32,
max_position_embeddings=128, max_position_embeddings=128,
...@@ -48,7 +49,7 @@ def _get_bert_config_or_encoder_config(use_bert_config, hidden_size, ...@@ -48,7 +49,7 @@ def _get_bert_config_or_encoder_config(use_bert_config, hidden_size,
encoder_config = encoders.EncoderConfig( encoder_config = encoders.EncoderConfig(
type="albert", type="albert",
albert=encoders.AlbertEncoderConfig( albert=encoders.AlbertEncoderConfig(
vocab_size=100, vocab_size=vocab_size,
embedding_width=16, embedding_width=16,
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=32, intermediate_size=32,
...@@ -450,11 +451,12 @@ _STRING_NOT_TO_LEAK = "private_path_component_" ...@@ -450,11 +451,12 @@ _STRING_NOT_TO_LEAK = "private_path_component_"
class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def _make_vocab_file(self, vocab, filename="vocab.txt"): def _make_vocab_file(self, vocab, filename="vocab.txt", add_mask_token=False):
"""Creates wordpiece vocab file with given words plus special tokens. """Creates wordpiece vocab file with given words plus special tokens.
The tokens of the resulting model are, in this order: The tokens of the resulting model are, in this order:
[PAD], [UNK], [CLS], [SEP], ...vocab... [PAD], [UNK], [CLS], [SEP], [MASK]*, ...vocab...
*=if requested by args.
This function also accepts wordpieces that start with the ## continuation This function also accepts wordpieces that start with the ## continuation
marker, but avoiding those makes this function interchangeable with marker, but avoiding those makes this function interchangeable with
...@@ -465,11 +467,13 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -465,11 +467,13 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
model's vocabulary. Do not include special tokens here. model's vocabulary. Do not include special tokens here.
filename: Optionally, a filename (relative to the temporary directory filename: Optionally, a filename (relative to the temporary directory
created by this function). created by this function).
add_mask_token: an optional bool, whether to include a [MASK] token.
Returns: Returns:
The absolute filename of the created vocab file. The absolute filename of the created vocab file.
""" """
full_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"] + vocab full_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"
] + ["[MASK]"]*add_mask_token + vocab
path = os.path.join( path = os.path.join(
tempfile.mkdtemp(dir=self.get_temp_dir(), # New subdir each time. tempfile.mkdtemp(dir=self.get_temp_dir(), # New subdir each time.
prefix=_STRING_NOT_TO_LEAK), prefix=_STRING_NOT_TO_LEAK),
...@@ -478,11 +482,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -478,11 +482,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
f.write("\n".join(full_vocab + [""])) f.write("\n".join(full_vocab + [""]))
return path return path
def _make_sp_model_file(self, vocab, prefix="spm"): def _make_sp_model_file(self, vocab, prefix="spm", add_mask_token=False):
"""Creates Sentencepiece word model with given words plus special tokens. """Creates Sentencepiece word model with given words plus special tokens.
The tokens of the resulting model are, in this order: The tokens of the resulting model are, in this order:
<pad>, <unk>, [CLS], [SEP], ...vocab..., <s>, </s> <pad>, <unk>, [CLS], [SEP], [MASK]*, ...vocab..., <s>, </s>
*=if requested by args.
The words in the input vocab are plain text, without the whitespace marker. The words in the input vocab are plain text, without the whitespace marker.
That makes this function interchangeable with _make_vocab_file(). That makes this function interchangeable with _make_vocab_file().
...@@ -492,6 +497,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -492,6 +497,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
vocabulary. Do not include special tokens here. vocabulary. Do not include special tokens here.
prefix: an optional string, to change the filename prefix for the model prefix: an optional string, to change the filename prefix for the model
(relative to the temporary directory created by this function). (relative to the temporary directory created by this function).
add_mask_token: an optional bool, whether to include a [MASK] token.
Returns: Returns:
The absolute filename of the created Sentencepiece model file. The absolute filename of the created Sentencepiece model file.
...@@ -507,12 +513,16 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -507,12 +513,16 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
input_text.append(" ".join([token] * (len(vocab) - i))) input_text.append(" ".join([token] * (len(vocab) - i)))
with tf.io.gfile.GFile(input_file, "w") as f: with tf.io.gfile.GFile(input_file, "w") as f:
f.write("\n".join(input_text + [""])) f.write("\n".join(input_text + [""]))
control_symbols = "[CLS],[SEP]"
full_vocab_size = len(vocab) + 6 # <pad>, <unk>, [CLS], [SEP], <s>, </s>. full_vocab_size = len(vocab) + 6 # <pad>, <unk>, [CLS], [SEP], <s>, </s>.
if add_mask_token:
control_symbols += ",[MASK]"
full_vocab_size += 1
flags = dict( flags = dict(
model_prefix=model_prefix, model_prefix=model_prefix,
model_type="word", model_type="word",
input=input_file, input=input_file,
pad_id=0, unk_id=1, control_symbols="[CLS],[SEP]", pad_id=0, unk_id=1, control_symbols=control_symbols,
vocab_size=full_vocab_size, vocab_size=full_vocab_size,
bos_id=full_vocab_size-2, eos_id=full_vocab_size-1) bos_id=full_vocab_size-2, eos_id=full_vocab_size-1)
SentencePieceTrainer.Train( SentencePieceTrainer.Train(
...@@ -521,14 +531,15 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -521,14 +531,15 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def _do_export(self, vocab, do_lower_case, default_seq_length=128, def _do_export(self, vocab, do_lower_case, default_seq_length=128,
tokenize_with_offsets=True, use_sp_model=False, tokenize_with_offsets=True, use_sp_model=False,
experimental_disable_assert=False): experimental_disable_assert=False, add_mask_token=False):
"""Runs SavedModel export and returns the export_path.""" """Runs SavedModel export and returns the export_path."""
export_path = tempfile.mkdtemp(dir=self.get_temp_dir()) export_path = tempfile.mkdtemp(dir=self.get_temp_dir())
vocab_file = sp_model_file = None vocab_file = sp_model_file = None
if use_sp_model: if use_sp_model:
sp_model_file = self._make_sp_model_file(vocab) sp_model_file = self._make_sp_model_file(vocab,
add_mask_token=add_mask_token)
else: else:
vocab_file = self._make_vocab_file(vocab) vocab_file = self._make_vocab_file(vocab, add_mask_token=add_mask_token)
export_tfhub_lib.export_preprocessing( export_tfhub_lib.export_preprocessing(
export_path, export_path,
vocab_file=vocab_file, vocab_file=vocab_file,
...@@ -553,7 +564,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -553,7 +564,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def test_exported_callables(self, use_sp_model): def test_exported_callables(self, use_sp_model):
preprocess = tf.saved_model.load(self._do_export( preprocess = tf.saved_model.load(self._do_export(
["d", "ef", "abc", "xy"], do_lower_case=True, ["d", "ef", "abc", "xy"], do_lower_case=True,
tokenize_with_offsets=not use_sp_model, # TODO(b/149576200): drop this. tokenize_with_offsets=not use_sp_model, # TODO(b/181866850): drop this.
experimental_disable_assert=True, # TODO(b/175369555): drop this. experimental_disable_assert=True, # TODO(b/175369555): drop this.
use_sp_model=use_sp_model)) use_sp_model=use_sp_model))
...@@ -579,7 +590,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -579,7 +590,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
# .tokenize_with_offsets() # .tokenize_with_offsets()
if use_sp_model: if use_sp_model:
# TODO(b/149576200): Enable tokenize_with_offsets when it works and test. # TODO(b/181866850): Enable tokenize_with_offsets when it works and test.
self.assertFalse(hasattr(preprocess, "tokenize_with_offsets")) self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
else: else:
token_ids, start_offsets, limit_offsets = ( token_ids, start_offsets, limit_offsets = (
...@@ -680,7 +691,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -680,7 +691,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def test_shapes(self, use_sp_model): def test_shapes(self, use_sp_model):
preprocess = tf.saved_model.load(self._do_export( preprocess = tf.saved_model.load(self._do_export(
["abc", "def"], do_lower_case=True, ["abc", "def"], do_lower_case=True,
tokenize_with_offsets=not use_sp_model, # TODO(b/149576200): drop this. tokenize_with_offsets=not use_sp_model, # TODO(b/181866850): drop this.
experimental_disable_assert=True, # TODO(b/175369555): drop this. experimental_disable_assert=True, # TODO(b/175369555): drop this.
use_sp_model=use_sp_model)) use_sp_model=use_sp_model))
...@@ -700,7 +711,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -700,7 +711,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
tf.TensorSpec([batch_size], tf.string)), tf.TensorSpec([batch_size], tf.string)),
token_out_shape, token_out_shape,
"with batch_size=%s" % batch_size) "with batch_size=%s" % batch_size)
# TODO(b/149576200): Enable tokenize_with_offsets when it works and test. # TODO(b/181866850): Enable tokenize_with_offsets when it works and test.
if use_sp_model: if use_sp_model:
self.assertFalse(hasattr(preprocess, "tokenize_with_offsets")) self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
else: else:
...@@ -751,6 +762,137 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -751,6 +762,137 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
@parameterized.named_parameters(("Bert", True), ("Albert", False))
def test_preprocessing_for_mlm(self, use_bert):
"""Combines both SavedModel types and TF.text helpers for MLM."""
# Create the preprocessing SavedModel with a [MASK] token.
non_special_tokens = ["hello", "world",
"nice", "movie", "great", "actors",
"quick", "fox", "lazy", "dog"]
preprocess = tf.saved_model.load(self._do_export(
non_special_tokens, do_lower_case=True,
tokenize_with_offsets=use_bert, # TODO(b/181866850): drop this.
experimental_disable_assert=True, # TODO(b/175369555): drop this.
add_mask_token=True, use_sp_model=not use_bert))
vocab_size = len(non_special_tokens) + (5 if use_bert else 7)
# Create the encoder SavedModel with an .mlm subobject.
hidden_size = 16
num_hidden_layers = 2
bert_config, encoder_config = _get_bert_config_or_encoder_config(
use_bert, hidden_size, num_hidden_layers, vocab_size)
_, pretrainer = export_tfhub_lib._create_model(
bert_config=bert_config, encoder_config=encoder_config, with_mlm=True)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy( # Not used below.
self.get_temp_dir(), use_sp_model=not use_bert)
encoder_export_path = os.path.join(self.get_temp_dir(), "encoder_export")
export_tfhub_lib.export_model(
export_path=encoder_export_path,
bert_config=bert_config,
encoder_config=encoder_config,
model_checkpoint_path=model_checkpoint_path,
with_mlm=True,
vocab_file=vocab_file,
sp_model_file=sp_model_file,
do_lower_case=True)
encoder = tf.saved_model.load(encoder_export_path)
# Get special tokens from the vocab (and vocab size).
special_tokens_dict = preprocess.tokenize.get_special_tokens_dict()
self.assertEqual(int(special_tokens_dict["vocab_size"]), vocab_size)
padding_id = int(special_tokens_dict["padding_id"])
self.assertEqual(padding_id, 0)
start_of_sequence_id = int(special_tokens_dict["start_of_sequence_id"])
self.assertEqual(start_of_sequence_id, 2)
end_of_segment_id = int(special_tokens_dict["end_of_segment_id"])
self.assertEqual(end_of_segment_id, 3)
mask_id = int(special_tokens_dict["mask_id"])
self.assertEqual(mask_id, 4)
# A batch of 3 segment pairs.
raw_segments = [tf.constant(["hello", "nice movie", "quick fox"]),
tf.constant(["world", "great actors", "lazy dog"])]
batch_size = 3
# Misc hyperparameters.
seq_length = 10
max_selections_per_seq = 2
# Tokenize inputs.
tokenized_segments = [preprocess.tokenize(s) for s in raw_segments]
# Trim inputs to eventually fit seq_lentgh.
num_special_tokens = len(raw_segments) + 1
trimmed_segments = text.WaterfallTrimmer(
seq_length - num_special_tokens).trim(tokenized_segments)
# Combine input segments into one input sequence.
input_ids, segment_ids = text.combine_segments(
trimmed_segments,
start_of_sequence_id=start_of_sequence_id,
end_of_segment_id=end_of_segment_id)
# Apply random masking controlled by policy objects.
(masked_input_ids, masked_lm_positions,
masked_ids) = text.mask_language_model(
input_ids=input_ids,
item_selector=text.RandomItemSelector(
max_selections_per_seq,
selection_rate=0.5, # Adjusted for the short test examples.
unselectable_ids=[start_of_sequence_id, end_of_segment_id]),
mask_values_chooser=text.MaskValuesChooser(
vocab_size=vocab_size, mask_token=mask_id,
# Always put [MASK] to have a predictable result.
mask_token_rate=1.0, random_token_rate=0.0))
# Pad to fixed-length Transformer encoder inputs.
input_word_ids, _ = text.pad_model_inputs(masked_input_ids,
seq_length,
pad_value=padding_id)
input_type_ids, input_mask = text.pad_model_inputs(segment_ids, seq_length,
pad_value=0)
masked_lm_positions, _ = text.pad_model_inputs(masked_lm_positions,
max_selections_per_seq,
pad_value=0)
masked_lm_positions = tf.cast(masked_lm_positions, tf.int32)
num_predictions = int(tf.shape(masked_lm_positions)[1])
# Test transformer inputs.
self.assertEqual(num_predictions, max_selections_per_seq)
expected_word_ids = np.array([
# [CLS] hello [SEP] world [SEP]
[2, 5, 3, 6, 3, 0, 0, 0, 0, 0],
# [CLS] nice movie [SEP] great actors [SEP]
[2, 7, 8, 3, 9, 10, 3, 0, 0, 0],
# [CLS] brown fox [SEP] lazy dog [SEP]
[2, 11, 12, 3, 13, 14, 3, 0, 0, 0]])
for i in range(batch_size):
for j in range(num_predictions):
k = int(masked_lm_positions[i, j])
if k != 0:
expected_word_ids[i, k] = 4 # [MASK]
self.assertAllEqual(input_word_ids, expected_word_ids)
# Call the MLM head of the Transformer encoder.
mlm_inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids,
masked_lm_positions=masked_lm_positions,
)
mlm_outputs = encoder.mlm(mlm_inputs)
self.assertEqual(mlm_outputs["pooled_output"].shape,
(batch_size, hidden_size))
self.assertEqual(mlm_outputs["sequence_output"].shape,
(batch_size, seq_length, hidden_size))
self.assertEqual(mlm_outputs["mlm_logits"].shape,
(batch_size, num_predictions, vocab_size))
self.assertLen(mlm_outputs["encoder_outputs"], num_hidden_layers)
# A real trainer would now compute the loss of mlm_logits
# trying to predict the masked_ids.
del masked_ids # Unused.
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True)) @parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
def test_special_tokens_in_estimator(self, use_sp_model): def test_special_tokens_in_estimator(self, use_sp_model):
"""Tests getting special tokens without an Eager init context.""" """Tests getting special tokens without an Eager init context."""
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""TFM common training driver.""" """TFM common training driver."""
from absl import app from absl import app
...@@ -47,7 +46,8 @@ def main(_): ...@@ -47,7 +46,8 @@ def main(_):
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""TFM continuous finetuning+eval training driver.""" """TFM continuous finetuning+eval training driver."""
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -39,8 +38,8 @@ def main(_): ...@@ -39,8 +38,8 @@ def main(_):
params = train_utils.parse_configuration(FLAGS) params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir model_dir = FLAGS.model_dir
train_utils.serialize_config(params, model_dir) train_utils.serialize_config(params, model_dir)
continuous_finetune_lib.run_continuous_finetune(FLAGS.mode, params, model_dir, continuous_finetune_lib.run_continuous_finetune(
FLAGS.pretrain_steps) FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)
train_utils.save_gin_config(FLAGS.mode, model_dir) train_utils.save_gin_config(FLAGS.mode, model_dir)
......
...@@ -3,9 +3,11 @@ This is an implementation of the Transformer translation model as described in ...@@ -3,9 +3,11 @@ This is an implementation of the Transformer translation model as described in
the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper. The the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper. The
implementation leverages tf.keras and makes sure it is compatible with TF 2.x. implementation leverages tf.keras and makes sure it is compatible with TF 2.x.
**Note: this transformer folder is subject to be integrated into official/nlp **Warning: the features in the `transformer/` folder have been fully intergrated
folder. Due to its dependencies, we will finish the refactoring after the model into nlp/modeling.
garden 2.1 release.** Due to its dependencies, we will remove this folder after the model
garden 2.5 release. The model in `nlp/modeling/models/seq2seq_transformer.py` is
identical to the model in this folder.**
## Contents ## Contents
* [Contents](#contents) * [Contents](#contents)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Implementation of multiheaded attention and self-attention layers.""" """Implementation of multiheaded attention and self-attention layers."""
import math import math
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Beam search to find the translated sequence with the highest probability.""" """Beam search to find the translated sequence with the highest probability."""
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Script to compute official BLEU score. """Script to compute official BLEU score.
Source: Source:
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Test functions in compute_blue.py.""" """Test functions in compute_blue.py."""
import tempfile import tempfile
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Download and preprocess WMT17 ende training and evaluation datasets.""" """Download and preprocess WMT17 ende training and evaluation datasets."""
import os import os
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Input pipeline for the transformer model to read, filter, and batch examples. """Input pipeline for the transformer model to read, filter, and batch examples.
Two things to note in the pipeline: Two things to note in the pipeline:
...@@ -242,7 +242,7 @@ def _read_and_batch_from_files(file_pattern, ...@@ -242,7 +242,7 @@ def _read_and_batch_from_files(file_pattern,
num_parallel_calls=tf.data.experimental.AUTOTUNE).with_options(options) num_parallel_calls=tf.data.experimental.AUTOTUNE).with_options(options)
# Parse each tf.Example into a dictionary # Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization. # TODO: Look into prefetch_input_elements for performance optimization. # pylint: disable=g-bad-todo
dataset = dataset.map( dataset = dataset.map(
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) _parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Implementation of embedding layer with shared weights.""" """Implementation of embedding layer with shared weights."""
import tensorflow as tf import tensorflow as tf
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Implementation of fully connected network.""" """Implementation of fully connected network."""
import tensorflow as tf import tensorflow as tf
...@@ -62,8 +62,6 @@ class FeedForwardNetwork(tf.keras.layers.Layer): ...@@ -62,8 +62,6 @@ class FeedForwardNetwork(tf.keras.layers.Layer):
tensor with shape [batch_size, length, hidden_size] tensor with shape [batch_size, length, hidden_size]
""" """
# Retrieve dynamically known shapes # Retrieve dynamically known shapes
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
output = self.filter_dense_layer(x) output = self.filter_dense_layer(x)
if training: if training:
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Functions for calculating loss, accuracy, and other model metrics. """Functions for calculating loss, accuracy, and other model metrics.
Metrics: Metrics:
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Misc for Transformer.""" """Misc for Transformer."""
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,13 +11,13 @@ ...@@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Defines Transformer model parameters.""" """Defines Transformer model parameters."""
from collections import defaultdict import collections
BASE_PARAMS = defaultdict( BASE_PARAMS = collections.defaultdict(
lambda: None, # Set default value to None. lambda: None, # Set default value to None.
# Input params # Input params
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Transformer model helper methods.""" """Transformer model helper methods."""
import math import math
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment