Commit 350b0aa6 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

decoder update

parent 7752c35a
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
"""Feature Pyramid Network and Path Aggregation variants used in YOLO""" """Feature Pyramid Network and Path Aggregation variants used in YOLO."""
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
...@@ -206,7 +206,7 @@ class YoloPAN(tf.keras.layers.Layer): ...@@ -206,7 +206,7 @@ class YoloPAN(tf.keras.layers.Layer):
by zero. by zero.
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
fpn_input: `bool`, for whether the input into this fucntion is an FPN or fpn_input: `bool`, for whether the input into this fucntion is an FPN or
a backbone. a backbone.
fpn_filter_scale: `int`, scaling factor for the FPN filters. fpn_filter_scale: `int`, scaling factor for the FPN filters.
...@@ -381,7 +381,7 @@ class YoloDecoder(tf.keras.Model): ...@@ -381,7 +381,7 @@ class YoloDecoder(tf.keras.Model):
"""Yolo Decoder initialization function. """Yolo Decoder initialization function.
A unified model that ties all decoder components into a conditionally build A unified model that ties all decoder components into a conditionally build
YOLO decder. YOLO decoder.
Args: Args:
input_specs: `dict[str, tf.InputSpec]`: input specs of each of the inputs input_specs: `dict[str, tf.InputSpec]`: input specs of each of the inputs
...@@ -403,7 +403,7 @@ class YoloDecoder(tf.keras.Model): ...@@ -403,7 +403,7 @@ class YoloDecoder(tf.keras.Model):
zero. zero.
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment