Commit 7752c35a authored by Vishnu Banna's avatar Vishnu Banna
Browse files

decoder update

parent f27d88f9
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
"""Feature Pyramid Network and Path Aggregation variants used in YOLO.""" """Feature Pyramid Network and Path Aggregation variants used in YOLO"""
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
...@@ -22,7 +22,7 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks ...@@ -22,7 +22,7 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class _IdentityRoute(tf.keras.layers.Layer): class _IdentityRoute(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs): # pylint: disable=arguments-differ
return None, inputs return None, inputs
...@@ -39,7 +39,7 @@ class YoloFPN(tf.keras.layers.Layer): ...@@ -39,7 +39,7 @@ class YoloFPN(tf.keras.layers.Layer):
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
kernel_initializer='glorot_uniform', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
**kwargs): **kwargs):
...@@ -172,7 +172,7 @@ class YoloFPN(tf.keras.layers.Layer): ...@@ -172,7 +172,7 @@ class YoloFPN(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class YoloPAN(tf.keras.layers.Layer): class YoloPAN(tf.keras.layers.Layer):
"""YOLO Path Aggregation Network.""" """YOLO Path Aggregation Network"""
def __init__(self, def __init__(self,
path_process_len=6, path_process_len=6,
...@@ -184,7 +184,7 @@ class YoloPAN(tf.keras.layers.Layer): ...@@ -184,7 +184,7 @@ class YoloPAN(tf.keras.layers.Layer):
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
kernel_initializer='glorot_uniform', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
fpn_input=True, fpn_input=True,
...@@ -323,7 +323,7 @@ class YoloPAN(tf.keras.layers.Layer): ...@@ -323,7 +323,7 @@ class YoloPAN(tf.keras.layers.Layer):
Args: Args:
minimum_depth: `int` depth of the smallest branch of the FPN. minimum_depth: `int` depth of the smallest branch of the FPN.
inputs: `dict[str, tf.InputSpec]` of the shape of input args as a inputs: `dict[str, tf.InputSpec]` of the shape of input args as a
dictionary of lists. dictionary of lists.
Returns: Returns:
...@@ -374,14 +374,14 @@ class YoloDecoder(tf.keras.Model): ...@@ -374,14 +374,14 @@ class YoloDecoder(tf.keras.Model):
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
kernel_initializer='glorot_uniform', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
**kwargs): **kwargs):
"""Yolo Decoder initialization function. """Yolo Decoder initialization function.
A unified model that ties all decoder components into a conditionally build A unified model that ties all decoder components into a conditionally build
YOLO decoder. YOLO decder.
Args: Args:
input_specs: `dict[str, tf.InputSpec]`: input specs of each of the inputs input_specs: `dict[str, tf.InputSpec]`: input specs of each of the inputs
...@@ -403,7 +403,7 @@ class YoloDecoder(tf.keras.Model): ...@@ -403,7 +403,7 @@ class YoloDecoder(tf.keras.Model):
zero. zero.
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment