Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
977ae4db
Commit
977ae4db
authored
May 05, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
May 05, 2021
Browse files
Internal change
PiperOrigin-RevId: 372153802
parent
7907ba50
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1058 additions
and
0 deletions
+1058
-0
official/nlp/serving/export_savedmodel.py
official/nlp/serving/export_savedmodel.py
+141
-0
official/nlp/serving/export_savedmodel_test.py
official/nlp/serving/export_savedmodel_test.py
+164
-0
official/nlp/serving/export_savedmodel_util.py
official/nlp/serving/export_savedmodel_util.py
+45
-0
official/nlp/serving/serving_modules.py
official/nlp/serving/serving_modules.py
+391
-0
official/nlp/serving/serving_modules_test.py
official/nlp/serving/serving_modules_test.py
+317
-0
No files found.
official/nlp/serving/export_savedmodel.py
0 → 100644
View file @
977ae4db
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A binary/library to export TF-NLP serving `SavedModel`."""
import
os
from
typing
import
Any
,
Dict
,
Text
from
absl
import
app
from
absl
import
flags
import
dataclasses
import
yaml
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling
import
hyperparams
from
official.modeling.hyperparams
import
base_config
from
official.nlp.serving
import
export_savedmodel_util
from
official.nlp.serving
import
serving_modules
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
FLAGS
=
flags
.
FLAGS
SERVING_MODULES
=
{
sentence_prediction
.
SentencePredictionTask
:
serving_modules
.
SentencePrediction
,
masked_lm
.
MaskedLMTask
:
serving_modules
.
MaskedLM
,
question_answering
.
QuestionAnsweringTask
:
serving_modules
.
QuestionAnswering
,
tagging
.
TaggingTask
:
serving_modules
.
Tagging
}
def
define_flags
():
"""Defines flags."""
flags
.
DEFINE_string
(
"task_name"
,
"SentencePrediction"
,
"The task to export."
)
flags
.
DEFINE_string
(
"config_file"
,
None
,
"The path to task/experiment yaml config file."
)
flags
.
DEFINE_string
(
"checkpoint_path"
,
None
,
"Object-based checkpoint path, from the training model directory."
)
flags
.
DEFINE_string
(
"export_savedmodel_dir"
,
None
,
"Output saved model directory."
)
flags
.
DEFINE_string
(
"serving_params"
,
None
,
"a YAML/JSON string or csv string for the serving parameters."
)
flags
.
DEFINE_string
(
"function_keys"
,
None
,
"A string key to retrieve pre-defined serving signatures."
)
flags
.
DEFINE_bool
(
"convert_tpu"
,
False
,
""
)
flags
.
DEFINE_multi_integer
(
"allowed_batch_size"
,
None
,
"Allowed batch sizes for batching ops."
)
def
lookup_export_module
(
task
:
base_task
.
Task
):
export_module_cls
=
SERVING_MODULES
.
get
(
task
.
__class__
,
None
)
if
export_module_cls
is
None
:
ValueError
(
"No registered export module for the task: %s"
,
task
.
__class__
)
return
export_module_cls
def
create_export_module
(
*
,
task_name
:
Text
,
config_file
:
Text
,
serving_params
:
Dict
[
Text
,
Any
]):
"""Creates a ExportModule."""
task_config_cls
=
None
task_cls
=
None
# pylint: disable=protected-access
for
key
,
value
in
task_factory
.
_REGISTERED_TASK_CLS
.
items
():
print
(
key
.
__name__
)
if
task_name
in
key
.
__name__
:
task_config_cls
,
task_cls
=
key
,
value
break
if
task_cls
is
None
:
raise
ValueError
(
"Failed to identify the task class. The provided task "
f
"name is
{
task_name
}
"
)
# pylint: enable=protected-access
# TODO(hongkuny): Figure out how to separate the task config from experiments.
@
dataclasses
.
dataclass
class
Dummy
(
base_config
.
Config
):
task
:
task_config_cls
=
task_config_cls
()
dummy_exp
=
Dummy
()
dummy_exp
=
hyperparams
.
override_params_dict
(
dummy_exp
,
config_file
,
is_strict
=
False
)
dummy_exp
.
task
.
validation_data
=
None
task
=
task_cls
(
dummy_exp
.
task
)
model
=
task
.
build_model
()
export_module_cls
=
lookup_export_module
(
task
)
params
=
export_module_cls
.
Params
(
**
serving_params
)
return
export_module_cls
(
params
=
params
,
model
=
model
)
def
main
(
_
):
serving_params
=
yaml
.
load
(
hyperparams
.
nested_csv_str_to_json_str
(
FLAGS
.
serving_params
),
Loader
=
yaml
.
FullLoader
)
export_module
=
create_export_module
(
task_name
=
FLAGS
.
task_name
,
config_file
=
FLAGS
.
config_file
,
serving_params
=
serving_params
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
[
FLAGS
.
function_keys
],
checkpoint_path
=
FLAGS
.
checkpoint_path
,
export_savedmodel_dir
=
FLAGS
.
export_savedmodel_dir
)
if
FLAGS
.
convert_tpu
:
# pylint: disable=g-import-not-at-top
from
cloud_tpu.inference_converter
import
converter_cli
from
cloud_tpu.inference_converter
import
converter_options_pb2
tpu_dir
=
os
.
path
.
join
(
export_dir
,
"tpu"
)
options
=
converter_options_pb2
.
ConverterOptions
()
if
FLAGS
.
allowed_batch_size
is
not
None
:
allowed_batch_sizes
=
sorted
(
FLAGS
.
allowed_batch_size
)
options
.
batch_options
.
num_batch_threads
=
4
options
.
batch_options
.
max_batch_size
=
allowed_batch_sizes
[
-
1
]
options
.
batch_options
.
batch_timeout_micros
=
100000
options
.
batch_options
.
allowed_batch_sizes
[:]
=
allowed_batch_sizes
options
.
batch_options
.
max_enqueued_batches
=
1000
converter_cli
.
ConvertSavedModel
(
export_dir
,
tpu_dir
,
function_alias
=
"tpu_candidate"
,
options
=
options
,
graph_rewrite_only
=
True
)
if
__name__
==
"__main__"
:
define_flags
()
app
.
run
(
main
)
official/nlp/serving/export_savedmodel_test.py
0 → 100644
View file @
977ae4db
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.export_saved_model."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.serving
import
export_savedmodel
from
official.nlp.serving
import
export_savedmodel_util
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
class
ExportSavedModelTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_create_export_module
(
self
):
export_module
=
export_savedmodel
.
create_export_module
(
task_name
=
"SentencePrediction"
,
config_file
=
None
,
serving_params
=
{
"inputs_only"
:
False
,
"parse_sequence_length"
:
10
})
self
.
assertEqual
(
export_module
.
name
,
"sentence_prediction"
)
self
.
assertFalse
(
export_module
.
params
.
inputs_only
)
self
.
assertEqual
(
export_module
.
params
.
parse_sequence_length
,
10
)
def
test_sentence_prediction
(
self
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
num_classes
=
2
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
export_module_cls
=
export_savedmodel
.
lookup_export_module
(
task
)
serving_params
=
{
"inputs_only"
:
False
}
params
=
export_module_cls
.
Params
(
**
serving_params
)
export_module
=
export_module_cls
(
params
=
params
,
model
=
model
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
[
"serve"
],
checkpoint_path
=
ckpt_path
,
export_savedmodel_dir
=
self
.
get_temp_dir
())
imported
=
tf
.
saved_model
.
load
(
export_dir
)
serving_fn
=
imported
.
signatures
[
"serving_default"
]
dummy_ids
=
tf
.
ones
((
1
,
5
),
dtype
=
tf
.
int32
)
inputs
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
ref_outputs
=
model
(
inputs
)
outputs
=
serving_fn
(
**
inputs
)
self
.
assertAllClose
(
ref_outputs
,
outputs
[
"outputs"
])
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
1
,
2
))
def
test_masked_lm
(
self
):
config
=
masked_lm
.
MaskedLMConfig
(
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"foo"
)
]))
task
=
masked_lm
.
MaskedLMTask
(
config
)
model
=
task
.
build_model
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
export_module_cls
=
export_savedmodel
.
lookup_export_module
(
task
)
serving_params
=
{
"cls_head_name"
:
"foo"
,
"parse_sequence_length"
:
10
,
"max_predictions_per_seq"
:
5
}
params
=
export_module_cls
.
Params
(
**
serving_params
)
export_module
=
export_module_cls
(
params
=
params
,
model
=
model
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
{
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
},
checkpoint_path
=
ckpt_path
,
export_savedmodel_dir
=
self
.
get_temp_dir
())
imported
=
tf
.
saved_model
.
load
(
export_dir
)
self
.
assertSameElements
(
imported
.
signatures
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
serving_fn
=
imported
.
signatures
[
"serving_default"
]
dummy_ids
=
tf
.
ones
((
1
,
10
),
dtype
=
tf
.
int32
)
dummy_pos
=
tf
.
ones
((
1
,
5
),
dtype
=
tf
.
int32
)
outputs
=
serving_fn
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
,
masked_lm_positions
=
dummy_pos
)
self
.
assertEqual
(
outputs
[
"classification"
].
shape
,
(
1
,
2
))
@
parameterized
.
parameters
(
True
,
False
)
def
test_tagging
(
self
,
output_encoder_outputs
):
hidden_size
=
768
num_classes
=
3
config
=
tagging
.
TaggingConfig
(
model
=
tagging
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
hidden_size
=
hidden_size
,
num_layers
=
1
))),
class_names
=
[
"class_0"
,
"class_1"
,
"class_2"
])
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
export_module_cls
=
export_savedmodel
.
lookup_export_module
(
task
)
serving_params
=
{
"parse_sequence_length"
:
10
,
}
params
=
export_module_cls
.
Params
(
**
serving_params
,
output_encoder_outputs
=
output_encoder_outputs
)
export_module
=
export_module_cls
(
params
=
params
,
model
=
model
)
export_dir
=
export_savedmodel_util
.
export
(
export_module
,
function_keys
=
{
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
},
checkpoint_path
=
ckpt_path
,
export_savedmodel_dir
=
self
.
get_temp_dir
())
imported
=
tf
.
saved_model
.
load
(
export_dir
)
self
.
assertCountEqual
(
imported
.
signatures
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
serving_fn
=
imported
.
signatures
[
"serving_default"
]
dummy_ids
=
tf
.
ones
((
1
,
5
),
dtype
=
tf
.
int32
)
inputs
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
outputs
=
serving_fn
(
**
inputs
)
self
.
assertEqual
(
outputs
[
"logits"
].
shape
,
(
1
,
5
,
num_classes
))
if
output_encoder_outputs
:
self
.
assertEqual
(
outputs
[
"encoder_outputs"
].
shape
,
(
1
,
5
,
hidden_size
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/serving/export_savedmodel_util.py
0 → 100644
View file @
977ae4db
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common library to export a SavedModel from the export module."""
from
typing
import
Dict
,
List
,
Optional
,
Text
,
Union
import
tensorflow
as
tf
from
official.core
import
export_base
def
export
(
export_module
:
export_base
.
ExportModule
,
function_keys
:
Union
[
List
[
Text
],
Dict
[
Text
,
Text
]],
export_savedmodel_dir
:
Text
,
checkpoint_path
:
Optional
[
Text
]
=
None
,
timestamped
:
bool
=
True
)
->
Text
:
"""Exports to SavedModel format.
Args:
export_module: a ExportModule with the keras Model and serving tf.functions.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
Returns:
The savedmodel directory path.
"""
save_options
=
tf
.
saved_model
.
SaveOptions
(
function_aliases
=
{
"tpu_candidate"
:
export_module
.
serve
,
})
return
export_base
.
export
(
export_module
,
function_keys
,
export_savedmodel_dir
,
checkpoint_path
,
timestamped
,
save_options
)
official/nlp/serving/serving_modules.py
0 → 100644
View file @
977ae4db
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serving export modules for TF Model Garden NLP models."""
# pylint:disable=missing-class-docstring
from
typing
import
Dict
,
List
,
Optional
,
Text
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
export_base
from
official.modeling.hyperparams
import
base_config
from
official.nlp.data
import
sentence_prediction_dataloader
def
features_to_int32
(
features
:
Dict
[
str
,
tf
.
Tensor
])
->
Dict
[
str
,
tf
.
Tensor
]:
"""Converts tf.int64 features to tf.int32, keep other features the same.
tf.Example only supports tf.int64, but the TPU only supports tf.int32.
Args:
features: Input tensor dictionary.
Returns:
Features with tf.int64 converted to tf.int32.
"""
converted_features
=
{}
for
name
,
tensor
in
features
.
items
():
if
tensor
.
dtype
==
tf
.
int64
:
converted_features
[
name
]
=
tf
.
cast
(
tensor
,
tf
.
int32
)
else
:
converted_features
[
name
]
=
tensor
return
converted_features
class
SentencePrediction
(
export_base
.
ExportModule
):
"""The export module for the sentence prediction task."""
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
inputs_only
:
bool
=
True
parse_sequence_length
:
Optional
[
int
]
=
None
use_v2_feature_names
:
bool
=
True
# For text input processing.
text_fields
:
Optional
[
List
[
str
]]
=
None
# Either specify these values for preprocessing by Python code...
tokenization
:
str
=
"WordPiece"
# WordPiece or SentencePiece
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
# file if tokenization is SentencePiece.
vocab_file
:
str
=
""
lower_case
:
bool
=
True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url
:
str
=
""
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
if
params
.
text_fields
:
self
.
_text_processor
=
sentence_prediction_dataloader
.
TextProcessor
(
seq_length
=
params
.
parse_sequence_length
,
vocab_file
=
params
.
vocab_file
,
tokenization
=
params
.
tokenization
,
lower_case
=
params
.
lower_case
,
preprocessing_hub_module_url
=
params
.
preprocessing_hub_module_url
)
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
=
None
,
input_type_ids
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
if
input_type_ids
is
None
:
# Requires CLS token is the first token of inputs.
input_type_ids
=
tf
.
zeros_like
(
input_word_ids
)
if
input_mask
is
None
:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask
=
tf
.
where
(
tf
.
equal
(
input_word_ids
,
0
),
tf
.
zeros_like
(
input_word_ids
),
tf
.
ones_like
(
input_word_ids
))
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
return
dict
(
outputs
=
self
.
inference_step
(
inputs
))
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
inputs_only
=
self
.
params
.
inputs_only
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
}
if
not
inputs_only
:
name_to_features
.
update
({
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
)
})
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
features
[
self
.
input_word_ids_field
],
input_mask
=
None
if
inputs_only
else
features
[
"input_mask"
],
input_type_ids
=
None
if
inputs_only
else
features
[
self
.
input_type_ids_field
])
@
tf
.
function
def
serve_text_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
name_to_features
=
{}
for
text_field
in
self
.
params
.
text_fields
:
name_to_features
[
text_field
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
string
)
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
segments
=
[
features
[
x
]
for
x
in
self
.
params
.
text_fields
]
model_inputs
=
self
.
_text_processor
(
segments
)
if
self
.
params
.
inputs_only
:
return
self
.
serve
(
input_word_ids
=
model_inputs
[
"input_word_ids"
])
return
self
.
serve
(
**
model_inputs
)
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
,
"serve_text_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
if
self
.
params
.
inputs_only
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
))
else
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
if
func_key
==
"serve_text_examples"
:
signatures
[
signature_key
]
=
self
.
serve_text_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
class
MaskedLM
(
export_base
.
ExportModule
):
"""The export module for the Bert Pretrain (MaskedLM) task."""
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
cls_head_name
:
str
=
"next_sentence"
use_v2_feature_names
:
bool
=
True
parse_sequence_length
:
Optional
[
int
]
=
None
max_predictions_per_seq
:
Optional
[
int
]
=
None
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
,
input_type_ids
,
masked_lm_positions
)
->
Dict
[
str
,
tf
.
Tensor
]:
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
,
masked_lm_positions
=
masked_lm_positions
)
outputs
=
self
.
inference_step
(
inputs
)
return
dict
(
classification
=
outputs
[
self
.
params
.
cls_head_name
])
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
max_predictions_per_seq
=
self
.
params
.
max_predictions_per_seq
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"masked_lm_positions"
:
tf
.
io
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
int64
)
}
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
input_word_ids
=
features
[
self
.
input_word_ids_field
],
input_mask
=
features
[
"input_mask"
],
input_type_ids
=
features
[
self
.
input_word_ids_field
],
masked_lm_positions
=
features
[
"masked_lm_positions"
])
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
),
masked_lm_positions
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"masked_lm_positions"
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
class
QuestionAnswering
(
export_base
.
ExportModule
):
"""The export module for the question answering task."""
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
parse_sequence_length
:
Optional
[
int
]
=
None
use_v2_feature_names
:
bool
=
True
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
=
None
,
input_type_ids
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
if
input_mask
is
None
:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask
=
tf
.
where
(
tf
.
equal
(
input_word_ids
,
0
),
tf
.
zeros_like
(
input_word_ids
),
tf
.
ones_like
(
input_word_ids
))
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
outputs
=
self
.
inference_step
(
inputs
)
return
dict
(
start_logits
=
outputs
[
0
],
end_logits
=
outputs
[
1
])
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
)
}
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
input_word_ids
=
features
[
self
.
input_word_ids_field
],
input_mask
=
features
[
"input_mask"
],
input_type_ids
=
features
[
self
.
input_type_ids_field
])
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
class
Tagging
(
export_base
.
ExportModule
):
"""The export module for the tagging task."""
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
parse_sequence_length
:
Optional
[
int
]
=
None
use_v2_feature_names
:
bool
=
True
output_encoder_outputs
:
bool
=
False
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
if
params
.
use_v2_feature_names
:
self
.
input_word_ids_field
=
"input_word_ids"
self
.
input_type_ids_field
=
"input_type_ids"
else
:
self
.
input_word_ids_field
=
"input_ids"
self
.
input_type_ids_field
=
"segment_ids"
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
,
input_type_ids
)
->
Dict
[
str
,
tf
.
Tensor
]:
inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
outputs
=
self
.
inference_step
(
inputs
)
if
self
.
params
.
output_encoder_outputs
:
return
dict
(
logits
=
outputs
[
"logits"
],
encoder_outputs
=
outputs
[
"encoder_outputs"
])
else
:
return
dict
(
logits
=
outputs
[
"logits"
])
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
sequence_length
=
self
.
params
.
parse_sequence_length
name_to_features
=
{
self
.
input_word_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
"input_mask"
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
),
self
.
input_type_ids_field
:
tf
.
io
.
FixedLenFeature
([
sequence_length
],
tf
.
int64
)
}
features
=
tf
.
io
.
parse_example
(
inputs
,
name_to_features
)
features
=
features_to_int32
(
features
)
return
self
.
serve
(
input_word_ids
=
features
[
self
.
input_word_ids_field
],
input_mask
=
features
[
"input_mask"
],
input_type_ids
=
features
[
self
.
input_type_ids_field
])
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve"
,
"serve_examples"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve"
:
signatures
[
signature_key
]
=
self
.
serve
.
get_concrete_function
(
input_word_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
self
.
input_word_ids_field
),
input_mask
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
"input_mask"
),
input_type_ids
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
int32
,
name
=
self
.
input_type_ids_field
))
if
func_key
==
"serve_examples"
:
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
official/nlp/serving/serving_modules_test.py
0 → 100644
View file @
977ae4db
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.serving.serving_modules."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.serving
import
serving_modules
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
def
_create_fake_serialized_examples
(
features_dict
):
"""Creates a fake dataset."""
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
def
create_str_feature
(
value
):
f
=
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
value
]))
return
f
examples
=
[]
for
_
in
range
(
10
):
features
=
{}
for
key
,
values
in
features_dict
.
items
():
if
isinstance
(
values
,
bytes
):
features
[
key
]
=
create_str_feature
(
values
)
else
:
features
[
key
]
=
create_int_feature
(
values
)
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
examples
.
append
(
tf_example
.
SerializeToString
())
return
tf
.
constant
(
examples
)
def
_create_fake_vocab_file
(
vocab_file_path
):
tokens
=
[
"[PAD]"
]
for
i
in
range
(
1
,
100
):
tokens
.
append
(
"[unused%d]"
%
i
)
tokens
.
extend
([
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"[MASK]"
,
"hello"
,
"world"
])
with
tf
.
io
.
gfile
.
GFile
(
vocab_file_path
,
"w"
)
as
outfile
:
outfile
.
write
(
"
\n
"
.
join
(
tokens
))
class
ServingModulesTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
# use_v2_feature_names
True
,
False
)
def
test_sentence_prediction
(
self
,
use_v2_feature_names
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
num_classes
=
2
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
SentencePrediction
.
Params
(
inputs_only
=
True
,
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
SentencePrediction
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
self
.
assertSameElements
(
functions
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
dummy_ids
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
params
=
serving_modules
.
SentencePrediction
.
Params
(
inputs_only
=
False
,
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
SentencePrediction
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
with
self
.
assertRaises
(
ValueError
):
_
=
export_module
.
get_inference_signatures
({
"foo"
:
None
})
@
parameterized
.
parameters
(
# inputs_only
True
,
False
)
def
test_sentence_prediction_text
(
self
,
inputs_only
):
vocab_file_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"vocab.txt"
)
_create_fake_vocab_file
(
vocab_file_path
)
config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
num_classes
=
2
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
SentencePrediction
.
Params
(
inputs_only
=
inputs_only
,
parse_sequence_length
=
10
,
text_fields
=
[
"foo"
,
"bar"
],
vocab_file
=
vocab_file_path
)
export_module
=
serving_modules
.
SentencePrediction
(
params
=
params
,
model
=
model
)
examples
=
_create_fake_serialized_examples
({
"foo"
:
b
"hello world"
,
"bar"
:
b
"hello world"
})
functions
=
export_module
.
get_inference_signatures
({
"serve_text_examples"
:
"serving_default"
,
})
outputs
=
functions
[
"serving_default"
](
examples
)
self
.
assertEqual
(
outputs
[
"outputs"
].
shape
,
(
10
,
2
))
@
parameterized
.
parameters
(
# use_v2_feature_names
True
,
False
)
def
test_masked_lm
(
self
,
use_v2_feature_names
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
config
=
masked_lm
.
MaskedLMConfig
(
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
]))
task
=
masked_lm
.
MaskedLMTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
MaskedLM
.
Params
(
parse_sequence_length
=
10
,
max_predictions_per_seq
=
5
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
MaskedLM
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
self
.
assertSameElements
(
functions
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
dummy_pos
=
tf
.
ones
((
10
,
5
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
,
masked_lm_positions
=
dummy_pos
)
self
.
assertEqual
(
outputs
[
"classification"
].
shape
,
(
10
,
2
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
dummy_pos
=
tf
.
ones
((
5
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
,
"masked_lm_positions"
:
dummy_pos
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"classification"
].
shape
,
(
10
,
2
))
@
parameterized
.
parameters
(
# use_v2_feature_names
True
,
False
)
def
test_question_answering
(
self
,
use_v2_feature_names
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
config
=
question_answering
.
QuestionAnsweringConfig
(
model
=
question_answering
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
))),
validation_data
=
None
)
task
=
question_answering
.
QuestionAnsweringTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
QuestionAnswering
.
Params
(
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
)
export_module
=
serving_modules
.
QuestionAnswering
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
self
.
assertSameElements
(
functions
.
keys
(),
[
"serving_default"
,
"serving_examples"
])
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
self
.
assertEqual
(
outputs
[
"start_logits"
].
shape
,
(
10
,
10
))
self
.
assertEqual
(
outputs
[
"end_logits"
].
shape
,
(
10
,
10
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"start_logits"
].
shape
,
(
10
,
10
))
self
.
assertEqual
(
outputs
[
"end_logits"
].
shape
,
(
10
,
10
))
@
parameterized
.
parameters
(
# (use_v2_feature_names, output_encoder_outputs)
(
True
,
True
),
(
False
,
False
))
def
test_tagging
(
self
,
use_v2_feature_names
,
output_encoder_outputs
):
if
use_v2_feature_names
:
input_word_ids_field
=
"input_word_ids"
input_type_ids_field
=
"input_type_ids"
else
:
input_word_ids_field
=
"input_ids"
input_type_ids_field
=
"segment_ids"
hidden_size
=
768
num_classes
=
3
config
=
tagging
.
TaggingConfig
(
model
=
tagging
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
hidden_size
=
hidden_size
,
num_layers
=
1
))),
class_names
=
[
"class_0"
,
"class_1"
,
"class_2"
])
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
Tagging
.
Params
(
parse_sequence_length
=
10
,
use_v2_feature_names
=
use_v2_feature_names
,
output_encoder_outputs
=
output_encoder_outputs
)
export_module
=
serving_modules
.
Tagging
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve"
:
"serving_default"
,
"serve_examples"
:
"serving_examples"
})
dummy_ids
=
tf
.
ones
((
10
,
10
),
dtype
=
tf
.
int32
)
outputs
=
functions
[
"serving_default"
](
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
self
.
assertEqual
(
outputs
[
"logits"
].
shape
,
(
10
,
10
,
num_classes
))
if
output_encoder_outputs
:
self
.
assertEqual
(
outputs
[
"encoder_outputs"
].
shape
,
(
10
,
10
,
hidden_size
))
dummy_ids
=
tf
.
ones
((
10
,),
dtype
=
tf
.
int32
)
examples
=
_create_fake_serialized_examples
({
input_word_ids_field
:
dummy_ids
,
"input_mask"
:
dummy_ids
,
input_type_ids_field
:
dummy_ids
})
outputs
=
functions
[
"serving_examples"
](
examples
)
self
.
assertEqual
(
outputs
[
"logits"
].
shape
,
(
10
,
10
,
num_classes
))
if
output_encoder_outputs
:
self
.
assertEqual
(
outputs
[
"encoder_outputs"
].
shape
,
(
10
,
10
,
hidden_size
))
with
self
.
assertRaises
(
ValueError
):
_
=
export_module
.
get_inference_signatures
({
"foo"
:
None
})
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment