Commit 319589aa authored by vedanshu's avatar vedanshu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 64f323b1 eaeea071
...@@ -370,7 +370,13 @@ class Trainer(_AsyncTrainer): ...@@ -370,7 +370,13 @@ class Trainer(_AsyncTrainer):
logs[metric.name] = metric.result() logs[metric.name] = metric.result()
metric.reset_states() metric.reset_states()
if callable(self.optimizer.learning_rate): if callable(self.optimizer.learning_rate):
logs["learning_rate"] = self.optimizer.learning_rate(self.global_step) # Maybe a self-implemented optimizer does not have `optimizer.iterations`.
# So just to be safe here.
if hasattr(self.optimizer, "iterations"):
logs["learning_rate"] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
logs["learning_rate"] = self.optimizer.learning_rate(self.global_step)
else: else:
logs["learning_rate"] = self.optimizer.learning_rate logs["learning_rate"] = self.optimizer.learning_rate
return logs return logs
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A binary/library to export TF-NLP serving `SavedModel`."""
import os
from typing import Any, Dict, Text
from absl import app
from absl import flags
import dataclasses
import yaml
from official.core import base_task
from official.core import task_factory
from official.modeling import hyperparams
from official.modeling.hyperparams import base_config
from official.nlp.serving import export_savedmodel_util
from official.nlp.serving import serving_modules
from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
FLAGS = flags.FLAGS
SERVING_MODULES = {
sentence_prediction.SentencePredictionTask:
serving_modules.SentencePrediction,
masked_lm.MaskedLMTask:
serving_modules.MaskedLM,
question_answering.QuestionAnsweringTask:
serving_modules.QuestionAnswering,
tagging.TaggingTask:
serving_modules.Tagging
}
def define_flags():
"""Defines flags."""
flags.DEFINE_string("task_name", "SentencePrediction", "The task to export.")
flags.DEFINE_string("config_file", None,
"The path to task/experiment yaml config file.")
flags.DEFINE_string(
"checkpoint_path", None,
"Object-based checkpoint path, from the training model directory.")
flags.DEFINE_string("export_savedmodel_dir", None,
"Output saved model directory.")
flags.DEFINE_string(
"serving_params", None,
"a YAML/JSON string or csv string for the serving parameters.")
flags.DEFINE_string(
"function_keys", None,
"A string key to retrieve pre-defined serving signatures.")
flags.DEFINE_bool("convert_tpu", False, "")
flags.DEFINE_multi_integer("allowed_batch_size", None,
"Allowed batch sizes for batching ops.")
def lookup_export_module(task: base_task.Task):
export_module_cls = SERVING_MODULES.get(task.__class__, None)
if export_module_cls is None:
ValueError("No registered export module for the task: %s", task.__class__)
return export_module_cls
def create_export_module(*, task_name: Text, config_file: Text,
serving_params: Dict[Text, Any]):
"""Creates a ExportModule."""
task_config_cls = None
task_cls = None
# pylint: disable=protected-access
for key, value in task_factory._REGISTERED_TASK_CLS.items():
print(key.__name__)
if task_name in key.__name__:
task_config_cls, task_cls = key, value
break
if task_cls is None:
raise ValueError("Failed to identify the task class. The provided task "
f"name is {task_name}")
# pylint: enable=protected-access
# TODO(hongkuny): Figure out how to separate the task config from experiments.
@dataclasses.dataclass
class Dummy(base_config.Config):
task: task_config_cls = task_config_cls()
dummy_exp = Dummy()
dummy_exp = hyperparams.override_params_dict(
dummy_exp, config_file, is_strict=False)
dummy_exp.task.validation_data = None
task = task_cls(dummy_exp.task)
model = task.build_model()
export_module_cls = lookup_export_module(task)
params = export_module_cls.Params(**serving_params)
return export_module_cls(params=params, model=model)
def main(_):
serving_params = yaml.load(
hyperparams.nested_csv_str_to_json_str(FLAGS.serving_params),
Loader=yaml.FullLoader)
export_module = create_export_module(
task_name=FLAGS.task_name,
config_file=FLAGS.config_file,
serving_params=serving_params)
export_dir = export_savedmodel_util.export(
export_module,
function_keys=[FLAGS.function_keys],
checkpoint_path=FLAGS.checkpoint_path,
export_savedmodel_dir=FLAGS.export_savedmodel_dir)
if FLAGS.convert_tpu:
# pylint: disable=g-import-not-at-top
from cloud_tpu.inference_converter import converter_cli
from cloud_tpu.inference_converter import converter_options_pb2
tpu_dir = os.path.join(export_dir, "tpu")
options = converter_options_pb2.ConverterOptions()
if FLAGS.allowed_batch_size is not None:
allowed_batch_sizes = sorted(FLAGS.allowed_batch_size)
options.batch_options.num_batch_threads = 4
options.batch_options.max_batch_size = allowed_batch_sizes[-1]
options.batch_options.batch_timeout_micros = 100000
options.batch_options.allowed_batch_sizes[:] = allowed_batch_sizes
options.batch_options.max_enqueued_batches = 1000
converter_cli.ConvertSavedModel(
export_dir, tpu_dir, function_alias="tpu_candidate", options=options,
graph_rewrite_only=True)
if __name__ == "__main__":
define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.export_saved_model."""
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.serving import export_savedmodel
from official.nlp.serving import export_savedmodel_util
from official.nlp.tasks import masked_lm
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
class ExportSavedModelTest(tf.test.TestCase, parameterized.TestCase):
def test_create_export_module(self):
export_module = export_savedmodel.create_export_module(
task_name="SentencePrediction",
config_file=None,
serving_params={
"inputs_only": False,
"parse_sequence_length": 10
})
self.assertEqual(export_module.name, "sentence_prediction")
self.assertFalse(export_module.params.inputs_only)
self.assertEqual(export_module.params.parse_sequence_length, 10)
def test_sentence_prediction(self):
config = sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_classes=2))
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = ckpt.save(self.get_temp_dir())
export_module_cls = export_savedmodel.lookup_export_module(task)
serving_params = {"inputs_only": False}
params = export_module_cls.Params(**serving_params)
export_module = export_module_cls(params=params, model=model)
export_dir = export_savedmodel_util.export(
export_module,
function_keys=["serve"],
checkpoint_path=ckpt_path,
export_savedmodel_dir=self.get_temp_dir())
imported = tf.saved_model.load(export_dir)
serving_fn = imported.signatures["serving_default"]
dummy_ids = tf.ones((1, 5), dtype=tf.int32)
inputs = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
ref_outputs = model(inputs)
outputs = serving_fn(**inputs)
self.assertAllClose(ref_outputs, outputs["outputs"])
self.assertEqual(outputs["outputs"].shape, (1, 2))
def test_masked_lm(self):
config = masked_lm.MaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
cls_heads=[
bert.ClsHeadConfig(inner_dim=10, num_classes=2, name="foo")
]))
task = masked_lm.MaskedLMTask(config)
model = task.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = ckpt.save(self.get_temp_dir())
export_module_cls = export_savedmodel.lookup_export_module(task)
serving_params = {
"cls_head_name": "foo",
"parse_sequence_length": 10,
"max_predictions_per_seq": 5
}
params = export_module_cls.Params(**serving_params)
export_module = export_module_cls(params=params, model=model)
export_dir = export_savedmodel_util.export(
export_module,
function_keys={
"serve": "serving_default",
"serve_examples": "serving_examples"
},
checkpoint_path=ckpt_path,
export_savedmodel_dir=self.get_temp_dir())
imported = tf.saved_model.load(export_dir)
self.assertSameElements(imported.signatures.keys(),
["serving_default", "serving_examples"])
serving_fn = imported.signatures["serving_default"]
dummy_ids = tf.ones((1, 10), dtype=tf.int32)
dummy_pos = tf.ones((1, 5), dtype=tf.int32)
outputs = serving_fn(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_pos)
self.assertEqual(outputs["classification"].shape, (1, 2))
@parameterized.parameters(True, False)
def test_tagging(self, output_encoder_outputs):
hidden_size = 768
num_classes = 3
config = tagging.TaggingConfig(
model=tagging.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(
hidden_size=hidden_size, num_layers=1))),
class_names=["class_0", "class_1", "class_2"])
task = tagging.TaggingTask(config)
model = task.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = ckpt.save(self.get_temp_dir())
export_module_cls = export_savedmodel.lookup_export_module(task)
serving_params = {
"parse_sequence_length": 10,
}
params = export_module_cls.Params(
**serving_params, output_encoder_outputs=output_encoder_outputs)
export_module = export_module_cls(params=params, model=model)
export_dir = export_savedmodel_util.export(
export_module,
function_keys={
"serve": "serving_default",
"serve_examples": "serving_examples"
},
checkpoint_path=ckpt_path,
export_savedmodel_dir=self.get_temp_dir())
imported = tf.saved_model.load(export_dir)
self.assertCountEqual(imported.signatures.keys(),
["serving_default", "serving_examples"])
serving_fn = imported.signatures["serving_default"]
dummy_ids = tf.ones((1, 5), dtype=tf.int32)
inputs = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
outputs = serving_fn(**inputs)
self.assertEqual(outputs["logits"].shape, (1, 5, num_classes))
if output_encoder_outputs:
self.assertEqual(outputs["encoder_outputs"].shape, (1, 5, hidden_size))
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common library to export a SavedModel from the export module."""
from typing import Dict, List, Optional, Text, Union
import tensorflow as tf
from official.core import export_base
def export(export_module: export_base.ExportModule,
function_keys: Union[List[Text], Dict[Text, Text]],
export_savedmodel_dir: Text,
checkpoint_path: Optional[Text] = None,
timestamped: bool = True) -> Text:
"""Exports to SavedModel format.
Args:
export_module: a ExportModule with the keras Model and serving tf.functions.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
Returns:
The savedmodel directory path.
"""
save_options = tf.saved_model.SaveOptions(function_aliases={
"tpu_candidate": export_module.serve,
})
return export_base.export(export_module, function_keys, export_savedmodel_dir,
checkpoint_path, timestamped, save_options)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serving export modules for TF Model Garden NLP models."""
# pylint:disable=missing-class-docstring
from typing import Dict, List, Optional, Text
import dataclasses
import tensorflow as tf
from official.core import export_base
from official.modeling.hyperparams import base_config
from official.nlp.data import sentence_prediction_dataloader
def features_to_int32(features: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""Converts tf.int64 features to tf.int32, keep other features the same.
tf.Example only supports tf.int64, but the TPU only supports tf.int32.
Args:
features: Input tensor dictionary.
Returns:
Features with tf.int64 converted to tf.int32.
"""
converted_features = {}
for name, tensor in features.items():
if tensor.dtype == tf.int64:
converted_features[name] = tf.cast(tensor, tf.int32)
else:
converted_features[name] = tensor
return converted_features
class SentencePrediction(export_base.ExportModule):
"""The export module for the sentence prediction task."""
@dataclasses.dataclass
class Params(base_config.Config):
inputs_only: bool = True
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
# For text input processing.
text_fields: Optional[List[str]] = None
# Either specify these values for preprocessing by Python code...
tokenization: str = "WordPiece" # WordPiece or SentencePiece
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
# file if tokenization is SentencePiece.
vocab_file: str = ""
lower_case: bool = True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url: str = ""
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
if params.text_fields:
self._text_processor = sentence_prediction_dataloader.TextProcessor(
seq_length=params.parse_sequence_length,
vocab_file=params.vocab_file,
tokenization=params.tokenization,
lower_case=params.lower_case,
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
if input_type_ids is None:
# Requires CLS token is the first token of inputs.
input_type_ids = tf.zeros_like(input_word_ids)
if input_mask is None:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask = tf.where(
tf.equal(input_word_ids, 0), tf.zeros_like(input_word_ids),
tf.ones_like(input_word_ids))
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
return dict(outputs=self.inference_step(inputs))
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
inputs_only = self.params.inputs_only
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
}
if not inputs_only:
name_to_features.update({
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
})
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
features[self.input_word_ids_field],
input_mask=None if inputs_only else features["input_mask"],
input_type_ids=None
if inputs_only else features[self.input_type_ids_field])
@tf.function
def serve_text_examples(self, inputs) -> Dict[str, tf.Tensor]:
name_to_features = {}
for text_field in self.params.text_fields:
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
features = tf.io.parse_example(inputs, name_to_features)
segments = [features[x] for x in self.params.text_fields]
model_inputs = self._text_processor(segments)
if self.params.inputs_only:
return self.serve(input_word_ids=model_inputs["input_word_ids"])
return self.serve(**model_inputs)
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples", "serve_text_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
if self.params.inputs_only:
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"))
else:
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
if func_key == "serve_text_examples":
signatures[
signature_key] = self.serve_text_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class MaskedLM(export_base.ExportModule):
"""The export module for the Bert Pretrain (MaskedLM) task."""
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@dataclasses.dataclass
class Params(base_config.Config):
cls_head_name: str = "next_sentence"
use_v2_feature_names: bool = True
parse_sequence_length: Optional[int] = None
max_predictions_per_seq: Optional[int] = None
@tf.function
def serve(self, input_word_ids, input_mask, input_type_ids,
masked_lm_positions) -> Dict[str, tf.Tensor]:
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids,
masked_lm_positions=masked_lm_positions)
outputs = self.inference_step(inputs)
return dict(classification=outputs[self.params.cls_head_name])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
max_predictions_per_seq = self.params.max_predictions_per_seq
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"masked_lm_positions":
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_word_ids_field],
masked_lm_positions=features["masked_lm_positions"])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"),
masked_lm_positions=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="masked_lm_positions"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class QuestionAnswering(export_base.ExportModule):
"""The export module for the question answering task."""
@dataclasses.dataclass
class Params(base_config.Config):
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
if input_mask is None:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask = tf.where(
tf.equal(input_word_ids, 0), tf.zeros_like(input_word_ids),
tf.ones_like(input_word_ids))
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
outputs = self.inference_step(inputs)
return dict(start_logits=outputs[0], end_logits=outputs[1])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_type_ids_field])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class Tagging(export_base.ExportModule):
"""The export module for the tagging task."""
@dataclasses.dataclass
class Params(base_config.Config):
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
output_encoder_outputs: bool = False
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@tf.function
def serve(self, input_word_ids, input_mask,
input_type_ids) -> Dict[str, tf.Tensor]:
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
outputs = self.inference_step(inputs)
if self.params.output_encoder_outputs:
return dict(
logits=outputs["logits"], encoder_outputs=outputs["encoder_outputs"])
else:
return dict(logits=outputs["logits"])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_type_ids_field])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None],
dtype=tf.int32,
name=self.input_word_ids_field),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None],
dtype=tf.int32,
name=self.input_type_ids_field))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.serving_modules."""
import os
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.serving import serving_modules
from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
def _create_fake_serialized_examples(features_dict):
"""Creates a fake dataset."""
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_str_feature(value):
f = tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
return f
examples = []
for _ in range(10):
features = {}
for key, values in features_dict.items():
if isinstance(values, bytes):
features[key] = create_str_feature(values)
else:
features[key] = create_int_feature(values)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
examples.append(tf_example.SerializeToString())
return tf.constant(examples)
def _create_fake_vocab_file(vocab_file_path):
tokens = ["[PAD]"]
for i in range(1, 100):
tokens.append("[unused%d]" % i)
tokens.extend(["[UNK]", "[CLS]", "[SEP]", "[MASK]", "hello", "world"])
with tf.io.gfile.GFile(vocab_file_path, "w") as outfile:
outfile.write("\n".join(tokens))
class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
# use_v2_feature_names
True,
False)
def test_sentence_prediction(self, use_v2_feature_names):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
config = sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_classes=2))
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
params = serving_modules.SentencePrediction.Params(
inputs_only=True,
parse_sequence_length=10,
use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.SentencePrediction(
params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
self.assertSameElements(functions.keys(),
["serving_default", "serving_examples"])
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
outputs = functions["serving_default"](dummy_ids)
self.assertEqual(outputs["outputs"].shape, (10, 2))
params = serving_modules.SentencePrediction.Params(
inputs_only=False,
parse_sequence_length=10,
use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.SentencePrediction(
params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
self.assertEqual(outputs["outputs"].shape, (10, 2))
dummy_ids = tf.ones((10,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["outputs"].shape, (10, 2))
with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None})
@parameterized.parameters(
# inputs_only
True,
False)
def test_sentence_prediction_text(self, inputs_only):
vocab_file_path = os.path.join(self.get_temp_dir(), "vocab.txt")
_create_fake_vocab_file(vocab_file_path)
config = sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_classes=2))
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
params = serving_modules.SentencePrediction.Params(
inputs_only=inputs_only,
parse_sequence_length=10,
text_fields=["foo", "bar"],
vocab_file=vocab_file_path)
export_module = serving_modules.SentencePrediction(
params=params, model=model)
examples = _create_fake_serialized_examples({
"foo": b"hello world",
"bar": b"hello world"
})
functions = export_module.get_inference_signatures({
"serve_text_examples": "serving_default",
})
outputs = functions["serving_default"](examples)
self.assertEqual(outputs["outputs"].shape, (10, 2))
@parameterized.parameters(
# use_v2_feature_names
True,
False)
def test_masked_lm(self, use_v2_feature_names):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
config = masked_lm.MaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]))
task = masked_lm.MaskedLMTask(config)
model = task.build_model()
params = serving_modules.MaskedLM.Params(
parse_sequence_length=10,
max_predictions_per_seq=5,
use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.MaskedLM(params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
self.assertSameElements(functions.keys(),
["serving_default", "serving_examples"])
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
dummy_pos = tf.ones((10, 5), dtype=tf.int32)
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_pos)
self.assertEqual(outputs["classification"].shape, (10, 2))
dummy_ids = tf.ones((10,), dtype=tf.int32)
dummy_pos = tf.ones((5,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids,
"masked_lm_positions": dummy_pos
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["classification"].shape, (10, 2))
@parameterized.parameters(
# use_v2_feature_names
True,
False)
def test_question_answering(self, use_v2_feature_names):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
config = question_answering.QuestionAnsweringConfig(
model=question_answering.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1))),
validation_data=None)
task = question_answering.QuestionAnsweringTask(config)
model = task.build_model()
params = serving_modules.QuestionAnswering.Params(
parse_sequence_length=10, use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.QuestionAnswering(
params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
self.assertSameElements(functions.keys(),
["serving_default", "serving_examples"])
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
self.assertEqual(outputs["start_logits"].shape, (10, 10))
self.assertEqual(outputs["end_logits"].shape, (10, 10))
dummy_ids = tf.ones((10,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["start_logits"].shape, (10, 10))
self.assertEqual(outputs["end_logits"].shape, (10, 10))
@parameterized.parameters(
# (use_v2_feature_names, output_encoder_outputs)
(True, True),
(False, False))
def test_tagging(self, use_v2_feature_names, output_encoder_outputs):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
hidden_size = 768
num_classes = 3
config = tagging.TaggingConfig(
model=tagging.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(
hidden_size=hidden_size, num_layers=1))),
class_names=["class_0", "class_1", "class_2"])
task = tagging.TaggingTask(config)
model = task.build_model()
params = serving_modules.Tagging.Params(
parse_sequence_length=10,
use_v2_feature_names=use_v2_feature_names,
output_encoder_outputs=output_encoder_outputs)
export_module = serving_modules.Tagging(params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
self.assertEqual(outputs["logits"].shape, (10, 10, num_classes))
if output_encoder_outputs:
self.assertEqual(outputs["encoder_outputs"].shape, (10, 10, hidden_size))
dummy_ids = tf.ones((10,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["logits"].shape, (10, 10, num_classes))
if output_encoder_outputs:
self.assertEqual(outputs["encoder_outputs"].shape, (10, 10, hidden_size))
with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None})
if __name__ == "__main__":
tf.test.main()
...@@ -63,12 +63,14 @@ def convert_to_feature(value, value_type=None): ...@@ -63,12 +63,14 @@ def convert_to_feature(value, value_type=None):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
elif value_type == 'int64_list': elif value_type == 'int64_list':
value = np.asarray(value).astype(np.int64).reshape(-1)
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
elif value_type == 'float': elif value_type == 'float':
return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
elif value_type == 'float_list': elif value_type == 'float_list':
value = np.asarray(value).astype(np.float32).reshape(-1)
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) return tf.train.Feature(float_list=tf.train.FloatList(value=value))
elif value_type == 'bytes': elif value_type == 'bytes':
...@@ -172,4 +174,3 @@ def check_and_make_dir(directory): ...@@ -172,4 +174,3 @@ def check_and_make_dir(directory):
"""Creates the directory if it doesn't exist.""" """Creates the directory if it doesn't exist."""
if not tf.io.gfile.isdir(directory): if not tf.io.gfile.isdir(directory):
tf.io.gfile.makedirs(directory) tf.io.gfile.makedirs(directory)
...@@ -320,6 +320,9 @@ class SpineNetMobile(tf.keras.Model): ...@@ -320,6 +320,9 @@ class SpineNetMobile(tf.keras.Model):
endpoints = {} endpoints = {}
for i, block_spec in enumerate(self._block_specs): for i, block_spec in enumerate(self._block_specs):
# Update block level if it is larger than max_level to avoid building
# blocks smaller than requested.
block_spec.level = min(block_spec.level, self._max_level)
# Find out specs for the target block. # Find out specs for the target block.
target_width = int(math.ceil(input_width / 2**block_spec.level)) target_width = int(math.ceil(input_width / 2**block_spec.level))
target_num_filters = int(FILTER_SIZE_MAP[block_spec.level] * target_num_filters = int(FILTER_SIZE_MAP[block_spec.level] *
...@@ -392,8 +395,9 @@ class SpineNetMobile(tf.keras.Model): ...@@ -392,8 +395,9 @@ class SpineNetMobile(tf.keras.Model):
block_spec.level)) block_spec.level))
if (block_spec.level < self._min_level or if (block_spec.level < self._min_level or
block_spec.level > self._max_level): block_spec.level > self._max_level):
raise ValueError('Output level is out of range [{}, {}]'.format( logging.warning(
self._min_level, self._max_level)) 'SpineNet output level out of range [min_level, max_levle] = [%s, %s] will not be used for further processing.',
self._min_level, self._max_level)
endpoints[str(block_spec.level)] = x endpoints[str(block_spec.level)] = x
return endpoints return endpoints
......
...@@ -1130,11 +1130,18 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training): ...@@ -1130,11 +1130,18 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
feature_extractor_config.use_separable_conv or feature_extractor_config.use_separable_conv or
feature_extractor_config.type == 'mobilenet_v2_fpn_sep_conv') feature_extractor_config.type == 'mobilenet_v2_fpn_sep_conv')
kwargs = { kwargs = {
'channel_means': list(feature_extractor_config.channel_means), 'channel_means':
'channel_stds': list(feature_extractor_config.channel_stds), list(feature_extractor_config.channel_means),
'bgr_ordering': feature_extractor_config.bgr_ordering, 'channel_stds':
'depth_multiplier': feature_extractor_config.depth_multiplier, list(feature_extractor_config.channel_stds),
'use_separable_conv': use_separable_conv, 'bgr_ordering':
feature_extractor_config.bgr_ordering,
'depth_multiplier':
feature_extractor_config.depth_multiplier,
'use_separable_conv':
use_separable_conv,
'upsampling_interpolation':
feature_extractor_config.upsampling_interpolation,
} }
......
...@@ -398,7 +398,7 @@ class ModelBuilderTF2Test( ...@@ -398,7 +398,7 @@ class ModelBuilderTF2Test(
} }
""" """
# Set up the configuration proto. # Set up the configuration proto.
config = text_format.Merge(proto_txt, model_pb2.DetectionModel()) config = text_format.Parse(proto_txt, model_pb2.DetectionModel())
# Only add object center and keypoint estimation configs here. # Only add object center and keypoint estimation configs here.
config.center_net.object_center_params.CopyFrom( config.center_net.object_center_params.CopyFrom(
self.get_fake_object_center_from_keypoints_proto()) self.get_fake_object_center_from_keypoints_proto())
...@@ -422,6 +422,50 @@ class ModelBuilderTF2Test( ...@@ -422,6 +422,50 @@ class ModelBuilderTF2Test(
self.assertEqual(kp_params.keypoint_labels, self.assertEqual(kp_params.keypoint_labels,
['nose', 'left_shoulder', 'right_shoulder', 'hip']) ['nose', 'left_shoulder', 'right_shoulder', 'hip'])
def test_create_center_net_model_mobilenet(self):
"""Test building a CenterNet model using bilinear interpolation."""
proto_txt = """
center_net {
num_classes: 10
feature_extractor {
type: "mobilenet_v2_fpn"
depth_multiplier: 1.0
use_separable_conv: true
upsampling_interpolation: "bilinear"
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
}
"""
# Set up the configuration proto.
config = text_format.Parse(proto_txt, model_pb2.DetectionModel())
# Only add object center and keypoint estimation configs here.
config.center_net.object_center_params.CopyFrom(
self.get_fake_object_center_from_keypoints_proto())
config.center_net.keypoint_estimation_task.append(
self.get_fake_keypoint_proto())
config.center_net.keypoint_label_map_path = (
self.get_fake_label_map_file_path())
# Build the model from the configuration.
model = model_builder.build(config, is_training=True)
feature_extractor = model._feature_extractor
# Verify the upsampling layers in the FPN use 'bilinear' interpolation.
fpn = feature_extractor.get_layer('model_1')
num_up_sampling2d_layers = 0
for layer in fpn.layers:
if 'up_sampling2d' in layer.name:
num_up_sampling2d_layers += 1
self.assertEqual('bilinear', layer.interpolation)
# Verify that there are up_sampling2d layers.
self.assertGreater(num_up_sampling2d_layers, 0)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -1776,6 +1776,7 @@ def random_pad_image(image, ...@@ -1776,6 +1776,7 @@ def random_pad_image(image,
min_image_size=None, min_image_size=None,
max_image_size=None, max_image_size=None,
pad_color=None, pad_color=None,
center_pad=False,
seed=None, seed=None,
preprocess_vars_cache=None): preprocess_vars_cache=None):
"""Randomly pads the image. """Randomly pads the image.
...@@ -1814,6 +1815,8 @@ def random_pad_image(image, ...@@ -1814,6 +1815,8 @@ def random_pad_image(image,
pad_color: padding color. A rank 1 tensor of [channels] with dtype= pad_color: padding color. A rank 1 tensor of [channels] with dtype=
tf.float32. if set as None, it will be set to average color of tf.float32. if set as None, it will be set to average color of
the input image. the input image.
center_pad: whether the original image will be padded to the center, or
randomly padded (which is default).
seed: random seed. seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this performed augmentations. Updated in-place. If this
...@@ -1870,6 +1873,12 @@ def random_pad_image(image, ...@@ -1870,6 +1873,12 @@ def random_pad_image(image,
lambda: _random_integer(0, target_width - image_width, seed), lambda: _random_integer(0, target_width - image_width, seed),
lambda: tf.constant(0, dtype=tf.int32)) lambda: tf.constant(0, dtype=tf.int32))
if center_pad:
offset_height = tf.cast(tf.floor((target_height - image_height) / 2),
tf.int32)
offset_width = tf.cast(tf.floor((target_width - image_width) / 2),
tf.int32)
gen_func = lambda: (target_height, target_width, offset_height, offset_width) gen_func = lambda: (target_height, target_width, offset_height, offset_width)
params = _get_or_create_preprocess_rand_vars( params = _get_or_create_preprocess_rand_vars(
gen_func, preprocessor_cache.PreprocessorCache.PAD_IMAGE, gen_func, preprocessor_cache.PreprocessorCache.PAD_IMAGE,
...@@ -2113,7 +2122,7 @@ def random_crop_pad_image(image, ...@@ -2113,7 +2122,7 @@ def random_crop_pad_image(image,
max_padded_size_ratio, max_padded_size_ratio,
dtype=tf.int32) dtype=tf.int32)
padded_image, padded_boxes = random_pad_image( padded_image, padded_boxes = random_pad_image( # pylint: disable=unbalanced-tuple-unpacking
cropped_image, cropped_image,
cropped_boxes, cropped_boxes,
min_image_size=min_image_size, min_image_size=min_image_size,
...@@ -2153,6 +2162,7 @@ def random_crop_to_aspect_ratio(image, ...@@ -2153,6 +2162,7 @@ def random_crop_to_aspect_ratio(image,
aspect_ratio=1.0, aspect_ratio=1.0,
overlap_thresh=0.3, overlap_thresh=0.3,
clip_boxes=True, clip_boxes=True,
center_crop=False,
seed=None, seed=None,
preprocess_vars_cache=None): preprocess_vars_cache=None):
"""Randomly crops an image to the specified aspect ratio. """Randomly crops an image to the specified aspect ratio.
...@@ -2191,6 +2201,7 @@ def random_crop_to_aspect_ratio(image, ...@@ -2191,6 +2201,7 @@ def random_crop_to_aspect_ratio(image,
overlap_thresh: minimum overlap thresh with new cropped overlap_thresh: minimum overlap thresh with new cropped
image to keep the box. image to keep the box.
clip_boxes: whether to clip the boxes to the cropped image. clip_boxes: whether to clip the boxes to the cropped image.
center_crop: whether to take the center crop or a random crop.
seed: random seed. seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this performed augmentations. Updated in-place. If this
...@@ -2247,8 +2258,14 @@ def random_crop_to_aspect_ratio(image, ...@@ -2247,8 +2258,14 @@ def random_crop_to_aspect_ratio(image,
# either offset_height = 0 and offset_width is randomly chosen from # either offset_height = 0 and offset_width is randomly chosen from
# [0, offset_width - target_width), or else offset_width = 0 and # [0, offset_width - target_width), or else offset_width = 0 and
# offset_height is randomly chosen from [0, offset_height - target_height) # offset_height is randomly chosen from [0, offset_height - target_height)
offset_height = _random_integer(0, orig_height - target_height + 1, seed) if center_crop:
offset_width = _random_integer(0, orig_width - target_width + 1, seed) offset_height = tf.cast(tf.math.floor((orig_height - target_height) / 2),
tf.int32)
offset_width = tf.cast(tf.math.floor((orig_width - target_width) / 2),
tf.int32)
else:
offset_height = _random_integer(0, orig_height - target_height + 1, seed)
offset_width = _random_integer(0, orig_width - target_width + 1, seed)
generator_func = lambda: (offset_height, offset_width) generator_func = lambda: (offset_height, offset_width)
offset_height, offset_width = _get_or_create_preprocess_rand_vars( offset_height, offset_width = _get_or_create_preprocess_rand_vars(
...@@ -2979,7 +2996,7 @@ def resize_to_range(image, ...@@ -2979,7 +2996,7 @@ def resize_to_range(image,
'per-channel pad value.') 'per-channel pad value.')
new_image = tf.stack( new_image = tf.stack(
[ [
tf.pad( tf.pad( # pylint: disable=g-complex-comprehension
channels[i], [[0, max_dimension - new_size[0]], channels[i], [[0, max_dimension - new_size[0]],
[0, max_dimension - new_size[1]]], [0, max_dimension - new_size[1]]],
constant_values=per_channel_pad_value[i]) constant_values=per_channel_pad_value[i])
......
...@@ -2194,6 +2194,54 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -2194,6 +2194,54 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
expected_boxes.flatten()) expected_boxes.flatten())
self.assertAllEqual(distorted_masks_.shape, [1, 200, 200]) self.assertAllEqual(distorted_masks_.shape, [1, 200, 200])
def testRunRandomCropToAspectRatioCenterCrop(self):
def graph_fn():
image = self.createColorfulTestImage()
boxes = self.createTestBoxes()
labels = self.createTestLabels()
weights = self.createTestGroundtruthWeights()
masks = tf.random_uniform([2, 200, 400], dtype=tf.float32)
tensor_dict = {
fields.InputDataFields.image: image,
fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_classes: labels,
fields.InputDataFields.groundtruth_weights: weights,
fields.InputDataFields.groundtruth_instance_masks: masks
}
preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_instance_masks=True)
preprocessing_options = [(preprocessor.random_crop_to_aspect_ratio, {
'center_crop': True
})]
with mock.patch.object(preprocessor,
'_random_integer') as mock_random_integer:
mock_random_integer.return_value = tf.constant(0, dtype=tf.int32)
distorted_tensor_dict = preprocessor.preprocess(
tensor_dict,
preprocessing_options,
func_arg_map=preprocessor_arg_map)
distorted_image = distorted_tensor_dict[fields.InputDataFields.image]
distorted_boxes = distorted_tensor_dict[
fields.InputDataFields.groundtruth_boxes]
distorted_labels = distorted_tensor_dict[
fields.InputDataFields.groundtruth_classes]
return [
distorted_image, distorted_boxes, distorted_labels
]
(distorted_image_, distorted_boxes_, distorted_labels_) = self.execute_cpu(
graph_fn, [])
expected_boxes = np.array([[0.0, 0.0, 0.75, 1.0],
[0.25, 0.5, 0.75, 1.0]], dtype=np.float32)
self.assertAllEqual(distorted_image_.shape, [1, 200, 200, 3])
self.assertAllEqual(distorted_labels_, [1, 2])
self.assertAllClose(distorted_boxes_.flatten(),
expected_boxes.flatten())
def testRunRandomCropToAspectRatioWithKeypoints(self): def testRunRandomCropToAspectRatioWithKeypoints(self):
def graph_fn(): def graph_fn():
image = self.createColorfulTestImage() image = self.createColorfulTestImage()
...@@ -2433,6 +2481,51 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -2433,6 +2481,51 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self.assertTrue(np.all((boxes_[:, 3] - boxes_[:, 1]) >= ( self.assertTrue(np.all((boxes_[:, 3] - boxes_[:, 1]) >= (
padded_boxes_[:, 3] - padded_boxes_[:, 1]))) padded_boxes_[:, 3] - padded_boxes_[:, 1])))
def testRandomPadImageCenterPad(self):
def graph_fn():
preprocessing_options = [(preprocessor.normalize_image, {
'original_minval': 0,
'original_maxval': 255,
'target_minval': 0,
'target_maxval': 1
})]
images = self.createColorfulTestImage()
boxes = self.createTestBoxes()
labels = self.createTestLabels()
tensor_dict = {
fields.InputDataFields.image: images,
fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_classes: labels,
}
tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options)
images = tensor_dict[fields.InputDataFields.image]
preprocessing_options = [(preprocessor.random_pad_image, {
'center_pad': True,
'min_image_size': [400, 400],
'max_image_size': [400, 400],
})]
padded_tensor_dict = preprocessor.preprocess(tensor_dict,
preprocessing_options)
padded_images = padded_tensor_dict[fields.InputDataFields.image]
padded_boxes = padded_tensor_dict[
fields.InputDataFields.groundtruth_boxes]
padded_labels = padded_tensor_dict[
fields.InputDataFields.groundtruth_classes]
return [padded_images, padded_boxes, padded_labels]
(padded_images_, padded_boxes_, padded_labels_) = self.execute_cpu(
graph_fn, [])
expected_boxes = np.array([[0.25, 0.25, 0.625, 1.0],
[0.375, 0.5, .625, 1.0]], dtype=np.float32)
self.assertAllEqual(padded_images_.shape, [1, 400, 400, 3])
self.assertAllEqual(padded_labels_, [1, 2])
self.assertAllClose(padded_boxes_.flatten(),
expected_boxes.flatten())
@parameterized.parameters( @parameterized.parameters(
{'include_dense_pose': False}, {'include_dense_pose': False},
) )
......
...@@ -117,23 +117,9 @@ class CenterNetFeatureExtractor(tf.keras.Model): ...@@ -117,23 +117,9 @@ class CenterNetFeatureExtractor(tf.keras.Model):
pass pass
@property @property
@abc.abstractmethod def classification_backbone(self):
def supported_sub_model_types(self): raise NotImplementedError(
"""Valid sub model types supported by the get_sub_model function.""" 'Classification backbone not supported for {}'.format(type(self)))
pass
@abc.abstractmethod
def get_sub_model(self, sub_model_type):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256), def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
...@@ -4200,25 +4186,28 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -4200,25 +4186,28 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint). A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
""" """
supported_types = self._feature_extractor.supported_sub_model_types if fine_tune_checkpoint_type == 'detection':
supported_types += ['fine_tune']
if fine_tune_checkpoint_type not in supported_types:
message = ('Checkpoint type "{}" not supported for {}. '
'Supported types are {}')
raise ValueError(
message.format(fine_tune_checkpoint_type,
self._feature_extractor.__class__.__name__,
supported_types))
elif fine_tune_checkpoint_type == 'fine_tune':
feature_extractor_model = tf.train.Checkpoint( feature_extractor_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor) _feature_extractor=self._feature_extractor)
return {'model': feature_extractor_model} return {'model': feature_extractor_model}
elif fine_tune_checkpoint_type == 'classification':
return {
'feature_extractor':
self._feature_extractor.classification_backbone
}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
elif fine_tune_checkpoint_type == 'fine_tune':
raise ValueError(('"fine_tune" is no longer supported for CenterNet. '
'Please set fine_tune_checkpoint_type to "detection"'
' which has the same functionality. If you are using'
' the ExtremeNet checkpoint, download the new version'
' from the model zoo.'))
else: else:
return {'feature_extractor': self._feature_extractor.get_sub_model( raise ValueError('Unknown fine tune checkpoint type {}'.format(
fine_tune_checkpoint_type)} fine_tune_checkpoint_type))
def updates(self): def updates(self):
if tf_version.is_tf2(): if tf_version.is_tf2():
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
from __future__ import division from __future__ import division
import functools import functools
import re
import unittest import unittest
from absl.testing import parameterized from absl.testing import parameterized
...@@ -2887,15 +2886,14 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase): ...@@ -2887,15 +2886,14 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self.assertIsInstance(restore_from_objects_map['feature_extractor'], self.assertIsInstance(restore_from_objects_map['feature_extractor'],
tf.keras.Model) tf.keras.Model)
def test_retore_map_error(self): def test_retore_map_detection(self):
"""Test that restoring unsupported checkpoint type raises an error.""" """Test that detection checkpoints can be restored."""
model = build_center_net_meta_arch(build_resnet=True) model = build_center_net_meta_arch(build_resnet=True)
msg = ("Checkpoint type \"detection\" not supported for " restore_from_objects_map = model.restore_from_objects('detection')
"CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']") self.assertIsInstance(restore_from_objects_map['model']._feature_extractor,
with self.assertRaisesRegex(ValueError, re.escape(msg)): tf.keras.Model)
model.restore_from_objects('detection')
class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor): class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor):
......
...@@ -2896,6 +2896,8 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -2896,6 +2896,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
_feature_extractor_for_proposal_features= _feature_extractor_for_proposal_features=
self._feature_extractor_for_proposal_features) self._feature_extractor_for_proposal_features)
return {'model': fake_model} return {'model': fake_model}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
else: else:
raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format( raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
fine_tune_checkpoint_type)) fine_tune_checkpoint_type))
......
...@@ -35,6 +35,7 @@ from object_detection.protos import train_pb2 ...@@ -35,6 +35,7 @@ from object_detection.protos import train_pb2
from object_detection.utils import config_util from object_detection.utils import config_util
from object_detection.utils import label_map_util from object_detection.utils import label_map_util
from object_detection.utils import ops from object_detection.utils import ops
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vutils from object_detection.utils import visualization_utils as vutils
...@@ -587,6 +588,9 @@ def train_loop( ...@@ -587,6 +588,9 @@ def train_loop(
lambda: global_step % num_steps_per_iteration == 0): lambda: global_step % num_steps_per_iteration == 0):
# Load a fine-tuning checkpoint. # Load a fine-tuning checkpoint.
if train_config.fine_tune_checkpoint: if train_config.fine_tune_checkpoint:
variables_helper.ensure_checkpoint_supported(
train_config.fine_tune_checkpoint, fine_tune_checkpoint_type,
model_dir)
load_fine_tune_checkpoint( load_fine_tune_checkpoint(
detection_model, train_config.fine_tune_checkpoint, detection_model, train_config.fine_tune_checkpoint,
fine_tune_checkpoint_type, fine_tune_checkpoint_version, fine_tune_checkpoint_type, fine_tune_checkpoint_version,
......
...@@ -62,16 +62,6 @@ class CenterNetHourglassFeatureExtractor( ...@@ -62,16 +62,6 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor.""" """Ther number of feature outputs returned by the feature extractor."""
return self._network.num_hourglasses return self._network.num_hourglasses
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def hourglass_10(channel_means, channel_stds, bgr_ordering, **kwargs): def hourglass_10(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-10 backbone for CenterNet.""" """The Hourglass-10 backbone for CenterNet."""
......
...@@ -83,9 +83,6 @@ class CenterNetMobileNetV2FeatureExtractor( ...@@ -83,9 +83,6 @@ class CenterNetMobileNetV2FeatureExtractor(
def load_feature_extractor_weights(self, path): def load_feature_extractor_weights(self, path):
self._network.load_weights(path) self._network.load_weights(path)
def get_base_model(self):
return self._network
def call(self, inputs): def call(self, inputs):
return [self._network(inputs)] return [self._network(inputs)]
...@@ -100,14 +97,8 @@ class CenterNetMobileNetV2FeatureExtractor( ...@@ -100,14 +97,8 @@ class CenterNetMobileNetV2FeatureExtractor(
return 1 return 1
@property @property
def supported_sub_model_types(self): def classification_backbone(self):
return ['detection'] return self._network
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering, def mobilenet_v2(channel_means, channel_stds, bgr_ordering,
......
...@@ -39,7 +39,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -39,7 +39,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means=(0., 0., 0.), channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.), channel_stds=(1., 1., 1.),
bgr_ordering=False, bgr_ordering=False,
use_separable_conv=False): use_separable_conv=False,
upsampling_interpolation='nearest'):
"""Intializes the feature extractor. """Intializes the feature extractor.
Args: Args:
...@@ -52,6 +53,9 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -52,6 +53,9 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
[blue, red, green] order. [blue, red, green] order.
use_separable_conv: If set to True, all convolutional layers in the FPN use_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions. network will be replaced by separable convolutions.
upsampling_interpolation: A string (one of 'nearest' or 'bilinear')
indicating which interpolation method to use for the upsampling ops in
the FPN.
""" """
super(CenterNetMobileNetV2FPNFeatureExtractor, self).__init__( super(CenterNetMobileNetV2FPNFeatureExtractor, self).__init__(
...@@ -84,7 +88,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -84,7 +88,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
for i, num_filters in enumerate(num_filters_list): for i, num_filters in enumerate(num_filters_list):
level_ind = len(num_filters_list) - 1 - i level_ind = len(num_filters_list) - 1 - i
# Upsample. # Upsample.
upsample_op = tf.keras.layers.UpSampling2D(2, interpolation='nearest') upsample_op = tf.keras.layers.UpSampling2D(
2, interpolation=upsampling_interpolation)
top_down = upsample_op(top_down) top_down = upsample_op(top_down)
# Residual (skip-connection) from bottom-up pathway. # Residual (skip-connection) from bottom-up pathway.
...@@ -144,7 +149,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -144,7 +149,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering, def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
use_separable_conv=False, depth_multiplier=1.0, **kwargs): use_separable_conv=False, depth_multiplier=1.0,
upsampling_interpolation='nearest', **kwargs):
"""The MobileNetV2+FPN backbone for CenterNet.""" """The MobileNetV2+FPN backbone for CenterNet."""
del kwargs del kwargs
...@@ -159,4 +165,5 @@ def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering, ...@@ -159,4 +165,5 @@ def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
channel_means=channel_means, channel_means=channel_means,
channel_stds=channel_stds, channel_stds=channel_stds,
bgr_ordering=bgr_ordering, bgr_ordering=bgr_ordering,
use_separable_conv=use_separable_conv) use_separable_conv=use_separable_conv,
upsampling_interpolation=upsampling_interpolation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment