Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
02d00c0c
Commit
02d00c0c
authored
Aug 04, 2022
by
A. Unique TensorFlower
Browse files
Enable multiple inputs for tf.Vision RetinaNet model.
PiperOrigin-RevId: 465434866
parent
d4efa810
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
6 deletions
+17
-6
official/vision/modeling/retinanet_model.py
official/vision/modeling/retinanet_model.py
+17
-6
No files found.
official/vision/modeling/retinanet_model.py
View file @
02d00c0c
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
"""RetinaNet."""
from
typing
import
Any
,
Mapping
,
List
,
Optional
,
Union
from
typing
import
Any
,
Mapping
,
List
,
Optional
,
Union
,
Sequence
# Import libraries
import
tensorflow
as
tf
...
...
@@ -74,7 +74,7 @@ class RetinaNetModel(tf.keras.Model):
self
.
_detection_generator
=
detection_generator
def
call
(
self
,
images
:
tf
.
Tensor
,
images
:
Union
[
tf
.
Tensor
,
Sequence
[
tf
.
Tensor
]],
image_shape
:
Optional
[
tf
.
Tensor
]
=
None
,
anchor_boxes
:
Optional
[
Mapping
[
str
,
tf
.
Tensor
]]
=
None
,
output_intermediate_features
:
bool
=
False
,
...
...
@@ -82,8 +82,10 @@ class RetinaNetModel(tf.keras.Model):
"""Forward pass of the RetinaNet model.
Args:
images: `Tensor`, the input batched images, whose shape is
[batch, height, width, 3].
images: `Tensor` or a sequence of `Tensor`, the input batched images to
the backbone network, whose shape(s) is [batch, height, width, 3]. If it
is a sequence of `Tensor`, we will assume the anchors are generated
based on the shape of the first image(s).
image_shape: `Tensor`, the actual shape of the input images, whose shape
is [batch, 2] where the last dimension is [height, width]. Note that
this is the actual image shape excluding paddings. For example, images
...
...
@@ -141,7 +143,16 @@ class RetinaNetModel(tf.keras.Model):
else
:
# Generate anchor boxes for this batch if not provided.
if
anchor_boxes
is
None
:
_
,
image_height
,
image_width
,
_
=
images
.
get_shape
().
as_list
()
if
isinstance
(
images
,
Sequence
):
primary_images
=
images
[
0
]
elif
isinstance
(
images
,
tf
.
Tensor
):
primary_images
=
images
else
:
raise
ValueError
(
'Input should be a tf.Tensor or a sequence of tf.Tensor, not {}.'
.
format
(
type
(
images
)))
_
,
image_height
,
image_width
,
_
=
primary_images
.
get_shape
().
as_list
()
anchor_boxes
=
anchor
.
Anchor
(
min_level
=
self
.
_config_dict
[
'min_level'
],
max_level
=
self
.
_config_dict
[
'max_level'
],
...
...
@@ -152,7 +163,7 @@ class RetinaNetModel(tf.keras.Model):
for
l
in
anchor_boxes
:
anchor_boxes
[
l
]
=
tf
.
tile
(
tf
.
expand_dims
(
anchor_boxes
[
l
],
axis
=
0
),
[
tf
.
shape
(
images
)[
0
],
1
,
1
,
1
])
[
tf
.
shape
(
primary_
images
)[
0
],
1
,
1
,
1
])
# Post-processing.
final_results
=
self
.
detection_generator
(
raw_boxes
,
raw_scores
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment