Commit 472e2f80 authored by zhanggzh's avatar zhanggzh
Browse files

Merge remote-tracking branch 'tf_model/main'

parents d91296eb f3a14f85
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom training loop for running TensorFlow 2.0 models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from official.common import distribute_utils
from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
strategy_flags_dict = hyperparams_flags.strategy_flags_dict
hparam_flags_dict = hyperparams_flags.hparam_flags_dict
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to model_dir with provided checkpoint prefix."""
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info('Saving model as TF checkpoint: %s', saved_path)
def _steps_to_run(current_step, total_steps, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
raise ValueError('steps_per_loop should be positive integer.')
return min(total_steps - current_step, steps_per_loop)
def _no_metric():
return None
def metrics_as_dict(metric):
"""Puts input metric(s) into a list.
Args:
metric: metric(s) to be put into the list. `metric` could be an object, a
list, or a dict of tf.keras.metrics.Metric or has the `required_method`.
Returns:
A dictionary of valid metrics.
"""
if isinstance(metric, tf.keras.metrics.Metric):
metrics = {metric.name: metric}
elif isinstance(metric, list):
metrics = {m.name: m for m in metric}
elif isinstance(metric, dict):
metrics = metric
elif not metric:
return {}
else:
metrics = {'metric': metric}
return metrics
def metric_results(metric):
"""Collects results from the given metric(s)."""
metrics = metrics_as_dict(metric)
metric_result = {
name: m.result().numpy().astype(float) for name, m in metrics.items()
}
return metric_result
def reset_states(metric):
"""Resets states of the given metric(s)."""
metrics = metrics_as_dict(metric)
for m in metrics.values():
m.reset_states()
class SummaryWriter(object):
"""Simple SummaryWriter for writing dictionary of metrics.
Attributes:
writer: The tf.SummaryWriter.
"""
def __init__(self, model_dir: Text, name: Text):
"""Inits SummaryWriter with paths.
Args:
model_dir: the model folder path.
name: the summary subfolder name.
"""
self.writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
"""Write metrics to summary with the given writer.
Args:
metrics: a dictionary of metrics values. Prefer dictionary.
step: integer. The training step.
"""
if not isinstance(metrics, dict):
# Support scalar metric without name.
logging.warning('Warning: summary writer prefer metrics as dictionary.')
metrics = {'metric': metrics}
with self.writer.as_default():
for k, v in metrics.items():
tf.summary.scalar(k, v, step=step)
self.writer.flush()
class DistributedExecutor(object):
"""Interface to train and eval models with tf.distribute.Strategy."""
def __init__(self, strategy, params, model_fn, loss_fn, is_multi_host=False):
"""Constructor.
Args:
strategy: an instance of tf.distribute.Strategy.
params: Model configuration needed to run distribution strategy.
model_fn: Keras model function. Signature:
(params: ParamsDict) -> tf.keras.models.Model.
loss_fn: loss function. Signature:
(y_true: Tensor, y_pred: Tensor) -> Tensor
is_multi_host: Set to True when using multi hosts for training, like multi
worker GPU or TPU pod (slice). Otherwise, False.
"""
self._params = params
self._model_fn = model_fn
self._loss_fn = loss_fn
self._strategy = strategy
self._checkpoint_name = 'ctl_step_{step}.ckpt'
self._is_multi_host = is_multi_host
self.train_summary_writer = None
self.eval_summary_writer = None
self.global_train_step = None
@property
def checkpoint_name(self):
"""Returns default checkpoint name."""
return self._checkpoint_name
@checkpoint_name.setter
def checkpoint_name(self, name):
"""Sets default summary writer for the current thread."""
self._checkpoint_name = name
def loss_fn(self):
return self._loss_fn()
def model_fn(self, params):
return self._model_fn(params)
def _save_config(self, model_dir):
"""Save parameters to config files if model_dir is defined."""
logging.info('Save config to model_dir %s.', model_dir)
if model_dir:
if not tf.io.gfile.exists(model_dir):
tf.io.gfile.makedirs(model_dir)
self._params.lock()
params_dict.save_params_dict_to_yaml(self._params,
model_dir + '/params.yaml')
else:
logging.warning('model_dir is empty, so skip the save config.')
def _get_input_iterator(
self, input_fn: Callable[..., tf.data.Dataset],
strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
"""Returns distributed dataset iterator.
Args:
input_fn: (params: dict) -> tf.data.Dataset.
strategy: an instance of tf.distribute.Strategy.
Returns:
An iterator that yields input tensors.
"""
if input_fn is None:
return None
# When training with multiple TPU workers, datasets needs to be cloned
# across workers. Since Dataset instance cannot be cloned in eager mode,
# we instead pass callable that returns a dataset.
if self._is_multi_host:
return iter(strategy.distribute_datasets_from_function(input_fn))
else:
input_data = input_fn()
return iter(strategy.experimental_distribute_dataset(input_data))
def _create_replicated_step(self,
strategy,
model,
loss_fn,
optimizer,
metric=None):
"""Creates a single training step.
Args:
strategy: an instance of tf.distribute.Strategy.
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
metric: tf.keras.metrics.Metric subclass.
Returns:
The training step callable.
"""
metrics = metrics_as_dict(metric)
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
prediction_loss = loss_fn(labels, outputs)
loss = tf.reduce_mean(prediction_loss)
loss = loss / strategy.num_replicas_in_sync
for m in metrics.values():
m.update_state(labels, outputs)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
return _replicated_step
def _create_train_step(self,
strategy,
model,
loss_fn,
optimizer,
metric=None):
"""Creates a distributed training step.
Args:
strategy: an instance of tf.distribute.Strategy.
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
metric: tf.keras.metrics.Metric subclass.
Returns:
The training step callable.
"""
replicated_step = self._create_replicated_step(strategy, model, loss_fn,
optimizer, metric)
@tf.function
def train_step(iterator, num_steps):
"""Performs a distributed training step.
Args:
iterator: an iterator that yields input tensors.
num_steps: the number of steps in the loop.
Returns:
The loss tensor.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
per_replica_losses = strategy.run(replicated_step, args=(next(iterator),))
for _ in tf.range(num_steps - 1):
per_replica_losses = strategy.run(
replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
losses = tf.nest.map_structure(
lambda x: strategy.reduce(tf.distribute.ReduceOp.MEAN, x, axis=None),
per_replica_losses)
return losses
return train_step
def _create_test_step(self, strategy, model, metric):
"""Creates a distributed test step."""
metrics = metrics_as_dict(metric)
@tf.function
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
if not metric:
logging.info('Skip test_step because metric is None (%s)', metric)
return None, None
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
for m in metrics.values():
m.update_state(labels, model_outputs)
return labels, model_outputs
return strategy.run(_test_step_fn, args=(next(iterator),))
return test_step
def train(
self,
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_input_fn: Optional[Callable[[params_dict.ParamsDict],
tf.data.Dataset]] = None,
model_dir: Optional[Text] = None,
total_steps: int = 1,
iterations_per_loop: int = 1,
train_metric_fn: Optional[Callable[[], Any]] = None,
eval_metric_fn: Optional[Callable[[], Any]] = None,
summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter,
init_checkpoint: Optional[Callable[[tf.keras.Model], Any]] = None,
custom_callbacks: Optional[List[tf.keras.callbacks.Callback]] = None,
continuous_eval: bool = False,
save_config: bool = True):
"""Runs distributed training.
Args:
train_input_fn: (params: dict) -> tf.data.Dataset training data input
function.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluating metric on eval data. If None, will not run the eval
step.
model_dir: the folder path for model checkpoints.
total_steps: total training steps.
iterations_per_loop: train steps per loop. After each loop, this job will
update metrics like loss and save checkpoint.
train_metric_fn: metric_fn for evaluation in train_step.
eval_metric_fn: metric_fn for evaluation in test_step.
summary_writer_fn: function to create summary writer.
init_checkpoint: function to load checkpoint.
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
continuous_eval: If `True`, will continously run evaluation on every
available checkpoints. If `False`, will do the evaluation once after the
final step.
save_config: bool. Whether to save params to model_dir.
Returns:
The training loss and eval metrics.
"""
assert train_input_fn is not None
if train_metric_fn and not callable(train_metric_fn):
raise ValueError('if `train_metric_fn` is specified, '
'train_metric_fn must be a callable.')
if eval_metric_fn and not callable(eval_metric_fn):
raise ValueError('if `eval_metric_fn` is specified, '
'eval_metric_fn must be a callable.')
train_metric_fn = train_metric_fn or _no_metric
eval_metric_fn = eval_metric_fn or _no_metric
if custom_callbacks and iterations_per_loop != 1:
logging.warning(
'It is sematically wrong to run callbacks when '
'iterations_per_loop is not one (%s)', iterations_per_loop)
custom_callbacks = custom_callbacks or []
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
if callback:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
if callback:
callback.on_batch_end(batch)
if save_config:
self._save_config(model_dir)
if FLAGS.save_checkpoint_freq:
save_freq = FLAGS.save_checkpoint_freq
else:
save_freq = iterations_per_loop
params = self._params
strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place
# input pipeline ops in worker task.
train_iterator = self._get_input_iterator(train_input_fn, strategy)
train_loss = None
train_metric_result = None
eval_metric_result = None
tf.keras.backend.set_learning_phase(1)
with strategy.scope():
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model = self.model_fn(params.as_dict())
if not hasattr(model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
optimizer = model.optimizer
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
initial_step = 0
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
initial_step = optimizer.iterations.numpy()
logging.info('Loading from checkpoint file completed. Init step %d',
initial_step)
elif init_checkpoint:
logging.info('Restoring from init checkpoint function')
init_checkpoint(model)
logging.info('Loading from init checkpoint file completed')
current_step = optimizer.iterations.numpy()
checkpoint_name = self.checkpoint_name
eval_metric = eval_metric_fn()
train_metric = train_metric_fn()
train_summary_writer = summary_writer_fn(model_dir, 'eval_train')
self.train_summary_writer = train_summary_writer.writer
test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
self.eval_summary_writer = test_summary_writer.writer
# Use training summary writer in TimeHistory if it's in use
for cb in custom_callbacks:
if isinstance(cb, keras_utils.TimeHistory):
cb.summary_writer = self.train_summary_writer
# Continue training loop.
train_step = self._create_train_step(
strategy=strategy,
model=model,
loss_fn=self.loss_fn(),
optimizer=optimizer,
metric=train_metric)
test_step = None
if eval_input_fn and eval_metric:
self.global_train_step = model.optimizer.iterations
test_step = self._create_test_step(strategy, model, metric=eval_metric)
# Step-0 operations
if current_step == 0 and not latest_checkpoint_file:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Step: %s evalation metric = %s.', current_step,
eval_metric_result)
test_summary_writer(metrics=eval_metric_result, step=optimizer.iterations)
reset_states(eval_metric)
logging.info('Training started')
last_save_checkpoint_step = current_step
while current_step < total_steps:
num_steps = _steps_to_run(current_step, total_steps, iterations_per_loop)
_run_callbacks_on_batch_begin(current_step)
train_loss = train_step(train_iterator,
tf.convert_to_tensor(num_steps, dtype=tf.int32))
current_step += num_steps
train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float),
train_loss)
_run_callbacks_on_batch_end(current_step - 1)
if not isinstance(train_loss, dict):
train_loss = {'total_loss': train_loss}
if np.isnan(train_loss['total_loss']):
raise ValueError('total loss is NaN.')
if train_metric:
train_metric_result = metric_results(train_metric)
train_metric_result.update(train_loss)
else:
train_metric_result = train_loss
if callable(optimizer.lr):
train_metric_result.update(
{'learning_rate': optimizer.lr(current_step).numpy()})
else:
train_metric_result.update({'learning_rate': optimizer.lr.numpy()})
logging.info('Train Step: %d/%d / loss = %s / training metric = %s',
current_step, total_steps, train_loss, train_metric_result)
train_summary_writer(
metrics=train_metric_result, step=optimizer.iterations)
# Saves model checkpoints and run validation steps at every
# iterations_per_loop steps.
# To avoid repeated model saving, we do not save after the last
# step of training.
if save_freq > 0 and current_step < total_steps and (
current_step - last_save_checkpoint_step) >= save_freq:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
last_save_checkpoint_step = current_step
if continuous_eval and current_step < total_steps and test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Step: %s evalation metric = %s.', current_step,
eval_metric_result)
test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations)
# Re-initialize evaluation metric, except the last step.
if eval_metric and current_step < total_steps:
reset_states(eval_metric)
if train_metric and current_step < total_steps:
reset_states(train_metric)
# Reaches the end of training and saves the last checkpoint.
if last_save_checkpoint_step < total_steps:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_step:
logging.info('Running final evaluation after training is complete.')
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Final evaluation metric = %s.', eval_metric_result)
test_summary_writer(metrics=eval_metric_result, step=optimizer.iterations)
self.train_summary_writer.close()
self.eval_summary_writer.close()
return train_metric_result, eval_metric_result
def _run_evaluation(self, test_step, current_training_step, metric,
test_iterator):
"""Runs validation steps and aggregate metrics."""
if not test_iterator or not metric:
logging.warning(
'Both test_iterator (%s) and metrics (%s) must not be None.',
test_iterator, metric)
return None
logging.info('Running evaluation after step: %s.', current_training_step)
eval_step = 0
while True:
try:
with tf.experimental.async_scope():
test_step(test_iterator)
eval_step += 1
except (StopIteration, tf.errors.OutOfRangeError):
tf.experimental.async_clear_error()
break
metric_result = metric_results(metric)
logging.info('Total eval steps: [%d]', eval_step)
logging.info('At training step: [%r] Validation metric = %r',
current_training_step, metric_result)
return metric_result
def evaluate_from_model_dir(
self,
model_dir: Text,
eval_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_metric_fn: Callable[[], Any],
total_steps: int = -1,
eval_timeout: Optional[int] = None,
min_eval_interval: int = 180,
summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter):
"""Runs distributed evaluation on model folder.
Args:
model_dir: the folder for storing model checkpoints.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step.
total_steps: total training steps. If the current step reaches the
total_steps, the evaluation loop will stop.
eval_timeout: The maximum number of seconds to wait between checkpoints.
If left as None, then the process will wait indefinitely. Used by
tf.train.checkpoints_iterator.
min_eval_interval: The minimum number of seconds between yielding
checkpoints. Used by tf.train.checkpoints_iterator.
summary_writer_fn: function to create summary writer.
Returns:
Eval metrics dictionary of the last checkpoint.
"""
if not model_dir:
raise ValueError('model_dir must be set.')
def terminate_eval():
tf.logging.info('Terminating eval after %d seconds of no checkpoints' %
eval_timeout)
return True
summary_writer = summary_writer_fn(model_dir, 'eval')
self.eval_summary_writer = summary_writer.writer
# Read checkpoints from the given model directory
# until `eval_timeout` seconds elapses.
for checkpoint_path in tf.train.checkpoints_iterator(
model_dir,
min_interval_secs=min_eval_interval,
timeout=eval_timeout,
timeout_fn=terminate_eval):
eval_metric_result, current_step = self.evaluate_checkpoint(
checkpoint_path=checkpoint_path,
eval_input_fn=eval_input_fn,
eval_metric_fn=eval_metric_fn,
summary_writer=summary_writer)
if total_steps > 0 and current_step >= total_steps:
logging.info('Evaluation finished after training step %d', current_step)
break
return eval_metric_result
def evaluate_checkpoint(self,
checkpoint_path: Text,
eval_input_fn: Callable[[params_dict.ParamsDict],
tf.data.Dataset],
eval_metric_fn: Callable[[], Any],
summary_writer: Optional[SummaryWriter] = None):
"""Runs distributed evaluation on the one checkpoint.
Args:
checkpoint_path: the checkpoint to evaluate.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step.
summary_writer: function to create summary writer.
Returns:
Eval metrics dictionary of the last checkpoint.
"""
if not callable(eval_metric_fn):
raise ValueError('if `eval_metric_fn` is specified, '
'eval_metric_fn must be a callable.')
old_phase = tf.keras.backend.learning_phase()
tf.keras.backend.set_learning_phase(0)
params = self._params
strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place
# input pipeline ops in worker task.
with strategy.scope():
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model = self.model_fn(params.as_dict())
checkpoint = tf.train.Checkpoint(model=model)
eval_metric = eval_metric_fn()
assert eval_metric, 'eval_metric does not exist'
test_step = self._create_test_step(strategy, model, metric=eval_metric)
logging.info('Starting to evaluate.')
if not checkpoint_path:
raise ValueError('checkpoint path is empty')
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
if reader.has_tensor('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE'):
# Legacy keras optimizer iteration.
current_step = reader.get_tensor(
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
else:
# New keras optimizer iteration.
current_step = reader.get_tensor(
'optimizer/_iterations/.ATTRIBUTES/VARIABLE_VALUE')
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', checkpoint_path)
status = checkpoint.restore(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
self.global_train_step = model.optimizer.iterations
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Step: %s evalation metric = %s.', current_step,
eval_metric_result)
summary_writer(metrics=eval_metric_result, step=current_step)
reset_states(eval_metric)
tf.keras.backend.set_learning_phase(old_phase)
return eval_metric_result, current_step
def predict(self):
return NotImplementedError('Unimplmented function.')
class ExecutorBuilder(object):
"""Builder of DistributedExecutor.
Example 1: Builds an executor with supported Strategy.
builder = ExecutorBuilder(
strategy_type='tpu',
strategy_config={'tpu': '/bns/xxx'})
dist_executor = builder.build_executor(
params=params,
model_fn=my_model_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
Example 2: Builds an executor with customized Strategy.
builder = ExecutorBuilder()
builder.strategy = <some customized Strategy>
dist_executor = builder.build_executor(
params=params,
model_fn=my_model_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
Example 3: Builds a customized executor with customized Strategy.
class MyDistributedExecutor(DistributedExecutor):
# implementation ...
builder = ExecutorBuilder()
builder.strategy = <some customized Strategy>
dist_executor = builder.build_executor(
class_ctor=MyDistributedExecutor,
params=params,
model_fn=my_model_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
"""
def __init__(self, strategy_type=None, strategy_config=None):
_ = distribute_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
"""Constructor.
Args:
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
If None, the user is responsible to set the strategy before calling
build_executor(...).
strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure.
"""
self._strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=strategy_type,
num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg,
num_packs=strategy_config.num_packs,
tpu_address=strategy_config.tpu)
@property
def strategy(self):
"""Returns default checkpoint name."""
return self._strategy
@strategy.setter
def strategy(self, new_strategy):
"""Sets default summary writer for the current thread."""
self._strategy = new_strategy
def build_executor(self,
class_ctor=DistributedExecutor,
params=None,
model_fn=None,
loss_fn=None,
**kwargs):
"""Creates an executor according to strategy type.
See doc string of the DistributedExecutor.__init__ for more information of
the
input arguments.
Args:
class_ctor: A constructor of executor (default: DistributedExecutor).
params: ParamsDict, all the model parameters and runtime parameters.
model_fn: Keras model function.
loss_fn: loss function.
**kwargs: other arguments to the executor constructor.
Returns:
An instance of DistributedExecutor or its subclass.
"""
if self._strategy is None:
raise ValueError('`strategy` should not be None. You need to specify '
'`strategy_type` in the builder contructor or directly '
'set the `strategy` property of the builder.')
return class_ctor(
strategy=self._strategy,
params=params,
model_fn=model_fn,
loss_fn=loss_fn,
**kwargs)
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main function to train various object detection models."""
import functools
import pprint
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.legacy.detection.configs import factory as config_factory
from official.legacy.detection.dataloader import input_reader
from official.legacy.detection.dataloader import mode_keys as ModeKeys
from official.legacy.detection.executor import distributed_executor as executor
from official.legacy.detection.executor.detection_executor import DetectionDistributedExecutor
from official.legacy.detection.modeling import factory as model_factory
from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
hyperparams_flags.initialize_common_flags()
flags_core.define_log_steps()
flags.DEFINE_bool('enable_xla', default=False, help='Enable XLA for GPU')
flags.DEFINE_string(
'mode',
default='train',
help='Mode to run: `train`, `eval` or `eval_once`.')
flags.DEFINE_string(
'model', default='retinanet',
help='Model to run: `retinanet`, `mask_rcnn` or `shapemask`.')
flags.DEFINE_string('training_file_pattern', None,
'Location of the train data.')
flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
flags.DEFINE_string(
'checkpoint_path', None,
'The checkpoint path to eval. Only used in eval_once mode.')
FLAGS = flags.FLAGS
def run_executor(params,
mode,
checkpoint_path=None,
train_input_fn=None,
eval_input_fn=None,
callbacks=None,
prebuilt_strategy=None):
"""Runs the object detection model on distribution strategy defined by the user."""
if params.architecture.use_bfloat16:
tf.compat.v2.keras.mixed_precision.set_global_policy('mixed_bfloat16')
model_builder = model_factory.model_generator(params)
if prebuilt_strategy is not None:
strategy = prebuilt_strategy
else:
strategy_config = params.strategy_config
distribute_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.strategy_type,
num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg,
num_packs=strategy_config.num_packs,
tpu_address=strategy_config.tpu)
num_workers = int(strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
if mode == 'train':
def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.TRAIN)
logging.info(
'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
strategy.num_replicas_in_sync, num_workers, is_multi_host)
dist_executor = DetectionDistributedExecutor(
strategy=strategy,
params=params,
model_fn=_model_fn,
loss_fn=model_builder.build_loss_fn,
is_multi_host=is_multi_host,
predict_post_process_fn=model_builder.post_processing,
trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn())
if is_multi_host:
train_input_fn = functools.partial(
train_input_fn,
batch_size=params.train.batch_size // strategy.num_replicas_in_sync)
return dist_executor.train(
train_input_fn=train_input_fn,
model_dir=params.model_dir,
iterations_per_loop=params.train.iterations_per_loop,
total_steps=params.train.total_steps,
init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks,
save_config=True)
elif mode == 'eval' or mode == 'eval_once':
def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)
logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
strategy.num_replicas_in_sync, num_workers, is_multi_host)
if is_multi_host:
eval_input_fn = functools.partial(
eval_input_fn,
batch_size=params.eval.batch_size // strategy.num_replicas_in_sync)
dist_executor = DetectionDistributedExecutor(
strategy=strategy,
params=params,
model_fn=_model_fn,
loss_fn=model_builder.build_loss_fn,
is_multi_host=is_multi_host,
predict_post_process_fn=model_builder.post_processing,
trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn())
if mode == 'eval':
results = dist_executor.evaluate_from_model_dir(
model_dir=params.model_dir,
eval_input_fn=eval_input_fn,
eval_metric_fn=model_builder.eval_metrics,
eval_timeout=params.eval.eval_timeout,
min_eval_interval=params.eval.min_eval_interval,
total_steps=params.train.total_steps)
else:
# Run evaluation once for a single checkpoint.
if not checkpoint_path:
raise ValueError('checkpoint_path cannot be empty.')
if tf.io.gfile.isdir(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
results, _ = dist_executor.evaluate_checkpoint(
checkpoint_path=checkpoint_path,
eval_input_fn=eval_input_fn,
eval_metric_fn=model_builder.eval_metrics,
summary_writer=summary_writer)
for k, v in results.items():
logging.info('Final eval metric %s: %f', k, v)
return results
else:
raise ValueError('Mode not found: %s.' % mode)
def run(callbacks=None):
"""Runs the experiment."""
keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)
params = config_factory.config_generator(FLAGS.model)
params = params_dict.override_params_dict(
params, FLAGS.config_file, is_strict=True)
params = params_dict.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.override(
{
'strategy_type': FLAGS.strategy_type,
'model_dir': FLAGS.model_dir,
'strategy_config': executor.strategy_flags_dict(),
},
is_strict=False)
# Make sure use_tpu and strategy_type are in sync.
params.use_tpu = (params.strategy_type == 'tpu')
if not params.use_tpu:
params.override({
'architecture': {
'use_bfloat16': False,
},
'norm_activation': {
'use_sync_bn': False,
},
}, is_strict=True)
params.validate()
params.lock()
pp = pprint.PrettyPrinter()
params_str = pp.pformat(params.as_dict())
logging.info('Model Parameters: %s', params_str)
train_input_fn = None
eval_input_fn = None
training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern
eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern
if not training_file_pattern and not eval_file_pattern:
raise ValueError('Must provide at least one of training_file_pattern and '
'eval_file_pattern.')
if training_file_pattern:
# Use global batch size for single host.
train_input_fn = input_reader.InputFn(
file_pattern=training_file_pattern,
params=params,
mode=input_reader.ModeKeys.TRAIN,
batch_size=params.train.batch_size)
if eval_file_pattern:
eval_input_fn = input_reader.InputFn(
file_pattern=eval_file_pattern,
params=params,
mode=input_reader.ModeKeys.PREDICT_WITH_GT,
batch_size=params.eval.batch_size,
num_examples=params.eval.eval_samples)
if callbacks is None:
callbacks = []
if FLAGS.log_steps:
callbacks.append(
keras_utils.TimeHistory(
batch_size=params.train.batch_size,
log_steps=FLAGS.log_steps,
))
return run_executor(
params,
FLAGS.mode,
checkpoint_path=FLAGS.checkpoint_path,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
callbacks=callbacks)
def main(argv):
del argv # Unused.
run()
if __name__ == '__main__':
tf.config.set_soft_device_placement(True)
app.run(main)
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model architecture factory."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from official.legacy.detection.modeling.architecture import fpn
from official.legacy.detection.modeling.architecture import heads
from official.legacy.detection.modeling.architecture import identity
from official.legacy.detection.modeling.architecture import nn_ops
from official.legacy.detection.modeling.architecture import resnet
from official.legacy.detection.modeling.architecture import spinenet
def norm_activation_generator(params):
return nn_ops.norm_activation_builder(
momentum=params.batch_norm_momentum,
epsilon=params.batch_norm_epsilon,
trainable=params.batch_norm_trainable,
activation=params.activation)
def backbone_generator(params):
"""Generator function for various backbone models."""
if params.architecture.backbone == 'resnet':
resnet_params = params.resnet
backbone_fn = resnet.Resnet(
resnet_depth=resnet_params.resnet_depth,
activation=params.norm_activation.activation,
norm_activation=norm_activation_generator(
params.norm_activation))
elif params.architecture.backbone == 'spinenet':
spinenet_params = params.spinenet
backbone_fn = spinenet.SpineNetBuilder(model_id=spinenet_params.model_id)
else:
raise ValueError('Backbone model `{}` is not supported.'
.format(params.architecture.backbone))
return backbone_fn
def multilevel_features_generator(params):
"""Generator function for various FPN models."""
if params.architecture.multilevel_features == 'fpn':
fpn_params = params.fpn
fpn_fn = fpn.Fpn(
min_level=params.architecture.min_level,
max_level=params.architecture.max_level,
fpn_feat_dims=fpn_params.fpn_feat_dims,
use_separable_conv=fpn_params.use_separable_conv,
activation=params.norm_activation.activation,
use_batch_norm=fpn_params.use_batch_norm,
norm_activation=norm_activation_generator(
params.norm_activation))
elif params.architecture.multilevel_features == 'identity':
fpn_fn = identity.Identity()
else:
raise ValueError('The multi-level feature model `{}` is not supported.'
.format(params.architecture.multilevel_features))
return fpn_fn
def retinanet_head_generator(params):
"""Generator function for RetinaNet head architecture."""
head_params = params.retinanet_head
anchors_per_location = params.anchor.num_scales * len(
params.anchor.aspect_ratios)
return heads.RetinanetHead(
params.architecture.min_level,
params.architecture.max_level,
params.architecture.num_classes,
anchors_per_location,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
norm_activation=norm_activation_generator(params.norm_activation))
def rpn_head_generator(params):
"""Generator function for RPN head architecture."""
head_params = params.rpn_head
anchors_per_location = params.anchor.num_scales * len(
params.anchor.aspect_ratios)
return heads.RpnHead(
params.architecture.min_level,
params.architecture.max_level,
anchors_per_location,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def oln_rpn_head_generator(params):
"""Generator function for OLN-proposal (OLN-RPN) head architecture."""
head_params = params.rpn_head
anchors_per_location = params.anchor.num_scales * len(
params.anchor.aspect_ratios)
return heads.OlnRpnHead(
params.architecture.min_level,
params.architecture.max_level,
anchors_per_location,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def fast_rcnn_head_generator(params):
"""Generator function for Fast R-CNN head architecture."""
head_params = params.frcnn_head
return heads.FastrcnnHead(
params.architecture.num_classes,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
head_params.num_fcs,
head_params.fc_dims,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def oln_box_score_head_generator(params):
"""Generator function for Scoring Fast R-CNN head architecture."""
head_params = params.frcnn_head
return heads.OlnBoxScoreHead(
params.architecture.num_classes,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
head_params.num_fcs,
head_params.fc_dims,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def mask_rcnn_head_generator(params):
"""Generator function for Mask R-CNN head architecture."""
head_params = params.mrcnn_head
return heads.MaskrcnnHead(
params.architecture.num_classes,
params.architecture.mask_target_size,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def oln_mask_score_head_generator(params):
"""Generator function for Scoring Mask R-CNN head architecture."""
head_params = params.mrcnn_head
return heads.OlnMaskScoreHead(
params.architecture.num_classes,
params.architecture.mask_target_size,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def shapeprior_head_generator(params):
"""Generator function for shape prior head architecture."""
head_params = params.shapemask_head
return heads.ShapemaskPriorHead(
params.architecture.num_classes,
head_params.num_downsample_channels,
head_params.mask_crop_size,
head_params.use_category_for_mask,
head_params.shape_prior_path)
def coarsemask_head_generator(params):
"""Generator function for ShapeMask coarse mask head architecture."""
head_params = params.shapemask_head
return heads.ShapemaskCoarsemaskHead(
params.architecture.num_classes,
head_params.num_downsample_channels,
head_params.mask_crop_size,
head_params.use_category_for_mask,
head_params.num_convs,
norm_activation=norm_activation_generator(params.norm_activation))
def finemask_head_generator(params):
"""Generator function for Shapemask fine mask head architecture."""
head_params = params.shapemask_head
return heads.ShapemaskFinemaskHead(
params.architecture.num_classes,
head_params.num_downsample_channels,
head_params.mask_crop_size,
head_params.use_category_for_mask,
head_params.num_convs,
head_params.upsample_factor)
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature Pyramid Networks.
Feature Pyramid Networks were proposed in:
[1] Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan,
, and Serge Belongie
Feature Pyramid Networks for Object Detection. CVPR 2017.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
from official.legacy.detection.modeling.architecture import nn_ops
from official.legacy.detection.ops import spatial_transform_ops
class Fpn(object):
"""Feature pyramid networks."""
def __init__(self,
min_level=3,
max_level=7,
fpn_feat_dims=256,
use_separable_conv=False,
activation='relu',
use_batch_norm=True,
norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""FPN initialization function.
Args:
min_level: `int` minimum level in FPN output feature maps.
max_level: `int` maximum level in FPN output feature maps.
fpn_feat_dims: `int` number of filters in FPN layers.
use_separable_conv: `bool`, if True use separable convolution for
convolution in FPN layers.
activation: the activation function.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
"""
self._min_level = min_level
self._max_level = max_level
self._fpn_feat_dims = fpn_feat_dims
if use_separable_conv:
self._conv2d_op = functools.partial(
tf.keras.layers.SeparableConv2D, depth_multiplier=1)
else:
self._conv2d_op = tf.keras.layers.Conv2D
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
self._norm_activation = norm_activation
self._norm_activations = {}
self._lateral_conv2d_op = {}
self._post_hoc_conv2d_op = {}
self._coarse_conv2d_op = {}
for level in range(self._min_level, self._max_level + 1):
if self._use_batch_norm:
self._norm_activations[level] = norm_activation(
use_activation=False, name='p%d-bn' % level)
self._lateral_conv2d_op[level] = self._conv2d_op(
filters=self._fpn_feat_dims,
kernel_size=(1, 1),
padding='same',
name='l%d' % level)
self._post_hoc_conv2d_op[level] = self._conv2d_op(
filters=self._fpn_feat_dims,
strides=(1, 1),
kernel_size=(3, 3),
padding='same',
name='post_hoc_d%d' % level)
self._coarse_conv2d_op[level] = self._conv2d_op(
filters=self._fpn_feat_dims,
strides=(2, 2),
kernel_size=(3, 3),
padding='same',
name='p%d' % level)
def __call__(self, multilevel_features, is_training=None):
"""Returns the FPN features for a given multilevel features.
Args:
multilevel_features: a `dict` containing `int` keys for continuous feature
levels, e.g., [2, 3, 4, 5]. The values are corresponding features with
shape [batch_size, height_l, width_l, num_filters].
is_training: `bool` if True, the model is in training mode.
Returns:
a `dict` containing `int` keys for continuous feature levels
[min_level, min_level + 1, ..., max_level]. The values are corresponding
FPN features with shape [batch_size, height_l, width_l, fpn_feat_dims].
"""
input_levels = list(multilevel_features.keys())
if min(input_levels) > self._min_level:
raise ValueError(
'The minimum backbone level %d should be '%(min(input_levels)) +
'less or equal to FPN minimum level %d.:'%(self._min_level))
backbone_max_level = min(max(input_levels), self._max_level)
with tf.name_scope('fpn'):
# Adds lateral connections.
feats_lateral = {}
for level in range(self._min_level, backbone_max_level + 1):
feats_lateral[level] = self._lateral_conv2d_op[level](
multilevel_features[level])
# Adds top-down path.
feats = {backbone_max_level: feats_lateral[backbone_max_level]}
for level in range(backbone_max_level - 1, self._min_level - 1, -1):
feats[level] = spatial_transform_ops.nearest_upsampling(
feats[level + 1], 2) + feats_lateral[level]
# Adds post-hoc 3x3 convolution kernel.
for level in range(self._min_level, backbone_max_level + 1):
feats[level] = self._post_hoc_conv2d_op[level](feats[level])
# Adds coarser FPN levels introduced for RetinaNet.
for level in range(backbone_max_level + 1, self._max_level + 1):
feats_in = feats[level - 1]
if level > backbone_max_level + 1:
feats_in = self._activation_op(feats_in)
feats[level] = self._coarse_conv2d_op[level](feats_in)
if self._use_batch_norm:
# Adds batch_norm layer.
for level in range(self._min_level, self._max_level + 1):
feats[level] = self._norm_activations[level](
feats[level], is_training=is_training)
return feats
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes to build various prediction heads in all supported models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import numpy as np
import tensorflow as tf
from official.legacy.detection.modeling.architecture import nn_ops
from official.legacy.detection.ops import spatial_transform_ops
class RpnHead(tf.keras.layers.Layer):
"""Region Proposal Network head."""
def __init__(
self,
min_level,
max_level,
anchors_per_location,
num_convs=2,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_batch_norm=True,
norm_activation=nn_ops.norm_activation_builder(activation='relu')):
"""Initialize params to build Region Proposal Network head.
Args:
min_level: `int` number of minimum feature level.
max_level: `int` number of maximum feature level.
anchors_per_location: `int` number of number of anchors per pixel
location.
num_convs: `int` number that represents the number of the intermediate
conv layers before the prediction.
num_filters: `int` number that represents the number of filters of the
intermediate conv layers.
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
activation: activation function. Support 'relu' and 'swish'.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer followed
by an optional activation layer.
"""
super().__init__(autocast=False)
self._min_level = min_level
self._max_level = max_level
self._anchors_per_location = anchors_per_location
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
if use_separable_conv:
self._conv2d_op = functools.partial(
tf.keras.layers.SeparableConv2D,
depth_multiplier=1,
bias_initializer=tf.zeros_initializer())
else:
self._conv2d_op = functools.partial(
tf.keras.layers.Conv2D,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer())
self._rpn_conv = self._conv2d_op(
num_filters,
kernel_size=(3, 3),
strides=(1, 1),
activation=(None if self._use_batch_norm else self._activation_op),
padding='same',
name='rpn')
self._rpn_class_conv = self._conv2d_op(
anchors_per_location,
kernel_size=(1, 1),
strides=(1, 1),
padding='valid',
name='rpn-class')
self._rpn_box_conv = self._conv2d_op(
4 * anchors_per_location,
kernel_size=(1, 1),
strides=(1, 1),
padding='valid',
name='rpn-box')
self._norm_activations = {}
if self._use_batch_norm:
for level in range(self._min_level, self._max_level + 1):
self._norm_activations[level] = norm_activation(name='rpn-l%d-bn' %
level)
def _shared_rpn_heads(self, features, anchors_per_location, level,
is_training):
"""Shared RPN heads."""
features = self._rpn_conv(features)
if self._use_batch_norm:
# The batch normalization layers are not shared between levels.
features = self._norm_activations[level](
features, is_training=is_training)
# Proposal classification scores
scores = self._rpn_class_conv(features)
# Proposal bbox regression deltas
bboxes = self._rpn_box_conv(features)
return scores, bboxes
def call(self, features, is_training=None):
scores_outputs = {}
box_outputs = {}
with tf.name_scope('rpn_head'):
for level in range(self._min_level, self._max_level + 1):
scores_output, box_output = self._shared_rpn_heads(
features[level], self._anchors_per_location, level, is_training)
scores_outputs[level] = scores_output
box_outputs[level] = box_output
return scores_outputs, box_outputs
class OlnRpnHead(tf.keras.layers.Layer):
"""Region Proposal Network for Object Localization Network (OLN)."""
def __init__(
self,
min_level,
max_level,
anchors_per_location,
num_convs=2,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_batch_norm=True,
norm_activation=nn_ops.norm_activation_builder(activation='relu')):
"""Initialize params to build Region Proposal Network head.
Args:
min_level: `int` number of minimum feature level.
max_level: `int` number of maximum feature level.
anchors_per_location: `int` number of number of anchors per pixel
location.
num_convs: `int` number that represents the number of the intermediate
conv layers before the prediction.
num_filters: `int` number that represents the number of filters of the
intermediate conv layers.
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
activation: activation function. Support 'relu' and 'swish'.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer followed
by an optional activation layer.
"""
self._min_level = min_level
self._max_level = max_level
self._anchors_per_location = anchors_per_location
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
if use_separable_conv:
self._conv2d_op = functools.partial(
tf.keras.layers.SeparableConv2D,
depth_multiplier=1,
bias_initializer=tf.zeros_initializer())
else:
self._conv2d_op = functools.partial(
tf.keras.layers.Conv2D,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer())
self._rpn_conv = self._conv2d_op(
num_filters,
kernel_size=(3, 3),
strides=(1, 1),
activation=(None if self._use_batch_norm else self._activation_op),
padding='same',
name='rpn')
self._rpn_class_conv = self._conv2d_op(
anchors_per_location,
kernel_size=(1, 1),
strides=(1, 1),
padding='valid',
name='rpn-class')
self._rpn_box_conv = self._conv2d_op(
4 * anchors_per_location,
kernel_size=(1, 1),
strides=(1, 1),
padding='valid',
name='rpn-box-lrtb')
self._rpn_center_conv = self._conv2d_op(
anchors_per_location,
kernel_size=(1, 1),
strides=(1, 1),
padding='valid',
name='rpn-centerness')
self._norm_activations = {}
if self._use_batch_norm:
for level in range(self._min_level, self._max_level + 1):
self._norm_activations[level] = norm_activation(name='rpn-l%d-bn' %
level)
def _shared_rpn_heads(self, features, anchors_per_location, level,
is_training):
"""Shared RPN heads."""
features = self._rpn_conv(features)
if self._use_batch_norm:
# The batch normalization layers are not shared between levels.
features = self._norm_activations[level](
features, is_training=is_training)
# Feature L2 normalization for training stability
features = tf.math.l2_normalize(
features,
axis=-1,
name='rpn-norm',)
# Proposal classification scores
scores = self._rpn_class_conv(features)
# Proposal bbox regression deltas
bboxes = self._rpn_box_conv(features)
# Proposal centerness scores
centers = self._rpn_center_conv(features)
return scores, bboxes, centers
def __call__(self, features, is_training=None):
scores_outputs = {}
box_outputs = {}
center_outputs = {}
with tf.name_scope('rpn_head'):
for level in range(self._min_level, self._max_level + 1):
scores_output, box_output, center_output = self._shared_rpn_heads(
features[level], self._anchors_per_location, level, is_training)
scores_outputs[level] = scores_output
box_outputs[level] = box_output
center_outputs[level] = center_output
return scores_outputs, box_outputs, center_outputs
class FastrcnnHead(tf.keras.layers.Layer):
"""Fast R-CNN box head."""
def __init__(
self,
num_classes,
num_convs=0,
num_filters=256,
use_separable_conv=False,
num_fcs=2,
fc_dims=1024,
activation='relu',
use_batch_norm=True,
norm_activation=nn_ops.norm_activation_builder(activation='relu')):
"""Initialize params to build Fast R-CNN box head.
Args:
num_classes: a integer for the number of classes.
num_convs: `int` number that represents the number of the intermediate
conv layers before the FC layers.
num_filters: `int` number that represents the number of filters of the
intermediate conv layers.
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
num_fcs: `int` number that represents the number of FC layers before the
predictions.
fc_dims: `int` number that represents the number of dimension of the FC
layers.
activation: activation function. Support 'relu' and 'swish'.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer followed
by an optional activation layer.
"""
super(FastrcnnHead, self).__init__(autocast=False)
self._num_classes = num_classes
self._num_convs = num_convs
self._num_filters = num_filters
if use_separable_conv:
self._conv2d_op = functools.partial(
tf.keras.layers.SeparableConv2D,
depth_multiplier=1,
bias_initializer=tf.zeros_initializer())
else:
self._conv2d_op = functools.partial(
tf.keras.layers.Conv2D,
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer())
self._num_fcs = num_fcs
self._fc_dims = fc_dims
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
self._norm_activation = norm_activation
self._conv_ops = []
self._conv_bn_ops = []
for i in range(self._num_convs):
self._conv_ops.append(
self._conv2d_op(
self._num_filters,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
dilation_rate=(1, 1),
activation=(None
if self._use_batch_norm else self._activation_op),
name='conv_{}'.format(i)))
if self._use_batch_norm:
self._conv_bn_ops.append(self._norm_activation())
self._fc_ops = []
self._fc_bn_ops = []
for i in range(self._num_fcs):
self._fc_ops.append(
tf.keras.layers.Dense(
units=self._fc_dims,
activation=(None
if self._use_batch_norm else self._activation_op),
name='fc{}'.format(i)))
if self._use_batch_norm:
self._fc_bn_ops.append(self._norm_activation(fused=False))
self._class_predict = tf.keras.layers.Dense(
self._num_classes,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer(),
name='class-predict')
self._box_predict = tf.keras.layers.Dense(
self._num_classes * 4,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.001),
bias_initializer=tf.zeros_initializer(),
name='box-predict')
def call(self, roi_features, is_training=None):
"""Box and class branches for the Mask-RCNN model.
Args:
roi_features: A ROI feature tensor of shape [batch_size, num_rois,
height_l, width_l, num_filters].
is_training: `boolean`, if True if model is in training mode.
Returns:
class_outputs: a tensor with a shape of
[batch_size, num_rois, num_classes], representing the class predictions.
box_outputs: a tensor with a shape of
[batch_size, num_rois, num_classes * 4], representing the box
predictions.
"""
with tf.name_scope(
'fast_rcnn_head'):
# reshape inputs beofre FC.
_, num_rois, height, width, filters = roi_features.get_shape().as_list()
net = tf.reshape(roi_features, [-1, height, width, filters])
for i in range(self._num_convs):
net = self._conv_ops[i](net)
if self._use_batch_norm:
net = self._conv_bn_ops[i](net, is_training=is_training)
filters = self._num_filters if self._num_convs > 0 else filters
net = tf.reshape(net, [-1, num_rois, height * width * filters])
for i in range(self._num_fcs):
net = self._fc_ops[i](net)
if self._use_batch_norm:
net = self._fc_bn_ops[i](net, is_training=is_training)
class_outputs = self._class_predict(net)
box_outputs = self._box_predict(net)
return class_outputs, box_outputs
class OlnBoxScoreHead(tf.keras.layers.Layer):
"""Box head of Object Localization Network (OLN)."""
def __init__(
self,
num_classes,
num_convs=0,
num_filters=256,
use_separable_conv=False,
num_fcs=2,
fc_dims=1024,
activation='relu',
use_batch_norm=True,
norm_activation=nn_ops.norm_activation_builder(activation='relu')):
"""Initialize params to build OLN box head.
Args:
num_classes: a integer for the number of classes.
num_convs: `int` number that represents the number of the intermediate
conv layers before the FC layers.
num_filters: `int` number that represents the number of filters of the
intermediate conv layers.
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
num_fcs: `int` number that represents the number of FC layers before the
predictions.
fc_dims: `int` number that represents the number of dimension of the FC
layers.
activation: activation function. Support 'relu' and 'swish'.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer followed
by an optional activation layer.
"""
self._num_classes = num_classes
self._num_convs = num_convs
self._num_filters = num_filters
if use_separable_conv:
self._conv2d_op = functools.partial(
tf.keras.layers.SeparableConv2D,
depth_multiplier=1,
bias_initializer=tf.zeros_initializer())
else:
self._conv2d_op = functools.partial(
tf.keras.layers.Conv2D,
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer())
self._num_fcs = num_fcs
self._fc_dims = fc_dims
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
self._norm_activation = norm_activation
self._conv_ops = []
self._conv_bn_ops = []
for i in range(self._num_convs):
self._conv_ops.append(
self._conv2d_op(
self._num_filters,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
dilation_rate=(1, 1),
activation=(None
if self._use_batch_norm else self._activation_op),
name='conv_{}'.format(i)))
if self._use_batch_norm:
self._conv_bn_ops.append(self._norm_activation())
self._fc_ops = []
self._fc_bn_ops = []
for i in range(self._num_fcs):
self._fc_ops.append(
tf.keras.layers.Dense(
units=self._fc_dims,
activation=(None
if self._use_batch_norm else self._activation_op),
name='fc{}'.format(i)))
if self._use_batch_norm:
self._fc_bn_ops.append(self._norm_activation(fused=False))
self._class_predict = tf.keras.layers.Dense(
self._num_classes,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer(),
name='class-predict')
self._box_predict = tf.keras.layers.Dense(
self._num_classes * 4,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.001),
bias_initializer=tf.zeros_initializer(),
name='box-predict')
self._score_predict = tf.keras.layers.Dense(
1,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer(),
name='score-predict')
def __call__(self, roi_features, is_training=None):
"""Box and class branches for the Mask-RCNN model.
Args:
roi_features: A ROI feature tensor of shape [batch_size, num_rois,
height_l, width_l, num_filters].
is_training: `boolean`, if True if model is in training mode.
Returns:
class_outputs: a tensor with a shape of
[batch_size, num_rois, num_classes], representing the class predictions.
box_outputs: a tensor with a shape of
[batch_size, num_rois, num_classes * 4], representing the box
predictions.
"""
with tf.name_scope('fast_rcnn_head'):
# reshape inputs beofre FC.
_, num_rois, height, width, filters = roi_features.get_shape().as_list()
net = tf.reshape(roi_features, [-1, height, width, filters])
for i in range(self._num_convs):
net = self._conv_ops[i](net)
if self._use_batch_norm:
net = self._conv_bn_ops[i](net, is_training=is_training)
filters = self._num_filters if self._num_convs > 0 else filters
net = tf.reshape(net, [-1, num_rois, height * width * filters])
for i in range(self._num_fcs):
net = self._fc_ops[i](net)
if self._use_batch_norm:
net = self._fc_bn_ops[i](net, is_training=is_training)
class_outputs = self._class_predict(net)
box_outputs = self._box_predict(net)
score_outputs = self._score_predict(net)
return class_outputs, box_outputs, score_outputs
class MaskrcnnHead(tf.keras.layers.Layer):
"""Mask R-CNN head."""
def __init__(
self,
num_classes,
mask_target_size,
num_convs=4,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_batch_norm=True,
norm_activation=nn_ops.norm_activation_builder(activation='relu')):
"""Initialize params to build Fast R-CNN head.
Args:
num_classes: a integer for the number of classes.
mask_target_size: a integer that is the resolution of masks.
num_convs: `int` number that represents the number of the intermediate
conv layers before the prediction.
num_filters: `int` number that represents the number of filters of the
intermediate conv layers.
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
activation: activation function. Support 'relu' and 'swish'.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer followed
by an optional activation layer.
"""
super(MaskrcnnHead, self).__init__(autocast=False)
self._num_classes = num_classes
self._mask_target_size = mask_target_size
self._num_convs = num_convs
self._num_filters = num_filters
if use_separable_conv:
self._conv2d_op = functools.partial(
tf.keras.layers.SeparableConv2D,
depth_multiplier=1,
bias_initializer=tf.zeros_initializer())
else:
self._conv2d_op = functools.partial(
tf.keras.layers.Conv2D,
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer())
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
self._norm_activation = norm_activation
self._conv2d_ops = []
for i in range(self._num_convs):
self._conv2d_ops.append(
self._conv2d_op(
self._num_filters,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
dilation_rate=(1, 1),
activation=(None
if self._use_batch_norm else self._activation_op),
name='mask-conv-l%d' % i))
self._mask_conv_transpose = tf.keras.layers.Conv2DTranspose(
self._num_filters,
kernel_size=(2, 2),
strides=(2, 2),
padding='valid',
activation=(None if self._use_batch_norm else self._activation_op),
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer(),
name='conv5-mask')
with tf.name_scope('mask_head'):
self._mask_conv2d_op = self._conv2d_op(
self._num_classes,
kernel_size=(1, 1),
strides=(1, 1),
padding='valid',
name='mask_fcn_logits')
def call(self, roi_features, class_indices, is_training=None):
"""Mask branch for the Mask-RCNN model.
Args:
roi_features: A ROI feature tensor of shape [batch_size, num_rois,
height_l, width_l, num_filters].
class_indices: a Tensor of shape [batch_size, num_rois], indicating which
class the ROI is.
is_training: `boolean`, if True if model is in training mode.
Returns:
mask_outputs: a tensor with a shape of
[batch_size, num_masks, mask_height, mask_width, num_classes],
representing the mask predictions.
fg_gather_indices: a tensor with a shape of [batch_size, num_masks, 2],
representing the fg mask targets.
Raises:
ValueError: If boxes is not a rank-3 tensor or the last dimension of
boxes is not 4.
"""
with tf.name_scope('mask_head'):
_, num_rois, height, width, filters = roi_features.get_shape().as_list()
net = tf.reshape(roi_features, [-1, height, width, filters])
for i in range(self._num_convs):
net = self._conv2d_ops[i](net)
if self._use_batch_norm:
net = self._norm_activation()(net, is_training=is_training)
net = self._mask_conv_transpose(net)
if self._use_batch_norm:
net = self._norm_activation()(net, is_training=is_training)
mask_outputs = self._mask_conv2d_op(net)
mask_outputs = tf.reshape(mask_outputs, [
-1, num_rois, self._mask_target_size, self._mask_target_size,
self._num_classes
])
with tf.name_scope('masks_post_processing'):
mask_outputs = tf.gather(
mask_outputs,
tf.cast(class_indices, tf.int32),
axis=-1,
batch_dims=2,
)
return mask_outputs
class RetinanetHead(object):
"""RetinaNet head."""
def __init__(
self,
min_level,
max_level,
num_classes,
anchors_per_location,
num_convs=4,
num_filters=256,
use_separable_conv=False,
norm_activation=nn_ops.norm_activation_builder(activation='relu')):
"""Initialize params to build RetinaNet head.
Args:
min_level: `int` number of minimum feature level.
max_level: `int` number of maximum feature level.
num_classes: `int` number of classification categories.
anchors_per_location: `int` number of anchors per pixel location.
num_convs: `int` number of stacked convolution before the last prediction
layer.
num_filters: `int` number of filters used in the head architecture.
use_separable_conv: `bool` to indicate whether to use separable
convoluation.
norm_activation: an operation that includes a normalization layer followed
by an optional activation layer.
"""
self._min_level = min_level
self._max_level = max_level
self._num_classes = num_classes
self._anchors_per_location = anchors_per_location
self._num_convs = num_convs
self._num_filters = num_filters
self._use_separable_conv = use_separable_conv
with tf.name_scope('class_net') as scope_name:
self._class_name_scope = tf.name_scope(scope_name)
with tf.name_scope('box_net') as scope_name:
self._box_name_scope = tf.name_scope(scope_name)
self._build_class_net_layers(norm_activation)
self._build_box_net_layers(norm_activation)
def _class_net_batch_norm_name(self, i, level):
return 'class-%d-%d' % (i, level)
def _box_net_batch_norm_name(self, i, level):
return 'box-%d-%d' % (i, level)
def _build_class_net_layers(self, norm_activation):
"""Build re-usable layers for class prediction network."""
if self._use_separable_conv:
self._class_predict = tf.keras.layers.SeparableConv2D(
self._num_classes * self._anchors_per_location,
kernel_size=(3, 3),
bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
padding='same',
name='class-predict')
else:
self._class_predict = tf.keras.layers.Conv2D(
self._num_classes * self._anchors_per_location,
kernel_size=(3, 3),
bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-5),
padding='same',
name='class-predict')
self._class_conv = []
self._class_norm_activation = {}
for i in range(self._num_convs):
if self._use_separable_conv:
self._class_conv.append(
tf.keras.layers.SeparableConv2D(
self._num_filters,
kernel_size=(3, 3),
bias_initializer=tf.zeros_initializer(),
activation=None,
padding='same',
name='class-' + str(i)))
else:
self._class_conv.append(
tf.keras.layers.Conv2D(
self._num_filters,
kernel_size=(3, 3),
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(
stddev=0.01),
activation=None,
padding='same',
name='class-' + str(i)))
for level in range(self._min_level, self._max_level + 1):
name = self._class_net_batch_norm_name(i, level)
self._class_norm_activation[name] = norm_activation(name=name)
def _build_box_net_layers(self, norm_activation):
"""Build re-usable layers for box prediction network."""
if self._use_separable_conv:
self._box_predict = tf.keras.layers.SeparableConv2D(
4 * self._anchors_per_location,
kernel_size=(3, 3),
bias_initializer=tf.zeros_initializer(),
padding='same',
name='box-predict')
else:
self._box_predict = tf.keras.layers.Conv2D(
4 * self._anchors_per_location,
kernel_size=(3, 3),
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-5),
padding='same',
name='box-predict')
self._box_conv = []
self._box_norm_activation = {}
for i in range(self._num_convs):
if self._use_separable_conv:
self._box_conv.append(
tf.keras.layers.SeparableConv2D(
self._num_filters,
kernel_size=(3, 3),
activation=None,
bias_initializer=tf.zeros_initializer(),
padding='same',
name='box-' + str(i)))
else:
self._box_conv.append(
tf.keras.layers.Conv2D(
self._num_filters,
kernel_size=(3, 3),
activation=None,
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(
stddev=0.01),
padding='same',
name='box-' + str(i)))
for level in range(self._min_level, self._max_level + 1):
name = self._box_net_batch_norm_name(i, level)
self._box_norm_activation[name] = norm_activation(name=name)
def __call__(self, fpn_features, is_training=None):
"""Returns outputs of RetinaNet head."""
class_outputs = {}
box_outputs = {}
with tf.name_scope('retinanet_head'):
for level in range(self._min_level, self._max_level + 1):
features = fpn_features[level]
class_outputs[level] = self.class_net(
features, level, is_training=is_training)
box_outputs[level] = self.box_net(
features, level, is_training=is_training)
return class_outputs, box_outputs
def class_net(self, features, level, is_training):
"""Class prediction network for RetinaNet."""
with self._class_name_scope:
for i in range(self._num_convs):
features = self._class_conv[i](features)
# The convolution layers in the class net are shared among all levels,
# but each level has its batch normlization to capture the statistical
# difference among different levels.
name = self._class_net_batch_norm_name(i, level)
features = self._class_norm_activation[name](
features, is_training=is_training)
classes = self._class_predict(features)
return classes
def box_net(self, features, level, is_training=None):
"""Box regression network for RetinaNet."""
with self._box_name_scope:
for i in range(self._num_convs):
features = self._box_conv[i](features)
# The convolution layers in the box net are shared among all levels, but
# each level has its batch normlization to capture the statistical
# difference among different levels.
name = self._box_net_batch_norm_name(i, level)
features = self._box_norm_activation[name](
features, is_training=is_training)
boxes = self._box_predict(features)
return boxes
# TODO(yeqing): Refactor this class when it is ready for var_scope reuse.
class ShapemaskPriorHead(object):
"""ShapeMask Prior head."""
def __init__(self, num_classes, num_downsample_channels, mask_crop_size,
use_category_for_mask, shape_prior_path):
"""Initialize params to build RetinaNet head.
Args:
num_classes: Number of output classes.
num_downsample_channels: number of channels in mask branch.
mask_crop_size: feature crop size.
use_category_for_mask: use class information in mask branch.
shape_prior_path: the path to load shape priors.
"""
self._mask_num_classes = num_classes if use_category_for_mask else 1
self._num_downsample_channels = num_downsample_channels
self._mask_crop_size = mask_crop_size
self._shape_prior_path = shape_prior_path
self._use_category_for_mask = use_category_for_mask
self._shape_prior_fc = tf.keras.layers.Dense(
self._num_downsample_channels, name='shape-prior-fc')
def __call__(self, fpn_features, boxes, outer_boxes, classes, is_training):
"""Generate the detection priors from the box detections and FPN features.
This corresponds to the Fig. 4 of the ShapeMask paper at
https://arxiv.org/pdf/1904.03239.pdf
Args:
fpn_features: a dictionary of FPN features.
boxes: a float tensor of shape [batch_size, num_instances, 4] representing
the tight gt boxes from dataloader/detection.
outer_boxes: a float tensor of shape [batch_size, num_instances, 4]
representing the loose gt boxes from dataloader/detection.
classes: a int Tensor of shape [batch_size, num_instances] of instance
classes.
is_training: training mode or not.
Returns:
instance_features: a float Tensor of shape [batch_size * num_instances,
mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
instance feature crop.
detection_priors: A float Tensor of shape [batch_size * num_instances,
mask_size, mask_size, 1].
"""
with tf.name_scope('prior_mask'):
batch_size, num_instances, _ = boxes.get_shape().as_list()
outer_boxes = tf.cast(outer_boxes, tf.float32)
boxes = tf.cast(boxes, tf.float32)
instance_features = spatial_transform_ops.multilevel_crop_and_resize(
fpn_features, outer_boxes, output_size=self._mask_crop_size)
instance_features = self._shape_prior_fc(instance_features)
shape_priors = self._get_priors()
# Get uniform priors for each outer box.
uniform_priors = tf.ones([
batch_size, num_instances, self._mask_crop_size, self._mask_crop_size
])
uniform_priors = spatial_transform_ops.crop_mask_in_target_box(
uniform_priors, boxes, outer_boxes, self._mask_crop_size)
# Classify shape priors using uniform priors + instance features.
prior_distribution = self._classify_shape_priors(
tf.cast(instance_features, tf.float32), uniform_priors, classes)
instance_priors = tf.gather(shape_priors, classes)
instance_priors *= tf.expand_dims(
tf.expand_dims(tf.cast(prior_distribution, tf.float32), axis=-1),
axis=-1)
instance_priors = tf.reduce_sum(instance_priors, axis=2)
detection_priors = spatial_transform_ops.crop_mask_in_target_box(
instance_priors, boxes, outer_boxes, self._mask_crop_size)
return instance_features, detection_priors
def _get_priors(self):
"""Load shape priors from file."""
# loads class specific or agnostic shape priors
if self._shape_prior_path:
# Priors are loaded into shape [mask_num_classes, num_clusters, 32, 32].
priors = np.load(tf.io.gfile.GFile(self._shape_prior_path, 'rb'))
priors = tf.convert_to_tensor(priors, dtype=tf.float32)
self._num_clusters = priors.get_shape().as_list()[1]
else:
# If prior path does not exist, do not use priors, i.e., pirors equal to
# uniform empty 32x32 patch.
self._num_clusters = 1
priors = tf.zeros([
self._mask_num_classes, self._num_clusters, self._mask_crop_size,
self._mask_crop_size
])
return priors
def _classify_shape_priors(self, features, uniform_priors, classes):
"""Classify the uniform prior by predicting the shape modes.
Classify the object crop features into K modes of the clusters for each
category.
Args:
features: A float Tensor of shape [batch_size, num_instances, mask_size,
mask_size, num_channels].
uniform_priors: A float Tensor of shape [batch_size, num_instances,
mask_size, mask_size] representing the uniform detection priors.
classes: A int Tensor of shape [batch_size, num_instances] of detection
class ids.
Returns:
prior_distribution: A float Tensor of shape
[batch_size, num_instances, num_clusters] representing the classifier
output probability over all possible shapes.
"""
batch_size, num_instances, _, _, _ = features.get_shape().as_list()
features *= tf.expand_dims(uniform_priors, axis=-1)
# Reduce spatial dimension of features. The features have shape
# [batch_size, num_instances, num_channels].
features = tf.reduce_mean(features, axis=(2, 3))
logits = tf.keras.layers.Dense(
self._mask_num_classes * self._num_clusters,
kernel_initializer=tf.random_normal_initializer(stddev=0.01),
name='classify-shape-prior-fc')(features)
logits = tf.reshape(
logits,
[batch_size, num_instances, self._mask_num_classes, self._num_clusters])
if self._use_category_for_mask:
logits = tf.gather(logits, tf.expand_dims(classes, axis=-1), batch_dims=2)
logits = tf.squeeze(logits, axis=2)
else:
logits = logits[:, :, 0, :]
distribution = tf.nn.softmax(logits, name='shape_prior_weights')
return distribution
class ShapemaskCoarsemaskHead(object):
"""ShapemaskCoarsemaskHead head."""
def __init__(self,
num_classes,
num_downsample_channels,
mask_crop_size,
use_category_for_mask,
num_convs,
norm_activation=nn_ops.norm_activation_builder()):
"""Initialize params to build ShapeMask coarse and fine prediction head.
Args:
num_classes: `int` number of mask classification categories.
num_downsample_channels: `int` number of filters at mask head.
mask_crop_size: feature crop size.
use_category_for_mask: use class information in mask branch.
num_convs: `int` number of stacked convolution before the last prediction
layer.
norm_activation: an operation that includes a normalization layer followed
by an optional activation layer.
"""
self._mask_num_classes = num_classes if use_category_for_mask else 1
self._use_category_for_mask = use_category_for_mask
self._num_downsample_channels = num_downsample_channels
self._mask_crop_size = mask_crop_size
self._num_convs = num_convs
self._norm_activation = norm_activation
self._coarse_mask_fc = tf.keras.layers.Dense(
self._num_downsample_channels, name='coarse-mask-fc')
self._class_conv = []
self._class_norm_activation = []
for i in range(self._num_convs):
self._class_conv.append(
tf.keras.layers.Conv2D(
self._num_downsample_channels,
kernel_size=(3, 3),
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(
stddev=0.01),
padding='same',
name='coarse-mask-class-%d' % i))
self._class_norm_activation.append(
norm_activation(name='coarse-mask-class-%d-bn' % i))
self._class_predict = tf.keras.layers.Conv2D(
self._mask_num_classes,
kernel_size=(1, 1),
# Focal loss bias initialization to have foreground 0.01 probability.
bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
padding='same',
name='coarse-mask-class-predict')
def __call__(self, features, detection_priors, classes, is_training):
"""Generate instance masks from FPN features and detection priors.
This corresponds to the Fig. 5-6 of the ShapeMask paper at
https://arxiv.org/pdf/1904.03239.pdf
Args:
features: a float Tensor of shape [batch_size, num_instances,
mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
instance feature crop.
detection_priors: a float Tensor of shape [batch_size, num_instances,
mask_crop_size, mask_crop_size, 1]. This is the detection prior for the
instance.
classes: a int Tensor of shape [batch_size, num_instances] of instance
classes.
is_training: a bool indicating whether in training mode.
Returns:
mask_outputs: instance mask prediction as a float Tensor of shape
[batch_size, num_instances, mask_size, mask_size].
"""
with tf.name_scope('coarse_mask'):
# Transform detection priors to have the same dimension as features.
detection_priors = tf.expand_dims(detection_priors, axis=-1)
detection_priors = self._coarse_mask_fc(detection_priors)
features += detection_priors
mask_logits = self.decoder_net(features, is_training)
# Gather the logits with right input class.
if self._use_category_for_mask:
mask_logits = tf.transpose(mask_logits, [0, 1, 4, 2, 3])
mask_logits = tf.gather(
mask_logits, tf.expand_dims(classes, -1), batch_dims=2)
mask_logits = tf.squeeze(mask_logits, axis=2)
else:
mask_logits = mask_logits[..., 0]
return mask_logits
def decoder_net(self, features, is_training=False):
"""Coarse mask decoder network architecture.
Args:
features: A tensor of size [batch, height_in, width_in, channels_in].
is_training: Whether batch_norm layers are in training mode.
Returns:
images: A feature tensor of size [batch, output_size, output_size,
num_channels]
"""
(batch_size, num_instances, height, width,
num_channels) = features.get_shape().as_list()
features = tf.reshape(
features, [batch_size * num_instances, height, width, num_channels])
for i in range(self._num_convs):
features = self._class_conv[i](features)
features = self._class_norm_activation[i](
features, is_training=is_training)
mask_logits = self._class_predict(features)
mask_logits = tf.reshape(
mask_logits,
[batch_size, num_instances, height, width, self._mask_num_classes])
return mask_logits
class ShapemaskFinemaskHead(object):
"""ShapemaskFinemaskHead head."""
def __init__(self,
num_classes,
num_downsample_channels,
mask_crop_size,
use_category_for_mask,
num_convs,
upsample_factor,
norm_activation=nn_ops.norm_activation_builder()):
"""Initialize params to build ShapeMask coarse and fine prediction head.
Args:
num_classes: `int` number of mask classification categories.
num_downsample_channels: `int` number of filters at mask head.
mask_crop_size: feature crop size.
use_category_for_mask: use class information in mask branch.
num_convs: `int` number of stacked convolution before the last prediction
layer.
upsample_factor: `int` number of fine mask upsampling factor.
norm_activation: an operation that includes a batch normalization layer
followed by a relu layer(optional).
"""
self._use_category_for_mask = use_category_for_mask
self._mask_num_classes = num_classes if use_category_for_mask else 1
self._num_downsample_channels = num_downsample_channels
self._mask_crop_size = mask_crop_size
self._num_convs = num_convs
self.up_sample_factor = upsample_factor
self._fine_mask_fc = tf.keras.layers.Dense(
self._num_downsample_channels, name='fine-mask-fc')
self._upsample_conv = tf.keras.layers.Conv2DTranspose(
self._num_downsample_channels,
(self.up_sample_factor, self.up_sample_factor),
(self.up_sample_factor, self.up_sample_factor),
name='fine-mask-conv2d-tran')
self._fine_class_conv = []
self._fine_class_bn = []
for i in range(self._num_convs):
self._fine_class_conv.append(
tf.keras.layers.Conv2D(
self._num_downsample_channels,
kernel_size=(3, 3),
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(
stddev=0.01),
activation=None,
padding='same',
name='fine-mask-class-%d' % i))
self._fine_class_bn.append(
norm_activation(name='fine-mask-class-%d-bn' % i))
self._class_predict_conv = tf.keras.layers.Conv2D(
self._mask_num_classes,
kernel_size=(1, 1),
# Focal loss bias initialization to have foreground 0.01 probability.
bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
padding='same',
name='fine-mask-class-predict')
def __call__(self, features, mask_logits, classes, is_training):
"""Generate instance masks from FPN features and detection priors.
This corresponds to the Fig. 5-6 of the ShapeMask paper at
https://arxiv.org/pdf/1904.03239.pdf
Args:
features: a float Tensor of shape [batch_size, num_instances,
mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
instance feature crop.
mask_logits: a float Tensor of shape [batch_size, num_instances,
mask_crop_size, mask_crop_size] indicating predicted mask logits.
classes: a int Tensor of shape [batch_size, num_instances] of instance
classes.
is_training: a bool indicating whether in training mode.
Returns:
mask_outputs: instance mask prediction as a float Tensor of shape
[batch_size, num_instances, mask_size, mask_size].
"""
# Extract the foreground mean features
# with tf.variable_scope('fine_mask', reuse=tf.AUTO_REUSE):
with tf.name_scope('fine_mask'):
mask_probs = tf.nn.sigmoid(mask_logits)
# Compute instance embedding for hard average.
binary_mask = tf.cast(tf.greater(mask_probs, 0.5), features.dtype)
instance_embedding = tf.reduce_sum(
features * tf.expand_dims(binary_mask, axis=-1), axis=(2, 3))
instance_embedding /= tf.expand_dims(
tf.reduce_sum(binary_mask, axis=(2, 3)) + 1e-20, axis=-1)
# Take the difference between crop features and mean instance features.
features -= tf.expand_dims(
tf.expand_dims(instance_embedding, axis=2), axis=2)
features += self._fine_mask_fc(tf.expand_dims(mask_probs, axis=-1))
# Decoder to generate upsampled segmentation mask.
mask_logits = self.decoder_net(features, is_training)
if self._use_category_for_mask:
mask_logits = tf.transpose(mask_logits, [0, 1, 4, 2, 3])
mask_logits = tf.gather(
mask_logits, tf.expand_dims(classes, -1), batch_dims=2)
mask_logits = tf.squeeze(mask_logits, axis=2)
else:
mask_logits = mask_logits[..., 0]
return mask_logits
def decoder_net(self, features, is_training=False):
"""Fine mask decoder network architecture.
Args:
features: A tensor of size [batch, height_in, width_in, channels_in].
is_training: Whether batch_norm layers are in training mode.
Returns:
images: A feature tensor of size [batch, output_size, output_size,
num_channels], where output size is self._gt_upsample_scale times
that of input.
"""
(batch_size, num_instances, height, width,
num_channels) = features.get_shape().as_list()
features = tf.reshape(
features, [batch_size * num_instances, height, width, num_channels])
for i in range(self._num_convs):
features = self._fine_class_conv[i](features)
features = self._fine_class_bn[i](features, is_training=is_training)
if self.up_sample_factor > 1:
features = self._upsample_conv(features)
# Predict per-class instance masks.
mask_logits = self._class_predict_conv(features)
mask_logits = tf.reshape(mask_logits, [
batch_size, num_instances, height * self.up_sample_factor,
width * self.up_sample_factor, self._mask_num_classes
])
return mask_logits
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Identity Fn that forwards the input features."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class Identity(object):
"""Identity function that forwards the input features."""
def __call__(self, features, is_training=False):
"""Only forwards the input features."""
return features
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for neural networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.modeling import tf_utils
class ResidualBlock(tf.keras.layers.Layer):
"""A residual block."""
def __init__(self,
filters,
strides,
use_projection=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A residual block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(ResidualBlock, self).__init__(**kwargs)
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._use_sync_bn = use_sync_bn
self._activation = activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(ResidualBlock, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ResidualBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
return self._activation_fn(x + shortcut)
class BottleneckBlock(tf.keras.layers.Layer):
"""A standard bottleneck block."""
def __init__(self,
filters,
strides,
use_projection=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A standard bottleneck block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(BottleneckBlock, self).__init__(**kwargs)
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._use_sync_bn = use_sync_bn
self._activation = activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv3 = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm3 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(BottleneckBlock, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(BottleneckBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
x = self._activation_fn(x)
x = self._conv3(x)
x = self._norm3(x)
return self._activation_fn(x + shortcut)
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neural network operations commonly shared by the architectures."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
class NormActivation(tf.keras.layers.Layer):
"""Combined Normalization and Activation layers."""
def __init__(self,
momentum=0.997,
epsilon=1e-4,
trainable=True,
init_zero=False,
use_activation=True,
activation='relu',
fused=True,
name=None):
"""A class to construct layers for a batch normalization followed by a ReLU.
Args:
momentum: momentum for the moving average.
epsilon: small float added to variance to avoid dividing by zero.
trainable: `bool`, if True also add variables to the graph collection
GraphKeys.TRAINABLE_VARIABLES. If False, freeze batch normalization
layer.
init_zero: `bool` if True, initializes scale parameter of batch
normalization with 0. If False, initialize it with 1.
use_activation: `bool`, whether to add the optional activation layer after
the batch normalization layer.
activation: 'string', the type of the activation layer. Currently support
`relu` and `swish`.
fused: `bool` fused option in batch normalziation.
name: `str` name for the operation.
"""
super(NormActivation, self).__init__(trainable=trainable)
if init_zero:
gamma_initializer = tf.keras.initializers.Zeros()
else:
gamma_initializer = tf.keras.initializers.Ones()
self._normalization_op = tf.keras.layers.BatchNormalization(
momentum=momentum,
epsilon=epsilon,
center=True,
scale=True,
trainable=trainable,
fused=fused,
gamma_initializer=gamma_initializer,
name=name)
self._use_activation = use_activation
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
def __call__(self, inputs, is_training=None):
"""Builds the normalization layer followed by an optional activation layer.
Args:
inputs: `Tensor` of shape `[batch, channels, ...]`.
is_training: `boolean`, if True if model is in training mode.
Returns:
A normalized `Tensor` with the same `data_format`.
"""
# We will need to keep training=None by default, so that it can be inherit
# from keras.Model.training
if is_training and self.trainable:
is_training = True
inputs = self._normalization_op(inputs, training=is_training)
if self._use_activation:
inputs = self._activation_op(inputs)
return inputs
def norm_activation_builder(momentum=0.997,
epsilon=1e-4,
trainable=True,
activation='relu',
**kwargs):
return functools.partial(
NormActivation,
momentum=momentum,
epsilon=epsilon,
trainable=trainable,
activation=activation,
**kwargs)
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions for the post-activation form of Residual Networks.
Residual networks (ResNets) were proposed in:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.legacy.detection.modeling.architecture import nn_ops
# TODO(b/140112644): Refactor the code with Keras style, i.e. build and call.
class Resnet(object):
"""Class to build ResNet family model."""
def __init__(
self,
resnet_depth,
activation='relu',
norm_activation=nn_ops.norm_activation_builder(activation='relu'),
data_format='channels_last'):
"""ResNet initialization function.
Args:
resnet_depth: `int` depth of ResNet backbone model.
activation: the activation function.
norm_activation: an operation that includes a normalization layer followed
by an optional activation layer.
data_format: `str` either "channels_first" for `[batch, channels, height,
width]` or "channels_last for `[batch, height, width, channels]`.
"""
self._resnet_depth = resnet_depth
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._norm_activation = norm_activation
self._data_format = data_format
model_params = {
10: {
'block': self.residual_block,
'layers': [1, 1, 1, 1]
},
18: {
'block': self.residual_block,
'layers': [2, 2, 2, 2]
},
34: {
'block': self.residual_block,
'layers': [3, 4, 6, 3]
},
50: {
'block': self.bottleneck_block,
'layers': [3, 4, 6, 3]
},
101: {
'block': self.bottleneck_block,
'layers': [3, 4, 23, 3]
},
152: {
'block': self.bottleneck_block,
'layers': [3, 8, 36, 3]
},
200: {
'block': self.bottleneck_block,
'layers': [3, 24, 36, 3]
}
}
if resnet_depth not in model_params:
valid_resnet_depths = ', '.join(
[str(depth) for depth in sorted(model_params.keys())])
raise ValueError(
'The resnet_depth should be in [%s]. Not a valid resnet_depth:' %
(valid_resnet_depths), self._resnet_depth)
params = model_params[resnet_depth]
self._resnet_fn = self.resnet_v1_generator(params['block'],
params['layers'])
def __call__(self, inputs, is_training=None):
"""Returns the ResNet model for a given size and number of output classes.
Args:
inputs: a `Tesnor` with shape [batch_size, height, width, 3] representing
a batch of images.
is_training: `bool` if True, the model is in training mode.
Returns:
a `dict` containing `int` keys for continuous feature levels [2, 3, 4, 5].
The values are corresponding feature hierarchy in ResNet with shape
[batch_size, height_l, width_l, num_filters].
"""
with tf.name_scope('resnet%s' % self._resnet_depth):
return self._resnet_fn(inputs, is_training)
def fixed_padding(self, inputs, kernel_size):
"""Pads the input along the spatial dimensions independently of input size.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]` or `[batch,
height, width, channels]` depending on `data_format`.
kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d`
operations. Should be a positive integer.
Returns:
A padded `Tensor` of the same `data_format` with size either intact
(if `kernel_size == 1`) or padded (if `kernel_size > 1`).
"""
pad_total = kernel_size - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
if self._data_format == 'channels_first':
padded_inputs = tf.pad(
tensor=inputs,
paddings=[[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]])
else:
padded_inputs = tf.pad(
tensor=inputs,
paddings=[[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
return padded_inputs
def conv2d_fixed_padding(self, inputs, filters, kernel_size, strides):
"""Strided 2-D convolution with explicit padding.
The padding is consistent and is based only on `kernel_size`, not on the
dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
Args:
inputs: `Tensor` of size `[batch, channels, height_in, width_in]`.
filters: `int` number of filters in the convolution.
kernel_size: `int` size of the kernel to be used in the convolution.
strides: `int` strides of the convolution.
Returns:
A `Tensor` of shape `[batch, filters, height_out, width_out]`.
"""
if strides > 1:
inputs = self.fixed_padding(inputs, kernel_size)
return tf.keras.layers.Conv2D(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=('SAME' if strides == 1 else 'VALID'),
use_bias=False,
kernel_initializer=tf.initializers.VarianceScaling(),
data_format=self._data_format)(
inputs=inputs)
def residual_block(self,
inputs,
filters,
strides,
use_projection=False,
is_training=None):
"""Standard building block for residual networks with BN after convolutions.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
is_training: `bool` if True, the model is in training mode.
Returns:
The output `Tensor` of the block.
"""
shortcut = inputs
if use_projection:
# Projection shortcut in first layer to match filters and strides
shortcut = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=strides)
shortcut = self._norm_activation(use_activation=False)(
shortcut, is_training=is_training)
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides)
inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=1)
inputs = self._norm_activation(
use_activation=False, init_zero=True)(
inputs, is_training=is_training)
return self._activation_op(inputs + shortcut)
def bottleneck_block(self,
inputs,
filters,
strides,
use_projection=False,
is_training=None):
"""Bottleneck block variant for residual networks with BN after convolutions.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
is_training: `bool` if True, the model is in training mode.
Returns:
The output `Tensor` of the block.
"""
shortcut = inputs
if use_projection:
# Projection shortcut only in first block within a group. Bottleneck
# blocks end with 4 times the number of filters.
filters_out = 4 * filters
shortcut = self.conv2d_fixed_padding(
inputs=inputs, filters=filters_out, kernel_size=1, strides=strides)
shortcut = self._norm_activation(use_activation=False)(
shortcut, is_training=is_training)
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=1)
inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides)
inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=4 * filters, kernel_size=1, strides=1)
inputs = self._norm_activation(
use_activation=False, init_zero=True)(
inputs, is_training=is_training)
return self._activation_op(inputs + shortcut)
def block_group(self, inputs, filters, block_fn, blocks, strides, name,
is_training):
"""Creates one group of blocks for the ResNet model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
block_fn: `function` for the block to use within the model
blocks: `int` number of blocks contained in the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
name: `str`name for the Tensor output of the block layer.
is_training: `bool` if True, the model is in training mode.
Returns:
The output `Tensor` of the block layer.
"""
# Only the first block per block_group uses projection shortcut and strides.
inputs = block_fn(
inputs, filters, strides, use_projection=True, is_training=is_training)
for _ in range(1, blocks):
inputs = block_fn(inputs, filters, 1, is_training=is_training)
return tf.identity(inputs, name)
def resnet_v1_generator(self, block_fn, layers):
"""Generator for ResNet v1 models.
Args:
block_fn: `function` for the block to use within the model. Either
`residual_block` or `bottleneck_block`.
layers: list of 4 `int`s denoting the number of blocks to include in each
of the 4 block groups. Each group consists of blocks that take inputs of
the same resolution.
Returns:
Model `function` that takes in `inputs` and `is_training` and returns the
output `Tensor` of the ResNet model.
"""
def model(inputs, is_training=None):
"""Creation of the model graph."""
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=64, kernel_size=7, strides=2)
inputs = tf.identity(inputs, 'initial_conv')
inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = tf.keras.layers.MaxPool2D(
pool_size=3, strides=2, padding='SAME',
data_format=self._data_format)(
inputs)
inputs = tf.identity(inputs, 'initial_max_pool')
c2 = self.block_group(
inputs=inputs,
filters=64,
block_fn=block_fn,
blocks=layers[0],
strides=1,
name='block_group1',
is_training=is_training)
c3 = self.block_group(
inputs=c2,
filters=128,
block_fn=block_fn,
blocks=layers[1],
strides=2,
name='block_group2',
is_training=is_training)
c4 = self.block_group(
inputs=c3,
filters=256,
block_fn=block_fn,
blocks=layers[2],
strides=2,
name='block_group3',
is_training=is_training)
c5 = self.block_group(
inputs=c4,
filters=512,
block_fn=block_fn,
blocks=layers[3],
strides=2,
name='block_group4',
is_training=is_training)
return {2: c2, 3: c3, 4: c4, 5: c5}
return model
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of SpineNet model.
X. Du, T-Y. Lin, P. Jin, G. Ghiasi, M. Tan, Y. Cui, Q. V. Le, X. Song
SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization
https://arxiv.org/abs/1912.05027
"""
import math
from absl import logging
import tensorflow as tf
from official.legacy.detection.modeling.architecture import nn_blocks
from official.modeling import tf_utils
layers = tf.keras.layers
FILTER_SIZE_MAP = {
1: 32,
2: 64,
3: 128,
4: 256,
5: 256,
6: 256,
7: 256,
}
# The fixed SpineNet architecture discovered by NAS.
# Each element represents a specification of a building block:
# (block_level, block_fn, (input_offset0, input_offset1), is_output).
SPINENET_BLOCK_SPECS = [
(2, 'bottleneck', (0, 1), False),
(4, 'residual', (0, 1), False),
(3, 'bottleneck', (2, 3), False),
(4, 'bottleneck', (2, 4), False),
(6, 'residual', (3, 5), False),
(4, 'bottleneck', (3, 5), False),
(5, 'residual', (6, 7), False),
(7, 'residual', (6, 8), False),
(5, 'bottleneck', (8, 9), False),
(5, 'bottleneck', (8, 10), False),
(4, 'bottleneck', (5, 10), True),
(3, 'bottleneck', (4, 10), True),
(5, 'bottleneck', (7, 12), True),
(7, 'bottleneck', (5, 14), True),
(6, 'bottleneck', (12, 14), True),
]
SCALING_MAP = {
'49S': {
'endpoints_num_filters': 128,
'filter_size_scale': 0.65,
'resample_alpha': 0.5,
'block_repeats': 1,
},
'49': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 0.5,
'block_repeats': 1,
},
'96': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 0.5,
'block_repeats': 2,
},
'143': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 1.0,
'block_repeats': 3,
},
'190': {
'endpoints_num_filters': 512,
'filter_size_scale': 1.3,
'resample_alpha': 1.0,
'block_repeats': 4,
},
}
class BlockSpec(object):
"""A container class that specifies the block configuration for SpineNet."""
def __init__(self, level, block_fn, input_offsets, is_output):
self.level = level
self.block_fn = block_fn
self.input_offsets = input_offsets
self.is_output = is_output
def build_block_specs(block_specs=None):
"""Builds the list of BlockSpec objects for SpineNet."""
if not block_specs:
block_specs = SPINENET_BLOCK_SPECS
logging.info('Building SpineNet block specs: %s', block_specs)
return [BlockSpec(*b) for b in block_specs]
class SpineNet(tf.keras.Model):
"""Class to build SpineNet models."""
def __init__(self,
input_specs=tf.keras.layers.InputSpec(shape=[None, 640, 640, 3]),
min_level=3,
max_level=7,
block_specs=None,
endpoints_num_filters=256,
resample_alpha=0.5,
block_repeats=1,
filter_size_scale=1.0,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""SpineNet model."""
self._min_level = min_level
self._max_level = max_level
self._block_specs = (
build_block_specs() if block_specs is None else block_specs
)
self._endpoints_num_filters = endpoints_num_filters
self._resample_alpha = resample_alpha
self._block_repeats = block_repeats
self._filter_size_scale = filter_size_scale
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if activation == 'relu':
self._activation = tf.nn.relu
elif activation == 'swish':
self._activation = tf.nn.swish
else:
raise ValueError('Activation {} not implemented.'.format(activation))
self._init_block_fn = 'bottleneck'
self._num_init_blocks = 2
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
# Build SpineNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
net = self._build_stem(inputs=inputs)
net = self._build_scale_permuted_network(
net=net, input_width=input_specs.shape[1])
net = self._build_endpoints(net=net)
super(SpineNet, self).__init__(inputs=inputs, outputs=net)
def _block_group(self,
inputs,
filters,
strides,
block_fn_cand,
block_repeats=1,
name='block_group'):
"""Creates one group of blocks for the SpineNet model."""
block_fn_candidates = {
'bottleneck': nn_blocks.BottleneckBlock,
'residual': nn_blocks.ResidualBlock,
}
block_fn = block_fn_candidates[block_fn_cand]
_, _, _, num_filters = inputs.get_shape().as_list()
if block_fn_cand == 'bottleneck':
use_projection = not (num_filters == (filters * 4) and strides == 1)
else:
use_projection = not (num_filters == filters and strides == 1)
x = block_fn(
filters=filters,
strides=strides,
use_projection=use_projection,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for _ in range(1, block_repeats):
x = block_fn(
filters=filters,
strides=1,
use_projection=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
def _build_stem(self, inputs):
"""Build SpineNet stem."""
x = layers.Conv2D(
filters=64,
kernel_size=7,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
net = []
# Build the initial level 2 blocks.
for i in range(self._num_init_blocks):
x = self._block_group(
inputs=x,
filters=int(FILTER_SIZE_MAP[2] * self._filter_size_scale),
strides=1,
block_fn_cand=self._init_block_fn,
block_repeats=self._block_repeats,
name='stem_block_{}'.format(i + 1))
net.append(x)
return net
def _build_scale_permuted_network(self,
net,
input_width,
weighted_fusion=False):
"""Build scale-permuted network."""
net_sizes = [int(math.ceil(input_width / 2**2))] * len(net)
net_block_fns = [self._init_block_fn] * len(net)
num_outgoing_connections = [0] * len(net)
endpoints = {}
for i, block_spec in enumerate(self._block_specs):
# Find out specs for the target block.
target_width = int(math.ceil(input_width / 2**block_spec.level))
target_num_filters = int(FILTER_SIZE_MAP[block_spec.level] *
self._filter_size_scale)
target_block_fn = block_spec.block_fn
# Resample then merge input0 and input1.
parents = []
input0 = block_spec.input_offsets[0]
input1 = block_spec.input_offsets[1]
x0 = self._resample_with_alpha(
inputs=net[input0],
input_width=net_sizes[input0],
input_block_fn=net_block_fns[input0],
target_width=target_width,
target_num_filters=target_num_filters,
target_block_fn=target_block_fn,
alpha=self._resample_alpha)
parents.append(x0)
num_outgoing_connections[input0] += 1
x1 = self._resample_with_alpha(
inputs=net[input1],
input_width=net_sizes[input1],
input_block_fn=net_block_fns[input1],
target_width=target_width,
target_num_filters=target_num_filters,
target_block_fn=target_block_fn,
alpha=self._resample_alpha)
parents.append(x1)
num_outgoing_connections[input1] += 1
# Merge 0 outdegree blocks to the output block.
if block_spec.is_output:
for j, (j_feat,
j_connections) in enumerate(zip(net, num_outgoing_connections)):
if j_connections == 0 and (j_feat.shape[2] == target_width and
j_feat.shape[3] == x0.shape[3]):
parents.append(j_feat)
num_outgoing_connections[j] += 1
# pylint: disable=g-direct-tensorflow-import
if weighted_fusion:
dtype = parents[0].dtype
parent_weights = [
tf.nn.relu(tf.cast(tf.Variable(1.0, name='block{}_fusion{}'.format(
i, j)), dtype=dtype)) for j in range(len(parents))]
weights_sum = tf.add_n(parent_weights)
parents = [
parents[i] * parent_weights[i] / (weights_sum + 0.0001)
for i in range(len(parents))
]
# Fuse all parent nodes then build a new block.
x = tf_utils.get_activation(self._activation)(tf.add_n(parents))
x = self._block_group(
inputs=x,
filters=target_num_filters,
strides=1,
block_fn_cand=target_block_fn,
block_repeats=self._block_repeats,
name='scale_permuted_block_{}'.format(i + 1))
net.append(x)
net_sizes.append(target_width)
net_block_fns.append(target_block_fn)
num_outgoing_connections.append(0)
# Save output feats.
if block_spec.is_output:
if block_spec.level in endpoints:
raise ValueError('Duplicate feats found for output level {}.'.format(
block_spec.level))
if (block_spec.level < self._min_level or
block_spec.level > self._max_level):
raise ValueError('Output level is out of range [{}, {}]'.format(
self._min_level, self._max_level))
endpoints[block_spec.level] = x
return endpoints
def _build_endpoints(self, net):
"""Match filter size for endpoints before sharing conv layers."""
endpoints = {}
for level in range(self._min_level, self._max_level + 1):
x = layers.Conv2D(
filters=self._endpoints_num_filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
net[level])
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
endpoints[level] = x
return endpoints
def _resample_with_alpha(self,
inputs,
input_width,
input_block_fn,
target_width,
target_num_filters,
target_block_fn,
alpha=0.5):
"""Match resolution and feature dimension."""
_, _, _, input_num_filters = inputs.get_shape().as_list()
if input_block_fn == 'bottleneck':
input_num_filters /= 4
new_num_filters = int(input_num_filters * alpha)
x = layers.Conv2D(
filters=new_num_filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
# Spatial resampling.
if input_width > target_width:
x = layers.Conv2D(
filters=new_num_filters,
kernel_size=3,
strides=2,
padding='SAME',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
input_width /= 2
while input_width > target_width:
x = layers.MaxPool2D(pool_size=3, strides=2, padding='SAME')(x)
input_width /= 2
elif input_width < target_width:
scale = target_width // input_width
x = layers.UpSampling2D(size=(scale, scale))(x)
# Last 1x1 conv to match filter size.
if target_block_fn == 'bottleneck':
target_num_filters *= 4
x = layers.Conv2D(
filters=target_num_filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
return x
class SpineNetBuilder(object):
"""SpineNet builder."""
def __init__(self,
model_id,
input_specs=tf.keras.layers.InputSpec(shape=[None, 640, 640, 3]),
min_level=3,
max_level=7,
block_specs=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001):
if model_id not in SCALING_MAP:
raise ValueError(
'SpineNet {} is not a valid architecture.'.format(model_id))
scaling_params = SCALING_MAP[model_id]
self._input_specs = input_specs
self._min_level = min_level
self._max_level = max_level
self._block_specs = block_specs or build_block_specs()
self._endpoints_num_filters = scaling_params['endpoints_num_filters']
self._resample_alpha = scaling_params['resample_alpha']
self._block_repeats = scaling_params['block_repeats']
self._filter_size_scale = scaling_params['filter_size_scale']
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._activation = activation
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
def __call__(self, inputs, is_training=None):
model = SpineNet(
input_specs=self._input_specs,
min_level=self._min_level,
max_level=self._max_level,
block_specs=self._block_specs,
endpoints_num_filters=self._endpoints_num_filters,
resample_alpha=self._resample_alpha,
block_repeats=self._block_repeats,
filter_size_scale=self._filter_size_scale,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)
return model(inputs)
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base Model definition."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import re
import tensorflow as tf
from official.legacy.detection.modeling import checkpoint_utils
from official.legacy.detection.modeling import learning_rates
from official.legacy.detection.modeling import optimizers
def _make_filter_trainable_variables_fn(frozen_variable_prefix):
"""Creates a function for filtering trainable varialbes."""
def _filter_trainable_variables(variables):
"""Filters trainable varialbes.
Args:
variables: a list of tf.Variable to be filtered.
Returns:
filtered_variables: a list of tf.Variable filtered out the frozen ones.
"""
# frozen_variable_prefix: a regex string specifing the prefix pattern of
# the frozen variables' names.
filtered_variables = [
v for v in variables if not frozen_variable_prefix or
not re.match(frozen_variable_prefix, v.name)
]
return filtered_variables
return _filter_trainable_variables
class Model(object):
"""Base class for model function."""
__metaclass__ = abc.ABCMeta
def __init__(self, params):
self._use_bfloat16 = params.architecture.use_bfloat16
if params.architecture.use_bfloat16:
tf.compat.v2.keras.mixed_precision.set_global_policy('mixed_bfloat16')
# Optimization.
self._optimizer_fn = optimizers.OptimizerFactory(params.train.optimizer)
self._learning_rate = learning_rates.learning_rate_generator(
params.train.total_steps, params.train.learning_rate)
self._frozen_variable_prefix = params.train.frozen_variable_prefix
self._regularization_var_regex = params.train.regularization_variable_regex
self._l2_weight_decay = params.train.l2_weight_decay
# Checkpoint restoration.
self._checkpoint = params.train.checkpoint.as_dict()
# Summary.
self._enable_summary = params.enable_summary
self._model_dir = params.model_dir
@abc.abstractmethod
def build_outputs(self, inputs, mode):
"""Build the graph of the forward path."""
pass
@abc.abstractmethod
def build_model(self, params, mode):
"""Build the model object."""
pass
@abc.abstractmethod
def build_loss_fn(self):
"""Build the model object."""
pass
def post_processing(self, labels, outputs):
"""Post-processing function."""
return labels, outputs
def model_outputs(self, inputs, mode):
"""Build the model outputs."""
return self.build_outputs(inputs, mode)
def build_optimizer(self):
"""Returns train_op to optimize total loss."""
# Sets up the optimizer.
return self._optimizer_fn(self._learning_rate)
def make_filter_trainable_variables_fn(self):
"""Creates a function for filtering trainable varialbes."""
return _make_filter_trainable_variables_fn(self._frozen_variable_prefix)
def weight_decay_loss(self, trainable_variables):
reg_variables = [
v for v in trainable_variables
if self._regularization_var_regex is None or
re.match(self._regularization_var_regex, v.name)
]
return self._l2_weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in reg_variables])
def make_restore_checkpoint_fn(self):
"""Returns scaffold function to restore parameters from v1 checkpoint."""
if 'skip_checkpoint_variables' in self._checkpoint:
skip_regex = self._checkpoint['skip_checkpoint_variables']
else:
skip_regex = None
return checkpoint_utils.make_restore_checkpoint_fn(
self._checkpoint['path'],
prefix=self._checkpoint['prefix'],
skip_regex=skip_regex)
def eval_metrics(self):
"""Returns tuple of metric function and its inputs for evaluation."""
raise NotImplementedError('Unimplemented eval_metrics')
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util functions for loading checkpoints.
Especially for loading Tensorflow 1.x
checkpoint to Tensorflow 2.x (keras) model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from absl import logging
import tensorflow as tf
def _build_assignment_map(keras_model,
prefix='',
skip_variables_regex=None,
var_to_shape_map=None):
"""Builds the variable assignment map.
Compute an assignment mapping for loading older checkpoints into a Keras
model. Variable names are remapped from the original TPUEstimator model to
the new Keras name.
Args:
keras_model: tf.keras.Model object to provide variables to assign.
prefix: prefix in the variable name to be remove for alignment with names in
the checkpoint.
skip_variables_regex: regular expression to math the names of variables that
do not need to be assign.
var_to_shape_map: variable name to shape mapping from the checkpoint.
Returns:
The variable assignment map.
"""
assignment_map = {}
checkpoint_names = []
if var_to_shape_map:
# pylint: disable=g-long-lambda
checkpoint_names = list(
filter(
lambda x: not x.endswith('Momentum') and not x.endswith(
'global_step'), var_to_shape_map.keys()))
# pylint: enable=g-long-lambda
logging.info('Number of variables in the checkpoint %d',
len(checkpoint_names))
for var in keras_model.variables:
var_name = var.name
if skip_variables_regex and re.match(skip_variables_regex, var_name):
continue
# Trim the index of the variable.
if ':' in var_name:
var_name = var_name[:var_name.rindex(':')]
if var_name.startswith(prefix):
var_name = var_name[len(prefix):]
if not var_to_shape_map:
assignment_map[var_name] = var
continue
# Match name with variables in the checkpoint.
# pylint: disable=cell-var-from-loop
match_names = list(filter(lambda x: x.endswith(var_name), checkpoint_names))
# pylint: enable=cell-var-from-loop
try:
if match_names:
assert len(match_names) == 1, 'more then on matches for {}: {}'.format(
var_name, match_names)
checkpoint_names.remove(match_names[0])
assignment_map[match_names[0]] = var
else:
logging.info('Error not found var name: %s', var_name)
except Exception as e:
logging.info('Error removing the match_name: %s', match_names)
logging.info('Exception: %s', e)
raise
logging.info('Found matching variable in checkpoint: %d', len(assignment_map))
return assignment_map
def _get_checkpoint_map(checkpoint_path):
reader = tf.train.load_checkpoint(checkpoint_path)
return reader.get_variable_to_shape_map()
def make_restore_checkpoint_fn(checkpoint_path, prefix='', skip_regex=None):
"""Returns scaffold function to restore parameters from v1 checkpoint.
Args:
checkpoint_path: path of the checkpoint folder or file.
Example 1: '/path/to/model_dir/'
Example 2: '/path/to/model.ckpt-22500'
prefix: prefix in the variable name to be remove for alignment with names in
the checkpoint.
skip_regex: regular expression to math the names of variables that do not
need to be assign.
Returns:
Callable[tf.kears.Model] -> void. Fn to load v1 checkpoint to keras model.
"""
def _restore_checkpoint_fn(keras_model):
"""Loads pretrained model through scaffold function."""
if not checkpoint_path:
logging.info('checkpoint_path is empty')
return
var_prefix = prefix
if prefix and not prefix.endswith('/'):
var_prefix += '/'
var_to_shape_map = _get_checkpoint_map(checkpoint_path)
assert var_to_shape_map, 'var_to_shape_map should not be empty'
vars_to_load = _build_assignment_map(
keras_model,
prefix=var_prefix,
skip_variables_regex=skip_regex,
var_to_shape_map=var_to_shape_map)
if not vars_to_load:
raise ValueError('Variables to load is empty.')
tf.compat.v1.train.init_from_checkpoint(checkpoint_path, vars_to_load)
return _restore_checkpoint_fn
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory to build detection model."""
from official.legacy.detection.modeling import maskrcnn_model
from official.legacy.detection.modeling import olnmask_model
from official.legacy.detection.modeling import retinanet_model
from official.legacy.detection.modeling import shapemask_model
def model_generator(params):
"""Model function generator."""
if params.type == 'retinanet':
model_fn = retinanet_model.RetinanetModel(params)
elif params.type == 'mask_rcnn':
model_fn = maskrcnn_model.MaskrcnnModel(params)
elif params.type == 'olnmask':
model_fn = olnmask_model.OlnMaskModel(params)
elif params.type == 'shapemask':
model_fn = shapemask_model.ShapeMaskModel(params)
else:
raise ValueError('Model %s is not supported.'% params.type)
return model_fn
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Learning rate schedule."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from official.modeling.hyperparams import params_dict
class StepLearningRateWithLinearWarmup(
tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor."""
def __init__(self, total_steps, params):
"""Creates the step learning rate tensor with linear warmup."""
super(StepLearningRateWithLinearWarmup, self).__init__()
self._total_steps = total_steps
assert isinstance(params, (dict, params_dict.ParamsDict))
if isinstance(params, dict):
params = params_dict.ParamsDict(params)
self._params = params
def __call__(self, global_step):
warmup_lr = self._params.warmup_learning_rate
warmup_steps = self._params.warmup_steps
init_lr = self._params.init_learning_rate
lr_levels = self._params.learning_rate_levels
lr_steps = self._params.learning_rate_steps
linear_warmup = (
warmup_lr + tf.cast(global_step, dtype=tf.float32) / warmup_steps *
(init_lr - warmup_lr))
learning_rate = tf.where(global_step < warmup_steps, linear_warmup, init_lr)
for next_learning_rate, start_step in zip(lr_levels, lr_steps):
learning_rate = tf.where(global_step >= start_step, next_learning_rate,
learning_rate)
return learning_rate
def get_config(self):
return {'_params': self._params.as_dict()}
class CosineLearningRateWithLinearWarmup(
tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor."""
def __init__(self, total_steps, params):
"""Creates the cosine learning rate tensor with linear warmup."""
super(CosineLearningRateWithLinearWarmup, self).__init__()
self._total_steps = total_steps
assert isinstance(params, (dict, params_dict.ParamsDict))
if isinstance(params, dict):
params = params_dict.ParamsDict(params)
self._params = params
def __call__(self, global_step):
global_step = tf.cast(global_step, dtype=tf.float32)
warmup_lr = self._params.warmup_learning_rate
warmup_steps = self._params.warmup_steps
init_lr = self._params.init_learning_rate
total_steps = self._total_steps
linear_warmup = (
warmup_lr + global_step / warmup_steps * (init_lr - warmup_lr))
cosine_learning_rate = (
init_lr * (tf.cos(np.pi * (global_step - warmup_steps) /
(total_steps - warmup_steps)) + 1.0) / 2.0)
learning_rate = tf.where(global_step < warmup_steps, linear_warmup,
cosine_learning_rate)
return learning_rate
def get_config(self):
return {'_params': self._params.as_dict()}
def learning_rate_generator(total_steps, params):
"""The learning rate function generator."""
if params.type == 'step':
return StepLearningRateWithLinearWarmup(total_steps, params)
elif params.type == 'cosine':
return CosineLearningRateWithLinearWarmup(total_steps, params)
else:
raise ValueError('Unsupported learning rate type: {}.'.format(params.type))
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses used for detection models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import tensorflow as tf
def focal_loss(logits, targets, alpha, gamma, normalizer):
"""Compute the focal loss between `logits` and the golden `target` values.
Focal loss = -(1-pt)^gamma * log(pt)
where pt is the probability of being classified to the true class.
Args:
logits: A float32 tensor of size
[batch, height_in, width_in, num_predictions].
targets: A float32 tensor of size
[batch, height_in, width_in, num_predictions].
alpha: A float32 scalar multiplying alpha to the loss from positive examples
and (1-alpha) to the loss from negative examples.
gamma: A float32 scalar modulating loss from hard and easy examples.
normalizer: A float32 scalar normalizes the total loss from all examples.
Returns:
loss: A float32 Tensor of size [batch, height_in, width_in, num_predictions]
representing normalized loss on the prediction map.
"""
with tf.name_scope('focal_loss'):
positive_label_mask = tf.math.equal(targets, 1.0)
cross_entropy = (
tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits))
# Below are comments/derivations for computing modulator.
# For brevity, let x = logits, z = targets, r = gamma, and p_t = sigmod(x)
# for positive samples and 1 - sigmoid(x) for negative examples.
#
# The modulator, defined as (1 - P_t)^r, is a critical part in focal loss
# computation. For r > 0, it puts more weights on hard examples, and less
# weights on easier ones. However if it is directly computed as (1 - P_t)^r,
# its back-propagation is not stable when r < 1. The implementation here
# resolves the issue.
#
# For positive samples (labels being 1),
# (1 - p_t)^r
# = (1 - sigmoid(x))^r
# = (1 - (1 / (1 + exp(-x))))^r
# = (exp(-x) / (1 + exp(-x)))^r
# = exp(log((exp(-x) / (1 + exp(-x)))^r))
# = exp(r * log(exp(-x)) - r * log(1 + exp(-x)))
# = exp(- r * x - r * log(1 + exp(-x)))
#
# For negative samples (labels being 0),
# (1 - p_t)^r
# = (sigmoid(x))^r
# = (1 / (1 + exp(-x)))^r
# = exp(log((1 / (1 + exp(-x)))^r))
# = exp(-r * log(1 + exp(-x)))
#
# Therefore one unified form for positive (z = 1) and negative (z = 0)
# samples is:
# (1 - p_t)^r = exp(-r * z * x - r * log(1 + exp(-x))).
neg_logits = -1.0 * logits
modulator = tf.math.exp(gamma * targets * neg_logits -
gamma * tf.math.log1p(tf.math.exp(neg_logits)))
loss = modulator * cross_entropy
weighted_loss = tf.where(positive_label_mask, alpha * loss,
(1.0 - alpha) * loss)
weighted_loss /= normalizer
return weighted_loss
class RpnScoreLoss(object):
"""Region Proposal Network score loss function."""
def __init__(self, params):
self._rpn_batch_size_per_im = params.rpn_batch_size_per_im
self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, score_outputs, labels):
"""Computes total RPN detection loss.
Computes total RPN detection loss including box and score from all levels.
Args:
score_outputs: an OrderDict with keys representing levels and values
representing scores in [batch_size, height, width, num_anchors].
labels: the dictionary that returned from dataloader that includes
groundturth targets.
Returns:
rpn_score_loss: a scalar tensor representing total score loss.
"""
with tf.name_scope('rpn_loss'):
levels = sorted(score_outputs.keys())
score_losses = []
for level in levels:
score_losses.append(
self._rpn_score_loss(
score_outputs[level],
labels[level],
normalizer=tf.cast(
tf.shape(score_outputs[level])[0] *
self._rpn_batch_size_per_im, dtype=tf.float32)))
# Sums per level losses to total loss.
return tf.math.add_n(score_losses)
def _rpn_score_loss(self, score_outputs, score_targets, normalizer=1.0):
"""Computes score loss."""
# score_targets has three values:
# (1) score_targets[i]=1, the anchor is a positive sample.
# (2) score_targets[i]=0, negative.
# (3) score_targets[i]=-1, the anchor is don't care (ignore).
with tf.name_scope('rpn_score_loss'):
mask = tf.math.logical_or(tf.math.equal(score_targets, 1),
tf.math.equal(score_targets, 0))
score_targets = tf.math.maximum(score_targets,
tf.zeros_like(score_targets))
score_targets = tf.expand_dims(score_targets, axis=-1)
score_outputs = tf.expand_dims(score_outputs, axis=-1)
score_loss = self._binary_crossentropy(
score_targets, score_outputs, sample_weight=mask)
score_loss /= normalizer
return score_loss
class RpnBoxLoss(object):
"""Region Proposal Network box regression loss function."""
def __init__(self, params):
logging.info('RpnBoxLoss huber_loss_delta %s', params.huber_loss_delta)
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
self._huber_loss = tf.keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
def __call__(self, box_outputs, labels):
"""Computes total RPN detection loss.
Computes total RPN detection loss including box and score from all levels.
Args:
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in
[batch_size, height, width, num_anchors * 4].
labels: the dictionary that returned from dataloader that includes
groundturth targets.
Returns:
rpn_box_loss: a scalar tensor representing total box regression loss.
"""
with tf.name_scope('rpn_loss'):
levels = sorted(box_outputs.keys())
box_losses = []
for level in levels:
box_losses.append(self._rpn_box_loss(box_outputs[level], labels[level]))
# Sum per level losses to total loss.
return tf.add_n(box_losses)
def _rpn_box_loss(self, box_outputs, box_targets, normalizer=1.0):
"""Computes box regression loss."""
with tf.name_scope('rpn_box_loss'):
mask = tf.cast(tf.not_equal(box_targets, 0.0), dtype=tf.float32)
box_targets = tf.expand_dims(box_targets, axis=-1)
box_outputs = tf.expand_dims(box_outputs, axis=-1)
box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
# The loss is normalized by the sum of non-zero weights and additional
# normalizer provided by the function caller. Using + 0.01 here to avoid
# division by zero.
box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01)
return box_loss
class OlnRpnCenterLoss(object):
"""Object Localization Network RPN centerness regression loss function."""
def __init__(self):
self._l1_loss = tf.keras.losses.MeanAbsoluteError(
reduction=tf.keras.losses.Reduction.SUM)
def __call__(self, center_outputs, labels):
"""Computes total RPN centerness regression loss.
Computes total RPN centerness score regression loss from all levels.
Args:
center_outputs: an OrderDict with keys representing levels and values
representing anchor centerness regression targets in
[batch_size, height, width, num_anchors * 4].
labels: the dictionary that returned from dataloader that includes
groundturth targets.
Returns:
rpn_center_loss: a scalar tensor representing total centerness regression
loss.
"""
with tf.name_scope('rpn_loss'):
# Normalizer.
levels = sorted(center_outputs.keys())
num_valid = 0
# 0<pos<1, neg=0, ign=-1
for level in levels:
num_valid += tf.reduce_sum(tf.cast(
tf.greater(labels[level], -1.0), tf.float32)) # in and out of box
num_valid += 1e-12
# Centerness loss over multi levels.
center_losses = []
for level in levels:
center_losses.append(
self._rpn_center_l1_loss(
center_outputs[level], labels[level],
normalizer=num_valid))
# Sum per level losses to total loss.
return tf.add_n(center_losses)
def _rpn_center_l1_loss(self, center_outputs, center_targets,
normalizer=1.0):
"""Computes centerness regression loss."""
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
with tf.name_scope('rpn_center_loss'):
# mask = tf.greater(center_targets, 0.0) # inside box only.
mask = tf.greater(center_targets, -1.0) # in and out of box.
center_targets = tf.maximum(center_targets, tf.zeros_like(center_targets))
center_outputs = tf.sigmoid(center_outputs)
center_targets = tf.expand_dims(center_targets, -1)
center_outputs = tf.expand_dims(center_outputs, -1)
mask = tf.cast(mask, dtype=tf.float32)
center_loss = self._l1_loss(center_targets, center_outputs,
sample_weight=mask)
center_loss /= normalizer
return center_loss
class OlnRpnIoULoss(object):
"""Object Localization Network RPN box-lrtb regression iou loss function."""
def __call__(self, box_outputs, labels, center_targets):
"""Computes total RPN detection loss.
Computes total RPN box regression loss from all levels.
Args:
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in
[batch_size, height, width, num_anchors * 4].
last channel: (left, right, top, bottom).
labels: the dictionary that returned from dataloader that includes
groundturth targets (left, right, top, bottom).
center_targets: valid_target mask.
Returns:
rpn_iou_loss: a scalar tensor representing total box regression loss.
"""
with tf.name_scope('rpn_loss'):
# Normalizer.
levels = sorted(box_outputs.keys())
normalizer = 0.
for level in levels:
# center_targets pos>0, neg=0, ign=-1.
mask_ = tf.cast(tf.logical_and(
tf.greater(center_targets[level][..., 0], 0.0),
tf.greater(tf.reduce_min(labels[level], -1), 0.0)), tf.float32)
normalizer += tf.reduce_sum(mask_)
normalizer += 1e-8
# iou_loss over multi levels.
iou_losses = []
for level in levels:
iou_losses.append(
self._rpn_iou_loss(
box_outputs[level], labels[level],
center_weight=center_targets[level][..., 0],
normalizer=normalizer))
# Sum per level losses to total loss.
return tf.add_n(iou_losses)
def _rpn_iou_loss(self, box_outputs, box_targets,
center_weight=None, normalizer=1.0):
"""Computes box regression loss."""
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
with tf.name_scope('rpn_iou_loss'):
mask = tf.logical_and(
tf.greater(center_weight, 0.0),
tf.greater(tf.reduce_min(box_targets, -1), 0.0))
pred_left = box_outputs[..., 0]
pred_right = box_outputs[..., 1]
pred_top = box_outputs[..., 2]
pred_bottom = box_outputs[..., 3]
gt_left = box_targets[..., 0]
gt_right = box_targets[..., 1]
gt_top = box_targets[..., 2]
gt_bottom = box_targets[..., 3]
inter_width = (tf.minimum(pred_left, gt_left) +
tf.minimum(pred_right, gt_right))
inter_height = (tf.minimum(pred_top, gt_top) +
tf.minimum(pred_bottom, gt_bottom))
inter_area = inter_width * inter_height
union_area = ((pred_left + pred_right) * (pred_top + pred_bottom) +
(gt_left + gt_right) * (gt_top + gt_bottom) -
inter_area)
iou = inter_area / (union_area + 1e-8)
mask_ = tf.cast(mask, tf.float32)
iou = tf.clip_by_value(iou, clip_value_min=1e-8, clip_value_max=1.0)
neg_log_iou = -tf.math.log(iou)
iou_loss = tf.reduce_sum(neg_log_iou * mask_)
iou_loss /= normalizer
return iou_loss
class FastrcnnClassLoss(object):
"""Fast R-CNN classification loss function."""
def __init__(self):
self._categorical_crossentropy = tf.keras.losses.CategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, class_outputs, class_targets):
"""Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
This function implements the classification loss of the Fast-RCNN.
The classification loss is softmax on all RoIs.
Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
Args:
class_outputs: a float tensor representing the class prediction for each box
with a shape of [batch_size, num_boxes, num_classes].
class_targets: a float tensor representing the class label for each box
with a shape of [batch_size, num_boxes].
Returns:
a scalar tensor representing total class loss.
"""
with tf.name_scope('fast_rcnn_loss'):
batch_size, num_boxes, num_classes = class_outputs.get_shape().as_list()
class_targets = tf.cast(class_targets, dtype=tf.int32)
class_targets_one_hot = tf.one_hot(class_targets, num_classes)
return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot,
normalizer=batch_size * num_boxes / 2.0)
def _fast_rcnn_class_loss(self, class_outputs, class_targets_one_hot,
normalizer):
"""Computes classification loss."""
with tf.name_scope('fast_rcnn_class_loss'):
class_loss = self._categorical_crossentropy(class_targets_one_hot,
class_outputs)
class_loss /= normalizer
return class_loss
class FastrcnnBoxLoss(object):
"""Fast R-CNN box regression loss function."""
def __init__(self, params):
logging.info('FastrcnnBoxLoss huber_loss_delta %s', params.huber_loss_delta)
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
self._huber_loss = tf.keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
def __call__(self, box_outputs, class_targets, box_targets):
"""Computes the box loss (Fast-RCNN branch) of Mask-RCNN.
This function implements the box regression loss of the Fast-RCNN. As the
`box_outputs` produces `num_classes` boxes for each RoI, the reference model
expands `box_targets` to match the shape of `box_outputs` and selects only
the target that the RoI has a maximum overlap. (Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/fast_rcnn.py) # pylint: disable=line-too-long
Instead, this function selects the `box_outputs` by the `class_targets` so
that it doesn't expand `box_targets`.
The box loss is smooth L1-loss on only positive samples of RoIs.
Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
Args:
box_outputs: a float tensor representing the box prediction for each box
with a shape of [batch_size, num_boxes, num_classes * 4].
class_targets: a float tensor representing the class label for each box
with a shape of [batch_size, num_boxes].
box_targets: a float tensor representing the box label for each box
with a shape of [batch_size, num_boxes, 4].
Returns:
box_loss: a scalar tensor representing total box regression loss.
"""
with tf.name_scope('fast_rcnn_loss'):
class_targets = tf.cast(class_targets, dtype=tf.int32)
# Selects the box from `box_outputs` based on `class_targets`, with which
# the box has the maximum overlap.
(batch_size, num_rois,
num_class_specific_boxes) = box_outputs.get_shape().as_list()
num_classes = num_class_specific_boxes // 4
box_outputs = tf.reshape(box_outputs,
[batch_size, num_rois, num_classes, 4])
box_indices = tf.reshape(
class_targets + tf.tile(
tf.expand_dims(
tf.range(batch_size) * num_rois * num_classes, 1),
[1, num_rois]) + tf.tile(
tf.expand_dims(tf.range(num_rois) * num_classes, 0),
[batch_size, 1]), [-1])
box_outputs = tf.matmul(
tf.one_hot(
box_indices,
batch_size * num_rois * num_classes,
dtype=box_outputs.dtype), tf.reshape(box_outputs, [-1, 4]))
box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4])
return self._fast_rcnn_box_loss(box_outputs, box_targets, class_targets)
def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets,
normalizer=1.0):
"""Computes box regression loss."""
with tf.name_scope('fast_rcnn_box_loss'):
mask = tf.tile(tf.expand_dims(tf.greater(class_targets, 0), axis=2),
[1, 1, 4])
mask = tf.cast(mask, dtype=tf.float32)
box_targets = tf.expand_dims(box_targets, axis=-1)
box_outputs = tf.expand_dims(box_outputs, axis=-1)
box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
# The loss is normalized by the number of ones in mask,
# additianal normalizer provided by the user and using 0.01 here to avoid
# division by 0.
box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01)
return box_loss
class OlnBoxScoreLoss(object):
"""Object Localization Network Box-Iou scoring function."""
def __init__(self, params):
self._ignore_threshold = params.ignore_threshold
self._l1_loss = tf.keras.losses.MeanAbsoluteError(
reduction=tf.keras.losses.Reduction.SUM)
def __call__(self, score_outputs, score_targets):
"""Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
This function implements the classification loss of the Fast-RCNN.
The classification loss is softmax on all RoIs.
Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
Args:
score_outputs: a float tensor representing the class prediction for each box
with a shape of [batch_size, num_boxes, num_classes].
score_targets: a float tensor representing the class label for each box
with a shape of [batch_size, num_boxes].
Returns:
a scalar tensor representing total score loss.
"""
with tf.name_scope('fast_rcnn_loss'):
score_outputs = tf.squeeze(score_outputs, -1)
mask = tf.greater(score_targets, self._ignore_threshold)
num_valid = tf.reduce_sum(tf.cast(mask, tf.float32))
score_targets = tf.maximum(score_targets, tf.zeros_like(score_targets))
score_outputs = tf.sigmoid(score_outputs)
score_targets = tf.expand_dims(score_targets, -1)
score_outputs = tf.expand_dims(score_outputs, -1)
mask = tf.cast(mask, dtype=tf.float32)
score_loss = self._l1_loss(score_targets, score_outputs,
sample_weight=mask)
score_loss /= (num_valid + 1e-10)
return score_loss
class MaskrcnnLoss(object):
"""Mask R-CNN instance segmentation mask loss function."""
def __init__(self):
self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, mask_outputs, mask_targets, select_class_targets):
"""Computes the mask loss of Mask-RCNN.
This function implements the mask loss of Mask-RCNN. As the `mask_outputs`
produces `num_classes` masks for each RoI, the reference model expands
`mask_targets` to match the shape of `mask_outputs` and selects only the
target that the RoI has a maximum overlap. (Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/mask_rcnn.py) # pylint: disable=line-too-long
Instead, this implementation selects the `mask_outputs` by the `class_targets`
so that it doesn't expand `mask_targets`. Note that the selection logic is
done in the post-processing of mask_rcnn_fn in mask_rcnn_architecture.py.
Args:
mask_outputs: a float tensor representing the prediction for each mask,
with a shape of
[batch_size, num_masks, mask_height, mask_width].
mask_targets: a float tensor representing the binary mask of ground truth
labels for each mask with a shape of
[batch_size, num_masks, mask_height, mask_width].
select_class_targets: a tensor with a shape of [batch_size, num_masks],
representing the foreground mask targets.
Returns:
mask_loss: a float tensor representing total mask loss.
"""
with tf.name_scope('mask_rcnn_loss'):
(batch_size, num_masks, mask_height,
mask_width) = mask_outputs.get_shape().as_list()
weights = tf.tile(
tf.reshape(tf.greater(select_class_targets, 0),
[batch_size, num_masks, 1, 1]),
[1, 1, mask_height, mask_width])
weights = tf.cast(weights, dtype=tf.float32)
mask_targets = tf.expand_dims(mask_targets, axis=-1)
mask_outputs = tf.expand_dims(mask_outputs, axis=-1)
mask_loss = self._binary_crossentropy(mask_targets, mask_outputs,
sample_weight=weights)
# The loss is normalized by the number of 1's in weights and
# + 0.01 is used to avoid division by zero.
return mask_loss / (tf.reduce_sum(weights) + 0.01)
class RetinanetClassLoss(object):
"""RetinaNet class loss."""
def __init__(self, params, num_classes):
self._num_classes = num_classes
self._focal_loss_alpha = params.focal_loss_alpha
self._focal_loss_gamma = params.focal_loss_gamma
def __call__(self, cls_outputs, labels, num_positives):
"""Computes total detection loss.
Computes total detection loss including box and class loss from all levels.
Args:
cls_outputs: an OrderDict with keys representing levels and values
representing logits in [batch_size, height, width,
num_anchors * num_classes].
labels: the dictionary that returned from dataloader that includes
class groundturth targets.
num_positives: number of positive examples in the minibatch.
Returns:
an integar tensor representing total class loss.
"""
# Sums all positives in a batch for normalization and avoids zero
# num_positives_sum, which would lead to inf loss during training
num_positives_sum = tf.reduce_sum(input_tensor=num_positives) + 1.0
cls_losses = []
for level in cls_outputs.keys():
cls_losses.append(self.class_loss(
cls_outputs[level], labels[level], num_positives_sum))
# Sums per level losses to total loss.
return tf.add_n(cls_losses)
def class_loss(self, cls_outputs, cls_targets, num_positives,
ignore_label=-2):
"""Computes RetinaNet classification loss."""
# Onehot encoding for classification labels.
cls_targets_one_hot = tf.one_hot(cls_targets, self._num_classes)
bs, height, width, _, _ = cls_targets_one_hot.get_shape().as_list()
cls_targets_one_hot = tf.reshape(cls_targets_one_hot,
[bs, height, width, -1])
loss = focal_loss(tf.cast(cls_outputs, dtype=tf.float32),
tf.cast(cls_targets_one_hot, dtype=tf.float32),
self._focal_loss_alpha,
self._focal_loss_gamma,
num_positives)
ignore_loss = tf.where(
tf.equal(cls_targets, ignore_label),
tf.zeros_like(cls_targets, dtype=tf.float32),
tf.ones_like(cls_targets, dtype=tf.float32),
)
ignore_loss = tf.expand_dims(ignore_loss, -1)
ignore_loss = tf.tile(ignore_loss, [1, 1, 1, 1, self._num_classes])
ignore_loss = tf.reshape(ignore_loss, tf.shape(input=loss))
return tf.reduce_sum(input_tensor=ignore_loss * loss)
class RetinanetBoxLoss(object):
"""RetinaNet box loss."""
def __init__(self, params):
self._huber_loss = tf.keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
def __call__(self, box_outputs, labels, num_positives):
"""Computes box detection loss.
Computes total detection loss including box and class loss from all levels.
Args:
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in [batch_size, height, width,
num_anchors * 4].
labels: the dictionary that returned from dataloader that includes
box groundturth targets.
num_positives: number of positive examples in the minibatch.
Returns:
an integer tensor representing total box regression loss.
"""
# Sums all positives in a batch for normalization and avoids zero
# num_positives_sum, which would lead to inf loss during training
num_positives_sum = tf.reduce_sum(input_tensor=num_positives) + 1.0
box_losses = []
for level in box_outputs.keys():
box_targets_l = labels[level]
box_losses.append(
self.box_loss(box_outputs[level], box_targets_l, num_positives_sum))
# Sums per level losses to total loss.
return tf.add_n(box_losses)
def box_loss(self, box_outputs, box_targets, num_positives):
"""Computes RetinaNet box regression loss."""
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P3-P7 pyramid is about [0.1, 0.1, 0.2, 0.2].
normalizer = num_positives * 4.0
mask = tf.cast(tf.not_equal(box_targets, 0.0), dtype=tf.float32)
box_targets = tf.expand_dims(box_targets, axis=-1)
box_outputs = tf.expand_dims(box_outputs, axis=-1)
box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
box_loss /= normalizer
return box_loss
class ShapemaskMseLoss(object):
"""ShapeMask mask Mean Squared Error loss function wrapper."""
def __call__(self, probs, labels, valid_mask):
"""Compute instance segmentation loss.
Args:
probs: A Tensor of shape [batch_size * num_points, height, width,
num_classes]. The logits are not necessarily between 0 and 1.
labels: A float32/float16 Tensor of shape [batch_size, num_instances,
mask_size, mask_size], where mask_size =
mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
for coarse masks and shape priors.
valid_mask: a binary mask indicating valid training masks.
Returns:
loss: an float tensor representing total mask classification loss.
"""
with tf.name_scope('shapemask_prior_loss'):
batch_size, num_instances = valid_mask.get_shape().as_list()[:2]
diff = (tf.cast(labels, dtype=tf.float32) -
tf.cast(probs, dtype=tf.float32))
diff *= tf.cast(
tf.reshape(valid_mask, [batch_size, num_instances, 1, 1]),
tf.float32)
# Adding 0.001 in the denominator to avoid division by zero.
loss = tf.nn.l2_loss(diff) / (tf.reduce_sum(labels) + 0.001)
return loss
class ShapemaskLoss(object):
"""ShapeMask mask loss function wrapper."""
def __init__(self):
self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, logits, labels, valid_mask):
"""ShapeMask mask cross entropy loss function wrapper.
Args:
logits: A Tensor of shape [batch_size * num_instances, height, width,
num_classes]. The logits are not necessarily between 0 and 1.
labels: A float16/float32 Tensor of shape [batch_size, num_instances,
mask_size, mask_size], where mask_size =
mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
for coarse masks and shape priors.
valid_mask: a binary mask of shape [batch_size, num_instances]
indicating valid training masks.
Returns:
loss: an float tensor representing total mask classification loss.
"""
with tf.name_scope('shapemask_loss'):
batch_size, num_instances = valid_mask.get_shape().as_list()[:2]
labels = tf.cast(labels, tf.float32)
logits = tf.cast(logits, tf.float32)
loss = self._binary_crossentropy(labels, logits)
loss *= tf.cast(tf.reshape(
valid_mask, [batch_size, num_instances, 1, 1]), loss.dtype)
# Adding 0.001 in the denominator to avoid division by zero.
loss = tf.reduce_sum(loss) / (tf.reduce_sum(labels) + 0.001)
return loss
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model defination for the Mask R-CNN Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.legacy.detection.dataloader import anchor
from official.legacy.detection.dataloader import mode_keys
from official.legacy.detection.evaluation import factory as eval_factory
from official.legacy.detection.modeling import base_model
from official.legacy.detection.modeling import losses
from official.legacy.detection.modeling.architecture import factory
from official.legacy.detection.ops import postprocess_ops
from official.legacy.detection.ops import roi_ops
from official.legacy.detection.ops import spatial_transform_ops
from official.legacy.detection.ops import target_ops
from official.legacy.detection.utils import box_utils
class MaskrcnnModel(base_model.Model):
"""Mask R-CNN model function."""
def __init__(self, params):
super(MaskrcnnModel, self).__init__(params)
# For eval metrics.
self._params = params
self._keras_model = None
self._include_mask = params.architecture.include_mask
# Architecture generators.
self._backbone_fn = factory.backbone_generator(params)
self._fpn_fn = factory.multilevel_features_generator(params)
self._rpn_head_fn = factory.rpn_head_generator(params)
self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
self._sample_rois_fn = target_ops.ROISampler(params.roi_sampling)
self._sample_masks_fn = target_ops.MaskSampler(
params.architecture.mask_target_size,
params.mask_sampling.num_mask_samples_per_image)
self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
if self._include_mask:
self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)
# Loss function.
self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
if self._include_mask:
self._mask_loss_fn = losses.MaskrcnnLoss()
self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
params.postprocess)
self._transpose_input = params.train.transpose_input
assert not self._transpose_input, 'Transpose input is not supportted.'
def build_outputs(self, inputs, mode):
is_training = mode == mode_keys.TRAIN
model_outputs = {}
image = inputs['image']
_, image_height, image_width, _ = image.get_shape().as_list()
backbone_features = self._backbone_fn(image, is_training)
fpn_features = self._fpn_fn(backbone_features, is_training)
rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
fpn_features, is_training)
model_outputs.update({
'rpn_score_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
rpn_score_outputs),
'rpn_box_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
rpn_box_outputs),
})
input_anchor = anchor.Anchor(self._params.architecture.min_level,
self._params.architecture.max_level,
self._params.anchor.num_scales,
self._params.anchor.aspect_ratios,
self._params.anchor.anchor_size,
(image_height, image_width))
rpn_rois, _ = self._generate_rois_fn(rpn_box_outputs, rpn_score_outputs,
input_anchor.multilevel_boxes,
inputs['image_info'][:, 1, :],
is_training)
if is_training:
rpn_rois = tf.stop_gradient(rpn_rois)
# Sample proposals.
rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
self._sample_rois_fn(rpn_rois, inputs['gt_boxes'],
inputs['gt_classes']))
# Create bounding box training targets.
box_targets = box_utils.encode_boxes(
matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])
# If the target is background, the box target is set to all 0s.
box_targets = tf.where(
tf.tile(
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
[1, 1, 4]), tf.zeros_like(box_targets), box_targets)
model_outputs.update({
'class_targets': matched_gt_classes,
'box_targets': box_targets,
})
roi_features = spatial_transform_ops.multilevel_crop_and_resize(
fpn_features, rpn_rois, output_size=7)
class_outputs, box_outputs = self._frcnn_head_fn(roi_features, is_training)
model_outputs.update({
'class_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
class_outputs),
'box_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
box_outputs),
})
# Add this output to train to make the checkpoint loadable in predict mode.
# If we skip it in train mode, the heads will be out-of-order and checkpoint
# loading will fail.
boxes, scores, classes, valid_detections = self._generate_detections_fn(
box_outputs, class_outputs, rpn_rois, inputs['image_info'][:, 1:2, :])
model_outputs.update({
'num_detections': valid_detections,
'detection_boxes': boxes,
'detection_classes': classes,
'detection_scores': scores,
})
if not self._include_mask:
return model_outputs
if is_training:
rpn_rois, classes, mask_targets = self._sample_masks_fn(
rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices,
inputs['gt_masks'])
mask_targets = tf.stop_gradient(mask_targets)
classes = tf.cast(classes, dtype=tf.int32)
model_outputs.update({
'mask_targets': mask_targets,
'sampled_class_targets': classes,
})
else:
rpn_rois = boxes
classes = tf.cast(classes, dtype=tf.int32)
mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
fpn_features, rpn_rois, output_size=14)
mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes, is_training)
if is_training:
model_outputs.update({
'mask_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
mask_outputs),
})
else:
model_outputs.update({'detection_masks': tf.nn.sigmoid(mask_outputs)})
return model_outputs
def build_loss_fn(self):
if self._keras_model is None:
raise ValueError('build_loss_fn() must be called after build_model().')
filter_fn = self.make_filter_trainable_variables_fn()
trainable_variables = filter_fn(self._keras_model.trainable_variables)
def _total_loss_fn(labels, outputs):
rpn_score_loss = self._rpn_score_loss_fn(outputs['rpn_score_outputs'],
labels['rpn_score_targets'])
rpn_box_loss = self._rpn_box_loss_fn(outputs['rpn_box_outputs'],
labels['rpn_box_targets'])
frcnn_class_loss = self._frcnn_class_loss_fn(outputs['class_outputs'],
outputs['class_targets'])
frcnn_box_loss = self._frcnn_box_loss_fn(outputs['box_outputs'],
outputs['class_targets'],
outputs['box_targets'])
if self._include_mask:
mask_loss = self._mask_loss_fn(outputs['mask_outputs'],
outputs['mask_targets'],
outputs['sampled_class_targets'])
else:
mask_loss = 0.0
model_loss = (
rpn_score_loss + rpn_box_loss + frcnn_class_loss + frcnn_box_loss +
mask_loss)
l2_regularization_loss = self.weight_decay_loss(trainable_variables)
total_loss = model_loss + l2_regularization_loss
return {
'total_loss': total_loss,
'loss': total_loss,
'fast_rcnn_class_loss': frcnn_class_loss,
'fast_rcnn_box_loss': frcnn_box_loss,
'mask_loss': mask_loss,
'model_loss': model_loss,
'l2_regularization_loss': l2_regularization_loss,
'rpn_score_loss': rpn_score_loss,
'rpn_box_loss': rpn_box_loss,
}
return _total_loss_fn
def build_input_layers(self, params, mode):
is_training = mode == mode_keys.TRAIN
input_shape = (
params.maskrcnn_parser.output_size +
[params.maskrcnn_parser.num_channels])
if is_training:
batch_size = params.train.batch_size
input_layer = {
'image':
tf.keras.layers.Input(
shape=input_shape,
batch_size=batch_size,
name='image',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
'image_info':
tf.keras.layers.Input(
shape=[4, 2],
batch_size=batch_size,
name='image_info',
),
'gt_boxes':
tf.keras.layers.Input(
shape=[params.maskrcnn_parser.max_num_instances, 4],
batch_size=batch_size,
name='gt_boxes'),
'gt_classes':
tf.keras.layers.Input(
shape=[params.maskrcnn_parser.max_num_instances],
batch_size=batch_size,
name='gt_classes',
dtype=tf.int64),
}
if self._include_mask:
input_layer['gt_masks'] = tf.keras.layers.Input(
shape=[
params.maskrcnn_parser.max_num_instances,
params.maskrcnn_parser.mask_crop_size,
params.maskrcnn_parser.mask_crop_size
],
batch_size=batch_size,
name='gt_masks')
else:
batch_size = params.eval.batch_size
input_layer = {
'image':
tf.keras.layers.Input(
shape=input_shape,
batch_size=batch_size,
name='image',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
'image_info':
tf.keras.layers.Input(
shape=[4, 2],
batch_size=batch_size,
name='image_info',
),
}
return input_layer
def build_model(self, params, mode):
if self._keras_model is None:
input_layers = self.build_input_layers(self._params, mode)
outputs = self.model_outputs(input_layers, mode)
model = tf.keras.models.Model(
inputs=input_layers, outputs=outputs, name='maskrcnn')
assert model is not None, 'Fail to build tf.keras.Model.'
model.optimizer = self.build_optimizer()
self._keras_model = model
return self._keras_model
def post_processing(self, labels, outputs):
required_output_fields = ['class_outputs', 'box_outputs']
for field in required_output_fields:
if field not in outputs:
raise ValueError('"%s" is missing in outputs, requried %s found %s' %
(field, required_output_fields, outputs.keys()))
predictions = {
'image_info': labels['image_info'],
'num_detections': outputs['num_detections'],
'detection_boxes': outputs['detection_boxes'],
'detection_classes': outputs['detection_classes'],
'detection_scores': outputs['detection_scores'],
}
if self._include_mask:
predictions.update({
'detection_masks': outputs['detection_masks'],
})
if 'groundtruths' in labels:
predictions['source_id'] = labels['groundtruths']['source_id']
predictions['gt_source_id'] = labels['groundtruths']['source_id']
predictions['gt_height'] = labels['groundtruths']['height']
predictions['gt_width'] = labels['groundtruths']['width']
predictions['gt_image_info'] = labels['image_info']
predictions['gt_num_detections'] = (
labels['groundtruths']['num_detections'])
predictions['gt_boxes'] = labels['groundtruths']['boxes']
predictions['gt_classes'] = labels['groundtruths']['classes']
predictions['gt_areas'] = labels['groundtruths']['areas']
predictions['gt_is_crowds'] = labels['groundtruths']['is_crowds']
return labels, predictions
def eval_metrics(self):
return eval_factory.evaluator_generator(self._params.eval)
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model defination for the Object Localization Network (OLN) Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.legacy.detection.dataloader import anchor
from official.legacy.detection.dataloader import mode_keys
from official.legacy.detection.modeling import losses
from official.legacy.detection.modeling.architecture import factory
from official.legacy.detection.modeling.maskrcnn_model import MaskrcnnModel
from official.legacy.detection.ops import postprocess_ops
from official.legacy.detection.ops import roi_ops
from official.legacy.detection.ops import spatial_transform_ops
from official.legacy.detection.ops import target_ops
from official.legacy.detection.utils import box_utils
class OlnMaskModel(MaskrcnnModel):
"""OLN-Mask model function."""
def __init__(self, params):
super(OlnMaskModel, self).__init__(params)
self._params = params
# Different heads and layers.
self._include_rpn_class = params.architecture.include_rpn_class
self._include_mask = params.architecture.include_mask
self._include_frcnn_class = params.architecture.include_frcnn_class
self._include_frcnn_box = params.architecture.include_frcnn_box
self._include_centerness = params.rpn_head.has_centerness
self._include_box_score = (params.frcnn_head.has_scoring and
params.architecture.include_frcnn_box)
self._include_mask_score = (params.mrcnn_head.has_scoring and
params.architecture.include_mask)
# Architecture generators.
self._backbone_fn = factory.backbone_generator(params)
self._fpn_fn = factory.multilevel_features_generator(params)
self._rpn_head_fn = factory.rpn_head_generator(params)
if self._include_centerness:
self._rpn_head_fn = factory.oln_rpn_head_generator(params)
else:
self._rpn_head_fn = factory.rpn_head_generator(params)
self._generate_rois_fn = roi_ops.OlnROIGenerator(params.roi_proposal)
self._sample_rois_fn = target_ops.ROIScoreSampler(params.roi_sampling)
self._sample_masks_fn = target_ops.MaskSampler(
params.architecture.mask_target_size,
params.mask_sampling.num_mask_samples_per_image)
if self._include_box_score:
self._frcnn_head_fn = factory.oln_box_score_head_generator(params)
else:
self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
if self._include_mask:
if self._include_mask_score:
self._mrcnn_head_fn = factory.oln_mask_score_head_generator(params)
else:
self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)
# Loss function.
self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
if self._include_centerness:
self._rpn_iou_loss_fn = losses.OlnRpnIoULoss()
self._rpn_center_loss_fn = losses.OlnRpnCenterLoss()
self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
if self._include_box_score:
self._frcnn_box_score_loss_fn = losses.OlnBoxScoreLoss(
params.frcnn_box_score_loss)
if self._include_mask:
self._mask_loss_fn = losses.MaskrcnnLoss()
self._generate_detections_fn = postprocess_ops.OlnDetectionGenerator(
params.postprocess)
self._transpose_input = params.train.transpose_input
assert not self._transpose_input, 'Transpose input is not supportted.'
def build_outputs(self, inputs, mode):
is_training = mode == mode_keys.TRAIN
model_outputs = {}
image = inputs['image']
_, image_height, image_width, _ = image.get_shape().as_list()
backbone_features = self._backbone_fn(image, is_training)
fpn_features = self._fpn_fn(backbone_features, is_training)
# rpn_centerness.
if self._include_centerness:
rpn_score_outputs, rpn_box_outputs, rpn_center_outputs = (
self._rpn_head_fn(fpn_features, is_training))
model_outputs.update({
'rpn_center_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
rpn_center_outputs),
})
object_scores = rpn_center_outputs
else:
rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
fpn_features, is_training)
object_scores = None
model_outputs.update({
'rpn_score_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
rpn_score_outputs),
'rpn_box_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
rpn_box_outputs),
})
input_anchor = anchor.Anchor(self._params.architecture.min_level,
self._params.architecture.max_level,
self._params.anchor.num_scales,
self._params.anchor.aspect_ratios,
self._params.anchor.anchor_size,
(image_height, image_width))
rpn_rois, rpn_roi_scores = self._generate_rois_fn(
rpn_box_outputs,
rpn_score_outputs,
input_anchor.multilevel_boxes,
inputs['image_info'][:, 1, :],
is_training,
is_box_lrtb=self._include_centerness,
object_scores=object_scores,
)
if (not self._include_frcnn_class and
not self._include_frcnn_box and
not self._include_mask):
# if not is_training:
# For direct RPN detection,
# use dummy box_outputs = (dy,dx,dh,dw = 0,0,0,0)
box_outputs = tf.zeros_like(rpn_rois)
box_outputs = tf.concat([box_outputs, box_outputs], -1)
boxes, scores, classes, valid_detections = self._generate_detections_fn(
box_outputs, rpn_roi_scores, rpn_rois,
inputs['image_info'][:, 1:2, :],
is_single_fg_score=True, # if no_background, no softmax is applied.
keep_nms=True)
model_outputs.update({
'num_detections': valid_detections,
'detection_boxes': boxes,
'detection_classes': classes,
'detection_scores': scores,
})
return model_outputs
# ---- OLN-Proposal finishes here. ----
if is_training:
rpn_rois = tf.stop_gradient(rpn_rois)
rpn_roi_scores = tf.stop_gradient(rpn_roi_scores)
# Sample proposals.
(rpn_rois, rpn_roi_scores, matched_gt_boxes, matched_gt_classes,
matched_gt_indices) = (
self._sample_rois_fn(rpn_rois, rpn_roi_scores, inputs['gt_boxes'],
inputs['gt_classes']))
# Create bounding box training targets.
box_targets = box_utils.encode_boxes(
matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])
# If the target is background, the box target is set to all 0s.
box_targets = tf.where(
tf.tile(
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
[1, 1, 4]), tf.zeros_like(box_targets), box_targets)
model_outputs.update({
'class_targets': matched_gt_classes,
'box_targets': box_targets,
})
# Create Box-IoU targets. {
box_ious = box_utils.bbox_overlap(
rpn_rois, inputs['gt_boxes'])
matched_box_ious = tf.reduce_max(box_ious, 2)
model_outputs.update({
'box_iou_targets': matched_box_ious,}) # }
roi_features = spatial_transform_ops.multilevel_crop_and_resize(
fpn_features, rpn_rois, output_size=7)
if not self._include_box_score:
class_outputs, box_outputs = self._frcnn_head_fn(
roi_features, is_training)
else:
class_outputs, box_outputs, score_outputs = self._frcnn_head_fn(
roi_features, is_training)
model_outputs.update({
'box_score_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
score_outputs),})
model_outputs.update({
'class_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
class_outputs),
'box_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
box_outputs),
})
# Add this output to train to make the checkpoint loadable in predict mode.
# If we skip it in train mode, the heads will be out-of-order and checkpoint
# loading will fail.
if not self._include_frcnn_box:
box_outputs = tf.zeros_like(box_outputs) # dummy zeros.
if self._include_box_score:
score_outputs = tf.cast(tf.squeeze(score_outputs, -1),
rpn_roi_scores.dtype)
# box-score = (rpn-centerness * box-iou)^(1/2)
# TR: rpn_roi_scores: b,1000, score_outputs: b,512
# TS: rpn_roi_scores: b,1000, score_outputs: b,1000
box_scores = tf.pow(
rpn_roi_scores * tf.sigmoid(score_outputs), 1/2.)
if not self._include_frcnn_class:
boxes, scores, classes, valid_detections = self._generate_detections_fn(
box_outputs,
box_scores,
rpn_rois,
inputs['image_info'][:, 1:2, :],
is_single_fg_score=True,
keep_nms=True,)
else:
boxes, scores, classes, valid_detections = self._generate_detections_fn(
box_outputs, class_outputs, rpn_rois,
inputs['image_info'][:, 1:2, :],
keep_nms=True,)
model_outputs.update({
'num_detections': valid_detections,
'detection_boxes': boxes,
'detection_classes': classes,
'detection_scores': scores,
})
# ---- OLN-Box finishes here. ----
if not self._include_mask:
return model_outputs
if is_training:
rpn_rois, classes, mask_targets = self._sample_masks_fn(
rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices,
inputs['gt_masks'])
mask_targets = tf.stop_gradient(mask_targets)
classes = tf.cast(classes, dtype=tf.int32)
model_outputs.update({
'mask_targets': mask_targets,
'sampled_class_targets': classes,
})
else:
rpn_rois = boxes
classes = tf.cast(classes, dtype=tf.int32)
mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
fpn_features, rpn_rois, output_size=14)
mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes, is_training)
if is_training:
model_outputs.update({
'mask_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
mask_outputs),
})
else:
model_outputs.update({'detection_masks': tf.nn.sigmoid(mask_outputs)})
return model_outputs
def build_loss_fn(self):
if self._keras_model is None:
raise ValueError('build_loss_fn() must be called after build_model().')
filter_fn = self.make_filter_trainable_variables_fn()
trainable_variables = filter_fn(self._keras_model.trainable_variables)
def _total_loss_fn(labels, outputs):
if self._include_rpn_class:
rpn_score_loss = self._rpn_score_loss_fn(outputs['rpn_score_outputs'],
labels['rpn_score_targets'])
else:
rpn_score_loss = 0.0
if self._include_centerness:
rpn_center_loss = self._rpn_center_loss_fn(
outputs['rpn_center_outputs'], labels['rpn_center_targets'])
rpn_box_loss = self._rpn_iou_loss_fn(
outputs['rpn_box_outputs'], labels['rpn_box_targets'],
labels['rpn_center_targets'])
else:
rpn_center_loss = 0.0
rpn_box_loss = self._rpn_box_loss_fn(
outputs['rpn_box_outputs'], labels['rpn_box_targets'])
if self._include_frcnn_class:
frcnn_class_loss = self._frcnn_class_loss_fn(
outputs['class_outputs'], outputs['class_targets'])
else:
frcnn_class_loss = 0.0
if self._include_frcnn_box:
frcnn_box_loss = self._frcnn_box_loss_fn(
outputs['box_outputs'], outputs['class_targets'],
outputs['box_targets'])
else:
frcnn_box_loss = 0.0
if self._include_box_score:
box_score_loss = self._frcnn_box_score_loss_fn(
outputs['box_score_outputs'], outputs['box_iou_targets'])
else:
box_score_loss = 0.0
if self._include_mask:
mask_loss = self._mask_loss_fn(outputs['mask_outputs'],
outputs['mask_targets'],
outputs['sampled_class_targets'])
else:
mask_loss = 0.0
model_loss = (
rpn_score_loss + rpn_box_loss + rpn_center_loss +
frcnn_class_loss + frcnn_box_loss + box_score_loss +
mask_loss)
l2_regularization_loss = self.weight_decay_loss(trainable_variables)
total_loss = model_loss + l2_regularization_loss
return {
'total_loss': total_loss,
'loss': total_loss,
'fast_rcnn_class_loss': frcnn_class_loss,
'fast_rcnn_box_loss': frcnn_box_loss,
'fast_rcnn_box_score_loss': box_score_loss,
'mask_loss': mask_loss,
'model_loss': model_loss,
'l2_regularization_loss': l2_regularization_loss,
'rpn_score_loss': rpn_score_loss,
'rpn_box_loss': rpn_box_loss,
'rpn_center_loss': rpn_center_loss,
}
return _total_loss_fn
def build_input_layers(self, params, mode):
is_training = mode == mode_keys.TRAIN
input_shape = (
params.olnmask_parser.output_size +
[params.olnmask_parser.num_channels])
if is_training:
batch_size = params.train.batch_size
input_layer = {
'image':
tf.keras.layers.Input(
shape=input_shape,
batch_size=batch_size,
name='image',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
'image_info':
tf.keras.layers.Input(
shape=[4, 2],
batch_size=batch_size,
name='image_info',
),
'gt_boxes':
tf.keras.layers.Input(
shape=[params.olnmask_parser.max_num_instances, 4],
batch_size=batch_size,
name='gt_boxes'),
'gt_classes':
tf.keras.layers.Input(
shape=[params.olnmask_parser.max_num_instances],
batch_size=batch_size,
name='gt_classes',
dtype=tf.int64),
}
if self._include_mask:
input_layer['gt_masks'] = tf.keras.layers.Input(
shape=[
params.olnmask_parser.max_num_instances,
params.olnmask_parser.mask_crop_size,
params.olnmask_parser.mask_crop_size
],
batch_size=batch_size,
name='gt_masks')
else:
batch_size = params.eval.batch_size
input_layer = {
'image':
tf.keras.layers.Input(
shape=input_shape,
batch_size=batch_size,
name='image',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
'image_info':
tf.keras.layers.Input(
shape=[4, 2],
batch_size=batch_size,
name='image_info',
),
}
return input_layer
def build_model(self, params, mode):
if self._keras_model is None:
input_layers = self.build_input_layers(self._params, mode)
outputs = self.model_outputs(input_layers, mode)
model = tf.keras.models.Model(
inputs=input_layers, outputs=outputs, name='olnmask')
assert model is not None, 'Fail to build tf.keras.Model.'
model.optimizer = self.build_optimizer()
self._keras_model = model
return self._keras_model
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
class OptimizerFactory(object):
"""Class to generate optimizer function."""
def __init__(self, params):
"""Creates optimized based on the specified flags."""
if params.type == 'momentum':
self._optimizer = functools.partial(
tf.keras.optimizers.SGD,
momentum=params.momentum,
nesterov=params.nesterov)
elif params.type == 'adam':
self._optimizer = tf.keras.optimizers.Adam
elif params.type == 'adadelta':
self._optimizer = tf.keras.optimizers.Adadelta
elif params.type == 'adagrad':
self._optimizer = tf.keras.optimizers.Adagrad
elif params.type == 'rmsprop':
self._optimizer = functools.partial(
tf.keras.optimizers.RMSprop, momentum=params.momentum)
else:
raise ValueError('Unsupported optimizer type `{}`.'.format(params.type))
def __call__(self, learning_rate):
return self._optimizer(learning_rate=learning_rate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment