Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -71,8 +71,8 @@ from absl import app ...@@ -71,8 +71,8 @@ from absl import app
from absl import flags from absl import flags
import gin import gin
from official.legacy.bert import configs
from official.modeling import hyperparams from official.modeling import hyperparams
from official.nlp.bert import configs
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.tools import export_tfhub_lib from official.nlp.tools import export_tfhub_lib
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -28,8 +28,8 @@ import tensorflow as tf ...@@ -28,8 +28,8 @@ import tensorflow as tf
from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
# pylint: enable=g-direct-tensorflow-import # pylint: enable=g-direct-tensorflow-import
from official.legacy.bert import configs
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.bert import configs
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import models from official.nlp.modeling import models
...@@ -84,13 +84,13 @@ def _create_model( ...@@ -84,13 +84,13 @@ def _create_model(
"""Creates the model to export and the model to restore the checkpoint. """Creates the model to export and the model to restore the checkpoint.
Args: Args:
bert_config: A legacy `BertConfig` to create a `BertEncoder` object. bert_config: A legacy `BertConfig` to create a `BertEncoder` object. Exactly
Exactly one of encoder_config and bert_config must be set. one of encoder_config and bert_config must be set.
encoder_config: An `EncoderConfig` to create an encoder of the configured encoder_config: An `EncoderConfig` to create an encoder of the configured
type (`BertEncoder` or other). type (`BertEncoder` or other).
with_mlm: A bool to control the second component of the result. with_mlm: A bool to control the second component of the result. If True,
If True, will create a `BertPretrainerV2` object; otherwise, will will create a `BertPretrainerV2` object; otherwise, will create a
create a `BertEncoder` object. `BertEncoder` object.
Returns: Returns:
A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2` A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2`
...@@ -110,7 +110,11 @@ def _create_model( ...@@ -110,7 +110,11 @@ def _create_model(
# Convert from list of named inputs to dict of inputs keyed by name. # Convert from list of named inputs to dict of inputs keyed by name.
# Only the latter accepts a dict of inputs after restoring from SavedModel. # Only the latter accepts a dict of inputs after restoring from SavedModel.
encoder_inputs_dict = {x.name: x for x in encoder.inputs} if isinstance(encoder.inputs, list) or isinstance(encoder.inputs, tuple):
encoder_inputs_dict = {x.name: x for x in encoder.inputs}
else:
# encoder.inputs by default is dict for BertEncoderV2.
encoder_inputs_dict = encoder.inputs
encoder_output_dict = encoder(encoder_inputs_dict) encoder_output_dict = encoder(encoder_inputs_dict)
# For interchangeability with other text representations, # For interchangeability with other text representations,
# add "default" as an alias for BERT's whole-input reptesentations. # add "default" as an alias for BERT's whole-input reptesentations.
...@@ -129,7 +133,10 @@ def _create_model( ...@@ -129,7 +133,10 @@ def _create_model(
encoder_network=encoder, encoder_network=encoder,
mlm_activation=tf_utils.get_activation(hidden_act)) mlm_activation=tf_utils.get_activation(hidden_act))
pretrainer_inputs_dict = {x.name: x for x in pretrainer.inputs} if isinstance(pretrainer.inputs, dict):
pretrainer_inputs_dict = pretrainer.inputs
else:
pretrainer_inputs_dict = {x.name: x for x in pretrainer.inputs}
pretrainer_output_dict = pretrainer(pretrainer_inputs_dict) pretrainer_output_dict = pretrainer(pretrainer_inputs_dict)
mlm_model = tf.keras.Model( mlm_model = tf.keras.Model(
inputs=pretrainer_inputs_dict, outputs=pretrainer_output_dict) inputs=pretrainer_inputs_dict, outputs=pretrainer_output_dict)
...@@ -206,26 +213,28 @@ def export_model(export_path: Text, ...@@ -206,26 +213,28 @@ def export_model(export_path: Text,
encoder_config: An optional `encoders.EncoderConfig` object. encoder_config: An optional `encoders.EncoderConfig` object.
model_checkpoint_path: The path to the checkpoint. model_checkpoint_path: The path to the checkpoint.
with_mlm: Whether to export the additional mlm sub-object. with_mlm: Whether to export the additional mlm sub-object.
copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer used
used in the next sentence prediction task to the encoder. in the next sentence prediction task to the encoder.
vocab_file: The path to the wordpiece vocab file, or None. vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None. sp_model_file: The path to the sentencepiece model file, or None. Exactly
Exactly one of vocab_file and sp_model_file must be set. one of vocab_file and sp_model_file must be set.
do_lower_case: Whether to lower-case text before tokenization. do_lower_case: Whether to lower-case text before tokenization.
""" """
if with_mlm: if with_mlm:
core_model, pretrainer = _create_model(bert_config=bert_config, core_model, pretrainer = _create_model(
encoder_config=encoder_config, bert_config=bert_config,
with_mlm=with_mlm) encoder_config=encoder_config,
with_mlm=with_mlm)
encoder = pretrainer.encoder_network encoder = pretrainer.encoder_network
# It supports both the new pretrainer checkpoint produced by TF-NLP and # It supports both the new pretrainer checkpoint produced by TF-NLP and
# the checkpoint converted from TF1 (original BERT, SmallBERTs). # the checkpoint converted from TF1 (original BERT, SmallBERTs).
checkpoint_items = pretrainer.checkpoint_items checkpoint_items = pretrainer.checkpoint_items
checkpoint = tf.train.Checkpoint(**checkpoint_items) checkpoint = tf.train.Checkpoint(**checkpoint_items)
else: else:
core_model, encoder = _create_model(bert_config=bert_config, core_model, encoder = _create_model(
encoder_config=encoder_config, bert_config=bert_config,
with_mlm=with_mlm) encoder_config=encoder_config,
with_mlm=with_mlm)
checkpoint = tf.train.Checkpoint( checkpoint = tf.train.Checkpoint(
model=encoder, # Legacy checkpoints. model=encoder, # Legacy checkpoints.
encoder=encoder) encoder=encoder)
...@@ -279,21 +288,26 @@ class BertPackInputsSavedModelWrapper(tf.train.Checkpoint): ...@@ -279,21 +288,26 @@ class BertPackInputsSavedModelWrapper(tf.train.Checkpoint):
# overridable. Having this dynamically determined default argument # overridable. Having this dynamically determined default argument
# requires self.__call__ to be defined in this indirect way. # requires self.__call__ to be defined in this indirect way.
default_seq_length = bert_pack_inputs.seq_length default_seq_length = bert_pack_inputs.seq_length
@tf.function(autograph=False) @tf.function(autograph=False)
def call(inputs, seq_length=default_seq_length): def call(inputs, seq_length=default_seq_length):
return layers.BertPackInputs.bert_pack_inputs( return layers.BertPackInputs.bert_pack_inputs(
inputs, seq_length=seq_length, inputs,
seq_length=seq_length,
start_of_sequence_id=bert_pack_inputs.start_of_sequence_id, start_of_sequence_id=bert_pack_inputs.start_of_sequence_id,
end_of_segment_id=bert_pack_inputs.end_of_segment_id, end_of_segment_id=bert_pack_inputs.end_of_segment_id,
padding_id=bert_pack_inputs.padding_id) padding_id=bert_pack_inputs.padding_id)
self.__call__ = call self.__call__ = call
for ragged_rank in range(1, 3): for ragged_rank in range(1, 3):
for num_segments in range(1, 3): for num_segments in range(1, 3):
_ = self.__call__.get_concrete_function( _ = self.__call__.get_concrete_function([
[tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32) tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32)
for _ in range(num_segments)], for _ in range(num_segments)
seq_length=tf.TensorSpec([], tf.int32)) ],
seq_length=tf.TensorSpec(
[], tf.int32))
def create_preprocessing(*, def create_preprocessing(*,
...@@ -311,14 +325,14 @@ def create_preprocessing(*, ...@@ -311,14 +325,14 @@ def create_preprocessing(*,
Args: Args:
vocab_file: The path to the wordpiece vocab file, or None. vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None. sp_model_file: The path to the sentencepiece model file, or None. Exactly
Exactly one of vocab_file and sp_model_file must be set. one of vocab_file and sp_model_file must be set. This determines the type
This determines the type of tokenzer that is used. of tokenzer that is used.
do_lower_case: Whether to do lower case. do_lower_case: Whether to do lower case.
tokenize_with_offsets: Whether to include the .tokenize_with_offsets tokenize_with_offsets: Whether to include the .tokenize_with_offsets
subobject. subobject.
default_seq_length: The sequence length of preprocessing results from default_seq_length: The sequence length of preprocessing results from root
root callable. This is also the default sequence length for the callable. This is also the default sequence length for the
bert_pack_inputs subobject. bert_pack_inputs subobject.
Returns: Returns:
...@@ -378,7 +392,8 @@ def create_preprocessing(*, ...@@ -378,7 +392,8 @@ def create_preprocessing(*,
def _move_to_tmpdir(file_path: Optional[Text], tmpdir: Text) -> Optional[Text]: def _move_to_tmpdir(file_path: Optional[Text], tmpdir: Text) -> Optional[Text]:
"""Returns new path with same basename and hash of original path.""" """Returns new path with same basename and hash of original path."""
if file_path is None: return None if file_path is None:
return None
olddir, filename = os.path.split(file_path) olddir, filename = os.path.split(file_path)
hasher = hashlib.sha1() hasher = hashlib.sha1()
hasher.update(olddir.encode("utf-8")) hasher.update(olddir.encode("utf-8"))
...@@ -460,12 +475,17 @@ def _check_no_assert(saved_model_path): ...@@ -460,12 +475,17 @@ def _check_no_assert(saved_model_path):
assert_nodes = [] assert_nodes = []
graph_def = saved_model.meta_graphs[0].graph_def graph_def = saved_model.meta_graphs[0].graph_def
assert_nodes += ["node '{}' in global graph".format(n.name) assert_nodes += [
for n in graph_def.node if n.op == "Assert"] "node '{}' in global graph".format(n.name)
for n in graph_def.node
if n.op == "Assert"
]
for fdef in graph_def.library.function: for fdef in graph_def.library.function:
assert_nodes += [ assert_nodes += [
"node '{}' in function '{}'".format(n.name, fdef.signature.name) "node '{}' in function '{}'".format(n.name, fdef.signature.name)
for n in fdef.node_def if n.op == "Assert"] for n in fdef.node_def
if n.op == "Assert"
]
if assert_nodes: if assert_nodes:
raise AssertionError( raise AssertionError(
"Internal tool error: " "Internal tool error: "
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,21 +20,39 @@ import tempfile ...@@ -20,21 +20,39 @@ import tempfile
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator
import tensorflow_hub as hub import tensorflow_hub as hub
import tensorflow_text as text import tensorflow_text as text
from sentencepiece import SentencePieceTrainer from sentencepiece import SentencePieceTrainer
from official.legacy.bert import configs
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.bert import configs
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.tools import export_tfhub_lib from official.nlp.tools import export_tfhub_lib
def _get_bert_config_or_encoder_config(use_bert_config, hidden_size, def _get_bert_config_or_encoder_config(use_bert_config,
num_hidden_layers, vocab_size=100): hidden_size,
"""Returns config args for export_tfhub_lib._create_model().""" num_hidden_layers,
encoder_type="albert",
vocab_size=100):
"""Generates config args for export_tfhub_lib._create_model().
Args:
use_bert_config: bool. If True, returns legacy BertConfig.
hidden_size: int.
num_hidden_layers: int.
encoder_type: str. Can be ['albert', 'bert', 'bert_v2']. If use_bert_config
== True, then model_type is not used.
vocab_size: int.
Returns:
bert_config, encoder_config. Only one is not None. If
`use_bert_config` == True, the first config is valid. Otherwise
`bert_config` == None.
"""
if use_bert_config: if use_bert_config:
bert_config = configs.BertConfig( bert_config = configs.BertConfig(
vocab_size=vocab_size, vocab_size=vocab_size,
...@@ -46,17 +64,31 @@ def _get_bert_config_or_encoder_config(use_bert_config, hidden_size, ...@@ -46,17 +64,31 @@ def _get_bert_config_or_encoder_config(use_bert_config, hidden_size,
encoder_config = None encoder_config = None
else: else:
bert_config = None bert_config = None
encoder_config = encoders.EncoderConfig( if encoder_type == "albert":
type="albert", encoder_config = encoders.EncoderConfig(
albert=encoders.AlbertEncoderConfig( type="albert",
vocab_size=vocab_size, albert=encoders.AlbertEncoderConfig(
embedding_width=16, vocab_size=vocab_size,
hidden_size=hidden_size, embedding_width=16,
intermediate_size=32, hidden_size=hidden_size,
max_position_embeddings=128, intermediate_size=32,
num_attention_heads=2, max_position_embeddings=128,
num_layers=num_hidden_layers, num_attention_heads=2,
dropout_rate=0.1)) num_layers=num_hidden_layers,
dropout_rate=0.1))
else:
# encoder_type can be 'bert' or 'bert_v2'.
model_config = encoders.BertEncoderConfig(
vocab_size=vocab_size,
embedding_size=16,
hidden_size=hidden_size,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_layers=num_hidden_layers,
dropout_rate=0.1)
kwargs = {"type": encoder_type, encoder_type: model_config}
encoder_config = encoders.EncoderConfig(**kwargs)
return bert_config, encoder_config return bert_config, encoder_config
...@@ -105,13 +137,18 @@ class ExportModelTest(tf.test.TestCase, parameterized.TestCase): ...@@ -105,13 +137,18 @@ class ExportModelTest(tf.test.TestCase, parameterized.TestCase):
alternative to BertTokenizer). alternative to BertTokenizer).
""" """
@parameterized.named_parameters(("Bert", True), ("Albert", False)) @parameterized.named_parameters(
def test_export_model(self, use_bert): ("Bert_Legacy", True, None), ("Albert", False, "albert"),
("BertEncoder", False, "bert"), ("BertEncoderV2", False, "bert_v2"))
def test_export_model(self, use_bert, encoder_type):
# Create the encoder and export it. # Create the encoder and export it.
hidden_size = 16 hidden_size = 16
num_hidden_layers = 1 num_hidden_layers = 1
bert_config, encoder_config = _get_bert_config_or_encoder_config( bert_config, encoder_config = _get_bert_config_or_encoder_config(
use_bert, hidden_size, num_hidden_layers) use_bert,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
encoder_type=encoder_type)
bert_model, encoder = export_tfhub_lib._create_model( bert_model, encoder = export_tfhub_lib._create_model(
bert_config=bert_config, encoder_config=encoder_config, with_mlm=False) bert_config=bert_config, encoder_config=encoder_config, with_mlm=False)
self.assertEmpty( self.assertEmpty(
...@@ -151,8 +188,8 @@ class ExportModelTest(tf.test.TestCase, parameterized.TestCase): ...@@ -151,8 +188,8 @@ class ExportModelTest(tf.test.TestCase, parameterized.TestCase):
_read_asset(hub_layer.resolved_object.sp_model_file)) _read_asset(hub_layer.resolved_object.sp_model_file))
# Check restored weights. # Check restored weights.
self.assertEqual(len(bert_model.trainable_weights), self.assertEqual(
len(hub_layer.trainable_weights)) len(bert_model.trainable_weights), len(hub_layer.trainable_weights))
for source_weight, hub_weight in zip(bert_model.trainable_weights, for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights): hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy()) self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
...@@ -334,8 +371,8 @@ class ExportModelWithMLMTest(tf.test.TestCase, parameterized.TestCase): ...@@ -334,8 +371,8 @@ class ExportModelWithMLMTest(tf.test.TestCase, parameterized.TestCase):
# Note that we set `_auto_track_sub_layers` to False when exporting the # Note that we set `_auto_track_sub_layers` to False when exporting the
# SavedModel, so hub_layer has the same number of weights as bert_model; # SavedModel, so hub_layer has the same number of weights as bert_model;
# otherwise, hub_layer will have extra weights from its `mlm` subobject. # otherwise, hub_layer will have extra weights from its `mlm` subobject.
self.assertEqual(len(bert_model.trainable_weights), self.assertEqual(
len(hub_layer.trainable_weights)) len(bert_model.trainable_weights), len(hub_layer.trainable_weights))
for source_weight, hub_weight in zip(bert_model.trainable_weights, for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights): hub_layer.trainable_weights):
self.assertAllClose(source_weight, hub_weight) self.assertAllClose(source_weight, hub_weight)
...@@ -473,10 +510,11 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -473,10 +510,11 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
The absolute filename of the created vocab file. The absolute filename of the created vocab file.
""" """
full_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]" full_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"
] + ["[MASK]"]*add_mask_token + vocab ] + ["[MASK]"] * add_mask_token + vocab
path = os.path.join( path = os.path.join(
tempfile.mkdtemp(dir=self.get_temp_dir(), # New subdir each time. tempfile.mkdtemp(
prefix=_STRING_NOT_TO_LEAK), dir=self.get_temp_dir(), # New subdir each time.
prefix=_STRING_NOT_TO_LEAK),
filename) filename)
with tf.io.gfile.GFile(path, "w") as f: with tf.io.gfile.GFile(path, "w") as f:
f.write("\n".join(full_vocab + [""])) f.write("\n".join(full_vocab + [""]))
...@@ -522,22 +560,30 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -522,22 +560,30 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
model_prefix=model_prefix, model_prefix=model_prefix,
model_type="word", model_type="word",
input=input_file, input=input_file,
pad_id=0, unk_id=1, control_symbols=control_symbols, pad_id=0,
unk_id=1,
control_symbols=control_symbols,
vocab_size=full_vocab_size, vocab_size=full_vocab_size,
bos_id=full_vocab_size-2, eos_id=full_vocab_size-1) bos_id=full_vocab_size - 2,
SentencePieceTrainer.Train( eos_id=full_vocab_size - 1)
" ".join(["--{}={}".format(k, v) for k, v in flags.items()])) SentencePieceTrainer.Train(" ".join(
["--{}={}".format(k, v) for k, v in flags.items()]))
return model_prefix + ".model" return model_prefix + ".model"
def _do_export(self, vocab, do_lower_case, default_seq_length=128, def _do_export(self,
tokenize_with_offsets=True, use_sp_model=False, vocab,
experimental_disable_assert=False, add_mask_token=False): do_lower_case,
default_seq_length=128,
tokenize_with_offsets=True,
use_sp_model=False,
experimental_disable_assert=False,
add_mask_token=False):
"""Runs SavedModel export and returns the export_path.""" """Runs SavedModel export and returns the export_path."""
export_path = tempfile.mkdtemp(dir=self.get_temp_dir()) export_path = tempfile.mkdtemp(dir=self.get_temp_dir())
vocab_file = sp_model_file = None vocab_file = sp_model_file = None
if use_sp_model: if use_sp_model:
sp_model_file = self._make_sp_model_file(vocab, sp_model_file = self._make_sp_model_file(
add_mask_token=add_mask_token) vocab, add_mask_token=add_mask_token)
else: else:
vocab_file = self._make_vocab_file(vocab, add_mask_token=add_mask_token) vocab_file = self._make_vocab_file(vocab, add_mask_token=add_mask_token)
export_tfhub_lib.export_preprocessing( export_tfhub_lib.export_preprocessing(
...@@ -554,19 +600,24 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -554,19 +600,24 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def test_no_leaks(self): def test_no_leaks(self):
"""Tests not leaking the path to the original vocab file.""" """Tests not leaking the path to the original vocab file."""
path = self._do_export( path = self._do_export(["d", "ef", "abc", "xy"],
["d", "ef", "abc", "xy"], do_lower_case=True, use_sp_model=False) do_lower_case=True,
use_sp_model=False)
with tf.io.gfile.GFile(os.path.join(path, "saved_model.pb"), "rb") as f: with tf.io.gfile.GFile(os.path.join(path, "saved_model.pb"), "rb") as f:
self.assertFalse( # pylint: disable=g-generic-assert self.assertFalse( # pylint: disable=g-generic-assert
_STRING_NOT_TO_LEAK.encode("ascii") in f.read()) _STRING_NOT_TO_LEAK.encode("ascii") in f.read())
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True)) @parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
def test_exported_callables(self, use_sp_model): def test_exported_callables(self, use_sp_model):
preprocess = tf.saved_model.load(self._do_export( preprocess = tf.saved_model.load(
["d", "ef", "abc", "xy"], do_lower_case=True, self._do_export(
tokenize_with_offsets=not use_sp_model, # TODO(b/181866850): drop this. ["d", "ef", "abc", "xy"],
experimental_disable_assert=True, # TODO(b/175369555): drop this. do_lower_case=True,
use_sp_model=use_sp_model)) # TODO(b/181866850): drop this.
tokenize_with_offsets=not use_sp_model,
# TODO(b/175369555): drop this.
experimental_disable_assert=True,
use_sp_model=use_sp_model))
def fold_dim(rt): def fold_dim(rt):
"""Removes the word/subword distinction of BertTokenizer.""" """Removes the word/subword distinction of BertTokenizer."""
...@@ -575,18 +626,20 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -575,18 +626,20 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
# .tokenize() # .tokenize()
inputs = tf.constant(["abc d ef", "ABC D EF d"]) inputs = tf.constant(["abc d ef", "ABC D EF d"])
token_ids = preprocess.tokenize(inputs) token_ids = preprocess.tokenize(inputs)
self.assertAllEqual(fold_dim(token_ids), self.assertAllEqual(
tf.ragged.constant([[6, 4, 5], fold_dim(token_ids), tf.ragged.constant([[6, 4, 5], [6, 4, 5, 4]]))
[6, 4, 5, 4]]))
special_tokens_dict = { special_tokens_dict = {
k: v.numpy().item() # Expecting eager Tensor, converting to Python. k: v.numpy().item() # Expecting eager Tensor, converting to Python.
for k, v in preprocess.tokenize.get_special_tokens_dict().items()} for k, v in preprocess.tokenize.get_special_tokens_dict().items()
self.assertDictEqual(special_tokens_dict, }
dict(padding_id=0, self.assertDictEqual(
start_of_sequence_id=2, special_tokens_dict,
end_of_segment_id=3, dict(
vocab_size=4+6 if use_sp_model else 4+4)) padding_id=0,
start_of_sequence_id=2,
end_of_segment_id=3,
vocab_size=4 + 6 if use_sp_model else 4 + 4))
# .tokenize_with_offsets() # .tokenize_with_offsets()
if use_sp_model: if use_sp_model:
...@@ -595,92 +648,104 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -595,92 +648,104 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
else: else:
token_ids, start_offsets, limit_offsets = ( token_ids, start_offsets, limit_offsets = (
preprocess.tokenize_with_offsets(inputs)) preprocess.tokenize_with_offsets(inputs))
self.assertAllEqual(fold_dim(token_ids), self.assertAllEqual(
tf.ragged.constant([[6, 4, 5], fold_dim(token_ids), tf.ragged.constant([[6, 4, 5], [6, 4, 5, 4]]))
[6, 4, 5, 4]])) self.assertAllEqual(
self.assertAllEqual(fold_dim(start_offsets), fold_dim(start_offsets), tf.ragged.constant([[0, 4, 6], [0, 4, 6,
tf.ragged.constant([[0, 4, 6], 9]]))
[0, 4, 6, 9]])) self.assertAllEqual(
self.assertAllEqual(fold_dim(limit_offsets), fold_dim(limit_offsets), tf.ragged.constant([[3, 5, 8], [3, 5, 8,
tf.ragged.constant([[3, 5, 8], 10]]))
[3, 5, 8, 10]]))
self.assertIs(preprocess.tokenize.get_special_tokens_dict, self.assertIs(preprocess.tokenize.get_special_tokens_dict,
preprocess.tokenize_with_offsets.get_special_tokens_dict) preprocess.tokenize_with_offsets.get_special_tokens_dict)
# Root callable. # Root callable.
bert_inputs = preprocess(inputs) bert_inputs = preprocess(inputs)
self.assertAllEqual(bert_inputs["input_word_ids"].shape.as_list(), [2, 128]) self.assertAllEqual(bert_inputs["input_word_ids"].shape.as_list(), [2, 128])
self.assertAllEqual(bert_inputs["input_word_ids"][:, :10], self.assertAllEqual(
tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0], bert_inputs["input_word_ids"][:, :10],
[2, 6, 4, 5, 4, 3, 0, 0, 0, 0]])) tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0],
[2, 6, 4, 5, 4, 3, 0, 0, 0, 0]]))
self.assertAllEqual(bert_inputs["input_mask"].shape.as_list(), [2, 128]) self.assertAllEqual(bert_inputs["input_mask"].shape.as_list(), [2, 128])
self.assertAllEqual(bert_inputs["input_mask"][:, :10], self.assertAllEqual(
tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], bert_inputs["input_mask"][:, :10],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])) tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]))
self.assertAllEqual(bert_inputs["input_type_ids"].shape.as_list(), [2, 128]) self.assertAllEqual(bert_inputs["input_type_ids"].shape.as_list(), [2, 128])
self.assertAllEqual(bert_inputs["input_type_ids"][:, :10], self.assertAllEqual(
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], bert_inputs["input_type_ids"][:, :10],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])) tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
# .bert_pack_inputs() # .bert_pack_inputs()
inputs_2 = tf.constant(["d xy", "xy abc"]) inputs_2 = tf.constant(["d xy", "xy abc"])
token_ids_2 = preprocess.tokenize(inputs_2) token_ids_2 = preprocess.tokenize(inputs_2)
bert_inputs = preprocess.bert_pack_inputs( bert_inputs = preprocess.bert_pack_inputs([token_ids, token_ids_2],
[token_ids, token_ids_2], seq_length=256) seq_length=256)
self.assertAllEqual(bert_inputs["input_word_ids"].shape.as_list(), [2, 256]) self.assertAllEqual(bert_inputs["input_word_ids"].shape.as_list(), [2, 256])
self.assertAllEqual(bert_inputs["input_word_ids"][:, :10], self.assertAllEqual(
tf.constant([[2, 6, 4, 5, 3, 4, 7, 3, 0, 0], bert_inputs["input_word_ids"][:, :10],
[2, 6, 4, 5, 4, 3, 7, 6, 3, 0]])) tf.constant([[2, 6, 4, 5, 3, 4, 7, 3, 0, 0],
[2, 6, 4, 5, 4, 3, 7, 6, 3, 0]]))
self.assertAllEqual(bert_inputs["input_mask"].shape.as_list(), [2, 256]) self.assertAllEqual(bert_inputs["input_mask"].shape.as_list(), [2, 256])
self.assertAllEqual(bert_inputs["input_mask"][:, :10], self.assertAllEqual(
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0], bert_inputs["input_mask"][:, :10],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])) tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]))
self.assertAllEqual(bert_inputs["input_type_ids"].shape.as_list(), [2, 256]) self.assertAllEqual(bert_inputs["input_type_ids"].shape.as_list(), [2, 256])
self.assertAllEqual(bert_inputs["input_type_ids"][:, :10], self.assertAllEqual(
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 0, 0], bert_inputs["input_type_ids"][:, :10],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 0]])) tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 0]]))
# For BertTokenizer only: repeat relevant parts for do_lower_case=False, # For BertTokenizer only: repeat relevant parts for do_lower_case=False,
# default_seq_length=10, experimental_disable_assert=False, # default_seq_length=10, experimental_disable_assert=False,
# tokenize_with_offsets=False, and without folding the word/subword dimension. # tokenize_with_offsets=False, and without folding the word/subword dimension.
def test_cased_length10(self): def test_cased_length10(self):
preprocess = tf.saved_model.load(self._do_export( preprocess = tf.saved_model.load(
["d", "##ef", "abc", "ABC"], self._do_export(["d", "##ef", "abc", "ABC"],
do_lower_case=False, default_seq_length=10, do_lower_case=False,
tokenize_with_offsets=False, default_seq_length=10,
use_sp_model=False, tokenize_with_offsets=False,
experimental_disable_assert=False)) use_sp_model=False,
experimental_disable_assert=False))
inputs = tf.constant(["abc def", "ABC DEF"]) inputs = tf.constant(["abc def", "ABC DEF"])
token_ids = preprocess.tokenize(inputs) token_ids = preprocess.tokenize(inputs)
self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]], self.assertAllEqual(token_ids,
[[7], [1]]])) tf.ragged.constant([[[6], [4, 5]], [[7], [1]]]))
self.assertFalse(hasattr(preprocess, "tokenize_with_offsets")) self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
bert_inputs = preprocess(inputs) bert_inputs = preprocess(inputs)
self.assertAllEqual(bert_inputs["input_word_ids"], self.assertAllEqual(
tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0], bert_inputs["input_word_ids"],
[2, 7, 1, 3, 0, 0, 0, 0, 0, 0]])) tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0],
self.assertAllEqual(bert_inputs["input_mask"], [2, 7, 1, 3, 0, 0, 0, 0, 0, 0]]))
tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], self.assertAllEqual(
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])) bert_inputs["input_mask"],
self.assertAllEqual(bert_inputs["input_type_ids"], tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]))
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])) self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
inputs_2 = tf.constant(["d ABC", "ABC abc"]) inputs_2 = tf.constant(["d ABC", "ABC abc"])
token_ids_2 = preprocess.tokenize(inputs_2) token_ids_2 = preprocess.tokenize(inputs_2)
bert_inputs = preprocess.bert_pack_inputs([token_ids, token_ids_2]) bert_inputs = preprocess.bert_pack_inputs([token_ids, token_ids_2])
# Test default seq_length=10. # Test default seq_length=10.
self.assertAllEqual(bert_inputs["input_word_ids"], self.assertAllEqual(
tf.constant([[2, 6, 4, 5, 3, 4, 7, 3, 0, 0], bert_inputs["input_word_ids"],
[2, 7, 1, 3, 7, 6, 3, 0, 0, 0]])) tf.constant([[2, 6, 4, 5, 3, 4, 7, 3, 0, 0],
self.assertAllEqual(bert_inputs["input_mask"], [2, 7, 1, 3, 7, 6, 3, 0, 0, 0]]))
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0], self.assertAllEqual(
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])) bert_inputs["input_mask"],
self.assertAllEqual(bert_inputs["input_type_ids"], tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]))
[0, 0, 0, 0, 1, 1, 1, 0, 0, 0]])) self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 0, 0, 0]]))
# XLA requires fixed shapes for tensors found in graph mode. # XLA requires fixed shapes for tensors found in graph mode.
# Statically known shapes in Python are a particularly firm way to # Statically known shapes in Python are a particularly firm way to
...@@ -689,16 +754,21 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -689,16 +754,21 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
# inference when applied to fully or partially known input shapes. # inference when applied to fully or partially known input shapes.
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True)) @parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
def test_shapes(self, use_sp_model): def test_shapes(self, use_sp_model):
preprocess = tf.saved_model.load(self._do_export( preprocess = tf.saved_model.load(
["abc", "def"], do_lower_case=True, self._do_export(
tokenize_with_offsets=not use_sp_model, # TODO(b/181866850): drop this. ["abc", "def"],
experimental_disable_assert=True, # TODO(b/175369555): drop this. do_lower_case=True,
use_sp_model=use_sp_model)) # TODO(b/181866850): drop this.
tokenize_with_offsets=not use_sp_model,
# TODO(b/175369555): drop this.
experimental_disable_assert=True,
use_sp_model=use_sp_model))
def expected_bert_input_shapes(batch_size, seq_length): def expected_bert_input_shapes(batch_size, seq_length):
return dict(input_word_ids=[batch_size, seq_length], return dict(
input_mask=[batch_size, seq_length], input_word_ids=[batch_size, seq_length],
input_type_ids=[batch_size, seq_length]) input_mask=[batch_size, seq_length],
input_type_ids=[batch_size, seq_length])
for batch_size in [7, None]: for batch_size in [7, None]:
if use_sp_model: if use_sp_model:
...@@ -706,11 +776,9 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -706,11 +776,9 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
else: else:
token_out_shape = [batch_size, None, None] token_out_shape = [batch_size, None, None]
self.assertEqual( self.assertEqual(
_result_shapes_in_tf_function( _result_shapes_in_tf_function(preprocess.tokenize,
preprocess.tokenize, tf.TensorSpec([batch_size], tf.string)),
tf.TensorSpec([batch_size], tf.string)), token_out_shape, "with batch_size=%s" % batch_size)
token_out_shape,
"with batch_size=%s" % batch_size)
# TODO(b/181866850): Enable tokenize_with_offsets when it works and test. # TODO(b/181866850): Enable tokenize_with_offsets when it works and test.
if use_sp_model: if use_sp_model:
self.assertFalse(hasattr(preprocess, "tokenize_with_offsets")) self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
...@@ -718,8 +786,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -718,8 +786,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual( self.assertEqual(
_result_shapes_in_tf_function( _result_shapes_in_tf_function(
preprocess.tokenize_with_offsets, preprocess.tokenize_with_offsets,
tf.TensorSpec([batch_size], tf.string)), tf.TensorSpec([batch_size], tf.string)), [token_out_shape] * 3,
[token_out_shape] * 3,
"with batch_size=%s" % batch_size) "with batch_size=%s" % batch_size)
self.assertEqual( self.assertEqual(
_result_shapes_in_tf_function( _result_shapes_in_tf_function(
...@@ -737,7 +804,9 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -737,7 +804,9 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def test_reexport(self, use_sp_model): def test_reexport(self, use_sp_model):
"""Test that preprocess keeps working after another save/load cycle.""" """Test that preprocess keeps working after another save/load cycle."""
path1 = self._do_export( path1 = self._do_export(
["d", "ef", "abc", "xy"], do_lower_case=True, default_seq_length=10, ["d", "ef", "abc", "xy"],
do_lower_case=True,
default_seq_length=10,
tokenize_with_offsets=False, tokenize_with_offsets=False,
experimental_disable_assert=True, # TODO(b/175369555): drop this. experimental_disable_assert=True, # TODO(b/175369555): drop this.
use_sp_model=use_sp_model) use_sp_model=use_sp_model)
...@@ -752,35 +821,46 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -752,35 +821,46 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
inputs = tf.constant(["abc d ef", "ABC D EF d"]) inputs = tf.constant(["abc d ef", "ABC D EF d"])
bert_inputs = model2(inputs) bert_inputs = model2(inputs)
self.assertAllEqual(bert_inputs["input_word_ids"], self.assertAllEqual(
tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0], bert_inputs["input_word_ids"],
[2, 6, 4, 5, 4, 3, 0, 0, 0, 0]])) tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0],
self.assertAllEqual(bert_inputs["input_mask"], [2, 6, 4, 5, 4, 3, 0, 0, 0, 0]]))
tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], self.assertAllEqual(
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])) bert_inputs["input_mask"],
self.assertAllEqual(bert_inputs["input_type_ids"], tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]))
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])) self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
@parameterized.named_parameters(("Bert", True), ("Albert", False)) @parameterized.named_parameters(("Bert", True), ("Albert", False))
def test_preprocessing_for_mlm(self, use_bert): def test_preprocessing_for_mlm(self, use_bert):
"""Combines both SavedModel types and TF.text helpers for MLM.""" """Combines both SavedModel types and TF.text helpers for MLM."""
# Create the preprocessing SavedModel with a [MASK] token. # Create the preprocessing SavedModel with a [MASK] token.
non_special_tokens = ["hello", "world", non_special_tokens = [
"nice", "movie", "great", "actors", "hello", "world", "nice", "movie", "great", "actors", "quick", "fox",
"quick", "fox", "lazy", "dog"] "lazy", "dog"
preprocess = tf.saved_model.load(self._do_export( ]
non_special_tokens, do_lower_case=True,
tokenize_with_offsets=use_bert, # TODO(b/181866850): drop this. preprocess = tf.saved_model.load(
experimental_disable_assert=True, # TODO(b/175369555): drop this. self._do_export(
add_mask_token=True, use_sp_model=not use_bert)) non_special_tokens,
do_lower_case=True,
tokenize_with_offsets=use_bert, # TODO(b/181866850): drop this.
experimental_disable_assert=True, # TODO(b/175369555): drop this.
add_mask_token=True,
use_sp_model=not use_bert))
vocab_size = len(non_special_tokens) + (5 if use_bert else 7) vocab_size = len(non_special_tokens) + (5 if use_bert else 7)
# Create the encoder SavedModel with an .mlm subobject. # Create the encoder SavedModel with an .mlm subobject.
hidden_size = 16 hidden_size = 16
num_hidden_layers = 2 num_hidden_layers = 2
bert_config, encoder_config = _get_bert_config_or_encoder_config( bert_config, encoder_config = _get_bert_config_or_encoder_config(
use_bert, hidden_size, num_hidden_layers, vocab_size) use_bert_config=use_bert,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
vocab_size=vocab_size)
_, pretrainer = export_tfhub_lib._create_model( _, pretrainer = export_tfhub_lib._create_model(
bert_config=bert_config, encoder_config=encoder_config, with_mlm=True) bert_config=bert_config, encoder_config=encoder_config, with_mlm=True)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint") model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
...@@ -814,8 +894,10 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -814,8 +894,10 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(mask_id, 4) self.assertEqual(mask_id, 4)
# A batch of 3 segment pairs. # A batch of 3 segment pairs.
raw_segments = [tf.constant(["hello", "nice movie", "quick fox"]), raw_segments = [
tf.constant(["world", "great actors", "lazy dog"])] tf.constant(["hello", "nice movie", "quick fox"]),
tf.constant(["world", "great actors", "lazy dog"])
]
batch_size = 3 batch_size = 3
# Misc hyperparameters. # Misc hyperparameters.
...@@ -842,18 +924,18 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -842,18 +924,18 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
selection_rate=0.5, # Adjusted for the short test examples. selection_rate=0.5, # Adjusted for the short test examples.
unselectable_ids=[start_of_sequence_id, end_of_segment_id]), unselectable_ids=[start_of_sequence_id, end_of_segment_id]),
mask_values_chooser=text.MaskValuesChooser( mask_values_chooser=text.MaskValuesChooser(
vocab_size=vocab_size, mask_token=mask_id, vocab_size=vocab_size,
mask_token=mask_id,
# Always put [MASK] to have a predictable result. # Always put [MASK] to have a predictable result.
mask_token_rate=1.0, random_token_rate=0.0)) mask_token_rate=1.0,
random_token_rate=0.0))
# Pad to fixed-length Transformer encoder inputs. # Pad to fixed-length Transformer encoder inputs.
input_word_ids, _ = text.pad_model_inputs(masked_input_ids, input_word_ids, _ = text.pad_model_inputs(
seq_length, masked_input_ids, seq_length, pad_value=padding_id)
pad_value=padding_id) input_type_ids, input_mask = text.pad_model_inputs(
input_type_ids, input_mask = text.pad_model_inputs(segment_ids, seq_length, segment_ids, seq_length, pad_value=0)
pad_value=0) masked_lm_positions, _ = text.pad_model_inputs(
masked_lm_positions, _ = text.pad_model_inputs(masked_lm_positions, masked_lm_positions, max_selections_per_seq, pad_value=0)
max_selections_per_seq,
pad_value=0)
masked_lm_positions = tf.cast(masked_lm_positions, tf.int32) masked_lm_positions = tf.cast(masked_lm_positions, tf.int32)
num_predictions = int(tf.shape(masked_lm_positions)[1]) num_predictions = int(tf.shape(masked_lm_positions)[1])
...@@ -865,7 +947,8 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -865,7 +947,8 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
# [CLS] nice movie [SEP] great actors [SEP] # [CLS] nice movie [SEP] great actors [SEP]
[2, 7, 8, 3, 9, 10, 3, 0, 0, 0], [2, 7, 8, 3, 9, 10, 3, 0, 0, 0],
# [CLS] brown fox [SEP] lazy dog [SEP] # [CLS] brown fox [SEP] lazy dog [SEP]
[2, 11, 12, 3, 13, 14, 3, 0, 0, 0]]) [2, 11, 12, 3, 13, 14, 3, 0, 0, 0]
])
for i in range(batch_size): for i in range(batch_size):
for j in range(num_predictions): for j in range(num_predictions):
k = int(masked_lm_positions[i, j]) k = int(masked_lm_positions[i, j])
...@@ -896,15 +979,17 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -896,15 +979,17 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True)) @parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
def test_special_tokens_in_estimator(self, use_sp_model): def test_special_tokens_in_estimator(self, use_sp_model):
"""Tests getting special tokens without an Eager init context.""" """Tests getting special tokens without an Eager init context."""
preprocess_export_path = self._do_export( preprocess_export_path = self._do_export(["d", "ef", "abc", "xy"],
["d", "ef", "abc", "xy"], do_lower_case=True, do_lower_case=True,
use_sp_model=use_sp_model, tokenize_with_offsets=False) use_sp_model=use_sp_model,
tokenize_with_offsets=False)
def _get_special_tokens_dict(obj): def _get_special_tokens_dict(obj):
"""Returns special tokens of restored tokenizer as Python values.""" """Returns special tokens of restored tokenizer as Python values."""
if tf.executing_eagerly(): if tf.executing_eagerly():
special_tokens_numpy = {k: v.numpy() special_tokens_numpy = {
for k, v in obj.get_special_tokens_dict()} k: v.numpy() for k, v in obj.get_special_tokens_dict()
}
else: else:
with tf.Graph().as_default(): with tf.Graph().as_default():
# This code expects `get_special_tokens_dict()` to be a tf.function # This code expects `get_special_tokens_dict()` to be a tf.function
...@@ -913,8 +998,10 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -913,8 +998,10 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
special_tokens_tensors = obj.get_special_tokens_dict() special_tokens_tensors = obj.get_special_tokens_dict()
with tf.compat.v1.Session() as sess: with tf.compat.v1.Session() as sess:
special_tokens_numpy = sess.run(special_tokens_tensors) special_tokens_numpy = sess.run(special_tokens_tensors)
return {k: v.item() # Numpy to Python. return {
for k, v in special_tokens_numpy.items()} k: v.item() # Numpy to Python.
for k, v in special_tokens_numpy.items()
}
def input_fn(): def input_fn():
self.assertFalse(tf.executing_eagerly()) self.assertFalse(tf.executing_eagerly())
...@@ -927,7 +1014,8 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -927,7 +1014,8 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(v, int, "Unexpected type for {}".format(k)) self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
tokens = tokenize(sentences) tokens = tokenize(sentences)
packed_inputs = layers.BertPackInputs( packed_inputs = layers.BertPackInputs(
4, special_tokens_dict=special_tokens_dict)(tokens) 4, special_tokens_dict=special_tokens_dict)(
tokens)
preprocessing = tf.keras.Model(sentences, packed_inputs) preprocessing = tf.keras.Model(sentences, packed_inputs)
# Map the dataset. # Map the dataset.
ds = tf.data.Dataset.from_tensors( ds = tf.data.Dataset.from_tensors(
...@@ -937,22 +1025,22 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -937,22 +1025,22 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def model_fn(features, labels, mode): def model_fn(features, labels, mode):
del labels # Unused. del labels # Unused.
return tf.estimator.EstimatorSpec(mode=mode, return tf_estimator.EstimatorSpec(
predictions=features["input_word_ids"]) mode=mode, predictions=features["input_word_ids"])
estimator = tf.estimator.Estimator(model_fn=model_fn) estimator = tf_estimator.Estimator(model_fn=model_fn)
outputs = list(estimator.predict(input_fn)) outputs = list(estimator.predict(input_fn))
self.assertAllEqual(outputs, np.array([[2, 6, 3, 0], self.assertAllEqual(outputs, np.array([[2, 6, 3, 0], [2, 4, 5, 3]]))
[2, 4, 5, 3]]))
# TODO(b/175369555): Remove that code and its test. # TODO(b/175369555): Remove that code and its test.
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True)) @parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
def test_check_no_assert(self, use_sp_model): def test_check_no_assert(self, use_sp_model):
"""Tests the self-check during export without assertions.""" """Tests the self-check during export without assertions."""
preprocess_export_path = self._do_export( preprocess_export_path = self._do_export(["d", "ef", "abc", "xy"],
["d", "ef", "abc", "xy"], do_lower_case=True, do_lower_case=True,
use_sp_model=use_sp_model, tokenize_with_offsets=False, use_sp_model=use_sp_model,
experimental_disable_assert=False) tokenize_with_offsets=False,
experimental_disable_assert=False)
with self.assertRaisesRegex(AssertionError, with self.assertRaisesRegex(AssertionError,
r"failed to suppress \d+ Assert ops"): r"failed to suppress \d+ Assert ops"):
export_tfhub_lib._check_no_assert(preprocess_export_path) export_tfhub_lib._check_no_assert(preprocess_export_path)
...@@ -963,8 +1051,8 @@ def _result_shapes_in_tf_function(fn, *args, **kwargs): ...@@ -963,8 +1051,8 @@ def _result_shapes_in_tf_function(fn, *args, **kwargs):
Args: Args:
fn: A callable. fn: A callable.
*args: TensorSpecs for Tensor-valued arguments and actual values *args: TensorSpecs for Tensor-valued arguments and actual values for
for Python-valued arguments to fn. Python-valued arguments to fn.
**kwargs: Same for keyword arguments. **kwargs: Same for keyword arguments.
Returns: Returns:
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
import numpy as np
import tensorflow.compat.v1 as tf # TF 1.x
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
BERT_NAME_REPLACEMENTS = (
("bert", "bert_model"),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/token_type_embeddings",
"embedding_postprocessor/type_embeddings"),
("embeddings/position_embeddings",
"embedding_postprocessor/position_embeddings"),
("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"),
)
BERT_V2_NAME_REPLACEMENTS = (
("bert/", ""),
("encoder", "transformer"),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
("embeddings/position_embeddings", "position_embedding/embeddings"),
("embeddings/LayerNorm", "embeddings/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention/attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"),
("cls/predictions", "bert/cls/predictions"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
("cls/seq_relationship/output_weights",
"predictions/transform/logits/kernel"),
)
BERT_PERMUTATIONS = ()
BERT_V2_PERMUTATIONS = (("cls/seq_relationship/output_weights", (1, 0)),)
def _bert_name_replacement(var_name, name_replacements):
"""Gets the variable name replacement."""
for src_pattern, tgt_pattern in name_replacements:
if src_pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(src_pattern, tgt_pattern)
tf.logging.info("Converted: %s --> %s", old_var_name, var_name)
return var_name
def _has_exclude_patterns(name, exclude_patterns):
"""Checks if a string contains substrings that match patterns to exclude."""
for p in exclude_patterns:
if p in name:
return True
return False
def _get_permutation(name, permutations):
"""Checks whether a variable requires transposition by pattern matching."""
for src_pattern, permutation in permutations:
if src_pattern in name:
tf.logging.info("Permuted: %s --> %s", name, permutation)
return permutation
return None
def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching."""
if "self_attention/attention_output/kernel" in name:
return tuple([num_heads, shape[0] // num_heads, shape[1]])
if "self_attention/attention_output/bias" in name:
return shape
patterns = [
"self_attention/query", "self_attention/value", "self_attention/key"
]
for pattern in patterns:
if pattern in name:
if "kernel" in name:
return tuple([shape[0], num_heads, shape[1] // num_heads])
if "bias" in name:
return tuple([num_heads, shape[0] // num_heads])
return None
def create_v2_checkpoint(model,
src_checkpoint,
output_path,
checkpoint_model_name="model"):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model.load_weights(src_checkpoint).assert_existing_objects_matched()
if hasattr(model, "checkpoint_items"):
checkpoint_items = model.checkpoint_items
else:
checkpoint_items = {}
checkpoint_items[checkpoint_model_name] = model
checkpoint = tf.train.Checkpoint(**checkpoint_items)
checkpoint.save(output_path)
def convert(checkpoint_from_path,
checkpoint_to_path,
num_heads,
name_replacements,
permutations,
exclude_patterns=None):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
num_heads: The number of heads of the model.
name_replacements: A list of tuples of the form (match_str, replace_str)
describing variable names to adjust.
permutations: A list of tuples of the form (match_str, permutation)
describing permutations to apply to given variables. Note that match_str
should match the original variable name, not the replaced one.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with tf.Graph().as_default():
tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
reader = tf.train.NewCheckpointReader(checkpoint_from_path)
name_shape_map = reader.get_variable_to_shape_map()
new_variable_map = {}
conversion_map = {}
for var_name in name_shape_map:
if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
continue
# Get the original tensor data.
tensor = reader.get_tensor(var_name)
# Look up the new variable name, if any.
new_var_name = _bert_name_replacement(var_name, name_replacements)
# See if we need to reshape the underlying tensor.
new_shape = None
if num_heads > 0:
new_shape = _get_new_shape(new_var_name, tensor.shape, num_heads)
if new_shape:
tf.logging.info("Veriable %s has a shape change from %s to %s",
var_name, tensor.shape, new_shape)
tensor = np.reshape(tensor, new_shape)
# See if we need to permute the underlying tensor.
permutation = _get_permutation(var_name, permutations)
if permutation:
tensor = np.transpose(tensor, permutation)
# Create a new variable with the possibly-reshaped or transposed tensor.
var = tf.Variable(tensor, name=var_name)
# Save the variable into the new variable map.
new_variable_map[new_var_name] = var
# Keep a list of converter variables for sanity checking.
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
saver = tf.train.Saver(new_variable_map)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path)
saver.save(sess, checkpoint_to_path, write_meta_graph=False)
tf.logging.info("Summary:")
tf.logging.info(" Converted %d variable name(s).", len(new_variable_map))
tf.logging.info(" Converted: %s", str(conversion_map))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -23,11 +23,11 @@ from absl import app ...@@ -23,11 +23,11 @@ from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.legacy.nlp.albert import configs from official.legacy.albert import configs
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.tools import tf1_bert_checkpoint_converter_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -128,12 +128,12 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint, ...@@ -128,12 +128,12 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
# Create a temporary V1 name-converted checkpoint in the output directory. # Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1") temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt") temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
tf1_checkpoint_converter_lib.convert( tf1_bert_checkpoint_converter_lib.convert(
checkpoint_from_path=v1_checkpoint, checkpoint_from_path=v1_checkpoint,
checkpoint_to_path=temporary_checkpoint, checkpoint_to_path=temporary_checkpoint,
num_heads=bert_config.num_attention_heads, num_heads=bert_config.num_attention_heads,
name_replacements=ALBERT_NAME_REPLACEMENTS, name_replacements=ALBERT_NAME_REPLACEMENTS,
permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS, permutations=tf1_bert_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
exclude_patterns=["adam", "Adam"]) exclude_patterns=["adam", "Adam"])
# Create a V2 checkpoint from the temporary checkpoint. # Create a V2 checkpoint from the temporary checkpoint.
...@@ -144,9 +144,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint, ...@@ -144,9 +144,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
else: else:
raise ValueError("Unsupported converted_model: %s" % converted_model) raise ValueError("Unsupported converted_model: %s" % converted_model)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint, tf1_bert_checkpoint_converter_lib.create_v2_checkpoint(
output_path, model, temporary_checkpoint, output_path, checkpoint_model_name)
checkpoint_model_name)
# Clean up the temporary checkpoint, if it exists. # Clean up the temporary checkpoint, if it exists.
try: try:
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used
to restore a BertEncoder or BertPretrainerV2 object (see the `converted_model`
FLAG below).
"""
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.legacy.bert import configs
from official.modeling import tf_utils
from official.nlp.modeling import models
from official.nlp.modeling import networks
from official.nlp.tools import tf1_bert_checkpoint_converter_lib
FLAGS = flags.FLAGS
flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string(
"checkpoint_to_convert", None,
"Initial checkpoint from a pretrained BERT model core (that is, only the "
"BertModel, with no task heads.)")
flags.DEFINE_string("converted_checkpoint_path", None,
"Name for the created object-based V2 checkpoint.")
flags.DEFINE_string("checkpoint_model_name", "encoder",
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
flags.DEFINE_enum(
"converted_model", "encoder", ["encoder", "pretrainer"],
"Whether to convert the checkpoint to a `BertEncoder` model or a "
"`BertPretrainerV2` model (with mlm but without classification heads).")
def _create_bert_model(cfg):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertEncoder network.
"""
bert_encoder = networks.BertEncoder(
vocab_size=cfg.vocab_size,
hidden_size=cfg.hidden_size,
num_layers=cfg.num_hidden_layers,
num_attention_heads=cfg.num_attention_heads,
intermediate_size=cfg.intermediate_size,
activation=tf_utils.get_activation(cfg.hidden_act),
dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob,
max_sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range),
embedding_width=cfg.embedding_size)
return bert_encoder
def _create_bert_pretrainer_model(cfg):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertPretrainerV2 model.
"""
bert_encoder = _create_bert_model(cfg)
pretrainer = models.BertPretrainerV2(
encoder_network=bert_encoder,
mlm_activation=tf_utils.get_activation(cfg.hidden_act),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range))
# Makes sure the pretrainer variables are created.
_ = pretrainer(pretrainer.inputs)
return pretrainer
def convert_checkpoint(bert_config,
output_path,
v1_checkpoint,
checkpoint_model_name="model",
converted_model="encoder"):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path)
tf.io.gfile.makedirs(output_dir)
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
tf1_bert_checkpoint_converter_lib.convert(
checkpoint_from_path=v1_checkpoint,
checkpoint_to_path=temporary_checkpoint,
num_heads=bert_config.num_attention_heads,
name_replacements=(
tf1_bert_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS),
permutations=tf1_bert_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
exclude_patterns=["adam", "Adam"])
if converted_model == "encoder":
model = _create_bert_model(bert_config)
elif converted_model == "pretrainer":
model = _create_bert_pretrainer_model(bert_config)
else:
raise ValueError("Unsupported converted_model: %s" % converted_model)
# Create a V2 checkpoint from the temporary checkpoint.
tf1_bert_checkpoint_converter_lib.create_v2_checkpoint(
model, temporary_checkpoint, output_path, checkpoint_model_name)
# Clean up the temporary checkpoint, if it exists.
try:
tf.io.gfile.rmtree(temporary_checkpoint_dir)
except tf.errors.OpError:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
checkpoint_model_name = FLAGS.checkpoint_model_name
converted_model = FLAGS.converted_model
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
convert_checkpoint(
bert_config=bert_config,
output_path=output_path,
v1_checkpoint=v1_checkpoint,
checkpoint_model_name=checkpoint_model_name,
converted_model=converted_model)
if __name__ == "__main__":
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,7 +18,7 @@ import tempfile ...@@ -18,7 +18,7 @@ import tempfile
import six import six
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.tools import tokenization
class TokenizationTest(tf.test.TestCase): class TokenizationTest(tf.test.TestCase):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Transformer Translation Model
This is an implementation of the Transformer translation model as described in
the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper. The
implementation leverages tf.keras and makes sure it is compatible with TF 2.x.
**Warning: the features in the `transformer/` folder have been fully intergrated
into nlp/modeling.
Due to its dependencies, we will remove this folder after the model
garden 2.5 release. The model in `nlp/modeling/models/seq2seq_transformer.py` is
identical to the model in this folder.**
## Contents
* [Contents](#contents)
* [Walkthrough](#walkthrough)
* [Detailed instructions](#detailed-instructions)
* [Environment preparation](#environment-preparation)
* [Download and preprocess datasets](#download-and-preprocess-datasets)
* [Model training and evaluation](#model-training-and-evaluation)
* [Implementation overview](#implementation-overview)
* [Model Definition](#model-definition)
* [Model Trainer](#model-trainer)
* [Test dataset](#test-dataset)
## Walkthrough
Below are the commands for running the Transformer model. See the
[Detailed instructions](#detailed-instructions) for more details on running the
model.
```
# Ensure that PYTHONPATH is correctly defined as described in
# https://github.com/tensorflow/models/tree/master/official#requirements
export PYTHONPATH="$PYTHONPATH:/path/to/models"
cd /path/to/models/official/nlp/transformer
# Export variables
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAM_SET
VOCAB_FILE=$DATA_DIR/vocab.ende.32768
# Download training/evaluation/test datasets
python3 data_download.py --data_dir=$DATA_DIR
# Train the model for 100000 steps and evaluate every 5000 steps on a single GPU.
# Each train step, takes 4096 tokens as a batch budget with 64 as sequence
# maximal length.
python3 transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--vocab_file=$VOCAB_FILE --param_set=$PARAM_SET \
--train_steps=100000 --steps_between_evals=5000 \
--batch_size=4096 --max_length=64 \
--bleu_source=$DATA_DIR/newstest2014.en \
--bleu_ref=$DATA_DIR/newstest2014.de \
--num_gpus=1 \
--enable_time_history=false
# Run during training in a separate process to get continuous updates,
# or after training is complete.
tensorboard --logdir=$MODEL_DIR
```
## Detailed instructions
0. ### Environment preparation
#### Add models repo to PYTHONPATH
Follow the instructions described in the [Requirements](https://github.com/tensorflow/models/tree/master/official#requirements) section to add the models folder to the python path.
#### Export variables (optional)
Export the following variables, or modify the values in each of the snippets below:
```shell
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAM_SET
VOCAB_FILE=$DATA_DIR/vocab.ende.32768
```
1. ### Download and preprocess datasets
[data_download.py](data_download.py) downloads and preprocesses the training and evaluation WMT datasets. After the data is downloaded and extracted, the training data is used to generate a vocabulary of subtokens. The evaluation and training strings are tokenized, and the resulting data is sharded, shuffled, and saved as TFRecords.
1.75GB of compressed data will be downloaded. In total, the raw files (compressed, extracted, and combined files) take up 8.4GB of disk space. The resulting TFRecord and vocabulary files are 722MB. The script takes around 40 minutes to run, with the bulk of the time spent downloading and ~15 minutes spent on preprocessing.
Command to run:
```
python3 data_download.py --data_dir=$DATA_DIR
```
Arguments:
* `--data_dir`: Path where the preprocessed TFRecord data, and vocab file will be saved.
* Use the `--help` or `-h` flag to get a full list of possible arguments.
2. ### Model training and evaluation
[transformer_main.py](transformer_main.py) creates a Transformer keras model,
and trains it uses keras model.fit().
Users need to adjust `batch_size` and `num_gpus` to get good performance
running multiple GPUs.
**Note that:**
when using multiple GPUs or TPUs, this is the global batch size for all
devices. For example, if the batch size is `4096*4` and there are 4 devices,
each device will take 4096 tokens as a batch budget.
Command to run:
```
python3 transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--vocab_file=$VOCAB_FILE --param_set=$PARAM_SET
```
Arguments:
* `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
* `--model_dir`: Directory to save Transformer model training checkpoints.
* `--vocab_file`: Path to subtoken vocabulary file. If data_download was used, you may find the file in `data_dir`.
* `--param_set`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* `--enable_time_history`: Whether add TimeHistory call. If so, --log_steps must be specified.
* `--batch_size`: The number of tokens to consider in a batch. Combining with
`--max_length`, they decide how many sequences are used per batch.
* Use the `--help` or `-h` flag to get a full list of possible arguments.
#### Using multiple GPUs
You can train these models on multiple GPUs using `tf.distribute.Strategy` API.
You can read more about them in this
[guide](https://www.tensorflow.org/guide/distribute_strategy).
In this example, we have made it easier to use is with just a command line flag
`--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA,
and 0 otherwise.
- --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
- --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
- --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
distributed training across the GPUs.
#### Using Cloud TPUs
You can train the Transformer model on Cloud TPUs using
`tf.distribute.TPUStrategy`. If you are not familiar with Cloud TPUs, it is
strongly recommended that you go through the
[quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
create a TPU and GCE VM.
To run the Transformer model on a TPU, you must set
`--distribution_strategy=tpu`, `--tpu=$TPU_NAME`, and `--use_ctl=True` where
`$TPU_NAME` the name of your TPU in the Cloud Console.
An example command to run Transformer on a v2-8 or v3-8 TPU would be:
```bash
python transformer_main.py \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--vocab_file=$DATA_DIR/vocab.ende.32768 \
--bleu_source=$DATA_DIR/newstest2014.en \
--bleu_ref=$DATA_DIR/newstest2014.end \
--batch_size=6144 \
--train_steps=2000 \
--static_batch=true \
--use_ctl=true \
--param_set=big \
--max_length=64 \
--decode_batch_size=32 \
--decode_max_length=97 \
--padded_decode=true \
--distribution_strategy=tpu
```
Note: `$MODEL_DIR` and `$DATA_DIR` must be GCS paths.
#### Customizing training schedule
By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags:
* Training with steps:
* `--train_steps`: sets the total number of training steps to run.
* `--steps_between_evals`: Number of training steps to run between evaluations.
#### Compute BLEU score during model evaluation
Use these flags to compute the BLEU when the model evaluates:
* `--bleu_source`: Path to file containing text to translate.
* `--bleu_ref`: Path to file containing the reference translation.
When running `transformer_main.py`, use the flags: `--bleu_source=$DATA_DIR/newstest2014.en --bleu_ref=$DATA_DIR/newstest2014.de`
#### Tensorboard
Training and evaluation metrics (loss, accuracy, approximate BLEU score, etc.) are logged, and can be displayed in the browser using Tensorboard.
```
tensorboard --logdir=$MODEL_DIR
```
The values are displayed at [localhost:6006](localhost:6006).
## Implementation overview
A brief look at each component in the code:
### Model Definition
* [transformer.py](transformer.py): Defines a tf.keras.Model: `Transformer`.
* [embedding_layer.py](embedding_layer.py): Contains the layer that calculates the embeddings. The embedding weights are also used to calculate the pre-softmax probabilities from the decoder output.
* [attention_layer.py](attention_layer.py): Defines the multi-headed and self attention layers that are used in the encoder/decoder stacks.
* [ffn_layer.py](ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers.
Other files:
* [beam_search.py](beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations.
### Model Trainer
[transformer_main.py](transformer_main.py) creates an `TransformerTask` to train and evaluate the model using tf.keras.
### Test dataset
The [newstest2014 files](https://storage.googleapis.com/tf-perf-public/official_transformer/test_data/newstest2014.tgz)
are extracted from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data).
The raw text files are converted from the SGM format of the
[WMT 2016](http://www.statmt.org/wmt16/translation-task.html) test sets. The
newstest2014 files are put into the `$DATA_DIR` when executing `data_download.py`
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of multiheaded attention and self-attention layers."""
import math
import tensorflow as tf
class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer."""
def __init__(self, hidden_size, num_heads, attention_dropout):
"""Initialize Attention.
Args:
hidden_size: int, output dim of hidden layer.
num_heads: int, number of heads to repeat the same attention structure.
attention_dropout: float, dropout rate inside attention for training.
"""
if hidden_size % num_heads:
raise ValueError(
"Hidden size ({}) must be divisible by the number of heads ({})."
.format(hidden_size, num_heads))
super(Attention, self).__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
def build(self, input_shape):
"""Builds the layer."""
# Layers for linearly projecting the queries, keys, and values.
size_per_head = self.hidden_size // self.num_heads
def _glorot_initializer(fan_in, fan_out):
limit = math.sqrt(6.0 / (fan_in + fan_out))
return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)
attention_initializer = _glorot_initializer(input_shape.as_list()[-1],
self.hidden_size)
self.query_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
bias_axes=None,
name="query")
self.key_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
bias_axes=None,
name="key")
self.value_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
bias_axes=None,
name="value")
output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
self.output_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTNH,NHE->BTE",
output_shape=(None, self.hidden_size),
kernel_initializer=output_initializer,
bias_axes=None,
name="output_transform")
super(Attention, self).build(input_shape)
def get_config(self):
return {
"hidden_size": self.hidden_size,
"num_heads": self.num_heads,
"attention_dropout": self.attention_dropout,
}
def call(self,
query_input,
source_input,
bias,
training,
cache=None,
decode_loop_step=None):
"""Apply attention mechanism to query_input and source_input.
Args:
query_input: A tensor with shape [batch_size, length_query, hidden_size].
source_input: A tensor with shape [batch_size, length_source,
hidden_size].
bias: A tensor with shape [batch_size, 1, length_query, length_source],
the attention bias that will be added to the result of the dot product.
training: A bool, whether in training mode or not.
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, heads, dim_per_head],
"v": tensor with shape [batch_size, i, heads, dim_per_head]} where
i is the current decoded length for non-padded decode, or max
sequence length for padded decode.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns:
Attention layer output with shape [batch_size, length_query, hidden_size]
"""
# Linearly project the query, key and value using different learned
# projections. Splitting heads is automatically done during the linear
# projections --> [batch_size, length, num_heads, dim_per_head].
query = self.query_dense_layer(query_input)
key = self.key_dense_layer(source_input)
value = self.value_dense_layer(source_input)
if cache is not None:
# Combine cached keys and values with new keys and values.
if decode_loop_step is not None:
cache_k_shape = cache["k"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
[1, cache_k_shape[1], 1, 1])
key = cache["k"] + key * indices
cache_v_shape = cache["v"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
[1, cache_v_shape[1], 1, 1])
value = cache["v"] + value * indices
else:
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
# Update cache
cache["k"] = key
cache["v"] = value
# Scale query to prevent the dot product between query and key from growing
# too large.
depth = (self.hidden_size // self.num_heads)
query *= depth**-0.5
# Calculate dot product attention
logits = tf.einsum("BTNH,BFNH->BNFT", key, query)
logits += bias
# Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
# and output in float16 for better performance.
weights = tf.nn.softmax(logits, name="attention_weights")
if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.einsum("BNFT,BTNH->BFNH", weights, value)
# Run the outputs through another linear projection layer. Recombining heads
# is automatically done --> [batch_size, length, hidden_size]
attention_output = self.output_dense_layer(attention_output)
return attention_output
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""
def call(self,
query_input,
bias,
training,
cache=None,
decode_loop_step=None):
return super(SelfAttention, self).call(query_input, query_input, bias,
training, cache, decode_loop_step)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Beam search to find the translated sequence with the highest probability."""
import tensorflow.compat.v1 as tf
from official.nlp.modeling.ops import beam_search
_StateKeys = beam_search._StateKeys # pylint: disable=protected-access
class SequenceBeamSearch(beam_search.SequenceBeamSearch):
"""Implementation of beam search loop."""
def _process_finished_state(self, finished_state):
alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
# Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch
# item.
finished_seq = tf.where(
tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
finished_scores = tf.where(
tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
return finished_seq, finished_scores
def sequence_beam_search(symbols_to_logits_fn,
initial_ids,
initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode=False):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape: ids -> A tensor with
shape [batch_size * beam_size, index]. index -> A scalar. cache -> A
nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache: logits -> A
tensor with shape [batch * beam_size, vocab_size]. new cache -> A nested
dictionary with the same shape/structure as the inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for each
batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used for
beam search.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, beam_size, alpha,
max_decode_length, eos_id, padded_decode)
return sbs.search(initial_ids, initial_cache)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to compute official BLEU score.
Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
"""
import re
import sys
import unicodedata
from absl import app
from absl import flags
from absl import logging
import six
from six.moves import range
import tensorflow as tf
from official.nlp.transformer.utils import metrics
from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
class UnicodeRegex(object):
"""Ad-hoc hack to recognize all punctuation and symbols."""
def __init__(self):
punctuation = self.property_chars("P")
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
def property_chars(self, prefix):
return "".join(
six.unichr(x)
for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix))
uregex = UnicodeRegex()
def bleu_tokenize(string):
r"""Tokenize a string following the official BLEU implementation.
See https://github.com/moses-smt/mosesdecoder/'
'blob/master/scripts/generic/mteval-v14.pl#L954-L983
In our case, the input string is expected to be just one line
and no HTML entities de-escaping is needed.
So we just tokenize on punctuation and symbols,
except when a punctuation is preceded and followed by a digit
(e.g. a comma/dot as a thousand/decimal separator).
Note that a numer (e.g. a year) followed by a dot at the end of sentence
is NOT tokenized,
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
does not match this case (unless we add a space after each sentence).
However, this error is already in the original mteval-v14.pl
and we want to be consistent with it.
Args:
string: the input string
Returns:
a list of tokens
"""
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
string = uregex.symbol_re.sub(r" \1 ", string)
return string.split()
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
"""Compute BLEU for two files (reference and hypothesis translation)."""
ref_lines = tokenizer.native_to_unicode(
tf.io.gfile.GFile(ref_filename).read()).strip().splitlines()
hyp_lines = tokenizer.native_to_unicode(
tf.io.gfile.GFile(hyp_filename).read()).strip().splitlines()
return bleu_on_list(ref_lines, hyp_lines, case_sensitive)
def bleu_on_list(ref_lines, hyp_lines, case_sensitive=False):
"""Compute BLEU for two list of strings (reference and hypothesis)."""
if len(ref_lines) != len(hyp_lines):
raise ValueError(
"Reference and translation files have different number of "
"lines (%d VS %d). If training only a few steps (100-200), the "
"translation may be empty." % (len(ref_lines), len(hyp_lines)))
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
return metrics.compute_bleu(ref_tokens, hyp_tokens) * 100
def main(unused_argv):
if FLAGS.bleu_variant in ("both", "uncased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
logging.info("Case-insensitive results: %f", score)
if FLAGS.bleu_variant in ("both", "cased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
logging.info("Case-sensitive results: %f", score)
def define_compute_bleu_flags():
"""Add flags for computing BLEU score."""
flags.DEFINE_string(
name="translation",
default=None,
help=flags_core.help_wrap("File containing translated text."))
flags.mark_flag_as_required("translation")
flags.DEFINE_string(
name="reference",
default=None,
help=flags_core.help_wrap("File containing reference translation."))
flags.mark_flag_as_required("reference")
flags.DEFINE_enum(
name="bleu_variant",
short_name="bv",
default="both",
enum_values=["both", "uncased", "cased"],
case_sensitive=False,
help=flags_core.help_wrap(
"Specify one or more BLEU variants to calculate. Variants: \"cased\""
", \"uncased\", or \"both\"."))
if __name__ == "__main__":
define_compute_bleu_flags()
FLAGS = flags.FLAGS
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test functions in compute_blue.py."""
import tempfile
import tensorflow as tf
from official.nlp.transformer import compute_bleu
class ComputeBleuTest(tf.test.TestCase):
def _create_temp_file(self, text):
temp_file = tempfile.NamedTemporaryFile(delete=False)
with tf.io.gfile.GFile(temp_file.name, "w") as w:
w.write(text)
return temp_file.name
def test_bleu_same(self):
ref = self._create_temp_file("test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nmore tests!")
uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
self.assertEqual(100, uncased_score)
self.assertEqual(100, cased_score)
def test_bleu_same_different_case(self):
ref = self._create_temp_file("Test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nMore tests!")
uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
self.assertEqual(100, uncased_score)
self.assertLess(cased_score, 100)
def test_bleu_different(self):
ref = self._create_temp_file("Testing\nmore tests!")
hyp = self._create_temp_file("Dog\nCat")
uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
self.assertLess(uncased_score, 100)
self.assertLess(cased_score, 100)
def test_bleu_tokenize(self):
s = "Test0, 1 two, 3"
tokenized = compute_bleu.bleu_tokenize(s)
self.assertEqual(["Test0", ",", "1", "two", ",", "3"], tokenized)
def test_bleu_list(self):
ref = ["test 1 two 3", "more tests!"]
hyp = ["test 1 two 3", "More tests!"]
uncased_score = compute_bleu.bleu_on_list(ref, hyp, False)
cased_score = compute_bleu.bleu_on_list(ref, hyp, True)
self.assertEqual(uncased_score, 100)
self.assertLess(cased_score, 100)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Download and preprocess WMT17 ende training and evaluation datasets."""
import os
import random
import tarfile
# pylint: disable=g-bad-import-order
from absl import app
from absl import flags
from absl import logging
import six
from six.moves import range
from six.moves import urllib
from six.moves import zip
import tensorflow.compat.v1 as tf
from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
# pylint: enable=g-bad-import-order
# Data sources for training/evaluating the transformer translation model.
# If any of the training sources are changed, then either:
# 1) use the flag `--search` to find the best min count or
# 2) update the _TRAIN_DATA_MIN_COUNT constant.
# min_count is the minimum number of times a token must appear in the data
# before it is added to the vocabulary. "Best min count" refers to the value
# that generates a vocabulary set that is closest in size to _TARGET_VOCAB_SIZE.
_TRAIN_DATA_SOURCES = [
{
"url": "http://data.statmt.org/wmt17/translation-task/"
"training-parallel-nc-v12.tgz",
"input": "news-commentary-v12.de-en.en",
"target": "news-commentary-v12.de-en.de",
},
{
"url": "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
"input": "commoncrawl.de-en.en",
"target": "commoncrawl.de-en.de",
},
{
"url": "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
"input": "europarl-v7.de-en.en",
"target": "europarl-v7.de-en.de",
},
]
# Use pre-defined minimum count to generate subtoken vocabulary.
_TRAIN_DATA_MIN_COUNT = 6
_EVAL_DATA_SOURCES = [{
"url": "http://data.statmt.org/wmt17/translation-task/dev.tgz",
"input": "newstest2013.en",
"target": "newstest2013.de",
}]
_TEST_DATA_SOURCES = [{
"url": ("https://storage.googleapis.com/cloud-tpu-test-datasets/"
"transformer_data/newstest2014.tgz"),
"input": "newstest2014.en",
"target": "newstest2014.de",
}]
# Vocabulary constants
_TARGET_VOCAB_SIZE = 32768 # Number of subtokens in the vocabulary list.
_TARGET_THRESHOLD = 327 # Accept vocabulary if size is within this threshold
VOCAB_FILE = "vocab.ende.%d" % _TARGET_VOCAB_SIZE
# Strings to inclue in the generated files.
_PREFIX = "wmt32k"
_TRAIN_TAG = "train"
_EVAL_TAG = "dev" # Following WMT and Tensor2Tensor conventions, in which the
# evaluation datasets are tagged as "dev" for development.
# Number of files to split train and evaluation data
_TRAIN_SHARDS = 100
_EVAL_SHARDS = 1
def find_file(path, filename, max_depth=5):
"""Returns full filepath if the file is in path or a subdirectory."""
for root, dirs, files in os.walk(path):
if filename in files:
return os.path.join(root, filename)
# Don't search past max_depth
depth = root[len(path) + 1:].count(os.sep)
if depth > max_depth:
del dirs[:] # Clear dirs
return None
###############################################################################
# Download and extraction functions
###############################################################################
def get_raw_files(raw_dir, data_source):
"""Return raw files from source.
Downloads/extracts if needed.
Args:
raw_dir: string directory to store raw files
data_source: dictionary with
{"url": url of compressed dataset containing input and target files
"input": file with data in input language
"target": file with data in target language}
Returns:
dictionary with
{"inputs": list of files containing data in input language
"targets": list of files containing corresponding data in target language
}
"""
raw_files = {
"inputs": [],
"targets": [],
} # keys
for d in data_source:
input_file, target_file = download_and_extract(raw_dir, d["url"],
d["input"], d["target"])
raw_files["inputs"].append(input_file)
raw_files["targets"].append(target_file)
return raw_files
def download_report_hook(count, block_size, total_size):
"""Report hook for download progress.
Args:
count: current block number
block_size: block size
total_size: total size
"""
percent = int(count * block_size * 100 / total_size)
print(six.ensure_str("\r%d%%" % percent) + " completed", end="\r")
def download_from_url(path, url):
"""Download content from a url.
Args:
path: string directory where file will be downloaded
url: string url
Returns:
Full path to downloaded file
"""
filename = six.ensure_str(url).split("/")[-1]
found_file = find_file(path, filename, max_depth=0)
if found_file is None:
filename = os.path.join(path, filename)
logging.info("Downloading from %s to %s.", url, filename)
inprogress_filepath = six.ensure_str(filename) + ".incomplete"
inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress.
print()
tf.gfile.Rename(inprogress_filepath, filename)
return filename
else:
logging.info("Already downloaded: %s (at %s).", url, found_file)
return found_file
def download_and_extract(path, url, input_filename, target_filename):
"""Extract files from downloaded compressed archive file.
Args:
path: string directory where the files will be downloaded
url: url containing the compressed input and target files
input_filename: name of file containing data in source language
target_filename: name of file containing data in target language
Returns:
Full paths to extracted input and target files.
Raises:
OSError: if the the download/extraction fails.
"""
# Check if extracted files already exist in path
input_file = find_file(path, input_filename)
target_file = find_file(path, target_filename)
if input_file and target_file:
logging.info("Already downloaded and extracted %s.", url)
return input_file, target_file
# Download archive file if it doesn't already exist.
compressed_file = download_from_url(path, url)
# Extract compressed files
logging.info("Extracting %s.", compressed_file)
with tarfile.open(compressed_file, "r:gz") as corpus_tar:
corpus_tar.extractall(path)
# Return file paths of the requested files.
input_file = find_file(path, input_filename)
target_file = find_file(path, target_filename)
if input_file and target_file:
return input_file, target_file
raise OSError("Download/extraction failed for url %s to path %s" %
(url, path))
def txt_line_iterator(path):
"""Iterate through lines of file."""
with tf.io.gfile.GFile(path) as f:
for line in f:
yield line.strip()
def compile_files(raw_dir, raw_files, tag):
"""Compile raw files into a single file for each language.
Args:
raw_dir: Directory containing downloaded raw files.
raw_files: Dict containing filenames of input and target data.
{"inputs": list of files containing data in input language
"targets": list of files containing corresponding data in target language
}
tag: String to append to the compiled filename.
Returns:
Full path of compiled input and target files.
"""
logging.info("Compiling files with tag %s.", tag)
filename = "%s-%s" % (_PREFIX, tag)
input_compiled_file = os.path.join(raw_dir,
six.ensure_str(filename) + ".lang1")
target_compiled_file = os.path.join(raw_dir,
six.ensure_str(filename) + ".lang2")
with tf.io.gfile.GFile(input_compiled_file, mode="w") as input_writer:
with tf.io.gfile.GFile(target_compiled_file, mode="w") as target_writer:
for i in range(len(raw_files["inputs"])):
input_file = raw_files["inputs"][i]
target_file = raw_files["targets"][i]
logging.info("Reading files %s and %s.", input_file, target_file)
write_file(input_writer, input_file)
write_file(target_writer, target_file)
return input_compiled_file, target_compiled_file
def write_file(writer, filename):
"""Write all of lines from file using the writer."""
for line in txt_line_iterator(filename):
writer.write(line)
writer.write("\n")
###############################################################################
# Data preprocessing
###############################################################################
def encode_and_save_files(subtokenizer, data_dir, raw_files, tag, total_shards):
"""Save data from files as encoded Examples in TFrecord format.
Args:
subtokenizer: Subtokenizer object that will be used to encode the strings.
data_dir: The directory in which to write the examples
raw_files: A tuple of (input, target) data files. Each line in the input and
the corresponding line in target file will be saved in a tf.Example.
tag: String that will be added onto the file names.
total_shards: Number of files to divide the data into.
Returns:
List of all files produced.
"""
# Create a file for each shard.
filepaths = [
shard_filename(data_dir, tag, n + 1, total_shards)
for n in range(total_shards)
]
if all_exist(filepaths):
logging.info("Files with tag %s already exist.", tag)
return filepaths
logging.info("Saving files with tag %s.", tag)
input_file = raw_files[0]
target_file = raw_files[1]
# Write examples to each shard in round robin order.
tmp_filepaths = [six.ensure_str(fname) + ".incomplete" for fname in filepaths]
writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filepaths]
counter, shard = 0, 0
for counter, (input_line, target_line) in enumerate(
zip(txt_line_iterator(input_file), txt_line_iterator(target_file))):
if counter > 0 and counter % 100000 == 0:
logging.info("\tSaving case %d.", counter)
example = dict_to_example({
"inputs": subtokenizer.encode(input_line, add_eos=True),
"targets": subtokenizer.encode(target_line, add_eos=True)
})
writers[shard].write(example.SerializeToString())
shard = (shard + 1) % total_shards
for writer in writers:
writer.close()
for tmp_name, final_name in zip(tmp_filepaths, filepaths):
tf.gfile.Rename(tmp_name, final_name)
logging.info("Saved %d Examples", counter + 1)
return filepaths
def shard_filename(path, tag, shard_num, total_shards):
"""Create filename for data shard."""
return os.path.join(
path, "%s-%s-%.5d-of-%.5d" % (_PREFIX, tag, shard_num, total_shards))
def shuffle_records(fname):
"""Shuffle records in a single file."""
logging.info("Shuffling records in file %s", fname)
# Rename file prior to shuffling
tmp_fname = six.ensure_str(fname) + ".unshuffled"
tf.gfile.Rename(fname, tmp_fname)
reader = tf.io.tf_record_iterator(tmp_fname)
records = []
for record in reader:
records.append(record)
if len(records) % 100000 == 0:
logging.info("\tRead: %d", len(records))
random.shuffle(records)
# Write shuffled records to original file name
with tf.python_io.TFRecordWriter(fname) as w:
for count, record in enumerate(records):
w.write(record)
if count > 0 and count % 100000 == 0:
logging.info("\tWriting record: %d", count)
tf.gfile.Remove(tmp_fname)
def dict_to_example(dictionary):
"""Converts a dictionary of string->int to a tf.Example."""
features = {}
for k, v in six.iteritems(dictionary):
features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
return tf.train.Example(features=tf.train.Features(feature=features))
def all_exist(filepaths):
"""Returns true if all files in the list exist."""
for fname in filepaths:
if not tf.gfile.Exists(fname):
return False
return True
def make_dir(path):
if not tf.gfile.Exists(path):
logging.info("Creating directory %s", path)
tf.gfile.MakeDirs(path)
def main(unused_argv):
"""Obtain training and evaluation data for the Transformer model."""
make_dir(FLAGS.raw_dir)
make_dir(FLAGS.data_dir)
# Download test_data
logging.info("Step 1/5: Downloading test data")
get_raw_files(FLAGS.data_dir, _TEST_DATA_SOURCES)
# Get paths of download/extracted training and evaluation files.
logging.info("Step 2/5: Downloading data from source")
train_files = get_raw_files(FLAGS.raw_dir, _TRAIN_DATA_SOURCES)
eval_files = get_raw_files(FLAGS.raw_dir, _EVAL_DATA_SOURCES)
# Create subtokenizer based on the training files.
logging.info("Step 3/5: Creating subtokenizer and building vocabulary")
train_files_flat = train_files["inputs"] + train_files["targets"]
vocab_file = os.path.join(FLAGS.data_dir, VOCAB_FILE)
subtokenizer = tokenizer.Subtokenizer.init_from_files(
vocab_file,
train_files_flat,
_TARGET_VOCAB_SIZE,
_TARGET_THRESHOLD,
min_count=None if FLAGS.search else _TRAIN_DATA_MIN_COUNT)
logging.info("Step 4/5: Compiling training and evaluation data")
compiled_train_files = compile_files(FLAGS.raw_dir, train_files, _TRAIN_TAG)
compiled_eval_files = compile_files(FLAGS.raw_dir, eval_files, _EVAL_TAG)
# Tokenize and save data as Examples in the TFRecord format.
logging.info("Step 5/5: Preprocessing and saving data")
train_tfrecord_files = encode_and_save_files(subtokenizer, FLAGS.data_dir,
compiled_train_files, _TRAIN_TAG,
_TRAIN_SHARDS)
encode_and_save_files(subtokenizer, FLAGS.data_dir, compiled_eval_files,
_EVAL_TAG, _EVAL_SHARDS)
for fname in train_tfrecord_files:
shuffle_records(fname)
def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir",
short_name="dd",
default="/tmp/translate_ende",
help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string(
name="raw_dir",
short_name="rd",
default="/tmp/translate_ende_raw",
help=flags_core.help_wrap(
"Path where the raw data will be downloaded and extracted."))
flags.DEFINE_bool(
name="search",
default=False,
help=flags_core.help_wrap(
"If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE))
if __name__ == "__main__":
logging.set_verbosity(logging.INFO)
define_data_download_flags()
FLAGS = flags.FLAGS
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment