Commit e3f09134 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

head update

parent 350b0aa6
......@@ -30,10 +30,11 @@ class YoloHead(tf.keras.layers.Layer):
output_extras=0,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='glorot_uniform',
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation=None,
smart_bias=False,
**kwargs):
"""Yolo Prediction Head initialization function.
......@@ -68,6 +69,7 @@ class YoloHead(tf.keras.layers.Layer):
self._output_extras = output_extras
self._output_conv = (classes + output_extras + 5) * boxes_per_level
self._smart_bias = smart_bias
self._base_config = dict(
activation=activation,
......@@ -85,10 +87,29 @@ class YoloHead(tf.keras.layers.Layer):
use_bn=False,
**self._base_config)
def bias_init(self, scale, inshape, isize=640, no_per_conf=8):
def bias(shape, dtype):
init = tf.keras.initializers.Zeros()
base = init(shape, dtype=dtype)
if self._smart_bias:
base = tf.reshape(base, [self._boxes_per_level, -1])
box, conf, classes = tf.split(base, [4, 1, -1], axis=-1)
conf += tf.math.log(no_per_conf / ((isize / scale)**2))
classes += tf.math.log(0.6 / (self._classes - 0.99))
base = tf.concat([box, conf, classes], axis=-1)
base = tf.reshape(base, [-1])
return base
return bias
def build(self, input_shape):
self._head = dict()
for key in self._key_list:
self._head[key] = nn_blocks.ConvBN(**self._conv_config)
scale = 2**int(key)
self._head[key] = nn_blocks.ConvBN(
bias_initializer=self.bias_init(scale, input_shape[key][-1]),
**self._conv_config)
def call(self, inputs):
outputs = dict()
......@@ -104,9 +125,13 @@ class YoloHead(tf.keras.layers.Layer):
def num_boxes(self):
if self._min_level is None or self._max_level is None:
raise Exception(
'Model has to be built before number of boxes can be determined.')
'model has to be built before number of boxes can be determined')
return (self._max_level - self._min_level + 1) * self._boxes_per_level
@property
def num_heads(self):
return (self._max_level - self._min_level + 1)
def get_config(self):
config = dict(
min_level=self._min_level,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment