Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
706a0bd9
Commit
706a0bd9
authored
Feb 21, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 296470381
parent
a2e95a60
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
15 deletions
+20
-15
official/vision/detection/configs/maskrcnn_config.py
official/vision/detection/configs/maskrcnn_config.py
+3
-1
official/vision/detection/configs/retinanet_config.py
official/vision/detection/configs/retinanet_config.py
+3
-3
official/vision/detection/modeling/base_model.py
official/vision/detection/modeling/base_model.py
+13
-9
official/vision/detection/modeling/retinanet_model.py
official/vision/detection/modeling/retinanet_model.py
+1
-2
No files found.
official/vision/detection/configs/maskrcnn_config.py
View file @
706a0bd9
...
...
@@ -14,8 +14,9 @@
# ==============================================================================
"""Config template to train Mask R-CNN."""
from
official.vision.detection.configs
import
base_config
from
official.modeling.hyperparams
import
params_dict
from
official.vision.detection.configs
import
base_config
# pylint: disable=line-too-long
MASKRCNN_CFG
=
params_dict
.
ParamsDict
(
base_config
.
BASE_CFG
)
...
...
@@ -23,6 +24,7 @@ MASKRCNN_CFG.override({
'type'
:
'mask_rcnn'
,
'eval'
:
{
'type'
:
'box_and_mask'
,
'num_images_to_visualize'
:
0
,
},
'architecture'
:
{
'parser'
:
'maskrcnn_parser'
,
...
...
official/vision/detection/configs/retinanet_config.py
View file @
706a0bd9
...
...
@@ -23,9 +23,8 @@
# need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET50_FROZEN_VAR_PREFIX
=
r
'(resnet\d+/)conv2d(|_([1-9]|10))\/'
RESNET_FROZEN_VAR_PREFIX
=
r
'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX
=
r
'.*(kernel|weight):0$'
# pylint: disable=line-too-long
RETINANET_CFG
=
{
...
...
@@ -54,10 +53,11 @@ RETINANET_CFG = {
'path'
:
''
,
'prefix'
:
''
,
},
'frozen_variable_prefix'
:
RESNET
50
_FROZEN_VAR_PREFIX
,
'frozen_variable_prefix'
:
RESNET_FROZEN_VAR_PREFIX
,
'train_file_pattern'
:
''
,
# TODO(b/142174042): Support transpose_input option.
'transpose_input'
:
False
,
'regularization_variable_regex'
:
REGULARIZATION_VAR_REGEX
,
'l2_weight_decay'
:
0.0001
,
'input_sharding'
:
False
,
},
...
...
official/vision/detection/modeling/base_model.py
View file @
706a0bd9
...
...
@@ -18,11 +18,9 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
abc
import
functools
import
re
import
six
from
absl
import
logging
import
tensorflow.compat.v2
as
tf
...
...
@@ -53,7 +51,7 @@ class OptimizerFactory(object):
self
.
_optimizer
=
tf
.
keras
.
optimizers
.
Adagrad
elif
params
.
type
==
'rmsprop'
:
self
.
_optimizer
=
functools
.
partial
(
tf
.
keras
.
optimizers
.
RMS
P
rop
,
momentum
=
params
.
momentum
)
tf
.
keras
.
optimizers
.
RMS
p
rop
,
momentum
=
params
.
momentum
)
else
:
raise
ValueError
(
'Unsupported optimizer type %s.'
%
self
.
_optimizer
)
...
...
@@ -104,6 +102,7 @@ class Model(object):
params
.
train
.
learning_rate
)
self
.
_frozen_variable_prefix
=
params
.
train
.
frozen_variable_prefix
self
.
_regularization_var_regex
=
params
.
train
.
regularization_variable_regex
self
.
_l2_weight_decay
=
params
.
train
.
l2_weight_decay
# Checkpoint restoration.
...
...
@@ -146,12 +145,17 @@ class Model(object):
"""
return
_make_filter_trainable_variables_fn
(
self
.
_frozen_variable_prefix
)
def
weight_decay_loss
(
self
,
l2_weight_decay
,
trainable_variables
):
return
l2_weight_decay
*
tf
.
add_n
([
tf
.
nn
.
l2_loss
(
v
)
for
v
in
trainable_variables
if
'batch_normalization'
not
in
v
.
name
and
'bias'
not
in
v
.
name
])
def
weight_decay_loss
(
self
,
trainable_variables
):
reg_variables
=
[
v
for
v
in
trainable_variables
if
self
.
_regularization_var_regex
is
None
or
re
.
match
(
self
.
_regularization_var_regex
,
v
.
name
)
]
logging
.
info
(
'Regularization Variables: %s'
,
[
v
.
name
for
v
in
reg_variables
])
return
self
.
_l2_weight_decay
*
tf
.
add_n
(
[
tf
.
nn
.
l2_loss
(
v
)
for
v
in
reg_variables
])
def
make_restore_checkpoint_fn
(
self
):
"""Returns scaffold function to restore parameters from v1 checkpoint."""
...
...
official/vision/detection/modeling/retinanet_model.py
View file @
706a0bd9
...
...
@@ -106,8 +106,7 @@ class RetinanetModel(base_model.Model):
labels
[
'box_targets'
],
labels
[
'num_positives'
])
model_loss
=
cls_loss
+
self
.
_box_loss_weight
*
box_loss
l2_regularization_loss
=
self
.
weight_decay_loss
(
self
.
_l2_weight_decay
,
trainable_variables
)
l2_regularization_loss
=
self
.
weight_decay_loss
(
trainable_variables
)
total_loss
=
model_loss
+
l2_regularization_loss
return
{
'total_loss'
:
total_loss
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment