Commit d09d4bef authored by Vishnu Banna's avatar Vishnu Banna
Browse files

model heads update

parent 0789bd4c
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Yolo heads.""" """Yolo heads."""
import tensorflow as tf import tensorflow as tf
...@@ -34,6 +35,7 @@ class YoloHead(tf.keras.layers.Layer): ...@@ -34,6 +35,7 @@ class YoloHead(tf.keras.layers.Layer):
bias_regularizer=None, bias_regularizer=None,
activation=None, activation=None,
smart_bias=False, smart_bias=False,
use_separable_conv=False,
**kwargs): **kwargs):
"""Yolo Prediction Head initialization function. """Yolo Prediction Head initialization function.
...@@ -52,7 +54,6 @@ class YoloHead(tf.keras.layers.Layer): ...@@ -52,7 +54,6 @@ class YoloHead(tf.keras.layers.Layer):
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
activation: `str`, the activation function to use typically leaky or mish. activation: `str`, the activation function to use typically leaky or mish.
smart_bias: `bool` whether or not use smart bias.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
...@@ -70,6 +71,7 @@ class YoloHead(tf.keras.layers.Layer): ...@@ -70,6 +71,7 @@ class YoloHead(tf.keras.layers.Layer):
self._output_conv = (classes + output_extras + 5) * boxes_per_level self._output_conv = (classes + output_extras + 5) * boxes_per_level
self._smart_bias = smart_bias self._smart_bias = smart_bias
self._use_separable_conv = use_separable_conv
self._base_config = dict( self._base_config = dict(
activation=activation, activation=activation,
...@@ -85,6 +87,7 @@ class YoloHead(tf.keras.layers.Layer): ...@@ -85,6 +87,7 @@ class YoloHead(tf.keras.layers.Layer):
strides=(1, 1), strides=(1, 1),
padding='same', padding='same',
use_bn=False, use_bn=False,
use_separable_conv=self._use_separable_conv,
**self._base_config) **self._base_config)
def bias_init(self, scale, inshape, isize=640, no_per_conf=8): def bias_init(self, scale, inshape, isize=640, no_per_conf=8):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment