Commit d5747aac authored by vishnubanna's avatar vishnubanna
Browse files

ready for review

parent 31c3ab9e
...@@ -77,11 +77,11 @@ class layer_factory(object): ...@@ -77,11 +77,11 @@ class layer_factory(object):
""" """
def __init__(self): def __init__(self):
self._layer_dict = { self._layer_dict = {
"DarkConv": (nn_blocks.DarkConv, self.darkconv_config_todict), "ConvBN": (nn_blocks.ConvBN, self.ConvBN_config_todict),
"MaxPool": (tf.keras.layers.MaxPool2D, self.maxpool_config_todict) "MaxPool": (tf.keras.layers.MaxPool2D, self.maxpool_config_todict)
} }
def darkconv_config_todict(self, config, kwargs): def ConvBN_config_todict(self, config, kwargs):
dictvals = { dictvals = {
"filters": config.filters, "filters": config.filters,
"kernel_size": config.kernel_size, "kernel_size": config.kernel_size,
...@@ -124,7 +124,7 @@ CSPDARKNET53 = { ...@@ -124,7 +124,7 @@ CSPDARKNET53 = {
"splits": {"backbone_split": 106, "splits": {"backbone_split": 106,
"neck_split": 138}, "neck_split": 138},
"backbone": [ "backbone": [
["DarkConv", None, 1, False, 32, None, 3, 1, "same", "mish", -1, 0, False], ["ConvBN", None, 1, False, 32, None, 3, 1, "same", "mish", -1, 0, False],
["DarkRes", "csp", 1, True, 64, None, None, None, None, "mish", -1, 1, False], ["DarkRes", "csp", 1, True, 64, None, None, None, None, "mish", -1, 1, False],
["DarkRes", "csp", 2, False, 128, None, None, None, None, "mish", -1, 2, False], ["DarkRes", "csp", 2, False, 128, None, None, None, None, "mish", -1, 2, False],
["DarkRes", "csp", 8, False, 256, None, None, None, None, "mish", -1, 3, True], ["DarkRes", "csp", 8, False, 256, None, None, None, None, "mish", -1, 3, True],
...@@ -137,7 +137,7 @@ DARKNET53 = { ...@@ -137,7 +137,7 @@ DARKNET53 = {
"list_names": LISTNAMES, "list_names": LISTNAMES,
"splits": {"backbone_split": 76}, "splits": {"backbone_split": 76},
"backbone": [ "backbone": [
["DarkConv", None, 1, False, 32, None, 3, 1, "same", "leaky", -1, 0, False], ["ConvBN", None, 1, False, 32, None, 3, 1, "same", "leaky", -1, 0, False],
["DarkRes", "residual", 1, True, 64, None, None, None, None, "leaky", -1, 1, False], ["DarkRes", "residual", 1, True, 64, None, None, None, None, "leaky", -1, 1, False],
["DarkRes", "residual", 2, False, 128, None, None, None, None, "leaky", -1, 2, False], ["DarkRes", "residual", 2, False, 128, None, None, None, None, "leaky", -1, 2, False],
["DarkRes", "residual", 8, False, 256, None, None, None, None, "leaky", -1, 3, True], ["DarkRes", "residual", 8, False, 256, None, None, None, None, "leaky", -1, 3, True],
...@@ -150,12 +150,12 @@ CSPDARKNETTINY = { ...@@ -150,12 +150,12 @@ CSPDARKNETTINY = {
"list_names": LISTNAMES, "list_names": LISTNAMES,
"splits": {"backbone_split": 28}, "splits": {"backbone_split": 28},
"backbone": [ "backbone": [
["DarkConv", None, 1, False, 32, None, 3, 2, "same", "leaky", -1, 0, False], ["ConvBN", None, 1, False, 32, None, 3, 2, "same", "leaky", -1, 0, False],
["DarkConv", None, 1, False, 64, None, 3, 2, "same", "leaky", -1, 1, False], ["ConvBN", None, 1, False, 64, None, 3, 2, "same", "leaky", -1, 1, False],
["CSPTiny", "csp_tiny", 1, False, 64, None, 3, 2, "same", "leaky", -1, 2, False], ["CSPTiny", "csp_tiny", 1, False, 64, None, 3, 2, "same", "leaky", -1, 2, False],
["CSPTiny", "csp_tiny", 1, False, 128, None, 3, 2, "same", "leaky", -1, 3, False], ["CSPTiny", "csp_tiny", 1, False, 128, None, 3, 2, "same", "leaky", -1, 3, False],
["CSPTiny", "csp_tiny", 1, False, 256, None, 3, 2, "same", "leaky", -1, 4, True], ["CSPTiny", "csp_tiny", 1, False, 256, None, 3, 2, "same", "leaky", -1, 4, True],
["DarkConv", None, 1, False, 512, None, 3, 1, "same", "leaky", -1, 5, True], ["ConvBN", None, 1, False, 512, None, 3, 1, "same", "leaky", -1, 5, True],
] ]
} }
...@@ -163,7 +163,7 @@ DARKNETTINY = { ...@@ -163,7 +163,7 @@ DARKNETTINY = {
"list_names": LISTNAMES, "list_names": LISTNAMES,
"splits": {"backbone_split": 14}, "splits": {"backbone_split": 14},
"backbone": [ "backbone": [
["DarkConv", None, 1, False, 16, None, 3, 1, "same", "leaky", -1, 0, False], ["ConvBN", None, 1, False, 16, None, 3, 1, "same", "leaky", -1, 0, False],
["DarkTiny", "tiny", 1, True, 32, None, 3, 2, "same", "leaky", -1, 1, False], ["DarkTiny", "tiny", 1, True, 32, None, 3, 2, "same", "leaky", -1, 1, False],
["DarkTiny", "tiny", 1, True, 64, None, 3, 2, "same", "leaky", -1, 2, False], ["DarkTiny", "tiny", 1, True, 64, None, 3, 2, "same", "leaky", -1, 2, False],
["DarkTiny", "tiny", 1, False, 128, None, 3, 2, "same", "leaky", -1, 3, False], ["DarkTiny", "tiny", 1, False, 128, None, 3, 2, "same", "leaky", -1, 3, False],
...@@ -292,27 +292,28 @@ class Darknet(ks.Model): ...@@ -292,27 +292,28 @@ class Darknet(ks.Model):
def _csp_stack(self, inputs, config, name): def _csp_stack(self, inputs, config, name):
if config.bottleneck: if config.bottleneck:
csp_filter_reduce = 1 csp_filter_scale = 1
residual_filter_reduce = 2 residual_filter_scale = 2
scale_filters = 1 scale_filters = 1
else: else:
csp_filter_reduce = 2 csp_filter_scale = 2
residual_filter_reduce = 1 residual_filter_scale = 1
scale_filters = 2 scale_filters = 2
self._default_dict["activation"] = self._get_activation(config.activation) self._default_dict["activation"] = self._get_activation(config.activation)
self._default_dict["name"] = f"{name}_csp_down" self._default_dict["name"] = f"{name}_csp_down"
x, x_route = nn_blocks.CSPDownSample(filters=config.filters, x, x_route = nn_blocks.CSPRoute(filters=config.filters,
filter_reduce=csp_filter_reduce, filter_scale=csp_filter_scale,
downsample=True,
**self._default_dict)(inputs) **self._default_dict)(inputs)
for i in range(config.repetitions): for i in range(config.repetitions):
self._default_dict["name"] = f"{name}_{i}" self._default_dict["name"] = f"{name}_{i}"
x = nn_blocks.DarkResidual(filters=config.filters // scale_filters, x = nn_blocks.DarkResidual(filters=config.filters // scale_filters,
filter_scale=residual_filter_reduce, filter_scale=residual_filter_scale,
**self._default_dict)(x) **self._default_dict)(x)
self._default_dict["name"] = f"{name}_csp_connect" self._default_dict["name"] = f"{name}_csp_connect"
output = nn_blocks.CSPConnect(filters=config.filters, output = nn_blocks.CSPConnect(filters=config.filters,
filter_reduce=csp_filter_reduce, filter_scale=csp_filter_scale,
**self._default_dict)([x, x_route]) **self._default_dict)([x, x_route])
self._default_dict["activation"] = self._activation self._default_dict["activation"] = self._activation
self._default_dict["name"] = None self._default_dict["name"] = None
...@@ -335,7 +336,7 @@ class Darknet(ks.Model): ...@@ -335,7 +336,7 @@ class Darknet(ks.Model):
name=f"{name}_tiny/pool")(inputs) name=f"{name}_tiny/pool")(inputs)
self._default_dict["activation"] = self._get_activation(config.activation) self._default_dict["activation"] = self._get_activation(config.activation)
self._default_dict["name"] = f"{name}_tiny/conv" self._default_dict["name"] = f"{name}_tiny/conv"
x = nn_blocks.DarkConv(filters=config.filters, x = nn_blocks.ConvBN(filters=config.filters,
kernel_size=(3, 3), kernel_size=(3, 3),
strides=(1, 1), strides=(1, 1),
padding='same', padding='same',
......
"""Contains common building blocks for yolo neural networks.""" """Contains common building blocks for yolo neural networks."""
from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Text
import tensorflow as tf import tensorflow as tf
import tensorflow.keras as ks
import tensorflow.keras.backend as K
from official.modeling import tf_utils from official.modeling import tf_utils
@ks.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class Identity(ks.layers.Layer): class Identity(tf.keras.layers.Layer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -15,8 +13,8 @@ class Identity(ks.layers.Layer): ...@@ -15,8 +13,8 @@ class Identity(ks.layers.Layer):
return input return input
@ks.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class DarkConv(ks.layers.Layer): class ConvBN(tf.keras.layers.Layer):
''' '''
Modified Convolution layer to match that of the DarkNet Library. The Layer is a standards combination of Conv BatchNorm Activation, Modified Convolution layer to match that of the DarkNet Library. The Layer is a standards combination of Conv BatchNorm Activation,
however, the use of bias in the conv is determined by the use of batch normalization. The Layer also allows for feature grouping however, the use of bias in the conv is determined by the use of batch normalization. The Layer also allows for feature grouping
...@@ -59,9 +57,9 @@ class DarkConv(ks.layers.Layer): ...@@ -59,9 +57,9 @@ class DarkConv(ks.layers.Layer):
padding='same', padding='same',
dilation_rate=(1, 1), dilation_rate=(1, 1),
use_bias=True, use_bias=True,
groups = 1, groups=1,
group_id = 0, group_id=0,
grouping_only = False, grouping_only=False,
kernel_initializer='glorot_uniform', kernel_initializer='glorot_uniform',
bias_initializer='zeros', bias_initializer='zeros',
bias_regularizer=None, bias_regularizer=None,
...@@ -74,7 +72,6 @@ class DarkConv(ks.layers.Layer): ...@@ -74,7 +72,6 @@ class DarkConv(ks.layers.Layer):
leaky_alpha=0.1, leaky_alpha=0.1,
**kwargs): **kwargs):
# convolution params # convolution params
self._filters = filters self._filters = filters
self._kernel_size = kernel_size self._kernel_size = kernel_size
...@@ -109,19 +106,19 @@ class DarkConv(ks.layers.Layer): ...@@ -109,19 +106,19 @@ class DarkConv(ks.layers.Layer):
self._activation = activation self._activation = activation
self._leaky_alpha = leaky_alpha self._leaky_alpha = leaky_alpha
super(DarkConv, self).__init__(**kwargs) super(ConvBN, self).__init__(**kwargs)
def build(self, input_shape): def build(self, input_shape):
if not self._grouping_only: if not self._grouping_only:
kernel_size = self._kernel_size if type( kernel_size = self._kernel_size if type(
self._kernel_size) == int else self._kernel_size[0] self._kernel_size) == int else self._kernel_size[0]
if self._padding == "same" and kernel_size != 1: if self._padding == "same" and kernel_size != 1:
self._zeropad = ks.layers.ZeroPadding2D( self._zeropad = tf.keras.layers.ZeroPadding2D(
((1, 1), (1, 1))) # symmetric padding ((1, 1), (1, 1))) # symmetric padding
else: else:
self._zeropad = Identity() self._zeropad = Identity()
self.conv = ks.layers.Conv2D( self.conv = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._filters,
kernel_size=self._kernel_size, kernel_size=self._kernel_size,
strides=self._strides, strides=self._strides,
...@@ -140,24 +137,26 @@ class DarkConv(ks.layers.Layer): ...@@ -140,24 +137,26 @@ class DarkConv(ks.layers.Layer):
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
axis=self._bn_axis) axis=self._bn_axis)
else: else:
self.bn = ks.layers.BatchNormalization(momentum=self._norm_moment, self.bn = tf.keras.layers.BatchNormalization(
momentum=self._norm_moment,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
axis=self._bn_axis) axis=self._bn_axis)
else: else:
self.bn = Identity() self.bn = Identity()
if self._activation == 'leaky': if self._activation == 'leaky':
alpha = {"alpha": self._leaky_alpha} self._activation_fn = tf.keras.layers.LeakyReLU(
self._activation_fn = partial(tf.nn.leaky_relu, **alpha) alpha=self._leaky_alpha)
elif self._activation == "mish": elif self._activation == "mish":
self._activation_fn = lambda x: x * tf.math.tanh(tf.math.softplus(x)) self._activation_fn = lambda x: x * tf.math.tanh(
tf.math.softplus(x))
else: else:
self._activation_fn = tf_utils.get_activation(self._activation) self._activation_fn = tf_utils.get_activation(self._activation)
def call(self, x): def call(self, x):
if self._groups != 1: if self._groups != 1:
x = tf.split(x, self._groups, axis=-1) x = tf.split(x, self._groups, axis=-1)
x = x[self._group_id] # grouping x = x[self._group_id]
if not self._grouping_only: if not self._grouping_only:
x = self._zeropad(x) x = self._zeropad(x)
x = self.conv(x) x = self.conv(x)
...@@ -188,15 +187,15 @@ class DarkConv(ks.layers.Layer): ...@@ -188,15 +187,15 @@ class DarkConv(ks.layers.Layer):
"activation": self._activation, "activation": self._activation,
"leaky_alpha": self._leaky_alpha "leaky_alpha": self._leaky_alpha
} }
layer_config.update(super(DarkConv, self).get_config()) layer_config.update(super(ConvBN, self).get_config())
return layer_config return layer_config
def __repr__(self): def __repr__(self):
return repr(self.get_config()) return repr(self.get_config())
@ks.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class DarkResidual(ks.layers.Layer): class DarkResidual(tf.keras.layers.Layer):
''' '''
DarkNet block with Residual connection for Yolo v3 Backbone DarkNet block with Residual connection for Yolo v3 Backbone
...@@ -242,7 +241,7 @@ class DarkResidual(ks.layers.Layer): ...@@ -242,7 +241,7 @@ class DarkResidual(ks.layers.Layer):
# downsample # downsample
self._downsample = downsample self._downsample = downsample
# darkconv params # ConvBN params
self._filters = filters self._filters = filters
self._filter_scale = filter_scale self._filter_scale = filter_scale
self._use_bias = use_bias self._use_bias = use_bias
...@@ -265,20 +264,21 @@ class DarkResidual(ks.layers.Layer): ...@@ -265,20 +264,21 @@ class DarkResidual(ks.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
def build(self, input_shape): def build(self, input_shape):
_dark_conv_args = {"use_bias" : self._use_bias, _dark_conv_args = {
"kernel_initializer" : self._kernel_initializer, "use_bias": self._use_bias,
"bias_initializer" : self._bias_initializer, "kernel_initializer": self._kernel_initializer,
"bias_regularizer" : self._bias_regularizer, "bias_initializer": self._bias_initializer,
"use_bn" : self._use_bn, "bias_regularizer": self._bias_regularizer,
"use_sync_bn" : self._use_sync_bn, "use_bn": self._use_bn,
"norm_momentum" : self._norm_moment, "use_sync_bn": self._use_sync_bn,
"norm_epsilon" : self._norm_epsilon, "norm_momentum": self._norm_moment,
"activation" : self._conv_activation, "norm_epsilon": self._norm_epsilon,
"kernel_regularizer" : self._kernel_regularizer, "activation": self._conv_activation,
"leaky_alpha" : self._leaky_alpha "kernel_regularizer": self._kernel_regularizer,
"leaky_alpha": self._leaky_alpha
} }
if self._downsample: if self._downsample:
self._dconv = DarkConv(filters=self._filters, self._dconv = ConvBN(filters=self._filters,
kernel_size=(3, 3), kernel_size=(3, 3),
strides=(2, 2), strides=(2, 2),
padding='same', padding='same',
...@@ -286,25 +286,25 @@ class DarkResidual(ks.layers.Layer): ...@@ -286,25 +286,25 @@ class DarkResidual(ks.layers.Layer):
else: else:
self._dconv = Identity() self._dconv = Identity()
self._conv1 = DarkConv(filters=self._filters // self._filter_scale, self._conv1 = ConvBN(filters=self._filters // self._filter_scale,
kernel_size=(1, 1), kernel_size=(1, 1),
strides=(1, 1), strides=(1, 1),
padding='same', padding='same',
**_dark_conv_args) **_dark_conv_args)
self._conv2 = DarkConv(filters=self._filters, self._conv2 = ConvBN(filters=self._filters,
kernel_size=(3, 3), kernel_size=(3, 3),
strides=(1, 1), strides=(1, 1),
padding='same', padding='same',
**_dark_conv_args) **_dark_conv_args)
self._shortcut = ks.layers.Add() self._shortcut = tf.keras.layers.Add()
# self._activation_fn = ks.layers.Activation(activation=self._sc_activation)
if self._sc_activation == 'leaky': if self._sc_activation == 'leaky':
alpha = {"alpha": self._leaky_alpha} self._activation_fn = tf.keras.layers.LeakyReLU(
self._activation_fn = partial(tf.nn.leaky_relu, **alpha) alpha=self._leaky_alpha)
elif self._sc_activation == "mish": elif self._sc_activation == "mish":
self._activation_fn = lambda x: x * tf.math.tanh(tf.math.softplus(x)) self._activation_fn = lambda x: x * tf.math.tanh(
tf.math.softplus(x))
else: else:
self._activation_fn = tf_utils.get_activation(self._sc_activation) self._activation_fn = tf_utils.get_activation(self._sc_activation)
super().build(input_shape) super().build(input_shape)
...@@ -337,11 +337,11 @@ class DarkResidual(ks.layers.Layer): ...@@ -337,11 +337,11 @@ class DarkResidual(ks.layers.Layer):
return layer_config return layer_config
@ks.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class CSPTiny(ks.layers.Layer): class CSPTiny(tf.keras.layers.Layer):
""" """
A Small size convolution block proposed in the CSPNet. The layer uses shortcuts, routing(concatnation), and feature grouping A Small size convolution block proposed in the CSPNet. The layer uses shortcuts, routing(concatnation), and feature grouping
in order to improve gradient variablity and allow for high efficency, low power residual learning for small networks. in order to improve gradient variablity and allow for high efficency, low power residual learning for small networtf.keras.
Cross Stage Partial networks (CSPNets) were proposed in: Cross Stage Partial networks (CSPNets) were proposed in:
[1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, Ping-Yang Chen, Jun-Wei Hsieh [1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, Ping-Yang Chen, Jun-Wei Hsieh
...@@ -369,8 +369,7 @@ class CSPTiny(ks.layers.Layer): ...@@ -369,8 +369,7 @@ class CSPTiny(ks.layers.Layer):
so the dimensions are forced to match so the dimensions are forced to match
**kwargs: Keyword Arguments **kwargs: Keyword Arguments
""" """
def __init__( def __init__(self,
self,
filters=1, filters=1,
use_bias=True, use_bias=True,
kernel_initializer='glorot_uniform', kernel_initializer='glorot_uniform',
...@@ -388,7 +387,7 @@ class CSPTiny(ks.layers.Layer): ...@@ -388,7 +387,7 @@ class CSPTiny(ks.layers.Layer):
leaky_alpha=0.1, leaky_alpha=0.1,
**kwargs): **kwargs):
# darkconv params # ConvBN params
self._filters = filters self._filters = filters
self._use_bias = use_bias self._use_bias = use_bias
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
...@@ -412,31 +411,32 @@ class CSPTiny(ks.layers.Layer): ...@@ -412,31 +411,32 @@ class CSPTiny(ks.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
def build(self, input_shape): def build(self, input_shape):
_dark_conv_args = {"use_bias" : self._use_bias, _dark_conv_args = {
"kernel_initializer" : self._kernel_initializer, "use_bias": self._use_bias,
"bias_initializer" : self._bias_initializer, "kernel_initializer": self._kernel_initializer,
"bias_regularizer" : self._bias_regularizer, "bias_initializer": self._bias_initializer,
"use_bn" : self._use_bn, "bias_regularizer": self._bias_regularizer,
"use_sync_bn" : self._use_sync_bn, "use_bn": self._use_bn,
"norm_momentum" : self._norm_moment, "use_sync_bn": self._use_sync_bn,
"norm_epsilon" : self._norm_epsilon, "norm_momentum": self._norm_moment,
"activation" : self._conv_activation, "norm_epsilon": self._norm_epsilon,
"kernel_regularizer" : self._kernel_regularizer, "activation": self._conv_activation,
"leaky_alpha" : self._leaky_alpha "kernel_regularizer": self._kernel_regularizer,
"leaky_alpha": self._leaky_alpha
} }
self._convlayer1 = DarkConv(filters=self._filters, self._convlayer1 = ConvBN(filters=self._filters,
kernel_size=(3, 3), kernel_size=(3, 3),
strides=(1, 1), strides=(1, 1),
padding='same', padding='same',
**_dark_conv_args) **_dark_conv_args)
self._convlayer2 = DarkConv(filters=self._filters // 2, self._convlayer2 = ConvBN(filters=self._filters // 2,
kernel_size=(3, 3), kernel_size=(3, 3),
strides=(1, 1), strides=(1, 1),
padding='same', padding='same',
use_bias=self._use_bias, use_bias=self._use_bias,
groups = self._groups, groups=self._groups,
group_id = self._group_id, group_id=self._group_id,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
...@@ -448,13 +448,13 @@ class CSPTiny(ks.layers.Layer): ...@@ -448,13 +448,13 @@ class CSPTiny(ks.layers.Layer):
activation=self._conv_activation, activation=self._conv_activation,
leaky_alpha=self._leaky_alpha) leaky_alpha=self._leaky_alpha)
self._convlayer3 = DarkConv(filters=self._filters // 2, self._convlayer3 = ConvBN(filters=self._filters // 2,
kernel_size=(3, 3), kernel_size=(3, 3),
strides=(1, 1), strides=(1, 1),
padding='same', padding='same',
**_dark_conv_args) **_dark_conv_args)
self._convlayer4 = DarkConv(filters=self._filters, self._convlayer4 = ConvBN(filters=self._filters,
kernel_size=(1, 1), kernel_size=(1, 1),
strides=(1, 1), strides=(1, 1),
padding='same', padding='same',
...@@ -499,12 +499,12 @@ class CSPTiny(ks.layers.Layer): ...@@ -499,12 +499,12 @@ class CSPTiny(ks.layers.Layer):
return layer_config return layer_config
@ks.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class CSPDownSample(ks.layers.Layer): class CSPRoute(tf.keras.layers.Layer):
""" """
Down sampling layer to take the place of down sampleing done in Residual networks. This is Down sampling layer to take the place of down sampleing done in Residual networks. This is
the first of 2 layers needed to convert any Residual Network model to a CSPNet. At the start of a new the first of 2 layers needed to convert any Residual Network model to a CSPNet. At the start of a new
level change, this CSPDownsample layer creates a learned identity that will act as a cross stage connection, level change, this CSPRoute layer creates a learned identity that will act as a cross stage connection,
that is used to inform the inputs to the next stage. It is called cross stage partial because the number of filters that is used to inform the inputs to the next stage. It is called cross stage partial because the number of filters
required in every intermitent Residual layer is reduced by half. The sister layer will take the partial generated by required in every intermitent Residual layer is reduced by half. The sister layer will take the partial generated by
this layer and concatnate it with the output of the final residual layer in the stack to create a fully feature level this layer and concatnate it with the output of the final residual layer in the stack to create a fully feature level
...@@ -518,7 +518,8 @@ class CSPDownSample(ks.layers.Layer): ...@@ -518,7 +518,8 @@ class CSPDownSample(ks.layers.Layer):
Args: Args:
filters: integer for output depth, or the number of features to learn filters: integer for output depth, or the number of features to learn
filter_reduce: integer dicating (filters//2) or the number of filters in the partial feature stack filter_scale: integer dicating (filters//2) or the number of filters in the partial feature stack
downsample: down_sample the input
activation: string for activation function to use in layer activation: string for activation function to use in layer
kernel_initializer: string to indicate which function to use to initialize weights kernel_initializer: string to indicate which function to use to initialize weights
bias_initializer: string to indicate which function to use to initialize bias bias_initializer: string to indicate which function to use to initialize bias
...@@ -531,10 +532,9 @@ class CSPDownSample(ks.layers.Layer): ...@@ -531,10 +532,9 @@ class CSPDownSample(ks.layers.Layer):
norm_epsilon: float for batch normalization epsilon norm_epsilon: float for batch normalization epsilon
**kwargs: Keyword Arguments **kwargs: Keyword Arguments
""" """
def __init__( def __init__(self,
self,
filters, filters,
filter_reduce=2, filter_scale=2,
activation="mish", activation="mish",
kernel_initializer='glorot_uniform', kernel_initializer='glorot_uniform',
bias_initializer='zeros', bias_initializer='zeros',
...@@ -544,12 +544,13 @@ class CSPDownSample(ks.layers.Layer): ...@@ -544,12 +544,13 @@ class CSPDownSample(ks.layers.Layer):
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
downsample=True,
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
#layer params #layer params
self._filters = filters self._filters = filters
self._filter_reduce = filter_reduce self._filter_scale = filter_scale
self._activation = activation self._activation = activation
#convoultion params #convoultion params
...@@ -561,28 +562,36 @@ class CSPDownSample(ks.layers.Layer): ...@@ -561,28 +562,36 @@ class CSPDownSample(ks.layers.Layer):
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._norm_moment = norm_momentum self._norm_moment = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._downsample = downsample
def build(self, input_shape): def build(self, input_shape):
_dark_conv_args = {"kernel_initializer" : self._kernel_initializer, _dark_conv_args = {
"bias_initializer" : self._bias_initializer, "kernel_initializer": self._kernel_initializer,
"bias_regularizer" : self._bias_regularizer, "bias_initializer": self._bias_initializer,
"use_bn" : self._use_bn, "bias_regularizer": self._bias_regularizer,
"use_sync_bn" : self._use_sync_bn, "use_bn": self._use_bn,
"norm_momentum" : self._norm_moment, "use_sync_bn": self._use_sync_bn,
"norm_epsilon" : self._norm_epsilon, "norm_momentum": self._norm_moment,
"activation" : self._activation, "norm_epsilon": self._norm_epsilon,
"kernel_regularizer" : self._kernel_regularizer, "activation": self._activation,
"kernel_regularizer": self._kernel_regularizer,
} }
self._conv1 = DarkConv(filters=self._filters, if self._downsample:
self._conv1 = ConvBN(filters=self._filters,
kernel_size=(3, 3), kernel_size=(3, 3),
strides=(2, 2), strides=(2, 2),
**_dark_conv_args) **_dark_conv_args)
self._conv2 = DarkConv(filters=self._filters // self._filter_reduce, else:
self._conv1 = ConvBN(filters=self._filters,
kernel_size=(3, 3),
strides=(1, 1),
**_dark_conv_args)
self._conv2 = ConvBN(filters=self._filters // self._filter_scale,
kernel_size=(1, 1), kernel_size=(1, 1),
strides=(1, 1), strides=(1, 1),
**_dark_conv_args) **_dark_conv_args)
self._conv3 = DarkConv(filters=self._filters // self._filter_reduce, self._conv3 = ConvBN(filters=self._filters // self._filter_scale,
kernel_size=(1, 1), kernel_size=(1, 1),
strides=(1, 1), strides=(1, 1),
**_dark_conv_args) **_dark_conv_args)
...@@ -594,10 +603,10 @@ class CSPDownSample(ks.layers.Layer): ...@@ -594,10 +603,10 @@ class CSPDownSample(ks.layers.Layer):
return (x, y) return (x, y)
@ks.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class CSPConnect(ks.layers.Layer): class CSPConnect(tf.keras.layers.Layer):
""" """
Sister Layer to the CSPDownsample layer. Merges the partial feature stacks generated by the CSPDownsampling layer, Sister Layer to the CSPRoute layer. Merges the partial feature stacks generated by the CSPDownsampling layer,
and the finaly output of the residual stack. Suggested in the CSPNet paper. and the finaly output of the residual stack. Suggested in the CSPNet paper.
Cross Stage Partial networks (CSPNets) were proposed in: Cross Stage Partial networks (CSPNets) were proposed in:
...@@ -606,7 +615,7 @@ class CSPConnect(ks.layers.Layer): ...@@ -606,7 +615,7 @@ class CSPConnect(ks.layers.Layer):
Args: Args:
filters: integer for output depth, or the number of features to learn filters: integer for output depth, or the number of features to learn
filter_reduce: integer dicating (filters//2) or the number of filters in the partial feature stack filter_scale: integer dicating (filters//2) or the number of filters in the partial feature stack
activation: string for activation function to use in layer activation: string for activation function to use in layer
kernel_initializer: string to indicate which function to use to initialize weights kernel_initializer: string to indicate which function to use to initialize weights
bias_initializer: string to indicate which function to use to initialize bias bias_initializer: string to indicate which function to use to initialize bias
...@@ -619,10 +628,9 @@ class CSPConnect(ks.layers.Layer): ...@@ -619,10 +628,9 @@ class CSPConnect(ks.layers.Layer):
norm_epsilon: float for batch normalization epsilon norm_epsilon: float for batch normalization epsilon
**kwargs: Keyword Arguments **kwargs: Keyword Arguments
""" """
def __init__( def __init__(self,
self,
filters, filters,
filter_reduce=2, filter_scale=2,
activation="mish", activation="mish",
kernel_initializer='glorot_uniform', kernel_initializer='glorot_uniform',
bias_initializer='zeros', bias_initializer='zeros',
...@@ -637,7 +645,7 @@ class CSPConnect(ks.layers.Layer): ...@@ -637,7 +645,7 @@ class CSPConnect(ks.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
#layer params #layer params
self._filters = filters self._filters = filters
self._filter_reduce = filter_reduce self._filter_scale = filter_scale
self._activation = activation self._activation = activation
#convoultion params #convoultion params
...@@ -652,22 +660,22 @@ class CSPConnect(ks.layers.Layer): ...@@ -652,22 +660,22 @@ class CSPConnect(ks.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
_dark_conv_args = { _dark_conv_args = {
"kernel_initializer" : self._kernel_initializer, "kernel_initializer": self._kernel_initializer,
"bias_initializer" : self._bias_initializer, "bias_initializer": self._bias_initializer,
"bias_regularizer" : self._bias_regularizer, "bias_regularizer": self._bias_regularizer,
"use_bn" : self._use_bn, "use_bn": self._use_bn,
"use_sync_bn" : self._use_sync_bn, "use_sync_bn": self._use_sync_bn,
"norm_momentum" : self._norm_moment, "norm_momentum": self._norm_moment,
"norm_epsilon" : self._norm_epsilon, "norm_epsilon": self._norm_epsilon,
"activation" : self._activation, "activation": self._activation,
"kernel_regularizer" : self._kernel_regularizer, "kernel_regularizer": self._kernel_regularizer,
} }
self._conv1 = DarkConv(filters=self._filters // self._filter_reduce, self._conv1 = ConvBN(filters=self._filters // self._filter_scale,
kernel_size=(1, 1), kernel_size=(1, 1),
strides=(1, 1), strides=(1, 1),
**_dark_conv_args) **_dark_conv_args)
self._concat = ks.layers.Concatenate(axis=-1) self._concat = tf.keras.layers.Concatenate(axis=-1)
self._conv2 = DarkConv(filters=self._filters, self._conv2 = ConvBN(filters=self._filters,
kernel_size=(1, 1), kernel_size=(1, 1),
strides=(1, 1), strides=(1, 1),
**_dark_conv_args) **_dark_conv_args)
...@@ -678,3 +686,102 @@ class CSPConnect(ks.layers.Layer): ...@@ -678,3 +686,102 @@ class CSPConnect(ks.layers.Layer):
x = self._concat([x, x_csp]) x = self._concat([x, x_csp])
x = self._conv2(x) x = self._conv2(x)
return x return x
class CSPStack(tf.keras.layers.Layer):
"""
CSP full stack, combines the route and the connect in case you dont want to jsut quickly wrap an existing callable or list of layers to
make it a cross stage partial. Added for ease of use. you should be able to wrap any layer stack with a CSP independent of wether it belongs
to the Darknet family. if filter_scale = 2, then the blocks in the stack passed into the the CSP stack should also have filters = filters/filter_scale
Cross Stage Partial networks (CSPNets) were proposed in:
[1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, Ping-Yang Chen, Jun-Wei Hsieh
CSPNet: A New Backbone that can Enhance Learning Capability of CNN. arXiv:1911.11929
Args:
model_to_wrap: callable Model or a list of callable objects that will process the output of CSPRoute, and be input into CSPConnect.
list will be called sequentially.
downsample: down_sample the input
filters: integer for output depth, or the number of features to learn
filter_scale: integer dicating (filters//2) or the number of filters in the partial feature stack
activation: string for activation function to use in layer
kernel_initializer: string to indicate which function to use to initialize weights
bias_initializer: string to indicate which function to use to initialize bias
kernel_regularizer: string to indicate which function to use to regularizer weights
bias_regularizer: string to indicate which function to use to regularizer bias
use_bn: boolean for whether to use batch normalization
use_sync_bn: boolean for whether sync batch normalization statistics
of all batch norm layers to the models global statistics (across all input batches)
norm_moment: float for moment to use for batch normalization
norm_epsilon: float for batch normalization epsilon
**kwargs: Keyword Arguments
"""
def __init__(self,
filters,
model_to_wrap=None,
filter_scale=2,
activation="mish",
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
bias_regularizer=None,
kernel_regularizer=None,
downsample=True,
use_bn=True,
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
super().__init__(**kwargs)
#layer params
self._filters = filters
self._filter_scale = filter_scale
self._activation = activation
self._downsample = downsample
#convoultion params
self._kernel_initializer = kernel_initializer
self._bias_initializer = bias_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_bn = use_bn
self._use_sync_bn = use_sync_bn
self._norm_moment = norm_momentum
self._norm_epsilon = norm_epsilon
if model_to_wrap != None:
if isinstance(model_to_wrap, Callable):
self._model_to_wrap = [model_to_wrap]
elif isinstance(model_to_wrap, List):
self._model_to_wrap = model_to_wrap
else:
raise Exception(
"the input to the CSPStack must be a list of layers that we can iterate through, or \n a callable"
)
else:
self._model_to_wrap = []
def build(self, input_shape):
_dark_conv_args = {
"filters": self._filters,
"filter_scale": self._filter_scale,
"activation": self._activation,
"kernel_initializer": self._kernel_initializer,
"bias_initializer": self._bias_initializer,
"bias_regularizer": self._bias_regularizer,
"use_bn": self._use_bn,
"use_sync_bn": self._use_sync_bn,
"norm_momentum": self._norm_moment,
"norm_epsilon": self._norm_epsilon,
"kernel_regularizer": self._kernel_regularizer,
}
self._route = CSPRoute(downsample=self._downsample, **_dark_conv_args)
self._connect = CSPConnect(**_dark_conv_args)
return
def call(self, inputs):
x, x_route = self._route(inputs)
for layer in self._model_to_wrap:
x = layer(x)
x = self._connect([x, x_route])
return x
...@@ -7,14 +7,14 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks ...@@ -7,14 +7,14 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
class CSPConnect(tf.test.TestCase, parameterized.TestCase): class CSPConnectTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("same", 224, 224, 64, 1), @parameterized.named_parameters(("same", 224, 224, 64, 1),
("downsample", 224, 224, 64, 2)) ("downsample", 224, 224, 64, 2))
def test_pass_through(self, width, height, filters, mod): def test_pass_through(self, width, height, filters, mod):
x = ks.Input(shape=(width, height, filters)) x = ks.Input(shape=(width, height, filters))
test_layer = nn_blocks.CSPDownSample(filters=filters, filter_reduce=mod) test_layer = nn_blocks.CSPRoute(filters=filters, filter_scale=mod)
test_layer2 = nn_blocks.CSPConnect(filters=filters, filter_reduce=mod) test_layer2 = nn_blocks.CSPConnect(filters=filters, filter_scale=mod)
outx, px = test_layer(x) outx, px = test_layer(x)
outx = test_layer2([outx, px]) outx = test_layer2([outx, px])
print(outx) print(outx)
...@@ -29,8 +29,8 @@ class CSPConnect(tf.test.TestCase, parameterized.TestCase): ...@@ -29,8 +29,8 @@ class CSPConnect(tf.test.TestCase, parameterized.TestCase):
def test_gradient_pass_though(self, filters, width, height, mod): def test_gradient_pass_though(self, filters, width, height, mod):
loss = ks.losses.MeanSquaredError() loss = ks.losses.MeanSquaredError()
optimizer = ks.optimizers.SGD() optimizer = ks.optimizers.SGD()
test_layer = nn_blocks.CSPDownSample(filters, filter_reduce=mod) test_layer = nn_blocks.CSPRoute(filters, filter_scale=mod)
path_layer = nn_blocks.CSPConnect(filters, filter_reduce=mod) path_layer = nn_blocks.CSPConnect(filters, filter_scale=mod)
init = tf.random_normal_initializer() init = tf.random_normal_initializer()
x = tf.Variable( x = tf.Variable(
...@@ -49,13 +49,13 @@ class CSPConnect(tf.test.TestCase, parameterized.TestCase): ...@@ -49,13 +49,13 @@ class CSPConnect(tf.test.TestCase, parameterized.TestCase):
self.assertNotIn(None, grad) self.assertNotIn(None, grad)
class CSPDownSample(tf.test.TestCase, parameterized.TestCase): class CSPRouteTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("same", 224, 224, 64, 1), @parameterized.named_parameters(("same", 224, 224, 64, 1),
("downsample", 224, 224, 64, 2)) ("downsample", 224, 224, 64, 2))
def test_pass_through(self, width, height, filters, mod): def test_pass_through(self, width, height, filters, mod):
x = ks.Input(shape=(width, height, filters)) x = ks.Input(shape=(width, height, filters))
test_layer = nn_blocks.CSPDownSample(filters=filters, filter_reduce=mod) test_layer = nn_blocks.CSPRoute(filters=filters, filter_scale=mod)
outx, px = test_layer(x) outx, px = test_layer(x)
print(outx) print(outx)
print(outx.shape.as_list()) print(outx.shape.as_list())
...@@ -69,8 +69,8 @@ class CSPDownSample(tf.test.TestCase, parameterized.TestCase): ...@@ -69,8 +69,8 @@ class CSPDownSample(tf.test.TestCase, parameterized.TestCase):
def test_gradient_pass_though(self, filters, width, height, mod): def test_gradient_pass_though(self, filters, width, height, mod):
loss = ks.losses.MeanSquaredError() loss = ks.losses.MeanSquaredError()
optimizer = ks.optimizers.SGD() optimizer = ks.optimizers.SGD()
test_layer = nn_blocks.CSPDownSample(filters, filter_reduce=mod) test_layer = nn_blocks.CSPRoute(filters, filter_scale=mod)
path_layer = nn_blocks.CSPConnect(filters, filter_reduce=mod) path_layer = nn_blocks.CSPConnect(filters, filter_scale=mod)
init = tf.random_normal_initializer() init = tf.random_normal_initializer()
x = tf.Variable( x = tf.Variable(
...@@ -89,7 +89,75 @@ class CSPDownSample(tf.test.TestCase, parameterized.TestCase): ...@@ -89,7 +89,75 @@ class CSPDownSample(tf.test.TestCase, parameterized.TestCase):
self.assertNotIn(None, grad) self.assertNotIn(None, grad)
class DarkConvTest(tf.test.TestCase, parameterized.TestCase): class CSPStackTest(tf.test.TestCase, parameterized.TestCase):
def build_layer(self, layer_type, filters, filter_scale, count, stack_type, downsample):
if stack_type != None:
layers = []
if layer_type == "residual":
for _ in range(count):
layers.append(nn_blocks.DarkResidual(filters = filters // filter_scale, filter_scale = filter_scale))
else:
for _ in range(count):
layers.append(nn_blocks.ConvBN(filters = filters))
if stack_type == "model":
layers = tf.keras.Sequential(layers=layers)
else:
layers = None
stack = nn_blocks.CSPStack(filters = filters,
filter_scale = filter_scale,
downsample = downsample,
model_to_wrap = layers)
return stack
@parameterized.named_parameters(("no_stack", 224, 224, 64, 2, "residual", None, 0, True),
("residual_stack", 224, 224, 64, 2, "residual", "list", 2, True),
("conv_stack", 224, 224, 64, 2, "conv", "list", 3, False),
("callable_no_scale", 224, 224, 64, 1, "residual", "model", 5, False))
def test_pass_through(self, width, height, filters, mod, layer_type, stack_type, count, downsample):
x = ks.Input(shape=(width, height, filters))
test_layer = self.build_layer(layer_type, filters, mod, count, stack_type, downsample)
outx = test_layer(x)
print(outx)
print(outx.shape.as_list())
if downsample:
self.assertAllEqual(
outx.shape.as_list(),
[None, width//2, height//2, filters])
else:
self.assertAllEqual(
outx.shape.as_list(),
[None, width, height, filters])
@parameterized.named_parameters(("no_stack", 224, 224, 64, 2, "residual", None, 0, True),
("residual_stack", 224, 224, 64, 2, "residual", "list", 2, True),
("conv_stack", 224, 224, 64, 2, "conv", "list", 3, False),
("callable_no_scale", 224, 224, 64, 1, "residual", "model", 5, False))
def test_gradient_pass_though(self, width, height, filters, mod, layer_type, stack_type, count, downsample):
loss = ks.losses.MeanSquaredError()
optimizer = ks.optimizers.SGD()
init = tf.random_normal_initializer()
x = tf.Variable(initial_value=init(shape=(1, width, height, filters), dtype=tf.float32))
if not downsample:
y = tf.Variable(initial_value=init(shape=(1, width, height, filters), dtype=tf.float32))
else:
y = tf.Variable(initial_value=init(shape=(1, width//2, height//2, filters), dtype=tf.float32))
test_layer = self.build_layer(layer_type, filters, mod, count, stack_type, downsample)
with tf.GradientTape() as tape:
x_hat = test_layer(x)
grad_loss = loss(x_hat, y)
grad = tape.gradient(grad_loss, test_layer.trainable_variables)
optimizer.apply_gradients(zip(grad, test_layer.trainable_variables))
self.assertNotIn(None, grad)
class ConvBNTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
("valid", (3, 3), "valid", (1, 1)), ("same", (3, 3), "same", (1, 1)), ("valid", (3, 3), "valid", (1, 1)), ("same", (3, 3), "same", (1, 1)),
...@@ -100,7 +168,7 @@ class DarkConvTest(tf.test.TestCase, parameterized.TestCase): ...@@ -100,7 +168,7 @@ class DarkConvTest(tf.test.TestCase, parameterized.TestCase):
else: else:
pad_const = 0 pad_const = 0
x = ks.Input(shape=(224, 224, 3)) x = ks.Input(shape=(224, 224, 3))
test_layer = nn_blocks.DarkConv(filters=64, test_layer = nn_blocks.ConvBN(filters=64,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=padding, padding=padding,
strides=strides, strides=strides,
...@@ -120,7 +188,7 @@ class DarkConvTest(tf.test.TestCase, parameterized.TestCase): ...@@ -120,7 +188,7 @@ class DarkConvTest(tf.test.TestCase, parameterized.TestCase):
loss = ks.losses.MeanSquaredError() loss = ks.losses.MeanSquaredError()
optimizer = ks.optimizers.SGD() optimizer = ks.optimizers.SGD()
with tf.device("/CPU:0"): with tf.device("/CPU:0"):
test_layer = nn_blocks.DarkConv(filters, kernel_size=(3, 3), padding="same") test_layer = nn_blocks.ConvBN(filters, kernel_size=(3, 3), padding="same")
init = tf.random_normal_initializer() init = tf.random_normal_initializer()
x = tf.Variable(initial_value=init(shape=(1, 224, 224, x = tf.Variable(initial_value=init(shape=(1, 224, 224,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment