Commit d4fb52e7 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

model builds

parent c631af40
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common factory functions yolo neural networks."""
from absl import logging
from official.vision.beta.modeling.backbones import factory as backbone_factory
from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.projects.yolo.configs import yolo
from official.vision.beta.projects.yolo.modeling import yolo_model
from official.vision.beta.projects.yolo.modeling.heads import yolo_head
from official.vision.beta.projects.yolo.modeling.layers import detection_generator
def build_yolo_detection_generator(model_config: yolo.Yolo, anchor_boxes):
model = detection_generator.YoloLayer(
classes=model_config.num_classes,
anchors=anchor_boxes,
iou_thresh=model_config.detection_generator.iou_thresh,
nms_thresh=model_config.detection_generator.nms_thresh,
max_boxes=model_config.detection_generator.max_boxes,
pre_nms_points=model_config.detection_generator.pre_nms_points,
nms_type=model_config.detection_generator.nms_type,
box_type=model_config.detection_generator.box_type.get(),
path_scale=model_config.detection_generator.path_scales.get(),
scale_xy=model_config.detection_generator.scale_xy.get(),
label_smoothing=model_config.loss.label_smoothing,
use_scaled_loss=model_config.loss.use_scaled_loss,
update_on_repeat=model_config.loss.update_on_repeat,
truth_thresh=model_config.loss.truth_thresh.get(),
loss_type=model_config.loss.box_loss_type.get(),
max_delta=model_config.loss.max_delta.get(),
iou_normalizer=model_config.loss.iou_normalizer.get(),
cls_normalizer=model_config.loss.cls_normalizer.get(),
obj_normalizer=model_config.loss.obj_normalizer.get(),
ignore_thresh=model_config.loss.ignore_thresh.get(),
objectness_smooth=model_config.loss.objectness_smooth.get())
return model
def build_yolo_head(input_specs, model_config: yolo.Yolo, l2_regularization):
min_level = min(map(int, input_specs.keys()))
max_level = max(map(int, input_specs.keys()))
head = yolo_head.YoloHead(
min_level=min_level,
max_level=max_level,
classes=model_config.num_classes,
boxes_per_level=model_config.anchor_boxes.anchors_per_scale,
norm_momentum=model_config.norm_activation.norm_momentum,
norm_epsilon=model_config.norm_activation.norm_epsilon,
kernel_regularizer=l2_regularization,
smart_bias=model_config.head.smart_bias)
return head
def build_yolo(input_specs, model_config, l2_regularization):
backbone = model_config.backbone.get()
anchor_dict, anchor_free = model_config.anchor_boxes.get(backbone.min_level,
backbone.max_level)
backbone = backbone_factory.build_backbone(input_specs,
model_config.backbone,
model_config.norm_activation,
l2_regularization)
decoder = decoder_factory.build_decoder(backbone.output_specs,
model_config,
l2_regularization)
head = build_yolo_head(decoder.output_specs, model_config, l2_regularization)
detection_generator = build_yolo_detection_generator(model_config,anchor_dict)
model = yolo_model.Yolo(
backbone=backbone,
decoder=decoder,
head=head,
detection_generator=detection_generator)
model.build(input_specs.shape)
model.summary(print_fn=logging.info)
if anchor_free is not None:
logging.info(f"Anchor Boxes: None -> Model is operating anchor-free.")
logging.info(" --> anchors_per_scale set to 1. ")
else:
logging.info(f"Anchor Boxes: {anchor_dict}")
losses = detection_generator.get_losses()
return model, losses
......@@ -15,6 +15,7 @@
"""Contains common building blocks for yolo neural networks."""
from typing import Callable, List, Tuple
import tensorflow as tf
import logging
from official.modeling import tf_utils
from official.vision.beta.ops import spatial_transform_ops
......@@ -141,6 +142,7 @@ class ConvBN(tf.keras.layers.Layer):
# activation params
self._activation = activation
self._leaky_alpha = leaky_alpha
self._fuse = False
super().__init__(**kwargs)
......@@ -164,6 +166,8 @@ class ConvBN(tf.keras.layers.Layer):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
axis=self._bn_axis)
else:
self.bn = None
if self._activation == 'leaky':
self._activation_fn = tf.keras.layers.LeakyReLU(alpha=self._leaky_alpha)
......@@ -174,11 +178,46 @@ class ConvBN(tf.keras.layers.Layer):
def call(self, x):
x = self.conv(x)
if self._use_bn:
if self._use_bn and not self._fuse:
x = self.bn(x)
x = self._activation_fn(x)
return x
def fuse(self):
if self.bn is not None and not self._use_separable_conv:
# Fuse convolution and batchnorm, gives me +2 to 3 FPS 2ms latency.
# layers: https://tehnokv.com/posts/fusing-batchnorm-and-conv/
if self._fuse:
return
self._fuse = True
conv_weights = self.conv.get_weights()[0]
gamma, beta, moving_mean, moving_variance = self.bn.get_weights()
self.conv.use_bias = True
infilters = conv_weights.shape[-2]
self.conv.build([None, None, None, infilters])
base = tf.sqrt(self._norm_epsilon + moving_variance)
w_conv_base = tf.transpose(conv_weights, perm = (3, 2, 0, 1))
w_conv = tf.reshape(w_conv_base, [conv_weights.shape[-1], -1])
w_bn = tf.linalg.diag(gamma/base)
w_conv = tf.reshape(tf.matmul(w_bn, w_conv), w_conv_base.get_shape())
w_conv = tf.transpose(w_conv, perm = (2, 3, 1, 0))
b_bn = beta - gamma * moving_mean/base
self.conv.set_weights([w_conv, b_bn])
del self.bn
self.trainable = False
self.conv.trainable = False
self.bn = None
logging.info(f"fusing: {self.name} -> no longer trainable")
return
def get_config(self):
# used to store/share parameters to reconstruct the model
layer_config = {
......
......@@ -15,62 +15,7 @@
"""Yolo models."""
import tensorflow as tf
# static base Yolo Models that do not require configuration
# similar to a backbone model id.
# this is done greatly simplify the model config
# the structure is as follows. model version, {v3, v4, v#, ... etc}
# the model config type {regular, tiny, small, large, ... etc}
YOLO_MODELS = {
"v4":
dict(
regular=dict(
embed_spp=False,
use_fpn=True,
max_level_process_len=None,
path_process_len=6),
tiny=dict(
embed_spp=False,
use_fpn=False,
max_level_process_len=2,
path_process_len=1),
csp=dict(
embed_spp=False,
use_fpn=True,
max_level_process_len=None,
csp_stack=5,
fpn_depth=5,
path_process_len=6),
csp_large=dict(
embed_spp=False,
use_fpn=True,
max_level_process_len=None,
csp_stack=7,
fpn_depth=7,
path_process_len=8,
fpn_filter_scale=2),
),
"v3":
dict(
regular=dict(
embed_spp=False,
use_fpn=False,
max_level_process_len=None,
path_process_len=6),
tiny=dict(
embed_spp=False,
use_fpn=False,
max_level_process_len=2,
path_process_len=1),
spp=dict(
embed_spp=True,
use_fpn=False,
max_level_process_len=2,
path_process_len=1),
),
}
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
class Yolo(tf.keras.Model):
"""The YOLO model class."""
......@@ -82,21 +27,20 @@ class Yolo(tf.keras.Model):
detection_generator=None,
**kwargs):
"""Detection initialization function.
Args:
backbone: `tf.keras.Model` a backbone network.
decoder: `tf.keras.Model` a decoder network.
head: `RetinaNetHead`, the RetinaNet head.
detection_generator: the detection generator.
filter: the detection generator.
**kwargs: keyword arguments to be passed.
"""
super(Yolo, self).__init__(**kwargs)
self._config_dict = {
"backbone": backbone,
"decoder": decoder,
"head": head,
"filter": detection_generator
'backbone': backbone,
'decoder': decoder,
'head': head,
'detection_generator': detection_generator
}
# model components
......@@ -104,6 +48,7 @@ class Yolo(tf.keras.Model):
self._decoder = decoder
self._head = head
self._detection_generator = detection_generator
self._fused = False
return
def call(self, inputs, training=False):
......@@ -142,10 +87,10 @@ class Yolo(tf.keras.Model):
return cls(**config)
def get_weight_groups(self, train_vars):
"""Sort the list of trainable variables into groups for optimization.
"""Sort the list of trainable variables into groups for optimization.
Args:
train_vars: a list of tf.Variables that need to get sorted into their
train_vars: a list of tf.Variables that need to get sorted into their
respective groups.
Returns:
......@@ -166,3 +111,13 @@ class Yolo(tf.keras.Model):
else:
other.append(var)
return weights, bias, other
def fuse(self):
print("Fusing Conv Batch Norm Layers.")
if not self._fused:
self._fused = True
for layer in self.submodules:
if isinstance(layer, nn_blocks.ConvBN):
layer.fuse()
self.summary()
return
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization package definition."""
# pylint: disable=wildcard-import
from official.modeling.optimization.configs.learning_rate_config import *
from official.vision.beta.projects.yolo.optimization.configs.optimization_config import *
from official.vision.beta.projects.yolo.optimization.configs.optimizer_config import *
from official.vision.beta.projects.yolo.optimization.optimizer_factory import OptimizerFactory
from official.vision.beta.projects.yolo.optimization.optimizer_factory import OptimizerFactory as YoloOptimizerFactory
from official.modeling.optimization.ema_optimizer import ExponentialMovingAverage
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclasses for learning rate schedule config."""
from typing import List, Optional
import dataclasses
from official.modeling.hyperparams import base_config
# zfrom official.modeling.optimization.configs import learning_rate_config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclasses for optimization configs.
This file define the dataclass for optimization configs (OptimizationConfig).
It also has two helper functions get_optimizer_config, and get_lr_config from
an OptimizationConfig class.
"""
from typing import Optional
import dataclasses
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import oneof
from official.vision.beta.projects.yolo.optimization.configs import learning_rate_config as lr_cfg
from official.vision.beta.projects.yolo.optimization.configs import optimizer_config as opt_cfg
from official.modeling.optimization.configs import optimization_config as optimization_cfg
@dataclasses.dataclass
class OptimizerConfig(optimization_cfg.OptimizerConfig):
"""Configuration for optimizer.
Attributes:
type: 'str', type of optimizer to be used, on the of fields below.
sgd: sgd optimizer config.
adam: adam optimizer config.
adamw: adam with weight decay.
lamb: lamb optimizer.
rmsprop: rmsprop optimizer.
"""
type: Optional[str] = None
sgd_torch: opt_cfg.SGDTorchConfig = opt_cfg.SGDTorchConfig()
@dataclasses.dataclass
class OptimizationConfig(optimization_cfg.OptimizationConfig):
"""Configuration for optimizer and learning rate schedule.
Attributes:
optimizer: optimizer oneof config.
ema: optional exponential moving average optimizer config, if specified,
ema optimizer will be used.
learning_rate: learning rate oneof config.
warmup: warmup oneof config.
"""
type: Optional[str] = None
optimizer: OptimizerConfig = OptimizerConfig()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclasses for optimizer configs."""
from typing import List, Optional
import dataclasses
from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs.optimizer_config import BaseOptimizerConfig
@dataclasses.dataclass
class BaseOptimizerConfig(base_config.Config):
"""Base optimizer config.
Attributes:
clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
their L2 norm exceeds this value.
clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
their absolute value exceeds this value.
global_clipnorm: float >= 0 or None. If not None, gradient of all weights is
clipped so that their global norm is no higher than this value
"""
clipnorm: Optional[float] = None
clipvalue: Optional[float] = None
global_clipnorm: Optional[float] = None
@dataclasses.dataclass
class SGDTorchConfig(BaseOptimizerConfig):
"""Configuration for SGD optimizer.
The attributes for this class matches the arguments of tf.keras.optimizer.SGD.
Attributes:
name: name of the optimizer.
decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer.
momentum_start: momentum starting point for SGD optimizer.
momentum: momentum for SGD optimizer.
"""
name: str = "SGD"
decay: float = 0.0
nesterov: bool = False
momentum_start: float = 0.0
momentum: float = 0.9
warmup_steps: int = 1000
weight_decay: float = 0.0
sim_torch: bool = False
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimizer factory class."""
import gin
from official.vision.beta.projects.yolo.optimization import sgd_torch
from official.modeling.optimization import ema_optimizer
from official.modeling.optimization import optimizer_factory
optimizer_factory.OPTIMIZERS_CLS.update({
'sgd_torch': sgd_torch.SGDTorch,
})
OPTIMIZERS_CLS = optimizer_factory.OPTIMIZERS_CLS
LR_CLS = optimizer_factory.LR_CLS
WARMUP_CLS = optimizer_factory.WARMUP_CLS
class OptimizerFactory(optimizer_factory.OptimizerFactory):
"""Optimizer factory class.
This class builds learning rate and optimizer based on an optimization config.
To use this class, you need to do the following:
(1) Define optimization config, this includes optimizer, and learning rate
schedule.
(2) Initialize the class using the optimization config.
(3) Build learning rate.
(4) Build optimizer.
This is a typical example for using this class:
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
},
'warmup': {
'type': 'linear',
'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
}
}
opt_config = OptimizationConfig(params)
opt_factory = OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
"""
def get_bias_lr_schedule(self, bias_lr):
"""Used to build additional learning rate schedules."""
temp = self._warmup_config.warmup_learning_rate
self._warmup_config.warmup_learning_rate = bias_lr
lr = self.build_learning_rate()
self._warmup_config.warmup_learning_rate = temp
return lr
@gin.configurable
def add_ema(self, optimizer):
"""Add EMA to the optimizer independently of the build optimizer method."""
if self._use_ema:
optimizer = ema_optimizer.ExponentialMovingAverage(
optimizer, **self._ema_config.as_dict())
return optimizer
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.python.training import gen_training_ops
import tensorflow as tf
import logging
__all__ = ['SGDTorch']
def _var_key(var):
"""Key for representing a primary variable, for looking up slots.
In graph mode the name is derived from the var shared name.
In eager mode the name is derived from the var unique id.
If distribution strategy exists, get the primary variable first.
Args:
var: the variable.
Returns:
the unique name of the variable.
"""
# pylint: disable=protected-access
# Get the distributed variable if it exists.
if hasattr(var, "_distributed_container"):
var = var._distributed_container()
if var._in_graph_mode:
return var._shared_name
return var._unique_id
class SGDTorch(tf.keras.optimizers.Optimizer):
"""Optimizer that computes an exponential moving average of the variables.
Empirically it has been found that using the moving average of the trained
parameters of a deep network is better than using its trained parameters
directly. This optimizer allows you to compute this moving average and swap
the variables at save time so that any code outside of the training loop
will use by default the average values instead of the original ones.
Example of usage for training:
```python
opt = tf.keras.optimizers.SGD(learning_rate)
opt = ExponentialMovingAverage(opt)
opt.shadow_copy(model)
```
At test time, swap the shadow variables to evaluate on the averaged weights:
```python
opt.swap_weights()
# Test eval the model here
opt.swap_weights()
```
"""
_HAS_AGGREGATE_GRAD = True
def __init__(self,
weight_decay=0.0,
learning_rate=0.01,
momentum=0.0,
momentum_start=0.0,
warmup_steps=1000,
nesterov=False,
sim_torch=False,
name="SGD",
**kwargs):
super(SGDTorch, self).__init__(name, **kwargs)
# Create Hyper Params for each group of the LR
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
self._set_hyper("bias_learning_rate", kwargs.get("lr", learning_rate))
self._set_hyper("other_learning_rate", kwargs.get("lr", learning_rate))
# SGD decay param
self._set_hyper("decay", self._initial_decay)
# Weight decay param
self._weight_decay = weight_decay != 0.0
self._set_hyper("weight_decay", weight_decay)
# Enable Momentum
self._momentum = False
if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
self._momentum = True
if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
raise ValueError("`momentum` must be between [0, 1].")
self._set_hyper("momentum", momentum)
self._set_hyper("momentum_start", momentum_start)
self._set_hyper("warmup_steps", tf.cast(warmup_steps, tf.int32))
# Enable Nesterov Momentum
self.nesterov = nesterov
# Simulate Pytorch Optimizer
self.sim_torch = sim_torch
# weights, biases, other
self._wset = set()
self._bset = set()
self._oset = set()
logging.info(f"Pytorch SGD simulation: ")
logging.info(f"Weight Decay: {weight_decay}")
def set_bias_lr(self, lr):
self._set_hyper("bias_learning_rate", lr)
def set_other_lr(self, lr):
self._set_hyper("other_learning_rate", lr)
def set_params(self, weights, biases, others):
self._wset = set([_var_key(w) for w in weights])
self._bset = set([_var_key(b) for b in biases])
self._oset = set([_var_key(o) for o in others])
logging.info(
f"Weights: {len(weights)} Biases: {len(biases)} Others: {len(others)}")
return
def _create_slots(self, var_list):
if self._momentum:
for var in var_list:
self.add_slot(var, "momentum")
def _get_momentum(self, iteration):
momentum = self._get_hyper("momentum")
momentum_start = self._get_hyper("momentum_start")
momentum_warm_up_steps = tf.cast(
self._get_hyper("warmup_steps"), iteration.dtype)
value = tf.cond(
(iteration - momentum_warm_up_steps) <= 0,
true_fn=lambda: (momentum_start +
(tf.cast(iteration, momentum.dtype) *
(momentum - momentum_start) / tf.cast(
momentum_warm_up_steps, momentum.dtype))),
false_fn=lambda: momentum)
return value
def _prepare_local(self, var_device, var_dtype, apply_state):
super(SGDTorch, self)._prepare_local(var_device, var_dtype,
apply_state)
weight_decay = self._get_hyper("weight_decay")
apply_state[(var_device,
var_dtype)]["weight_decay"] = tf.cast(weight_decay, var_dtype)
if self._momentum:
momentum = self._get_momentum(self.iterations)
momentum = tf.cast(momentum, var_dtype)
apply_state[(var_device,
var_dtype)]["momentum"] = array_ops.identity(momentum)
bias_lr = self._get_hyper("bias_learning_rate")
if isinstance(bias_lr, LearningRateSchedule):
bias_lr = bias_lr(self.iterations)
bias_lr = tf.cast(bias_lr, var_dtype)
apply_state[(var_device,
var_dtype)]["bias_lr_t"] = array_ops.identity(bias_lr)
other_lr = self._get_hyper("other_learning_rate")
if isinstance(other_lr, LearningRateSchedule):
other_lr = other_lr(self.iterations)
other_lr = tf.cast(other_lr, var_dtype)
apply_state[(var_device,
var_dtype)]["other_lr_t"] = array_ops.identity(other_lr)
return apply_state[(var_device, var_dtype)]
def _apply_tf(self, grad, var, weight_decay, momentum, lr):
def decay_op(var, learning_rate, wd):
if self._weight_decay and wd > 0:
return var.assign_sub(
learning_rate * var * wd, use_locking=self._use_locking)
return tf.no_op()
decay = decay_op(var, lr, weight_decay)
with tf.control_dependencies([decay]):
if self._momentum:
momentum_var = self.get_slot(var, "momentum")
return gen_training_ops.ResourceApplyKerasMomentum(
var=var.handle,
accum=momentum_var.handle,
lr=lr,
grad=grad,
momentum=momentum,
use_locking=self._use_locking,
use_nesterov=self.nesterov)
else:
return gen_training_ops.ResourceApplyGradientDescent(
var=var.handle, alpha=lr, delta=grad, use_locking=self._use_locking)
def _apply(self, grad, var, weight_decay, momentum, lr):
dparams = grad
groups = []
# do not update non-trainable weights
if not var.trainable:
return tf.group(*groups)
if self._weight_decay:
dparams += (weight_decay * var)
if self._momentum:
momentum_var = self.get_slot(var, "momentum")
momentum_update = momentum_var.assign(
momentum * momentum_var + dparams, use_locking=self._use_locking)
groups.append(momentum_update)
if self.nesterov:
dparams += (momentum * momentum_update)
else:
dparams = momentum_update
weight_update = var.assign_add(-lr * dparams, use_locking=self._use_locking)
groups.append(weight_update)
return tf.group(*groups)
def _get_vartype(self, var, coefficients):
if (_var_key(var) in self._wset):
return True, False, False
elif (_var_key(var) in self._bset):
return False, True, False
return False, False, True
def _run_sgd(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
self._fallback_apply_state(var_device, var_dtype))
weights, bias, others = self._get_vartype(var, coefficients)
weight_decay = tf.zeros_like(coefficients["weight_decay"])
lr = coefficients["lr_t"]
if weights:
weight_decay = coefficients["weight_decay"]
lr = coefficients["lr_t"]
elif bias:
weight_decay = tf.zeros_like(coefficients["weight_decay"])
lr = coefficients["bias_lr_t"]
elif others:
weight_decay = tf.zeros_like(coefficients["weight_decay"])
lr = coefficients["other_lr_t"]
momentum = coefficients["momentum"]
if self.sim_torch:
return self._apply(grad, var, weight_decay, momentum, lr)
else:
return self._apply_tf(grad, var, weight_decay, momentum, lr)
def _resource_apply_dense(self, grad, var, apply_state=None):
return self._run_sgd(grad, var, apply_state=apply_state)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
# This method is only needed for momentum optimization.
holder = tf.tensor_scatter_nd_add(
tf.zeros_like(var), tf.expand_dims(indices, axis = -1), grad)
return self._run_sgd(holder, var, apply_state=apply_state)
def get_config(self):
config = super(SGDTorch, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"decay": self._initial_decay,
"momentum": self._serialize_hyperparameter("momentum"),
"momentum_start": self._serialize_hyperparameter("momentum_start"),
"warmup_steps": self._serialize_hyperparameter("warmup_steps"),
"nesterov": self.nesterov,
})
return config
@property
def learning_rate(self):
return self._optimizer._get_hyper('learning_rate')
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains classes used to train Yolo."""
from absl import logging
import collections
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.core import config_definitions
from official.modeling import performance
from official.vision.beta.ops import box_ops
from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.dataloaders import tfds_detection_decoders
from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.projects.yolo import optimization
from official.vision.beta.projects.yolo.ops import mosaic
from official.vision.beta.projects.yolo.ops import preprocessing_ops
from official.vision.beta.projects.yolo.dataloaders import yolo_input
from official.vision.beta.projects.yolo.dataloaders import tf_example_decoder
from official.vision.beta.projects.yolo.configs import yolo as exp_cfg
import tensorflow as tf
from typing import Optional
OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig
@task_factory.register_task_cls(exp_cfg.YoloTask)
class YoloTask(base_task.Task):
"""A single-replica view of training procedure.
YOLO task provides artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss,
post-processing, and customized metrics with reduction.
"""
def __init__(self, params, logging_dir: str = None):
super().__init__(params, logging_dir)
self.coco_metric = None
self._loss_fn = None
self._model = None
self._metrics = []
# globally set the random seed
preprocessing_ops.set_random_seeds(seed=params.train_data.seed)
return
def build_model(self):
"""Build an instance of Yolo."""
from official.vision.beta.projects.yolo.modeling.factory import build_yolo
model_base_cfg = self.task_config.model
l2_weight_decay = self.task_config.weight_decay / 2.0
input_size = model_base_cfg.input_size.copy()
input_specs = tf.keras.layers.InputSpec(shape=[None] + input_size)
l2_regularizer = (
tf.keras.regularizers.l2(l2_weight_decay) if l2_weight_decay else None)
model, losses = build_yolo(input_specs, model_base_cfg,
l2_regularizer)
# save for later usage within the task.
self._loss_fn = losses
self._model = model
return model
def get_decoder(self, params):
"""Get a decoder object to decode the dataset."""
if params.tfds_name:
if params.tfds_name in tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else:
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
coco91_to_80=decoder_cfg.coco91_to_80,
regenerate_source_id=decoder_cfg.regenerate_source_id)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
return decoder
def build_inputs(self, params, input_context=None):
"""Build input dataset."""
model = self.task_config.model
# get anchor boxes dict based on models min and max level
backbone = model.backbone.get()
anchor_dict, level_limits = model.anchor_boxes.get(backbone.min_level,
backbone.max_level)
# set shared patamters between mosaic and yolo_input
base_config = dict(
letter_box=params.parser.letter_box,
aug_rand_translate=params.parser.aug_rand_translate,
aug_rand_angle=params.parser.aug_rand_angle,
aug_rand_perspective=params.parser.aug_rand_perspective,
area_thresh=params.parser.area_thresh,
random_flip=params.parser.random_flip,
seed=params.seed,
)
# get the decoder
decoder = self.get_decoder(params)
# init Mosaic
sample_fn = mosaic.Mosaic(
output_size=model.input_size,
mosaic_frequency=params.parser.mosaic.mosaic_frequency,
mixup_frequency=params.parser.mosaic.mixup_frequency,
jitter=params.parser.mosaic.jitter,
mosaic_center=params.parser.mosaic.mosaic_center,
mosaic_crop_mode=params.parser.mosaic.mosaic_crop_mode,
aug_scale_min=params.parser.mosaic.aug_scale_min,
aug_scale_max=params.parser.mosaic.aug_scale_max,
**base_config)
# init Parser
parser = yolo_input.Parser(
output_size=model.input_size,
anchors=anchor_dict,
use_tie_breaker=params.parser.use_tie_breaker,
jitter=params.parser.jitter,
aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max,
aug_rand_hue=params.parser.aug_rand_hue,
aug_rand_saturation=params.parser.aug_rand_saturation,
aug_rand_brightness=params.parser.aug_rand_brightness,
max_num_instances=params.parser.max_num_instances,
scale_xy=model.detection_generator.scale_xy.get(),
expanded_strides=model.detection_generator.path_scales.get(),
darknet=model.darknet_based_model,
best_match_only=params.parser.best_match_only,
anchor_t=params.parser.anchor_thresh,
random_pad=params.parser.random_pad,
level_limits=level_limits,
dtype=params.dtype,
**base_config)
# init the dataset reader
reader = input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder.decode,
sample_fn=sample_fn.mosaic_fn(is_training=params.is_training),
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_metrics(self, training=True):
"""Build detection metrics."""
metrics = []
backbone = self.task_config.model.backbone.get()
metric_names = collections.defaultdict(list)
for key in range(backbone.min_level, backbone.max_level + 1):
key = str(key)
metric_names[key].append('loss')
metric_names[key].append("avg_iou")
metric_names[key].append("avg_obj")
metric_names['net'].append('box')
metric_names['net'].append('class')
metric_names['net'].append('conf')
for i, key in enumerate(metric_names.keys()):
metrics.append(ListMetrics(metric_names[key], name=key))
self._metrics = metrics
if not training:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self.task_config.annotation_file,
include_mask=False,
need_rescale_bboxes=False,
per_category_metrics=self._task_config.per_category_metrics)
return metrics
def build_losses(self, outputs, labels, aux_losses=None):
"""Build YOLO losses."""
return self._loss_fn(labels, outputs)
def train_step(self, inputs, model, optimizer, metrics=None):
"""Train Step. Forward step and backwards propagate the model.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
image, label = inputs
with tf.GradientTape(persistent=False) as tape:
# Compute a prediction
y_pred = model(image, training=True)
# Cast to float32 for gradietn computation
y_pred = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), y_pred)
# Get the total loss
(scaled_loss, metric_loss,
loss_metrics) = self.build_losses(y_pred['raw_output'], label)
# Scale the loss for numerical stability
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
# Compute the gradient
train_vars = model.trainable_variables
gradients = tape.gradient(scaled_loss, train_vars)
# Get unscaled loss if we are using the loss scale optimizer on fp16
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
gradients = optimizer.get_unscaled_gradients(gradients)
# Clip the gradients
if self.task_config.gradient_clip_norm > 0.0:
gradients, _ = tf.clip_by_global_norm(gradients,
self.task_config.gradient_clip_norm)
# Apply gradients to the model
optimizer.apply_gradients(zip(gradients, train_vars))
logs = {self.loss: metric_loss}
# Compute all metrics
if metrics:
for m in metrics:
m.update_state(loss_metrics[m.name])
logs.update({m.name: m.result()})
return logs
def _reorg_boxes(self, boxes, num_detections, image):
"""Scale and Clean boxes prior to Evaluation."""
# Build a prediciton mask to take only the number of detections
mask = tf.sequence_mask(num_detections, maxlen=tf.shape(boxes)[1])
mask = tf.cast(tf.expand_dims(mask, axis=-1), boxes.dtype)
# Denormalize the boxes by the shape of the image
inshape = tf.cast(preprocessing_ops.get_image_shape(image), boxes.dtype)
boxes = box_ops.denormalize_boxes(boxes, inshape)
# Mask the boxes for usage
boxes *= mask
boxes += (mask - 1)
return boxes
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
image, label = inputs
# Step the model once
y_pred = model(image, training=False)
y_pred = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), y_pred)
(_, metric_loss, loss_metrics) = self.build_losses(y_pred['raw_output'],
label)
logs = {self.loss: metric_loss}
# Reorganize and rescale the boxes
boxes = self._reorg_boxes(y_pred['bbox'], y_pred['num_detections'], image)
label['groundtruths']["boxes"] = self._reorg_boxes(
label['groundtruths']["boxes"], label['groundtruths']["num_detections"],
image)
# Build the input for the coc evaluation metric
coco_model_outputs = {
'detection_boxes': boxes,
'detection_scores': y_pred['confidence'],
'detection_classes': y_pred['classes'],
'num_detections': y_pred['num_detections'],
'source_id': label['groundtruths']['source_id'],
'image_info': label['groundtruths']['image_info']
}
# Compute all metrics
if metrics:
logs.update(
{self.coco_metric.name: (label['groundtruths'], coco_model_outputs)})
for m in metrics:
m.update_state(loss_metrics[m.name])
logs.update({m.name: m.result()})
return logs
def aggregate_logs(self, state=None, step_outputs=None):
"""Get Metric Results."""
if not state:
self.coco_metric.reset_states()
state = self.coco_metric
self.coco_metric.update_state(step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
"""Reduce logs and remove unneeded items. Update with COCO results."""
res = self.coco_metric.result()
return res
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
logging.info("Training from Scratch.")
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(model=model)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'decoder':
ckpt = tf.train.Checkpoint(backbone=model.backbone, decoder=model.decoder)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def _wrap_optimizer(self, optimizer, runtime_config):
"""Wraps the optimizer object with the loss scale optimizer."""
if runtime_config and runtime_config.loss_scale:
use_float16 = runtime_config.mixed_precision_dtype == "float16"
optimizer = performance.configure_optimizer(
optimizer,
use_graph_rewrite=False,
use_float16=use_float16,
loss_scale=runtime_config.loss_scale)
return optimizer
def create_optimizer(self,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.YoloOptimizerFactory(optimizer_config)
ema = opt_factory._use_ema
opt_factory._use_ema = False
opt_type = opt_factory._optimizer_type
if (opt_type == 'sgd_torch'):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
optimizer.set_bias_lr(
opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr))
weights, biases, others = self._model.get_weight_groups(
self._model.trainable_variables)
optimizer.set_params(weights, biases, others)
else:
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
opt_factory._use_ema = ema
if ema:
logging.info("EMA is enabled.")
optimizer = opt_factory.add_ema(optimizer)
optimizer = self._wrap_optimizer(optimizer, runtime_config)
return optimizer
class ListMetrics:
def __init__(self, metric_names, name="ListMetrics", **kwargs):
self.name = name
self._metric_names = metric_names
self._metrics = self.build_metric()
return
def build_metric(self):
metric_names = self._metric_names
metrics = []
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
return metrics
def update_state(self, loss_metrics):
metrics = self._metrics
for m in metrics:
m.update_state(loss_metrics[m.name])
return
def result(self):
logs = dict()
metrics = self._metrics
for m in metrics:
logs.update({m.name: m.result()})
return logs
def reset_states(self):
metrics = self._metrics
for m in metrics:
m.reset_states()
return
\ No newline at end of file
......@@ -31,6 +31,9 @@ FLAGS = flags.FLAGS
'''
python3 -m official.vision.beta.projects.yolo.train --mode=train_and_eval --experiment=darknet_classification --model_dir=training_dir --config_file=official/vision/beta/projects/yolo/configs/experiments/darknet53_tfds.yaml
python3.8 -m official.vision.beta.projects.yolo.train --experiment=yolo_darknet --mode train_and_eval --config_file yolo/configs/experiments/yolov4/inference/512-swin.yaml --model_dir ../checkpoints/test-swin
'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment