Commit 6a55ecde authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #10286 from PurdueDualityLab:task_pr

PiperOrigin-RevId: 402338060
parents 2d353306 379d64c5
......@@ -170,14 +170,14 @@ def get_image_shape(image):
def _augment_hsv_darknet(image, rh, rs, rv, seed=None):
"""Randomize the hue, saturation, and brightness via the darknet method."""
if rh > 0.0:
delta = random_uniform_strong(-rh, rh, seed=seed)
image = tf.image.adjust_hue(image, delta)
deltah = random_uniform_strong(-rh, rh, seed=seed)
image = tf.image.adjust_hue(image, deltah)
if rs > 0.0:
delta = random_scale(rs, seed=seed)
image = tf.image.adjust_saturation(image, delta)
deltas = random_scale(rs, seed=seed)
image = tf.image.adjust_saturation(image, deltas)
if rv > 0.0:
delta = random_scale(rv, seed=seed)
image *= delta
deltav = random_scale(rv, seed=seed)
image *= tf.cast(deltav, image.dtype)
# clip the values of the image between 0.0 and 1.0
image = tf.clip_by_value(image, 0.0, 1.0)
......@@ -719,7 +719,7 @@ def affine_warp_boxes(affine, boxes, output_size, box_history):
return tf.stack([y_min, x_min, y_max, x_max], axis=-1)
def _aug_boxes(affine_matrix, box):
"""Apply an affine transformation matrix M to the boxes augmente boxes."""
"""Apply an affine transformation matrix M to the boxes augment boxes."""
corners = _get_corners(box)
corners = tf.reshape(corners, [-1, 4, 2])
z = tf.expand_dims(tf.ones_like(corners[..., 1]), axis=-1)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization package definition."""
# pylint: disable=wildcard-import
from official.modeling.optimization.configs.learning_rate_config import *
from official.modeling.optimization.ema_optimizer import ExponentialMovingAverage
from official.vision.beta.projects.yolo.optimization.configs.optimization_config import *
from official.vision.beta.projects.yolo.optimization.configs.optimizer_config import *
from official.vision.beta.projects.yolo.optimization.optimizer_factory import OptimizerFactory as YoloOptimizerFactory
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclasses for optimization configs.
This file define the dataclass for optimization configs (OptimizationConfig).
It also has two helper functions get_optimizer_config, and get_lr_config from
an OptimizationConfig class.
"""
import dataclasses
from typing import Optional
from official.modeling.optimization.configs import optimization_config as optimization_cfg
from official.vision.beta.projects.yolo.optimization.configs import optimizer_config as opt_cfg
@dataclasses.dataclass
class OptimizerConfig(optimization_cfg.OptimizerConfig):
"""Configuration for optimizer.
Attributes:
type: 'str', type of optimizer to be used, on the of fields below.
sgd: sgd optimizer config.
adam: adam optimizer config.
adamw: adam with weight decay.
lamb: lamb optimizer.
rmsprop: rmsprop optimizer.
"""
type: Optional[str] = None
sgd_torch: opt_cfg.SGDTorchConfig = opt_cfg.SGDTorchConfig()
@dataclasses.dataclass
class OptimizationConfig(optimization_cfg.OptimizationConfig):
"""Configuration for optimizer and learning rate schedule.
Attributes:
optimizer: optimizer oneof config.
ema: optional exponential moving average optimizer config, if specified, ema
optimizer will be used.
learning_rate: learning rate oneof config.
warmup: warmup oneof config.
"""
type: Optional[str] = None
optimizer: OptimizerConfig = OptimizerConfig()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclasses for optimizer configs."""
import dataclasses
from typing import List, Optional
from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs import optimizer_config
@dataclasses.dataclass
class BaseOptimizerConfig(base_config.Config):
"""Base optimizer config.
Attributes:
clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
their L2 norm exceeds this value.
clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
their absolute value exceeds this value.
global_clipnorm: float >= 0 or None. If not None, gradient of all weights is
clipped so that their global norm is no higher than this value
"""
clipnorm: Optional[float] = None
clipvalue: Optional[float] = None
global_clipnorm: Optional[float] = None
@dataclasses.dataclass
class SGDTorchConfig(optimizer_config.BaseOptimizerConfig):
"""Configuration for SGD optimizer.
The attributes for this class matches the arguments of tf.keras.optimizer.SGD.
Attributes:
name: name of the optimizer.
decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer.
momentum_start: momentum starting point for SGD optimizer.
momentum: momentum for SGD optimizer.
"""
name: str = "SGD"
decay: float = 0.0
nesterov: bool = False
momentum_start: float = 0.0
momentum: float = 0.9
warmup_steps: int = 0
weight_decay: float = 0.0
weight_keys: Optional[List[str]] = dataclasses.field(
default_factory=lambda: ["kernel", "weight"])
bias_keys: Optional[List[str]] = dataclasses.field(
default_factory=lambda: ["bias", "beta"])
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimizer factory class."""
import gin
from official.modeling.optimization import ema_optimizer
from official.modeling.optimization import optimizer_factory
from official.vision.beta.projects.yolo.optimization import sgd_torch
optimizer_factory.OPTIMIZERS_CLS.update({
'sgd_torch': sgd_torch.SGDTorch,
})
OPTIMIZERS_CLS = optimizer_factory.OPTIMIZERS_CLS
LR_CLS = optimizer_factory.LR_CLS
WARMUP_CLS = optimizer_factory.WARMUP_CLS
class OptimizerFactory(optimizer_factory.OptimizerFactory):
"""Optimizer factory class.
This class builds learning rate and optimizer based on an optimization config.
To use this class, you need to do the following:
(1) Define optimization config, this includes optimizer, and learning rate
schedule.
(2) Initialize the class using the optimization config.
(3) Build learning rate.
(4) Build optimizer.
This is a typical example for using this class:
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
},
'warmup': {
'type': 'linear',
'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
}
}
opt_config = OptimizationConfig(params)
opt_factory = OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
"""
def get_bias_lr_schedule(self, bias_lr):
"""Build learning rate.
Builds learning rate from config. Learning rate schedule is built according
to the learning rate config. If learning rate type is consant,
lr_config.learning_rate is returned.
Args:
bias_lr: learning rate config.
Returns:
tf.keras.optimizers.schedules.LearningRateSchedule instance. If
learning rate type is consant, lr_config.learning_rate is returned.
"""
if self._lr_type == 'constant':
lr = self._lr_config.learning_rate
else:
lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
if self._warmup_config:
if self._warmup_type != 'linear':
raise ValueError('Smart Bias is only supported currently with a'
'linear warm up.')
warm_up_cfg = self._warmup_config.as_dict()
warm_up_cfg['warmup_learning_rate'] = bias_lr
lr = WARMUP_CLS['linear'](lr, **warm_up_cfg)
return lr
@gin.configurable
def add_ema(self, optimizer):
"""Add EMA to the optimizer independently of the build optimizer method."""
if self._use_ema:
optimizer = ema_optimizer.ExponentialMovingAverage(
optimizer, **self._ema_config.as_dict())
return optimizer
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SGD PyTorch optimizer."""
import re
from absl import logging
import tensorflow as tf
LearningRateSchedule = tf.keras.optimizers.schedules.LearningRateSchedule
def _var_key(var):
"""Key for representing a primary variable, for looking up slots.
In graph mode the name is derived from the var shared name.
In eager mode the name is derived from the var unique id.
If distribution strategy exists, get the primary variable first.
Args:
var: the variable.
Returns:
the unique name of the variable.
"""
# pylint: disable=protected-access
# Get the distributed variable if it exists.
if hasattr(var, "_distributed_container"):
var = var._distributed_container()
if var._in_graph_mode:
return var._shared_name
return var._unique_id
class SGDTorch(tf.keras.optimizers.Optimizer):
"""Optimizer that simulates the SGD module used in pytorch.
For details on the differences between the original SGD implemention and the
one in pytorch:
https://pytorch.org/docs/stable/generated/torch.optim.SGD.html.
This optimizer also allow for the usage of a momentum warmup along side a
learning rate warm up, though using this is not required.
Example of usage for training:
```python
opt = SGDTorch(learning_rate, weight_decay = 0.0001)
l2_regularization = None
# iterate all model.trainable_variables and split the variables by key
# into the weights, biases, and others.
optimizer.search_and_set_variable_groups(model.trainable_variables)
# if the learning rate schedule on the biases are different. if lr is not set
# the default schedule used for weights will be used on the biases.
opt.set_bias_lr(<lr schedule>)
# if the learning rate schedule on the others are different. if lr is not set
# the default schedule used for weights will be used on the biases.
opt.set_other_lr(<lr schedule>)
```
"""
_HAS_AGGREGATE_GRAD = True
def __init__(self,
weight_decay=0.0,
learning_rate=0.01,
momentum=0.0,
momentum_start=0.0,
warmup_steps=1000,
nesterov=False,
name="SGD",
weight_keys=("kernel", "weight"),
bias_keys=("bias", "beta"),
**kwargs):
super(SGDTorch, self).__init__(name, **kwargs)
# Create Hyper Params for each group of the LR
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
self._set_hyper("bias_learning_rate", kwargs.get("lr", learning_rate))
self._set_hyper("other_learning_rate", kwargs.get("lr", learning_rate))
# SGD decay param
self._set_hyper("decay", self._initial_decay)
# Weight decay param
self._weight_decay = weight_decay != 0.0
self._set_hyper("weight_decay", weight_decay)
# Enable Momentum
self._momentum = False
if isinstance(momentum, tf.Tensor) or callable(momentum) or momentum > 0:
self._momentum = True
if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
raise ValueError("`momentum` must be between [0, 1].")
self._set_hyper("momentum", momentum)
self._set_hyper("momentum_start", momentum_start)
self._set_hyper("warmup_steps", tf.cast(warmup_steps, tf.int32))
# Enable Nesterov Momentum
self.nesterov = nesterov
# weights, biases, other
self._weight_keys = weight_keys
self._bias_keys = bias_keys
self._variables_set = False
self._wset = set()
self._bset = set()
self._oset = set()
logging.info("Pytorch SGD simulation: ")
logging.info("Weight Decay: %f", weight_decay)
def set_bias_lr(self, lr):
self._set_hyper("bias_learning_rate", lr)
def set_other_lr(self, lr):
self._set_hyper("other_learning_rate", lr)
def _search(self, var, keys):
"""Search all all keys for matches. Return True on match."""
if keys is not None:
# variable group is not ignored so search for the keys.
for r in keys:
if re.search(r, var.name) is not None:
return True
return False
def search_and_set_variable_groups(self, variables):
"""Search all variable for matches at each group."""
weights = []
biases = []
others = []
for var in variables:
if self._search(var, self._weight_keys):
# search for weights
weights.append(var)
elif self._search(var, self._bias_keys):
# search for biases
biases.append(var)
else:
# if all searches fail, add to other group
others.append(var)
self._set_variable_groups(weights, biases, others)
return weights, biases, others
def _set_variable_groups(self, weights, biases, others):
"""Sets the variables to be used in each group."""
if self._variables_set:
logging.warning("_set_variable_groups has been called again indicating"
"that the variable groups have already been set, they"
"will be updated.")
self._wset.update(set([_var_key(w) for w in weights]))
self._bset.update(set([_var_key(b) for b in biases]))
self._oset.update(set([_var_key(o) for o in others]))
self._variables_set = True
return
def _get_variable_group(self, var, coefficients):
if self._variables_set:
# check which groups hold which varaibles, preset.
if _var_key(var) in self._wset:
return True, False, False
elif _var_key(var) in self._bset:
return False, True, False
else:
# search the variables at run time.
if self._search(var, self._weight_keys):
return True, False, False
elif self._search(var, self._bias_keys):
return False, True, False
return False, False, True
def _create_slots(self, var_list):
"""Create a momentum variable for each variable."""
if self._momentum:
for var in var_list:
# check if trainable to support GPU EMA.
if var.trainable:
self.add_slot(var, "momentum")
def _get_momentum(self, iteration):
"""Get the momentum value."""
momentum = self._get_hyper("momentum")
momentum_start = self._get_hyper("momentum_start")
momentum_warm_up_steps = tf.cast(
self._get_hyper("warmup_steps"), iteration.dtype)
value = tf.cond(
(iteration - momentum_warm_up_steps) <= 0,
true_fn=lambda: (momentum_start + # pylint: disable=g-long-lambda
(tf.cast(iteration, momentum.dtype) *
(momentum - momentum_start) / tf.cast(
momentum_warm_up_steps, momentum.dtype))),
false_fn=lambda: momentum)
return value
def _prepare_local(self, var_device, var_dtype, apply_state):
super(SGDTorch, self)._prepare_local(var_device, var_dtype, apply_state) # pytype: disable=attribute-error
weight_decay = self._get_hyper("weight_decay")
apply_state[(var_device,
var_dtype)]["weight_decay"] = tf.cast(weight_decay, var_dtype)
if self._momentum:
momentum = self._get_momentum(self.iterations)
momentum = tf.cast(momentum, var_dtype)
apply_state[(var_device,
var_dtype)]["momentum"] = tf.identity(momentum)
bias_lr = self._get_hyper("bias_learning_rate")
if isinstance(bias_lr, LearningRateSchedule):
bias_lr = bias_lr(self.iterations)
bias_lr = tf.cast(bias_lr, var_dtype)
apply_state[(var_device,
var_dtype)]["bias_lr_t"] = tf.identity(bias_lr)
other_lr = self._get_hyper("other_learning_rate")
if isinstance(other_lr, LearningRateSchedule):
other_lr = other_lr(self.iterations)
other_lr = tf.cast(other_lr, var_dtype)
apply_state[(var_device,
var_dtype)]["other_lr_t"] = tf.identity(other_lr)
return apply_state[(var_device, var_dtype)]
def _apply(self, grad, var, weight_decay, momentum, lr):
"""Uses Pytorch Optimizer with Weight decay SGDW."""
dparams = grad
groups = []
# do not update non-trainable weights
if not var.trainable:
return tf.group(*groups)
if self._weight_decay:
dparams += (weight_decay * var)
if self._momentum:
momentum_var = self.get_slot(var, "momentum")
momentum_update = momentum_var.assign(
momentum * momentum_var + dparams, use_locking=self._use_locking)
groups.append(momentum_update)
if self.nesterov:
dparams += (momentum * momentum_update)
else:
dparams = momentum_update
weight_update = var.assign_add(-lr * dparams, use_locking=self._use_locking)
groups.append(weight_update)
return tf.group(*groups)
def _run_sgd(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
self._fallback_apply_state(var_device, var_dtype))
weights, bias, others = self._get_variable_group(var, coefficients)
weight_decay = tf.zeros_like(coefficients["weight_decay"])
lr = coefficients["lr_t"]
if weights:
weight_decay = coefficients["weight_decay"]
lr = coefficients["lr_t"]
elif bias:
weight_decay = tf.zeros_like(coefficients["weight_decay"])
lr = coefficients["bias_lr_t"]
elif others:
weight_decay = tf.zeros_like(coefficients["weight_decay"])
lr = coefficients["other_lr_t"]
momentum = coefficients["momentum"]
return self._apply(grad, var, weight_decay, momentum, lr)
def _resource_apply_dense(self, grad, var, apply_state=None):
return self._run_sgd(grad, var, apply_state=apply_state)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
# This method is only needed for momentum optimization.
holder = tf.tensor_scatter_nd_add(
tf.zeros_like(var), tf.expand_dims(indices, axis=-1), grad)
return self._run_sgd(holder, var, apply_state=apply_state)
def get_config(self):
config = super(SGDTorch, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"decay": self._initial_decay,
"momentum": self._serialize_hyperparameter("momentum"),
"momentum_start": self._serialize_hyperparameter("momentum_start"),
"warmup_steps": self._serialize_hyperparameter("warmup_steps"),
"nesterov": self.nesterov,
})
return config
@property
def learning_rate(self):
return self._optimizer._get_hyper("learning_rate") # pylint: disable=protected-access
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for yolo task."""
import tensorflow as tf
class ListMetrics:
"""Private class used to cleanly place the matric values for each level."""
def __init__(self, metric_names, name="ListMetrics"):
self.name = name
self._metric_names = metric_names
self._metrics = self.build_metric()
return
def build_metric(self):
metric_names = self._metric_names
metrics = []
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
return metrics
def update_state(self, loss_metrics):
metrics = self._metrics
for m in metrics:
m.update_state(loss_metrics[m.name])
return
def result(self):
logs = dict()
metrics = self._metrics
for m in metrics:
logs.update({m.name: m.result()})
return logs
def reset_states(self):
metrics = self._metrics
for m in metrics:
m.reset_states()
return
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains classes used to train Yolo."""
import collections
from typing import Optional
from absl import logging
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.core import input_reader
from official.core import task_factory
from official.modeling import performance
from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.ops import box_ops
from official.vision.beta.projects.yolo import optimization
from official.vision.beta.projects.yolo.configs import yolo as exp_cfg
from official.vision.beta.projects.yolo.dataloaders import tf_example_decoder
from official.vision.beta.projects.yolo.dataloaders import yolo_input
from official.vision.beta.projects.yolo.modeling import factory
from official.vision.beta.projects.yolo.ops import mosaic
from official.vision.beta.projects.yolo.ops import preprocessing_ops
from official.vision.beta.projects.yolo.tasks import task_utils
OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig
@task_factory.register_task_cls(exp_cfg.YoloTask)
class YoloTask(base_task.Task):
"""A single-replica view of training procedure.
YOLO task provides artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss,
post-processing, and customized metrics with reduction.
"""
def __init__(self, params, logging_dir: Optional[str] = None):
super().__init__(params, logging_dir)
self.coco_metric = None
self._loss_fn = None
self._model = None
self._coco_91_to_80 = False
self._metrics = []
# globally set the random seed
preprocessing_ops.set_random_seeds(seed=params.seed)
return
def build_model(self):
"""Build an instance of Yolo."""
model_base_cfg = self.task_config.model
l2_weight_decay = self.task_config.weight_decay / 2.0
input_size = model_base_cfg.input_size.copy()
input_specs = tf.keras.layers.InputSpec(shape=[None] + input_size)
l2_regularizer = (
tf.keras.regularizers.l2(l2_weight_decay) if l2_weight_decay else None)
model, losses = factory.build_yolo(
input_specs, model_base_cfg, l2_regularizer)
# save for later usage within the task.
self._loss_fn = losses
self._model = model
return model
def _get_data_decoder(self, params):
"""Get a decoder object to decode the dataset."""
if params.tfds_name:
decoder = tfds_factory.get_detection_decoder(params.tfds_name)
else:
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
self._coco_91_to_80 = decoder_cfg.coco91_to_80
decoder = tf_example_decoder.TfExampleDecoder(
coco91_to_80=decoder_cfg.coco91_to_80,
regenerate_source_id=decoder_cfg.regenerate_source_id)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
return decoder
def build_inputs(self, params, input_context=None):
"""Build input dataset."""
model = self.task_config.model
# get anchor boxes dict based on models min and max level
backbone = model.backbone.get()
anchor_dict, level_limits = model.anchor_boxes.get(backbone.min_level,
backbone.max_level)
params.seed = self.task_config.seed
# set shared patamters between mosaic and yolo_input
base_config = dict(
letter_box=params.parser.letter_box,
aug_rand_translate=params.parser.aug_rand_translate,
aug_rand_angle=params.parser.aug_rand_angle,
aug_rand_perspective=params.parser.aug_rand_perspective,
area_thresh=params.parser.area_thresh,
random_flip=params.parser.random_flip,
seed=params.seed,
)
# get the decoder
decoder = self._get_data_decoder(params)
# init Mosaic
sample_fn = mosaic.Mosaic(
output_size=model.input_size,
mosaic_frequency=params.parser.mosaic.mosaic_frequency,
mixup_frequency=params.parser.mosaic.mixup_frequency,
jitter=params.parser.mosaic.jitter,
mosaic_center=params.parser.mosaic.mosaic_center,
mosaic_crop_mode=params.parser.mosaic.mosaic_crop_mode,
aug_scale_min=params.parser.mosaic.aug_scale_min,
aug_scale_max=params.parser.mosaic.aug_scale_max,
**base_config)
# init Parser
parser = yolo_input.Parser(
output_size=model.input_size,
anchors=anchor_dict,
use_tie_breaker=params.parser.use_tie_breaker,
jitter=params.parser.jitter,
aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max,
aug_rand_hue=params.parser.aug_rand_hue,
aug_rand_saturation=params.parser.aug_rand_saturation,
aug_rand_brightness=params.parser.aug_rand_brightness,
max_num_instances=params.parser.max_num_instances,
scale_xy=model.detection_generator.scale_xy.get(),
expanded_strides=model.detection_generator.path_scales.get(),
darknet=model.darknet_based_model,
best_match_only=params.parser.best_match_only,
anchor_t=params.parser.anchor_thresh,
random_pad=params.parser.random_pad,
level_limits=level_limits,
dtype=params.dtype,
**base_config)
# init the dataset reader
reader = input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder.decode,
sample_fn=sample_fn.mosaic_fn(is_training=params.is_training),
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_metrics(self, training=True):
"""Build detection metrics."""
metrics = []
backbone = self.task_config.model.backbone.get()
metric_names = collections.defaultdict(list)
for key in range(backbone.min_level, backbone.max_level + 1):
key = str(key)
metric_names[key].append('loss')
metric_names[key].append('avg_iou')
metric_names[key].append('avg_obj')
metric_names['net'].append('box')
metric_names['net'].append('class')
metric_names['net'].append('conf')
for _, key in enumerate(metric_names.keys()):
metrics.append(task_utils.ListMetrics(metric_names[key], name=key))
self._metrics = metrics
if not training:
annotation_file = self.task_config.annotation_file
if self._coco_91_to_80:
annotation_file = None
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=annotation_file,
include_mask=False,
need_rescale_bboxes=False,
per_category_metrics=self._task_config.per_category_metrics)
return metrics
def build_losses(self, outputs, labels, aux_losses=None):
"""Build YOLO losses."""
return self._loss_fn(labels, outputs)
def train_step(self, inputs, model, optimizer, metrics=None):
"""Train Step.
Forward step and backwards propagate the model.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
image, label = inputs
with tf.GradientTape(persistent=False) as tape:
# Compute a prediction
y_pred = model(image, training=True)
# Cast to float32 for gradietn computation
y_pred = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), y_pred)
# Get the total loss
(scaled_loss, metric_loss,
loss_metrics) = self.build_losses(y_pred['raw_output'], label)
# Scale the loss for numerical stability
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
# Compute the gradient
train_vars = model.trainable_variables
gradients = tape.gradient(scaled_loss, train_vars)
# Get unscaled loss if we are using the loss scale optimizer on fp16
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
gradients = optimizer.get_unscaled_gradients(gradients)
# Apply gradients to the model
optimizer.apply_gradients(zip(gradients, train_vars))
logs = {self.loss: metric_loss}
# Compute all metrics
if metrics:
for m in metrics:
m.update_state(loss_metrics[m.name])
logs.update({m.name: m.result()})
return logs
def _reorg_boxes(self, boxes, num_detections, image):
"""Scale and Clean boxes prior to Evaluation."""
# Build a prediciton mask to take only the number of detections
mask = tf.sequence_mask(num_detections, maxlen=tf.shape(boxes)[1])
mask = tf.cast(tf.expand_dims(mask, axis=-1), boxes.dtype)
# Denormalize the boxes by the shape of the image
inshape = tf.cast(preprocessing_ops.get_image_shape(image), boxes.dtype)
boxes = box_ops.denormalize_boxes(boxes, inshape)
# Mask the boxes for usage
boxes *= mask
boxes += (mask - 1)
return boxes
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
image, label = inputs
# Step the model once
y_pred = model(image, training=False)
y_pred = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), y_pred)
(_, metric_loss, loss_metrics) = self.build_losses(y_pred['raw_output'],
label)
logs = {self.loss: metric_loss}
# Reorganize and rescale the boxes
boxes = self._reorg_boxes(y_pred['bbox'], y_pred['num_detections'], image)
label['groundtruths']['boxes'] = self._reorg_boxes(
label['groundtruths']['boxes'], label['groundtruths']['num_detections'],
image)
# Build the input for the coc evaluation metric
coco_model_outputs = {
'detection_boxes': boxes,
'detection_scores': y_pred['confidence'],
'detection_classes': y_pred['classes'],
'num_detections': y_pred['num_detections'],
'source_id': label['groundtruths']['source_id'],
'image_info': label['groundtruths']['image_info']
}
# Compute all metrics
if metrics:
logs.update(
{self.coco_metric.name: (label['groundtruths'], coco_model_outputs)})
for m in metrics:
m.update_state(loss_metrics[m.name])
logs.update({m.name: m.result()})
return logs
def aggregate_logs(self, state=None, step_outputs=None):
"""Get Metric Results."""
if not state:
self.coco_metric.reset_states()
state = self.coco_metric
self.coco_metric.update_state(step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
"""Reduce logs and remove unneeded items. Update with COCO results."""
res = self.coco_metric.result()
return res
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
logging.info('Training from Scratch.')
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def create_optimizer(self,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.YoloOptimizerFactory(optimizer_config)
# pylint: disable=protected-access
ema = opt_factory._use_ema
opt_factory._use_ema = False
opt_type = opt_factory._optimizer_type
if opt_type == 'sgd_torch':
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
optimizer.set_bias_lr(
opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr))
optimizer.search_and_set_variable_groups(self._model.trainable_variables)
else:
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
opt_factory._use_ema = ema
if ema:
logging.info('EMA is enabled.')
optimizer = opt_factory.add_ema(optimizer)
# pylint: enable=protected-access
if runtime_config and runtime_config.loss_scale:
use_float16 = runtime_config.mixed_precision_dtype == 'float16'
optimizer = performance.configure_optimizer(
optimizer,
use_graph_rewrite=False,
use_float16=use_float16,
loss_scale=runtime_config.loss_scale)
return optimizer
......@@ -12,62 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.vision.beta import train
from official.vision.beta.projects.yolo.common import registry_imports # pylint: disable=unused-import
FLAGS = flags.FLAGS
'''
python3 -m official.vision.beta.projects.yolo.train --mode=train_and_eval --experiment=darknet_classification --model_dir=training_dir --config_file=official/vision/beta/projects/yolo/configs/experiments/darknet53_tfds.yaml
'''
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
print(FLAGS.experiment)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
app.run(train.main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment