Commit 47d10833 authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

Base model refactor.

PiperOrigin-RevId: 306597558
parent d70eca30
...@@ -24,37 +24,7 @@ import re ...@@ -24,37 +24,7 @@ import re
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
from official.vision.detection.modeling import checkpoint_utils from official.vision.detection.modeling import checkpoint_utils
from official.vision.detection.modeling import learning_rates from official.vision.detection.modeling import learning_rates
from official.vision.detection.modeling import optimizers
class OptimizerFactory(object):
"""Class to generate optimizer function."""
def __init__(self, params):
"""Creates optimized based on the specified flags."""
if params.type == 'momentum':
nesterov = False
try:
nesterov = params.nesterov
except AttributeError:
pass
self._optimizer = functools.partial(
tf.keras.optimizers.SGD,
momentum=params.momentum,
nesterov=nesterov)
elif params.type == 'adam':
self._optimizer = tf.keras.optimizers.Adam
elif params.type == 'adadelta':
self._optimizer = tf.keras.optimizers.Adadelta
elif params.type == 'adagrad':
self._optimizer = tf.keras.optimizers.Adagrad
elif params.type == 'rmsprop':
self._optimizer = functools.partial(
tf.keras.optimizers.RMSprop, momentum=params.momentum)
else:
raise ValueError('Unsupported optimizer type %s.' % self._optimizer)
def __call__(self, learning_rate):
return self._optimizer(learning_rate=learning_rate)
def _make_filter_trainable_variables_fn(frozen_variable_prefix): def _make_filter_trainable_variables_fn(frozen_variable_prefix):
...@@ -94,7 +64,7 @@ class Model(object): ...@@ -94,7 +64,7 @@ class Model(object):
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
# Optimization. # Optimization.
self._optimizer_fn = OptimizerFactory(params.train.optimizer) self._optimizer_fn = optimizers.OptimizerFactory(params.train.optimizer)
self._learning_rate = learning_rates.learning_rate_generator( self._learning_rate = learning_rates.learning_rate_generator(
params.train.learning_rate) params.train.learning_rate)
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import numpy as np
import tensorflow.compat.v2 as tf
class OptimizerFactory(object):
"""Class to generate optimizer function."""
def __init__(self, params):
"""Creates optimized based on the specified flags."""
if params.type == 'momentum':
nesterov = False
try:
nesterov = params.nesterov
except AttributeError:
pass
self._optimizer = functools.partial(
tf.keras.optimizers.SGD,
momentum=params.momentum,
nesterov=nesterov)
elif params.type == 'adam':
self._optimizer = tf.keras.optimizers.Adam
elif params.type == 'adadelta':
self._optimizer = tf.keras.optimizers.Adadelta
elif params.type == 'adagrad':
self._optimizer = tf.keras.optimizers.Adagrad
elif params.type == 'rmsprop':
self._optimizer = functools.partial(
tf.keras.optimizers.RMSprop, momentum=params.momentum)
else:
raise ValueError('Unsupported optimizer type `{}`.'.format(params.type))
def __call__(self, learning_rate):
return self._optimizer(learning_rate=learning_rate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment