Commit f246dd8f authored by Anirudh Vegesana's avatar Anirudh Vegesana Committed by Jaeyoun Kim
Browse files

YOLO Family: Updated model (#9923)



* Update YOLO model

* Fix some docstrings

* Fix docstrings

* Address some of Dr. Davis' changes

* Give descriptive names to the test cases

* Fix bugs

* Fix YOLO head imports

* docstring and variable name updates

* docstring and variable name updates

* docstring and variable name updates
Co-authored-by: default avatarvishnubanna <banna3vishnu@gmail.com>
Co-authored-by: default avatarVishnu Banna <43182884+vishnubanna@users.noreply.github.com>
parent bcbce005
...@@ -43,6 +43,7 @@ from official.modeling import hyperparams ...@@ -43,6 +43,7 @@ from official.modeling import hyperparams
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
# builder required classes
class BlockConfig: class BlockConfig:
"""Class to store layer config to make code more readable.""" """Class to store layer config to make code more readable."""
...@@ -666,7 +667,9 @@ def build_darknet( ...@@ -666,7 +667,9 @@ def build_darknet(
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds darknet.""" """Builds darknet."""
backbone_cfg = backbone_config.get() backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
model = Darknet( model = Darknet(
model_id=backbone_cfg.model_id, model_id=backbone_cfg.model_id,
min_level=backbone_cfg.min_level, min_level=backbone_cfg.min_level,
......
...@@ -70,7 +70,7 @@ class DarknetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -70,7 +70,7 @@ class DarknetTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
strategy=[ strategy=[
strategy_combinations.cloud_tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
use_sync_bn=[False, True], use_sync_bn=[False, True],
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# Lint as: python3 # Lint as: python3
"""Contains common building blocks for yolo neural networks.""" """Contains common building blocks for yolo neural networks."""
from typing import Callable, List from typing import Callable, List
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -24,6 +23,9 @@ from official.vision.beta.ops import spatial_transform_ops ...@@ -24,6 +23,9 @@ from official.vision.beta.ops import spatial_transform_ops
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class Identity(tf.keras.layers.Layer): class Identity(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, inputs): def call(self, inputs):
return inputs return inputs
...@@ -640,6 +642,11 @@ class CSPRoute(tf.keras.layers.Layer): ...@@ -640,6 +642,11 @@ class CSPRoute(tf.keras.layers.Layer):
x = self._conv3(inputs) x = self._conv3(inputs)
return (x, y) return (x, y)
self._conv2 = ConvBN(
filters=self._filters // self._filter_scale,
kernel_size=(1, 1),
strides=(1, 1),
**dark_conv_args)
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class CSPConnect(tf.keras.layers.Layer): class CSPConnect(tf.keras.layers.Layer):
...@@ -797,7 +804,6 @@ class CSPStack(tf.keras.layers.Layer): ...@@ -797,7 +804,6 @@ class CSPStack(tf.keras.layers.Layer):
"""CSPStack layer initializer. """CSPStack layer initializer.
Args: Args:
filters: integer for output depth, or the number of features to learn.
model_to_wrap: callable Model or a list of callable objects that will model_to_wrap: callable Model or a list of callable objects that will
process the output of CSPRoute, and be input into CSPConnect. process the output of CSPRoute, and be input into CSPConnect.
list will be called sequentially. list will be called sequentially.
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import numpy as np
from absl.testing import parameterized
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
...@@ -71,7 +73,7 @@ class CSPRouteTest(tf.test.TestCase, parameterized.TestCase): ...@@ -71,7 +73,7 @@ class CSPRouteTest(tf.test.TestCase, parameterized.TestCase):
def test_pass_through(self, width, height, filters, mod): def test_pass_through(self, width, height, filters, mod):
x = tf.keras.Input(shape=(width, height, filters)) x = tf.keras.Input(shape=(width, height, filters))
test_layer = nn_blocks.CSPRoute(filters=filters, filter_scale=mod) test_layer = nn_blocks.CSPRoute(filters=filters, filter_scale=mod)
outx, _ = test_layer(x) outx, px = test_layer(x)
print(outx) print(outx)
print(outx.shape.as_list()) print(outx.shape.as_list())
self.assertAllEqual( self.assertAllEqual(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment