Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8b641b13
Unverified
Commit
8b641b13
authored
Mar 26, 2022
by
Srihari Humbarwadi
Committed by
GitHub
Mar 26, 2022
Browse files
Merge branch 'tensorflow:master' into panoptic-deeplab
parents
7cffacfe
357fa547
Changes
411
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
29 additions
and
251 deletions
+29
-251
official/vision/beta/projects/yolo/configs/yolo.py
official/vision/beta/projects/yolo/configs/yolo.py
+1
-1
official/vision/beta/projects/yolo/dataloaders/classification_input.py
...on/beta/projects/yolo/dataloaders/classification_input.py
+2
-2
official/vision/beta/projects/yolo/dataloaders/tf_example_decoder.py
...sion/beta/projects/yolo/dataloaders/tf_example_decoder.py
+1
-1
official/vision/beta/projects/yolo/dataloaders/yolo_input.py
official/vision/beta/projects/yolo/dataloaders/yolo_input.py
+4
-4
official/vision/beta/projects/yolo/modeling/backbones/darknet.py
...l/vision/beta/projects/yolo/modeling/backbones/darknet.py
+2
-1
official/vision/beta/projects/yolo/modeling/backbones/darknet_test.py
...ion/beta/projects/yolo/modeling/backbones/darknet_test.py
+0
-1
official/vision/beta/projects/yolo/modeling/decoders/yolo_decoder.py
...sion/beta/projects/yolo/modeling/decoders/yolo_decoder.py
+2
-2
official/vision/beta/projects/yolo/modeling/decoders/yolo_decoder_test.py
...beta/projects/yolo/modeling/decoders/yolo_decoder_test.py
+0
-1
official/vision/beta/projects/yolo/modeling/factory.py
official/vision/beta/projects/yolo/modeling/factory.py
+2
-2
official/vision/beta/projects/yolo/modeling/heads/yolo_head_test.py
...ision/beta/projects/yolo/modeling/heads/yolo_head_test.py
+0
-1
official/vision/beta/projects/yolo/modeling/layers/detection_generator.py
...beta/projects/yolo/modeling/layers/detection_generator.py
+1
-1
official/vision/beta/projects/yolo/modeling/layers/nn_blocks.py
...al/vision/beta/projects/yolo/modeling/layers/nn_blocks.py
+1
-1
official/vision/beta/projects/yolo/modeling/layers/nn_blocks_test.py
...sion/beta/projects/yolo/modeling/layers/nn_blocks_test.py
+0
-1
official/vision/beta/projects/yolo/ops/mosaic.py
official/vision/beta/projects/yolo/ops/mosaic.py
+3
-2
official/vision/beta/projects/yolo/ops/preprocessing_ops.py
official/vision/beta/projects/yolo/ops/preprocessing_ops.py
+1
-1
official/vision/beta/projects/yolo/ops/preprocessing_ops_test.py
...l/vision/beta/projects/yolo/ops/preprocessing_ops_test.py
+1
-1
official/vision/beta/projects/yolo/tasks/image_classification.py
...l/vision/beta/projects/yolo/tasks/image_classification.py
+4
-4
official/vision/beta/projects/yolo/tasks/yolo.py
official/vision/beta/projects/yolo/tasks/yolo.py
+4
-4
official/vision/beta/serving/__init__.py
official/vision/beta/serving/__init__.py
+0
-14
official/vision/beta/serving/detection.py
official/vision/beta/serving/detection.py
+0
-206
No files found.
Too many changes to show.
To preserve performance only
411 of 411+
files are displayed.
Plain diff
Email patch
official/vision/beta/projects/yolo/configs/yolo.py
View file @
8b641b13
...
@@ -22,10 +22,10 @@ import numpy as np
...
@@ -22,10 +22,10 @@ import numpy as np
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.vision.beta.configs
import
common
from
official.vision.beta.projects.yolo
import
optimization
from
official.vision.beta.projects.yolo
import
optimization
from
official.vision.beta.projects.yolo.configs
import
backbones
from
official.vision.beta.projects.yolo.configs
import
backbones
from
official.vision.beta.projects.yolo.configs
import
decoders
from
official.vision.beta.projects.yolo.configs
import
decoders
from
official.vision.configs
import
common
# pytype: disable=annotation-type-mismatch
# pytype: disable=annotation-type-mismatch
...
...
official/vision/beta/projects/yolo/dataloaders/classification_input.py
View file @
8b641b13
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
"""Classification decoder and parser."""
"""Classification decoder and parser."""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.
beta.
dataloaders
import
classification_input
from
official.vision.dataloaders
import
classification_input
from
official.vision.
beta.
ops
import
preprocess_ops
from
official.vision.ops
import
preprocess_ops
class
Parser
(
classification_input
.
Parser
):
class
Parser
(
classification_input
.
Parser
):
...
...
official/vision/beta/projects/yolo/dataloaders/tf_example_decoder.py
View file @
8b641b13
...
@@ -19,7 +19,7 @@ protos for object detection.
...
@@ -19,7 +19,7 @@ protos for object detection.
"""
"""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.
beta.
dataloaders
import
tf_example_decoder
from
official.vision.dataloaders
import
tf_example_decoder
def
_coco91_to_80
(
classif
,
box
,
areas
,
iscrowds
):
def
_coco91_to_80
(
classif
,
box
,
areas
,
iscrowds
):
...
...
official/vision/beta/projects/yolo/dataloaders/yolo_input.py
View file @
8b641b13
...
@@ -15,12 +15,12 @@
...
@@ -15,12 +15,12 @@
"""Detection Data parser and processing for YOLO."""
"""Detection Data parser and processing for YOLO."""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.dataloaders
import
utils
from
official.vision.beta.ops
import
box_ops
as
bbox_ops
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.projects.yolo.ops
import
anchor
from
official.vision.beta.projects.yolo.ops
import
anchor
from
official.vision.beta.projects.yolo.ops
import
preprocessing_ops
from
official.vision.beta.projects.yolo.ops
import
preprocessing_ops
from
official.vision.dataloaders
import
parser
from
official.vision.dataloaders
import
utils
from
official.vision.ops
import
box_ops
as
bbox_ops
from
official.vision.ops
import
preprocess_ops
class
Parser
(
parser
.
Parser
):
class
Parser
(
parser
.
Parser
):
...
...
official/vision/beta/projects/yolo/modeling/backbones/darknet.py
View file @
8b641b13
...
@@ -36,11 +36,12 @@ Darknets are used mainly for object detection in:
...
@@ -36,11 +36,12 @@ Darknets are used mainly for object detection in:
"""
"""
import
collections
import
collections
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.projects.yolo.modeling.layers
import
nn_blocks
from
official.vision.beta.projects.yolo.modeling.layers
import
nn_blocks
from
official.vision.modeling.backbones
import
factory
class
BlockConfig
:
class
BlockConfig
:
...
...
official/vision/beta/projects/yolo/modeling/backbones/darknet_test.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for yolo."""
"""Tests for yolo."""
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
...
...
official/vision/beta/projects/yolo/modeling/decoders/yolo_decoder.py
View file @
8b641b13
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
# limitations under the License.
# limitations under the License.
"""Feature Pyramid Network and Path Aggregation variants used in YOLO."""
"""Feature Pyramid Network and Path Aggregation variants used in YOLO."""
from
typing
import
Mapping
,
Union
,
Optional
from
typing
import
Mapping
,
Optional
,
Union
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.vision.beta.modeling.decoders
import
factory
from
official.vision.beta.projects.yolo.modeling.layers
import
nn_blocks
from
official.vision.beta.projects.yolo.modeling.layers
import
nn_blocks
from
official.vision.modeling.decoders
import
factory
# model configurations
# model configurations
# the structure is as follows. model version, {v3, v4, v#, ... etc}
# the structure is as follows. model version, {v3, v4, v#, ... etc}
...
...
official/vision/beta/projects/yolo/modeling/decoders/yolo_decoder_test.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for YOLO."""
"""Tests for YOLO."""
# Import libraries
# Import libraries
...
...
official/vision/beta/projects/yolo/modeling/factory.py
View file @
8b641b13
...
@@ -15,13 +15,13 @@
...
@@ -15,13 +15,13 @@
"""Contains common factory functions yolo neural networks."""
"""Contains common factory functions yolo neural networks."""
from
absl
import
logging
from
absl
import
logging
from
official.vision.beta.modeling.backbones
import
factory
as
backbone_factory
from
official.vision.beta.modeling.decoders
import
factory
as
decoder_factory
from
official.vision.beta.projects.yolo.configs
import
yolo
from
official.vision.beta.projects.yolo.configs
import
yolo
from
official.vision.beta.projects.yolo.modeling
import
yolo_model
from
official.vision.beta.projects.yolo.modeling
import
yolo_model
from
official.vision.beta.projects.yolo.modeling.heads
import
yolo_head
from
official.vision.beta.projects.yolo.modeling.heads
import
yolo_head
from
official.vision.beta.projects.yolo.modeling.layers
import
detection_generator
from
official.vision.beta.projects.yolo.modeling.layers
import
detection_generator
from
official.vision.modeling.backbones
import
factory
as
backbone_factory
from
official.vision.modeling.decoders
import
factory
as
decoder_factory
def
build_yolo_detection_generator
(
model_config
:
yolo
.
Yolo
,
anchor_boxes
):
def
build_yolo_detection_generator
(
model_config
:
yolo
.
Yolo
,
anchor_boxes
):
...
...
official/vision/beta/projects/yolo/modeling/heads/yolo_head_test.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for yolo heads."""
"""Tests for yolo heads."""
# Import libraries
# Import libraries
...
...
official/vision/beta/projects/yolo/modeling/layers/detection_generator.py
View file @
8b641b13
...
@@ -15,10 +15,10 @@
...
@@ -15,10 +15,10 @@
"""Contains common building blocks for yolo layer (detection layer)."""
"""Contains common building blocks for yolo layer (detection layer)."""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.modeling.layers
import
detection_generator
from
official.vision.beta.projects.yolo.losses
import
yolo_loss
from
official.vision.beta.projects.yolo.losses
import
yolo_loss
from
official.vision.beta.projects.yolo.ops
import
box_ops
from
official.vision.beta.projects.yolo.ops
import
box_ops
from
official.vision.beta.projects.yolo.ops
import
loss_utils
from
official.vision.beta.projects.yolo.ops
import
loss_utils
from
official.vision.modeling.layers
import
detection_generator
class
YoloLayer
(
tf
.
keras
.
Model
):
class
YoloLayer
(
tf
.
keras
.
Model
):
...
...
official/vision/beta/projects/yolo/modeling/layers/nn_blocks.py
View file @
8b641b13
...
@@ -18,7 +18,7 @@ from typing import Callable, List, Tuple
...
@@ -18,7 +18,7 @@ from typing import Callable, List, Tuple
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.vision.
beta.
ops
import
spatial_transform_ops
from
official.vision.ops
import
spatial_transform_ops
class
Identity
(
tf
.
keras
.
layers
.
Layer
):
class
Identity
(
tf
.
keras
.
layers
.
Layer
):
...
...
official/vision/beta/projects/yolo/modeling/layers/nn_blocks_test.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/vision/beta/projects/yolo/ops/mosaic.py
View file @
8b641b13
...
@@ -14,12 +14,13 @@
...
@@ -14,12 +14,13 @@
"""Mosaic op."""
"""Mosaic op."""
import
random
import
random
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_addons
as
tfa
import
tensorflow_addons
as
tfa
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.projects.yolo.ops
import
preprocessing_ops
from
official.vision.beta.projects.yolo.ops
import
preprocessing_ops
from
official.vision.ops
import
box_ops
from
official.vision.ops
import
preprocess_ops
class
Mosaic
:
class
Mosaic
:
...
...
official/vision/beta/projects/yolo/ops/preprocessing_ops.py
View file @
8b641b13
...
@@ -19,7 +19,7 @@ import numpy as np
...
@@ -19,7 +19,7 @@ import numpy as np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_addons
as
tfa
import
tensorflow_addons
as
tfa
from
official.vision.
beta.
ops
import
box_ops
as
bbox_ops
from
official.vision.ops
import
box_ops
as
bbox_ops
PAD_VALUE
=
114
PAD_VALUE
=
114
GLOBAL_SEED_SET
=
False
GLOBAL_SEED_SET
=
False
...
...
official/vision/beta/projects/yolo/ops/preprocessing_ops_test.py
View file @
8b641b13
...
@@ -17,8 +17,8 @@ from absl.testing import parameterized
...
@@ -17,8 +17,8 @@ from absl.testing import parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.ops
import
box_ops
as
bbox_ops
from
official.vision.beta.projects.yolo.ops
import
preprocessing_ops
from
official.vision.beta.projects.yolo.ops
import
preprocessing_ops
from
official.vision.ops
import
box_ops
as
bbox_ops
class
InputUtilsTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
InputUtilsTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/vision/beta/projects/yolo/tasks/image_classification.py
View file @
8b641b13
...
@@ -15,12 +15,12 @@
...
@@ -15,12 +15,12 @@
"""Image classification task definition."""
"""Image classification task definition."""
from
official.common
import
dataset_fn
from
official.common
import
dataset_fn
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.vision.beta.dataloaders
import
classification_input
as
classification_input_base
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.projects.yolo.configs
import
darknet_classification
as
exp_cfg
from
official.vision.beta.projects.yolo.configs
import
darknet_classification
as
exp_cfg
from
official.vision.beta.projects.yolo.dataloaders
import
classification_input
from
official.vision.beta.projects.yolo.dataloaders
import
classification_input
from
official.vision.beta.tasks
import
image_classification
from
official.vision.dataloaders
import
classification_input
as
classification_input_base
from
official.vision.dataloaders
import
input_reader_factory
from
official.vision.dataloaders
import
tfds_factory
from
official.vision.tasks
import
image_classification
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
...
...
official/vision/beta/projects/yolo/tasks/yolo.py
View file @
8b641b13
...
@@ -26,10 +26,6 @@ from official.core import config_definitions
...
@@ -26,10 +26,6 @@ from official.core import config_definitions
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.projects.yolo
import
optimization
from
official.vision.beta.projects.yolo
import
optimization
from
official.vision.beta.projects.yolo.configs
import
yolo
as
exp_cfg
from
official.vision.beta.projects.yolo.configs
import
yolo
as
exp_cfg
from
official.vision.beta.projects.yolo.dataloaders
import
tf_example_decoder
from
official.vision.beta.projects.yolo.dataloaders
import
tf_example_decoder
...
@@ -39,6 +35,10 @@ from official.vision.beta.projects.yolo.ops import kmeans_anchors
...
@@ -39,6 +35,10 @@ from official.vision.beta.projects.yolo.ops import kmeans_anchors
from
official.vision.beta.projects.yolo.ops
import
mosaic
from
official.vision.beta.projects.yolo.ops
import
mosaic
from
official.vision.beta.projects.yolo.ops
import
preprocessing_ops
from
official.vision.beta.projects.yolo.ops
import
preprocessing_ops
from
official.vision.beta.projects.yolo.tasks
import
task_utils
from
official.vision.beta.projects.yolo.tasks
import
task_utils
from
official.vision.dataloaders
import
tfds_factory
from
official.vision.dataloaders.google
import
tf_example_label_map_decoder
from
official.vision.evaluation
import
coco_evaluator
from
official.vision.ops
import
box_ops
OptimizationConfig
=
optimization
.
OptimizationConfig
OptimizationConfig
=
optimization
.
OptimizationConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
...
...
official/vision/beta/serving/__init__.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/beta/serving/detection.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Detection input and model functions for serving/inference."""
from
typing
import
Mapping
,
Text
import
tensorflow
as
tf
from
official.vision.beta
import
configs
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.ops
import
anchor
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.serving
import
export_base
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
STDDEV_RGB
=
(
0.229
*
255
,
0.224
*
255
,
0.225
*
255
)
class
DetectionModule
(
export_base
.
ExportModule
):
"""Detection Module."""
def
_build_model
(
self
):
if
self
.
_batch_size
is
None
:
raise
ValueError
(
'batch_size cannot be None for detection models.'
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
])
if
isinstance
(
self
.
params
.
task
.
model
,
configs
.
maskrcnn
.
MaskRCNN
):
model
=
factory
.
build_maskrcnn
(
input_specs
=
input_specs
,
model_config
=
self
.
params
.
task
.
model
)
elif
isinstance
(
self
.
params
.
task
.
model
,
configs
.
retinanet
.
RetinaNet
):
model
=
factory
.
build_retinanet
(
input_specs
=
input_specs
,
model_config
=
self
.
params
.
task
.
model
)
else
:
raise
ValueError
(
'Detection module not implemented for {} model.'
.
format
(
type
(
self
.
params
.
task
.
model
)))
return
model
def
_build_anchor_boxes
(
self
):
"""Builds and returns anchor boxes."""
model_params
=
self
.
params
.
task
.
model
input_anchor
=
anchor
.
build_anchor_generator
(
min_level
=
model_params
.
min_level
,
max_level
=
model_params
.
max_level
,
num_scales
=
model_params
.
anchor
.
num_scales
,
aspect_ratios
=
model_params
.
anchor
.
aspect_ratios
,
anchor_size
=
model_params
.
anchor
.
anchor_size
)
return
input_anchor
(
image_size
=
(
self
.
_input_image_size
[
0
],
self
.
_input_image_size
[
1
]))
def
_build_inputs
(
self
,
image
):
"""Builds detection model inputs for serving."""
model_params
=
self
.
params
.
task
.
model
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
image
,
image_info
=
preprocess_ops
.
resize_and_crop_image
(
image
,
self
.
_input_image_size
,
padded_size
=
preprocess_ops
.
compute_padded_size
(
self
.
_input_image_size
,
2
**
model_params
.
max_level
),
aug_scale_min
=
1.0
,
aug_scale_max
=
1.0
)
anchor_boxes
=
self
.
_build_anchor_boxes
()
return
image
,
anchor_boxes
,
image_info
def
preprocess
(
self
,
images
:
tf
.
Tensor
)
->
(
tf
.
Tensor
,
Mapping
[
Text
,
tf
.
Tensor
],
tf
.
Tensor
):
"""Preprocess inputs to be suitable for the model.
Args:
images: The images tensor.
Returns:
images: The images tensor cast to float.
anchor_boxes: Dict mapping anchor levels to anchor boxes.
image_info: Tensor containing the details of the image resizing.
"""
model_params
=
self
.
params
.
task
.
model
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
cast
(
images
,
dtype
=
tf
.
float32
)
# Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info).
images_spec
=
tf
.
TensorSpec
(
shape
=
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
float32
)
num_anchors
=
model_params
.
anchor
.
num_scales
*
len
(
model_params
.
anchor
.
aspect_ratios
)
*
4
anchor_shapes
=
[]
for
level
in
range
(
model_params
.
min_level
,
model_params
.
max_level
+
1
):
anchor_level_spec
=
tf
.
TensorSpec
(
shape
=
[
self
.
_input_image_size
[
0
]
//
2
**
level
,
self
.
_input_image_size
[
1
]
//
2
**
level
,
num_anchors
],
dtype
=
tf
.
float32
)
anchor_shapes
.
append
((
str
(
level
),
anchor_level_spec
))
image_info_spec
=
tf
.
TensorSpec
(
shape
=
[
4
,
2
],
dtype
=
tf
.
float32
)
images
,
anchor_boxes
,
image_info
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
self
.
_build_inputs
,
elems
=
images
,
fn_output_signature
=
(
images_spec
,
dict
(
anchor_shapes
),
image_info_spec
),
parallel_iterations
=
32
))
return
images
,
anchor_boxes
,
image_info
def
serve
(
self
,
images
:
tf
.
Tensor
):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding detection output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if
self
.
_input_type
!=
'tflite'
:
images
,
anchor_boxes
,
image_info
=
self
.
preprocess
(
images
)
else
:
with
tf
.
device
(
'cpu:0'
):
anchor_boxes
=
self
.
_build_anchor_boxes
()
# image_info is a 3D tensor of shape [batch_size, 4, 2]. It is in the
# format of [[original_height, original_width],
# [desired_height, desired_width], [y_scale, x_scale],
# [y_offset, x_offset]]. When input_type is tflite, input image is
# supposed to be preprocessed already.
image_info
=
tf
.
convert_to_tensor
([[
self
.
_input_image_size
,
self
.
_input_image_size
,
[
1.0
,
1.0
],
[
0
,
0
]
]],
dtype
=
tf
.
float32
)
input_image_shape
=
image_info
[:,
1
,
:]
# To overcome keras.Model extra limitation to save a model with layers that
# have multiple inputs, we use `model.call` here to trigger the forward
# path. Note that, this disables some keras magics happens in `__call__`.
detections
=
self
.
model
.
call
(
images
=
images
,
image_shape
=
input_image_shape
,
anchor_boxes
=
anchor_boxes
,
training
=
False
)
if
self
.
params
.
task
.
model
.
detection_generator
.
apply_nms
:
# For RetinaNet model, apply export_config.
# TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed.
if
isinstance
(
self
.
params
.
task
.
model
,
configs
.
retinanet
.
RetinaNet
):
export_config
=
self
.
params
.
task
.
export_config
# Normalize detection box coordinates to [0, 1].
if
export_config
.
output_normalized_coordinates
:
detection_boxes
=
(
detections
[
'detection_boxes'
]
/
tf
.
tile
(
image_info
[:,
2
:
3
,
:],
[
1
,
1
,
2
]))
detections
[
'detection_boxes'
]
=
box_ops
.
normalize_boxes
(
detection_boxes
,
image_info
[:,
0
:
1
,
:])
# Cast num_detections and detection_classes to float. This allows the
# model inference to work on chain (go/chain) as chain requires floating
# point outputs.
if
export_config
.
cast_num_detections_to_float
:
detections
[
'num_detections'
]
=
tf
.
cast
(
detections
[
'num_detections'
],
dtype
=
tf
.
float32
)
if
export_config
.
cast_detection_classes_to_float
:
detections
[
'detection_classes'
]
=
tf
.
cast
(
detections
[
'detection_classes'
],
dtype
=
tf
.
float32
)
final_outputs
=
{
'detection_boxes'
:
detections
[
'detection_boxes'
],
'detection_scores'
:
detections
[
'detection_scores'
],
'detection_classes'
:
detections
[
'detection_classes'
],
'num_detections'
:
detections
[
'num_detections'
]
}
else
:
final_outputs
=
{
'decoded_boxes'
:
detections
[
'decoded_boxes'
],
'decoded_box_scores'
:
detections
[
'decoded_box_scores'
]
}
if
'detection_masks'
in
detections
.
keys
():
final_outputs
[
'detection_masks'
]
=
detections
[
'detection_masks'
]
final_outputs
.
update
({
'image_info'
:
image_info
})
return
final_outputs
Prev
1
…
16
17
18
19
20
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment