Commit 706a0bd9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 296470381
parent a2e95a60
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
# ============================================================================== # ==============================================================================
"""Config template to train Mask R-CNN.""" """Config template to train Mask R-CNN."""
from official.vision.detection.configs import base_config
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import base_config
# pylint: disable=line-too-long # pylint: disable=line-too-long
MASKRCNN_CFG = params_dict.ParamsDict(base_config.BASE_CFG) MASKRCNN_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
...@@ -23,6 +24,7 @@ MASKRCNN_CFG.override({ ...@@ -23,6 +24,7 @@ MASKRCNN_CFG.override({
'type': 'mask_rcnn', 'type': 'mask_rcnn',
'eval': { 'eval': {
'type': 'box_and_mask', 'type': 'box_and_mask',
'num_images_to_visualize': 0,
}, },
'architecture': { 'architecture': {
'parser': 'maskrcnn_parser', 'parser': 'maskrcnn_parser',
......
...@@ -23,9 +23,8 @@ ...@@ -23,9 +23,8 @@
# need to be fine-tuned for the detection task. # need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match. # Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198 # [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET50_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/'
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/' RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
# pylint: disable=line-too-long # pylint: disable=line-too-long
RETINANET_CFG = { RETINANET_CFG = {
...@@ -54,10 +53,11 @@ RETINANET_CFG = { ...@@ -54,10 +53,11 @@ RETINANET_CFG = {
'path': '', 'path': '',
'prefix': '', 'prefix': '',
}, },
'frozen_variable_prefix': RESNET50_FROZEN_VAR_PREFIX, 'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX,
'train_file_pattern': '', 'train_file_pattern': '',
# TODO(b/142174042): Support transpose_input option. # TODO(b/142174042): Support transpose_input option.
'transpose_input': False, 'transpose_input': False,
'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
'l2_weight_decay': 0.0001, 'l2_weight_decay': 0.0001,
'input_sharding': False, 'input_sharding': False,
}, },
......
...@@ -18,11 +18,9 @@ from __future__ import absolute_import ...@@ -18,11 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import abc import abc
import functools import functools
import re import re
import six
from absl import logging from absl import logging
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
...@@ -53,7 +51,7 @@ class OptimizerFactory(object): ...@@ -53,7 +51,7 @@ class OptimizerFactory(object):
self._optimizer = tf.keras.optimizers.Adagrad self._optimizer = tf.keras.optimizers.Adagrad
elif params.type == 'rmsprop': elif params.type == 'rmsprop':
self._optimizer = functools.partial( self._optimizer = functools.partial(
tf.keras.optimizers.RMSProp, momentum=params.momentum) tf.keras.optimizers.RMSprop, momentum=params.momentum)
else: else:
raise ValueError('Unsupported optimizer type %s.' % self._optimizer) raise ValueError('Unsupported optimizer type %s.' % self._optimizer)
...@@ -104,6 +102,7 @@ class Model(object): ...@@ -104,6 +102,7 @@ class Model(object):
params.train.learning_rate) params.train.learning_rate)
self._frozen_variable_prefix = params.train.frozen_variable_prefix self._frozen_variable_prefix = params.train.frozen_variable_prefix
self._regularization_var_regex = params.train.regularization_variable_regex
self._l2_weight_decay = params.train.l2_weight_decay self._l2_weight_decay = params.train.l2_weight_decay
# Checkpoint restoration. # Checkpoint restoration.
...@@ -146,12 +145,17 @@ class Model(object): ...@@ -146,12 +145,17 @@ class Model(object):
""" """
return _make_filter_trainable_variables_fn(self._frozen_variable_prefix) return _make_filter_trainable_variables_fn(self._frozen_variable_prefix)
def weight_decay_loss(self, l2_weight_decay, trainable_variables): def weight_decay_loss(self, trainable_variables):
return l2_weight_decay * tf.add_n([ reg_variables = [
tf.nn.l2_loss(v) v for v in trainable_variables
for v in trainable_variables if self._regularization_var_regex is None
if 'batch_normalization' not in v.name and 'bias' not in v.name or re.match(self._regularization_var_regex, v.name)
]) ]
logging.info('Regularization Variables: %s',
[v.name for v in reg_variables])
return self._l2_weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in reg_variables])
def make_restore_checkpoint_fn(self): def make_restore_checkpoint_fn(self):
"""Returns scaffold function to restore parameters from v1 checkpoint.""" """Returns scaffold function to restore parameters from v1 checkpoint."""
......
...@@ -106,8 +106,7 @@ class RetinanetModel(base_model.Model): ...@@ -106,8 +106,7 @@ class RetinanetModel(base_model.Model):
labels['box_targets'], labels['box_targets'],
labels['num_positives']) labels['num_positives'])
model_loss = cls_loss + self._box_loss_weight * box_loss model_loss = cls_loss + self._box_loss_weight * box_loss
l2_regularization_loss = self.weight_decay_loss(self._l2_weight_decay, l2_regularization_loss = self.weight_decay_loss(trainable_variables)
trainable_variables)
total_loss = model_loss + l2_regularization_loss total_loss = model_loss + l2_regularization_loss
return { return {
'total_loss': total_loss, 'total_loss': total_loss,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment