Commit a7f56602 authored by vishnubanna's avatar vishnubanna
Browse files

cleaner grouping

parent d5747aac
......@@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Text
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='yolo')
class Identity(tf.keras.layers.Layer):
def __init__(self, **kwargs):
......@@ -17,8 +16,7 @@ class Identity(tf.keras.layers.Layer):
class ConvBN(tf.keras.layers.Layer):
'''
Modified Convolution layer to match that of the DarkNet Library. The Layer is a standards combination of Conv BatchNorm Activation,
however, the use of bias in the conv is determined by the use of batch normalization. The Layer also allows for feature grouping
suggested in the CSPNet paper
however, the use of bias in the conv is determined by the use of batch normalization.
Cross Stage Partial networks (CSPNets) were proposed in:
[1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, Ping-Yang Chen, Jun-Wei Hsieh
......@@ -36,9 +34,6 @@ class ConvBN(tf.keras.layers.Layer):
bias_initializer: string to indicate which function to use to initialize bias
kernel_regularizer: string to indicate which function to use to regularizer weights
bias_regularizer: string to indicate which function to use to regularizer bias
group_id: integer for which group of features to pass through the conv.
groups: integer for how many splits there should be in the convolution feature stack input
grouping_only: skip the convolution and only return the group of features indicated by grouping_only
use_bn: boolean for whether to use batch normalization
use_sync_bn: boolean for whether sync batch normalization statistics
of all batch norm layers to the models global statistics (across all input batches)
......@@ -57,9 +52,6 @@ class ConvBN(tf.keras.layers.Layer):
padding='same',
dilation_rate=(1, 1),
use_bias=True,
groups=1,
group_id=0,
grouping_only=False,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
bias_regularizer=None,
......@@ -79,9 +71,6 @@ class ConvBN(tf.keras.layers.Layer):
self._padding = padding
self._dilation_rate = dilation_rate
self._use_bias = use_bias
self._groups = groups
self._group_id = group_id
self._grouping_only = grouping_only
self._kernel_initializer = kernel_initializer
self._bias_initializer = bias_initializer
self._kernel_regularizer = kernel_regularizer
......@@ -109,59 +98,55 @@ class ConvBN(tf.keras.layers.Layer):
super(ConvBN, self).__init__(**kwargs)
def build(self, input_shape):
if not self._grouping_only:
kernel_size = self._kernel_size if type(
self._kernel_size) == int else self._kernel_size[0]
if self._padding == "same" and kernel_size != 1:
self._zeropad = tf.keras.layers.ZeroPadding2D(
((1, 1), (1, 1))) # symmetric padding
else:
self._zeropad = Identity()
self.conv = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=self._kernel_size,
strides=self._strides,
padding="valid",
dilation_rate=self._dilation_rate,
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
if self._use_bn:
if self._use_sync_bn:
self.bn = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=self._norm_moment,
epsilon=self._norm_epsilon,
axis=self._bn_axis)
else:
self.bn = tf.keras.layers.BatchNormalization(
momentum=self._norm_moment,
epsilon=self._norm_epsilon,
axis=self._bn_axis)
else:
self.bn = Identity()
if self._activation == 'leaky':
self._activation_fn = tf.keras.layers.LeakyReLU(
alpha=self._leaky_alpha)
elif self._activation == "mish":
self._activation_fn = lambda x: x * tf.math.tanh(
tf.math.softplus(x))
kernel_size = self._kernel_size if type(self._kernel_size) == int else self._kernel_size[0]
dilation_rate = self._dilation_rate if type(self._dilation_rate) == int else self._dilation_rate[0]
if self._padding == "same" and kernel_size != 1:
padding = dilation_rate * (kernel_size - 1)
left_shift = tf.cast(tf.math.floor(padding/2), dtype = tf.int32)
self._zeropad = tf.keras.layers.ZeroPadding2D([[left_shift, left_shift], [left_shift, left_shift]])
else:
self._zeropad = Identity()
self.conv = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=self._kernel_size,
strides=self._strides,
padding="valid",
dilation_rate=self._dilation_rate,
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
if self._use_bn:
if self._use_sync_bn:
self.bn = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=self._norm_moment,
epsilon=self._norm_epsilon,
axis=self._bn_axis)
else:
self._activation_fn = tf_utils.get_activation(self._activation)
self.bn = tf.keras.layers.BatchNormalization(
momentum=self._norm_moment,
epsilon=self._norm_epsilon,
axis=self._bn_axis)
else:
self.bn = Identity()
if self._activation == 'leaky':
self._activation_fn = tf.keras.layers.LeakyReLU(
alpha=self._leaky_alpha)
elif self._activation == "mish":
self._activation_fn = lambda x: x * tf.math.tanh(
tf.math.softplus(x))
else:
self._activation_fn = tf_utils.get_activation(self._activation)
def call(self, x):
if self._groups != 1:
x = tf.split(x, self._groups, axis=-1)
x = x[self._group_id]
if not self._grouping_only:
x = self._zeropad(x)
x = self.conv(x)
x = self.bn(x)
x = self._activation_fn(x)
x = self._zeropad(x)
x = self.conv(x)
x = self.bn(x)
x = self._activation_fn(x)
return x
def get_config(self):
......@@ -173,9 +158,6 @@ class ConvBN(tf.keras.layers.Layer):
"padding": self._padding,
"dilation_rate": self._dilation_rate,
"use_bias": self._use_bias,
"groups": self._groups,
"group_id": self._group_id,
"grouping_only": self._grouping_only,
"kernel_initializer": self._kernel_initializer,
"bias_initializer": self._bias_initializer,
"bias_regularizer": self._bias_regularizer,
......@@ -435,8 +417,6 @@ class CSPTiny(tf.keras.layers.Layer):
strides=(1, 1),
padding='same',
use_bias=self._use_bias,
groups=self._groups,
group_id=self._group_id,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
bias_regularizer=self._bias_regularizer,
......@@ -469,7 +449,8 @@ class CSPTiny(tf.keras.layers.Layer):
def call(self, inputs):
x1 = self._convlayer1(inputs)
x2 = self._convlayer2(x1) # grouping
x_iterm = tf.split(x, self._groups, axis = -1)[self._group_id]
x2 = self._convlayer2(x_interm) # grouping
x3 = self._convlayer3(x2)
x4 = tf.concat([x3, x2], axis=-1) # csp partial using grouping
x5 = self._convlayer4(x4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment