Commit 7752c35a authored by Vishnu Banna's avatar Vishnu Banna
Browse files

decoder update

parent f27d88f9
......@@ -13,7 +13,7 @@
# limitations under the License.
# Lint as: python3
"""Feature Pyramid Network and Path Aggregation variants used in YOLO."""
"""Feature Pyramid Network and Path Aggregation variants used in YOLO"""
import tensorflow as tf
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
......@@ -22,7 +22,7 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
@tf.keras.utils.register_keras_serializable(package='yolo')
class _IdentityRoute(tf.keras.layers.Layer):
def call(self, inputs):
def call(self, inputs): # pylint: disable=arguments-differ
return None, inputs
......@@ -39,7 +39,7 @@ class YoloFPN(tf.keras.layers.Layer):
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='glorot_uniform',
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
......@@ -172,7 +172,7 @@ class YoloFPN(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloPAN(tf.keras.layers.Layer):
"""YOLO Path Aggregation Network."""
"""YOLO Path Aggregation Network"""
def __init__(self,
path_process_len=6,
......@@ -184,7 +184,7 @@ class YoloPAN(tf.keras.layers.Layer):
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='glorot_uniform',
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
fpn_input=True,
......@@ -374,14 +374,14 @@ class YoloDecoder(tf.keras.Model):
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='glorot_uniform',
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""Yolo Decoder initialization function.
A unified model that ties all decoder components into a conditionally build
YOLO decoder.
YOLO decder.
Args:
input_specs: `dict[str, tf.InputSpec]`: input specs of each of the inputs
......@@ -403,7 +403,7 @@ class YoloDecoder(tf.keras.Model):
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment