Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
eb6e0ac4
Commit
eb6e0ac4
authored
May 09, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 447598773
parent
d801c959
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
76 additions
and
7 deletions
+76
-7
official/vision/beta/projects/yolo/serving/export_module_factory.py
...ision/beta/projects/yolo/serving/export_module_factory.py
+76
-7
No files found.
official/vision/beta/projects/yolo/serving/export_module_factory.py
View file @
eb6e0ac4
...
@@ -14,11 +14,12 @@
...
@@ -14,11 +14,12 @@
"""Factory for YOLO export modules."""
"""Factory for YOLO export modules."""
from
typing
import
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Text
,
Union
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
export_base
from
official.vision
import
configs
from
official.vision
import
configs
from
official.vision.beta.projects.yolo.configs.yolo
import
YoloTask
from
official.vision.beta.projects.yolo.configs.yolo
import
YoloTask
from
official.vision.beta.projects.yolo.modeling
import
factory
as
yolo_factory
from
official.vision.beta.projects.yolo.modeling
import
factory
as
yolo_factory
...
@@ -27,16 +28,84 @@ from official.vision.beta.projects.yolo.modeling.decoders import yolo_decoder #
...
@@ -27,16 +28,84 @@ from official.vision.beta.projects.yolo.modeling.decoders import yolo_decoder #
from
official.vision.beta.projects.yolo.serving
import
model_fn
as
yolo_model_fn
from
official.vision.beta.projects.yolo.serving
import
model_fn
as
yolo_model_fn
from
official.vision.dataloaders
import
classification_input
from
official.vision.dataloaders
import
classification_input
from
official.vision.modeling
import
factory
from
official.vision.modeling
import
factory
from
official.vision.serving
import
export_base_v2
as
export_base
from
official.vision.serving
import
export_utils
from
official.vision.serving
import
export_utils
class
ExportModule
(
export_base
.
ExportModule
):
"""Base Export Module."""
def
__init__
(
self
,
params
:
cfg
.
ExperimentConfig
,
model
:
tf
.
keras
.
Model
,
input_signature
:
Union
[
tf
.
TensorSpec
,
Dict
[
str
,
tf
.
TensorSpec
]],
preprocessor
:
Optional
[
Callable
[...,
Any
]]
=
None
,
inference_step
:
Optional
[
Callable
[...,
Any
]]
=
None
,
postprocessor
:
Optional
[
Callable
[...,
Any
]]
=
None
,
eval_postprocessor
:
Optional
[
Callable
[...,
Any
]]
=
None
):
"""Initializes a module for export.
Args:
params: A dataclass for parameters to the module.
model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model.
postprocessor: An optional callable to postprocess the model outputs.
eval_postprocessor: An optional callable to postprocess model outputs
used for model evaluation.
"""
super
().
__init__
(
params
,
model
=
model
,
preprocessor
=
preprocessor
,
inference_step
=
inference_step
,
postprocessor
=
postprocessor
)
self
.
input_signature
=
input_signature
@
tf
.
function
def
serve
(
self
,
inputs
:
Any
)
->
Any
:
x
=
self
.
preprocessor
(
inputs
=
inputs
)
if
self
.
preprocessor
else
inputs
x
=
self
.
inference_step
(
x
)
x
=
self
.
postprocessor
(
x
)
if
self
.
postprocessor
else
x
return
x
@
tf
.
function
def
serve_eval
(
self
,
inputs
:
Any
)
->
Any
:
x
=
self
.
preprocessor
(
inputs
=
inputs
)
if
self
.
preprocessor
else
inputs
x
=
self
.
inference_step
(
x
)
x
=
self
.
eval_postprocessor
(
x
)
if
self
.
eval_postprocessor
else
x
return
x
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
])
->
dict
[
Text
,
Any
]:
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures
=
{}
for
_
,
def_name
in
function_keys
.
items
():
if
'eval'
in
def_name
and
self
.
eval_postprocessor
:
signatures
[
def_name
]
=
self
.
serve_eval
.
get_concrete_function
(
self
.
input_signature
)
else
:
signatures
[
def_name
]
=
self
.
serve
.
get_concrete_function
(
self
.
input_signature
)
return
signatures
def
create_classification_export_module
(
def
create_classification_export_module
(
params
:
cfg
.
ExperimentConfig
,
params
:
cfg
.
ExperimentConfig
,
input_type
:
str
,
input_type
:
str
,
batch_size
:
int
,
batch_size
:
int
,
input_image_size
:
List
[
int
],
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
)
->
export_base
.
ExportModule
:
num_channels
:
int
=
3
)
->
ExportModule
:
"""Creates classification export module."""
"""Creates classification export module."""
input_signature
=
export_utils
.
get_image_input_signatures
(
input_signature
=
export_utils
.
get_image_input_signatures
(
input_type
,
batch_size
,
input_image_size
,
num_channels
)
input_type
,
batch_size
,
input_image_size
,
num_channels
)
...
@@ -71,7 +140,7 @@ def create_classification_export_module(
...
@@ -71,7 +140,7 @@ def create_classification_export_module(
probs
=
tf
.
nn
.
softmax
(
logits
)
probs
=
tf
.
nn
.
softmax
(
logits
)
return
{
'logits'
:
logits
,
'probs'
:
probs
}
return
{
'logits'
:
logits
,
'probs'
:
probs
}
export_module
=
export_base
.
ExportModule
(
export_module
=
ExportModule
(
params
,
params
,
model
=
model
,
model
=
model
,
input_signature
=
input_signature
,
input_signature
=
input_signature
,
...
@@ -85,7 +154,7 @@ def create_yolo_export_module(
...
@@ -85,7 +154,7 @@ def create_yolo_export_module(
input_type
:
str
,
input_type
:
str
,
batch_size
:
int
,
batch_size
:
int
,
input_image_size
:
List
[
int
],
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
)
->
export_base
.
ExportModule
:
num_channels
:
int
=
3
)
->
ExportModule
:
"""Creates YOLO export module."""
"""Creates YOLO export module."""
input_signature
=
export_utils
.
get_image_input_signatures
(
input_signature
=
export_utils
.
get_image_input_signatures
(
input_type
,
batch_size
,
input_image_size
,
num_channels
)
input_type
,
batch_size
,
input_image_size
,
num_channels
)
...
@@ -144,7 +213,7 @@ def create_yolo_export_module(
...
@@ -144,7 +213,7 @@ def create_yolo_export_module(
return
final_outputs
return
final_outputs
export_module
=
export_base
.
ExportModule
(
export_module
=
ExportModule
(
params
,
params
,
model
=
model
,
model
=
model
,
input_signature
=
input_signature
,
input_signature
=
input_signature
,
...
@@ -158,7 +227,7 @@ def get_export_module(params: cfg.ExperimentConfig,
...
@@ -158,7 +227,7 @@ def get_export_module(params: cfg.ExperimentConfig,
input_type
:
str
,
input_type
:
str
,
batch_size
:
Optional
[
int
],
batch_size
:
Optional
[
int
],
input_image_size
:
List
[
int
],
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
)
->
export_base
.
ExportModule
:
num_channels
:
int
=
3
)
->
ExportModule
:
"""Factory for export modules."""
"""Factory for export modules."""
if
isinstance
(
params
.
task
,
if
isinstance
(
params
.
task
,
configs
.
image_classification
.
ImageClassificationTask
):
configs
.
image_classification
.
ImageClassificationTask
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment