Commit d09d4bef authored by Vishnu Banna's avatar Vishnu Banna
Browse files

model heads update

parent 0789bd4c
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Yolo heads."""
import tensorflow as tf
......@@ -34,6 +35,7 @@ class YoloHead(tf.keras.layers.Layer):
bias_regularizer=None,
activation=None,
smart_bias=False,
use_separable_conv=False,
**kwargs):
"""Yolo Prediction Head initialization function.
......@@ -52,7 +54,6 @@ class YoloHead(tf.keras.layers.Layer):
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
activation: `str`, the activation function to use typically leaky or mish.
smart_bias: `bool` whether or not use smart bias.
**kwargs: keyword arguments to be passed.
"""
......@@ -70,6 +71,7 @@ class YoloHead(tf.keras.layers.Layer):
self._output_conv = (classes + output_extras + 5) * boxes_per_level
self._smart_bias = smart_bias
self._use_separable_conv = use_separable_conv
self._base_config = dict(
activation=activation,
......@@ -85,6 +87,7 @@ class YoloHead(tf.keras.layers.Layer):
strides=(1, 1),
padding='same',
use_bn=False,
use_separable_conv=self._use_separable_conv,
**self._base_config)
def bias_init(self, scale, inshape, isize=640, no_per_conf=8):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment