Commit f246dd8f authored by Anirudh Vegesana's avatar Anirudh Vegesana Committed by Jaeyoun Kim
Browse files

YOLO Family: Updated model (#9923)



* Update YOLO model

* Fix some docstrings

* Fix docstrings

* Address some of Dr. Davis' changes

* Give descriptive names to the test cases

* Fix bugs

* Fix YOLO head imports

* docstring and variable name updates

* docstring and variable name updates

* docstring and variable name updates
Co-authored-by: default avatarvishnubanna <banna3vishnu@gmail.com>
Co-authored-by: default avatarVishnu Banna <43182884+vishnubanna@users.noreply.github.com>
parent bcbce005
......@@ -43,6 +43,7 @@ from official.modeling import hyperparams
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
# builder required classes
class BlockConfig:
"""Class to store layer config to make code more readable."""
......@@ -666,7 +667,9 @@ def build_darknet(
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds darknet."""
backbone_cfg = backbone_config.get()
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
model = Darknet(
model_id=backbone_cfg.model_id,
min_level=backbone_cfg.min_level,
......
......@@ -70,7 +70,7 @@ class DarknetTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
use_sync_bn=[False, True],
......
......@@ -14,7 +14,6 @@
# Lint as: python3
"""Contains common building blocks for yolo neural networks."""
from typing import Callable, List
import tensorflow as tf
from official.modeling import tf_utils
......@@ -24,6 +23,9 @@ from official.vision.beta.ops import spatial_transform_ops
@tf.keras.utils.register_keras_serializable(package='yolo')
class Identity(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, inputs):
return inputs
......@@ -640,6 +642,11 @@ class CSPRoute(tf.keras.layers.Layer):
x = self._conv3(inputs)
return (x, y)
self._conv2 = ConvBN(
filters=self._filters // self._filter_scale,
kernel_size=(1, 1),
strides=(1, 1),
**dark_conv_args)
@tf.keras.utils.register_keras_serializable(package='yolo')
class CSPConnect(tf.keras.layers.Layer):
......@@ -797,7 +804,6 @@ class CSPStack(tf.keras.layers.Layer):
"""CSPStack layer initializer.
Args:
filters: integer for output depth, or the number of features to learn.
model_to_wrap: callable Model or a list of callable objects that will
process the output of CSPRoute, and be input into CSPConnect.
list will be called sequentially.
......
......@@ -16,6 +16,8 @@
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import numpy as np
from absl.testing import parameterized
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
......@@ -71,7 +73,7 @@ class CSPRouteTest(tf.test.TestCase, parameterized.TestCase):
def test_pass_through(self, width, height, filters, mod):
x = tf.keras.Input(shape=(width, height, filters))
test_layer = nn_blocks.CSPRoute(filters=filters, filter_scale=mod)
outx, _ = test_layer(x)
outx, px = test_layer(x)
print(outx)
print(outx.shape.as_list())
self.assertAllEqual(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment