Commit 319589aa authored by vedanshu's avatar vedanshu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 64f323b1 eaeea071
......@@ -370,6 +370,12 @@ class Trainer(_AsyncTrainer):
logs[metric.name] = metric.result()
metric.reset_states()
if callable(self.optimizer.learning_rate):
# Maybe a self-implemented optimizer does not have `optimizer.iterations`.
# So just to be safe here.
if hasattr(self.optimizer, "iterations"):
logs["learning_rate"] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
logs["learning_rate"] = self.optimizer.learning_rate(self.global_step)
else:
logs["learning_rate"] = self.optimizer.learning_rate
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A binary/library to export TF-NLP serving `SavedModel`."""
import os
from typing import Any, Dict, Text
from absl import app
from absl import flags
import dataclasses
import yaml
from official.core import base_task
from official.core import task_factory
from official.modeling import hyperparams
from official.modeling.hyperparams import base_config
from official.nlp.serving import export_savedmodel_util
from official.nlp.serving import serving_modules
from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
FLAGS = flags.FLAGS
SERVING_MODULES = {
sentence_prediction.SentencePredictionTask:
serving_modules.SentencePrediction,
masked_lm.MaskedLMTask:
serving_modules.MaskedLM,
question_answering.QuestionAnsweringTask:
serving_modules.QuestionAnswering,
tagging.TaggingTask:
serving_modules.Tagging
}
def define_flags():
"""Defines flags."""
flags.DEFINE_string("task_name", "SentencePrediction", "The task to export.")
flags.DEFINE_string("config_file", None,
"The path to task/experiment yaml config file.")
flags.DEFINE_string(
"checkpoint_path", None,
"Object-based checkpoint path, from the training model directory.")
flags.DEFINE_string("export_savedmodel_dir", None,
"Output saved model directory.")
flags.DEFINE_string(
"serving_params", None,
"a YAML/JSON string or csv string for the serving parameters.")
flags.DEFINE_string(
"function_keys", None,
"A string key to retrieve pre-defined serving signatures.")
flags.DEFINE_bool("convert_tpu", False, "")
flags.DEFINE_multi_integer("allowed_batch_size", None,
"Allowed batch sizes for batching ops.")
def lookup_export_module(task: base_task.Task):
export_module_cls = SERVING_MODULES.get(task.__class__, None)
if export_module_cls is None:
ValueError("No registered export module for the task: %s", task.__class__)
return export_module_cls
def create_export_module(*, task_name: Text, config_file: Text,
serving_params: Dict[Text, Any]):
"""Creates a ExportModule."""
task_config_cls = None
task_cls = None
# pylint: disable=protected-access
for key, value in task_factory._REGISTERED_TASK_CLS.items():
print(key.__name__)
if task_name in key.__name__:
task_config_cls, task_cls = key, value
break
if task_cls is None:
raise ValueError("Failed to identify the task class. The provided task "
f"name is {task_name}")
# pylint: enable=protected-access
# TODO(hongkuny): Figure out how to separate the task config from experiments.
@dataclasses.dataclass
class Dummy(base_config.Config):
task: task_config_cls = task_config_cls()
dummy_exp = Dummy()
dummy_exp = hyperparams.override_params_dict(
dummy_exp, config_file, is_strict=False)
dummy_exp.task.validation_data = None
task = task_cls(dummy_exp.task)
model = task.build_model()
export_module_cls = lookup_export_module(task)
params = export_module_cls.Params(**serving_params)
return export_module_cls(params=params, model=model)
def main(_):
serving_params = yaml.load(
hyperparams.nested_csv_str_to_json_str(FLAGS.serving_params),
Loader=yaml.FullLoader)
export_module = create_export_module(
task_name=FLAGS.task_name,
config_file=FLAGS.config_file,
serving_params=serving_params)
export_dir = export_savedmodel_util.export(
export_module,
function_keys=[FLAGS.function_keys],
checkpoint_path=FLAGS.checkpoint_path,
export_savedmodel_dir=FLAGS.export_savedmodel_dir)
if FLAGS.convert_tpu:
# pylint: disable=g-import-not-at-top
from cloud_tpu.inference_converter import converter_cli
from cloud_tpu.inference_converter import converter_options_pb2
tpu_dir = os.path.join(export_dir, "tpu")
options = converter_options_pb2.ConverterOptions()
if FLAGS.allowed_batch_size is not None:
allowed_batch_sizes = sorted(FLAGS.allowed_batch_size)
options.batch_options.num_batch_threads = 4
options.batch_options.max_batch_size = allowed_batch_sizes[-1]
options.batch_options.batch_timeout_micros = 100000
options.batch_options.allowed_batch_sizes[:] = allowed_batch_sizes
options.batch_options.max_enqueued_batches = 1000
converter_cli.ConvertSavedModel(
export_dir, tpu_dir, function_alias="tpu_candidate", options=options,
graph_rewrite_only=True)
if __name__ == "__main__":
define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.export_saved_model."""
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.serving import export_savedmodel
from official.nlp.serving import export_savedmodel_util
from official.nlp.tasks import masked_lm
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
class ExportSavedModelTest(tf.test.TestCase, parameterized.TestCase):
def test_create_export_module(self):
export_module = export_savedmodel.create_export_module(
task_name="SentencePrediction",
config_file=None,
serving_params={
"inputs_only": False,
"parse_sequence_length": 10
})
self.assertEqual(export_module.name, "sentence_prediction")
self.assertFalse(export_module.params.inputs_only)
self.assertEqual(export_module.params.parse_sequence_length, 10)
def test_sentence_prediction(self):
config = sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_classes=2))
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = ckpt.save(self.get_temp_dir())
export_module_cls = export_savedmodel.lookup_export_module(task)
serving_params = {"inputs_only": False}
params = export_module_cls.Params(**serving_params)
export_module = export_module_cls(params=params, model=model)
export_dir = export_savedmodel_util.export(
export_module,
function_keys=["serve"],
checkpoint_path=ckpt_path,
export_savedmodel_dir=self.get_temp_dir())
imported = tf.saved_model.load(export_dir)
serving_fn = imported.signatures["serving_default"]
dummy_ids = tf.ones((1, 5), dtype=tf.int32)
inputs = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
ref_outputs = model(inputs)
outputs = serving_fn(**inputs)
self.assertAllClose(ref_outputs, outputs["outputs"])
self.assertEqual(outputs["outputs"].shape, (1, 2))
def test_masked_lm(self):
config = masked_lm.MaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
cls_heads=[
bert.ClsHeadConfig(inner_dim=10, num_classes=2, name="foo")
]))
task = masked_lm.MaskedLMTask(config)
model = task.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = ckpt.save(self.get_temp_dir())
export_module_cls = export_savedmodel.lookup_export_module(task)
serving_params = {
"cls_head_name": "foo",
"parse_sequence_length": 10,
"max_predictions_per_seq": 5
}
params = export_module_cls.Params(**serving_params)
export_module = export_module_cls(params=params, model=model)
export_dir = export_savedmodel_util.export(
export_module,
function_keys={
"serve": "serving_default",
"serve_examples": "serving_examples"
},
checkpoint_path=ckpt_path,
export_savedmodel_dir=self.get_temp_dir())
imported = tf.saved_model.load(export_dir)
self.assertSameElements(imported.signatures.keys(),
["serving_default", "serving_examples"])
serving_fn = imported.signatures["serving_default"]
dummy_ids = tf.ones((1, 10), dtype=tf.int32)
dummy_pos = tf.ones((1, 5), dtype=tf.int32)
outputs = serving_fn(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_pos)
self.assertEqual(outputs["classification"].shape, (1, 2))
@parameterized.parameters(True, False)
def test_tagging(self, output_encoder_outputs):
hidden_size = 768
num_classes = 3
config = tagging.TaggingConfig(
model=tagging.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(
hidden_size=hidden_size, num_layers=1))),
class_names=["class_0", "class_1", "class_2"])
task = tagging.TaggingTask(config)
model = task.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = ckpt.save(self.get_temp_dir())
export_module_cls = export_savedmodel.lookup_export_module(task)
serving_params = {
"parse_sequence_length": 10,
}
params = export_module_cls.Params(
**serving_params, output_encoder_outputs=output_encoder_outputs)
export_module = export_module_cls(params=params, model=model)
export_dir = export_savedmodel_util.export(
export_module,
function_keys={
"serve": "serving_default",
"serve_examples": "serving_examples"
},
checkpoint_path=ckpt_path,
export_savedmodel_dir=self.get_temp_dir())
imported = tf.saved_model.load(export_dir)
self.assertCountEqual(imported.signatures.keys(),
["serving_default", "serving_examples"])
serving_fn = imported.signatures["serving_default"]
dummy_ids = tf.ones((1, 5), dtype=tf.int32)
inputs = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
outputs = serving_fn(**inputs)
self.assertEqual(outputs["logits"].shape, (1, 5, num_classes))
if output_encoder_outputs:
self.assertEqual(outputs["encoder_outputs"].shape, (1, 5, hidden_size))
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common library to export a SavedModel from the export module."""
from typing import Dict, List, Optional, Text, Union
import tensorflow as tf
from official.core import export_base
def export(export_module: export_base.ExportModule,
function_keys: Union[List[Text], Dict[Text, Text]],
export_savedmodel_dir: Text,
checkpoint_path: Optional[Text] = None,
timestamped: bool = True) -> Text:
"""Exports to SavedModel format.
Args:
export_module: a ExportModule with the keras Model and serving tf.functions.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
Returns:
The savedmodel directory path.
"""
save_options = tf.saved_model.SaveOptions(function_aliases={
"tpu_candidate": export_module.serve,
})
return export_base.export(export_module, function_keys, export_savedmodel_dir,
checkpoint_path, timestamped, save_options)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serving export modules for TF Model Garden NLP models."""
# pylint:disable=missing-class-docstring
from typing import Dict, List, Optional, Text
import dataclasses
import tensorflow as tf
from official.core import export_base
from official.modeling.hyperparams import base_config
from official.nlp.data import sentence_prediction_dataloader
def features_to_int32(features: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""Converts tf.int64 features to tf.int32, keep other features the same.
tf.Example only supports tf.int64, but the TPU only supports tf.int32.
Args:
features: Input tensor dictionary.
Returns:
Features with tf.int64 converted to tf.int32.
"""
converted_features = {}
for name, tensor in features.items():
if tensor.dtype == tf.int64:
converted_features[name] = tf.cast(tensor, tf.int32)
else:
converted_features[name] = tensor
return converted_features
class SentencePrediction(export_base.ExportModule):
"""The export module for the sentence prediction task."""
@dataclasses.dataclass
class Params(base_config.Config):
inputs_only: bool = True
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
# For text input processing.
text_fields: Optional[List[str]] = None
# Either specify these values for preprocessing by Python code...
tokenization: str = "WordPiece" # WordPiece or SentencePiece
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
# file if tokenization is SentencePiece.
vocab_file: str = ""
lower_case: bool = True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url: str = ""
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
if params.text_fields:
self._text_processor = sentence_prediction_dataloader.TextProcessor(
seq_length=params.parse_sequence_length,
vocab_file=params.vocab_file,
tokenization=params.tokenization,
lower_case=params.lower_case,
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
if input_type_ids is None:
# Requires CLS token is the first token of inputs.
input_type_ids = tf.zeros_like(input_word_ids)
if input_mask is None:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask = tf.where(
tf.equal(input_word_ids, 0), tf.zeros_like(input_word_ids),
tf.ones_like(input_word_ids))
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
return dict(outputs=self.inference_step(inputs))
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
inputs_only = self.params.inputs_only
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
}
if not inputs_only:
name_to_features.update({
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
})
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
features[self.input_word_ids_field],
input_mask=None if inputs_only else features["input_mask"],
input_type_ids=None
if inputs_only else features[self.input_type_ids_field])
@tf.function
def serve_text_examples(self, inputs) -> Dict[str, tf.Tensor]:
name_to_features = {}
for text_field in self.params.text_fields:
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
features = tf.io.parse_example(inputs, name_to_features)
segments = [features[x] for x in self.params.text_fields]
model_inputs = self._text_processor(segments)
if self.params.inputs_only:
return self.serve(input_word_ids=model_inputs["input_word_ids"])
return self.serve(**model_inputs)
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples", "serve_text_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
if self.params.inputs_only:
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"))
else:
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
if func_key == "serve_text_examples":
signatures[
signature_key] = self.serve_text_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class MaskedLM(export_base.ExportModule):
"""The export module for the Bert Pretrain (MaskedLM) task."""
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@dataclasses.dataclass
class Params(base_config.Config):
cls_head_name: str = "next_sentence"
use_v2_feature_names: bool = True
parse_sequence_length: Optional[int] = None
max_predictions_per_seq: Optional[int] = None
@tf.function
def serve(self, input_word_ids, input_mask, input_type_ids,
masked_lm_positions) -> Dict[str, tf.Tensor]:
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids,
masked_lm_positions=masked_lm_positions)
outputs = self.inference_step(inputs)
return dict(classification=outputs[self.params.cls_head_name])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
max_predictions_per_seq = self.params.max_predictions_per_seq
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"masked_lm_positions":
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_word_ids_field],
masked_lm_positions=features["masked_lm_positions"])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"),
masked_lm_positions=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="masked_lm_positions"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class QuestionAnswering(export_base.ExportModule):
"""The export module for the question answering task."""
@dataclasses.dataclass
class Params(base_config.Config):
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
if input_mask is None:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask = tf.where(
tf.equal(input_word_ids, 0), tf.zeros_like(input_word_ids),
tf.ones_like(input_word_ids))
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
outputs = self.inference_step(inputs)
return dict(start_logits=outputs[0], end_logits=outputs[1])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_type_ids_field])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class Tagging(export_base.ExportModule):
"""The export module for the tagging task."""
@dataclasses.dataclass
class Params(base_config.Config):
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
output_encoder_outputs: bool = False
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@tf.function
def serve(self, input_word_ids, input_mask,
input_type_ids) -> Dict[str, tf.Tensor]:
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
outputs = self.inference_step(inputs)
if self.params.output_encoder_outputs:
return dict(
logits=outputs["logits"], encoder_outputs=outputs["encoder_outputs"])
else:
return dict(logits=outputs["logits"])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_type_ids_field])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None],
dtype=tf.int32,
name=self.input_word_ids_field),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None],
dtype=tf.int32,
name=self.input_type_ids_field))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.serving_modules."""
import os
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.serving import serving_modules
from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
def _create_fake_serialized_examples(features_dict):
"""Creates a fake dataset."""
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_str_feature(value):
f = tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
return f
examples = []
for _ in range(10):
features = {}
for key, values in features_dict.items():
if isinstance(values, bytes):
features[key] = create_str_feature(values)
else:
features[key] = create_int_feature(values)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
examples.append(tf_example.SerializeToString())
return tf.constant(examples)
def _create_fake_vocab_file(vocab_file_path):
tokens = ["[PAD]"]
for i in range(1, 100):
tokens.append("[unused%d]" % i)
tokens.extend(["[UNK]", "[CLS]", "[SEP]", "[MASK]", "hello", "world"])
with tf.io.gfile.GFile(vocab_file_path, "w") as outfile:
outfile.write("\n".join(tokens))
class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
# use_v2_feature_names
True,
False)
def test_sentence_prediction(self, use_v2_feature_names):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
config = sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_classes=2))
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
params = serving_modules.SentencePrediction.Params(
inputs_only=True,
parse_sequence_length=10,
use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.SentencePrediction(
params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
self.assertSameElements(functions.keys(),
["serving_default", "serving_examples"])
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
outputs = functions["serving_default"](dummy_ids)
self.assertEqual(outputs["outputs"].shape, (10, 2))
params = serving_modules.SentencePrediction.Params(
inputs_only=False,
parse_sequence_length=10,
use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.SentencePrediction(
params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
self.assertEqual(outputs["outputs"].shape, (10, 2))
dummy_ids = tf.ones((10,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["outputs"].shape, (10, 2))
with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None})
@parameterized.parameters(
# inputs_only
True,
False)
def test_sentence_prediction_text(self, inputs_only):
vocab_file_path = os.path.join(self.get_temp_dir(), "vocab.txt")
_create_fake_vocab_file(vocab_file_path)
config = sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_classes=2))
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
params = serving_modules.SentencePrediction.Params(
inputs_only=inputs_only,
parse_sequence_length=10,
text_fields=["foo", "bar"],
vocab_file=vocab_file_path)
export_module = serving_modules.SentencePrediction(
params=params, model=model)
examples = _create_fake_serialized_examples({
"foo": b"hello world",
"bar": b"hello world"
})
functions = export_module.get_inference_signatures({
"serve_text_examples": "serving_default",
})
outputs = functions["serving_default"](examples)
self.assertEqual(outputs["outputs"].shape, (10, 2))
@parameterized.parameters(
# use_v2_feature_names
True,
False)
def test_masked_lm(self, use_v2_feature_names):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
config = masked_lm.MaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]))
task = masked_lm.MaskedLMTask(config)
model = task.build_model()
params = serving_modules.MaskedLM.Params(
parse_sequence_length=10,
max_predictions_per_seq=5,
use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.MaskedLM(params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
self.assertSameElements(functions.keys(),
["serving_default", "serving_examples"])
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
dummy_pos = tf.ones((10, 5), dtype=tf.int32)
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_pos)
self.assertEqual(outputs["classification"].shape, (10, 2))
dummy_ids = tf.ones((10,), dtype=tf.int32)
dummy_pos = tf.ones((5,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids,
"masked_lm_positions": dummy_pos
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["classification"].shape, (10, 2))
@parameterized.parameters(
# use_v2_feature_names
True,
False)
def test_question_answering(self, use_v2_feature_names):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
config = question_answering.QuestionAnsweringConfig(
model=question_answering.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1))),
validation_data=None)
task = question_answering.QuestionAnsweringTask(config)
model = task.build_model()
params = serving_modules.QuestionAnswering.Params(
parse_sequence_length=10, use_v2_feature_names=use_v2_feature_names)
export_module = serving_modules.QuestionAnswering(
params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
self.assertSameElements(functions.keys(),
["serving_default", "serving_examples"])
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
self.assertEqual(outputs["start_logits"].shape, (10, 10))
self.assertEqual(outputs["end_logits"].shape, (10, 10))
dummy_ids = tf.ones((10,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["start_logits"].shape, (10, 10))
self.assertEqual(outputs["end_logits"].shape, (10, 10))
@parameterized.parameters(
# (use_v2_feature_names, output_encoder_outputs)
(True, True),
(False, False))
def test_tagging(self, use_v2_feature_names, output_encoder_outputs):
if use_v2_feature_names:
input_word_ids_field = "input_word_ids"
input_type_ids_field = "input_type_ids"
else:
input_word_ids_field = "input_ids"
input_type_ids_field = "segment_ids"
hidden_size = 768
num_classes = 3
config = tagging.TaggingConfig(
model=tagging.ModelConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(
hidden_size=hidden_size, num_layers=1))),
class_names=["class_0", "class_1", "class_2"])
task = tagging.TaggingTask(config)
model = task.build_model()
params = serving_modules.Tagging.Params(
parse_sequence_length=10,
use_v2_feature_names=use_v2_feature_names,
output_encoder_outputs=output_encoder_outputs)
export_module = serving_modules.Tagging(params=params, model=model)
functions = export_module.get_inference_signatures({
"serve": "serving_default",
"serve_examples": "serving_examples"
})
dummy_ids = tf.ones((10, 10), dtype=tf.int32)
outputs = functions["serving_default"](
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
self.assertEqual(outputs["logits"].shape, (10, 10, num_classes))
if output_encoder_outputs:
self.assertEqual(outputs["encoder_outputs"].shape, (10, 10, hidden_size))
dummy_ids = tf.ones((10,), dtype=tf.int32)
examples = _create_fake_serialized_examples({
input_word_ids_field: dummy_ids,
"input_mask": dummy_ids,
input_type_ids_field: dummy_ids
})
outputs = functions["serving_examples"](examples)
self.assertEqual(outputs["logits"].shape, (10, 10, num_classes))
if output_encoder_outputs:
self.assertEqual(outputs["encoder_outputs"].shape, (10, 10, hidden_size))
with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None})
if __name__ == "__main__":
tf.test.main()
......@@ -63,12 +63,14 @@ def convert_to_feature(value, value_type=None):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
elif value_type == 'int64_list':
value = np.asarray(value).astype(np.int64).reshape(-1)
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
elif value_type == 'float':
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
elif value_type == 'float_list':
value = np.asarray(value).astype(np.float32).reshape(-1)
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
elif value_type == 'bytes':
......@@ -172,4 +174,3 @@ def check_and_make_dir(directory):
"""Creates the directory if it doesn't exist."""
if not tf.io.gfile.isdir(directory):
tf.io.gfile.makedirs(directory)
......@@ -320,6 +320,9 @@ class SpineNetMobile(tf.keras.Model):
endpoints = {}
for i, block_spec in enumerate(self._block_specs):
# Update block level if it is larger than max_level to avoid building
# blocks smaller than requested.
block_spec.level = min(block_spec.level, self._max_level)
# Find out specs for the target block.
target_width = int(math.ceil(input_width / 2**block_spec.level))
target_num_filters = int(FILTER_SIZE_MAP[block_spec.level] *
......@@ -392,8 +395,9 @@ class SpineNetMobile(tf.keras.Model):
block_spec.level))
if (block_spec.level < self._min_level or
block_spec.level > self._max_level):
raise ValueError('Output level is out of range [{}, {}]'.format(
self._min_level, self._max_level))
logging.warning(
'SpineNet output level out of range [min_level, max_levle] = [%s, %s] will not be used for further processing.',
self._min_level, self._max_level)
endpoints[str(block_spec.level)] = x
return endpoints
......
......@@ -1130,11 +1130,18 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
feature_extractor_config.use_separable_conv or
feature_extractor_config.type == 'mobilenet_v2_fpn_sep_conv')
kwargs = {
'channel_means': list(feature_extractor_config.channel_means),
'channel_stds': list(feature_extractor_config.channel_stds),
'bgr_ordering': feature_extractor_config.bgr_ordering,
'depth_multiplier': feature_extractor_config.depth_multiplier,
'use_separable_conv': use_separable_conv,
'channel_means':
list(feature_extractor_config.channel_means),
'channel_stds':
list(feature_extractor_config.channel_stds),
'bgr_ordering':
feature_extractor_config.bgr_ordering,
'depth_multiplier':
feature_extractor_config.depth_multiplier,
'use_separable_conv':
use_separable_conv,
'upsampling_interpolation':
feature_extractor_config.upsampling_interpolation,
}
......
......@@ -398,7 +398,7 @@ class ModelBuilderTF2Test(
}
"""
# Set up the configuration proto.
config = text_format.Merge(proto_txt, model_pb2.DetectionModel())
config = text_format.Parse(proto_txt, model_pb2.DetectionModel())
# Only add object center and keypoint estimation configs here.
config.center_net.object_center_params.CopyFrom(
self.get_fake_object_center_from_keypoints_proto())
......@@ -422,6 +422,50 @@ class ModelBuilderTF2Test(
self.assertEqual(kp_params.keypoint_labels,
['nose', 'left_shoulder', 'right_shoulder', 'hip'])
def test_create_center_net_model_mobilenet(self):
"""Test building a CenterNet model using bilinear interpolation."""
proto_txt = """
center_net {
num_classes: 10
feature_extractor {
type: "mobilenet_v2_fpn"
depth_multiplier: 1.0
use_separable_conv: true
upsampling_interpolation: "bilinear"
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
}
"""
# Set up the configuration proto.
config = text_format.Parse(proto_txt, model_pb2.DetectionModel())
# Only add object center and keypoint estimation configs here.
config.center_net.object_center_params.CopyFrom(
self.get_fake_object_center_from_keypoints_proto())
config.center_net.keypoint_estimation_task.append(
self.get_fake_keypoint_proto())
config.center_net.keypoint_label_map_path = (
self.get_fake_label_map_file_path())
# Build the model from the configuration.
model = model_builder.build(config, is_training=True)
feature_extractor = model._feature_extractor
# Verify the upsampling layers in the FPN use 'bilinear' interpolation.
fpn = feature_extractor.get_layer('model_1')
num_up_sampling2d_layers = 0
for layer in fpn.layers:
if 'up_sampling2d' in layer.name:
num_up_sampling2d_layers += 1
self.assertEqual('bilinear', layer.interpolation)
# Verify that there are up_sampling2d layers.
self.assertGreater(num_up_sampling2d_layers, 0)
if __name__ == '__main__':
tf.test.main()
......@@ -1776,6 +1776,7 @@ def random_pad_image(image,
min_image_size=None,
max_image_size=None,
pad_color=None,
center_pad=False,
seed=None,
preprocess_vars_cache=None):
"""Randomly pads the image.
......@@ -1814,6 +1815,8 @@ def random_pad_image(image,
pad_color: padding color. A rank 1 tensor of [channels] with dtype=
tf.float32. if set as None, it will be set to average color of
the input image.
center_pad: whether the original image will be padded to the center, or
randomly padded (which is default).
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
......@@ -1870,6 +1873,12 @@ def random_pad_image(image,
lambda: _random_integer(0, target_width - image_width, seed),
lambda: tf.constant(0, dtype=tf.int32))
if center_pad:
offset_height = tf.cast(tf.floor((target_height - image_height) / 2),
tf.int32)
offset_width = tf.cast(tf.floor((target_width - image_width) / 2),
tf.int32)
gen_func = lambda: (target_height, target_width, offset_height, offset_width)
params = _get_or_create_preprocess_rand_vars(
gen_func, preprocessor_cache.PreprocessorCache.PAD_IMAGE,
......@@ -2113,7 +2122,7 @@ def random_crop_pad_image(image,
max_padded_size_ratio,
dtype=tf.int32)
padded_image, padded_boxes = random_pad_image(
padded_image, padded_boxes = random_pad_image( # pylint: disable=unbalanced-tuple-unpacking
cropped_image,
cropped_boxes,
min_image_size=min_image_size,
......@@ -2153,6 +2162,7 @@ def random_crop_to_aspect_ratio(image,
aspect_ratio=1.0,
overlap_thresh=0.3,
clip_boxes=True,
center_crop=False,
seed=None,
preprocess_vars_cache=None):
"""Randomly crops an image to the specified aspect ratio.
......@@ -2191,6 +2201,7 @@ def random_crop_to_aspect_ratio(image,
overlap_thresh: minimum overlap thresh with new cropped
image to keep the box.
clip_boxes: whether to clip the boxes to the cropped image.
center_crop: whether to take the center crop or a random crop.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
......@@ -2247,6 +2258,12 @@ def random_crop_to_aspect_ratio(image,
# either offset_height = 0 and offset_width is randomly chosen from
# [0, offset_width - target_width), or else offset_width = 0 and
# offset_height is randomly chosen from [0, offset_height - target_height)
if center_crop:
offset_height = tf.cast(tf.math.floor((orig_height - target_height) / 2),
tf.int32)
offset_width = tf.cast(tf.math.floor((orig_width - target_width) / 2),
tf.int32)
else:
offset_height = _random_integer(0, orig_height - target_height + 1, seed)
offset_width = _random_integer(0, orig_width - target_width + 1, seed)
......@@ -2979,7 +2996,7 @@ def resize_to_range(image,
'per-channel pad value.')
new_image = tf.stack(
[
tf.pad(
tf.pad( # pylint: disable=g-complex-comprehension
channels[i], [[0, max_dimension - new_size[0]],
[0, max_dimension - new_size[1]]],
constant_values=per_channel_pad_value[i])
......
......@@ -2194,6 +2194,54 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
expected_boxes.flatten())
self.assertAllEqual(distorted_masks_.shape, [1, 200, 200])
def testRunRandomCropToAspectRatioCenterCrop(self):
def graph_fn():
image = self.createColorfulTestImage()
boxes = self.createTestBoxes()
labels = self.createTestLabels()
weights = self.createTestGroundtruthWeights()
masks = tf.random_uniform([2, 200, 400], dtype=tf.float32)
tensor_dict = {
fields.InputDataFields.image: image,
fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_classes: labels,
fields.InputDataFields.groundtruth_weights: weights,
fields.InputDataFields.groundtruth_instance_masks: masks
}
preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_instance_masks=True)
preprocessing_options = [(preprocessor.random_crop_to_aspect_ratio, {
'center_crop': True
})]
with mock.patch.object(preprocessor,
'_random_integer') as mock_random_integer:
mock_random_integer.return_value = tf.constant(0, dtype=tf.int32)
distorted_tensor_dict = preprocessor.preprocess(
tensor_dict,
preprocessing_options,
func_arg_map=preprocessor_arg_map)
distorted_image = distorted_tensor_dict[fields.InputDataFields.image]
distorted_boxes = distorted_tensor_dict[
fields.InputDataFields.groundtruth_boxes]
distorted_labels = distorted_tensor_dict[
fields.InputDataFields.groundtruth_classes]
return [
distorted_image, distorted_boxes, distorted_labels
]
(distorted_image_, distorted_boxes_, distorted_labels_) = self.execute_cpu(
graph_fn, [])
expected_boxes = np.array([[0.0, 0.0, 0.75, 1.0],
[0.25, 0.5, 0.75, 1.0]], dtype=np.float32)
self.assertAllEqual(distorted_image_.shape, [1, 200, 200, 3])
self.assertAllEqual(distorted_labels_, [1, 2])
self.assertAllClose(distorted_boxes_.flatten(),
expected_boxes.flatten())
def testRunRandomCropToAspectRatioWithKeypoints(self):
def graph_fn():
image = self.createColorfulTestImage()
......@@ -2433,6 +2481,51 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self.assertTrue(np.all((boxes_[:, 3] - boxes_[:, 1]) >= (
padded_boxes_[:, 3] - padded_boxes_[:, 1])))
def testRandomPadImageCenterPad(self):
def graph_fn():
preprocessing_options = [(preprocessor.normalize_image, {
'original_minval': 0,
'original_maxval': 255,
'target_minval': 0,
'target_maxval': 1
})]
images = self.createColorfulTestImage()
boxes = self.createTestBoxes()
labels = self.createTestLabels()
tensor_dict = {
fields.InputDataFields.image: images,
fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_classes: labels,
}
tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options)
images = tensor_dict[fields.InputDataFields.image]
preprocessing_options = [(preprocessor.random_pad_image, {
'center_pad': True,
'min_image_size': [400, 400],
'max_image_size': [400, 400],
})]
padded_tensor_dict = preprocessor.preprocess(tensor_dict,
preprocessing_options)
padded_images = padded_tensor_dict[fields.InputDataFields.image]
padded_boxes = padded_tensor_dict[
fields.InputDataFields.groundtruth_boxes]
padded_labels = padded_tensor_dict[
fields.InputDataFields.groundtruth_classes]
return [padded_images, padded_boxes, padded_labels]
(padded_images_, padded_boxes_, padded_labels_) = self.execute_cpu(
graph_fn, [])
expected_boxes = np.array([[0.25, 0.25, 0.625, 1.0],
[0.375, 0.5, .625, 1.0]], dtype=np.float32)
self.assertAllEqual(padded_images_.shape, [1, 400, 400, 3])
self.assertAllEqual(padded_labels_, [1, 2])
self.assertAllClose(padded_boxes_.flatten(),
expected_boxes.flatten())
@parameterized.parameters(
{'include_dense_pose': False},
)
......
......@@ -117,23 +117,9 @@ class CenterNetFeatureExtractor(tf.keras.Model):
pass
@property
@abc.abstractmethod
def supported_sub_model_types(self):
"""Valid sub model types supported by the get_sub_model function."""
pass
@abc.abstractmethod
def get_sub_model(self, sub_model_type):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def classification_backbone(self):
raise NotImplementedError(
'Classification backbone not supported for {}'.format(type(self)))
def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
......@@ -4200,25 +4186,28 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
supported_types = self._feature_extractor.supported_sub_model_types
supported_types += ['fine_tune']
if fine_tune_checkpoint_type not in supported_types:
message = ('Checkpoint type "{}" not supported for {}. '
'Supported types are {}')
raise ValueError(
message.format(fine_tune_checkpoint_type,
self._feature_extractor.__class__.__name__,
supported_types))
elif fine_tune_checkpoint_type == 'fine_tune':
if fine_tune_checkpoint_type == 'detection':
feature_extractor_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor)
return {'model': feature_extractor_model}
elif fine_tune_checkpoint_type == 'classification':
return {
'feature_extractor':
self._feature_extractor.classification_backbone
}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
elif fine_tune_checkpoint_type == 'fine_tune':
raise ValueError(('"fine_tune" is no longer supported for CenterNet. '
'Please set fine_tune_checkpoint_type to "detection"'
' which has the same functionality. If you are using'
' the ExtremeNet checkpoint, download the new version'
' from the model zoo.'))
else:
return {'feature_extractor': self._feature_extractor.get_sub_model(
fine_tune_checkpoint_type)}
raise ValueError('Unknown fine tune checkpoint type {}'.format(
fine_tune_checkpoint_type))
def updates(self):
if tf_version.is_tf2():
......
......@@ -17,7 +17,6 @@
from __future__ import division
import functools
import re
import unittest
from absl.testing import parameterized
......@@ -2887,15 +2886,14 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self.assertIsInstance(restore_from_objects_map['feature_extractor'],
tf.keras.Model)
def test_retore_map_error(self):
"""Test that restoring unsupported checkpoint type raises an error."""
def test_retore_map_detection(self):
"""Test that detection checkpoints can be restored."""
model = build_center_net_meta_arch(build_resnet=True)
msg = ("Checkpoint type \"detection\" not supported for "
"CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']")
with self.assertRaisesRegex(ValueError, re.escape(msg)):
model.restore_from_objects('detection')
restore_from_objects_map = model.restore_from_objects('detection')
self.assertIsInstance(restore_from_objects_map['model']._feature_extractor,
tf.keras.Model)
class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor):
......
......@@ -2896,6 +2896,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
_feature_extractor_for_proposal_features=
self._feature_extractor_for_proposal_features)
return {'model': fake_model}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
else:
raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
fine_tune_checkpoint_type))
......
......@@ -35,6 +35,7 @@ from object_detection.protos import train_pb2
from object_detection.utils import config_util
from object_detection.utils import label_map_util
from object_detection.utils import ops
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vutils
......@@ -587,6 +588,9 @@ def train_loop(
lambda: global_step % num_steps_per_iteration == 0):
# Load a fine-tuning checkpoint.
if train_config.fine_tune_checkpoint:
variables_helper.ensure_checkpoint_supported(
train_config.fine_tune_checkpoint, fine_tune_checkpoint_type,
model_dir)
load_fine_tune_checkpoint(
detection_model, train_config.fine_tune_checkpoint,
fine_tune_checkpoint_type, fine_tune_checkpoint_version,
......
......@@ -62,16 +62,6 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
return self._network.num_hourglasses
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def hourglass_10(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-10 backbone for CenterNet."""
......
......@@ -83,9 +83,6 @@ class CenterNetMobileNetV2FeatureExtractor(
def load_feature_extractor_weights(self, path):
self._network.load_weights(path)
def get_base_model(self):
return self._network
def call(self, inputs):
return [self._network(inputs)]
......@@ -100,14 +97,8 @@ class CenterNetMobileNetV2FeatureExtractor(
return 1
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
def classification_backbone(self):
return self._network
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering,
......
......@@ -39,7 +39,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.),
bgr_ordering=False,
use_separable_conv=False):
use_separable_conv=False,
upsampling_interpolation='nearest'):
"""Intializes the feature extractor.
Args:
......@@ -52,6 +53,9 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
[blue, red, green] order.
use_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions.
upsampling_interpolation: A string (one of 'nearest' or 'bilinear')
indicating which interpolation method to use for the upsampling ops in
the FPN.
"""
super(CenterNetMobileNetV2FPNFeatureExtractor, self).__init__(
......@@ -84,7 +88,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
for i, num_filters in enumerate(num_filters_list):
level_ind = len(num_filters_list) - 1 - i
# Upsample.
upsample_op = tf.keras.layers.UpSampling2D(2, interpolation='nearest')
upsample_op = tf.keras.layers.UpSampling2D(
2, interpolation=upsampling_interpolation)
top_down = upsample_op(top_down)
# Residual (skip-connection) from bottom-up pathway.
......@@ -144,7 +149,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
use_separable_conv=False, depth_multiplier=1.0, **kwargs):
use_separable_conv=False, depth_multiplier=1.0,
upsampling_interpolation='nearest', **kwargs):
"""The MobileNetV2+FPN backbone for CenterNet."""
del kwargs
......@@ -159,4 +165,5 @@ def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering,
use_separable_conv=use_separable_conv)
use_separable_conv=use_separable_conv,
upsampling_interpolation=upsampling_interpolation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment