Commit bcb231f0 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Move retinanet keras model to tensorflow_models/official

PiperOrigin-RevId: 274010788
parent 04ce9636
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Custom training loop for running TensorFlow 2.0 models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import json
import os
from absl import flags
from absl import logging
import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from typing import Optional, Dict, Text, Callable, Union, Iterator, Any
from official.modeling.hyperparams import params_dict
from official.utils.misc import tpu_lib
FLAGS = flags.FLAGS
# TODO(yeqing): Move all the flags out of this file.
def define_common_hparams_flags():
"""Define the common flags across models."""
flags.DEFINE_string(
'model_dir',
default=None,
help=('The directory where the model and training/evaluation summaries'
'are stored.'))
flags.DEFINE_integer(
'train_batch_size', default=None, help='Batch size for training.')
flags.DEFINE_integer(
'eval_batch_size', default=None, help='Batch size for evaluation.')
flags.DEFINE_string(
'precision',
default=None,
help=('Precision to use; one of: {bfloat16, float32}'))
flags.DEFINE_string(
'config_file',
default=None,
help=('A YAML file which specifies overrides. Note that this file can be '
'used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, the one in '
'`--params_override` will be used finally.'))
flags.DEFINE_string(
'params_override',
default=None,
help=('a YAML/JSON string or a YAML file which specifies additional '
'overrides over the default parameters and those specified in '
'`--config_file`. Note that this is supposed to be used only to '
'override the model parameters, but not the parameters like TPU '
'specific flags. One canonical use case of `--config_file` and '
'`--params_override` is users first define a template config file '
'using `--config_file`, then use `--params_override` to adjust the '
'minimal set of tuning parameters, for example setting up different'
' `train_batch_size`. '
'The final override order of parameters: default_model_params --> '
'params from config_file --> params in params_override.'
'See also the help message of `--config_file`.'))
flags.DEFINE_string(
'strategy_type', 'mirrored', 'Type of distribute strategy.'
'One of mirrored, tpu and multiworker.')
def initialize_common_flags():
"""Define the common flags across models."""
define_common_hparams_flags()
flags.DEFINE_string(
'tpu',
default=None,
help='The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.')
# Parameters for MultiWorkerMirroredStrategy
flags.DEFINE_string(
'worker_hosts',
default=None,
help='Comma-separated list of worker ip:port pairs for running '
'multi-worker models with distribution strategy. The user would '
'start the program on each host with identical value for this flag.')
flags.DEFINE_integer(
'task_index', 0,
'If multi-worker training, the task_index of this worker.')
def strategy_flags_dict():
"""Returns TPU related flags in a dictionary."""
return {
# TPUStrategy related flags.
'tpu': FLAGS.tpu,
# MultiWorkerMirroredStrategy related flags.
'worker_hosts': FLAGS.worker_hosts,
'task_index': FLAGS.task_index,
}
def hparam_flags_dict():
"""Returns model params related flags in a dictionary."""
return {
'data_dir': FLAGS.data_dir,
'model_dir': FLAGS.model_dir,
'train_batch_size': FLAGS.train_batch_size,
'eval_batch_size': FLAGS.eval_batch_size,
'precision': FLAGS.precision,
'config_file': FLAGS.config_file,
'params_override': FLAGS.params_override,
}
def primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return '/job:worker' if use_remote_tpu else ''
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to model_dir with provided checkpoint prefix."""
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info('Saving model as TF checkpoint: %s', saved_path)
def _no_metric():
return None
class SummaryWriter(object):
"""Simple SummaryWriter for writing dictionary of metrics.
Attributes:
_writer: The tf.SummaryWriter.
"""
def __init__(self, model_dir: Text, name: Text):
"""Inits SummaryWriter with paths.
Arguments:
model_dir: the model folder path.
name: the summary subfolder name.
"""
self._writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
"""Write metrics to summary with the given writer.
Args:
metrics: a dictionary of metrics values. Prefer dictionary.
step: integer. The training step.
"""
if not isinstance(metrics, dict):
# Support scalar metric without name.
logging.warning('Warning: summary writer prefer metrics as dictionary.')
metrics = {'metric': metrics}
with self._writer.as_default():
for k, v in metrics.items():
tf.summary.scalar(k, v, step=step)
self._writer.flush()
class DistributedExecutor(object):
"""Interface to train and eval models with tf.distribute.Strategy.
Arguments:
strategy: an instance of tf.distribute.Strategy.
params: Model configuration needed to run distribution strategy.
model_fn: Keras model function. Signature:
(params: ParamsDict) -> tf.keras.models.Model.
loss_fn: loss function. Signature:
(y_true: Tensor, y_pred: Tensor) -> Tensor
metric_fn: metric function. Signature: () -> tf.keras.metrics.Metric.
is_multi_host: Set to True when using multi hosts for training, like multi
worker GPU or TPU pod (slice). Otherwise, False.
use_remote_tpu: If True, run on remote TPU mode.
"""
def __init__(self,
strategy,
params,
model_fn,
loss_fn,
is_multi_host=False,
use_remote_tpu=False):
self._params = params
self._model_fn = model_fn
self._loss_fn = loss_fn
self._strategy = strategy
self._use_remote_tpu = use_remote_tpu
self._checkpoint_name = 'ctl_step_{step}.ckpt'
self._is_multi_host = is_multi_host
@property
def checkpoint_name(self):
"""Returns default checkpoint name."""
return self._checkpoint_name
@checkpoint_name.setter
def checkpoint_name(self, name):
"""Sets default summary writer for the current thread."""
self._checkpoint_name = name
def loss_fn(self):
return self._loss_fn()
def model_fn(self, params):
return self._model_fn(params)
def _save_config(self, model_dir):
"""Save parameters to config files if model_dir is defined."""
logging.info('Save config to model_dir %s.', model_dir)
if model_dir:
if not tf.io.gfile.exists(model_dir):
tf.io.gfile.makedirs(model_dir)
self._params.lock()
params_dict.save_params_dict_to_yaml(self._params,
model_dir + '/params.yaml')
else:
logging.warning('model_dir is empty, so skip the save config.')
def _get_input_iterator(
self, input_fn: Callable[[Optional[params_dict.ParamsDict]],
tf.data.Dataset],
strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
"""Returns distributed dataset iterator.
Args:
input_fn: (params: dict) -> tf.data.Dataset.
strategy: an instance of tf.distribute.Strategy.
Returns:
An iterator that yields input tensors.
"""
if input_fn is None:
return None
# When training with multiple TPU workers, datasets needs to be cloned
# across workers. Since Dataset instance cannot be cloned in eager mode,
# we instead pass callable that returns a dataset.
if self._is_multi_host:
return iter(
strategy.experimental_distribute_datasets_from_function(input_fn))
else:
input_data = input_fn(self._params)
return iter(strategy.experimental_distribute_dataset(input_data))
# TODO(yeqing): Extract the train_step out as a class for re-usability.
def _create_train_step(self,
strategy,
model,
loss_fn,
optimizer,
metric=None):
"""Creates a distributed training step.
Args:
strategy: an instance of tf.distribute.Strategy.
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
iterator: an iterator that yields input tensors.
metric: tf.keras.metrics.Metric subclass.
Returns:
The training step callable.
"""
@tf.function
def train_step(iterator):
"""Performs a distributed training step.
Args:
iterator: an iterator that yields input tensors.
Returns:
The loss tensor.
"""
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
prediction_loss = loss_fn(labels, outputs)
loss = tf.reduce_mean(prediction_loss)
loss = loss / strategy.num_replicas_in_sync
if isinstance(metric, tf.keras.metrics.Metric):
metric.update_state(labels, outputs)
else:
logging.error('train metric is not an instance of '
'tf.keras.metrics.Metric.')
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
per_replica_losses = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return loss
return train_step
def _create_test_step(self, strategy, model, metric):
"""Creates a distributed test step."""
@tf.function
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
if not metric:
logging.info('Skip test_step because metric is None (%s)', metric)
return None, None
if not isinstance(metric, tf.keras.metrics.Metric):
raise ValueError(
'Metric must be an instance of tf.keras.metrics.Metric '
'for running in test_step. Actual {}'.format(metric))
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
metric.update_state(labels, model_outputs)
return labels, model_outputs
return strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
return test_step
def train(self,
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_input_fn: Callable[[params_dict.ParamsDict],
tf.data.Dataset] = None,
model_dir: Text = None,
total_steps: int = 1,
iterations_per_loop: int = 1,
train_metric_fn: Callable[[], Any] = None,
eval_metric_fn: Callable[[], Any] = None,
summary_writer_fn: Callable[[Text, Text],
SummaryWriter] = SummaryWriter,
init_checkpoint: Callable[[tf.keras.Model], Any] = None,
save_config: bool = True):
"""Runs distributed training.
Args:
train_input_fn: (params: dict) -> tf.data.Dataset training data input
function.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
model_dir: the folder path for model checkpoints.
total_steps: total training steps.
iterations_per_loop: train steps per loop. After each loop, this job will
update metrics like loss and save checkpoint.
train_metric_fn: metric_fn for evaluation in train_step.
eval_metric_fn: metric_fn for evaluation in test_step.
summary_writer_fn: function to create summary writer.
init_checkpoint: function to load checkpoint.
save_config: bool. Whether to save params to model_dir.
Returns:
The trained keras model.
"""
assert train_input_fn is not None
if train_metric_fn and not callable(train_metric_fn):
raise ValueError('if `train_metric_fn` is specified, '
'train_metric_fn must be a callable.')
if eval_metric_fn and not callable(eval_metric_fn):
raise ValueError('if `eval_metric_fn` is specified, '
'eval_metric_fn must be a callable.')
train_metric_fn = train_metric_fn or _no_metric
eval_metric_fn = eval_metric_fn or _no_metric
if save_config:
self._save_config(model_dir)
params = self._params
strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place
# input pipeline ops in worker task.
train_iterator = self._get_input_iterator(train_input_fn, strategy)
with strategy.scope():
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model = self.model_fn(params.as_dict())
if not hasattr(model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
optimizer = model.optimizer
# Training loop starts here.
# TODO(yeqing): Implementing checkpoints with Callbacks.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
initial_step = 0
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
initial_step = optimizer.iterations.numpy()
logging.info('Loading from checkpoint file completed. Init step %d',
initial_step)
elif init_checkpoint:
logging.info('Restoring from init checkpoint function')
init_checkpoint(model)
logging.info('Loading from init checkpoint file completed')
current_step = optimizer.iterations.numpy()
checkpoint_name = self.checkpoint_name
eval_metric = eval_metric_fn()
train_metric = train_metric_fn()
train_summary_writer = summary_writer_fn(model_dir, 'eval_train')
test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
# Continue training loop.
train_step = self._create_train_step(
strategy, model, self.loss_fn(), optimizer, metric=train_metric)
test_step = None
if eval_input_fn and eval_metric:
test_step = self._create_test_step(strategy, model, metric=eval_metric)
logging.info('Training started')
for step in range(initial_step, total_steps):
current_step = step + 1
train_loss = train_step(train_iterator)
if current_step % iterations_per_loop != 0:
# Skip metric if run less than one epoch.
continue
train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float),
train_loss)
if not isinstance(train_loss, dict):
train_loss = {'total_loss': train_loss}
if train_metric:
train_metric_result = train_metric.result()
if isinstance(train_metric, tf.keras.metrics.Metric):
train_metric_result = tf.nest.map_structure(
lambda x: x.numpy().astype(float), train_metric_result)
if not isinstance(train_metric_result, dict):
train_metric_result = {'metric': train_metric_result}
train_metric_result.update(train_loss)
else:
train_metric_result = train_loss
if callable(optimizer.lr):
train_metric_result.update(
{'learning_rate': optimizer.lr(current_step).numpy()})
else:
train_metric_result.update({'learning_rate': optimizer.lr.numpy()})
logging.info('Train Step: %d/%d / loss = %s / training metric = %s',
current_step, total_steps, train_loss,
train_metric_result)
train_summary_writer(
metrics=train_metric_result, step=optimizer.iterations)
# Saves model checkpoints and run validation steps at every epoch end.
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_steps:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Step: %s evalation metric = %s.', current_step,
eval_metric_result)
test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations)
# Re-initialize evaluation metric, except the last step.
if eval_metric and current_step < total_steps:
eval_metric.reset_states()
if train_metric and current_step < total_steps:
train_metric.reset_states()
# Reaches the end of training and saves the last checkpoint.
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_step:
logging.info('Running final evaluation after training is complete.')
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Final evaluation metric = %s.', eval_metric_result)
test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations)
return model
def _run_evaluation(self, test_step, current_training_step, metric,
test_iterator):
"""Runs validation steps and aggregate metrics."""
if not test_iterator or not metric:
logging.warning(
'Both test_iterator (%s) and metrics (%s) must not be None.',
test_iterator, metric)
return None
logging.info('Running evaluation after step: %s.', current_training_step)
while True:
try:
test_step(test_iterator)
except (StopIteration, tf.errors.OutOfRangeError):
break
metric_result = metric.result()
if isinstance(metric, tf.keras.metrics.Metric):
metric_result = metric_result.numpy().astype(float)
logging.info('Step: [%d] Validation metric = %f', current_training_step,
metric_result)
return metric_result
def evaluate_from_model_dir(
self,
model_dir: Text,
eval_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_metric_fn: Callable[[], Any],
total_steps: int = -1,
eval_timeout: int = None,
min_eval_interval: int = 180,
summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter):
"""Runs distributed evaluation on model folder.
Args:
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step.
model_dir: the folder for storing model checkpoints.
total_steps: total training steps. If the current step reaches the
total_steps, the evaluation loop will stop.
eval_timeout: The maximum number of seconds to wait between checkpoints.
If left as None, then the process will wait indefinitely. Used by
tf.train.checkpoints_iterator.
min_eval_interval: The minimum number of seconds between yielding
checkpoints. Used by tf.train.checkpoints_iterator.
summary_writer_fn: function to create summary writer.
Returns:
Eval metrics dictionary of the last checkpoint.
"""
if not model_dir:
raise ValueError('model_dir must be set.')
def terminate_eval():
tf.logging.info('Terminating eval after %d seconds of no checkpoints' %
eval_timeout)
return True
summary_writer = summary_writer_fn(model_dir, 'eval')
# Read checkpoints from the given model directory
# until `eval_timeout` seconds elapses.
for checkpoint_path in tf.train.checkpoints_iterator(
model_dir,
min_interval_secs=min_eval_interval,
timeout=eval_timeout,
timeout_fn=terminate_eval):
eval_metric_result, current_step = self.evaluate_checkpoint(
checkpoint_path=checkpoint_path,
eval_input_fn=eval_input_fn,
eval_metric_fn=eval_metric_fn,
summary_writer=summary_writer)
if total_steps > 0 and current_step >= total_steps:
logging.info('Evaluation finished after training step %d', current_step)
break
return eval_metric_result
def evaluate_checkpoint(self,
checkpoint_path: Text,
eval_input_fn: Callable[[params_dict.ParamsDict],
tf.data.Dataset],
eval_metric_fn: Callable[[], Any],
summary_writer: SummaryWriter = None):
"""Runs distributed evaluation on the one checkpoint.
Args:
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step.
checkpoint_path: the checkpoint to evaluate.
summary_writer_fn: function to create summary writer.
Returns:
Eval metrics dictionary of the last checkpoint.
"""
if not callable(eval_metric_fn):
raise ValueError('if `eval_metric_fn` is specified, '
'eval_metric_fn must be a callable.')
params = self._params
strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place
# input pipeline ops in worker task.
with strategy.scope():
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model = self.model_fn(params.as_dict())
checkpoint = tf.train.Checkpoint(model=model)
eval_metric = eval_metric_fn()
assert eval_metric, 'eval_metric does not exist'
test_step = self._create_test_step(strategy, model, metric=eval_metric)
logging.info('Starting to evaluate.')
if not checkpoint_path:
raise ValueError('checkpoint path is empty')
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
current_step = reader.get_tensor(
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', checkpoint_path)
checkpoint.restore(checkpoint_path)
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Step: %s evalation metric = %s.', current_step,
eval_metric_result)
summary_writer(metrics=eval_metric_result, step=current_step)
eval_metric.reset_states()
return eval_metric_result, current_step
def predict(self):
return NotImplementedError('Unimplmented function.')
# TODO(yeqing): Add unit test for MultiWorkerMirroredStrategy.
class ExecutorBuilder(object):
"""Builder of DistributedExecutor.
Example 1: Builds an executor with supported Strategy.
builder = ExecutorBuilder(
strategy_type='tpu',
strategy_config={'tpu': '/bns/xxx'})
dist_executor = builder.build_executor(
params=params,
model_fn=my_model_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
Example 2: Builds an executor with customized Strategy.
builder = ExecutorBuilder()
builder.strategy = <some customized Strategy>
dist_executor = builder.build_executor(
params=params,
model_fn=my_model_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
Example 3: Builds a customized executor with customized Strategy.
class MyDistributedExecutor(DistributedExecutor):
# implementation ...
builder = ExecutorBuilder()
builder.strategy = <some customized Strategy>
dist_executor = builder.build_executor(
class_ctor=MyDistributedExecutor,
params=params,
model_fn=my_model_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
Args:
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'. If
None. User is responsible to set the strategy before calling
build_executor(...).
strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure.
"""
def __init__(self, strategy_type=None, strategy_config=None):
self._strategy_config = strategy_config
self._strategy = self._build_strategy(strategy_type)
@property
def strategy(self):
"""Returns default checkpoint name."""
return self._strategy
@strategy.setter
def strategy(self, new_strategy):
"""Sets default summary writer for the current thread."""
self._strategy = new_strategy
def _build_strategy(self, strategy_type):
"""Builds tf.distribute.Strategy instance.
Args:
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
Returns:
An tf.distribute.Strategy object. Returns None if strategy_type is None.
"""
if strategy_type is None:
return None
if strategy_type == 'tpu':
return self._build_tpu_strategy()
elif strategy_type == 'mirrored':
return self._build_mirrored_strategy()
elif strategy_type == 'multi_worker_mirrored':
return self._build_multiworker_mirrored_strategy()
else:
raise NotImplementedError('Unsupport accelerator type "%s"' %
strategy_type)
def _build_mirrored_strategy(self):
"""Builds a MirroredStrategy object."""
return tf.distribute.MirroredStrategy()
def _build_tpu_strategy(self):
"""Builds a TPUStrategy object."""
tpu = self._strategy_config.tpu
logging.info('Use TPU at %s', tpu if tpu is not None else '')
cluster_resolver = tpu_lib.tpu_initialize(tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
return strategy
def _build_multiworker_mirrored_strategy(self):
"""Builds a MultiWorkerMirroredStrategy object."""
worker_hosts = self._strategy_config.worker_hosts
if worker_hosts is not None:
# Set TF_CONFIG environment variable
worker_hosts = worker_hosts.split(',')
task_index = self._strategy_config.task_index
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': worker_hosts
},
'task': {
'type': 'worker',
'index': task_index
}
})
multiworker_strategy = (
tf.distribute.experimental.MultiWorkerMirroredStrategy())
return multiworker_strategy
def build_executor(self,
class_ctor=DistributedExecutor,
params=None,
model_fn=None,
loss_fn=None,
**kwargs):
"""Creates an executor according to strategy type.
See doc string of the DistributedExecutor.__init__ for more information of
the
input arguments.
Args:
class_ctor: A constructor of executor (default: DistributedExecutor).
params: ParamsDict, all the model parameters and runtime parameters.
model_fn: Keras model function.
loss_fn: loss function.
**kwargs: other arguments to the executor constructor.
Returns:
An instance of DistributedExecutor or its subclass.
"""
if self._strategy is None:
raise ValueError('`strategy` should not be None. You need to specify '
'`strategy_type` in the builder contructor or directly '
'set the `strategy` property of the builder.')
if 'use_remote_tpu' not in kwargs:
use_remote_tpu = (
isinstance(self._strategy, tf.distribute.experimental.TPUStrategy) and
bool(self._strategy_config.tpu))
kwargs['use_remote_tpu'] = use_remote_tpu
return class_ctor(
strategy=self._strategy,
params=params,
model_fn=model_fn,
loss_fn=loss_fn,
**kwargs)
# Object Detection Models on TensorFlow 2.0
**Note**: The repo is still under construction. More features and instructions
will be added soon.
## Prerequsite
To get started, make sure to use Tensorflow 2.1+ on Google Cloud. Also here are
a few package you need to install to get started:
```bash
sudo apt-get install -y python-tk && \
pip install Cython matplotlib opencv-python-headless pyyaml Pillow && \
pip install 'git+https://github.com/cocodataset/cocoapi#egg=pycocotools&subdirectory=PythonAPI'
```
Next, download the code from TensorFlow models github repository or use the
pre-installed Google Cloud VM.
```bash
git clone https://github.com/tensorflow/models.git
```
## Train RetinaNet on TPU
### Train a vanilla ResNet-50 based RetinaNet.
```bash
TPU_NAME="<your GCP TPU name>"
MODEL_DIR="<path to the directory to store model files>"
RESNET_CHECKPOINT="<path to the pre-trained Resnet-50 checkpoint>"
TRAIN_FILE_PATTERN="<path to the TFRecord training data>"
EVAL_FILE_PATTERN="<path to the TFRecord validation data>"
VAL_JSON_FILE="<path to the validation annotation JSON file>"
python ~/models/official/vision/detection/main.py \
--strategy_type=tpu \
--tpu="${TPU_NAME?}" \
--model_dir="${MODEL_DIR?}" \
--mode=train \
--params_override="{ type: retinanet, train: { checkpoint: { path: ${RESNET_CHECKPOINT?}, prefix: resnet50/ }, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} } }"
```
### Train a custom RetinaNet using the config file.
First, create a YAML config file, e.g. *my_retinanet.yaml*. This file specifies
the parameters to be overridden, which should at least include the following
fields.
```YAML
# my_retinanet.yaml
type: 'retinanet'
train:
train_file_pattern: <path to the TFRecord training data>
eval:
eval_file_pattern: <path to the TFRecord validation data>
val_json_file: <path to the validation annotation JSON file>
```
Once the YAML config file is created, you can launch the training using the
following command.
```bash
TPU_NAME="<your GCP TPU name>"
MODEL_DIR="<path to the directory to store model files>"
python ~/models/official/vision/detection/main.py \
--strategy_type=tpu \
--tpu="${TPU_NAME?}" \
--model_dir="${MODEL_DIR?}" \
--mode=train \
--config_file="my_retinanet.yaml"
```
## Train RetinaNet on GPU
Note: Instructions are comming soon.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory to provide model configs."""
from official.vision.detection.configs import retinanet_config
from official.modeling.hyperparams import params_dict
def config_generator(model):
"""Model function generator."""
if model == 'retinanet':
default_config = retinanet_config.RETINANET_CFG
restrictions = retinanet_config.RETINANET_RESTRICTIONS
else:
raise ValueError('Model %s is not supported.' % model)
return params_dict.ParamsDict(default_config, restrictions)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config template to train Retinanet."""
# pylint: disable=line-too-long
# For ResNet-50, this freezes the variables of the first conv1 and conv2_x
# layers [1], which leads to higher training speed and slightly better testing
# accuracy. The intuition is that the low-level architecture (e.g., ResNet-50)
# is able to capture low-level features such as edges; therefore, it does not
# need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET50_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/'
# pylint: disable=line-too-long
RETINANET_CFG = {
'type': 'retinanet',
'model_dir': '',
'use_tpu': True,
'train': {
'batch_size': 64,
'iterations_per_loop': 500,
'total_steps': 22500,
'optimizer': {
'type': 'momentum',
'momentum': 0.9,
},
'learning_rate': {
'type': 'step',
'warmup_learning_rate': 0.0067,
'warmup_steps': 500,
'init_learning_rate': 0.08,
'learning_rate_levels': [0.008, 0.0008],
'learning_rate_steps': [15000, 20000],
},
'checkpoint': {
'path': '',
'prefix': '',
},
'frozen_variable_prefix': RESNET50_FROZEN_VAR_PREFIX,
'train_file_pattern': '',
# TODO(b/142174042): Support transpose_input option.
'transpose_input': False,
'l2_weight_decay': 0.0001,
},
'eval': {
'batch_size': 8,
'min_eval_interval': 180,
'eval_timeout': None,
'eval_samples': 5000,
'type': 'box',
'val_json_file': '',
'eval_file_pattern': '',
},
'predict': {
'predict_batch_size': 8,
},
'architecture': {
'parser': 'retinanet_parser',
'backbone': 'resnet',
'multilevel_features': 'fpn',
'use_bfloat16': False,
},
'anchor': {
'min_level': 3,
'max_level': 7,
'num_scales': 3,
'aspect_ratios': [1.0, 2.0, 0.5],
'anchor_size': 4.0,
},
'retinanet_parser': {
'use_bfloat16': False,
'output_size': [640, 640],
'num_channels': 3,
'match_threshold': 0.5,
'unmatched_threshold': 0.5,
'aug_rand_hflip': True,
'aug_scale_min': 1.0,
'aug_scale_max': 1.0,
'use_autoaugment': False,
'autoaugment_policy_name': 'v0',
'skip_crowd_during_training': True,
'max_num_instances': 100,
},
'resnet': {
'resnet_depth': 50,
'dropblock': {
'dropblock_keep_prob': None,
'dropblock_size': None,
},
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
},
},
'fpn': {
'min_level': 3,
'max_level': 7,
'fpn_feat_dims': 256,
'use_separable_conv': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
},
},
'nasfpn': {
'min_level': 3,
'max_level': 7,
'fpn_feat_dims': 256,
'num_repeats': 5,
'use_separable_conv': False,
'dropblock': {
'dropblock_keep_prob': None,
'dropblock_size': None,
},
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
},
},
'retinanet_head': {
'min_level': 3,
'max_level': 7,
# Note that `num_classes` is the total number of classes including
# one background classes whose index is 0.
'num_classes': 91,
'anchors_per_location': 9,
'retinanet_head_num_convs': 4,
'retinanet_head_num_filters': 256,
'use_separable_conv': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
},
},
'retinanet_loss': {
'num_classes': 91,
'focal_loss_alpha': 0.25,
'focal_loss_gamma': 1.5,
'huber_loss_delta': 0.1,
'box_loss_weight': 50,
},
'postprocess': {
'use_batched_nms': False,
'min_level': 3,
'max_level': 7,
'num_classes': 91,
'max_total_size': 100,
'nms_iou_threshold': 0.5,
'score_threshold': 0.05
},
'enable_summary': False,
}
RETINANET_RESTRICTIONS = [
'architecture.use_bfloat16 == retinanet_parser.use_bfloat16',
'anchor.min_level == fpn.min_level',
'anchor.max_level == fpn.max_level',
'anchor.min_level == nasfpn.min_level',
'anchor.max_level == nasfpn.max_level',
'anchor.min_level == retinanet_head.min_level',
'anchor.max_level == retinanet_head.max_level',
'anchor.min_level == postprocess.min_level',
'anchor.max_level == postprocess.max_level',
'retinanet_head.num_classes == retinanet_loss.num_classes',
'retinanet_head.num_classes == postprocess.num_classes',
]
# pylint: enable=line-too-long
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Anchor box and labeler definition."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow.compat.v2 as tf
from official.vision.detection.utils.object_detection import argmax_matcher
from official.vision.detection.utils.object_detection import balanced_positive_negative_sampler
from official.vision.detection.utils.object_detection import box_list
from official.vision.detection.utils.object_detection import faster_rcnn_box_coder
from official.vision.detection.utils.object_detection import region_similarity_calculator
from official.vision.detection.utils.object_detection import target_assigner
class Anchor(object):
"""Anchor class for anchor-based object detectors."""
def __init__(self,
min_level,
max_level,
num_scales,
aspect_ratios,
anchor_size,
image_size):
"""Constructs multiscale anchors.
Args:
min_level: integer number of minimum level of the output feature pyramid.
max_level: integer number of maximum level of the output feature pyramid.
num_scales: integer number representing intermediate scales added
on each level. For instances, num_scales=2 adds one additional
intermediate anchor scales [2^0, 2^0.5] on each level.
aspect_ratios: list of float numbers representing the aspect raito anchors
added on each level. The number indicates the ratio of width to height.
For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each
scale level.
anchor_size: float number representing the scale of size of the base
anchor to the feature stride 2^level.
image_size: a list of integer numbers or Tensors representing
[height, width] of the input image size.The image_size should be divided
by the largest feature stride 2^max_level.
"""
self.min_level = min_level
self.max_level = max_level
self.num_scales = num_scales
self.aspect_ratios = aspect_ratios
self.anchor_size = anchor_size
self.image_size = image_size
self.boxes = self._generate_boxes()
def _generate_boxes(self):
"""Generates multiscale anchor boxes.
Returns:
a Tensor of shape [N, 4], represneting anchor boxes of all levels
concatenated together.
"""
boxes_all = []
for level in range(self.min_level, self.max_level + 1):
boxes_l = []
for scale in range(self.num_scales):
for aspect_ratio in self.aspect_ratios:
stride = 2 ** level
intermidate_scale = 2 ** (scale / float(self.num_scales))
base_anchor_size = self.anchor_size * stride * intermidate_scale
aspect_x = aspect_ratio ** 0.5
aspect_y = aspect_ratio ** -0.5
half_anchor_size_x = base_anchor_size * aspect_x / 2.0
half_anchor_size_y = base_anchor_size * aspect_y / 2.0
x = tf.range(stride / 2, self.image_size[1], stride)
y = tf.range(stride / 2, self.image_size[0], stride)
xv, yv = tf.meshgrid(x, y)
xv = tf.cast(tf.reshape(xv, [-1]), dtype=tf.float32)
yv = tf.cast(tf.reshape(yv, [-1]), dtype=tf.float32)
# Tensor shape Nx4.
boxes = tf.stack([yv - half_anchor_size_y, xv - half_anchor_size_x,
yv + half_anchor_size_y, xv + half_anchor_size_x],
axis=1)
boxes_l.append(boxes)
# Concat anchors on the same level to tensor shape NxAx4.
boxes_l = tf.stack(boxes_l, axis=1)
boxes_l = tf.reshape(boxes_l, [-1, 4])
boxes_all.append(boxes_l)
return tf.concat(boxes_all, axis=0)
def unpack_labels(self, labels):
"""Unpacks an array of labels into multiscales labels."""
unpacked_labels = collections.OrderedDict()
count = 0
for level in range(self.min_level, self.max_level + 1):
feat_size_y = tf.cast(self.image_size[0] / 2 ** level, tf.int32)
feat_size_x = tf.cast(self.image_size[1] / 2 ** level, tf.int32)
steps = feat_size_y * feat_size_x * self.anchors_per_location
unpacked_labels[level] = tf.reshape(
labels[count:count + steps], [feat_size_y, feat_size_x, -1])
count += steps
return unpacked_labels
@property
def anchors_per_location(self):
return self.num_scales * len(self.aspect_ratios)
@property
def multilevel_boxes(self):
return self.unpack_labels(self.boxes)
class AnchorLabeler(object):
"""Labeler for dense object detector."""
def __init__(self,
anchor,
match_threshold=0.5,
unmatched_threshold=0.5):
"""Constructs anchor labeler to assign labels to anchors.
Args:
anchor: an instance of class Anchors.
match_threshold: a float number between 0 and 1 representing the
lower-bound threshold to assign positive labels for anchors. An anchor
with a score over the threshold is labeled positive.
unmatched_threshold: a float number between 0 and 1 representing the
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
"""
similarity_calc = region_similarity_calculator.IouSimilarity()
matcher = argmax_matcher.ArgMaxMatcher(
match_threshold,
unmatched_threshold=unmatched_threshold,
negatives_lower_than_unmatched=True,
force_match_for_each_row=True)
box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()
self._target_assigner = target_assigner.TargetAssigner(
similarity_calc, matcher, box_coder)
self._anchor = anchor
self._match_threshold = match_threshold
self._unmatched_threshold = unmatched_threshold
def label_anchors(self, gt_boxes, gt_labels):
"""Labels anchors with ground truth inputs.
Args:
gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
For each row, it stores [y0, x0, y1, x1] for four corners of a box.
gt_labels: A integer tensor with shape [N, 1] representing groundtruth
classes.
Returns:
cls_targets_dict: ordered dictionary with keys
[min_level, min_level+1, ..., max_level]. The values are tensor with
shape [height_l, width_l, num_anchors_per_location]. The height_l and
width_l represent the dimension of class logits at l-th level.
box_targets_dict: ordered dictionary with keys
[min_level, min_level+1, ..., max_level]. The values are tensor with
shape [height_l, width_l, num_anchors_per_location * 4]. The height_l
and width_l represent the dimension of bounding box regression output at
l-th level.
num_positives: scalar tensor storing number of positives in an image.
"""
gt_box_list = box_list.BoxList(gt_boxes)
anchor_box_list = box_list.BoxList(self._anchor.boxes)
# The cls_weights, box_weights are not used.
cls_targets, _, box_targets, _, matches = self._target_assigner.assign(
anchor_box_list, gt_box_list, gt_labels)
# Labels definition in matches.match_results:
# (1) match_results[i]>=0, meaning that column i is matched with row
# match_results[i].
# (2) match_results[i]=-1, meaning that column i is not matched.
# (3) match_results[i]=-2, meaning that column i is ignored.
match_results = tf.expand_dims(matches.match_results, axis=1)
cls_targets = tf.cast(cls_targets, tf.int32)
cls_targets = tf.where(
tf.equal(match_results, -1), -tf.ones_like(cls_targets), cls_targets)
cls_targets = tf.where(
tf.equal(match_results, -2), -2 * tf.ones_like(cls_targets),
cls_targets)
# Unpacks labels into multi-level representations.
cls_targets_dict = self._anchor.unpack_labels(cls_targets)
box_targets_dict = self._anchor.unpack_labels(box_targets)
num_positives = tf.reduce_sum(
input_tensor=tf.cast(tf.greater(matches.match_results, -1), tf.float32))
return cls_targets_dict, box_targets_dict, num_positives
class RpnAnchorLabeler(AnchorLabeler):
"""Labeler for Region Proposal Network."""
def __init__(self, anchor, match_threshold=0.7,
unmatched_threshold=0.3, rpn_batch_size_per_im=256,
rpn_fg_fraction=0.5):
AnchorLabeler.__init__(self, anchor, match_threshold=0.7,
unmatched_threshold=0.3)
self._rpn_batch_size_per_im = rpn_batch_size_per_im
self._rpn_fg_fraction = rpn_fg_fraction
def _get_rpn_samples(self, match_results):
"""Computes anchor labels.
This function performs subsampling for foreground (fg) and background (bg)
anchors.
Args:
match_results: A integer tensor with shape [N] representing the
matching results of anchors. (1) match_results[i]>=0,
meaning that column i is matched with row match_results[i].
(2) match_results[i]=-1, meaning that column i is not matched.
(3) match_results[i]=-2, meaning that column i is ignored.
Returns:
score_targets: a integer tensor with the a shape of [N].
(1) score_targets[i]=1, the anchor is a positive sample.
(2) score_targets[i]=0, negative. (3) score_targets[i]=-1, the anchor is
don't care (ignore).
"""
sampler = (
balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
positive_fraction=self._rpn_fg_fraction, is_static=False))
# indicator includes both positive and negative labels.
# labels includes only positives labels.
# positives = indicator & labels.
# negatives = indicator & !labels.
# ignore = !indicator.
indicator = tf.greater(match_results, -2)
labels = tf.greater(match_results, -1)
samples = sampler.subsample(
indicator, self._rpn_batch_size_per_im, labels)
positive_labels = tf.where(
tf.logical_and(samples, labels),
tf.constant(2, dtype=tf.int32, shape=match_results.shape),
tf.constant(0, dtype=tf.int32, shape=match_results.shape))
negative_labels = tf.where(
tf.logical_and(samples, tf.logical_not(labels)),
tf.constant(1, dtype=tf.int32, shape=match_results.shape),
tf.constant(0, dtype=tf.int32, shape=match_results.shape))
ignore_labels = tf.fill(match_results.shape, -1)
return (ignore_labels + positive_labels + negative_labels,
positive_labels, negative_labels)
def label_anchors(self, gt_boxes, gt_labels):
"""Labels anchors with ground truth inputs.
Args:
gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
For each row, it stores [y0, x0, y1, x1] for four corners of a box.
gt_labels: A integer tensor with shape [N, 1] representing groundtruth
classes.
Returns:
score_targets_dict: ordered dictionary with keys
[min_level, min_level+1, ..., max_level]. The values are tensor with
shape [height_l, width_l, num_anchors]. The height_l and width_l
represent the dimension of class logits at l-th level.
box_targets_dict: ordered dictionary with keys
[min_level, min_level+1, ..., max_level]. The values are tensor with
shape [height_l, width_l, num_anchors * 4]. The height_l and
width_l represent the dimension of bounding box regression output at
l-th level.
"""
gt_box_list = box_list.BoxList(gt_boxes)
anchor_box_list = box_list.BoxList(self._anchor.boxes)
# cls_targets, cls_weights, box_weights are not used.
_, _, box_targets, _, matches = self._target_assigner.assign(
anchor_box_list, gt_box_list, gt_labels)
# score_targets contains the subsampled positive and negative anchors.
score_targets, _, _ = self._get_rpn_samples(matches.match_results)
# Unpacks labels.
score_targets_dict = self._anchor.unpack_labels(score_targets)
box_targets_dict = self._anchor.unpack_labels(box_targets)
return score_targets_dict, box_targets_dict
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model architecture factory."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from official.vision.detection.dataloader import retinanet_parser
def parser_generator(params, mode):
"""Generator function for various dataset parser."""
if params.architecture.parser == 'retinanet_parser':
anchor_params = params.anchor
parser_params = params.retinanet_parser
parser_fn = retinanet_parser.Parser(
output_size=parser_params.output_size,
min_level=anchor_params.min_level,
max_level=anchor_params.max_level,
num_scales=anchor_params.num_scales,
aspect_ratios=anchor_params.aspect_ratios,
anchor_size=anchor_params.anchor_size,
match_threshold=parser_params.match_threshold,
unmatched_threshold=parser_params.unmatched_threshold,
aug_rand_hflip=parser_params.aug_rand_hflip,
aug_scale_min=parser_params.aug_scale_min,
aug_scale_max=parser_params.aug_scale_max,
use_autoaugment=parser_params.use_autoaugment,
autoaugment_policy_name=parser_params.autoaugment_policy_name,
skip_crowd_during_training=parser_params.skip_crowd_during_training,
max_num_instances=parser_params.max_num_instances,
use_bfloat16=parser_params.use_bfloat16,
mode=mode)
else:
raise ValueError('Parser %s is not supported.' % params.architecture.parser)
return parser_fn
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data loader and input processing."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow.compat.v2 as tf
from typing import Text, Optional
from official.modeling.hyperparams import params_dict
from official.vision.detection.dataloader import factory
from official.vision.detection.dataloader import mode_keys as ModeKeys
class InputFn(object):
"""Input function for tf.Estimator."""
def __init__(self,
file_pattern: Text,
params: params_dict.ParamsDict,
mode: Text,
batch_size: int,
num_examples: Optional[int] = -1):
"""Initialize.
Args:
file_pattern: the file pattern for the data example (TFRecords).
params: the parameter object for constructing example parser and model.
mode: ModeKeys.TRAIN or ModeKeys.Eval
batch_size: the data batch size.
num_examples: If positive, only takes this number of examples and raise
tf.errors.OutOfRangeError after that. If non-positive, it will be
ignored.
"""
assert file_pattern is not None
assert mode is not None
assert batch_size is not None
self._file_pattern = file_pattern
self._mode = mode
self._is_training = (mode == ModeKeys.TRAIN)
self._batch_size = batch_size
self._num_examples = num_examples
self._parser_fn = factory.parser_generator(params, mode)
self._dataset_fn = tf.data.TFRecordDataset
def __call__(self,
params: params_dict.ParamsDict = None,
batch_size=None,
ctx=None):
"""Provides tf.data.Dataset object.
Args:
params: placeholder for model parameters.
batch_size: expected batch size input data.
ctx: context object.
Returns:
tf.data.Dataset object.
"""
if not batch_size:
batch_size = self._batch_size
assert batch_size is not None
dataset = tf.data.Dataset.list_files(
self._file_pattern, shuffle=self._is_training)
if ctx and ctx.num_input_pipelines > 1:
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(
lambda file_name: self._dataset_fn(file_name).prefetch(1),
cycle_length=32,
sloppy=self._is_training))
if self._is_training:
dataset = dataset.shuffle(64)
if self._num_examples > 0:
dataset = dataset.take(self._num_examples)
# Parses the fetched records to input tensors for model function.
dataset = dataset.map(self._parser_fn, num_parallel_calls=64)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Standard names for input dataloader modes.
The following standard keys are defined:
* `TRAIN`: training mode.
* `EVAL`: evaluation mode.
* `PREDICT`: prediction mode.
* `PREDICT_WITH_GT`: prediction mode with groundtruths in returned variables.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
TRAIN = 'train'
EVAL = 'eval'
PREDICT = 'predict'
PREDICT_WITH_GT = 'predict_with_gt'
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data parser and processing.
Parse image and ground truths in a dataset to training targets and package them
into (image, labels) tuple for RetinaNet.
T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Dollar
Focal Loss for Dense Object Detection. arXiv:1708.02002
"""
import tensorflow.compat.v2 as tf
from official.vision.detection.dataloader import anchor
from official.vision.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.dataloader import tf_example_decoder
from official.vision.detection.utils import autoaugment_utils
from official.vision.detection.utils import box_utils
from official.vision.detection.utils import input_utils
def process_source_id(source_id):
"""Processes source_id to the right format."""
if source_id.dtype == tf.string:
source_id = tf.cast(tf.strings.to_number(source_id), tf.int32)
with tf.control_dependencies([source_id]):
source_id = tf.cond(
pred=tf.equal(tf.size(input=source_id), 0),
true_fn=lambda: tf.cast(tf.constant(-1), tf.int32),
false_fn=lambda: tf.identity(source_id))
return source_id
def pad_groundtruths_to_fixed_size(gt, n):
"""Pads the first dimension of groundtruths labels to the fixed size."""
gt['boxes'] = input_utils.pad_to_fixed_size(gt['boxes'], n, -1)
gt['is_crowds'] = input_utils.pad_to_fixed_size(gt['is_crowds'], n, 0)
gt['areas'] = input_utils.pad_to_fixed_size(gt['areas'], n, -1)
gt['classes'] = input_utils.pad_to_fixed_size(gt['classes'], n, -1)
return gt
class Parser(object):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
output_size,
min_level,
max_level,
num_scales,
aspect_ratios,
anchor_size,
match_threshold=0.5,
unmatched_threshold=0.5,
aug_rand_hflip=False,
aug_scale_min=1.0,
aug_scale_max=1.0,
use_autoaugment=False,
autoaugment_policy_name='v0',
skip_crowd_during_training=True,
max_num_instances=100,
use_bfloat16=True,
mode=None):
"""Initializes parameters for parsing annotations in the dataset.
Args:
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
min_level: `int` number of minimum level of the output feature pyramid.
max_level: `int` number of maximum level of the output feature pyramid.
num_scales: `int` number representing intermediate scales added
on each level. For instances, num_scales=2 adds one additional
intermediate anchor scales [2^0, 2^0.5] on each level.
aspect_ratios: `list` of float numbers representing the aspect raito
anchors added on each level. The number indicates the ratio of width to
height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
on each scale level.
anchor_size: `float` number representing the scale of size of the base
anchor to the feature stride 2^level.
match_threshold: `float` number between 0 and 1 representing the
lower-bound threshold to assign positive labels for anchors. An anchor
with a score over the threshold is labeled positive.
unmatched_threshold: `float` number between 0 and 1 representing the
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_scale_min: `float`, the minimum scale applied to `output_size` for
data augmentation during training.
aug_scale_max: `float`, the maximum scale applied to `output_size` for
data augmentation during training.
use_autoaugment: `bool`, if True, use the AutoAugment augmentation policy
during training.
autoaugment_policy_name: `string` that specifies the name of the
AutoAugment policy that will be used during training.
skip_crowd_during_training: `bool`, if True, skip annotations labeled with
`is_crowd` equals to 1.
max_num_instances: `int` number of maximum number of instances in an
image. The groundtruth data will be padded to `max_num_instances`.
use_bfloat16: `bool`, if True, cast output image to tf.bfloat16.
mode: a ModeKeys. Specifies if this is training, evaluation, prediction
or prediction with groundtruths in the outputs.
"""
self._mode = mode
self._max_num_instances = max_num_instances
self._skip_crowd_during_training = skip_crowd_during_training
self._is_training = (mode == ModeKeys.TRAIN)
self._example_decoder = tf_example_decoder.TfExampleDecoder(
include_mask=False)
# Anchor.
self._output_size = output_size
self._min_level = min_level
self._max_level = max_level
self._num_scales = num_scales
self._aspect_ratios = aspect_ratios
self._anchor_size = anchor_size
self._match_threshold = match_threshold
self._unmatched_threshold = unmatched_threshold
# Data augmentation.
self._aug_rand_hflip = aug_rand_hflip
self._aug_scale_min = aug_scale_min
self._aug_scale_max = aug_scale_max
# Data Augmentation with AutoAugment.
self._use_autoaugment = use_autoaugment
self._autoaugment_policy_name = autoaugment_policy_name
# Device.
self._use_bfloat16 = use_bfloat16
# Data is parsed depending on the model Modekey.
if mode == ModeKeys.TRAIN:
self._parse_fn = self._parse_train_data
elif mode == ModeKeys.EVAL:
self._parse_fn = self._parse_eval_data
elif mode == ModeKeys.PREDICT or mode == ModeKeys.PREDICT_WITH_GT:
self._parse_fn = self._parse_predict_data
else:
raise ValueError('mode is not defined.')
def __call__(self, value):
"""Parses data to an image and associated training labels.
Args:
value: a string tensor holding a serialized tf.Example proto.
Returns:
image: image tensor that is preproessed to have normalized value and
dimension [output_size[0], output_size[1], 3]
labels:
cls_targets: ordered dictionary with keys
[min_level, min_level+1, ..., max_level]. The values are tensor with
shape [height_l, width_l, anchors_per_location]. The height_l and
width_l represent the dimension of class logits at l-th level.
box_targets: ordered dictionary with keys
[min_level, min_level+1, ..., max_level]. The values are tensor with
shape [height_l, width_l, anchors_per_location * 4]. The height_l and
width_l represent the dimension of bounding box regression output at
l-th level.
num_positives: number of positive anchors in the image.
anchor_boxes: ordered dictionary with keys
[min_level, min_level+1, ..., max_level]. The values are tensor with
shape [height_l, width_l, 4] representing anchor boxes at each level.
image_info: a 2D `Tensor` that encodes the information of the image and
the applied preprocessing. It is in the format of
[[original_height, original_width], [scaled_height, scaled_width],
[y_scale, x_scale], [y_offset, x_offset]].
groundtruths:
source_id: source image id. Default value -1 if the source id is empty
in the groundtruth annotation.
boxes: groundtruth bounding box annotations. The box is represented in
[y1, x1, y2, x2] format. The tennsor is padded with -1 to the fixed
dimension [self._max_num_instances, 4].
classes: groundtruth classes annotations. The tennsor is padded with
-1 to the fixed dimension [self._max_num_instances].
areas: groundtruth areas annotations. The tennsor is padded with -1
to the fixed dimension [self._max_num_instances].
is_crowds: groundtruth annotations to indicate if an annotation
represents a group of instances by value {0, 1}. The tennsor is
padded with 0 to the fixed dimension [self._max_num_instances].
"""
with tf.name_scope('parser'):
data = self._example_decoder.decode(value)
return self._parse_fn(data)
def _parse_train_data(self, data):
"""Parses data for training and evaluation."""
classes = data['groundtruth_classes']
boxes = data['groundtruth_boxes']
is_crowds = data['groundtruth_is_crowd']
# Skips annotations with `is_crowd` = True.
if self._skip_crowd_during_training and self._is_training:
num_groundtrtuhs = tf.shape(input=classes)[0]
with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
indices = tf.cond(
pred=tf.greater(tf.size(input=is_crowds), 0),
true_fn=lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
false_fn=lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
classes = tf.gather(classes, indices)
boxes = tf.gather(boxes, indices)
# Gets original image and its size.
image = data['image']
# NOTE: The autoaugment method works best when used alongside the standard
# horizontal flipping of images along with size jittering and normalization.
if self._use_autoaugment:
image, boxes = autoaugment_utils.distort_image_with_autoaugment(
image, boxes, self._autoaugment_policy_name)
image_shape = tf.shape(input=image)[0:2]
# Normalizes image with mean and std pixel values.
image = input_utils.normalize_image(image)
# Flips image randomly during training.
if self._aug_rand_hflip:
image, boxes = input_utils.random_horizontal_flip(image, boxes)
# Converts boxes from normalized coordinates to pixel coordinates.
boxes = box_utils.denormalize_boxes(boxes, image_shape)
# Resizes and crops image.
image, image_info = input_utils.resize_and_crop_image(
image,
self._output_size,
padded_size=input_utils.compute_padded_size(
self._output_size, 2 ** self._max_level),
aug_scale_min=self._aug_scale_min,
aug_scale_max=self._aug_scale_max)
image_height, image_width, _ = image.get_shape().as_list()
# Resizes and crops boxes.
image_scale = image_info[2, :]
offset = image_info[3, :]
boxes = input_utils.resize_and_crop_boxes(
boxes, image_scale, (image_height, image_width), offset)
# Filters out ground truth boxes that are all zeros.
indices = input_utils.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
# Assigns anchors.
input_anchor = anchor.Anchor(
self._min_level, self._max_level, self._num_scales,
self._aspect_ratios, self._anchor_size, (image_height, image_width))
anchor_labeler = anchor.AnchorLabeler(
input_anchor, self._match_threshold, self._unmatched_threshold)
(cls_targets, box_targets, num_positives) = anchor_labeler.label_anchors(
boxes,
tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
# If bfloat16 is used, casts input image to tf.bfloat16.
if self._use_bfloat16:
image = tf.cast(image, dtype=tf.bfloat16)
# Packs labels for model_fn outputs.
labels = {
'cls_targets': cls_targets,
'box_targets': box_targets,
'anchor_boxes': input_anchor.multilevel_boxes,
'num_positives': num_positives,
'image_info': image_info,
}
return image, labels
def _parse_eval_data(self, data):
"""Parses data for training and evaluation."""
groundtruths = {}
classes = data['groundtruth_classes']
boxes = data['groundtruth_boxes']
# Gets original image and its size.
image = data['image']
image_shape = tf.shape(input=image)[0:2]
# Normalizes image with mean and std pixel values.
image = input_utils.normalize_image(image)
# Converts boxes from normalized coordinates to pixel coordinates.
boxes = box_utils.denormalize_boxes(boxes, image_shape)
# Resizes and crops image.
image, image_info = input_utils.resize_and_crop_image(
image,
self._output_size,
padded_size=input_utils.compute_padded_size(
self._output_size, 2 ** self._max_level),
aug_scale_min=1.0,
aug_scale_max=1.0)
image_height, image_width, _ = image.get_shape().as_list()
# Resizes and crops boxes.
image_scale = image_info[2, :]
offset = image_info[3, :]
boxes = input_utils.resize_and_crop_boxes(
boxes, image_scale, (image_height, image_width), offset)
# Filters out ground truth boxes that are all zeros.
indices = input_utils.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
# Assigns anchors.
input_anchor = anchor.Anchor(
self._min_level, self._max_level, self._num_scales,
self._aspect_ratios, self._anchor_size, (image_height, image_width))
anchor_labeler = anchor.AnchorLabeler(
input_anchor, self._match_threshold, self._unmatched_threshold)
(cls_targets, box_targets, num_positives) = anchor_labeler.label_anchors(
boxes,
tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
# If bfloat16 is used, casts input image to tf.bfloat16.
if self._use_bfloat16:
image = tf.cast(image, dtype=tf.bfloat16)
# Sets up groundtruth data for evaluation.
groundtruths = {
'source_id': data['source_id'],
'num_groundtrtuhs': tf.shape(data['groundtruth_classes']),
'image_info': image_info,
'boxes': box_utils.denormalize_boxes(
data['groundtruth_boxes'], image_shape),
'classes': data['groundtruth_classes'],
'areas': data['groundtruth_area'],
'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
}
groundtruths['source_id'] = process_source_id(groundtruths['source_id'])
groundtruths = pad_groundtruths_to_fixed_size(
groundtruths, self._max_num_instances)
# Packs labels for model_fn outputs.
labels = {
'cls_targets': cls_targets,
'box_targets': box_targets,
'anchor_boxes': input_anchor.multilevel_boxes,
'num_positives': num_positives,
'image_info': image_info,
'groundtruths': groundtruths,
}
return image, labels
def _parse_predict_data(self, data):
"""Parses data for prediction."""
# Gets original image and its size.
image = data['image']
image_shape = tf.shape(input=image)[0:2]
# Normalizes image with mean and std pixel values.
image = input_utils.normalize_image(image)
# Resizes and crops image.
image, image_info = input_utils.resize_and_crop_image(
image,
self._output_size,
padded_size=input_utils.compute_padded_size(
self._output_size, 2 ** self._max_level),
aug_scale_min=1.0,
aug_scale_max=1.0)
image_height, image_width, _ = image.get_shape().as_list()
# If bfloat16 is used, casts input image to tf.bfloat16.
if self._use_bfloat16:
image = tf.cast(image, dtype=tf.bfloat16)
# Compute Anchor boxes.
input_anchor = anchor.Anchor(
self._min_level, self._max_level, self._num_scales,
self._aspect_ratios, self._anchor_size, (image_height, image_width))
labels = {
'anchor_boxes': input_anchor.multilevel_boxes,
'image_info': image_info,
}
# If mode is PREDICT_WITH_GT, returns groundtruths and training targets
# in labels.
if self._mode == ModeKeys.PREDICT_WITH_GT:
# Converts boxes from normalized coordinates to pixel coordinates.
boxes = box_utils.denormalize_boxes(
data['groundtruth_boxes'], image_shape)
groundtruths = {
'source_id': data['source_id'],
'num_detections': tf.shape(data['groundtruth_classes']),
'boxes': boxes,
'classes': data['groundtruth_classes'],
'areas': data['groundtruth_area'],
'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
}
groundtruths['source_id'] = process_source_id(groundtruths['source_id'])
groundtruths = pad_groundtruths_to_fixed_size(
groundtruths, self._max_num_instances)
labels['groundtruths'] = groundtruths
# Computes training objective for evaluation loss.
classes = data['groundtruth_classes']
image_scale = image_info[2, :]
offset = image_info[3, :]
boxes = input_utils.resize_and_crop_boxes(
boxes, image_scale, (image_height, image_width), offset)
# Filters out ground truth boxes that are all zeros.
indices = input_utils.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
# Assigns anchors.
anchor_labeler = anchor.AnchorLabeler(
input_anchor, self._match_threshold, self._unmatched_threshold)
(cls_targets, box_targets, num_positives) = anchor_labeler.label_anchors(
boxes,
tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
labels['cls_targets'] = cls_targets
labels['box_targets'] = box_targets
labels['num_positives'] = num_positives
return image, labels
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tensorflow Example proto decoder for object detection.
A decoder to decode string tensors containing serialized tensorflow.Example
protos for object detection.
"""
import tensorflow.compat.v2 as tf
class TfExampleDecoder(object):
"""Tensorflow Example proto decoder."""
def __init__(self, include_mask=False):
self._include_mask = include_mask
self._keys_to_features = {
'image/encoded':
tf.io.FixedLenFeature((), tf.string),
'image/source_id':
tf.io.FixedLenFeature((), tf.string),
'image/height':
tf.io.FixedLenFeature((), tf.int64),
'image/width':
tf.io.FixedLenFeature((), tf.int64),
'image/object/bbox/xmin':
tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax':
tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin':
tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax':
tf.io.VarLenFeature(tf.float32),
'image/object/class/label':
tf.io.VarLenFeature(tf.int64),
'image/object/area':
tf.io.VarLenFeature(tf.float32),
'image/object/is_crowd':
tf.io.VarLenFeature(tf.int64),
}
if include_mask:
self._keys_to_features.update({
'image/object/mask':
tf.io.VarLenFeature(tf.string),
})
def _decode_image(self, parsed_tensors):
"""Decodes the image and set its static shape."""
image = tf.io.decode_image(parsed_tensors['image/encoded'], channels=3)
image.set_shape([None, None, 3])
return image
def _decode_boxes(self, parsed_tensors):
"""Concat box coordinates in the format of [ymin, xmin, ymax, xmax]."""
xmin = parsed_tensors['image/object/bbox/xmin']
xmax = parsed_tensors['image/object/bbox/xmax']
ymin = parsed_tensors['image/object/bbox/ymin']
ymax = parsed_tensors['image/object/bbox/ymax']
return tf.stack([ymin, xmin, ymax, xmax], axis=-1)
def _decode_masks(self, parsed_tensors):
"""Decode a set of PNG masks to the tf.float32 tensors."""
def _decode_png_mask(png_bytes):
mask = tf.squeeze(
tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1)
mask = tf.cast(mask, dtype=tf.float32)
mask.set_shape([None, None])
return mask
height = parsed_tensors['image/height']
width = parsed_tensors['image/width']
masks = parsed_tensors['image/object/mask']
return tf.cond(
pred=tf.greater(tf.size(input=masks), 0),
true_fn=lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32),
false_fn=lambda: tf.zeros([0, height, width], dtype=tf.float32))
def _decode_areas(self, parsed_tensors):
xmin = parsed_tensors['image/object/bbox/xmin']
xmax = parsed_tensors['image/object/bbox/xmax']
ymin = parsed_tensors['image/object/bbox/ymin']
ymax = parsed_tensors['image/object/bbox/ymax']
return tf.cond(
tf.greater(tf.shape(parsed_tensors['image/object/area'])[0], 0),
lambda: parsed_tensors['image/object/area'],
lambda: (xmax - xmin) * (ymax - ymin))
def decode(self, serialized_example):
"""Decode the serialized example.
Args:
serialized_example: a single serialized tf.Example string.
Returns:
decoded_tensors: a dictionary of tensors with the following fields:
- image: a uint8 tensor of shape [None, None, 3].
- source_id: a string scalar tensor.
- height: an integer scalar tensor.
- width: an integer scalar tensor.
- groundtruth_classes: a int64 tensor of shape [None].
- groundtruth_is_crowd: a bool tensor of shape [None].
- groundtruth_area: a float32 tensor of shape [None].
- groundtruth_boxes: a float32 tensor of shape [None, 4].
- groundtruth_instance_masks: a float32 tensor of shape
[None, None, None].
- groundtruth_instance_masks_png: a string tensor of shape [None].
"""
parsed_tensors = tf.io.parse_single_example(
serialized=serialized_example, features=self._keys_to_features)
for k in parsed_tensors:
if isinstance(parsed_tensors[k], tf.SparseTensor):
if parsed_tensors[k].dtype == tf.string:
parsed_tensors[k] = tf.sparse.to_dense(
parsed_tensors[k], default_value='')
else:
parsed_tensors[k] = tf.sparse.to_dense(
parsed_tensors[k], default_value=0)
image = self._decode_image(parsed_tensors)
boxes = self._decode_boxes(parsed_tensors)
areas = self._decode_areas(parsed_tensors)
is_crowds = tf.cond(
tf.greater(tf.shape(parsed_tensors['image/object/is_crowd'])[0], 0),
lambda: tf.cast(parsed_tensors['image/object/is_crowd'], dtype=tf.bool),
lambda: tf.zeros_like(parsed_tensors['image/object/class/label'], dtype=tf.bool)) # pylint: disable=line-too-long
if self._include_mask:
masks = self._decode_masks(parsed_tensors)
decoded_tensors = {
'image': image,
'source_id': parsed_tensors['image/source_id'],
'height': parsed_tensors['image/height'],
'width': parsed_tensors['image/width'],
'groundtruth_classes': parsed_tensors['image/object/class/label'],
'groundtruth_is_crowd': is_crowds,
'groundtruth_area': areas,
'groundtruth_boxes': boxes,
}
if self._include_mask:
decoded_tensors.update({
'groundtruth_instance_masks': masks,
'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'],
})
return decoded_tensors
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The COCO-style evaluator.
The following snippet demonstrates the use of interfaces:
evaluator = COCOEvaluator(...)
for _ in range(num_evals):
for _ in range(num_batches_per_eval):
predictions, groundtruth = predictor.predict(...) # pop a batch.
evaluator.update(predictions, groundtruths) # aggregate internal stats.
evaluator.evaluate() # finish one full eval.
See also: https://github.com/cocodataset/cocoapi/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import atexit
import tempfile
import numpy as np
from pycocotools import cocoeval
import six
import tensorflow.compat.v2 as tf
from official.vision.detection.evaluation import coco_utils
class COCOEvaluator(object):
"""COCO evaluation metric class."""
def __init__(self, annotation_file, include_mask):
"""Constructs COCO evaluation class.
The class provides the interface to metrics_fn in TPUEstimator. The
_update_op() takes detections from each image and push them to
self.detections. The _evaluate() loads a JSON file in COCO annotation format
as the groundtruths and runs COCO evaluation.
Args:
annotation_file: a JSON file that stores annotations of the eval dataset.
If `annotation_file` is None, groundtruth annotations will be loaded
from the dataloader.
include_mask: a boolean to indicate whether or not to include the mask
eval.
"""
if annotation_file:
if annotation_file.startswith('gs://'):
_, local_val_json = tempfile.mkstemp(suffix='.json')
tf.io.gfile.remove(local_val_json)
tf.io.gfile.copy(annotation_file, local_val_json)
atexit.register(tf.io.gfile.remove, local_val_json)
else:
local_val_json = annotation_file
self._coco_gt = coco_utils.COCOWrapper(
eval_type=('mask' if include_mask else 'box'),
annotation_file=local_val_json)
self._annotation_file = annotation_file
self._include_mask = include_mask
self._metric_names = ['AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'ARmax1',
'ARmax10', 'ARmax100', 'ARs', 'ARm', 'ARl']
self._required_prediction_fields = [
'source_id', 'image_info', 'num_detections', 'detection_classes',
'detection_scores', 'detection_boxes']
self._required_groundtruth_fields = [
'source_id', 'height', 'width', 'classes', 'boxes']
if self._include_mask:
mask_metric_names = ['mask_' + x for x in self._metric_names]
self._metric_names.extend(mask_metric_names)
self._required_prediction_fields.extend(['detection_masks'])
self._required_groundtruth_fields.extend(['masks'])
self.reset()
def reset(self):
"""Resets internal states for a fresh run."""
self._predictions = {}
if not self._annotation_file:
self._groundtruths = {}
def evaluate(self):
"""Evaluates with detections from all images with COCO API.
Returns:
coco_metric: float numpy array with shape [24] representing the
coco-style evaluation metrics (box and mask).
"""
if not self._annotation_file:
gt_dataset = coco_utils.convert_groundtruths_to_coco_dataset(
self._groundtruths)
coco_gt = coco_utils.COCOWrapper(
eval_type=('mask' if self._include_mask else 'box'),
gt_dataset=gt_dataset)
else:
coco_gt = self._coco_gt
coco_predictions = coco_utils.convert_predictions_to_coco_annotations(
self._predictions)
coco_dt = coco_gt.loadRes(predictions=coco_predictions)
image_ids = [ann['image_id'] for ann in coco_predictions]
coco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='bbox')
coco_eval.params.imgIds = image_ids
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
coco_metrics = coco_eval.stats
if self._include_mask:
mcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='segm')
mcoco_eval.params.imgIds = image_ids
mcoco_eval.evaluate()
mcoco_eval.accumulate()
mcoco_eval.summarize()
mask_coco_metrics = mcoco_eval.stats
if self._include_mask:
metrics = np.hstack((coco_metrics, mask_coco_metrics))
else:
metrics = coco_metrics
# Cleans up the internal variables in order for a fresh eval next time.
self.reset()
metrics_dict = {}
for i, name in enumerate(self._metric_names):
metrics_dict[name] = metrics[i].astype(np.float32)
return metrics_dict
def _process_predictions(self, predictions):
image_scale = np.tile(predictions['image_info'][:, 2:3, :], (1, 1, 2))
predictions['detection_boxes'] = (
predictions['detection_boxes'] / image_scale)
def update(self, predictions, groundtruths=None):
"""Update and aggregate detection results and groundtruth data.
Args:
predictions: a dictionary of numpy arrays including the fields below.
See different parsers under `../dataloader` for more details.
Required fields:
- source_id: a numpy array of int or string of shape [batch_size].
- image_info: a numpy array of float of shape [batch_size, 4, 2].
- num_detections: a numpy array of int of shape [batch_size].
- detection_boxes: a numpy array of float of shape [batch_size, K, 4].
- detection_classes: a numpy array of int of shape [batch_size, K].
- detection_scores: a numpy array of float of shape [batch_size, K].
Optional fields:
- detection_masks: a numpy array of float of shape
[batch_size, K, mask_height, mask_width].
groundtruths: a dictionary of numpy arrays including the fields below.
See also different parsers under `../dataloader` for more details.
Required fields:
- source_id: a numpy array of int or string of shape [batch_size].
- height: a numpy array of int of shape [batch_size].
- width: a numpy array of int of shape [batch_size].
- num_detections: a numpy array of int of shape [batch_size].
- boxes: a numpy array of float of shape [batch_size, K, 4].
- classes: a numpy array of int of shape [batch_size, K].
Optional fields:
- is_crowds: a numpy array of int of shape [batch_size, K]. If the
field is absent, it is assumed that this instance is not crowd.
- areas: a numy array of float of shape [batch_size, K]. If the
field is absent, the area is calculated using either boxes or
masks depending on which one is available.
- masks: a numpy array of float of shape
[batch_size, K, mask_height, mask_width],
Raises:
ValueError: if the required prediction or groundtruth fields are not
present in the incoming `predictions` or `groundtruths`.
"""
for k in self._required_prediction_fields:
if k not in predictions:
raise ValueError('Missing the required key `{}` in predictions!'
.format(k))
self._process_predictions(predictions)
for k, v in six.iteritems(predictions):
if k not in self._predictions:
self._predictions[k] = [v]
else:
self._predictions[k].append(v)
if not self._annotation_file:
assert groundtruths
for k in self._required_groundtruth_fields:
if k not in groundtruths:
raise ValueError('Missing the required key `{}` in groundtruths!'
.format(k))
for k, v in six.iteritems(groundtruths):
if k not in self._groundtruths:
self._groundtruths[k] = [v]
else:
self._groundtruths[k].append(v)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Util functions related to pycocotools and COCO eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
from absl import logging
import numpy as np
from PIL import Image
from pycocotools import coco
from pycocotools import mask as mask_utils
import six
import tensorflow.compat.v2 as tf
from official.vision.detection.dataloader import tf_example_decoder
from official.vision.detection.utils import box_utils
class COCOWrapper(coco.COCO):
"""COCO wrapper class.
This class wraps COCO API object, which provides the following additional
functionalities:
1. Support string type image id.
2. Support loading the groundtruth dataset using the external annotation
dictionary.
3. Support loading the prediction results using the external annotation
dictionary.
"""
def __init__(self, eval_type='box', annotation_file=None, gt_dataset=None):
"""Instantiates a COCO-style API object.
Args:
eval_type: either 'box' or 'mask'.
annotation_file: a JSON file that stores annotations of the eval dataset.
This is required if `gt_dataset` is not provided.
gt_dataset: the groundtruth eval datatset in COCO API format.
"""
if ((annotation_file and gt_dataset) or
((not annotation_file) and (not gt_dataset))):
raise ValueError('One and only one of `annotation_file` and `gt_dataset` '
'needs to be specified.')
if eval_type not in ['box', 'mask']:
raise ValueError('The `eval_type` can only be either `box` or `mask`.')
coco.COCO.__init__(self, annotation_file=annotation_file)
self._eval_type = eval_type
if gt_dataset:
self.dataset = gt_dataset
self.createIndex()
def loadRes(self, predictions):
"""Loads result file and return a result api object.
Args:
predictions: a list of dictionary each representing an annotation in COCO
format. The required fields are `image_id`, `category_id`, `score`,
`bbox`, `segmentation`.
Returns:
res: result COCO api object.
Raises:
ValueError: if the set of image id from predctions is not the subset of
the set of image id of the groundtruth dataset.
"""
res = coco.COCO()
res.dataset['images'] = copy.deepcopy(self.dataset['images'])
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
image_ids = [ann['image_id'] for ann in predictions]
if set(image_ids) != (set(image_ids) & set(self.getImgIds())):
raise ValueError('Results do not correspond to the current dataset!')
for ann in predictions:
x1, x2, y1, y2 = [ann['bbox'][0], ann['bbox'][0] + ann['bbox'][2],
ann['bbox'][1], ann['bbox'][1] + ann['bbox'][3]]
if self._eval_type == 'box':
ann['area'] = ann['bbox'][2] * ann['bbox'][3]
ann['segmentation'] = [
[x1, y1, x1, y2, x2, y2, x2, y1]]
elif self._eval_type == 'mask':
ann['bbox'] = mask_utils.toBbox(ann['segmentation'])
ann['area'] = mask_utils.area(ann['segmentation'])
res.dataset['annotations'] = copy.deepcopy(predictions)
res.createIndex()
return res
def convert_predictions_to_coco_annotations(predictions):
"""Converts a batch of predictions to annotations in COCO format.
Args:
predictions: a dictionary of lists of numpy arrays including the following
fields. K below denotes the maximum number of instances per image.
Required fields:
- source_id: a list of numpy arrays of int or string of shape
[batch_size].
- num_detections: a list of numpy arrays of int of shape [batch_size].
- detection_boxes: a list of numpy arrays of float of shape
[batch_size, K, 4], where coordinates are in the original image
space (not the scaled image space).
- detection_classes: a list of numpy arrays of int of shape
[batch_size, K].
- detection_scores: a list of numpy arrays of float of shape
[batch_size, K].
Optional fields:
- detection_masks: a list of numpy arrays of float of shape
[batch_size, K, mask_height, mask_width].
Returns:
coco_predictions: prediction in COCO annotation format.
"""
coco_predictions = []
num_batches = len(predictions['source_id'])
batch_size = predictions['source_id'][0].shape[0]
max_num_detections = predictions['detection_classes'][0].shape[1]
for i in range(num_batches):
for j in range(batch_size):
for k in range(max_num_detections):
ann = {}
ann['image_id'] = predictions['source_id'][i][j]
ann['category_id'] = predictions['detection_classes'][i][j, k]
boxes = predictions['detection_boxes'][i]
ann['bbox'] = [
boxes[j, k, 1],
boxes[j, k, 0],
boxes[j, k, 3] - boxes[j, k, 1],
boxes[j, k, 2] - boxes[j, k, 0]]
ann['score'] = predictions['detection_scores'][i][j, k]
if 'detection_masks' in predictions:
encoded_mask = mask_utils.encode(
np.asfortranarray(
predictions['detection_masks'][i][j, k].astype(np.uint8)))
ann['segmentation'] = encoded_mask
coco_predictions.append(ann)
for i, ann in enumerate(coco_predictions):
ann['id'] = i + 1
return coco_predictions
def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
"""Converts groundtruths to the dataset in COCO format.
Args:
groundtruths: a dictionary of numpy arrays including the fields below.
Note that each element in the list represent the number for a single
example without batch dimension. K below denotes the actual number of
instances for each image.
Required fields:
- source_id: a list of numpy arrays of int or string of shape
[batch_size].
- height: a list of numpy arrays of int of shape [batch_size].
- width: a list of numpy arrays of int of shape [batch_size].
- num_detections: a list of numpy arrays of int of shape [batch_size].
- boxes: a list of numpy arrays of float of shape [batch_size, K, 4],
where coordinates are in the original image space (not the
normalized coordinates).
- classes: a list of numpy arrays of int of shape [batch_size, K].
Optional fields:
- is_crowds: a list of numpy arrays of int of shape [batch_size, K]. If
th field is absent, it is assumed that this instance is not crowd.
- areas: a list of numy arrays of float of shape [batch_size, K]. If the
field is absent, the area is calculated using either boxes or
masks depending on which one is available.
- masks: a list of numpy arrays of string of shape [batch_size, K],
label_map: (optional) a dictionary that defines items from the category id
to the category name. If `None`, collect the category mappping from the
`groundtruths`.
Returns:
coco_groundtruths: the groundtruth dataset in COCO format.
"""
source_ids = np.concatenate(groundtruths['source_id'], axis=0)
heights = np.concatenate(groundtruths['height'], axis=0)
widths = np.concatenate(groundtruths['width'], axis=0)
gt_images = [{'id': int(i), 'height': int(h), 'width': int(w)} for i, h, w
in zip(source_ids, heights, widths)]
gt_annotations = []
num_batches = len(groundtruths['source_id'])
batch_size = groundtruths['source_id'][0].shape[0]
for i in range(num_batches):
for j in range(batch_size):
num_instances = groundtruths['num_detections'][i][j]
for k in range(num_instances):
ann = {}
ann['image_id'] = int(groundtruths['source_id'][i][j])
if 'is_crowds' in groundtruths:
ann['iscrowd'] = int(groundtruths['is_crowds'][i][j, k])
else:
ann['iscrowd'] = 0
ann['category_id'] = int(groundtruths['classes'][i][j, k])
boxes = groundtruths['boxes'][i]
ann['bbox'] = [
float(boxes[j, k, 1]),
float(boxes[j, k, 0]),
float(boxes[j, k, 3] - boxes[j, k, 1]),
float(boxes[j, k, 2] - boxes[j, k, 0])]
if 'areas' in groundtruths:
ann['area'] = float(groundtruths['areas'][i][j, k])
else:
ann['area'] = float(
(boxes[j, k, 3] - boxes[j, k, 1]) *
(boxes[j, k, 2] - boxes[j, k, 0]))
if 'masks' in groundtruths:
mask = Image.open(six.StringIO(groundtruths['masks'][i][j, k]))
width, height = mask.size
np_mask = (
np.array(mask.getdata()).reshape(height, width).astype(np.uint8))
np_mask[np_mask > 0] = 255
encoded_mask = mask_utils.encode(np.asfortranarray(np_mask))
ann['segmentation'] = encoded_mask
if 'areas' not in groundtruths:
ann['area'] = mask_utils.area(encoded_mask)
gt_annotations.append(ann)
for i, ann in enumerate(gt_annotations):
ann['id'] = i + 1
if label_map:
gt_categories = [{'id': i, 'name': label_map[i]} for i in label_map]
else:
category_ids = [gt['category_id'] for gt in gt_annotations]
gt_categories = [{'id': i} for i in set(category_ids)]
gt_dataset = {
'images': gt_images,
'categories': gt_categories,
'annotations': copy.deepcopy(gt_annotations),
}
return gt_dataset
class COCOGroundtruthGenerator(object):
"""Generates the groundtruth annotations from a single example sequentially."""
def __init__(self, file_pattern, num_examples, include_mask):
self._file_pattern = file_pattern
self._num_examples = num_examples
self._include_mask = include_mask
self._dataset_fn = tf.data.TFRecordDataset
def _parse_single_example(self, example):
"""Parses a single serialized tf.Example proto.
Args:
example: a serialized tf.Example proto string.
Returns:
A dictionary of groundtruth with the following fields:
source_id: a scalar tensor of int64 representing the image source_id.
height: a scalar tensor of int64 representing the image height.
width: a scalar tensor of int64 representing the image width.
boxes: a float tensor of shape [K, 4], representing the groundtruth
boxes in absolute coordinates with respect to the original image size.
classes: a int64 tensor of shape [K], representing the class labels of
each instances.
is_crowds: a bool tensor of shape [K], indicating whether the instance
is crowd.
areas: a float tensor of shape [K], indicating the area of each
instance.
masks: a string tensor of shape [K], containing the bytes of the png
mask of each instance.
"""
decoder = tf_example_decoder.TfExampleDecoder(
include_mask=self._include_mask)
decoded_tensors = decoder.decode(example)
image = decoded_tensors['image']
image_size = tf.shape(image)[0:2]
boxes = box_utils.denormalize_boxes(
decoded_tensors['groundtruth_boxes'], image_size)
groundtruths = {
'source_id': tf.string_to_number(
decoded_tensors['source_id'], out_type=tf.int64),
'height': decoded_tensors['height'],
'width': decoded_tensors['width'],
'num_detections': tf.shape(decoded_tensors['groundtruth_classes'])[0],
'boxes': boxes,
'classes': decoded_tensors['groundtruth_classes'],
'is_crowds': decoded_tensors['groundtruth_is_crowd'],
'areas': decoded_tensors['groundtruth_area'],
}
if self._include_mask:
groundtruths.update({
'masks': decoded_tensors['groundtruth_instance_masks_png'],
})
return groundtruths
def _build_pipeline(self):
"""Builds data pipeline to generate groundtruth annotations."""
dataset = tf.data.Dataset.list_files(self._file_pattern, shuffle=False)
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(
lambda filename: self._dataset_fn(filename).prefetch(1),
cycle_length=32,
sloppy=False))
dataset = dataset.map(self._parse_single_example, num_parallel_calls=64)
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
dataset = dataset.batch(1, drop_remainder=False)
return dataset
def __call__(self):
with tf.Graph().as_default():
dataset = self._build_pipeline()
groundtruth = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for _ in range(self._num_examples):
groundtruth_result = sess.run(groundtruth)
yield groundtruth_result
def scan_and_generator_annotation_file(file_pattern,
num_samples,
include_mask,
annotation_file):
"""Scans and generate the COCO-style annotation JSON file given a dataset."""
groundtruth_generator = COCOGroundtruthGenerator(
file_pattern, num_samples, include_mask)
generate_annotation_file(groundtruth_generator, annotation_file)
def generate_annotation_file(groundtruth_generator,
annotation_file):
"""Generates COCO-style annotation JSON file given a groundtruth generator."""
groundtruths = {}
logging.info('Loading groundtruth annotations from dataset to memory...')
for groundtruth in groundtruth_generator():
for k, v in six.iteritems(groundtruth):
if k not in groundtruths:
groundtruths[k] = [v]
else:
groundtruths[k].append(v)
gt_dataset = convert_groundtruths_to_coco_dataset(groundtruths)
logging.info('Saving groundtruth annotations to the JSON file...')
with tf.io.gfile.GFile(annotation_file, 'w') as f:
f.write(json.dumps(gt_dataset))
logging.info('Done saving the JSON file...')
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluator factory."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from official.vision.detection.evaluation import coco_evaluator
def evaluator_generator(params):
"""Generator function for various evaluators."""
if params.type == 'box':
evaluator = coco_evaluator.COCOEvaluator(
annotation_file=params.val_json_file, include_mask=False)
elif params.type == 'box_and_mask':
evaluator = coco_evaluator.COCOEvaluator(
annotation_file=params.val_json_file, include_mask=True)
else:
raise ValueError('Evaluator %s is not supported.' % params.type)
return evaluator
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An executor class for running model on TensorFlow 2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import logging
import os
import json
import tensorflow.compat.v2 as tf
from official.modeling.training import distributed_executor as executor
class DetectionDistributedExecutor(executor.DistributedExecutor):
"""Detection specific customer training loop executor.
Subclasses the DistributedExecutor and adds support for numpy based metrics.
"""
def __init__(self,
predict_post_process_fn=None,
trainable_variables_filter=None,
**kwargs):
super(DetectionDistributedExecutor, self).__init__(**kwargs)
params = kwargs['params']
if predict_post_process_fn:
assert callable(predict_post_process_fn)
if trainable_variables_filter:
assert callable(trainable_variables_filter)
self._predict_post_process_fn = predict_post_process_fn
self._trainable_variables_filter = trainable_variables_filter
def _create_train_step(self,
strategy,
model,
loss_fn,
optimizer,
metric=None):
"""Creates a distributed training step."""
@tf.function
def train_step(iterator):
"""Performs a distributed training step.
Args:
strategy: an instance of tf.distribute.Strategy.
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
iterator: an iterator that yields input tensors.
metric: eval metrics to be run outside the graph.
Returns:
The loss tensor.
"""
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
all_losses = loss_fn(labels, outputs)
losses = {}
for k, v in all_losses.items():
v = tf.reduce_mean(v) / strategy.num_replicas_in_sync
losses[k] = v
loss = losses['total_loss']
if isinstance(metric, tf.keras.metrics.Metric):
metric.update_state(labels, outputs)
else:
logging.error('train metric is not an instance of '
'tf.keras.metrics.Metric.')
trainable_variables = model.trainable_variables
if self._trainable_variables_filter:
trainable_variables = self._trainable_variables_filter(
trainable_variables)
logging.info('Filter trainable variables from %d to %d',
len(model.trainable_variables), len(trainable_variables))
grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
# return losses, labels
return loss
per_replica_losses = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return loss
return train_step
def _create_test_step(self, strategy, model, metric):
"""Creates a distributed test step."""
@tf.function
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
if self._predict_post_process_fn:
labels, model_outputs = self._predict_post_process_fn(
labels, model_outputs)
return labels, model_outputs
labels, outputs = strategy.experimental_run_v2(
_test_step_fn, args=(next(iterator),))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs)
labels = tf.nest.map_structure(strategy.experimental_local_results,
labels)
return labels, outputs
return test_step
def _run_evaluation(self, test_step, current_training_step, metric,
test_iterator):
"""Runs validation steps and aggregate metrics."""
if not test_iterator or not metric:
logging.warning(
'Both test_iterator (%s) and metrics (%s) must not be None.',
test_iterator, metric)
return None
logging.info('Running evaluation after step: %s.', current_training_step)
while True:
try:
labels, outputs = test_step(test_iterator)
if metric:
metric.update_state(labels, outputs)
except (StopIteration, tf.errors.OutOfRangeError):
break
metric_result = metric.result()
if isinstance(metric, tf.keras.metrics.Metric):
metric_result = tf.nest.map_structure(lambda x: x.numpy().astype(float),
metric_result)
logging.info('Step: [%d] Validation metric = %s', current_training_step,
metric_result)
return metric_result
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main function to train various object detection models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import app
from absl import flags
from absl import logging
import functools
import os
import pprint
import tensorflow.compat.v2 as tf
from official.modeling.hyperparams import params_dict
from official.modeling.training import distributed_executor as executor
from official.vision.detection.configs import factory as config_factory
from official.vision.detection.dataloader import input_reader
from official.vision.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor
from official.vision.detection.modeling import factory as model_factory
executor.initialize_common_flags()
flags.DEFINE_string(
'mode',
default='train',
help='Mode to run: `train`, `eval` or `train_and_eval`.')
flags.DEFINE_string(
'model', default='retinanet',
help='Model to run: `retinanet` or `shapemask`.')
flags.DEFINE_string('training_file_pattern', None,
'Location of the train data.')
flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
FLAGS = flags.FLAGS
def run_executor(params, train_input_fn=None, eval_input_fn=None):
"""Runs Retinanet model on distribution strategy defined by the user."""
model_builder = model_factory.model_generator(params)
if FLAGS.mode == 'train':
def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.TRAIN)
builder = executor.ExecutorBuilder(
strategy_type=params.strategy_type,
strategy_config=params.strategy_config)
num_workers = (builder.strategy.num_replicas_in_sync + 7) / 8
is_multi_host = (num_workers > 1)
if is_multi_host:
train_input_fn = functools.partial(
train_input_fn,
batch_size=params.train.batch_size //
builder.strategy.num_replicas_in_sync)
dist_executor = builder.build_executor(
class_ctor=DetectionDistributedExecutor,
params=params,
is_multi_host=is_multi_host,
model_fn=_model_fn,
loss_fn=model_builder.build_loss_fn,
predict_post_process_fn=model_builder.post_processing,
trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn())
return dist_executor.train(
train_input_fn=train_input_fn,
model_dir=params.model_dir,
iterations_per_loop=params.train.iterations_per_loop,
total_steps=params.train.total_steps,
init_checkpoint=model_builder.make_restore_checkpoint_fn(),
save_config=True)
elif FLAGS.mode == 'eval':
def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)
builder = executor.ExecutorBuilder(
strategy_type=params.strategy_type,
strategy_config=params.strategy_config)
dist_executor = builder.build_executor(
class_ctor=DetectionDistributedExecutor,
params=params,
model_fn=_model_fn,
loss_fn=model_builder.build_loss_fn,
predict_post_process_fn=model_builder.post_processing,
trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn())
results = dist_executor.evaluate_from_model_dir(
model_dir=params.model_dir,
eval_input_fn=eval_input_fn,
eval_metric_fn=model_builder.eval_metrics,
eval_timeout=params.eval.eval_timeout,
min_eval_interval=params.eval.min_eval_interval,
total_steps=params.train.total_steps)
for k, v in results.items():
logging.info('Final eval metric %s: %f', k, v)
return results
else:
tf.logging.info('Mode not found: %s.' % FLAGS.mode)
def main(argv):
del argv # Unused.
params = config_factory.config_generator(FLAGS.model)
params = params_dict.override_params_dict(
params, FLAGS.config_file, is_strict=True)
params = params_dict.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.override(
{
'strategy_type': FLAGS.strategy_type,
'model_dir': FLAGS.model_dir,
'strategy_config': executor.strategy_flags_dict(),
},
is_strict=False)
params.validate()
params.lock()
pp = pprint.PrettyPrinter()
params_str = pp.pformat(params.as_dict())
logging.info('Model Parameters: {}'.format(params_str))
train_input_fn = None
eval_input_fn = None
training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern
eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern
if not training_file_pattern and not eval_file_pattern:
raise ValueError('Must provide at least one of training_file_pattern and '
'eval_file_pattern.')
if training_file_pattern:
# Use global batch size for single host.
train_input_fn = input_reader.InputFn(
file_pattern=training_file_pattern,
params=params,
mode=input_reader.ModeKeys.TRAIN,
batch_size=params.train.batch_size)
if eval_file_pattern:
eval_input_fn = input_reader.InputFn(
file_pattern=eval_file_pattern,
params=params,
mode=input_reader.ModeKeys.PREDICT_WITH_GT,
batch_size=params.eval.batch_size,
num_examples=params.eval.eval_samples)
run_executor(
params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment