Commit 706a0bd9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 296470381
parent a2e95a60
......@@ -14,8 +14,9 @@
# ==============================================================================
"""Config template to train Mask R-CNN."""
from official.vision.detection.configs import base_config
from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import base_config
# pylint: disable=line-too-long
MASKRCNN_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
......@@ -23,6 +24,7 @@ MASKRCNN_CFG.override({
'type': 'mask_rcnn',
'eval': {
'type': 'box_and_mask',
'num_images_to_visualize': 0,
},
'architecture': {
'parser': 'maskrcnn_parser',
......
......@@ -23,9 +23,8 @@
# need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET50_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/'
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
# pylint: disable=line-too-long
RETINANET_CFG = {
......@@ -54,10 +53,11 @@ RETINANET_CFG = {
'path': '',
'prefix': '',
},
'frozen_variable_prefix': RESNET50_FROZEN_VAR_PREFIX,
'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX,
'train_file_pattern': '',
# TODO(b/142174042): Support transpose_input option.
'transpose_input': False,
'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
'l2_weight_decay': 0.0001,
'input_sharding': False,
},
......
......@@ -18,11 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import functools
import re
import six
from absl import logging
import tensorflow.compat.v2 as tf
......@@ -53,7 +51,7 @@ class OptimizerFactory(object):
self._optimizer = tf.keras.optimizers.Adagrad
elif params.type == 'rmsprop':
self._optimizer = functools.partial(
tf.keras.optimizers.RMSProp, momentum=params.momentum)
tf.keras.optimizers.RMSprop, momentum=params.momentum)
else:
raise ValueError('Unsupported optimizer type %s.' % self._optimizer)
......@@ -104,6 +102,7 @@ class Model(object):
params.train.learning_rate)
self._frozen_variable_prefix = params.train.frozen_variable_prefix
self._regularization_var_regex = params.train.regularization_variable_regex
self._l2_weight_decay = params.train.l2_weight_decay
# Checkpoint restoration.
......@@ -146,12 +145,17 @@ class Model(object):
"""
return _make_filter_trainable_variables_fn(self._frozen_variable_prefix)
def weight_decay_loss(self, l2_weight_decay, trainable_variables):
return l2_weight_decay * tf.add_n([
tf.nn.l2_loss(v)
for v in trainable_variables
if 'batch_normalization' not in v.name and 'bias' not in v.name
])
def weight_decay_loss(self, trainable_variables):
reg_variables = [
v for v in trainable_variables
if self._regularization_var_regex is None
or re.match(self._regularization_var_regex, v.name)
]
logging.info('Regularization Variables: %s',
[v.name for v in reg_variables])
return self._l2_weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in reg_variables])
def make_restore_checkpoint_fn(self):
"""Returns scaffold function to restore parameters from v1 checkpoint."""
......
......@@ -106,8 +106,7 @@ class RetinanetModel(base_model.Model):
labels['box_targets'],
labels['num_positives'])
model_loss = cls_loss + self._box_loss_weight * box_loss
l2_regularization_loss = self.weight_decay_loss(self._l2_weight_decay,
trainable_variables)
l2_regularization_loss = self.weight_decay_loss(trainable_variables)
total_loss = model_loss + l2_regularization_loss
return {
'total_loss': total_loss,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment