Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
319589aa
"cpp_onnx/src/FeatureQueue.cpp" did not exist on "06924f5d520d11f755f7fd401b2b04d4a1db003b"
Commit
319589aa
authored
May 07, 2021
by
vedanshu
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
parents
64f323b1
eaeea071
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1289 additions
and
81 deletions
+1289
-81
official/core/base_trainer.py
official/core/base_trainer.py
+7
-1
official/nlp/serving/export_savedmodel.py
official/nlp/serving/export_savedmodel.py
+141
-0
official/nlp/serving/export_savedmodel_test.py
official/nlp/serving/export_savedmodel_test.py
+164
-0
official/nlp/serving/export_savedmodel_util.py
official/nlp/serving/export_savedmodel_util.py
+45
-0
official/nlp/serving/serving_modules.py
official/nlp/serving/serving_modules.py
+391
-0
official/nlp/serving/serving_modules_test.py
official/nlp/serving/serving_modules_test.py
+317
-0
official/vision/beta/data/tfrecord_lib.py
official/vision/beta/data/tfrecord_lib.py
+2
-1
official/vision/beta/modeling/backbones/spinenet_mobile.py
official/vision/beta/modeling/backbones/spinenet_mobile.py
+6
-2
official/vision/beta/projects/keypoint/README.md
official/vision/beta/projects/keypoint/README.md
+0
-3
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+12
-5
research/object_detection/builders/model_builder_tf2_test.py
research/object_detection/builders/model_builder_tf2_test.py
+45
-1
research/object_detection/core/preprocessor.py
research/object_detection/core/preprocessor.py
+21
-4
research/object_detection/core/preprocessor_test.py
research/object_detection/core/preprocessor_test.py
+93
-0
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+20
-31
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+6
-8
research/object_detection/meta_architectures/faster_rcnn_meta_arch.py
...ect_detection/meta_architectures/faster_rcnn_meta_arch.py
+2
-0
research/object_detection/model_lib_v2.py
research/object_detection/model_lib_v2.py
+4
-0
research/object_detection/models/center_net_hourglass_feature_extractor.py
...etection/models/center_net_hourglass_feature_extractor.py
+0
-10
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+2
-11
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
...n/models/center_net_mobilenet_v2_fpn_feature_extractor.py
+11
-4
No files found.
official/core/base_trainer.py
View file @
319589aa
...
...
@@ -370,7 +370,13 @@ class Trainer(_AsyncTrainer):
logs
[
metric
.
name
]
=
metric
.
result
()
metric
.
reset_states
()
if
callable
(
self
.
optimizer
.
learning_rate
):
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
global_step
)
# Maybe a self-implemented optimizer does not have `optimizer.iterations`.
# So just to be safe here.
if
hasattr
(
self
.
optimizer
,
"iterations"
):
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
optimizer
.
iterations
)
else
:
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
global_step
)
else
:
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
return
logs
...
...
official/nlp/serving/export_savedmodel.py
0 → 100644
View file @
319589aa
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A binary/library to export TF-NLP serving `SavedModel`."""
import
os
from
typing
import
Any
,
Dict
,
Text
from
absl
import
app
from
absl
import
flags
import
dataclasses
import
yaml
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling
import
hyperparams
from
official.modeling.hyperparams
import
base_config
from
official.nlp.serving
import
export_savedmodel_util
from
official.nlp.serving
import
serving_modules
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
FLAGS
=
flags
.
FLAGS
SERVING_MODULES
=
{
sentence_prediction
.
SentencePredictionTask
:
serving_modules
.
SentencePrediction
,
masked_lm
.
MaskedLMTask
:
serving_modules
.
MaskedLM
,
question_answering
.
QuestionAnsweringTask
:
serving_modules
.
QuestionAnswering
,
tagging
.
TaggingTask
:
serving_modules
.
Tagging
}
def
define_flags
():
"""Defines flags."""
flags
.
DEFINE_string
(
"task_name"
,
"SentencePrediction"
,
"The task to export."
)
flags
.
DEFINE_string
(
"config_file"
,
None
,
"The path to task/experiment yaml config file."
)
flags
.
DEFINE_string
(
"checkpoint_path"
,
None
,
"Object-based checkpoint path, from the training model directory."
)
flags
.
DEFINE_string
(
"export_savedmodel_dir"
,
None
,
"Output saved model directory."
)
flags
.
DEFINE_string
(
"serving_params"
,
None
,
"a YAML/JSON string or csv string for the serving parameters."
)
flags
.
DEFINE_string
(
"function_keys"
,
None
,
"A string key to retrieve pre-defined serving signatures."
)
flags
.
DEFINE_bool
(
"convert_tpu"
,
False
,
""
)
flags
.
DEFINE_multi_integer
(
"allowed_batch_size"
,
None
,
"Allowed batch sizes for batching ops."
)
def
lookup_export_module
(
task
:
base_task
.
Task
):
export_module_cls
=
SERVING_MODULES
.
get
(
task
.
__class__
,
None
)
if
export_module_cls
is
None
:
ValueError
(
"No registered export module for the task: %s"
,
task
.
__class__
)
return
export_module_cls
def
create_export_module
(
*
,
task_name
:
Text
,
config_file
:
Text
,
serving_params
:
Dict
[
Text
,
Any
]):
"""Creates a ExportModule."""
task_config_cls
=
None
task_cls
=
None
# pylint: disable=protected-access
for
key
,
value
in
task_factory
.
_REGISTERED_TASK_CLS
.
items
():
print
(
key
.
__name__
)
if
task_name
in
key
.
__name__
:
task_config_cls
,
task_cls
=
key
,
value
break
if
task_cls
is
None
:
raise
ValueError
(
"Failed to identify the task class. The provided task "
f
"name is
{
task_name
}
"
)
# pylint: enable=protected-access
# TODO(hongkuny): Figure out how to separate the task config from experiments.
@
dataclasses
.
dataclass
class
Dummy
(
base_config
.
Config
):
task
:
task_config_cls
=
task_config_cls
()
dummy_exp
=
Dummy
()
dummy_exp
=
hyperparams
.
override_params_dict
(
dummy_exp
,
config_file
,
is_strict
=
False
)
dummy_exp
.
task
.
validation_data
=
None
task
=
task_cls
(
dummy_exp
.
task
)
model
=
task
.
build_model
()
export_module_cls
=
lookup_export_module
(
task
)
params
=
export_module_cls
.
Params
(
**
serving_params
)
return
export_module_cls
(
params
=
params
,
model
=
model
)
def
main
(
_
):
serving_params
=
yaml
.
load
(
hyperparams
.
nested_csv_str_to_json_str
(
FLAGS
.
serving_params
),
Loader
=
yaml
.
FullLoader
)
export_module
=
create_export_module
(
task_name
=
FLAGS
.
task_name
,
config_file
=
FLAGS
.
config_file
,
serving_params
=
serving_params
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
[
FLAGS
.
function_keys
],
checkpoint_path
=
FLAGS
.
checkpoint_path
,
export_savedmodel_dir
=
FLAGS
.
export_savedmodel_dir
)
if
FLAGS
.
convert_tpu
:
# pylint: disable=g-import-not-at-top
from
cloud_tpu.inference_converter
import
converter_cli
from
cloud_tpu.inference_converter
import
converter_options_pb2
tpu_dir
=
os
.
path
.
join
(
export_dir
,
"tpu"
)
options
=
converter_options_pb2
.
ConverterOptions
()
if
FLAGS
.
allowed_batch_size
is
not
None
:
allowed_batch_sizes
=
sorted
(
FLAGS
.
allowed_batch_size
)
options
.
batch_options
.
num_batch_threads
=
4
options
.
batch_options
.
max_batch_size
=
allowed_batch_sizes
[
-
1
]
options
.
batch_options
.
batch_timeout_micros
=
100000
options
.
batch_options
.
allowed_batch_sizes
[:]
=
allowed_batch_sizes
options
.
batch_options
.
max_enqueued_batches
=
1000
converter_cli
.
ConvertSavedModel
(
export_dir
,
tpu_dir
,
function_alias
=
"tpu_candidate"
,
options
=
options
,
graph_rewrite_only
=
True
)
if
__name__
==
"__main__"
:
define_flags
()
app
.
run
(
main
)
official/nlp/serving/export_savedmodel_test.py
0 → 100644
View file @
319589aa
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.export_saved_model."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.serving
import
export_savedmodel
from
official.nlp.serving
import
export_savedmodel_util
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
class
ExportSavedModelTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_create_export_module
(
self
):
export_module
=
export_savedmodel
.
create_export_module
(
task_name
=
"SentencePrediction"
,
config_file
=
None
,
serving_params
=
{
"inputs_only"
:
False
,
"parse_sequence_length"
:
10
})
self
.
assertEqual
(
export_module
.
name
,
"sentence_prediction"
)
self
.
assertFalse
(
export_module
.
params
.
inputs_only
)
self
.
assertEqual
(
export_module
.
params
.
parse_sequence_length
,
10
)
def
test_sentence_prediction
(
self
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
num_classes
=
2
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
export_module_cls
=
export_savedmodel
.
lookup_export_module
(
task
)
serving_params
=
{
"inputs_only"
:
False
}
params
=
export_module_cls
.
Params
(
**
serving_params
)
export_module
=
export_module_cls
(
params
=
params
,
model
=
model
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
[
"serve"
],
checkpoint_path
=
ckpt_path
,
export_savedmodel_dir
=
self
.
get_temp_dir
())
imported
=
tf
.
saved_model
.
load
(
export_dir
)
serving_fn
=
imported
.
signatures
[
"serving_default"
]
dummy_ids
=
tf
.
ones
((
1
,
5
),
dtype
=
tf
.
int32
)
inputs
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
ref_outputs
=
model
(
inputs
)
outputs
=
serving_fn
(
**
inputs
)
self
.
assertAllClose
(
ref_outputs
,
outputs
[
"outputs"
])
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
1
,
2
))
def
test_masked_lm
(
self
):
config
=
masked_lm
.
MaskedLMConfig
(
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"foo"
)
]))
task
=
masked_lm
.
MaskedLMTask
(
config
)
model
=
task
.
build_model
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
export_module_cls
=
export_savedmodel
.
lookup_export_module
(
task
)
serving_params
=
{
"cls_head_name"
:
"foo"
,
"parse_sequence_length"
:
10
,
"max_predictions_per_seq"
:
5
}
params
=
export_module_cls
.
Params
(
**
serving_params
)
export_module
=
export_module_cls
(
params
=
params
,
model
=
model
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
{
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
},
checkpoint_path
=
ckpt_path
,
export_savedmodel_dir
=
self
.
get_temp_dir
())
imported
=
tf
.
saved_model
.
load
(
export_dir
)
self
.
assertSameElements
(
imported
.
signatures
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
serving_fn
=
imported
.
signatures
[
"serving_default"
]
dummy_ids
=
tf
.
ones
((
1
,
10
),
dtype
=
tf
.
int32
)
dummy_pos
=
tf
.
ones
((
1
,
5
),
dtype
=
tf
.
int32
)
outputs
=
serving_fn
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
,
masked_lm_positions
=
dummy_pos
)
self
.
assertEqual
(
outputs
[
"classification"
].
shape
,
(
1
,
2
))
@
parameterized
.
parameters
(
True
,
False
)
def
test_tagging
(
self
,
output_encoder_outputs
):
hidden_size
=
768
num_classes
=
3
config
=
tagging
.
TaggingConfig
(
model
=
tagging
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
hidden_size
=
hidden_size
,
num_layers
=
1
))),
class_names
=
[
"class_0"
,
"class_1"
,
"class_2"
])
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
export_module_cls
=
export_savedmodel
.
lookup_export_module
(
task
)
serving_params
=
{
"parse_sequence_length"
:
10
,
}
params
=
export_module_cls
.
Params
(
**
serving_params
,
output_encoder_outputs
=
output_encoder_outputs
)
export_module
=
export_module_cls
(
params
=
params
,
model
=
model
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
{
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
},
checkpoint_path
=
ckpt_path
,
export_savedmodel_dir
=
self
.
get_temp_dir
())
imported
=
tf
.
saved_model
.
load
(
export_dir
)
self
.
assertCountEqual
(
imported
.
signatures
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
serving_fn
=
imported
.
signatures
[
"serving_default"
]
dummy_ids
=
tf
.
ones
((
1
,
5
),
dtype
=
tf
.
int32
)
inputs
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
outputs
=
serving_fn
(
**
inputs
)
self
.
assertEqual
(
outputs
[
"logits"
].
shape
,
(
1
,
5
,
num_classes
))
if
output_encoder_outputs
:
self
.
assertEqual
(
outputs
[
"encoder_outputs"
].
shape
,
(
1
,
5
,
hidden_size
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/serving/export_savedmodel_util.py
0 → 100644
View file @
319589aa
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common library to export a SavedModel from the export module."""
from
typing
import
Dict
,
List
,
Optional
,
Text
,
Union
import
tensorflow
as
tf
from
official.core
import
export_base
def
export
(
export_module
:
export_base
.
ExportModule
,
function_keys
:
Union
[
List
[
Text
],
Dict
[
Text
,
Text
]],
export_savedmodel_dir
:
Text
,
checkpoint_path
:
Optional
[
Text
]
=
None
,
timestamped
:
bool
=
True
)
->
Text
:
"""Exports to SavedModel format.
Args:
export_module: a ExportModule with the keras Model and serving tf.functions.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
Returns:
The savedmodel directory path.
"""
save_options
=
tf
.
saved_model
.
SaveOptions
(
function_aliases
=
{
"tpu_candidate"
:
export_module
.
serve
,
})
return
export_base
.
export
(
export_module
,
function_keys
,
export_savedmodel_dir
,
checkpoint_path
,
timestamped
,
save_options
)
official/nlp/serving/serving_modules.py
0 → 100644
View file @
319589aa
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serving export modules for TF Model Garden NLP models."""
# pylint:disable=missing-class-docstring
from
typing
import
Dict
,
List
,
Optional
,
Text
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
export_base
from
official.modeling.hyperparams
import
base_config
from
official.nlp.data
import
sentence_prediction_dataloader
def
features_to_int32
(
features
:
Dict
[
str
,
tf
.
Tensor
])
->
Dict
[
str
,
tf
.
Tensor
]:
"""Converts tf.int64 features to tf.int32, keep other features the same.
tf.Example only supports tf.int64, but the TPU only supports tf.int32.
Args:
features: Input tensor dictionary.
Returns:
Features with tf.int64 converted to tf.int32.
"""
converted_features
=
{}
for
name
,
tensor
in
features
.
items
():
if
tensor
.
dtype
==
tf
.
int64
:
converted_features
[
name
]
=
tf
.
cast
(
tensor
,
tf
.
int32
)
else
:
converted_features
[
name
]
=
tensor
return
converted_features
class
SentencePrediction
(
export_base
.
ExportModule
):
"""The export module for the sentence prediction task."""
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
inputs_only
:
bool
=
True
parse_sequence_length
:
Optional
[
int
]
=
None
use_v2_feature_names
:
bool
=
True
# For text input processing.
text_fields
:
Optional
[
List
[
str
]]
=
None
# Either specify these values for preprocessing by Python code...
tokenization
:
str
=
"WordPiece"
# WordPiece or SentencePiece
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
# file if tokenization is SentencePiece.
vocab_file
:
str
=
""
lower_case
:
bool
=
True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url
:
str
=
""
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
if
params
.
text_fields
:
self
.
_text_processor
=
sentence_prediction_dataloader
.
TextProcessor
(
seq_length
=
params
.
parse_sequence_length
,
vocab_file
=
params
.
vocab_file
,
tokenization
=
params
.
tokenization
,
lower_case
=
params
.
lower_case
,
preprocessing_hub_module_url
=
params
.
preprocessing_hub_module_url
)
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
=
None
,
input_type_ids
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
if
input_type_ids
is
None
:
# Requires CLS token is the first token of inputs.
input_type_ids
=
tf
.
zeros_like
(
input_word_ids
)
if
input_mask
is
None
:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask
=
tf
.
where
(
tf
.
equal
(
input_word_ids
,
0
),
tf
.
zeros_like
(
input_word_ids
),
tf
.
ones_like
(
input_word_ids
))
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
return
dict
(
outputs
=
self
.
inference_step
(
inputs
))
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
inputs_only
=
self
.
params
.
inputs_only
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
}
if
not
inputs_only
:
name_to_features
.
update
({
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
)
})
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
features
[
self
.
input_word_ids_field
],
input_mask
=
None
if
inputs_only
else
features
[
"input_mask"
],
input_type_ids
=
None
if
inputs_only
else
features
[
self
.
input_type_ids_field
])
@
tf
.
function
def
serve_text_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
name_to_features
=
{}
for
text_field
in
self
.
params
.
text_fields
:
name_to_features
[
text_field
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
string
)
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
segments
=
[
features
[
x
]
for
x
in
self
.
params
.
text_fields
]
model_inputs
=
self
.
_text_processor
(
segments
)
if
self
.
params
.
inputs_only
:
return
self
.
serve
(
input_word_ids
=
model_inputs
[
"input_word_ids"
])
return
self
.
serve
(
**
model_inputs
)
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
,
"serve_text_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
if
self
.
params
.
inputs_only
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
))
else
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
if
func_key
==
"serve_text_examples"
:
signatures
[
signature_key
]
=
self
.
serve_text_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
class
MaskedLM
(
export_base
.
ExportModule
):
"""The export module for the Bert Pretrain (MaskedLM) task."""
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
cls_head_name
:
str
=
"next_sentence"
use_v2_feature_names
:
bool
=
True
parse_sequence_length
:
Optional
[
int
]
=
None
max_predictions_per_seq
:
Optional
[
int
]
=
None
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
,
input_type_ids
,
masked_lm_positions
)
->
Dict
[
str
,
tf
.
Tensor
]:
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
,
masked_lm_positions
=
masked_lm_positions
)
outputs
=
self
.
inference_step
(
inputs
)
return
dict
(
classification
=
outputs
[
self
.
params
.
cls_head_name
])
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
max_predictions_per_seq
=
self
.
params
.
max_predictions_per_seq
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"masked_lm_positions"
:
tf
.
io
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
int64
)
}
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
input_word_ids
=
features
[
self
.
input_word_ids_field
],
input_mask
=
features
[
"input_mask"
],
input_type_ids
=
features
[
self
.
input_word_ids_field
],
masked_lm_positions
=
features
[
"masked_lm_positions"
])
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
),
masked_lm_positions
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"masked_lm_positions"
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
class
QuestionAnswering
(
export_base
.
ExportModule
):
"""The export module for the question answering task."""
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
parse_sequence_length
:
Optional
[
int
]
=
None
use_v2_feature_names
:
bool
=
True
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
=
None
,
input_type_ids
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
if
input_mask
is
None
:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask
=
tf
.
where
(
tf
.
equal
(
input_word_ids
,
0
),
tf
.
zeros_like
(
input_word_ids
),
tf
.
ones_like
(
input_word_ids
))
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
outputs
=
self
.
inference_step
(
inputs
)
return
dict
(
start_logits
=
outputs
[
0
],
end_logits
=
outputs
[
1
])
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
)
}
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
input_word_ids
=
features
[
self
.
input_word_ids_field
],
input_mask
=
features
[
"input_mask"
],
input_type_ids
=
features
[
self
.
input_type_ids_field
])
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
class
Tagging
(
export_base
.
ExportModule
):
"""The export module for the tagging task."""
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
parse_sequence_length
:
Optional
[
int
]
=
None
use_v2_feature_names
:
bool
=
True
output_encoder_outputs
:
bool
=
False
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
,
input_type_ids
)
->
Dict
[
str
,
tf
.
Tensor
]:
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
outputs
=
self
.
inference_step
(
inputs
)
if
self
.
params
.
output_encoder_outputs
:
return
dict
(
logits
=
outputs
[
"logits"
],
encoder_outputs
=
outputs
[
"encoder_outputs"
])
else
:
return
dict
(
logits
=
outputs
[
"logits"
])
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
)
}
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
input_word_ids
=
features
[
self
.
input_word_ids_field
],
input_mask
=
features
[
"input_mask"
],
input_type_ids
=
features
[
self
.
input_type_ids_field
])
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
self
.
input_word_ids_field
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
self
.
input_type_ids_field
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
official/nlp/serving/serving_modules_test.py
0 → 100644
View file @
319589aa
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.serving_modules."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.serving
import
serving_modules
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
def
_create_fake_serialized_examples
(
features_dict
):
"""Creates a fake dataset."""
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
def
create_str_feature
(
value
):
f
=
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
value
]))
return
f
examples
=
[]
for
_
in
range
(
10
):
features
=
{}
for
key
,
values
in
features_dict
.
items
():
if
isinstance
(
values
,
bytes
):
features
[
key
]
=
create_str_feature
(
values
)
else
:
features
[
key
]
=
create_int_feature
(
values
)
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
examples
.
append
(
tf_example
.
SerializeToString
())
return
tf
.
constant
(
examples
)
def
_create_fake_vocab_file
(
vocab_file_path
):
tokens
=
[
"[PAD]"
]
for
i
in
range
(
1
,
100
):
tokens
.
append
(
"[unused%d]"
%
i
)
tokens
.
extend
([
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"[MASK]"
,
"hello"
,
"world"
])
with
tf
.
io
.
gfile
.
GFile
(
vocab_file_path
,
"w"
)
as
outfile
:
outfile
.
write
(
"
\n
"
.
join
(
tokens
))
class
ServingModulesTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
# use_v2_feature_names
True
,
False
)
def
test_sentence_prediction
(
self
,
use_v2_feature_names
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
num_classes
=
2
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
SentencePrediction
.
Params
(
inputs_only
=
True
,
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
SentencePrediction
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
self
.
assertSameElements
(
functions
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
dummy_ids
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
params
=
serving_modules
.
SentencePrediction
.
Params
(
inputs_only
=
False
,
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
SentencePrediction
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
with
self
.
assertRaises
(
ValueError
):
_
=
export_module
.
get_inference_signatures
({
"foo"
:
None
})
@
parameterized
.
parameters
(
# inputs_only
True
,
False
)
def
test_sentence_prediction_text
(
self
,
inputs_only
):
vocab_file_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"vocab.txt"
)
_create_fake_vocab_file
(
vocab_file_path
)
config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
num_classes
=
2
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
SentencePrediction
.
Params
(
inputs_only
=
inputs_only
,
parse_sequence_length
=
10
,
text_fields
=
[
"foo"
,
"bar"
],
vocab_file
=
vocab_file_path
)
export_module
=
serving_modules
.
SentencePrediction
(
params
=
params
,
model
=
model
)
examples
=
_create_fake_serialized_examples
({
"foo"
:
b
"hello world"
,
"bar"
:
b
"hello world"
})
functions
=
export_module
.
get_inference_signatures
({
"serve_text_examples"
:
"serving_default"
,
})
outputs
=
functions
[
"serving_default"
](
examples
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
@
parameterized
.
parameters
(
# use_v2_feature_names
True
,
False
)
def
test_masked_lm
(
self
,
use_v2_feature_names
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
config
=
masked_lm
.
MaskedLMConfig
(
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
]))
task
=
masked_lm
.
MaskedLMTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
MaskedLM
.
Params
(
parse_sequence_length
=
10
,
max_predictions_per_seq
=
5
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
MaskedLM
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
self
.
assertSameElements
(
functions
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
dummy_pos
=
tf
.
ones
((
10
,
5
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
,
masked_lm_positions
=
dummy_pos
)
self
.
assertEqual
(
outputs
[
"classification"
].
shape
,
(
10
,
2
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
dummy_pos
=
tf
.
ones
((
5
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
,
"masked_lm_positions"
:
dummy_pos
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"classification"
].
shape
,
(
10
,
2
))
@
parameterized
.
parameters
(
# use_v2_feature_names
True
,
False
)
def
test_question_answering
(
self
,
use_v2_feature_names
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
config
=
question_answering
.
QuestionAnsweringConfig
(
model
=
question_answering
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
))),
validation_data
=
None
)
task
=
question_answering
.
QuestionAnsweringTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
QuestionAnswering
.
Params
(
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
QuestionAnswering
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
self
.
assertSameElements
(
functions
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
self
.
assertEqual
(
outputs
[
"start_logits"
].
shape
,
(
10
,
10
))
self
.
assertEqual
(
outputs
[
"end_logits"
].
shape
,
(
10
,
10
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"start_logits"
].
shape
,
(
10
,
10
))
self
.
assertEqual
(
outputs
[
"end_logits"
].
shape
,
(
10
,
10
))
@
parameterized
.
parameters
(
# (use_v2_feature_names, output_encoder_outputs)
(
True
,
True
),
(
False
,
False
))
def
test_tagging
(
self
,
use_v2_feature_names
,
output_encoder_outputs
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
hidden_size
=
768
num_classes
=
3
config
=
tagging
.
TaggingConfig
(
model
=
tagging
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
hidden_size
=
hidden_size
,
num_layers
=
1
))),
class_names
=
[
"class_0"
,
"class_1"
,
"class_2"
])
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
Tagging
.
Params
(
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
,
output_encoder_outputs
=
output_encoder_outputs
)
export_module
=
serving_modules
.
Tagging
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
self
.
assertEqual
(
outputs
[
"logits"
].
shape
,
(
10
,
10
,
num_classes
))
if
output_encoder_outputs
:
self
.
assertEqual
(
outputs
[
"encoder_outputs"
].
shape
,
(
10
,
10
,
hidden_size
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"logits"
].
shape
,
(
10
,
10
,
num_classes
))
if
output_encoder_outputs
:
self
.
assertEqual
(
outputs
[
"encoder_outputs"
].
shape
,
(
10
,
10
,
hidden_size
))
with
self
.
assertRaises
(
ValueError
):
_
=
export_module
.
get_inference_signatures
({
"foo"
:
None
})
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/vision/beta/data/tfrecord_lib.py
View file @
319589aa
...
...
@@ -63,12 +63,14 @@ def convert_to_feature(value, value_type=None):
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
value
]))
elif
value_type
==
'int64_list'
:
value
=
np
.
asarray
(
value
).
astype
(
np
.
int64
).
reshape
(
-
1
)
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
value
))
elif
value_type
==
'float'
:
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
[
value
]))
elif
value_type
==
'float_list'
:
value
=
np
.
asarray
(
value
).
astype
(
np
.
float32
).
reshape
(
-
1
)
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
value
))
elif
value_type
==
'bytes'
:
...
...
@@ -172,4 +174,3 @@ def check_and_make_dir(directory):
"""Creates the directory if it doesn't exist."""
if
not
tf
.
io
.
gfile
.
isdir
(
directory
):
tf
.
io
.
gfile
.
makedirs
(
directory
)
official/vision/beta/modeling/backbones/spinenet_mobile.py
View file @
319589aa
...
...
@@ -320,6 +320,9 @@ class SpineNetMobile(tf.keras.Model):
endpoints
=
{}
for
i
,
block_spec
in
enumerate
(
self
.
_block_specs
):
# Update block level if it is larger than max_level to avoid building
# blocks smaller than requested.
block_spec
.
level
=
min
(
block_spec
.
level
,
self
.
_max_level
)
# Find out specs for the target block.
target_width
=
int
(
math
.
ceil
(
input_width
/
2
**
block_spec
.
level
))
target_num_filters
=
int
(
FILTER_SIZE_MAP
[
block_spec
.
level
]
*
...
...
@@ -392,8 +395,9 @@ class SpineNetMobile(tf.keras.Model):
block_spec
.
level
))
if
(
block_spec
.
level
<
self
.
_min_level
or
block_spec
.
level
>
self
.
_max_level
):
raise
ValueError
(
'Output level is out of range [{}, {}]'
.
format
(
self
.
_min_level
,
self
.
_max_level
))
logging
.
warning
(
'SpineNet output level out of range [min_level, max_levle] = [%s, %s] will not be used for further processing.'
,
self
.
_min_level
,
self
.
_max_level
)
endpoints
[
str
(
block_spec
.
level
)]
=
x
return
endpoints
...
...
official/vision/beta/projects/keypoint/README.md
deleted
100644 → 0
View file @
64f323b1
# Keypoint Detection Models.
TBD
research/object_detection/builders/model_builder.py
View file @
319589aa
...
...
@@ -1130,11 +1130,18 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
feature_extractor_config
.
use_separable_conv
or
feature_extractor_config
.
type
==
'mobilenet_v2_fpn_sep_conv'
)
kwargs
=
{
'channel_means'
:
list
(
feature_extractor_config
.
channel_means
),
'channel_stds'
:
list
(
feature_extractor_config
.
channel_stds
),
'bgr_ordering'
:
feature_extractor_config
.
bgr_ordering
,
'depth_multiplier'
:
feature_extractor_config
.
depth_multiplier
,
'use_separable_conv'
:
use_separable_conv
,
'channel_means'
:
list
(
feature_extractor_config
.
channel_means
),
'channel_stds'
:
list
(
feature_extractor_config
.
channel_stds
),
'bgr_ordering'
:
feature_extractor_config
.
bgr_ordering
,
'depth_multiplier'
:
feature_extractor_config
.
depth_multiplier
,
'use_separable_conv'
:
use_separable_conv
,
'upsampling_interpolation'
:
feature_extractor_config
.
upsampling_interpolation
,
}
...
...
research/object_detection/builders/model_builder_tf2_test.py
View file @
319589aa
...
...
@@ -398,7 +398,7 @@ class ModelBuilderTF2Test(
}
"""
# Set up the configuration proto.
config
=
text_format
.
Merg
e
(
proto_txt
,
model_pb2
.
DetectionModel
())
config
=
text_format
.
Pars
e
(
proto_txt
,
model_pb2
.
DetectionModel
())
# Only add object center and keypoint estimation configs here.
config
.
center_net
.
object_center_params
.
CopyFrom
(
self
.
get_fake_object_center_from_keypoints_proto
())
...
...
@@ -422,6 +422,50 @@ class ModelBuilderTF2Test(
self
.
assertEqual
(
kp_params
.
keypoint_labels
,
[
'nose'
,
'left_shoulder'
,
'right_shoulder'
,
'hip'
])
def
test_create_center_net_model_mobilenet
(
self
):
"""Test building a CenterNet model using bilinear interpolation."""
proto_txt
=
"""
center_net {
num_classes: 10
feature_extractor {
type: "mobilenet_v2_fpn"
depth_multiplier: 1.0
use_separable_conv: true
upsampling_interpolation: "bilinear"
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
}
"""
# Set up the configuration proto.
config
=
text_format
.
Parse
(
proto_txt
,
model_pb2
.
DetectionModel
())
# Only add object center and keypoint estimation configs here.
config
.
center_net
.
object_center_params
.
CopyFrom
(
self
.
get_fake_object_center_from_keypoints_proto
())
config
.
center_net
.
keypoint_estimation_task
.
append
(
self
.
get_fake_keypoint_proto
())
config
.
center_net
.
keypoint_label_map_path
=
(
self
.
get_fake_label_map_file_path
())
# Build the model from the configuration.
model
=
model_builder
.
build
(
config
,
is_training
=
True
)
feature_extractor
=
model
.
_feature_extractor
# Verify the upsampling layers in the FPN use 'bilinear' interpolation.
fpn
=
feature_extractor
.
get_layer
(
'model_1'
)
num_up_sampling2d_layers
=
0
for
layer
in
fpn
.
layers
:
if
'up_sampling2d'
in
layer
.
name
:
num_up_sampling2d_layers
+=
1
self
.
assertEqual
(
'bilinear'
,
layer
.
interpolation
)
# Verify that there are up_sampling2d layers.
self
.
assertGreater
(
num_up_sampling2d_layers
,
0
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/object_detection/core/preprocessor.py
View file @
319589aa
...
...
@@ -1776,6 +1776,7 @@ def random_pad_image(image,
min_image_size
=
None
,
max_image_size
=
None
,
pad_color
=
None
,
center_pad
=
False
,
seed
=
None
,
preprocess_vars_cache
=
None
):
"""Randomly pads the image.
...
...
@@ -1814,6 +1815,8 @@ def random_pad_image(image,
pad_color: padding color. A rank 1 tensor of [channels] with dtype=
tf.float32. if set as None, it will be set to average color of
the input image.
center_pad: whether the original image will be padded to the center, or
randomly padded (which is default).
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
...
...
@@ -1870,6 +1873,12 @@ def random_pad_image(image,
lambda
:
_random_integer
(
0
,
target_width
-
image_width
,
seed
),
lambda
:
tf
.
constant
(
0
,
dtype
=
tf
.
int32
))
if
center_pad
:
offset_height
=
tf
.
cast
(
tf
.
floor
((
target_height
-
image_height
)
/
2
),
tf
.
int32
)
offset_width
=
tf
.
cast
(
tf
.
floor
((
target_width
-
image_width
)
/
2
),
tf
.
int32
)
gen_func
=
lambda
:
(
target_height
,
target_width
,
offset_height
,
offset_width
)
params
=
_get_or_create_preprocess_rand_vars
(
gen_func
,
preprocessor_cache
.
PreprocessorCache
.
PAD_IMAGE
,
...
...
@@ -2113,7 +2122,7 @@ def random_crop_pad_image(image,
max_padded_size_ratio
,
dtype
=
tf
.
int32
)
padded_image
,
padded_boxes
=
random_pad_image
(
padded_image
,
padded_boxes
=
random_pad_image
(
# pylint: disable=unbalanced-tuple-unpacking
cropped_image
,
cropped_boxes
,
min_image_size
=
min_image_size
,
...
...
@@ -2153,6 +2162,7 @@ def random_crop_to_aspect_ratio(image,
aspect_ratio
=
1.0
,
overlap_thresh
=
0.3
,
clip_boxes
=
True
,
center_crop
=
False
,
seed
=
None
,
preprocess_vars_cache
=
None
):
"""Randomly crops an image to the specified aspect ratio.
...
...
@@ -2191,6 +2201,7 @@ def random_crop_to_aspect_ratio(image,
overlap_thresh: minimum overlap thresh with new cropped
image to keep the box.
clip_boxes: whether to clip the boxes to the cropped image.
center_crop: whether to take the center crop or a random crop.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
...
...
@@ -2247,8 +2258,14 @@ def random_crop_to_aspect_ratio(image,
# either offset_height = 0 and offset_width is randomly chosen from
# [0, offset_width - target_width), or else offset_width = 0 and
# offset_height is randomly chosen from [0, offset_height - target_height)
offset_height
=
_random_integer
(
0
,
orig_height
-
target_height
+
1
,
seed
)
offset_width
=
_random_integer
(
0
,
orig_width
-
target_width
+
1
,
seed
)
if
center_crop
:
offset_height
=
tf
.
cast
(
tf
.
math
.
floor
((
orig_height
-
target_height
)
/
2
),
tf
.
int32
)
offset_width
=
tf
.
cast
(
tf
.
math
.
floor
((
orig_width
-
target_width
)
/
2
),
tf
.
int32
)
else
:
offset_height
=
_random_integer
(
0
,
orig_height
-
target_height
+
1
,
seed
)
offset_width
=
_random_integer
(
0
,
orig_width
-
target_width
+
1
,
seed
)
generator_func
=
lambda
:
(
offset_height
,
offset_width
)
offset_height
,
offset_width
=
_get_or_create_preprocess_rand_vars
(
...
...
@@ -2979,7 +2996,7 @@ def resize_to_range(image,
'per-channel pad value.'
)
new_image
=
tf
.
stack
(
[
tf
.
pad
(
tf
.
pad
(
# pylint: disable=g-complex-comprehension
channels
[
i
],
[[
0
,
max_dimension
-
new_size
[
0
]],
[
0
,
max_dimension
-
new_size
[
1
]]],
constant_values
=
per_channel_pad_value
[
i
])
...
...
research/object_detection/core/preprocessor_test.py
View file @
319589aa
...
...
@@ -2194,6 +2194,54 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
expected_boxes
.
flatten
())
self
.
assertAllEqual
(
distorted_masks_
.
shape
,
[
1
,
200
,
200
])
def
testRunRandomCropToAspectRatioCenterCrop
(
self
):
def
graph_fn
():
image
=
self
.
createColorfulTestImage
()
boxes
=
self
.
createTestBoxes
()
labels
=
self
.
createTestLabels
()
weights
=
self
.
createTestGroundtruthWeights
()
masks
=
tf
.
random_uniform
([
2
,
200
,
400
],
dtype
=
tf
.
float32
)
tensor_dict
=
{
fields
.
InputDataFields
.
image
:
image
,
fields
.
InputDataFields
.
groundtruth_boxes
:
boxes
,
fields
.
InputDataFields
.
groundtruth_classes
:
labels
,
fields
.
InputDataFields
.
groundtruth_weights
:
weights
,
fields
.
InputDataFields
.
groundtruth_instance_masks
:
masks
}
preprocessor_arg_map
=
preprocessor
.
get_default_func_arg_map
(
include_instance_masks
=
True
)
preprocessing_options
=
[(
preprocessor
.
random_crop_to_aspect_ratio
,
{
'center_crop'
:
True
})]
with
mock
.
patch
.
object
(
preprocessor
,
'_random_integer'
)
as
mock_random_integer
:
mock_random_integer
.
return_value
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
)
distorted_tensor_dict
=
preprocessor
.
preprocess
(
tensor_dict
,
preprocessing_options
,
func_arg_map
=
preprocessor_arg_map
)
distorted_image
=
distorted_tensor_dict
[
fields
.
InputDataFields
.
image
]
distorted_boxes
=
distorted_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
]
distorted_labels
=
distorted_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_classes
]
return
[
distorted_image
,
distorted_boxes
,
distorted_labels
]
(
distorted_image_
,
distorted_boxes_
,
distorted_labels_
)
=
self
.
execute_cpu
(
graph_fn
,
[])
expected_boxes
=
np
.
array
([[
0.0
,
0.0
,
0.75
,
1.0
],
[
0.25
,
0.5
,
0.75
,
1.0
]],
dtype
=
np
.
float32
)
self
.
assertAllEqual
(
distorted_image_
.
shape
,
[
1
,
200
,
200
,
3
])
self
.
assertAllEqual
(
distorted_labels_
,
[
1
,
2
])
self
.
assertAllClose
(
distorted_boxes_
.
flatten
(),
expected_boxes
.
flatten
())
def
testRunRandomCropToAspectRatioWithKeypoints
(
self
):
def
graph_fn
():
image
=
self
.
createColorfulTestImage
()
...
...
@@ -2433,6 +2481,51 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self
.
assertTrue
(
np
.
all
((
boxes_
[:,
3
]
-
boxes_
[:,
1
])
>=
(
padded_boxes_
[:,
3
]
-
padded_boxes_
[:,
1
])))
def
testRandomPadImageCenterPad
(
self
):
def
graph_fn
():
preprocessing_options
=
[(
preprocessor
.
normalize_image
,
{
'original_minval'
:
0
,
'original_maxval'
:
255
,
'target_minval'
:
0
,
'target_maxval'
:
1
})]
images
=
self
.
createColorfulTestImage
()
boxes
=
self
.
createTestBoxes
()
labels
=
self
.
createTestLabels
()
tensor_dict
=
{
fields
.
InputDataFields
.
image
:
images
,
fields
.
InputDataFields
.
groundtruth_boxes
:
boxes
,
fields
.
InputDataFields
.
groundtruth_classes
:
labels
,
}
tensor_dict
=
preprocessor
.
preprocess
(
tensor_dict
,
preprocessing_options
)
images
=
tensor_dict
[
fields
.
InputDataFields
.
image
]
preprocessing_options
=
[(
preprocessor
.
random_pad_image
,
{
'center_pad'
:
True
,
'min_image_size'
:
[
400
,
400
],
'max_image_size'
:
[
400
,
400
],
})]
padded_tensor_dict
=
preprocessor
.
preprocess
(
tensor_dict
,
preprocessing_options
)
padded_images
=
padded_tensor_dict
[
fields
.
InputDataFields
.
image
]
padded_boxes
=
padded_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
]
padded_labels
=
padded_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_classes
]
return
[
padded_images
,
padded_boxes
,
padded_labels
]
(
padded_images_
,
padded_boxes_
,
padded_labels_
)
=
self
.
execute_cpu
(
graph_fn
,
[])
expected_boxes
=
np
.
array
([[
0.25
,
0.25
,
0.625
,
1.0
],
[
0.375
,
0.5
,
.
625
,
1.0
]],
dtype
=
np
.
float32
)
self
.
assertAllEqual
(
padded_images_
.
shape
,
[
1
,
400
,
400
,
3
])
self
.
assertAllEqual
(
padded_labels_
,
[
1
,
2
])
self
.
assertAllClose
(
padded_boxes_
.
flatten
(),
expected_boxes
.
flatten
())
@
parameterized
.
parameters
(
{
'include_dense_pose'
:
False
},
)
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
319589aa
...
...
@@ -117,23 +117,9 @@ class CenterNetFeatureExtractor(tf.keras.Model):
pass
@
property
@
abc
.
abstractmethod
def
supported_sub_model_types
(
self
):
"""Valid sub model types supported by the get_sub_model function."""
pass
@
abc
.
abstractmethod
def
get_sub_model
(
self
,
sub_model_type
):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def
classification_backbone
(
self
):
raise
NotImplementedError
(
'Classification backbone not supported for {}'
.
format
(
type
(
self
)))
def
make_prediction_net
(
num_out_channels
,
kernel_sizes
=
(
3
),
num_filters
=
(
256
),
...
...
@@ -4200,25 +4186,28 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
supported_types
=
self
.
_feature_extractor
.
supported_sub_model_types
supported_types
+=
[
'fine_tune'
]
if
fine_tune_checkpoint_type
not
in
supported_types
:
message
=
(
'Checkpoint type "{}" not supported for {}. '
'Supported types are {}'
)
raise
ValueError
(
message
.
format
(
fine_tune_checkpoint_type
,
self
.
_feature_extractor
.
__class__
.
__name__
,
supported_types
))
elif
fine_tune_checkpoint_type
==
'fine_tune'
:
if
fine_tune_checkpoint_type
==
'detection'
:
feature_extractor_model
=
tf
.
train
.
Checkpoint
(
_feature_extractor
=
self
.
_feature_extractor
)
return
{
'model'
:
feature_extractor_model
}
elif
fine_tune_checkpoint_type
==
'classification'
:
return
{
'feature_extractor'
:
self
.
_feature_extractor
.
classification_backbone
}
elif
fine_tune_checkpoint_type
==
'full'
:
return
{
'model'
:
self
}
elif
fine_tune_checkpoint_type
==
'fine_tune'
:
raise
ValueError
((
'"fine_tune" is no longer supported for CenterNet. '
'Please set fine_tune_checkpoint_type to "detection"'
' which has the same functionality. If you are using'
' the ExtremeNet checkpoint, download the new version'
' from the model zoo.'
))
else
:
r
eturn
{
'feature_extractor'
:
self
.
_feature_extractor
.
get_sub_model
(
fine_tune_checkpoint_type
)
}
r
aise
ValueError
(
'Unknown fine tune checkpoint type {}'
.
format
(
fine_tune_checkpoint_type
)
)
def
updates
(
self
):
if
tf_version
.
is_tf2
():
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
319589aa
...
...
@@ -17,7 +17,6 @@
from
__future__
import
division
import
functools
import
re
import
unittest
from
absl.testing
import
parameterized
...
...
@@ -2887,15 +2886,14 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self
.
assertIsInstance
(
restore_from_objects_map
[
'feature_extractor'
],
tf
.
keras
.
Model
)
def
test_retore_map_
error
(
self
):
"""Test that
restoring unsupported checkpoint type raise
s an
error
."""
def
test_retore_map_
detection
(
self
):
"""Test that
detection checkpoint
s
c
an
be restored
."""
model
=
build_center_net_meta_arch
(
build_resnet
=
True
)
msg
=
(
"Checkpoint type
\"
detection
\"
not supported for "
"CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']"
)
with
self
.
assertRaisesRegex
(
ValueError
,
re
.
escape
(
msg
)):
model
.
restore_from_objects
(
'detection'
)
restore_from_objects_map
=
model
.
restore_from_objects
(
'detection'
)
self
.
assertIsInstance
(
restore_from_objects_map
[
'model'
].
_feature_extractor
,
tf
.
keras
.
Model
)
class
DummyFeatureExtractor
(
cnma
.
CenterNetFeatureExtractor
):
...
...
research/object_detection/meta_architectures/faster_rcnn_meta_arch.py
View file @
319589aa
...
...
@@ -2896,6 +2896,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
_feature_extractor_for_proposal_features
=
self
.
_feature_extractor_for_proposal_features
)
return
{
'model'
:
fake_model
}
elif
fine_tune_checkpoint_type
==
'full'
:
return
{
'model'
:
self
}
else
:
raise
ValueError
(
'Not supported fine_tune_checkpoint_type: {}'
.
format
(
fine_tune_checkpoint_type
))
...
...
research/object_detection/model_lib_v2.py
View file @
319589aa
...
...
@@ -35,6 +35,7 @@ from object_detection.protos import train_pb2
from
object_detection.utils
import
config_util
from
object_detection.utils
import
label_map_util
from
object_detection.utils
import
ops
from
object_detection.utils
import
variables_helper
from
object_detection.utils
import
visualization_utils
as
vutils
...
...
@@ -587,6 +588,9 @@ def train_loop(
lambda
:
global_step
%
num_steps_per_iteration
==
0
):
# Load a fine-tuning checkpoint.
if
train_config
.
fine_tune_checkpoint
:
variables_helper
.
ensure_checkpoint_supported
(
train_config
.
fine_tune_checkpoint
,
fine_tune_checkpoint_type
,
model_dir
)
load_fine_tune_checkpoint
(
detection_model
,
train_config
.
fine_tune_checkpoint
,
fine_tune_checkpoint_type
,
fine_tune_checkpoint_version
,
...
...
research/object_detection/models/center_net_hourglass_feature_extractor.py
View file @
319589aa
...
...
@@ -62,16 +62,6 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
return
self
.
_network
.
num_hourglasses
@
property
def
supported_sub_model_types
(
self
):
return
[
'detection'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
hourglass_10
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The Hourglass-10 backbone for CenterNet."""
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
319589aa
...
...
@@ -83,9 +83,6 @@ class CenterNetMobileNetV2FeatureExtractor(
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_network
.
load_weights
(
path
)
def
get_base_model
(
self
):
return
self
.
_network
def
call
(
self
,
inputs
):
return
[
self
.
_network
(
inputs
)]
...
...
@@ -100,14 +97,8 @@ class CenterNetMobileNetV2FeatureExtractor(
return
1
@
property
def
supported_sub_model_types
(
self
):
return
[
'detection'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
classification_backbone
(
self
):
return
self
.
_network
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
,
...
...
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
View file @
319589aa
...
...
@@ -39,7 +39,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means
=
(
0.
,
0.
,
0.
),
channel_stds
=
(
1.
,
1.
,
1.
),
bgr_ordering
=
False
,
use_separable_conv
=
False
):
use_separable_conv
=
False
,
upsampling_interpolation
=
'nearest'
):
"""Intializes the feature extractor.
Args:
...
...
@@ -52,6 +53,9 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
[blue, red, green] order.
use_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions.
upsampling_interpolation: A string (one of 'nearest' or 'bilinear')
indicating which interpolation method to use for the upsampling ops in
the FPN.
"""
super
(
CenterNetMobileNetV2FPNFeatureExtractor
,
self
).
__init__
(
...
...
@@ -84,7 +88,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
for
i
,
num_filters
in
enumerate
(
num_filters_list
):
level_ind
=
len
(
num_filters_list
)
-
1
-
i
# Upsample.
upsample_op
=
tf
.
keras
.
layers
.
UpSampling2D
(
2
,
interpolation
=
'nearest'
)
upsample_op
=
tf
.
keras
.
layers
.
UpSampling2D
(
2
,
interpolation
=
upsampling_interpolation
)
top_down
=
upsample_op
(
top_down
)
# Residual (skip-connection) from bottom-up pathway.
...
...
@@ -144,7 +149,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
def
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
use_separable_conv
=
False
,
depth_multiplier
=
1.0
,
**
kwargs
):
use_separable_conv
=
False
,
depth_multiplier
=
1.0
,
upsampling_interpolation
=
'nearest'
,
**
kwargs
):
"""The MobileNetV2+FPN backbone for CenterNet."""
del
kwargs
...
...
@@ -159,4 +165,5 @@ def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
,
use_separable_conv
=
use_separable_conv
)
use_separable_conv
=
use_separable_conv
,
upsampling_interpolation
=
upsampling_interpolation
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment