Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
706a0bd9
Commit
706a0bd9
authored
Feb 21, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 296470381
parent
a2e95a60
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
15 deletions
+20
-15
official/vision/detection/configs/maskrcnn_config.py
official/vision/detection/configs/maskrcnn_config.py
+3
-1
official/vision/detection/configs/retinanet_config.py
official/vision/detection/configs/retinanet_config.py
+3
-3
official/vision/detection/modeling/base_model.py
official/vision/detection/modeling/base_model.py
+13
-9
official/vision/detection/modeling/retinanet_model.py
official/vision/detection/modeling/retinanet_model.py
+1
-2
No files found.
official/vision/detection/configs/maskrcnn_config.py
View file @
706a0bd9
...
@@ -14,8 +14,9 @@
...
@@ -14,8 +14,9 @@
# ==============================================================================
# ==============================================================================
"""Config template to train Mask R-CNN."""
"""Config template to train Mask R-CNN."""
from
official.vision.detection.configs
import
base_config
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.vision.detection.configs
import
base_config
# pylint: disable=line-too-long
# pylint: disable=line-too-long
MASKRCNN_CFG
=
params_dict
.
ParamsDict
(
base_config
.
BASE_CFG
)
MASKRCNN_CFG
=
params_dict
.
ParamsDict
(
base_config
.
BASE_CFG
)
...
@@ -23,6 +24,7 @@ MASKRCNN_CFG.override({
...
@@ -23,6 +24,7 @@ MASKRCNN_CFG.override({
'type'
:
'mask_rcnn'
,
'type'
:
'mask_rcnn'
,
'eval'
:
{
'eval'
:
{
'type'
:
'box_and_mask'
,
'type'
:
'box_and_mask'
,
'num_images_to_visualize'
:
0
,
},
},
'architecture'
:
{
'architecture'
:
{
'parser'
:
'maskrcnn_parser'
,
'parser'
:
'maskrcnn_parser'
,
...
...
official/vision/detection/configs/retinanet_config.py
View file @
706a0bd9
...
@@ -23,9 +23,8 @@
...
@@ -23,9 +23,8 @@
# need to be fine-tuned for the detection task.
# need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET50_FROZEN_VAR_PREFIX
=
r
'(resnet\d+/)conv2d(|_([1-9]|10))\/'
RESNET_FROZEN_VAR_PREFIX
=
r
'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
RESNET_FROZEN_VAR_PREFIX
=
r
'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX
=
r
'.*(kernel|weight):0$'
# pylint: disable=line-too-long
# pylint: disable=line-too-long
RETINANET_CFG
=
{
RETINANET_CFG
=
{
...
@@ -54,10 +53,11 @@ RETINANET_CFG = {
...
@@ -54,10 +53,11 @@ RETINANET_CFG = {
'path'
:
''
,
'path'
:
''
,
'prefix'
:
''
,
'prefix'
:
''
,
},
},
'frozen_variable_prefix'
:
RESNET
50
_FROZEN_VAR_PREFIX
,
'frozen_variable_prefix'
:
RESNET_FROZEN_VAR_PREFIX
,
'train_file_pattern'
:
''
,
'train_file_pattern'
:
''
,
# TODO(b/142174042): Support transpose_input option.
# TODO(b/142174042): Support transpose_input option.
'transpose_input'
:
False
,
'transpose_input'
:
False
,
'regularization_variable_regex'
:
REGULARIZATION_VAR_REGEX
,
'l2_weight_decay'
:
0.0001
,
'l2_weight_decay'
:
0.0001
,
'input_sharding'
:
False
,
'input_sharding'
:
False
,
},
},
...
...
official/vision/detection/modeling/base_model.py
View file @
706a0bd9
...
@@ -18,11 +18,9 @@ from __future__ import absolute_import
...
@@ -18,11 +18,9 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
abc
import
abc
import
functools
import
functools
import
re
import
re
import
six
from
absl
import
logging
from
absl
import
logging
import
tensorflow.compat.v2
as
tf
import
tensorflow.compat.v2
as
tf
...
@@ -53,7 +51,7 @@ class OptimizerFactory(object):
...
@@ -53,7 +51,7 @@ class OptimizerFactory(object):
self
.
_optimizer
=
tf
.
keras
.
optimizers
.
Adagrad
self
.
_optimizer
=
tf
.
keras
.
optimizers
.
Adagrad
elif
params
.
type
==
'rmsprop'
:
elif
params
.
type
==
'rmsprop'
:
self
.
_optimizer
=
functools
.
partial
(
self
.
_optimizer
=
functools
.
partial
(
tf
.
keras
.
optimizers
.
RMS
P
rop
,
momentum
=
params
.
momentum
)
tf
.
keras
.
optimizers
.
RMS
p
rop
,
momentum
=
params
.
momentum
)
else
:
else
:
raise
ValueError
(
'Unsupported optimizer type %s.'
%
self
.
_optimizer
)
raise
ValueError
(
'Unsupported optimizer type %s.'
%
self
.
_optimizer
)
...
@@ -104,6 +102,7 @@ class Model(object):
...
@@ -104,6 +102,7 @@ class Model(object):
params
.
train
.
learning_rate
)
params
.
train
.
learning_rate
)
self
.
_frozen_variable_prefix
=
params
.
train
.
frozen_variable_prefix
self
.
_frozen_variable_prefix
=
params
.
train
.
frozen_variable_prefix
self
.
_regularization_var_regex
=
params
.
train
.
regularization_variable_regex
self
.
_l2_weight_decay
=
params
.
train
.
l2_weight_decay
self
.
_l2_weight_decay
=
params
.
train
.
l2_weight_decay
# Checkpoint restoration.
# Checkpoint restoration.
...
@@ -146,12 +145,17 @@ class Model(object):
...
@@ -146,12 +145,17 @@ class Model(object):
"""
"""
return
_make_filter_trainable_variables_fn
(
self
.
_frozen_variable_prefix
)
return
_make_filter_trainable_variables_fn
(
self
.
_frozen_variable_prefix
)
def
weight_decay_loss
(
self
,
l2_weight_decay
,
trainable_variables
):
def
weight_decay_loss
(
self
,
trainable_variables
):
return
l2_weight_decay
*
tf
.
add_n
([
reg_variables
=
[
tf
.
nn
.
l2_loss
(
v
)
v
for
v
in
trainable_variables
for
v
in
trainable_variables
if
self
.
_regularization_var_regex
is
None
if
'batch_normalization'
not
in
v
.
name
and
'bias'
not
in
v
.
name
or
re
.
match
(
self
.
_regularization_var_regex
,
v
.
name
)
])
]
logging
.
info
(
'Regularization Variables: %s'
,
[
v
.
name
for
v
in
reg_variables
])
return
self
.
_l2_weight_decay
*
tf
.
add_n
(
[
tf
.
nn
.
l2_loss
(
v
)
for
v
in
reg_variables
])
def
make_restore_checkpoint_fn
(
self
):
def
make_restore_checkpoint_fn
(
self
):
"""Returns scaffold function to restore parameters from v1 checkpoint."""
"""Returns scaffold function to restore parameters from v1 checkpoint."""
...
...
official/vision/detection/modeling/retinanet_model.py
View file @
706a0bd9
...
@@ -106,8 +106,7 @@ class RetinanetModel(base_model.Model):
...
@@ -106,8 +106,7 @@ class RetinanetModel(base_model.Model):
labels
[
'box_targets'
],
labels
[
'box_targets'
],
labels
[
'num_positives'
])
labels
[
'num_positives'
])
model_loss
=
cls_loss
+
self
.
_box_loss_weight
*
box_loss
model_loss
=
cls_loss
+
self
.
_box_loss_weight
*
box_loss
l2_regularization_loss
=
self
.
weight_decay_loss
(
self
.
_l2_weight_decay
,
l2_regularization_loss
=
self
.
weight_decay_loss
(
trainable_variables
)
trainable_variables
)
total_loss
=
model_loss
+
l2_regularization_loss
total_loss
=
model_loss
+
l2_regularization_loss
return
{
return
{
'total_loss'
:
total_loss
,
'total_loss'
:
total_loss
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment