Commit d7eabefa authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[translation] Add text2text export module.

PiperOrigin-RevId: 418559537
parent 439d515a
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
"""A binary/library to export TF-NLP serving `SavedModel`.""" """A binary/library to export TF-NLP serving `SavedModel`."""
import dataclasses
import os import os
from typing import Any, Dict, Text from typing import Any, Dict, Text
from absl import app from absl import app
from absl import flags from absl import flags
import dataclasses
import yaml import yaml
from official.core import base_task from official.core import base_task
from official.core import task_factory from official.core import task_factory
from official.modeling import hyperparams from official.modeling import hyperparams
...@@ -29,6 +31,7 @@ from official.nlp.tasks import masked_lm ...@@ -29,6 +31,7 @@ from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging from official.nlp.tasks import tagging
from official.nlp.tasks import translation
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -40,7 +43,9 @@ SERVING_MODULES = { ...@@ -40,7 +43,9 @@ SERVING_MODULES = {
question_answering.QuestionAnsweringTask: question_answering.QuestionAnsweringTask:
serving_modules.QuestionAnswering, serving_modules.QuestionAnswering,
tagging.TaggingTask: tagging.TaggingTask:
serving_modules.Tagging serving_modules.Tagging,
translation.TranslationTask:
serving_modules.Translation
} }
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
"""Serving export modules for TF Model Garden NLP models.""" """Serving export modules for TF Model Garden NLP models."""
# pylint:disable=missing-class-docstring # pylint:disable=missing-class-docstring
import dataclasses
from typing import Dict, List, Optional, Text from typing import Dict, List, Optional, Text
import dataclasses
import tensorflow as tf import tensorflow as tf
import tensorflow_text as tf_text
from official.core import export_base from official.core import export_base
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.data import sentence_prediction_dataloader from official.nlp.data import sentence_prediction_dataloader
...@@ -407,3 +409,48 @@ class Tagging(export_base.ExportModule): ...@@ -407,3 +409,48 @@ class Tagging(export_base.ExportModule):
signatures[signature_key] = self.serve_examples.get_concrete_function( signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")) tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures return signatures
class Translation(export_base.ExportModule):
"""The export module for the translation task."""
@dataclasses.dataclass
class Params(base_config.Config):
sentencepiece_model_path: str = ""
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
self._sp_tokenizer = tf_text.SentencepieceTokenizer(
model=tf.io.gfile.GFile(params.sentencepiece_model_path, "rb").read(),
add_eos=True)
try:
empty_str_tokenized = self._sp_tokenizer.tokenize("").numpy()
except tf.errors.InternalError:
raise ValueError(
"EOS token not in tokenizer vocab."
"Please make sure the tokenizer generates a single token for an "
"empty string.")
self._eos_id = empty_str_tokenized.item()
@tf.function
def serve(self, inputs) -> Dict[str, tf.Tensor]:
return self.inference_step(inputs)
@tf.function
def serve_text(self, text: tf.Tensor) -> Dict[str, tf.Tensor]:
tokenized = self._sp_tokenizer.tokenize(text).to_tensor(0)
return self._sp_tokenizer.detokenize(
self.serve({"inputs": tokenized})["outputs"])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve_text")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve_text":
signatures[signature_key] = self.serve_text.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="text"))
return signatures
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
"""Tests for nlp.serving.serving_modules.""" """Tests for nlp.serving.serving_modules."""
import os import os
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.serving import serving_modules from official.nlp.serving import serving_modules
...@@ -24,6 +27,7 @@ from official.nlp.tasks import masked_lm ...@@ -24,6 +27,7 @@ from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging from official.nlp.tasks import tagging
from official.nlp.tasks import translation
def _create_fake_serialized_examples(features_dict): def _create_fake_serialized_examples(features_dict):
...@@ -59,6 +63,33 @@ def _create_fake_vocab_file(vocab_file_path): ...@@ -59,6 +63,33 @@ def _create_fake_vocab_file(vocab_file_path):
outfile.write("\n".join(tokens)) outfile.write("\n".join(tokens))
def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
argstr = " ".join([
f"--input={input_path}", f"--vocab_size={vocab_size}",
"--character_coverage=0.995",
f"--model_prefix={model_path}", "--model_type=bpe",
"--bos_id=-1", "--pad_id=0", f"--eos_id={eos_id}", "--unk_id=2"
])
SentencePieceTrainer.Train(argstr)
def _generate_line_file(filepath, lines):
with tf.io.gfile.GFile(filepath, "w") as f:
for l in lines:
f.write("{}\n".format(l))
def _make_sentencepeice(output_dir):
src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"]
tgt_lines = ["dd cc a ef g", "bcd ef a g", "gef cd ba"]
sentencepeice_input_path = os.path.join(output_dir, "inputs.txt")
_generate_line_file(sentencepeice_input_path, src_lines + tgt_lines)
sentencepeice_model_prefix = os.path.join(output_dir, "sp")
_train_sentencepiece(sentencepeice_input_path, 11, sentencepeice_model_prefix)
sentencepeice_model_path = "{}.model".format(sentencepeice_model_prefix)
return sentencepeice_model_path
class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
...@@ -312,6 +343,31 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): ...@@ -312,6 +343,31 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None}) _ = export_module.get_inference_signatures({"foo": None})
def test_translation(self):
sp_path = _make_sentencepeice(self.get_temp_dir())
encdecoder = translation.EncDecoder(
num_attention_heads=4, intermediate_size=256)
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=encdecoder,
decoder=encdecoder,
embedding_width=256,
padded_decode=False,
decode_max_length=100),
sentencepiece_model_path=sp_path,
)
task = translation.TranslationTask(config)
model = task.build_model()
params = serving_modules.Translation.Params(
sentencepiece_model_path=sp_path)
export_module = serving_modules.Translation(params=params, model=model)
functions = export_module.get_inference_signatures({
"serve_text": "serving_default"
})
outputs = functions["serving_default"](tf.constant(["abcd", "ef gh"]))
self.assertEqual(outputs.shape, (2,))
self.assertEqual(outputs.dtype, tf.string)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment