Commit 588d6da4 authored by Jaeyoun Kim's avatar Jaeyoun Kim Committed by A. Unique TensorFlower
Browse files

Copybara import of the project:

--
63719f08

 by Anirudh Vegesana <anirudh.vegesana@gmail.com>:

YOLO Family: Updated model (#9923)

* Update YOLO model

* Fix some docstrings

* Fix docstrings

* Address some of Dr. Davis' changes

* Give descriptive names to the test cases

* Fix bugs

* Fix YOLO head imports

* docstring and variable name updates

* docstring and variable name updates

* docstring and variable name updates
Co-authored-by: default avatarvishnubanna <banna3vishnu@gmail.com>
Co-authored-by: default avatarVishnu Banna <43182884+vishnubanna@users.noreply.github.com>
--
725b8c8c

 by Anirudh Vegesana <anirudh.vegesana@gmail.com>:

disclaimer (#10020)
Co-authored-by: default avatarVishnu Banna <43182884+vishnubanna@users.noreply.github.com>
--
404d24b0

 by Anirudh Vegesana <anirudh.vegesana@gmail.com>:

YOLO Family: Linting (#10027)

* YOLO Family: Updated model (#9923)

* Update YOLO model

* Fix some docstrings

* Fix docstrings

* Address some of Dr. Davis' changes

* Give descriptive names to the test cases

* Fix bugs

* Fix YOLO head imports

* docstring and variable name updates

* docstring and variable name updates

* docstring and variable name updates
Co-authored-by: default avatarvishnubanna <banna3vishnu@gmail.com>
Co-authored-by: default avatarVishnu Banna <43182884+vishnubanna@users.noreply.github.com>

* disclaimer

* Fix some PyLint errors
Co-authored-by: default avatarvishnubanna <banna3vishnu@gmail.com>
Co-authored-by: default avatarVishnu Banna <43182884+vishnubanna@users.noreply.github.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/models/pull/10021 from tensorflow:purdue-yolo 404d24b0
PiperOrigin-RevId: 379372162
parent e15c0aec
DISCLAIMER: this YOLO implementation is still under development. No support will
be provided during the development phase.
# YOLO Object Detectors, You Only Look Once # YOLO Object Detectors, You Only Look Once
[![Paper](http://img.shields.io/badge/Paper-arXiv.1804.02767-B3181B?logo=arXiv)](https://arxiv.org/abs/1804.02767) [![Paper](http://img.shields.io/badge/Paper-arXiv.1804.02767-B3181B?logo=arXiv)](https://arxiv.org/abs/1804.02767)
...@@ -74,3 +77,5 @@ head could be connected to a new, more powerful backbone if a person chose to. ...@@ -74,3 +77,5 @@ head could be connected to a new, more powerful backbone if a person chose to.
[![TensorFlow 2.2](https://img.shields.io/badge/TensorFlow-2.2-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.2.0) [![TensorFlow 2.2](https://img.shields.io/badge/TensorFlow-2.2-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.2.0)
[![Python 3.8](https://img.shields.io/badge/Python-3.8-3776AB)](https://www.python.org/downloads/release/python-380/) [![Python 3.8](https://img.shields.io/badge/Python-3.8-3776AB)](https://www.python.org/downloads/release/python-380/)
...@@ -24,11 +24,14 @@ from official.vision.beta.configs import backbones ...@@ -24,11 +24,14 @@ from official.vision.beta.configs import backbones
@dataclasses.dataclass @dataclasses.dataclass
class DarkNet(hyperparams.Config): class Darknet(hyperparams.Config):
"""DarkNet config.""" """Darknet config."""
model_id: str = "darknet53" model_id: str = 'darknet53'
width_scale: float = 1.0
depth_scale: float = 1.0
dilate: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class Backbone(backbones.Backbone): class Backbone(backbones.Backbone):
darknet: DarkNet = DarkNet() darknet: Darknet = Darknet()
...@@ -32,7 +32,7 @@ class ImageClassificationModel(hyperparams.Config): ...@@ -32,7 +32,7 @@ class ImageClassificationModel(hyperparams.Config):
num_classes: int = 0 num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list) input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone( backbone: backbones.Backbone = backbones.Backbone(
type='darknet', resnet=backbones.DarkNet()) type='darknet', darknet=backbones.Darknet())
dropout_rate: float = 0.0 dropout_rate: float = 0.0
norm_activation: common.NormActivation = common.NormActivation() norm_activation: common.NormActivation = common.NormActivation()
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
"""Tests for resnet.""" """Tests for yolo."""
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
...@@ -24,35 +24,48 @@ from tensorflow.python.distribute import strategy_combinations ...@@ -24,35 +24,48 @@ from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.projects.yolo.modeling.backbones import darknet from official.vision.beta.projects.yolo.modeling.backbones import darknet
class DarkNetTest(parameterized.TestCase, tf.test.TestCase): class DarknetTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(224, "darknet53", 2, 1), (224, 'darknet53', 2, 1, True),
(224, "darknettiny", 1, 2), (224, 'darknettiny', 1, 2, False),
(224, "cspdarknettiny", 1, 1), (224, 'cspdarknettiny', 1, 1, False),
(224, "cspdarknet53", 2, 1), (224, 'cspdarknet53', 2, 1, True),
) )
def test_network_creation(self, input_size, model_id, def test_network_creation(self, input_size, model_id, endpoint_filter_scale,
endpoint_filter_scale, scale_final): scale_final, dilate):
"""Test creation of ResNet family models.""" """Test creation of ResNet family models."""
tf.keras.backend.set_image_data_format("channels_last") tf.keras.backend.set_image_data_format('channels_last')
network = darknet.Darknet(model_id=model_id, min_level=3, max_level=5) network = darknet.Darknet(
model_id=model_id, min_level=3, max_level=5, dilate=dilate)
self.assertEqual(network.model_id, model_id) self.assertEqual(network.model_id, model_id)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1) inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs) endpoints = network(inputs)
self.assertAllEqual( if dilate:
[1, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale], self.assertAllEqual([
endpoints["3"].shape.as_list()) 1, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale
self.assertAllEqual( ], endpoints['3'].shape.as_list())
[1, input_size / 2**4, input_size / 2**4, 256 * endpoint_filter_scale], self.assertAllEqual([
endpoints["4"].shape.as_list()) 1, input_size / 2**3, input_size / 2**3, 256 * endpoint_filter_scale
self.assertAllEqual([ ], endpoints['4'].shape.as_list())
1, input_size / 2**5, input_size / 2**5, self.assertAllEqual([
512 * endpoint_filter_scale * scale_final 1, input_size / 2**3, input_size / 2**3,
], endpoints["5"].shape.as_list()) 512 * endpoint_filter_scale * scale_final
], endpoints['5'].shape.as_list())
else:
self.assertAllEqual([
1, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale
], endpoints['3'].shape.as_list())
self.assertAllEqual([
1, input_size / 2**4, input_size / 2**4, 256 * endpoint_filter_scale
], endpoints['4'].shape.as_list())
self.assertAllEqual([
1, input_size / 2**5, input_size / 2**5,
512 * endpoint_filter_scale * scale_final
], endpoints['5'].shape.as_list())
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
...@@ -66,20 +79,20 @@ class DarkNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -66,20 +79,20 @@ class DarkNetTest(parameterized.TestCase, tf.test.TestCase):
"""Test for sync bn on TPU and GPU devices.""" """Test for sync bn on TPU and GPU devices."""
inputs = np.random.rand(1, 224, 224, 3) inputs = np.random.rand(1, 224, 224, 3)
tf.keras.backend.set_image_data_format("channels_last") tf.keras.backend.set_image_data_format('channels_last')
with strategy.scope(): with strategy.scope():
network = darknet.Darknet(model_id="darknet53", min_size=3, max_size=5) network = darknet.Darknet(model_id='darknet53', min_size=3, max_size=5)
_ = network(inputs) _ = network(inputs)
@parameterized.parameters(1, 3, 4) @parameterized.parameters(1, 3, 4)
def test_input_specs(self, input_dim): def test_input_specs(self, input_dim):
"""Test different input feature dimensions.""" """Test different input feature dimensions."""
tf.keras.backend.set_image_data_format("channels_last") tf.keras.backend.set_image_data_format('channels_last')
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, input_dim]) input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, input_dim])
network = darknet.Darknet( network = darknet.Darknet(
model_id="darknet53", min_level=3, max_level=5, input_specs=input_specs) model_id='darknet53', min_level=3, max_level=5, input_specs=input_specs)
inputs = tf.keras.Input(shape=(224, 224, input_dim), batch_size=1) inputs = tf.keras.Input(shape=(224, 224, input_dim), batch_size=1)
_ = network(inputs) _ = network(inputs)
...@@ -87,14 +100,14 @@ class DarkNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -87,14 +100,14 @@ class DarkNetTest(parameterized.TestCase, tf.test.TestCase):
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
kwargs = dict( kwargs = dict(
model_id="darknet53", model_id='darknet53',
min_level=3, min_level=3,
max_level=5, max_level=5,
use_sync_bn=False, use_sync_bn=False,
activation="relu", activation='relu',
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
kernel_initializer="VarianceScaling", kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
) )
...@@ -113,5 +126,5 @@ class DarkNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -113,5 +126,5 @@ class DarkNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(network.get_config(), new_network.get_config()) self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == "__main__": if __name__ == '__main__':
tf.test.main() tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Feature Pyramid Network and Path Aggregation variants used in YOLO."""
import tensorflow as tf
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
@tf.keras.utils.register_keras_serializable(package='yolo')
class _IdentityRoute(tf.keras.layers.Layer):
def call(self, inputs):
return None, inputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloFPN(tf.keras.layers.Layer):
"""YOLO Feature pyramid network."""
def __init__(self,
fpn_depth=4,
use_spatial_attention=False,
csp_stack=False,
activation='leaky',
fpn_filter_scale=1,
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='glorot_uniform',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""Yolo FPN initialization function (Yolo V4).
Args:
fpn_depth: `int`, number of layers to use in each FPN path
if you choose to use an FPN.
use_spatial_attention: `bool`, use the spatial attention module.
csp_stack: `bool`, CSPize the FPN.
activation: `str`, the activation function to use typically leaky or mish.
fpn_filter_scale: `int`, scaling factor for the FPN filters.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float`, normalization momentum for the moving average.
norm_epsilon: `float`, small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
**kwargs: keyword arguments to be passed.
"""
super().__init__(**kwargs)
self._fpn_depth = fpn_depth
self._activation = activation
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_spatial_attention = use_spatial_attention
self._filter_scale = fpn_filter_scale
self._csp_stack = csp_stack
self._base_config = dict(
activation=self._activation,
use_sync_bn=self._use_sync_bn,
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
bias_regularizer=self._bias_regularizer,
norm_epsilon=self._norm_epsilon,
norm_momentum=self._norm_momentum)
def get_raw_depths(self, minimum_depth, inputs):
"""Calculates the unscaled depths of the FPN branches.
Args:
minimum_depth (int): depth of the smallest branch of the FPN.
inputs (dict): dictionary of the shape of input args as a dictionary of
lists.
Returns:
The unscaled depths of the FPN branches.
"""
depths = []
for i in range(self._min_level, self._max_level + 1):
depths.append(inputs[str(i)][-1] / self._filter_scale)
return list(reversed(depths))
def build(self, inputs):
"""Use config dictionary to generate all important attributes for head.
Args:
inputs: dictionary of the shape of input args as a dictionary of lists.
"""
keys = [int(key) for key in inputs.keys()]
self._min_level = min(keys)
self._max_level = max(keys)
self._min_depth = inputs[str(self._min_level)][-1]
self._depths = self.get_raw_depths(self._min_depth, inputs)
# directly connect to an input path and process it
self.preprocessors = dict()
# resample an input and merge it with the output of another path
# inorder to aggregate backbone outputs
self.resamples = dict()
# set of convoltion layers and upsample layers that are used to
# prepare the FPN processors for output
for level, depth in zip(
reversed(range(self._min_level, self._max_level + 1)), self._depths):
if level == self._min_level:
self.resamples[str(level)] = nn_blocks.PathAggregationBlock(
filters=depth // 2,
inverted=True,
upsample=True,
drop_final=self._csp_stack == 0,
upsample_size=2,
**self._base_config)
self.preprocessors[str(level)] = _IdentityRoute()
elif level != self._max_level:
self.resamples[str(level)] = nn_blocks.PathAggregationBlock(
filters=depth // 2,
inverted=True,
upsample=True,
drop_final=False,
upsample_size=2,
**self._base_config)
self.preprocessors[str(level)] = nn_blocks.DarkRouteProcess(
filters=depth,
repetitions=self._fpn_depth - int(level == self._min_level),
block_invert=True,
insert_spp=False,
csp_stack=self._csp_stack,
**self._base_config)
else:
self.preprocessors[str(level)] = nn_blocks.DarkRouteProcess(
filters=depth,
repetitions=self._fpn_depth + 1 * int(self._csp_stack == 0),
insert_spp=True,
block_invert=False,
csp_stack=self._csp_stack,
**self._base_config)
def call(self, inputs):
outputs = dict()
layer_in = inputs[str(self._max_level)]
for level in reversed(range(self._min_level, self._max_level + 1)):
_, x = self.preprocessors[str(level)](layer_in)
outputs[str(level)] = x
if level > self._min_level:
x_next = inputs[str(level - 1)]
_, layer_in = self.resamples[str(level - 1)]([x_next, x])
return outputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloPAN(tf.keras.layers.Layer):
"""YOLO Path Aggregation Network."""
def __init__(self,
path_process_len=6,
max_level_process_len=None,
embed_spp=False,
use_spatial_attention=False,
csp_stack=False,
activation='leaky',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='glorot_uniform',
kernel_regularizer=None,
bias_regularizer=None,
fpn_input=True,
fpn_filter_scale=1.0,
**kwargs):
"""Yolo Path Aggregation Network initialization function (Yolo V3 and V4).
Args:
path_process_len: `int`, number of layers ot use in each Decoder path.
max_level_process_len: `int`, number of layers ot use in the largest
processing path, or the backbones largest output if it is different.
embed_spp: `bool`, use the SPP found in the YoloV3 and V4 model.
use_spatial_attention: `bool`, use the spatial attention module.
csp_stack: `bool`, CSPize the FPN.
activation: `str`, the activation function to use typically leaky or mish.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float`, normalization omentum for the moving average.
norm_epsilon: `float`, small float added to variance to avoid dividing
by zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
fpn_input: `bool`, for whether the input into this fucntion is an FPN or
a backbone.
fpn_filter_scale: `int`, scaling factor for the FPN filters.
**kwargs: keyword arguments to be passed.
"""
super().__init__(**kwargs)
self._path_process_len = path_process_len
self._embed_spp = embed_spp
self._use_spatial_attention = use_spatial_attention
self._activation = activation
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._fpn_input = fpn_input
self._max_level_process_len = max_level_process_len
self._csp_stack = csp_stack
self._fpn_filter_scale = fpn_filter_scale
if max_level_process_len is None:
self._max_level_process_len = path_process_len
self._base_config = dict(
activation=self._activation,
use_sync_bn=self._use_sync_bn,
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
bias_regularizer=self._bias_regularizer,
norm_epsilon=self._norm_epsilon,
norm_momentum=self._norm_momentum)
def build(self, inputs):
"""Use config dictionary to generate all important attributes for head.
Args:
inputs: dictionary of the shape of input args as a dictionary of lists.
"""
# define the key order
keys = [int(key) for key in inputs.keys()]
self._min_level = min(keys)
self._max_level = max(keys)
self._min_depth = inputs[str(self._min_level)][-1]
self._depths = self.get_raw_depths(self._min_depth, inputs)
# directly connect to an input path and process it
self.preprocessors = dict()
# resample an input and merge it with the output of another path
# inorder to aggregate backbone outputs
self.resamples = dict()
# FPN will reverse the key process order for the backbone, so we need
# adjust the order that objects are created and processed to adjust for
# this. not using an FPN will directly connect the decoder to the backbone
# therefore the object creation order needs to be done from the largest
# to smallest level.
if self._fpn_input:
# process order {... 3, 4, 5}
self._iterator = range(self._min_level, self._max_level + 1)
self._check = lambda x: x < self._max_level
self._key_shift = lambda x: x + 1
self._input = self._min_level
downsample = True
upsample = False
else:
# process order {5, 4, 3, ...}
self._iterator = list(
reversed(range(self._min_level, self._max_level + 1)))
self._check = lambda x: x > self._min_level
self._key_shift = lambda x: x - 1
self._input = self._max_level
downsample = False
upsample = True
if self._csp_stack == 0:
proc_filters = lambda x: x
resample_filters = lambda x: x // 2
else:
proc_filters = lambda x: x * 2
resample_filters = lambda x: x
for level, depth in zip(self._iterator, self._depths):
if level == self._input:
self.preprocessors[str(level)] = nn_blocks.DarkRouteProcess(
filters=proc_filters(depth),
repetitions=self._max_level_process_len,
insert_spp=self._embed_spp,
block_invert=False,
insert_sam=self._use_spatial_attention,
csp_stack=self._csp_stack,
**self._base_config)
else:
self.resamples[str(level)] = nn_blocks.PathAggregationBlock(
filters=resample_filters(depth),
upsample=upsample,
downsample=downsample,
inverted=False,
drop_final=self._csp_stack == 0,
**self._base_config)
self.preprocessors[str(level)] = nn_blocks.DarkRouteProcess(
filters=proc_filters(depth),
repetitions=self._path_process_len,
insert_spp=False,
insert_sam=self._use_spatial_attention,
csp_stack=self._csp_stack,
**self._base_config)
def get_raw_depths(self, minimum_depth, inputs):
"""Calculates the unscaled depths of the FPN branches.
Args:
minimum_depth: `int` depth of the smallest branch of the FPN.
inputs: `dict[str, tf.InputSpec]` of the shape of input args as a
dictionary of lists.
Returns:
The unscaled depths of the FPN branches.
"""
depths = []
if len(inputs.keys()) > 3 or self._fpn_filter_scale > 1:
for i in range(self._min_level, self._max_level + 1):
depths.append(inputs[str(i)][-1] * 2)
else:
for _ in range(self._min_level, self._max_level + 1):
depths.append(minimum_depth)
minimum_depth *= 2
if self._fpn_input:
return depths
return list(reversed(depths))
def call(self, inputs):
outputs = dict()
layer_in = inputs[str(self._input)]
for level in self._iterator:
x_route, x = self.preprocessors[str(level)](layer_in)
outputs[str(level)] = x
if self._check(level):
x_next = inputs[str(self._key_shift(level))]
_, layer_in = self.resamples[str(
self._key_shift(level))]([x_route, x_next])
return outputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloDecoder(tf.keras.Model):
"""Darknet Backbone Decoder."""
def __init__(self,
input_specs,
use_fpn=False,
use_spatial_attention=False,
csp_stack=False,
fpn_depth=4,
fpn_filter_scale=1,
path_process_len=6,
max_level_process_len=None,
embed_spp=False,
activation='leaky',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='glorot_uniform',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""Yolo Decoder initialization function.
A unified model that ties all decoder components into a conditionally build
YOLO decoder.
Args:
input_specs: `dict[str, tf.InputSpec]`: input specs of each of the inputs
to the heads.
use_fpn: `bool`, use the FPN found in the YoloV4 model.
use_spatial_attention: `bool`, use the spatial attention module.
csp_stack: `bool`, CSPize the FPN.
fpn_depth: `int`, number of layers ot use in each FPN path
if you choose to use an FPN.
fpn_filter_scale: `int`, scaling factor for the FPN filters.
path_process_len: `int`, number of layers ot use in each Decoder path.
max_level_process_len: `int`, number of layers ot use in the largest
processing path, or the backbones largest output if it is different.
embed_spp: `bool`, use the SPP found in the YoloV3 and V4 model.
activation: `str`, the activation function to use typically leaky or mish.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float`, normalization omentum for the moving average.
norm_epsilon: `float`, small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
**kwargs: keyword arguments to be passed.
"""
self._input_specs = input_specs
self._use_fpn = use_fpn
self._fpn_depth = fpn_depth
self._path_process_len = path_process_len
self._max_level_process_len = max_level_process_len
self._embed_spp = embed_spp
self._activation = activation
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._base_config = dict(
use_spatial_attention=use_spatial_attention,
csp_stack=csp_stack,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
fpn_filter_scale=fpn_filter_scale,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._decoder_config = dict(
path_process_len=self._path_process_len,
max_level_process_len=self._max_level_process_len,
embed_spp=self._embed_spp,
fpn_input=self._use_fpn,
**self._base_config)
inputs = {
key: tf.keras.layers.Input(shape=value[1:])
for key, value in input_specs.items()
}
if self._use_fpn:
inter_outs = YoloFPN(
fpn_depth=self._fpn_depth, **self._base_config)(
inputs)
outputs = YoloPAN(**self._decoder_config)(inter_outs)
else:
inter_outs = None
outputs = YoloPAN(**self._decoder_config)(inputs)
self._output_specs = {key: value.shape for key, value in outputs.items()}
super().__init__(inputs=inputs, outputs=outputs, name='YoloDecoder')
@property
def use_fpn(self):
return self._use_fpn
@property
def output_specs(self):
return self._output_specs
def get_config(self):
config = dict(
input_specs=self._input_specs,
use_fpn=self._use_fpn,
fpn_depth=self._fpn_depth,
**self._decoder_config)
return config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for YOLO."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.projects.yolo.modeling.decoders import yolo_decoder as decoders
class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
def _build_yolo_decoder(self, input_specs, name='1'):
# Builds 4 different arbitrary decoders.
if name == '1':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=False,
use_fpn=False,
max_level_process_len=2,
path_process_len=1,
activation='mish')
elif name == '6spp':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=True,
use_fpn=False,
max_level_process_len=None,
path_process_len=6,
activation='mish')
elif name == '6sppfpn':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=True,
use_fpn=True,
max_level_process_len=None,
path_process_len=6,
activation='mish')
elif name == '6':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=False,
use_fpn=False,
max_level_process_len=None,
path_process_len=6,
activation='mish')
else:
raise NotImplementedError(f'YOLO decoder test {type} not implemented.')
return model
@parameterized.parameters('1', '6spp', '6sppfpn', '6')
def test_network_creation(self, version):
"""Test creation of ResNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
input_shape = {
'3': [1, 52, 52, 256],
'4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024]
}
decoder = self._build_yolo_decoder(input_shape, version)
inputs = {}
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
endpoints = decoder.call(inputs)
for key in endpoints.keys():
self.assertAllEqual(endpoints[key].shape.as_list(), input_shape[key])
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
use_sync_bn=[False, True],
))
def test_sync_bn_multiple_devices(self, strategy, use_sync_bn):
"""Test for sync bn on TPU and GPU devices."""
tf.keras.backend.set_image_data_format('channels_last')
with strategy.scope():
input_shape = {
'3': [1, 52, 52, 256],
'4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024]
}
decoder = self._build_yolo_decoder(input_shape, '6')
inputs = {}
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = decoder.call(inputs)
@parameterized.parameters(1, 3, 4)
def test_input_specs(self, input_dim):
"""Test different input feature dimensions."""
tf.keras.backend.set_image_data_format('channels_last')
input_shape = {
'3': [1, 52, 52, 256],
'4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024]
}
decoder = self._build_yolo_decoder(input_shape, '6')
inputs = {}
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = decoder(inputs)
def test_serialize_deserialize(self):
"""Create a network object that sets all of its config options."""
tf.keras.backend.set_image_data_format('channels_last')
input_shape = {
'3': [1, 52, 52, 256],
'4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024]
}
decoder = self._build_yolo_decoder(input_shape, '6')
inputs = {}
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = decoder(inputs)
config = decoder.get_config()
decoder_from_config = decoders.YoloDecoder.from_config(config)
self.assertAllEqual(decoder.get_config(), decoder_from_config.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Yolo heads."""
import tensorflow as tf
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
class YoloHead(tf.keras.layers.Layer):
"""YOLO Prediction Head."""
def __init__(self,
min_level,
max_level,
classes=80,
boxes_per_level=3,
output_extras=0,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='glorot_uniform',
kernel_regularizer=None,
bias_regularizer=None,
activation=None,
**kwargs):
"""Yolo Prediction Head initialization function.
Args:
min_level: `int`, the minimum backbone output level.
max_level: `int`, the maximum backbone output level.
classes: `int`, number of classes per category.
boxes_per_level: `int`, number of boxes to predict per level.
output_extras: `int`, number of additional output channels that the head.
should predict for non-object detection and non-image classification
tasks.
norm_momentum: `float`, normalization momentum for the moving average.
norm_epsilon: `float`, small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
activation: `str`, the activation function to use typically leaky or mish.
**kwargs: keyword arguments to be passed.
"""
super().__init__(**kwargs)
self._min_level = min_level
self._max_level = max_level
self._key_list = [
str(key) for key in range(self._min_level, self._max_level + 1)
]
self._classes = classes
self._boxes_per_level = boxes_per_level
self._output_extras = output_extras
self._output_conv = (classes + output_extras + 5) * boxes_per_level
self._base_config = dict(
activation=activation,
norm_momentum=norm_momentum,
norm_epsilon=norm_epsilon,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer)
self._conv_config = dict(
filters=self._output_conv,
kernel_size=(1, 1),
strides=(1, 1),
padding='same',
use_bn=False,
**self._base_config)
def build(self, input_shape):
self._head = dict()
for key in self._key_list:
self._head[key] = nn_blocks.ConvBN(**self._conv_config)
def call(self, inputs):
outputs = dict()
for key in self._key_list:
outputs[key] = self._head[key](inputs[key])
return outputs
@property
def output_depth(self):
return (self._classes + self._output_extras + 5) * self._boxes_per_level
@property
def num_boxes(self):
if self._min_level is None or self._max_level is None:
raise Exception(
'Model has to be built before number of boxes can be determined.')
return (self._max_level - self._min_level + 1) * self._boxes_per_level
def get_config(self):
config = dict(
min_level=self._min_level,
max_level=self._max_level,
classes=self._classes,
boxes_per_level=self._boxes_per_level,
output_extras=self._output_extras,
**self._base_config)
return config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for yolo heads."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.yolo.modeling.heads import yolo_head as heads
class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
def test_network_creation(self):
"""Test creation of YOLO family models."""
tf.keras.backend.set_image_data_format('channels_last')
input_shape = {
'3': [1, 52, 52, 256],
'4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024]
}
classes = 100
bps = 3
head = heads.YoloHead(3, 5, classes=classes, boxes_per_level=bps)
inputs = {}
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
endpoints = head(inputs)
# print(endpoints)
for key in endpoints.keys():
expected_input_shape = input_shape[key]
expected_input_shape[-1] = (classes + 5) * bps
self.assertAllEqual(endpoints[key].shape.as_list(), expected_input_shape)
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
tf.keras.backend.set_image_data_format('channels_last')
input_shape = {
'3': [1, 52, 52, 256],
'4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024]
}
classes = 100
bps = 3
head = heads.YoloHead(3, 5, classes=classes, boxes_per_level=bps)
inputs = {}
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = head(inputs)
configs = head.get_config()
head_from_config = heads.YoloHead.from_config(configs)
self.assertAllEqual(head.get_config(), head_from_config.get_config())
if __name__ == '__main__':
tf.test.main()
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -23,8 +22,8 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks ...@@ -23,8 +22,8 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
class CSPConnectTest(tf.test.TestCase, parameterized.TestCase): class CSPConnectTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("same", 224, 224, 64, 1), @parameterized.named_parameters(('same', 224, 224, 64, 1),
("downsample", 224, 224, 64, 2)) ('downsample', 224, 224, 64, 2))
def test_pass_through(self, width, height, filters, mod): def test_pass_through(self, width, height, filters, mod):
x = tf.keras.Input(shape=(width, height, filters)) x = tf.keras.Input(shape=(width, height, filters))
test_layer = nn_blocks.CSPRoute(filters=filters, filter_scale=mod) test_layer = nn_blocks.CSPRoute(filters=filters, filter_scale=mod)
...@@ -38,8 +37,8 @@ class CSPConnectTest(tf.test.TestCase, parameterized.TestCase): ...@@ -38,8 +37,8 @@ class CSPConnectTest(tf.test.TestCase, parameterized.TestCase):
[None, np.ceil(width // 2), [None, np.ceil(width // 2),
np.ceil(height // 2), (filters)]) np.ceil(height // 2), (filters)])
@parameterized.named_parameters(("same", 224, 224, 64, 1), @parameterized.named_parameters(('same', 224, 224, 64, 1),
("downsample", 224, 224, 128, 2)) ('downsample', 224, 224, 128, 2))
def test_gradient_pass_though(self, filters, width, height, mod): def test_gradient_pass_though(self, filters, width, height, mod):
loss = tf.keras.losses.MeanSquaredError() loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD() optimizer = tf.keras.optimizers.SGD()
...@@ -49,10 +48,11 @@ class CSPConnectTest(tf.test.TestCase, parameterized.TestCase): ...@@ -49,10 +48,11 @@ class CSPConnectTest(tf.test.TestCase, parameterized.TestCase):
init = tf.random_normal_initializer() init = tf.random_normal_initializer()
x = tf.Variable( x = tf.Variable(
initial_value=init(shape=(1, width, height, filters), dtype=tf.float32)) initial_value=init(shape=(1, width, height, filters), dtype=tf.float32))
y = tf.Variable(initial_value=init(shape=(1, int(np.ceil(width // 2)), y = tf.Variable(
int(np.ceil(height // 2)), initial_value=init(
filters), shape=(1, int(np.ceil(width // 2)), int(np.ceil(height // 2)),
dtype=tf.float32)) filters),
dtype=tf.float32))
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
x_hat, x_prev = test_layer(x) x_hat, x_prev = test_layer(x)
...@@ -66,8 +66,8 @@ class CSPConnectTest(tf.test.TestCase, parameterized.TestCase): ...@@ -66,8 +66,8 @@ class CSPConnectTest(tf.test.TestCase, parameterized.TestCase):
class CSPRouteTest(tf.test.TestCase, parameterized.TestCase): class CSPRouteTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("same", 224, 224, 64, 1), @parameterized.named_parameters(('same', 224, 224, 64, 1),
("downsample", 224, 224, 64, 2)) ('downsample', 224, 224, 64, 2))
def test_pass_through(self, width, height, filters, mod): def test_pass_through(self, width, height, filters, mod):
x = tf.keras.Input(shape=(width, height, filters)) x = tf.keras.Input(shape=(width, height, filters))
test_layer = nn_blocks.CSPRoute(filters=filters, filter_scale=mod) test_layer = nn_blocks.CSPRoute(filters=filters, filter_scale=mod)
...@@ -79,8 +79,8 @@ class CSPRouteTest(tf.test.TestCase, parameterized.TestCase): ...@@ -79,8 +79,8 @@ class CSPRouteTest(tf.test.TestCase, parameterized.TestCase):
[None, np.ceil(width // 2), [None, np.ceil(width // 2),
np.ceil(height // 2), (filters / mod)]) np.ceil(height // 2), (filters / mod)])
@parameterized.named_parameters(("same", 224, 224, 64, 1), @parameterized.named_parameters(('same', 224, 224, 64, 1),
("downsample", 224, 224, 128, 2)) ('downsample', 224, 224, 128, 2))
def test_gradient_pass_though(self, filters, width, height, mod): def test_gradient_pass_though(self, filters, width, height, mod):
loss = tf.keras.losses.MeanSquaredError() loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD() optimizer = tf.keras.optimizers.SGD()
...@@ -90,10 +90,11 @@ class CSPRouteTest(tf.test.TestCase, parameterized.TestCase): ...@@ -90,10 +90,11 @@ class CSPRouteTest(tf.test.TestCase, parameterized.TestCase):
init = tf.random_normal_initializer() init = tf.random_normal_initializer()
x = tf.Variable( x = tf.Variable(
initial_value=init(shape=(1, width, height, filters), dtype=tf.float32)) initial_value=init(shape=(1, width, height, filters), dtype=tf.float32))
y = tf.Variable(initial_value=init(shape=(1, int(np.ceil(width // 2)), y = tf.Variable(
int(np.ceil(height // 2)), initial_value=init(
filters), shape=(1, int(np.ceil(width // 2)), int(np.ceil(height // 2)),
dtype=tf.float32)) filters),
dtype=tf.float32))
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
x_hat, x_prev = test_layer(x) x_hat, x_prev = test_layer(x)
...@@ -107,11 +108,11 @@ class CSPRouteTest(tf.test.TestCase, parameterized.TestCase): ...@@ -107,11 +108,11 @@ class CSPRouteTest(tf.test.TestCase, parameterized.TestCase):
class CSPStackTest(tf.test.TestCase, parameterized.TestCase): class CSPStackTest(tf.test.TestCase, parameterized.TestCase):
def build_layer( def build_layer(self, layer_type, filters, filter_scale, count, stack_type,
self, layer_type, filters, filter_scale, count, stack_type, downsample): downsample):
if stack_type is not None: if stack_type is not None:
layers = [] layers = []
if layer_type == "residual": if layer_type == 'residual':
for _ in range(count): for _ in range(count):
layers.append( layers.append(
nn_blocks.DarkResidual( nn_blocks.DarkResidual(
...@@ -120,7 +121,7 @@ class CSPStackTest(tf.test.TestCase, parameterized.TestCase): ...@@ -120,7 +121,7 @@ class CSPStackTest(tf.test.TestCase, parameterized.TestCase):
for _ in range(count): for _ in range(count):
layers.append(nn_blocks.ConvBN(filters=filters)) layers.append(nn_blocks.ConvBN(filters=filters))
if stack_type == "model": if stack_type == 'model':
layers = tf.keras.Sequential(layers=layers) layers = tf.keras.Sequential(layers=layers)
else: else:
layers = None layers = None
...@@ -133,10 +134,10 @@ class CSPStackTest(tf.test.TestCase, parameterized.TestCase): ...@@ -133,10 +134,10 @@ class CSPStackTest(tf.test.TestCase, parameterized.TestCase):
return stack return stack
@parameterized.named_parameters( @parameterized.named_parameters(
("no_stack", 224, 224, 64, 2, "residual", None, 0, True), ('no_stack', 224, 224, 64, 2, 'residual', None, 0, True),
("residual_stack", 224, 224, 64, 2, "residual", "list", 2, True), ('residual_stack', 224, 224, 64, 2, 'residual', 'list', 2, True),
("conv_stack", 224, 224, 64, 2, "conv", "list", 3, False), ('conv_stack', 224, 224, 64, 2, 'conv', 'list', 3, False),
("callable_no_scale", 224, 224, 64, 1, "residual", "model", 5, False)) ('callable_no_scale', 224, 224, 64, 1, 'residual', 'model', 5, False))
def test_pass_through(self, width, height, filters, mod, layer_type, def test_pass_through(self, width, height, filters, mod, layer_type,
stack_type, count, downsample): stack_type, count, downsample):
x = tf.keras.Input(shape=(width, height, filters)) x = tf.keras.Input(shape=(width, height, filters))
...@@ -152,10 +153,10 @@ class CSPStackTest(tf.test.TestCase, parameterized.TestCase): ...@@ -152,10 +153,10 @@ class CSPStackTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllEqual(outx.shape.as_list(), [None, width, height, filters]) self.assertAllEqual(outx.shape.as_list(), [None, width, height, filters])
@parameterized.named_parameters( @parameterized.named_parameters(
("no_stack", 224, 224, 64, 2, "residual", None, 0, True), ('no_stack', 224, 224, 64, 2, 'residual', None, 0, True),
("residual_stack", 224, 224, 64, 2, "residual", "list", 2, True), ('residual_stack', 224, 224, 64, 2, 'residual', 'list', 2, True),
("conv_stack", 224, 224, 64, 2, "conv", "list", 3, False), ('conv_stack', 224, 224, 64, 2, 'conv', 'list', 3, False),
("callable_no_scale", 224, 224, 64, 1, "residual", "model", 5, False)) ('callable_no_scale', 224, 224, 64, 1, 'residual', 'model', 5, False))
def test_gradient_pass_though(self, width, height, filters, mod, layer_type, def test_gradient_pass_though(self, width, height, filters, mod, layer_type,
stack_type, count, downsample): stack_type, count, downsample):
loss = tf.keras.losses.MeanSquaredError() loss = tf.keras.losses.MeanSquaredError()
...@@ -188,10 +189,10 @@ class CSPStackTest(tf.test.TestCase, parameterized.TestCase): ...@@ -188,10 +189,10 @@ class CSPStackTest(tf.test.TestCase, parameterized.TestCase):
class ConvBNTest(tf.test.TestCase, parameterized.TestCase): class ConvBNTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
("valid", (3, 3), "valid", (1, 1)), ("same", (3, 3), "same", (1, 1)), ('valid', (3, 3), 'valid', (1, 1)), ('same', (3, 3), 'same', (1, 1)),
("downsample", (3, 3), "same", (2, 2)), ("test", (1, 1), "valid", (1, 1))) ('downsample', (3, 3), 'same', (2, 2)), ('test', (1, 1), 'valid', (1, 1)))
def test_pass_through(self, kernel_size, padding, strides): def test_pass_through(self, kernel_size, padding, strides):
if padding == "same": if padding == 'same':
pad_const = 1 pad_const = 1
else: else:
pad_const = 0 pad_const = 0
...@@ -212,16 +213,16 @@ class ConvBNTest(tf.test.TestCase, parameterized.TestCase): ...@@ -212,16 +213,16 @@ class ConvBNTest(tf.test.TestCase, parameterized.TestCase):
print(test) print(test)
self.assertAllEqual(outx.shape.as_list(), test) self.assertAllEqual(outx.shape.as_list(), test)
@parameterized.named_parameters(("filters", 3)) @parameterized.named_parameters(('filters', 3))
def test_gradient_pass_though(self, filters): def test_gradient_pass_though(self, filters):
loss = tf.keras.losses.MeanSquaredError() loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD() optimizer = tf.keras.optimizers.SGD()
with tf.device("/CPU:0"): with tf.device('/CPU:0'):
test_layer = nn_blocks.ConvBN(filters, kernel_size=(3, 3), padding="same") test_layer = nn_blocks.ConvBN(filters, kernel_size=(3, 3), padding='same')
init = tf.random_normal_initializer() init = tf.random_normal_initializer()
x = tf.Variable(initial_value=init(shape=(1, 224, 224, x = tf.Variable(
3), dtype=tf.float32)) initial_value=init(shape=(1, 224, 224, 3), dtype=tf.float32))
y = tf.Variable( y = tf.Variable(
initial_value=init(shape=(1, 224, 224, filters), dtype=tf.float32)) initial_value=init(shape=(1, 224, 224, filters), dtype=tf.float32))
...@@ -235,9 +236,9 @@ class ConvBNTest(tf.test.TestCase, parameterized.TestCase): ...@@ -235,9 +236,9 @@ class ConvBNTest(tf.test.TestCase, parameterized.TestCase):
class DarkResidualTest(tf.test.TestCase, parameterized.TestCase): class DarkResidualTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("same", 224, 224, 64, False), @parameterized.named_parameters(('same', 224, 224, 64, False),
("downsample", 223, 223, 32, True), ('downsample', 223, 223, 32, True),
("oddball", 223, 223, 32, False)) ('oddball', 223, 223, 32, False))
def test_pass_through(self, width, height, filters, downsample): def test_pass_through(self, width, height, filters, downsample):
mod = 1 mod = 1
if downsample: if downsample:
...@@ -252,9 +253,9 @@ class DarkResidualTest(tf.test.TestCase, parameterized.TestCase): ...@@ -252,9 +253,9 @@ class DarkResidualTest(tf.test.TestCase, parameterized.TestCase):
[None, np.ceil(width / mod), [None, np.ceil(width / mod),
np.ceil(height / mod), filters]) np.ceil(height / mod), filters])
@parameterized.named_parameters(("same", 64, 224, 224, False), @parameterized.named_parameters(('same', 64, 224, 224, False),
("downsample", 32, 223, 223, True), ('downsample', 32, 223, 223, True),
("oddball", 32, 223, 223, False)) ('oddball', 32, 223, 223, False))
def test_gradient_pass_though(self, filters, width, height, downsample): def test_gradient_pass_though(self, filters, width, height, downsample):
loss = tf.keras.losses.MeanSquaredError() loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD() optimizer = tf.keras.optimizers.SGD()
...@@ -268,10 +269,11 @@ class DarkResidualTest(tf.test.TestCase, parameterized.TestCase): ...@@ -268,10 +269,11 @@ class DarkResidualTest(tf.test.TestCase, parameterized.TestCase):
init = tf.random_normal_initializer() init = tf.random_normal_initializer()
x = tf.Variable( x = tf.Variable(
initial_value=init(shape=(1, width, height, filters), dtype=tf.float32)) initial_value=init(shape=(1, width, height, filters), dtype=tf.float32))
y = tf.Variable(initial_value=init(shape=(1, int(np.ceil(width / mod)), y = tf.Variable(
int(np.ceil(height / mod)), initial_value=init(
filters), shape=(1, int(np.ceil(width / mod)), int(np.ceil(height / mod)),
dtype=tf.float32)) filters),
dtype=tf.float32))
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
x_hat = test_layer(x) x_hat = test_layer(x)
...@@ -281,5 +283,104 @@ class DarkResidualTest(tf.test.TestCase, parameterized.TestCase): ...@@ -281,5 +283,104 @@ class DarkResidualTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotIn(None, grad) self.assertNotIn(None, grad)
if __name__ == "__main__":
class DarkSppTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('RouteProcessSpp', 224, 224, 3, [5, 9, 13]),
('test1', 300, 300, 10, [2, 3, 4, 5]),
('test2', 256, 256, 5, [10]))
def test_pass_through(self, width, height, channels, sizes):
x = tf.keras.Input(shape=(width, height, channels))
test_layer = nn_blocks.SPP(sizes=sizes)
outx = test_layer(x)
self.assertAllEqual(outx.shape.as_list(),
[None, width, height, channels * (len(sizes) + 1)])
return
@parameterized.named_parameters(('RouteProcessSpp', 224, 224, 3, [5, 9, 13]),
('test1', 300, 300, 10, [2, 3, 4, 5]),
('test2', 256, 256, 5, [10]))
def test_gradient_pass_though(self, width, height, channels, sizes):
loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD()
test_layer = nn_blocks.SPP(sizes=sizes)
init = tf.random_normal_initializer()
x = tf.Variable(
initial_value=init(
shape=(1, width, height, channels), dtype=tf.float32))
y = tf.Variable(
initial_value=init(
shape=(1, width, height, channels * (len(sizes) + 1)),
dtype=tf.float32))
with tf.GradientTape() as tape:
x_hat = test_layer(x)
grad_loss = loss(x_hat, y)
grad = tape.gradient(grad_loss, test_layer.trainable_variables)
optimizer.apply_gradients(zip(grad, test_layer.trainable_variables))
self.assertNotIn(None, grad)
return
class DarkRouteProcessTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('test1', 224, 224, 64, 7, False), ('test2', 223, 223, 32, 3, False),
('tiny', 223, 223, 16, 1, False), ('spp', 224, 224, 64, 7, False))
def test_pass_through(self, width, height, filters, repetitions, spp):
x = tf.keras.Input(shape=(width, height, filters))
test_layer = nn_blocks.DarkRouteProcess(
filters=filters, repetitions=repetitions, insert_spp=spp)
outx = test_layer(x)
self.assertLen(outx, 2, msg='len(outx) != 2')
if repetitions == 1:
filter_y1 = filters
else:
filter_y1 = filters // 2
self.assertAllEqual(
outx[1].shape.as_list(), [None, width, height, filter_y1])
self.assertAllEqual(
filters % 2,
0,
msg='Output of a DarkRouteProcess layer has an odd number of filters')
self.assertAllEqual(outx[0].shape.as_list(), [None, width, height, filters])
@parameterized.named_parameters(
('test1', 224, 224, 64, 7, False), ('test2', 223, 223, 32, 3, False),
('tiny', 223, 223, 16, 1, False), ('spp', 224, 224, 64, 7, False))
def test_gradient_pass_though(self, width, height, filters, repetitions, spp):
loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD()
test_layer = nn_blocks.DarkRouteProcess(
filters=filters, repetitions=repetitions, insert_spp=spp)
if repetitions == 1:
filter_y1 = filters
else:
filter_y1 = filters // 2
init = tf.random_normal_initializer()
x = tf.Variable(
initial_value=init(shape=(1, width, height, filters), dtype=tf.float32))
y_0 = tf.Variable(
initial_value=init(shape=(1, width, height, filters), dtype=tf.float32))
y_1 = tf.Variable(
initial_value=init(
shape=(1, width, height, filter_y1), dtype=tf.float32))
with tf.GradientTape() as tape:
x_hat_0, x_hat_1 = test_layer(x)
grad_loss_0 = loss(x_hat_0, y_0)
grad_loss_1 = loss(x_hat_1, y_1)
grad = tape.gradient([grad_loss_0, grad_loss_1],
test_layer.trainable_variables)
optimizer.apply_gradients(zip(grad, test_layer.trainable_variables))
self.assertNotIn(None, grad)
return
if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment