Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
833e6939
Commit
833e6939
authored
Nov 06, 2019
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 278969067
parent
bf0dc049
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
3 deletions
+17
-3
official/vision/detection/main.py
official/vision/detection/main.py
+5
-0
official/vision/detection/modeling/base_model.py
official/vision/detection/modeling/base_model.py
+0
-1
official/vision/detection/modeling/postprocess.py
official/vision/detection/modeling/postprocess.py
+2
-1
official/vision/detection/modeling/retinanet_model.py
official/vision/detection/modeling/retinanet_model.py
+10
-1
No files found.
official/vision/detection/main.py
View file @
833e6939
...
...
@@ -64,6 +64,11 @@ def run_executor(params,
callbacks
=
None
):
"""Runs Retinanet model on distribution strategy defined by the user."""
if
params
.
architecture
.
use_bfloat16
:
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_bfloat16'
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
model_builder
=
model_factory
.
model_generator
(
params
)
if
FLAGS
.
mode
==
'train'
:
...
...
official/vision/detection/modeling/base_model.py
View file @
833e6939
...
...
@@ -85,7 +85,6 @@ class Model(object):
def
__init__
(
self
,
params
):
self
.
_use_bfloat16
=
params
.
architecture
.
use_bfloat16
assert
not
self
.
_use_bfloat16
,
'bfloat16 is not supported in Keras yet.'
# Optimization.
self
.
_optimizer_fn
=
OptimizerFactory
(
params
.
train
.
optimizer
)
...
...
official/vision/detection/modeling/postprocess.py
View file @
833e6939
...
...
@@ -318,7 +318,8 @@ class GenerateOneStageDetections(tf.keras.layers.Layer):
boxes
=
tf
.
expand_dims
(
boxes
,
axis
=
2
)
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
self
.
_generate_detections
(
boxes
,
scores
)
valid_detections
)
=
self
.
_generate_detections
(
tf
.
cast
(
boxes
,
tf
.
float32
),
tf
.
cast
(
scores
,
tf
.
float32
))
# Adds 1 to offset the background class which has index 0.
nmsed_classes
+=
1
return
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
official/vision/detection/modeling/retinanet_model.py
View file @
833e6939
...
...
@@ -92,7 +92,9 @@ class RetinanetModel(base_model.Model):
input_shape
=
(
params
.
retinanet_parser
.
output_size
+
[
params
.
retinanet_parser
.
num_channels
])
self
.
_input_layer
=
tf
.
keras
.
layers
.
Input
(
shape
=
input_shape
,
name
=
''
)
self
.
_input_layer
=
tf
.
keras
.
layers
.
Input
(
shape
=
input_shape
,
name
=
''
,
dtype
=
tf
.
bfloat16
if
self
.
_use_bfloat16
else
tf
.
float32
)
def
build_outputs
(
self
,
inputs
,
mode
):
backbone_features
=
self
.
_backbone_fn
(
...
...
@@ -101,6 +103,13 @@ class RetinanetModel(base_model.Model):
backbone_features
,
is_training
=
(
mode
==
mode_keys
.
TRAIN
))
cls_outputs
,
box_outputs
=
self
.
_head_fn
(
fpn_features
,
is_training
=
(
mode
==
mode_keys
.
TRAIN
))
if
self
.
_use_bfloat16
:
levels
=
cls_outputs
.
keys
()
for
level
in
levels
:
cls_outputs
[
level
]
=
tf
.
cast
(
cls_outputs
[
level
],
tf
.
float32
)
box_outputs
[
level
]
=
tf
.
cast
(
box_outputs
[
level
],
tf
.
float32
)
model_outputs
=
{
'cls_outputs'
:
cls_outputs
,
'box_outputs'
:
box_outputs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment