"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e5674015f3d16f55495afc222dc54a65a9089600"
Commit 977ae4db authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 372153802
parent 7907ba50
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A binary/library to export TF-NLP serving `SavedModel`."""
import os
from typing import Any, Dict, Text
from absl import app
from absl import flags
import dataclasses
import yaml
from official.core import base_task
from official.core import task_factory
from official.modeling import hyperparams
from official.modeling.hyperparams import base_config
from official.nlp.serving import export_savedmodel_util
from official.nlp.serving import serving_modules
from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
FLAGS = flags.FLAGS
SERVING_MODULES = {
sentence_prediction.SentencePredictionTask:
serving_modules.SentencePrediction,
masked_lm.MaskedLMTask:
serving_modules.MaskedLM,
question_answering.QuestionAnsweringTask:
serving_modules.QuestionAnswering,
tagging.TaggingTask:
serving_modules.Tagging
}
def define_flags():
"""Defines flags."""
flags.DEFINE_string("task_name", "SentencePrediction", "The task to export.")
flags.DEFINE_string("config_file", None,
"The path to task/experiment yaml config file.")
flags.DEFINE_string(
"checkpoint_path", None,
"Object-based checkpoint path, from the training model directory.")
flags.DEFINE_string("export_savedmodel_dir", None,
"Output saved model directory.")
flags.DEFINE_string(
"serving_params", None,
"a YAML/JSON string or csv string for the serving parameters.")
flags.DEFINE_string(
"function_keys", None,
"A string key to retrieve pre-defined serving signatures.")
flags.DEFINE_bool("convert_tpu", False, "")
flags.DEFINE_multi_integer("allowed_batch_size", None,
"Allowed batch sizes for batching ops.")
def lookup_export_module(task: base_task.Task):
export_module_cls = SERVING_MODULES.get(task.__class__, None)
if export_module_cls is None:
ValueError("No registered export module for the task: %s", task.__class__)
return export_module_cls
def create_export_module(*, task_name: Text, config_file: Text,
serving_params: Dict[Text, Any]):
"""Creates a ExportModule."""
task_config_cls = None
task_cls = None
# pylint: disable=protected-access
for key, value in task_factory._REGISTERED_TASK_CLS.items():
print(key.__name__)
if task_name in key.__name__:
task_config_cls, task_cls = key, value
break
if task_cls is None:
raise ValueError("Failed to identify the task class. The provided task "
f"name is {task_name}")
# pylint: enable=protected-access
# TODO(hongkuny): Figure out how to separate the task config from experiments.
@dataclasses.dataclass
class Dummy(base_config.Config):
task: task_config_cls = task_config_cls()
dummy_exp = Dummy()
dummy_exp = hyperparams.override_params_dict(
dummy_exp, config_file, is_strict=False)
dummy_exp.task.validation_data = None
task = task_cls(dummy_exp.task)
model = task.build_model()
export_module_cls = lookup_export_module(task)
params = export_module_cls.Params(**serving_params)
return export_module_cls(params=params, model=model)
def main(_):
serving_params = yaml.load(
hyperparams.nested_csv_str_to_json_str(FLAGS.serving_params),
Loader=yaml.FullLoader)
export_module = create_export_module(
task_name=FLAGS.task_name,
config_file=FLAGS.config_file,
serving_params=serving_params)
export_dir = export_savedmodel_util.export(
export_module,
function_keys=[FLAGS.function_keys],
checkpoint_path=FLAGS.checkpoint_path,
export_savedmodel_dir=FLAGS.export_savedmodel_dir)
if FLAGS.convert_tpu:
# pylint: disable=g-import-not-at-top
from cloud_tpu.inference_converter import converter_cli
from cloud_tpu.inference_converter import converter_options_pb2
tpu_dir = os.path.join(export_dir, "tpu")
options = converter_options_pb2.ConverterOptions()
if FLAGS.allowed_batch_size is not None:
allowed_batch_sizes = sorted(FLAGS.allowed_batch_size)
options.batch_options.num_batch_threads = 4
options.batch_options.max_batch_size = allowed_batch_sizes[-1]
options.batch_options.batch_timeout_micros = 100000
options.batch_options.allowed_batch_sizes[:] = allowed_batch_sizes
options.batch_options.max_enqueued_batches = 1000
converter_cli.ConvertSavedModel(
export_dir, tpu_dir, function_alias="tpu_candidate", options=options,
graph_rewrite_only=True)
if __name__ == "__main__":
define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.export_saved_model."""
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.serving import export_savedmodel
from official.nlp.serving import export_savedmodel_util
from official.nlp.tasks import masked_lm
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
class ExportSavedModelTest(tf.test.TestCase, parameterized.TestCase):
def test_create_export_module(self):
export_module = export_savedmodel.create_export_module(
task_name="SentencePrediction",
config_file=None,
serving_params={
"inputs_only": False,
"parse_sequence_length": 10
})
self.assertEqual(export_module.name, "sentence_prediction")
self.assertFalse(export_module.params.inputs_only)
self.assertEqual(export_module.params.parse_sequence_length, 10)
def test_sentence_prediction(self):
config = sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_classes=2))
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = ckpt.save(self.get_temp_dir())
export_module_cls = export_savedmodel.lookup_export_module(task)
serving_params = {"inputs_only": False}
params = export_module_cls.Params(**serving_params)
export_module = export_module_cls(params=params, model=model)
export_dir = export_savedmodel_util.export(
export_module,
function_keys=["serve"],
checkpoint_path=ckpt_path,
export_savedmodel_dir=self.get_temp_dir())
imported = tf.saved_model.load(export_dir)
serving_fn = imported.signatures["serving_default"]
dummy_ids = tf.ones((1, 5), dtype=tf.int32)
inputs = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
ref_outputs = model(inputs)
outputs = serving_fn(**inputs)
self.assertAllClose(ref_outputs, outputs["outputs"])
self.assertEqual(outputs["outputs"].shape, (1, 2))
def test_masked_lm(self):
config = masked_lm.MaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
cls_heads=[
bert.ClsHeadConfig(inner_dim=10, num_classes=2, name="foo")
]))
task = masked_lm.MaskedLMTask(config)
model = task.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = ckpt.save(self.get_temp_dir())
export_module_cls = export_savedmodel.lookup_export_module(task)
serving_params = {
"cls_head_name": "foo",
"parse_sequence_length": 10,
"max_predictions_per_seq": 5
}
params = export_module_cls.Params(**serving_params)
export_module = export_module_cls(params=params, model=model)
export_dir = export_savedmodel_util.export(
export_module,
function_keys={
"serve": "serving_default",
"serve_examples": "serving_examples"
},
checkpoint_path=ckpt_path,
export_savedmodel_dir=self.get_temp_dir())
imported = tf.saved_model.load(export_dir)
self.assertSameElements(imported.signatures.keys(),
["serving_default", "serving_examples"])
serving_fn = imported.signatures["serving_default"]
dummy_ids = tf.ones((1, 10), dtype=tf.int32)
dummy_pos = tf.ones((1, 5), dtype=tf.int32)
outputs = serving_fn(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_pos)
self.assertEqual(outputs["classification"].shape, (1, 2))
@parameterized.parameters(True, False)
def test_tagging(self, output_encoder_outputs):
hidden_size = 768
num_classes = 3
config = tagging.TaggingConfig(
model=tagging.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(
hidden_size=hidden_size, num_layers=1))),
class_names=["class_0", "class_1", "class_2"])
task = tagging.TaggingTask(config)
model = task.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = ckpt.save(self.get_temp_dir())
export_module_cls = export_savedmodel.lookup_export_module(task)
serving_params = {
"parse_sequence_length": 10,
}
params = export_module_cls.Params(
**serving_params, output_encoder_outputs=output_encoder_outputs)
export_module = export_module_cls(params=params, model=model)
export_dir = export_savedmodel_util.export(
export_module,
function_keys={
"serve": "serving_default",
"serve_examples": "serving_examples"
},
checkpoint_path=ckpt_path,
export_savedmodel_dir=self.get_temp_dir())
imported = tf.saved_model.load(export_dir)
self.assertCountEqual(imported.signatures.keys(),
["serving_default", "serving_examples"])
serving_fn = imported.signatures["serving_default"]
dummy_ids = tf.ones((1, 5), dtype=tf.int32)
inputs = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
outputs = serving_fn(**inputs)
self.assertEqual(outputs["logits"].shape, (1, 5, num_classes))
if output_encoder_outputs:
self.assertEqual(outputs["encoder_outputs"].shape, (1, 5, hidden_size))
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common library to export a SavedModel from the export module."""
from typing import Dict, List, Optional, Text, Union
import tensorflow as tf
from official.core import export_base
def export(export_module: export_base.ExportModule,
function_keys: Union[List[Text], Dict[Text, Text]],
export_savedmodel_dir: Text,
checkpoint_path: Optional[Text] = None,
timestamped: bool = True) -> Text:
"""Exports to SavedModel format.
Args:
export_module: a ExportModule with the keras Model and serving tf.functions.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
Returns:
The savedmodel directory path.
"""
save_options = tf.saved_model.SaveOptions(function_aliases={
"tpu_candidate": export_module.serve,
})
return export_base.export(export_module, function_keys, export_savedmodel_dir,
checkpoint_path, timestamped, save_options)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serving export modules for TF Model Garden NLP models."""
# pylint:disable=missing-class-docstring
from typing import Dict, List, Optional, Text
import dataclasses
import tensorflow as tf
from official.core import export_base
from official.modeling.hyperparams import base_config
from official.nlp.data import sentence_prediction_dataloader
def features_to_int32(features: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""Converts tf.int64 features to tf.int32, keep other features the same.
tf.Example only supports tf.int64, but the TPU only supports tf.int32.
Args:
features: Input tensor dictionary.
Returns:
Features with tf.int64 converted to tf.int32.
"""
converted_features = {}
for name, tensor in features.items():
if tensor.dtype == tf.int64:
converted_features[name] = tf.cast(tensor, tf.int32)
else:
converted_features[name] = tensor
return converted_features
class SentencePrediction(export_base.ExportModule):
"""The export module for the sentence prediction task."""
@dataclasses.dataclass
class Params(base_config.Config):
inputs_only: bool = True
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
# For text input processing.
text_fields: Optional[List[str]] = None
# Either specify these values for preprocessing by Python code...
tokenization: str = "WordPiece" # WordPiece or SentencePiece
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
# file if tokenization is SentencePiece.
vocab_file: str = ""
lower_case: bool = True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url: str = ""
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
if params.text_fields:
self._text_processor = sentence_prediction_dataloader.TextProcessor(
seq_length=params.parse_sequence_length,
vocab_file=params.vocab_file,
tokenization=params.tokenization,
lower_case=params.lower_case,
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
if input_type_ids is None:
# Requires CLS token is the first token of inputs.
input_type_ids = tf.zeros_like(input_word_ids)
if input_mask is None:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask = tf.where(
tf.equal(input_word_ids, 0), tf.zeros_like(input_word_ids),
tf.ones_like(input_word_ids))
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
return dict(outputs=self.inference_step(inputs))
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
inputs_only = self.params.inputs_only
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
}
if not inputs_only:
name_to_features.update({
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
})
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
features[self.input_word_ids_field],
input_mask=None if inputs_only else features["input_mask"],
input_type_ids=None
if inputs_only else features[self.input_type_ids_field])
@tf.function
def serve_text_examples(self, inputs) -> Dict[str, tf.Tensor]:
name_to_features = {}
for text_field in self.params.text_fields:
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
features = tf.io.parse_example(inputs, name_to_features)
segments = [features[x] for x in self.params.text_fields]
model_inputs = self._text_processor(segments)
if self.params.inputs_only:
return self.serve(input_word_ids=model_inputs["input_word_ids"])
return self.serve(**model_inputs)
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples", "serve_text_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
if self.params.inputs_only:
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"))
else:
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
if func_key == "serve_text_examples":
signatures[
signature_key] = self.serve_text_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class MaskedLM(export_base.ExportModule):
"""The export module for the Bert Pretrain (MaskedLM) task."""
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@dataclasses.dataclass
class Params(base_config.Config):
cls_head_name: str = "next_sentence"
use_v2_feature_names: bool = True
parse_sequence_length: Optional[int] = None
max_predictions_per_seq: Optional[int] = None
@tf.function
def serve(self, input_word_ids, input_mask, input_type_ids,
masked_lm_positions) -> Dict[str, tf.Tensor]:
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids,
masked_lm_positions=masked_lm_positions)
outputs = self.inference_step(inputs)
return dict(classification=outputs[self.params.cls_head_name])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
max_predictions_per_seq = self.params.max_predictions_per_seq
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"masked_lm_positions":
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_word_ids_field],
masked_lm_positions=features["masked_lm_positions"])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"),
masked_lm_positions=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="masked_lm_positions"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class QuestionAnswering(export_base.ExportModule):
"""The export module for the question answering task."""
@dataclasses.dataclass
class Params(base_config.Config):
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
if input_mask is None:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask = tf.where(
tf.equal(input_word_ids, 0), tf.zeros_like(input_word_ids),
tf.ones_like(input_word_ids))
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
outputs = self.inference_step(inputs)
return dict(start_logits=outputs[0], end_logits=outputs[1])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_type_ids_field])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class Tagging(export_base.ExportModule):
"""The export module for the tagging task."""
@dataclasses.dataclass
class Params(base_config.Config):
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
output_encoder_outputs: bool = False
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@tf.function
def serve(self, input_word_ids, input_mask,
input_type_ids) -> Dict[str, tf.Tensor]:
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
outputs = self.inference_step(inputs)
if self.params.output_encoder_outputs:
return dict(
logits=outputs["logits"], encoder_outputs=outputs["encoder_outputs"])
else:
return dict(logits=outputs["logits"])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_type_ids_field])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None],
dtype=tf.int32,
name=self.input_word_ids_field),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None],
dtype=tf.int32,
name=self.input_type_ids_field))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.serving_modules."""
import os
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.serving import serving_modules
from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
def _create_fake_serialized_examples(features_dict):
"""Creates a fake dataset."""
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_str_feature(value):
f = tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
return f
examples = []
for _ in range(10):
features = {}
for key, values in features_dict.items():
if isinstance(values, bytes):
features[key] = create_str_feature(values)
else:
features[key] = create_int_feature(values)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
examples.append(tf_example.SerializeToString())
return tf.constant(examples)
def _create_fake_vocab_file(vocab_file_path):
tokens = ["[PAD]"]
for i in range(1, 100):
tokens.append("[unused%d]" % i)
tokens.extend(["[UNK]", "[CLS]", "[SEP]", "[MASK]", "hello", "world"])
with tf.io.gfile.GFile(vocab_file_path, "w") as outfile:
outfile.write("\n".join(tokens))
class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
# use_v2_feature_names
True,
False)
def test_sentence_prediction(self, use_v2_feature_names):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
config = sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_classes=2))
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
params = serving_modules.SentencePrediction.Params(
inputs_only=True,
parse_sequence_length=10,
use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.SentencePrediction(
params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
self.assertSameElements(functions.keys(),
["serving_default", "serving_examples"])
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
outputs = functions["serving_default"](dummy_ids)
self.assertEqual(outputs["outputs"].shape, (10, 2))
params = serving_modules.SentencePrediction.Params(
inputs_only=False,
parse_sequence_length=10,
use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.SentencePrediction(
params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
self.assertEqual(outputs["outputs"].shape, (10, 2))
dummy_ids = tf.ones((10,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["outputs"].shape, (10, 2))
with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None})
@parameterized.parameters(
# inputs_only
True,
False)
def test_sentence_prediction_text(self, inputs_only):
vocab_file_path = os.path.join(self.get_temp_dir(), "vocab.txt")
_create_fake_vocab_file(vocab_file_path)
config = sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_classes=2))
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
params = serving_modules.SentencePrediction.Params(
inputs_only=inputs_only,
parse_sequence_length=10,
text_fields=["foo", "bar"],
vocab_file=vocab_file_path)
export_module = serving_modules.SentencePrediction(
params=params, model=model)
examples = _create_fake_serialized_examples({
"foo": b"hello world",
"bar": b"hello world"
})
functions = export_module.get_inference_signatures({
"serve_text_examples": "serving_default",
})
outputs = functions["serving_default"](examples)
self.assertEqual(outputs["outputs"].shape, (10, 2))
@parameterized.parameters(
# use_v2_feature_names
True,
False)
def test_masked_lm(self, use_v2_feature_names):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
config = masked_lm.MaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]))
task = masked_lm.MaskedLMTask(config)
model = task.build_model()
params = serving_modules.MaskedLM.Params(
parse_sequence_length=10,
max_predictions_per_seq=5,
use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.MaskedLM(params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
self.assertSameElements(functions.keys(),
["serving_default", "serving_examples"])
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
dummy_pos = tf.ones((10, 5), dtype=tf.int32)
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_pos)
self.assertEqual(outputs["classification"].shape, (10, 2))
dummy_ids = tf.ones((10,), dtype=tf.int32)
dummy_pos = tf.ones((5,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids,
"masked_lm_positions": dummy_pos
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["classification"].shape, (10, 2))
@parameterized.parameters(
# use_v2_feature_names
True,
False)
def test_question_answering(self, use_v2_feature_names):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
config = question_answering.QuestionAnsweringConfig(
model=question_answering.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1))),
validation_data=None)
task = question_answering.QuestionAnsweringTask(config)
model = task.build_model()
params = serving_modules.QuestionAnswering.Params(
parse_sequence_length=10, use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.QuestionAnswering(
params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
self.assertSameElements(functions.keys(),
["serving_default", "serving_examples"])
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
self.assertEqual(outputs["start_logits"].shape, (10, 10))
self.assertEqual(outputs["end_logits"].shape, (10, 10))
dummy_ids = tf.ones((10,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["start_logits"].shape, (10, 10))
self.assertEqual(outputs["end_logits"].shape, (10, 10))
@parameterized.parameters(
# (use_v2_feature_names, output_encoder_outputs)
(True, True),
(False, False))
def test_tagging(self, use_v2_feature_names, output_encoder_outputs):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
hidden_size = 768
num_classes = 3
config = tagging.TaggingConfig(
model=tagging.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(
hidden_size=hidden_size, num_layers=1))),
class_names=["class_0", "class_1", "class_2"])
task = tagging.TaggingTask(config)
model = task.build_model()
params = serving_modules.Tagging.Params(
parse_sequence_length=10,
use_v2_feature_names=use_v2_feature_names,
output_encoder_outputs=output_encoder_outputs)
export_module = serving_modules.Tagging(params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
self.assertEqual(outputs["logits"].shape, (10, 10, num_classes))
if output_encoder_outputs:
self.assertEqual(outputs["encoder_outputs"].shape, (10, 10, hidden_size))
dummy_ids = tf.ones((10,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["logits"].shape, (10, 10, num_classes))
if output_encoder_outputs:
self.assertEqual(outputs["encoder_outputs"].shape, (10, 10, hidden_size))
with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None})
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment