Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
319589aa
Commit
319589aa
authored
May 07, 2021
by
vedanshu
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
parents
64f323b1
eaeea071
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1289 additions
and
81 deletions
+1289
-81
official/core/base_trainer.py
official/core/base_trainer.py
+7
-1
official/nlp/serving/export_savedmodel.py
official/nlp/serving/export_savedmodel.py
+141
-0
official/nlp/serving/export_savedmodel_test.py
official/nlp/serving/export_savedmodel_test.py
+164
-0
official/nlp/serving/export_savedmodel_util.py
official/nlp/serving/export_savedmodel_util.py
+45
-0
official/nlp/serving/serving_modules.py
official/nlp/serving/serving_modules.py
+391
-0
official/nlp/serving/serving_modules_test.py
official/nlp/serving/serving_modules_test.py
+317
-0
official/vision/beta/data/tfrecord_lib.py
official/vision/beta/data/tfrecord_lib.py
+2
-1
official/vision/beta/modeling/backbones/spinenet_mobile.py
official/vision/beta/modeling/backbones/spinenet_mobile.py
+6
-2
official/vision/beta/projects/keypoint/README.md
official/vision/beta/projects/keypoint/README.md
+0
-3
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+12
-5
research/object_detection/builders/model_builder_tf2_test.py
research/object_detection/builders/model_builder_tf2_test.py
+45
-1
research/object_detection/core/preprocessor.py
research/object_detection/core/preprocessor.py
+21
-4
research/object_detection/core/preprocessor_test.py
research/object_detection/core/preprocessor_test.py
+93
-0
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+20
-31
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+6
-8
research/object_detection/meta_architectures/faster_rcnn_meta_arch.py
...ect_detection/meta_architectures/faster_rcnn_meta_arch.py
+2
-0
research/object_detection/model_lib_v2.py
research/object_detection/model_lib_v2.py
+4
-0
research/object_detection/models/center_net_hourglass_feature_extractor.py
...etection/models/center_net_hourglass_feature_extractor.py
+0
-10
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+2
-11
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
...n/models/center_net_mobilenet_v2_fpn_feature_extractor.py
+11
-4
No files found.
official/core/base_trainer.py
View file @
319589aa
...
@@ -370,7 +370,13 @@ class Trainer(_AsyncTrainer):
...
@@ -370,7 +370,13 @@ class Trainer(_AsyncTrainer):
logs
[
metric
.
name
]
=
metric
.
result
()
logs
[
metric
.
name
]
=
metric
.
result
()
metric
.
reset_states
()
metric
.
reset_states
()
if
callable
(
self
.
optimizer
.
learning_rate
):
if
callable
(
self
.
optimizer
.
learning_rate
):
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
global_step
)
# Maybe a self-implemented optimizer does not have `optimizer.iterations`.
# So just to be safe here.
if
hasattr
(
self
.
optimizer
,
"iterations"
):
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
optimizer
.
iterations
)
else
:
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
global_step
)
else
:
else
:
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
return
logs
return
logs
...
...
official/nlp/serving/export_savedmodel.py
0 → 100644
View file @
319589aa
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A binary/library to export TF-NLP serving `SavedModel`."""
import
os
from
typing
import
Any
,
Dict
,
Text
from
absl
import
app
from
absl
import
flags
import
dataclasses
import
yaml
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling
import
hyperparams
from
official.modeling.hyperparams
import
base_config
from
official.nlp.serving
import
export_savedmodel_util
from
official.nlp.serving
import
serving_modules
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
FLAGS
=
flags
.
FLAGS
SERVING_MODULES
=
{
sentence_prediction
.
SentencePredictionTask
:
serving_modules
.
SentencePrediction
,
masked_lm
.
MaskedLMTask
:
serving_modules
.
MaskedLM
,
question_answering
.
QuestionAnsweringTask
:
serving_modules
.
QuestionAnswering
,
tagging
.
TaggingTask
:
serving_modules
.
Tagging
}
def
define_flags
():
"""Defines flags."""
flags
.
DEFINE_string
(
"task_name"
,
"SentencePrediction"
,
"The task to export."
)
flags
.
DEFINE_string
(
"config_file"
,
None
,
"The path to task/experiment yaml config file."
)
flags
.
DEFINE_string
(
"checkpoint_path"
,
None
,
"Object-based checkpoint path, from the training model directory."
)
flags
.
DEFINE_string
(
"export_savedmodel_dir"
,
None
,
"Output saved model directory."
)
flags
.
DEFINE_string
(
"serving_params"
,
None
,
"a YAML/JSON string or csv string for the serving parameters."
)
flags
.
DEFINE_string
(
"function_keys"
,
None
,
"A string key to retrieve pre-defined serving signatures."
)
flags
.
DEFINE_bool
(
"convert_tpu"
,
False
,
""
)
flags
.
DEFINE_multi_integer
(
"allowed_batch_size"
,
None
,
"Allowed batch sizes for batching ops."
)
def
lookup_export_module
(
task
:
base_task
.
Task
):
export_module_cls
=
SERVING_MODULES
.
get
(
task
.
__class__
,
None
)
if
export_module_cls
is
None
:
ValueError
(
"No registered export module for the task: %s"
,
task
.
__class__
)
return
export_module_cls
def
create_export_module
(
*
,
task_name
:
Text
,
config_file
:
Text
,
serving_params
:
Dict
[
Text
,
Any
]):
"""Creates a ExportModule."""
task_config_cls
=
None
task_cls
=
None
# pylint: disable=protected-access
for
key
,
value
in
task_factory
.
_REGISTERED_TASK_CLS
.
items
():
print
(
key
.
__name__
)
if
task_name
in
key
.
__name__
:
task_config_cls
,
task_cls
=
key
,
value
break
if
task_cls
is
None
:
raise
ValueError
(
"Failed to identify the task class. The provided task "
f
"name is
{
task_name
}
"
)
# pylint: enable=protected-access
# TODO(hongkuny): Figure out how to separate the task config from experiments.
@
dataclasses
.
dataclass
class
Dummy
(
base_config
.
Config
):
task
:
task_config_cls
=
task_config_cls
()
dummy_exp
=
Dummy
()
dummy_exp
=
hyperparams
.
override_params_dict
(
dummy_exp
,
config_file
,
is_strict
=
False
)
dummy_exp
.
task
.
validation_data
=
None
task
=
task_cls
(
dummy_exp
.
task
)
model
=
task
.
build_model
()
export_module_cls
=
lookup_export_module
(
task
)
params
=
export_module_cls
.
Params
(
**
serving_params
)
return
export_module_cls
(
params
=
params
,
model
=
model
)
def
main
(
_
):
serving_params
=
yaml
.
load
(
hyperparams
.
nested_csv_str_to_json_str
(
FLAGS
.
serving_params
),
Loader
=
yaml
.
FullLoader
)
export_module
=
create_export_module
(
task_name
=
FLAGS
.
task_name
,
config_file
=
FLAGS
.
config_file
,
serving_params
=
serving_params
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
[
FLAGS
.
function_keys
],
checkpoint_path
=
FLAGS
.
checkpoint_path
,
export_savedmodel_dir
=
FLAGS
.
export_savedmodel_dir
)
if
FLAGS
.
convert_tpu
:
# pylint: disable=g-import-not-at-top
from
cloud_tpu.inference_converter
import
converter_cli
from
cloud_tpu.inference_converter
import
converter_options_pb2
tpu_dir
=
os
.
path
.
join
(
export_dir
,
"tpu"
)
options
=
converter_options_pb2
.
ConverterOptions
()
if
FLAGS
.
allowed_batch_size
is
not
None
:
allowed_batch_sizes
=
sorted
(
FLAGS
.
allowed_batch_size
)
options
.
batch_options
.
num_batch_threads
=
4
options
.
batch_options
.
max_batch_size
=
allowed_batch_sizes
[
-
1
]
options
.
batch_options
.
batch_timeout_micros
=
100000
options
.
batch_options
.
allowed_batch_sizes
[:]
=
allowed_batch_sizes
options
.
batch_options
.
max_enqueued_batches
=
1000
converter_cli
.
ConvertSavedModel
(
export_dir
,
tpu_dir
,
function_alias
=
"tpu_candidate"
,
options
=
options
,
graph_rewrite_only
=
True
)
if
__name__
==
"__main__"
:
define_flags
()
app
.
run
(
main
)
official/nlp/serving/export_savedmodel_test.py
0 → 100644
View file @
319589aa
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.export_saved_model."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.serving
import
export_savedmodel
from
official.nlp.serving
import
export_savedmodel_util
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
class
ExportSavedModelTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_create_export_module
(
self
):
export_module
=
export_savedmodel
.
create_export_module
(
task_name
=
"SentencePrediction"
,
config_file
=
None
,
serving_params
=
{
"inputs_only"
:
False
,
"parse_sequence_length"
:
10
})
self
.
assertEqual
(
export_module
.
name
,
"sentence_prediction"
)
self
.
assertFalse
(
export_module
.
params
.
inputs_only
)
self
.
assertEqual
(
export_module
.
params
.
parse_sequence_length
,
10
)
def
test_sentence_prediction
(
self
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
num_classes
=
2
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
export_module_cls
=
export_savedmodel
.
lookup_export_module
(
task
)
serving_params
=
{
"inputs_only"
:
False
}
params
=
export_module_cls
.
Params
(
**
serving_params
)
export_module
=
export_module_cls
(
params
=
params
,
model
=
model
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
[
"serve"
],
checkpoint_path
=
ckpt_path
,
export_savedmodel_dir
=
self
.
get_temp_dir
())
imported
=
tf
.
saved_model
.
load
(
export_dir
)
serving_fn
=
imported
.
signatures
[
"serving_default"
]
dummy_ids
=
tf
.
ones
((
1
,
5
),
dtype
=
tf
.
int32
)
inputs
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
ref_outputs
=
model
(
inputs
)
outputs
=
serving_fn
(
**
inputs
)
self
.
assertAllClose
(
ref_outputs
,
outputs
[
"outputs"
])
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
1
,
2
))
def
test_masked_lm
(
self
):
config
=
masked_lm
.
MaskedLMConfig
(
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"foo"
)
]))
task
=
masked_lm
.
MaskedLMTask
(
config
)
model
=
task
.
build_model
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
export_module_cls
=
export_savedmodel
.
lookup_export_module
(
task
)
serving_params
=
{
"cls_head_name"
:
"foo"
,
"parse_sequence_length"
:
10
,
"max_predictions_per_seq"
:
5
}
params
=
export_module_cls
.
Params
(
**
serving_params
)
export_module
=
export_module_cls
(
params
=
params
,
model
=
model
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
{
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
},
checkpoint_path
=
ckpt_path
,
export_savedmodel_dir
=
self
.
get_temp_dir
())
imported
=
tf
.
saved_model
.
load
(
export_dir
)
self
.
assertSameElements
(
imported
.
signatures
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
serving_fn
=
imported
.
signatures
[
"serving_default"
]
dummy_ids
=
tf
.
ones
((
1
,
10
),
dtype
=
tf
.
int32
)
dummy_pos
=
tf
.
ones
((
1
,
5
),
dtype
=
tf
.
int32
)
outputs
=
serving_fn
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
,
masked_lm_positions
=
dummy_pos
)
self
.
assertEqual
(
outputs
[
"classification"
].
shape
,
(
1
,
2
))
@
parameterized
.
parameters
(
True
,
False
)
def
test_tagging
(
self
,
output_encoder_outputs
):
hidden_size
=
768
num_classes
=
3
config
=
tagging
.
TaggingConfig
(
model
=
tagging
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
hidden_size
=
hidden_size
,
num_layers
=
1
))),
class_names
=
[
"class_0"
,
"class_1"
,
"class_2"
])
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
export_module_cls
=
export_savedmodel
.
lookup_export_module
(
task
)
serving_params
=
{
"parse_sequence_length"
:
10
,
}
params
=
export_module_cls
.
Params
(
**
serving_params
,
output_encoder_outputs
=
output_encoder_outputs
)
export_module
=
export_module_cls
(
params
=
params
,
model
=
model
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
{
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
},
checkpoint_path
=
ckpt_path
,
export_savedmodel_dir
=
self
.
get_temp_dir
())
imported
=
tf
.
saved_model
.
load
(
export_dir
)
self
.
assertCountEqual
(
imported
.
signatures
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
serving_fn
=
imported
.
signatures
[
"serving_default"
]
dummy_ids
=
tf
.
ones
((
1
,
5
),
dtype
=
tf
.
int32
)
inputs
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
outputs
=
serving_fn
(
**
inputs
)
self
.
assertEqual
(
outputs
[
"logits"
].
shape
,
(
1
,
5
,
num_classes
))
if
output_encoder_outputs
:
self
.
assertEqual
(
outputs
[
"encoder_outputs"
].
shape
,
(
1
,
5
,
hidden_size
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/serving/export_savedmodel_util.py
0 → 100644
View file @
319589aa
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common library to export a SavedModel from the export module."""
from
typing
import
Dict
,
List
,
Optional
,
Text
,
Union
import
tensorflow
as
tf
from
official.core
import
export_base
def
export
(
export_module
:
export_base
.
ExportModule
,
function_keys
:
Union
[
List
[
Text
],
Dict
[
Text
,
Text
]],
export_savedmodel_dir
:
Text
,
checkpoint_path
:
Optional
[
Text
]
=
None
,
timestamped
:
bool
=
True
)
->
Text
:
"""Exports to SavedModel format.
Args:
export_module: a ExportModule with the keras Model and serving tf.functions.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
Returns:
The savedmodel directory path.
"""
save_options
=
tf
.
saved_model
.
SaveOptions
(
function_aliases
=
{
"tpu_candidate"
:
export_module
.
serve
,
})
return
export_base
.
export
(
export_module
,
function_keys
,
export_savedmodel_dir
,
checkpoint_path
,
timestamped
,
save_options
)
official/nlp/serving/serving_modules.py
0 → 100644
View file @
319589aa
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serving export modules for TF Model Garden NLP models."""
# pylint:disable=missing-class-docstring
from
typing
import
Dict
,
List
,
Optional
,
Text
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
export_base
from
official.modeling.hyperparams
import
base_config
from
official.nlp.data
import
sentence_prediction_dataloader
def
features_to_int32
(
features
:
Dict
[
str
,
tf
.
Tensor
])
->
Dict
[
str
,
tf
.
Tensor
]:
"""Converts tf.int64 features to tf.int32, keep other features the same.
tf.Example only supports tf.int64, but the TPU only supports tf.int32.
Args:
features: Input tensor dictionary.
Returns:
Features with tf.int64 converted to tf.int32.
"""
converted_features
=
{}
for
name
,
tensor
in
features
.
items
():
if
tensor
.
dtype
==
tf
.
int64
:
converted_features
[
name
]
=
tf
.
cast
(
tensor
,
tf
.
int32
)
else
:
converted_features
[
name
]
=
tensor
return
converted_features
class
SentencePrediction
(
export_base
.
ExportModule
):
"""The export module for the sentence prediction task."""
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
inputs_only
:
bool
=
True
parse_sequence_length
:
Optional
[
int
]
=
None
use_v2_feature_names
:
bool
=
True
# For text input processing.
text_fields
:
Optional
[
List
[
str
]]
=
None
# Either specify these values for preprocessing by Python code...
tokenization
:
str
=
"WordPiece"
# WordPiece or SentencePiece
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
# file if tokenization is SentencePiece.
vocab_file
:
str
=
""
lower_case
:
bool
=
True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url
:
str
=
""
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
if
params
.
text_fields
:
self
.
_text_processor
=
sentence_prediction_dataloader
.
TextProcessor
(
seq_length
=
params
.
parse_sequence_length
,
vocab_file
=
params
.
vocab_file
,
tokenization
=
params
.
tokenization
,
lower_case
=
params
.
lower_case
,
preprocessing_hub_module_url
=
params
.
preprocessing_hub_module_url
)
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
=
None
,
input_type_ids
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
if
input_type_ids
is
None
:
# Requires CLS token is the first token of inputs.
input_type_ids
=
tf
.
zeros_like
(
input_word_ids
)
if
input_mask
is
None
:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask
=
tf
.
where
(
tf
.
equal
(
input_word_ids
,
0
),
tf
.
zeros_like
(
input_word_ids
),
tf
.
ones_like
(
input_word_ids
))
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
return
dict
(
outputs
=
self
.
inference_step
(
inputs
))
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
inputs_only
=
self
.
params
.
inputs_only
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
}
if
not
inputs_only
:
name_to_features
.
update
({
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
)
})
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
features
[
self
.
input_word_ids_field
],
input_mask
=
None
if
inputs_only
else
features
[
"input_mask"
],
input_type_ids
=
None
if
inputs_only
else
features
[
self
.
input_type_ids_field
])
@
tf
.
function
def
serve_text_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
name_to_features
=
{}
for
text_field
in
self
.
params
.
text_fields
:
name_to_features
[
text_field
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
string
)
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
segments
=
[
features
[
x
]
for
x
in
self
.
params
.
text_fields
]
model_inputs
=
self
.
_text_processor
(
segments
)
if
self
.
params
.
inputs_only
:
return
self
.
serve
(
input_word_ids
=
model_inputs
[
"input_word_ids"
])
return
self
.
serve
(
**
model_inputs
)
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
,
"serve_text_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
if
self
.
params
.
inputs_only
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
))
else
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
if
func_key
==
"serve_text_examples"
:
signatures
[
signature_key
]
=
self
.
serve_text_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
class
MaskedLM
(
export_base
.
ExportModule
):
"""The export module for the Bert Pretrain (MaskedLM) task."""
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
cls_head_name
:
str
=
"next_sentence"
use_v2_feature_names
:
bool
=
True
parse_sequence_length
:
Optional
[
int
]
=
None
max_predictions_per_seq
:
Optional
[
int
]
=
None
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
,
input_type_ids
,
masked_lm_positions
)
->
Dict
[
str
,
tf
.
Tensor
]:
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
,
masked_lm_positions
=
masked_lm_positions
)
outputs
=
self
.
inference_step
(
inputs
)
return
dict
(
classification
=
outputs
[
self
.
params
.
cls_head_name
])
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
max_predictions_per_seq
=
self
.
params
.
max_predictions_per_seq
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"masked_lm_positions"
:
tf
.
io
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
int64
)
}
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
input_word_ids
=
features
[
self
.
input_word_ids_field
],
input_mask
=
features
[
"input_mask"
],
input_type_ids
=
features
[
self
.
input_word_ids_field
],
masked_lm_positions
=
features
[
"masked_lm_positions"
])
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
),
masked_lm_positions
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"masked_lm_positions"
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
class
QuestionAnswering
(
export_base
.
ExportModule
):
"""The export module for the question answering task."""
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
parse_sequence_length
:
Optional
[
int
]
=
None
use_v2_feature_names
:
bool
=
True
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
=
None
,
input_type_ids
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
if
input_mask
is
None
:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask
=
tf
.
where
(
tf
.
equal
(
input_word_ids
,
0
),
tf
.
zeros_like
(
input_word_ids
),
tf
.
ones_like
(
input_word_ids
))
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
outputs
=
self
.
inference_step
(
inputs
)
return
dict
(
start_logits
=
outputs
[
0
],
end_logits
=
outputs
[
1
])
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
)
}
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
input_word_ids
=
features
[
self
.
input_word_ids_field
],
input_mask
=
features
[
"input_mask"
],
input_type_ids
=
features
[
self
.
input_type_ids_field
])
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
class
Tagging
(
export_base
.
ExportModule
):
"""The export module for the tagging task."""
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
parse_sequence_length
:
Optional
[
int
]
=
None
use_v2_feature_names
:
bool
=
True
output_encoder_outputs
:
bool
=
False
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
,
input_type_ids
)
->
Dict
[
str
,
tf
.
Tensor
]:
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
outputs
=
self
.
inference_step
(
inputs
)
if
self
.
params
.
output_encoder_outputs
:
return
dict
(
logits
=
outputs
[
"logits"
],
encoder_outputs
=
outputs
[
"encoder_outputs"
])
else
:
return
dict
(
logits
=
outputs
[
"logits"
])
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
)
}
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
input_word_ids
=
features
[
self
.
input_word_ids_field
],
input_mask
=
features
[
"input_mask"
],
input_type_ids
=
features
[
self
.
input_type_ids_field
])
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
self
.
input_word_ids_field
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
self
.
input_type_ids_field
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
official/nlp/serving/serving_modules_test.py
0 → 100644
View file @
319589aa
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.serving_modules."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.serving
import
serving_modules
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
def
_create_fake_serialized_examples
(
features_dict
):
"""Creates a fake dataset."""
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
def
create_str_feature
(
value
):
f
=
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
value
]))
return
f
examples
=
[]
for
_
in
range
(
10
):
features
=
{}
for
key
,
values
in
features_dict
.
items
():
if
isinstance
(
values
,
bytes
):
features
[
key
]
=
create_str_feature
(
values
)
else
:
features
[
key
]
=
create_int_feature
(
values
)
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
examples
.
append
(
tf_example
.
SerializeToString
())
return
tf
.
constant
(
examples
)
def
_create_fake_vocab_file
(
vocab_file_path
):
tokens
=
[
"[PAD]"
]
for
i
in
range
(
1
,
100
):
tokens
.
append
(
"[unused%d]"
%
i
)
tokens
.
extend
([
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"[MASK]"
,
"hello"
,
"world"
])
with
tf
.
io
.
gfile
.
GFile
(
vocab_file_path
,
"w"
)
as
outfile
:
outfile
.
write
(
"
\n
"
.
join
(
tokens
))
class
ServingModulesTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
# use_v2_feature_names
True
,
False
)
def
test_sentence_prediction
(
self
,
use_v2_feature_names
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
num_classes
=
2
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
SentencePrediction
.
Params
(
inputs_only
=
True
,
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
SentencePrediction
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
self
.
assertSameElements
(
functions
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
dummy_ids
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
params
=
serving_modules
.
SentencePrediction
.
Params
(
inputs_only
=
False
,
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
SentencePrediction
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
with
self
.
assertRaises
(
ValueError
):
_
=
export_module
.
get_inference_signatures
({
"foo"
:
None
})
@
parameterized
.
parameters
(
# inputs_only
True
,
False
)
def
test_sentence_prediction_text
(
self
,
inputs_only
):
vocab_file_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"vocab.txt"
)
_create_fake_vocab_file
(
vocab_file_path
)
config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
num_classes
=
2
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
SentencePrediction
.
Params
(
inputs_only
=
inputs_only
,
parse_sequence_length
=
10
,
text_fields
=
[
"foo"
,
"bar"
],
vocab_file
=
vocab_file_path
)
export_module
=
serving_modules
.
SentencePrediction
(
params
=
params
,
model
=
model
)
examples
=
_create_fake_serialized_examples
({
"foo"
:
b
"hello world"
,
"bar"
:
b
"hello world"
})
functions
=
export_module
.
get_inference_signatures
({
"serve_text_examples"
:
"serving_default"
,
})
outputs
=
functions
[
"serving_default"
](
examples
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
@
parameterized
.
parameters
(
# use_v2_feature_names
True
,
False
)
def
test_masked_lm
(
self
,
use_v2_feature_names
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
config
=
masked_lm
.
MaskedLMConfig
(
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
]))
task
=
masked_lm
.
MaskedLMTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
MaskedLM
.
Params
(
parse_sequence_length
=
10
,
max_predictions_per_seq
=
5
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
MaskedLM
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
self
.
assertSameElements
(
functions
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
dummy_pos
=
tf
.
ones
((
10
,
5
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
,
masked_lm_positions
=
dummy_pos
)
self
.
assertEqual
(
outputs
[
"classification"
].
shape
,
(
10
,
2
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
dummy_pos
=
tf
.
ones
((
5
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
,
"masked_lm_positions"
:
dummy_pos
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"classification"
].
shape
,
(
10
,
2
))
@
parameterized
.
parameters
(
# use_v2_feature_names
True
,
False
)
def
test_question_answering
(
self
,
use_v2_feature_names
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
config
=
question_answering
.
QuestionAnsweringConfig
(
model
=
question_answering
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
))),
validation_data
=
None
)
task
=
question_answering
.
QuestionAnsweringTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
QuestionAnswering
.
Params
(
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
QuestionAnswering
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
self
.
assertSameElements
(
functions
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
self
.
assertEqual
(
outputs
[
"start_logits"
].
shape
,
(
10
,
10
))
self
.
assertEqual
(
outputs
[
"end_logits"
].
shape
,
(
10
,
10
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"start_logits"
].
shape
,
(
10
,
10
))
self
.
assertEqual
(
outputs
[
"end_logits"
].
shape
,
(
10
,
10
))
@
parameterized
.
parameters
(
# (use_v2_feature_names, output_encoder_outputs)
(
True
,
True
),
(
False
,
False
))
def
test_tagging
(
self
,
use_v2_feature_names
,
output_encoder_outputs
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
hidden_size
=
768
num_classes
=
3
config
=
tagging
.
TaggingConfig
(
model
=
tagging
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
hidden_size
=
hidden_size
,
num_layers
=
1
))),
class_names
=
[
"class_0"
,
"class_1"
,
"class_2"
])
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
Tagging
.
Params
(
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
,
output_encoder_outputs
=
output_encoder_outputs
)
export_module
=
serving_modules
.
Tagging
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
self
.
assertEqual
(
outputs
[
"logits"
].
shape
,
(
10
,
10
,
num_classes
))
if
output_encoder_outputs
:
self
.
assertEqual
(
outputs
[
"encoder_outputs"
].
shape
,
(
10
,
10
,
hidden_size
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"logits"
].
shape
,
(
10
,
10
,
num_classes
))
if
output_encoder_outputs
:
self
.
assertEqual
(
outputs
[
"encoder_outputs"
].
shape
,
(
10
,
10
,
hidden_size
))
with
self
.
assertRaises
(
ValueError
):
_
=
export_module
.
get_inference_signatures
({
"foo"
:
None
})
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/vision/beta/data/tfrecord_lib.py
View file @
319589aa
...
@@ -63,12 +63,14 @@ def convert_to_feature(value, value_type=None):
...
@@ -63,12 +63,14 @@ def convert_to_feature(value, value_type=None):
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
value
]))
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
value
]))
elif
value_type
==
'int64_list'
:
elif
value_type
==
'int64_list'
:
value
=
np
.
asarray
(
value
).
astype
(
np
.
int64
).
reshape
(
-
1
)
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
value
))
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
value
))
elif
value_type
==
'float'
:
elif
value_type
==
'float'
:
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
[
value
]))
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
[
value
]))
elif
value_type
==
'float_list'
:
elif
value_type
==
'float_list'
:
value
=
np
.
asarray
(
value
).
astype
(
np
.
float32
).
reshape
(
-
1
)
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
value
))
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
value
))
elif
value_type
==
'bytes'
:
elif
value_type
==
'bytes'
:
...
@@ -172,4 +174,3 @@ def check_and_make_dir(directory):
...
@@ -172,4 +174,3 @@ def check_and_make_dir(directory):
"""Creates the directory if it doesn't exist."""
"""Creates the directory if it doesn't exist."""
if
not
tf
.
io
.
gfile
.
isdir
(
directory
):
if
not
tf
.
io
.
gfile
.
isdir
(
directory
):
tf
.
io
.
gfile
.
makedirs
(
directory
)
tf
.
io
.
gfile
.
makedirs
(
directory
)
official/vision/beta/modeling/backbones/spinenet_mobile.py
View file @
319589aa
...
@@ -320,6 +320,9 @@ class SpineNetMobile(tf.keras.Model):
...
@@ -320,6 +320,9 @@ class SpineNetMobile(tf.keras.Model):
endpoints
=
{}
endpoints
=
{}
for
i
,
block_spec
in
enumerate
(
self
.
_block_specs
):
for
i
,
block_spec
in
enumerate
(
self
.
_block_specs
):
# Update block level if it is larger than max_level to avoid building
# blocks smaller than requested.
block_spec
.
level
=
min
(
block_spec
.
level
,
self
.
_max_level
)
# Find out specs for the target block.
# Find out specs for the target block.
target_width
=
int
(
math
.
ceil
(
input_width
/
2
**
block_spec
.
level
))
target_width
=
int
(
math
.
ceil
(
input_width
/
2
**
block_spec
.
level
))
target_num_filters
=
int
(
FILTER_SIZE_MAP
[
block_spec
.
level
]
*
target_num_filters
=
int
(
FILTER_SIZE_MAP
[
block_spec
.
level
]
*
...
@@ -392,8 +395,9 @@ class SpineNetMobile(tf.keras.Model):
...
@@ -392,8 +395,9 @@ class SpineNetMobile(tf.keras.Model):
block_spec
.
level
))
block_spec
.
level
))
if
(
block_spec
.
level
<
self
.
_min_level
or
if
(
block_spec
.
level
<
self
.
_min_level
or
block_spec
.
level
>
self
.
_max_level
):
block_spec
.
level
>
self
.
_max_level
):
raise
ValueError
(
'Output level is out of range [{}, {}]'
.
format
(
logging
.
warning
(
self
.
_min_level
,
self
.
_max_level
))
'SpineNet output level out of range [min_level, max_levle] = [%s, %s] will not be used for further processing.'
,
self
.
_min_level
,
self
.
_max_level
)
endpoints
[
str
(
block_spec
.
level
)]
=
x
endpoints
[
str
(
block_spec
.
level
)]
=
x
return
endpoints
return
endpoints
...
...
official/vision/beta/projects/keypoint/README.md
deleted
100644 → 0
View file @
64f323b1
# Keypoint Detection Models.
TBD
research/object_detection/builders/model_builder.py
View file @
319589aa
...
@@ -1130,11 +1130,18 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
...
@@ -1130,11 +1130,18 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
feature_extractor_config
.
use_separable_conv
or
feature_extractor_config
.
use_separable_conv
or
feature_extractor_config
.
type
==
'mobilenet_v2_fpn_sep_conv'
)
feature_extractor_config
.
type
==
'mobilenet_v2_fpn_sep_conv'
)
kwargs
=
{
kwargs
=
{
'channel_means'
:
list
(
feature_extractor_config
.
channel_means
),
'channel_means'
:
'channel_stds'
:
list
(
feature_extractor_config
.
channel_stds
),
list
(
feature_extractor_config
.
channel_means
),
'bgr_ordering'
:
feature_extractor_config
.
bgr_ordering
,
'channel_stds'
:
'depth_multiplier'
:
feature_extractor_config
.
depth_multiplier
,
list
(
feature_extractor_config
.
channel_stds
),
'use_separable_conv'
:
use_separable_conv
,
'bgr_ordering'
:
feature_extractor_config
.
bgr_ordering
,
'depth_multiplier'
:
feature_extractor_config
.
depth_multiplier
,
'use_separable_conv'
:
use_separable_conv
,
'upsampling_interpolation'
:
feature_extractor_config
.
upsampling_interpolation
,
}
}
...
...
research/object_detection/builders/model_builder_tf2_test.py
View file @
319589aa
...
@@ -398,7 +398,7 @@ class ModelBuilderTF2Test(
...
@@ -398,7 +398,7 @@ class ModelBuilderTF2Test(
}
}
"""
"""
# Set up the configuration proto.
# Set up the configuration proto.
config
=
text_format
.
Merg
e
(
proto_txt
,
model_pb2
.
DetectionModel
())
config
=
text_format
.
Pars
e
(
proto_txt
,
model_pb2
.
DetectionModel
())
# Only add object center and keypoint estimation configs here.
# Only add object center and keypoint estimation configs here.
config
.
center_net
.
object_center_params
.
CopyFrom
(
config
.
center_net
.
object_center_params
.
CopyFrom
(
self
.
get_fake_object_center_from_keypoints_proto
())
self
.
get_fake_object_center_from_keypoints_proto
())
...
@@ -422,6 +422,50 @@ class ModelBuilderTF2Test(
...
@@ -422,6 +422,50 @@ class ModelBuilderTF2Test(
self
.
assertEqual
(
kp_params
.
keypoint_labels
,
self
.
assertEqual
(
kp_params
.
keypoint_labels
,
[
'nose'
,
'left_shoulder'
,
'right_shoulder'
,
'hip'
])
[
'nose'
,
'left_shoulder'
,
'right_shoulder'
,
'hip'
])
def
test_create_center_net_model_mobilenet
(
self
):
"""Test building a CenterNet model using bilinear interpolation."""
proto_txt
=
"""
center_net {
num_classes: 10
feature_extractor {
type: "mobilenet_v2_fpn"
depth_multiplier: 1.0
use_separable_conv: true
upsampling_interpolation: "bilinear"
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
}
"""
# Set up the configuration proto.
config
=
text_format
.
Parse
(
proto_txt
,
model_pb2
.
DetectionModel
())
# Only add object center and keypoint estimation configs here.
config
.
center_net
.
object_center_params
.
CopyFrom
(
self
.
get_fake_object_center_from_keypoints_proto
())
config
.
center_net
.
keypoint_estimation_task
.
append
(
self
.
get_fake_keypoint_proto
())
config
.
center_net
.
keypoint_label_map_path
=
(
self
.
get_fake_label_map_file_path
())
# Build the model from the configuration.
model
=
model_builder
.
build
(
config
,
is_training
=
True
)
feature_extractor
=
model
.
_feature_extractor
# Verify the upsampling layers in the FPN use 'bilinear' interpolation.
fpn
=
feature_extractor
.
get_layer
(
'model_1'
)
num_up_sampling2d_layers
=
0
for
layer
in
fpn
.
layers
:
if
'up_sampling2d'
in
layer
.
name
:
num_up_sampling2d_layers
+=
1
self
.
assertEqual
(
'bilinear'
,
layer
.
interpolation
)
# Verify that there are up_sampling2d layers.
self
.
assertGreater
(
num_up_sampling2d_layers
,
0
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
research/object_detection/core/preprocessor.py
View file @
319589aa
...
@@ -1776,6 +1776,7 @@ def random_pad_image(image,
...
@@ -1776,6 +1776,7 @@ def random_pad_image(image,
min_image_size
=
None
,
min_image_size
=
None
,
max_image_size
=
None
,
max_image_size
=
None
,
pad_color
=
None
,
pad_color
=
None
,
center_pad
=
False
,
seed
=
None
,
seed
=
None
,
preprocess_vars_cache
=
None
):
preprocess_vars_cache
=
None
):
"""Randomly pads the image.
"""Randomly pads the image.
...
@@ -1814,6 +1815,8 @@ def random_pad_image(image,
...
@@ -1814,6 +1815,8 @@ def random_pad_image(image,
pad_color: padding color. A rank 1 tensor of [channels] with dtype=
pad_color: padding color. A rank 1 tensor of [channels] with dtype=
tf.float32. if set as None, it will be set to average color of
tf.float32. if set as None, it will be set to average color of
the input image.
the input image.
center_pad: whether the original image will be padded to the center, or
randomly padded (which is default).
seed: random seed.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
performed augmentations. Updated in-place. If this
...
@@ -1870,6 +1873,12 @@ def random_pad_image(image,
...
@@ -1870,6 +1873,12 @@ def random_pad_image(image,
lambda
:
_random_integer
(
0
,
target_width
-
image_width
,
seed
),
lambda
:
_random_integer
(
0
,
target_width
-
image_width
,
seed
),
lambda
:
tf
.
constant
(
0
,
dtype
=
tf
.
int32
))
lambda
:
tf
.
constant
(
0
,
dtype
=
tf
.
int32
))
if
center_pad
:
offset_height
=
tf
.
cast
(
tf
.
floor
((
target_height
-
image_height
)
/
2
),
tf
.
int32
)
offset_width
=
tf
.
cast
(
tf
.
floor
((
target_width
-
image_width
)
/
2
),
tf
.
int32
)
gen_func
=
lambda
:
(
target_height
,
target_width
,
offset_height
,
offset_width
)
gen_func
=
lambda
:
(
target_height
,
target_width
,
offset_height
,
offset_width
)
params
=
_get_or_create_preprocess_rand_vars
(
params
=
_get_or_create_preprocess_rand_vars
(
gen_func
,
preprocessor_cache
.
PreprocessorCache
.
PAD_IMAGE
,
gen_func
,
preprocessor_cache
.
PreprocessorCache
.
PAD_IMAGE
,
...
@@ -2113,7 +2122,7 @@ def random_crop_pad_image(image,
...
@@ -2113,7 +2122,7 @@ def random_crop_pad_image(image,
max_padded_size_ratio
,
max_padded_size_ratio
,
dtype
=
tf
.
int32
)
dtype
=
tf
.
int32
)
padded_image
,
padded_boxes
=
random_pad_image
(
padded_image
,
padded_boxes
=
random_pad_image
(
# pylint: disable=unbalanced-tuple-unpacking
cropped_image
,
cropped_image
,
cropped_boxes
,
cropped_boxes
,
min_image_size
=
min_image_size
,
min_image_size
=
min_image_size
,
...
@@ -2153,6 +2162,7 @@ def random_crop_to_aspect_ratio(image,
...
@@ -2153,6 +2162,7 @@ def random_crop_to_aspect_ratio(image,
aspect_ratio
=
1.0
,
aspect_ratio
=
1.0
,
overlap_thresh
=
0.3
,
overlap_thresh
=
0.3
,
clip_boxes
=
True
,
clip_boxes
=
True
,
center_crop
=
False
,
seed
=
None
,
seed
=
None
,
preprocess_vars_cache
=
None
):
preprocess_vars_cache
=
None
):
"""Randomly crops an image to the specified aspect ratio.
"""Randomly crops an image to the specified aspect ratio.
...
@@ -2191,6 +2201,7 @@ def random_crop_to_aspect_ratio(image,
...
@@ -2191,6 +2201,7 @@ def random_crop_to_aspect_ratio(image,
overlap_thresh: minimum overlap thresh with new cropped
overlap_thresh: minimum overlap thresh with new cropped
image to keep the box.
image to keep the box.
clip_boxes: whether to clip the boxes to the cropped image.
clip_boxes: whether to clip the boxes to the cropped image.
center_crop: whether to take the center crop or a random crop.
seed: random seed.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
performed augmentations. Updated in-place. If this
...
@@ -2247,8 +2258,14 @@ def random_crop_to_aspect_ratio(image,
...
@@ -2247,8 +2258,14 @@ def random_crop_to_aspect_ratio(image,
# either offset_height = 0 and offset_width is randomly chosen from
# either offset_height = 0 and offset_width is randomly chosen from
# [0, offset_width - target_width), or else offset_width = 0 and
# [0, offset_width - target_width), or else offset_width = 0 and
# offset_height is randomly chosen from [0, offset_height - target_height)
# offset_height is randomly chosen from [0, offset_height - target_height)
offset_height
=
_random_integer
(
0
,
orig_height
-
target_height
+
1
,
seed
)
if
center_crop
:
offset_width
=
_random_integer
(
0
,
orig_width
-
target_width
+
1
,
seed
)
offset_height
=
tf
.
cast
(
tf
.
math
.
floor
((
orig_height
-
target_height
)
/
2
),
tf
.
int32
)
offset_width
=
tf
.
cast
(
tf
.
math
.
floor
((
orig_width
-
target_width
)
/
2
),
tf
.
int32
)
else
:
offset_height
=
_random_integer
(
0
,
orig_height
-
target_height
+
1
,
seed
)
offset_width
=
_random_integer
(
0
,
orig_width
-
target_width
+
1
,
seed
)
generator_func
=
lambda
:
(
offset_height
,
offset_width
)
generator_func
=
lambda
:
(
offset_height
,
offset_width
)
offset_height
,
offset_width
=
_get_or_create_preprocess_rand_vars
(
offset_height
,
offset_width
=
_get_or_create_preprocess_rand_vars
(
...
@@ -2979,7 +2996,7 @@ def resize_to_range(image,
...
@@ -2979,7 +2996,7 @@ def resize_to_range(image,
'per-channel pad value.'
)
'per-channel pad value.'
)
new_image
=
tf
.
stack
(
new_image
=
tf
.
stack
(
[
[
tf
.
pad
(
tf
.
pad
(
# pylint: disable=g-complex-comprehension
channels
[
i
],
[[
0
,
max_dimension
-
new_size
[
0
]],
channels
[
i
],
[[
0
,
max_dimension
-
new_size
[
0
]],
[
0
,
max_dimension
-
new_size
[
1
]]],
[
0
,
max_dimension
-
new_size
[
1
]]],
constant_values
=
per_channel_pad_value
[
i
])
constant_values
=
per_channel_pad_value
[
i
])
...
...
research/object_detection/core/preprocessor_test.py
View file @
319589aa
...
@@ -2194,6 +2194,54 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
...
@@ -2194,6 +2194,54 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
expected_boxes
.
flatten
())
expected_boxes
.
flatten
())
self
.
assertAllEqual
(
distorted_masks_
.
shape
,
[
1
,
200
,
200
])
self
.
assertAllEqual
(
distorted_masks_
.
shape
,
[
1
,
200
,
200
])
def
testRunRandomCropToAspectRatioCenterCrop
(
self
):
def
graph_fn
():
image
=
self
.
createColorfulTestImage
()
boxes
=
self
.
createTestBoxes
()
labels
=
self
.
createTestLabels
()
weights
=
self
.
createTestGroundtruthWeights
()
masks
=
tf
.
random_uniform
([
2
,
200
,
400
],
dtype
=
tf
.
float32
)
tensor_dict
=
{
fields
.
InputDataFields
.
image
:
image
,
fields
.
InputDataFields
.
groundtruth_boxes
:
boxes
,
fields
.
InputDataFields
.
groundtruth_classes
:
labels
,
fields
.
InputDataFields
.
groundtruth_weights
:
weights
,
fields
.
InputDataFields
.
groundtruth_instance_masks
:
masks
}
preprocessor_arg_map
=
preprocessor
.
get_default_func_arg_map
(
include_instance_masks
=
True
)
preprocessing_options
=
[(
preprocessor
.
random_crop_to_aspect_ratio
,
{
'center_crop'
:
True
})]
with
mock
.
patch
.
object
(
preprocessor
,
'_random_integer'
)
as
mock_random_integer
:
mock_random_integer
.
return_value
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
)
distorted_tensor_dict
=
preprocessor
.
preprocess
(
tensor_dict
,
preprocessing_options
,
func_arg_map
=
preprocessor_arg_map
)
distorted_image
=
distorted_tensor_dict
[
fields
.
InputDataFields
.
image
]
distorted_boxes
=
distorted_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
]
distorted_labels
=
distorted_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_classes
]
return
[
distorted_image
,
distorted_boxes
,
distorted_labels
]
(
distorted_image_
,
distorted_boxes_
,
distorted_labels_
)
=
self
.
execute_cpu
(
graph_fn
,
[])
expected_boxes
=
np
.
array
([[
0.0
,
0.0
,
0.75
,
1.0
],
[
0.25
,
0.5
,
0.75
,
1.0
]],
dtype
=
np
.
float32
)
self
.
assertAllEqual
(
distorted_image_
.
shape
,
[
1
,
200
,
200
,
3
])
self
.
assertAllEqual
(
distorted_labels_
,
[
1
,
2
])
self
.
assertAllClose
(
distorted_boxes_
.
flatten
(),
expected_boxes
.
flatten
())
def
testRunRandomCropToAspectRatioWithKeypoints
(
self
):
def
testRunRandomCropToAspectRatioWithKeypoints
(
self
):
def
graph_fn
():
def
graph_fn
():
image
=
self
.
createColorfulTestImage
()
image
=
self
.
createColorfulTestImage
()
...
@@ -2433,6 +2481,51 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
...
@@ -2433,6 +2481,51 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self
.
assertTrue
(
np
.
all
((
boxes_
[:,
3
]
-
boxes_
[:,
1
])
>=
(
self
.
assertTrue
(
np
.
all
((
boxes_
[:,
3
]
-
boxes_
[:,
1
])
>=
(
padded_boxes_
[:,
3
]
-
padded_boxes_
[:,
1
])))
padded_boxes_
[:,
3
]
-
padded_boxes_
[:,
1
])))
def
testRandomPadImageCenterPad
(
self
):
def
graph_fn
():
preprocessing_options
=
[(
preprocessor
.
normalize_image
,
{
'original_minval'
:
0
,
'original_maxval'
:
255
,
'target_minval'
:
0
,
'target_maxval'
:
1
})]
images
=
self
.
createColorfulTestImage
()
boxes
=
self
.
createTestBoxes
()
labels
=
self
.
createTestLabels
()
tensor_dict
=
{
fields
.
InputDataFields
.
image
:
images
,
fields
.
InputDataFields
.
groundtruth_boxes
:
boxes
,
fields
.
InputDataFields
.
groundtruth_classes
:
labels
,
}
tensor_dict
=
preprocessor
.
preprocess
(
tensor_dict
,
preprocessing_options
)
images
=
tensor_dict
[
fields
.
InputDataFields
.
image
]
preprocessing_options
=
[(
preprocessor
.
random_pad_image
,
{
'center_pad'
:
True
,
'min_image_size'
:
[
400
,
400
],
'max_image_size'
:
[
400
,
400
],
})]
padded_tensor_dict
=
preprocessor
.
preprocess
(
tensor_dict
,
preprocessing_options
)
padded_images
=
padded_tensor_dict
[
fields
.
InputDataFields
.
image
]
padded_boxes
=
padded_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
]
padded_labels
=
padded_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_classes
]
return
[
padded_images
,
padded_boxes
,
padded_labels
]
(
padded_images_
,
padded_boxes_
,
padded_labels_
)
=
self
.
execute_cpu
(
graph_fn
,
[])
expected_boxes
=
np
.
array
([[
0.25
,
0.25
,
0.625
,
1.0
],
[
0.375
,
0.5
,
.
625
,
1.0
]],
dtype
=
np
.
float32
)
self
.
assertAllEqual
(
padded_images_
.
shape
,
[
1
,
400
,
400
,
3
])
self
.
assertAllEqual
(
padded_labels_
,
[
1
,
2
])
self
.
assertAllClose
(
padded_boxes_
.
flatten
(),
expected_boxes
.
flatten
())
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
{
'include_dense_pose'
:
False
},
{
'include_dense_pose'
:
False
},
)
)
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
319589aa
...
@@ -117,23 +117,9 @@ class CenterNetFeatureExtractor(tf.keras.Model):
...
@@ -117,23 +117,9 @@ class CenterNetFeatureExtractor(tf.keras.Model):
pass
pass
@
property
@
property
@
abc
.
abstractmethod
def
classification_backbone
(
self
):
def
supported_sub_model_types
(
self
):
raise
NotImplementedError
(
"""Valid sub model types supported by the get_sub_model function."""
'Classification backbone not supported for {}'
.
format
(
type
(
self
)))
pass
@
abc
.
abstractmethod
def
get_sub_model
(
self
,
sub_model_type
):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def
make_prediction_net
(
num_out_channels
,
kernel_sizes
=
(
3
),
num_filters
=
(
256
),
def
make_prediction_net
(
num_out_channels
,
kernel_sizes
=
(
3
),
num_filters
=
(
256
),
...
@@ -4200,25 +4186,28 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -4200,25 +4186,28 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
"""
supported_types
=
self
.
_feature_extractor
.
supported_sub_model_types
if
fine_tune_checkpoint_type
==
'detection'
:
supported_types
+=
[
'fine_tune'
]
if
fine_tune_checkpoint_type
not
in
supported_types
:
message
=
(
'Checkpoint type "{}" not supported for {}. '
'Supported types are {}'
)
raise
ValueError
(
message
.
format
(
fine_tune_checkpoint_type
,
self
.
_feature_extractor
.
__class__
.
__name__
,
supported_types
))
elif
fine_tune_checkpoint_type
==
'fine_tune'
:
feature_extractor_model
=
tf
.
train
.
Checkpoint
(
feature_extractor_model
=
tf
.
train
.
Checkpoint
(
_feature_extractor
=
self
.
_feature_extractor
)
_feature_extractor
=
self
.
_feature_extractor
)
return
{
'model'
:
feature_extractor_model
}
return
{
'model'
:
feature_extractor_model
}
elif
fine_tune_checkpoint_type
==
'classification'
:
return
{
'feature_extractor'
:
self
.
_feature_extractor
.
classification_backbone
}
elif
fine_tune_checkpoint_type
==
'full'
:
return
{
'model'
:
self
}
elif
fine_tune_checkpoint_type
==
'fine_tune'
:
raise
ValueError
((
'"fine_tune" is no longer supported for CenterNet. '
'Please set fine_tune_checkpoint_type to "detection"'
' which has the same functionality. If you are using'
' the ExtremeNet checkpoint, download the new version'
' from the model zoo.'
))
else
:
else
:
r
eturn
{
'feature_extractor'
:
self
.
_feature_extractor
.
get_sub_model
(
r
aise
ValueError
(
'Unknown fine tune checkpoint type {}'
.
format
(
fine_tune_checkpoint_type
)
}
fine_tune_checkpoint_type
)
)
def
updates
(
self
):
def
updates
(
self
):
if
tf_version
.
is_tf2
():
if
tf_version
.
is_tf2
():
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
319589aa
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
from
__future__
import
division
from
__future__
import
division
import
functools
import
functools
import
re
import
unittest
import
unittest
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
...
@@ -2887,15 +2886,14 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
...
@@ -2887,15 +2886,14 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self
.
assertIsInstance
(
restore_from_objects_map
[
'feature_extractor'
],
self
.
assertIsInstance
(
restore_from_objects_map
[
'feature_extractor'
],
tf
.
keras
.
Model
)
tf
.
keras
.
Model
)
def
test_retore_map_
error
(
self
):
def
test_retore_map_
detection
(
self
):
"""Test that
restoring unsupported checkpoint type raise
s an
error
."""
"""Test that
detection checkpoint
s
c
an
be restored
."""
model
=
build_center_net_meta_arch
(
build_resnet
=
True
)
model
=
build_center_net_meta_arch
(
build_resnet
=
True
)
msg
=
(
"Checkpoint type
\"
detection
\"
not supported for "
restore_from_objects_map
=
model
.
restore_from_objects
(
'detection'
)
"CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']"
)
self
.
assertIsInstance
(
restore_from_objects_map
[
'model'
].
_feature_extractor
,
with
self
.
assertRaisesRegex
(
ValueError
,
re
.
escape
(
msg
)):
tf
.
keras
.
Model
)
model
.
restore_from_objects
(
'detection'
)
class
DummyFeatureExtractor
(
cnma
.
CenterNetFeatureExtractor
):
class
DummyFeatureExtractor
(
cnma
.
CenterNetFeatureExtractor
):
...
...
research/object_detection/meta_architectures/faster_rcnn_meta_arch.py
View file @
319589aa
...
@@ -2896,6 +2896,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
...
@@ -2896,6 +2896,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
_feature_extractor_for_proposal_features
=
_feature_extractor_for_proposal_features
=
self
.
_feature_extractor_for_proposal_features
)
self
.
_feature_extractor_for_proposal_features
)
return
{
'model'
:
fake_model
}
return
{
'model'
:
fake_model
}
elif
fine_tune_checkpoint_type
==
'full'
:
return
{
'model'
:
self
}
else
:
else
:
raise
ValueError
(
'Not supported fine_tune_checkpoint_type: {}'
.
format
(
raise
ValueError
(
'Not supported fine_tune_checkpoint_type: {}'
.
format
(
fine_tune_checkpoint_type
))
fine_tune_checkpoint_type
))
...
...
research/object_detection/model_lib_v2.py
View file @
319589aa
...
@@ -35,6 +35,7 @@ from object_detection.protos import train_pb2
...
@@ -35,6 +35,7 @@ from object_detection.protos import train_pb2
from
object_detection.utils
import
config_util
from
object_detection.utils
import
config_util
from
object_detection.utils
import
label_map_util
from
object_detection.utils
import
label_map_util
from
object_detection.utils
import
ops
from
object_detection.utils
import
ops
from
object_detection.utils
import
variables_helper
from
object_detection.utils
import
visualization_utils
as
vutils
from
object_detection.utils
import
visualization_utils
as
vutils
...
@@ -587,6 +588,9 @@ def train_loop(
...
@@ -587,6 +588,9 @@ def train_loop(
lambda
:
global_step
%
num_steps_per_iteration
==
0
):
lambda
:
global_step
%
num_steps_per_iteration
==
0
):
# Load a fine-tuning checkpoint.
# Load a fine-tuning checkpoint.
if
train_config
.
fine_tune_checkpoint
:
if
train_config
.
fine_tune_checkpoint
:
variables_helper
.
ensure_checkpoint_supported
(
train_config
.
fine_tune_checkpoint
,
fine_tune_checkpoint_type
,
model_dir
)
load_fine_tune_checkpoint
(
load_fine_tune_checkpoint
(
detection_model
,
train_config
.
fine_tune_checkpoint
,
detection_model
,
train_config
.
fine_tune_checkpoint
,
fine_tune_checkpoint_type
,
fine_tune_checkpoint_version
,
fine_tune_checkpoint_type
,
fine_tune_checkpoint_version
,
...
...
research/object_detection/models/center_net_hourglass_feature_extractor.py
View file @
319589aa
...
@@ -62,16 +62,6 @@ class CenterNetHourglassFeatureExtractor(
...
@@ -62,16 +62,6 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
"""Ther number of feature outputs returned by the feature extractor."""
return
self
.
_network
.
num_hourglasses
return
self
.
_network
.
num_hourglasses
@
property
def
supported_sub_model_types
(
self
):
return
[
'detection'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
hourglass_10
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
def
hourglass_10
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The Hourglass-10 backbone for CenterNet."""
"""The Hourglass-10 backbone for CenterNet."""
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
319589aa
...
@@ -83,9 +83,6 @@ class CenterNetMobileNetV2FeatureExtractor(
...
@@ -83,9 +83,6 @@ class CenterNetMobileNetV2FeatureExtractor(
def
load_feature_extractor_weights
(
self
,
path
):
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_network
.
load_weights
(
path
)
self
.
_network
.
load_weights
(
path
)
def
get_base_model
(
self
):
return
self
.
_network
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
return
[
self
.
_network
(
inputs
)]
return
[
self
.
_network
(
inputs
)]
...
@@ -100,14 +97,8 @@ class CenterNetMobileNetV2FeatureExtractor(
...
@@ -100,14 +97,8 @@ class CenterNetMobileNetV2FeatureExtractor(
return
1
return
1
@
property
@
property
def
supported_sub_model_types
(
self
):
def
classification_backbone
(
self
):
return
[
'detection'
]
return
self
.
_network
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
,
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
,
...
...
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
View file @
319589aa
...
@@ -39,7 +39,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
...
@@ -39,7 +39,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means
=
(
0.
,
0.
,
0.
),
channel_means
=
(
0.
,
0.
,
0.
),
channel_stds
=
(
1.
,
1.
,
1.
),
channel_stds
=
(
1.
,
1.
,
1.
),
bgr_ordering
=
False
,
bgr_ordering
=
False
,
use_separable_conv
=
False
):
use_separable_conv
=
False
,
upsampling_interpolation
=
'nearest'
):
"""Intializes the feature extractor.
"""Intializes the feature extractor.
Args:
Args:
...
@@ -52,6 +53,9 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
...
@@ -52,6 +53,9 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
[blue, red, green] order.
[blue, red, green] order.
use_separable_conv: If set to True, all convolutional layers in the FPN
use_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions.
network will be replaced by separable convolutions.
upsampling_interpolation: A string (one of 'nearest' or 'bilinear')
indicating which interpolation method to use for the upsampling ops in
the FPN.
"""
"""
super
(
CenterNetMobileNetV2FPNFeatureExtractor
,
self
).
__init__
(
super
(
CenterNetMobileNetV2FPNFeatureExtractor
,
self
).
__init__
(
...
@@ -84,7 +88,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
...
@@ -84,7 +88,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
for
i
,
num_filters
in
enumerate
(
num_filters_list
):
for
i
,
num_filters
in
enumerate
(
num_filters_list
):
level_ind
=
len
(
num_filters_list
)
-
1
-
i
level_ind
=
len
(
num_filters_list
)
-
1
-
i
# Upsample.
# Upsample.
upsample_op
=
tf
.
keras
.
layers
.
UpSampling2D
(
2
,
interpolation
=
'nearest'
)
upsample_op
=
tf
.
keras
.
layers
.
UpSampling2D
(
2
,
interpolation
=
upsampling_interpolation
)
top_down
=
upsample_op
(
top_down
)
top_down
=
upsample_op
(
top_down
)
# Residual (skip-connection) from bottom-up pathway.
# Residual (skip-connection) from bottom-up pathway.
...
@@ -144,7 +149,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
...
@@ -144,7 +149,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
def
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
def
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
use_separable_conv
=
False
,
depth_multiplier
=
1.0
,
**
kwargs
):
use_separable_conv
=
False
,
depth_multiplier
=
1.0
,
upsampling_interpolation
=
'nearest'
,
**
kwargs
):
"""The MobileNetV2+FPN backbone for CenterNet."""
"""The MobileNetV2+FPN backbone for CenterNet."""
del
kwargs
del
kwargs
...
@@ -159,4 +165,5 @@ def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
...
@@ -159,4 +165,5 @@ def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
channel_means
=
channel_means
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
,
bgr_ordering
=
bgr_ordering
,
use_separable_conv
=
use_separable_conv
)
use_separable_conv
=
use_separable_conv
,
upsampling_interpolation
=
upsampling_interpolation
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment