"external/vscode:/vscode.git/clone" did not exist on "e823d518cb46ad61ddb3c70eac8529e0a58af1f8"
Commit a04d9e0e authored by Vishnu Banna's avatar Vishnu Banna
Browse files

merged

parents 64f16d61 bcbce005
...@@ -26,7 +26,12 @@ from official.vision.beta.configs import backbones ...@@ -26,7 +26,12 @@ from official.vision.beta.configs import backbones
@dataclasses.dataclass @dataclasses.dataclass
class Darknet(hyperparams.Config): class Darknet(hyperparams.Config):
"""Darknet config.""" """Darknet config."""
model_id: str = "darknet53" model_id: str = 'darknet53'
width_scale: float = 1.0
depth_scale: float = 1.0
dilate: bool = False
min_level: int = 3
max_level: int = 5
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
"""Contains definitions of Darknet Backbone Networks. """Contains definitions of Darknet Backbone Networks.
These models are inspired by ResNet and CSPNet. These models are inspired by ResNet and CSPNet.
...@@ -46,16 +45,14 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks ...@@ -46,16 +45,14 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
# builder required classes # builder required classes
class BlockConfig: class BlockConfig:
""" """Class to store layer config to make code more readable."""
This is a class to store layer config to make code more readable.
"""
def __init__(self, layer, stack, reps, bottleneck, filters, pool_size, def __init__(self, layer, stack, reps, bottleneck, filters, pool_size,
kernel_size, strides, padding, activation, route, dilation_rate, kernel_size, strides, padding, activation, route, dilation_rate,
output_name, is_output): output_name, is_output):
""" """Initializing method for BlockConfig.
Args: Args:
layer: A `str` for layer name. layer: A `str` for layer name.
stack: A `str` for the type of layer ordering to use for this specific stack: A `str` for the type of layer ordering to use for this specific
...@@ -69,7 +66,7 @@ class BlockConfig: ...@@ -69,7 +66,7 @@ class BlockConfig:
padding: An `int` for the padding to apply to layers in this stack. padding: An `int` for the padding to apply to layers in this stack.
activation: A `str` for the activation to use for this stack. activation: A `str` for the activation to use for this stack.
route: An `int` for the level to route from to get the next input. route: An `int` for the level to route from to get the next input.
dilation_rate: An `int` for the scale used in dilated Darknet. dilation_rate: An `int` for the scale used in dialated Darknet.
output_name: A `str` for the name to use for this output. output_name: A `str` for the name to use for this output.
is_output: A `bool` for whether this layer is an output in the default is_output: A `bool` for whether this layer is an output in the default
model. model.
...@@ -98,11 +95,11 @@ def build_block_specs(config): ...@@ -98,11 +95,11 @@ def build_block_specs(config):
class LayerBuilder: class LayerBuilder:
""" """Layer builder class.
This is a class that is used for quick look up of default layers used
by darknet to connect, introduce or exit a level. Used in place of an Class for quick look up of default layers used by darknet to
if condition or switch to make adding new layers easier and to reduce connect, introduce or exit a level. Used in place of an if condition
redundant code. or switch to make adding new layers easier and to reduce redundant code.
""" """
def __init__(self): def __init__(self):
...@@ -378,7 +375,7 @@ BACKBONES = { ...@@ -378,7 +375,7 @@ BACKBONES = {
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class Darknet(tf.keras.Model): class Darknet(tf.keras.Model):
""" The Darknet backbone architecture. """ """The Darknet backbone architecture."""
def __init__( def __init__(
self, self,
...@@ -596,8 +593,8 @@ class Darknet(tf.keras.Model): ...@@ -596,8 +593,8 @@ class Darknet(tf.keras.Model):
filters=config.filters, downsample=True, **self._default_dict)( filters=config.filters, downsample=True, **self._default_dict)(
inputs) inputs)
dilated_reps = config.repetitions - \ dilated_reps = config.repetitions - (
(self._default_dict['dilation_rate'] // 2) - 1 self._default_dict['dilation_rate'] // 2) - 1
for i in range(dilated_reps): for i in range(dilated_reps):
self._default_dict['name'] = f'{name}_{i}' self._default_dict['name'] = f'{name}_{i}'
x = nn_blocks.DarkResidual( x = nn_blocks.DarkResidual(
...@@ -668,14 +665,13 @@ def build_darknet( ...@@ -668,14 +665,13 @@ def build_darknet(
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds darknet."""
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
model = Darknet( model = Darknet(
model_id=backbone_cfg.model_id, model_id=backbone_cfg.model_id,
min_level=model_config.min_level, min_level=backbone_cfg.min_level,
max_level=model_config.max_level, max_level=backbone_cfg.max_level,
input_specs=input_specs, input_specs=input_specs,
dilate=backbone_cfg.dilate, dilate=backbone_cfg.dilate,
width_scale=backbone_cfg.width_scale, width_scale=backbone_cfg.width_scale,
...@@ -686,4 +682,4 @@ def build_darknet( ...@@ -686,4 +682,4 @@ def build_darknet(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
model.summary() model.summary()
return model return model
\ No newline at end of file
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
"""Tests for YOLO.""" """Tests for yolo."""
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
...@@ -125,6 +125,5 @@ class DarknetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -125,6 +125,5 @@ class DarknetTest(parameterized.TestCase, tf.test.TestCase):
# If the serialization was successful, the new config should match the old. # If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config()) self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -22,15 +22,7 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks ...@@ -22,15 +22,7 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class _IdentityRoute(tf.keras.layers.Layer): class _IdentityRoute(tf.keras.layers.Layer):
def __init__(self, **kwargs): def call(self, inputs):
"""
Private class to mirror the outputs of blocks in nn_blocks for an easier
programatic generation of the feature pyramid network.
"""
super().__init__(**kwargs)
def call(self, inputs): # pylint: disable=arguments-differ
return None, inputs return None, inputs
...@@ -111,8 +103,7 @@ class YoloFPN(tf.keras.layers.Layer): ...@@ -111,8 +103,7 @@ class YoloFPN(tf.keras.layers.Layer):
return list(reversed(depths)) return list(reversed(depths))
def build(self, inputs): def build(self, inputs):
"""Use config dictionary to generate all important attributes for head """Use config dictionary to generate all important attributes for head.
construction.
Args: Args:
inputs: dictionary of the shape of input args as a dictionary of lists. inputs: dictionary of the shape of input args as a dictionary of lists.
...@@ -127,7 +118,7 @@ class YoloFPN(tf.keras.layers.Layer): ...@@ -127,7 +118,7 @@ class YoloFPN(tf.keras.layers.Layer):
# directly connect to an input path and process it # directly connect to an input path and process it
self.preprocessors = dict() self.preprocessors = dict()
# resample an input and merge it with the output of another path # resample an input and merge it with the output of another path
# in order to aggregate backbone outputs # inorder to aggregate backbone outputs
self.resamples = dict() self.resamples = dict()
# set of convoltion layers and upsample layers that are used to # set of convoltion layers and upsample layers that are used to
# prepare the FPN processors for output # prepare the FPN processors for output
...@@ -181,7 +172,7 @@ class YoloFPN(tf.keras.layers.Layer): ...@@ -181,7 +172,7 @@ class YoloFPN(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class YoloPAN(tf.keras.layers.Layer): class YoloPAN(tf.keras.layers.Layer):
"""YOLO Path Aggregation Network""" """YOLO Path Aggregation Network."""
def __init__(self, def __init__(self,
path_process_len=6, path_process_len=6,
...@@ -216,7 +207,7 @@ class YoloPAN(tf.keras.layers.Layer): ...@@ -216,7 +207,7 @@ class YoloPAN(tf.keras.layers.Layer):
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
fpn_input: `bool`, for whether the input into this function is an FPN or fpn_input: `bool`, for whether the input into this fucntion is an FPN or
a backbone. a backbone.
fpn_filter_scale: `int`, scaling factor for the FPN filters. fpn_filter_scale: `int`, scaling factor for the FPN filters.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
...@@ -253,8 +244,7 @@ class YoloPAN(tf.keras.layers.Layer): ...@@ -253,8 +244,7 @@ class YoloPAN(tf.keras.layers.Layer):
norm_momentum=self._norm_momentum) norm_momentum=self._norm_momentum)
def build(self, inputs): def build(self, inputs):
"""Use config dictionary to generate all important attributes for head """Use config dictionary to generate all important attributes for head.
construction.
Args: Args:
inputs: dictionary of the shape of input args as a dictionary of lists. inputs: dictionary of the shape of input args as a dictionary of lists.
...@@ -270,7 +260,7 @@ class YoloPAN(tf.keras.layers.Layer): ...@@ -270,7 +260,7 @@ class YoloPAN(tf.keras.layers.Layer):
# directly connect to an input path and process it # directly connect to an input path and process it
self.preprocessors = dict() self.preprocessors = dict()
# resample an input and merge it with the output of another path # resample an input and merge it with the output of another path
# in order to aggregate backbone outputs # inorder to aggregate backbone outputs
self.resamples = dict() self.resamples = dict()
# FPN will reverse the key process order for the backbone, so we need # FPN will reverse the key process order for the backbone, so we need
...@@ -368,7 +358,7 @@ class YoloPAN(tf.keras.layers.Layer): ...@@ -368,7 +358,7 @@ class YoloPAN(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class YoloDecoder(tf.keras.Model): class YoloDecoder(tf.keras.Model):
"""Darknet Backbone Decoder""" """Darknet Backbone Decoder."""
def __init__(self, def __init__(self,
input_specs, input_specs,
...@@ -388,8 +378,10 @@ class YoloDecoder(tf.keras.Model): ...@@ -388,8 +378,10 @@ class YoloDecoder(tf.keras.Model):
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
**kwargs): **kwargs):
"""Yolo Decoder initialization function. A unified model that ties all """Yolo Decoder initialization function.
decoder components into a conditionally build YOLO decoder.
A unified model that ties all decoder components into a conditionally build
YOLO decoder.
Args: Args:
input_specs: `dict[str, tf.InputSpec]`: input specs of each of the inputs input_specs: `dict[str, tf.InputSpec]`: input specs of each of the inputs
...@@ -483,4 +475,4 @@ class YoloDecoder(tf.keras.Model): ...@@ -483,4 +475,4 @@ class YoloDecoder(tf.keras.Model):
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
\ No newline at end of file
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
# Import libraries # Import libraries
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
...@@ -27,6 +26,45 @@ from official.vision.beta.projects.yolo.modeling.decoders import yolo_decoder as ...@@ -27,6 +26,45 @@ from official.vision.beta.projects.yolo.modeling.decoders import yolo_decoder as
class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
def _build_yolo_decoder(self, input_specs, name='1'):
# Builds 4 different arbitrary decoders.
if name == '1':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=False,
use_fpn=False,
max_level_process_len=2,
path_process_len=1,
activation='mish')
elif name == '6spp':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=True,
use_fpn=False,
max_level_process_len=None,
path_process_len=6,
activation='mish')
elif name == '6sppfpn':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=True,
use_fpn=True,
max_level_process_len=None,
path_process_len=6,
activation='mish')
elif name == '6':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=False,
use_fpn=False,
max_level_process_len=None,
path_process_len=6,
activation='mish')
else:
raise NotImplementedError(f'YOLO decoder test {type} not implemented.')
return model
@parameterized.parameters('1', '6spp', '6sppfpn', '6') @parameterized.parameters('1', '6spp', '6sppfpn', '6')
def test_network_creation(self, version): def test_network_creation(self, version):
"""Test creation of ResNet family models.""" """Test creation of ResNet family models."""
...@@ -36,10 +74,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -36,10 +74,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
'4': [1, 26, 26, 512], '4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024] '5': [1, 13, 13, 1024]
} }
decoder = build_yolo_decoder(input_shape, version) decoder = self._build_yolo_decoder(input_shape, version)
inputs = {} inputs = {}
for key in input_shape.keys(): for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32) inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
endpoints = decoder.call(inputs) endpoints = decoder.call(inputs)
...@@ -50,7 +88,7 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -50,7 +88,7 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
strategy=[ strategy=[
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
use_sync_bn=[False, True], use_sync_bn=[False, True],
...@@ -66,10 +104,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -66,10 +104,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
'4': [1, 26, 26, 512], '4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024] '5': [1, 13, 13, 1024]
} }
decoder = build_yolo_decoder(input_shape, '6') decoder = self._build_yolo_decoder(input_shape, '6')
inputs = {} inputs = {}
for key in input_shape.keys(): for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32) inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = decoder.call(inputs) _ = decoder.call(inputs)
...@@ -84,10 +122,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -84,10 +122,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
'4': [1, 26, 26, 512], '4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024] '5': [1, 13, 13, 1024]
} }
decoder = build_yolo_decoder(input_shape, '6') decoder = self._build_yolo_decoder(input_shape, '6')
inputs = {} inputs = {}
for key in input_shape.keys(): for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32) inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = decoder(inputs) _ = decoder(inputs)
...@@ -100,10 +138,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -100,10 +138,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
'4': [1, 26, 26, 512], '4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024] '5': [1, 13, 13, 1024]
} }
decoder = build_yolo_decoder(input_shape, '6') decoder = self._build_yolo_decoder(input_shape, '6')
inputs = {} inputs = {}
for key in input_shape.keys(): for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32) inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = decoder(inputs) _ = decoder(inputs)
...@@ -111,44 +149,5 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -111,44 +149,5 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
decoder_from_config = decoders.YoloDecoder.from_config(config) decoder_from_config = decoders.YoloDecoder.from_config(config)
self.assertAllEqual(decoder.get_config(), decoder_from_config.get_config()) self.assertAllEqual(decoder.get_config(), decoder_from_config.get_config())
def build_yolo_decoder(input_specs, type='1'):
if type == '1':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=False,
use_fpn=False,
max_level_process_len=2,
path_process_len=1,
activation='mish')
elif type == '6spp':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=True,
use_fpn=False,
max_level_process_len=None,
path_process_len=6,
activation='mish')
elif type == '6sppfpn':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=True,
use_fpn=True,
max_level_process_len=None,
path_process_len=6,
activation='mish')
elif type == '6':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=False,
use_fpn=False,
max_level_process_len=None,
path_process_len=6,
activation='mish')
else:
raise NotImplementedError(f"YOLO decoder test {type} not implemented.")
return model
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
"""Yolo heads."""
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
class YoloHead(tf.keras.layers.Layer): class YoloHead(tf.keras.layers.Layer):
"""YOLO Prediction Head""" """YOLO Prediction Head."""
def __init__(self, def __init__(self,
min_level, min_level,
...@@ -117,4 +119,4 @@ class YoloHead(tf.keras.layers.Layer): ...@@ -117,4 +119,4 @@ class YoloHead(tf.keras.layers.Layer):
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
\ No newline at end of file
...@@ -13,15 +13,12 @@ ...@@ -13,15 +13,12 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
"""Tests for YOLO heads.""" """Tests for yolo heads."""
# Import libraries # Import libraries
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.projects.yolo.modeling.heads import yolo_head as heads from official.vision.beta.projects.yolo.modeling.heads import yolo_head as heads
...@@ -40,10 +37,11 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -40,10 +37,11 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
head = heads.YoloHead(3, 5, classes=classes, boxes_per_level=bps) head = heads.YoloHead(3, 5, classes=classes, boxes_per_level=bps)
inputs = {} inputs = {}
for key in input_shape.keys(): for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32) inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
endpoints = head(inputs) endpoints = head(inputs)
# print(endpoints)
for key in endpoints.keys(): for key in endpoints.keys():
expected_input_shape = input_shape[key] expected_input_shape = input_shape[key]
...@@ -63,7 +61,7 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -63,7 +61,7 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
head = heads.YoloHead(3, 5, classes=classes, boxes_per_level=bps) head = heads.YoloHead(3, 5, classes=classes, boxes_per_level=bps)
inputs = {} inputs = {}
for key in input_shape.keys(): for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32) inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = head(inputs) _ = head(inputs)
...@@ -71,6 +69,5 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -71,6 +69,5 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
head_from_config = heads.YoloHead.from_config(configs) head_from_config = heads.YoloHead.from_config(configs)
self.assertAllEqual(head.get_config(), head_from_config.get_config()) self.assertAllEqual(head.get_config(), head_from_config.get_config())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
\ No newline at end of file
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
"""Contains common building blocks for yolo neural networks."""
"""Contains common building blocks for YOLO neural networks."""
from typing import Callable, List from typing import Callable, List
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -33,11 +33,12 @@ class Identity(tf.keras.layers.Layer): ...@@ -33,11 +33,12 @@ class Identity(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class ConvBN(tf.keras.layers.Layer): class ConvBN(tf.keras.layers.Layer):
""" """ConvBN block.
Modified Convolution layer to match that of the Darknet Library. Modified Convolution layer to match that of the Darknet Library.
The Layer is a standard combination of Conv BatchNorm Activation, The Layer is a standards combination of Conv BatchNorm Activation,
however, the use of bias in the Conv is determined by the use of however, the use of bias in the conv is determined by the use of batch
batch normalization. normalization.
Cross Stage Partial networks (CSPNets) were proposed in: Cross Stage Partial networks (CSPNets) were proposed in:
[1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, [1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu,
Ping-Yang Chen, Jun-Wei Hsieh Ping-Yang Chen, Jun-Wei Hsieh
...@@ -62,7 +63,8 @@ class ConvBN(tf.keras.layers.Layer): ...@@ -62,7 +63,8 @@ class ConvBN(tf.keras.layers.Layer):
activation='leaky', activation='leaky',
leaky_alpha=0.1, leaky_alpha=0.1,
**kwargs): **kwargs):
""" """ConvBN initializer.
Args: Args:
filters: integer for output depth, or the number of features to learn. filters: integer for output depth, or the number of features to learn.
kernel_size: integer or tuple for the shape of the weight matrix or kernel kernel_size: integer or tuple for the shape of the weight matrix or kernel
...@@ -190,9 +192,7 @@ class ConvBN(tf.keras.layers.Layer): ...@@ -190,9 +192,7 @@ class ConvBN(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class DarkResidual(tf.keras.layers.Layer): class DarkResidual(tf.keras.layers.Layer):
""" """Darknet block with Residual connection for Yolo v3 Backbone."""
Darknet block with Residual connection for YOLO v3 Backbone
"""
def __init__(self, def __init__(self,
filters=1, filters=1,
...@@ -211,9 +211,13 @@ class DarkResidual(tf.keras.layers.Layer): ...@@ -211,9 +211,13 @@ class DarkResidual(tf.keras.layers.Layer):
sc_activation='linear', sc_activation='linear',
downsample=False, downsample=False,
**kwargs): **kwargs):
""" """Dark Residual initializer.
Args: Args:
filters: integer for output depth, or the number of features to learn. filters: integer for output depth, or the number of features to learn.
filter_scale: `int` for filter scale.
dilation_rate: tuple to indicate how much to modulate kernel weights and
how many pixels in a feature map to skip.
kernel_initializer: string to indicate which function to use to initialize kernel_initializer: string to indicate which function to use to initialize
weights. weights.
bias_initializer: string to indicate which function to use to initialize bias_initializer: string to indicate which function to use to initialize
...@@ -228,6 +232,8 @@ class DarkResidual(tf.keras.layers.Layer): ...@@ -228,6 +232,8 @@ class DarkResidual(tf.keras.layers.Layer):
(across all input batches). (across all input batches).
norm_momentum: float for moment to use for batch normalization. norm_momentum: float for moment to use for batch normalization.
norm_epsilon: float for batch normalization epsilon. norm_epsilon: float for batch normalization epsilon.
activation: string or None for activation function to use in layer,
if None activation is replaced by linear.
leaky_alpha: float to use as alpha if activation function is leaky. leaky_alpha: float to use as alpha if activation function is leaky.
sc_activation: string for activation function to use in layer. sc_activation: string for activation function to use in layer.
downsample: boolean for if image input is larger than layer output, set downsample: boolean for if image input is larger than layer output, set
...@@ -349,11 +355,12 @@ class DarkResidual(tf.keras.layers.Layer): ...@@ -349,11 +355,12 @@ class DarkResidual(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class CSPTiny(tf.keras.layers.Layer): class CSPTiny(tf.keras.layers.Layer):
""" """CSP Tiny layer.
A small size convolution block proposed in the CSPNet. The layer uses shortcuts,
routing(concatenation), and feature grouping in order to improve gradient A Small size convolution block proposed in the CSPNet. The layer uses
variability and allow for high efficiency, low power residual learning for small shortcuts, routing(concatnation), and feature grouping in order to improve
networks. gradient variablity and allow for high efficency, low power residual learning
for small networtf.keras.
Cross Stage Partial networks (CSPNets) were proposed in: Cross Stage Partial networks (CSPNets) were proposed in:
[1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, [1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu,
Ping-Yang Chen, Jun-Wei Hsieh Ping-Yang Chen, Jun-Wei Hsieh
...@@ -378,7 +385,8 @@ class CSPTiny(tf.keras.layers.Layer): ...@@ -378,7 +385,8 @@ class CSPTiny(tf.keras.layers.Layer):
downsample=True, downsample=True,
leaky_alpha=0.1, leaky_alpha=0.1,
**kwargs): **kwargs):
""" """Initializer for CSPTiny block.
Args: Args:
filters: integer for output depth, or the number of features to learn. filters: integer for output depth, or the number of features to learn.
kernel_initializer: string to indicate which function to use to initialize kernel_initializer: string to indicate which function to use to initialize
...@@ -390,6 +398,7 @@ class CSPTiny(tf.keras.layers.Layer): ...@@ -390,6 +398,7 @@ class CSPTiny(tf.keras.layers.Layer):
kernel_regularizer: string to indicate which function to use to kernel_regularizer: string to indicate which function to use to
regularizer weights. regularizer weights.
use_bn: boolean for whether to use batch normalization. use_bn: boolean for whether to use batch normalization.
dilation_rate: `int`, dilation rate for conv layers.
use_sync_bn: boolean for whether sync batch normalization statistics use_sync_bn: boolean for whether sync batch normalization statistics
of all batch norm layers to the models global statistics of all batch norm layers to the models global statistics
(across all input batches). (across all input batches).
...@@ -399,12 +408,11 @@ class CSPTiny(tf.keras.layers.Layer): ...@@ -399,12 +408,11 @@ class CSPTiny(tf.keras.layers.Layer):
feature stack output. feature stack output.
norm_momentum: float for moment to use for batch normalization. norm_momentum: float for moment to use for batch normalization.
norm_epsilon: float for batch normalization epsilon. norm_epsilon: float for batch normalization epsilon.
activation: string or None for activation function to use in layer,
if None activation is replaced by linear.
downsample: boolean for if image input is larger than layer output, set downsample: boolean for if image input is larger than layer output, set
downsample to True so the dimensions are forced to match. downsample to True so the dimensions are forced to match.
leaky_alpha: float to use as alpha if activation function is leaky. leaky_alpha: float to use as alpha if activation function is leaky.
sc_activation: string for activation function to use in layer.
conv_activation: string or None for activation function to use in layer,
if None activation is replaced by linear.
**kwargs: Keyword Arguments. **kwargs: Keyword Arguments.
""" """
...@@ -502,19 +510,20 @@ class CSPTiny(tf.keras.layers.Layer): ...@@ -502,19 +510,20 @@ class CSPTiny(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class CSPRoute(tf.keras.layers.Layer): class CSPRoute(tf.keras.layers.Layer):
""" """CSPRoute block.
Down sampling layer to take the place of down sampling done in Residual
Down sampling layer to take the place of down sampleing done in Residual
networks. This is the first of 2 layers needed to convert any Residual Network networks. This is the first of 2 layers needed to convert any Residual Network
model to a CSPNet. At the start of a new level change, this CSPRoute layer model to a CSPNet. At the start of a new level change, this CSPRoute layer
creates a learned identity that will act as a cross stage connection that creates a learned identity that will act as a cross stage connection,
is used to inform the inputs to the next stage. This is called cross stage that is used to inform the inputs to the next stage. It is called cross stage
partial because the number of filters required in every intermittent residual partial because the number of filters required in every intermitent Residual
layer is reduced by half. The sister layer will take the partial generated by layer is reduced by half. The sister layer will take the partial generated by
this layer and concatenate it with the output of the final residual layer in the this layer and concatnate it with the output of the final residual layer in
stack to create a fully feature level output. This concatenation merges the the stack to create a fully feature level output. This concatnation merges the
partial blocks of 2 levels as input to the next allowing the gradients of each partial blocks of 2 levels as input to the next allowing the gradients of each
level to be more unique, and reducing the number of parameters required by each level to be more unique, and reducing the number of parameters required by
level by 50% while keeping accuracy consistent. each level by 50% while keeping accuracy consistent.
Cross Stage Partial networks (CSPNets) were proposed in: Cross Stage Partial networks (CSPNets) were proposed in:
[1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, [1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu,
...@@ -539,7 +548,8 @@ class CSPRoute(tf.keras.layers.Layer): ...@@ -539,7 +548,8 @@ class CSPRoute(tf.keras.layers.Layer):
downsample=True, downsample=True,
leaky_alpha=0.1, leaky_alpha=0.1,
**kwargs): **kwargs):
""" """CSPRoute layer initializer.
Args: Args:
filters: integer for output depth, or the number of features to learn filters: integer for output depth, or the number of features to learn
filter_scale: integer dictating (filters//2) or the number of filters in filter_scale: integer dictating (filters//2) or the number of filters in
...@@ -553,6 +563,7 @@ class CSPRoute(tf.keras.layers.Layer): ...@@ -553,6 +563,7 @@ class CSPRoute(tf.keras.layers.Layer):
bias. bias.
kernel_regularizer: string to indicate which function to use to kernel_regularizer: string to indicate which function to use to
regularizer weights. regularizer weights.
dilation_rate: dilation rate for conv layers.
use_bn: boolean for whether to use batch normalization. use_bn: boolean for whether to use batch normalization.
use_sync_bn: boolean for whether sync batch normalization statistics use_sync_bn: boolean for whether sync batch normalization statistics
of all batch norm layers to the models global statistics of all batch norm layers to the models global statistics
...@@ -560,6 +571,7 @@ class CSPRoute(tf.keras.layers.Layer): ...@@ -560,6 +571,7 @@ class CSPRoute(tf.keras.layers.Layer):
norm_momentum: float for moment to use for batch normalization. norm_momentum: float for moment to use for batch normalization.
norm_epsilon: float for batch normalization epsilon. norm_epsilon: float for batch normalization epsilon.
downsample: down_sample the input. downsample: down_sample the input.
leaky_alpha: `float`, for leaky alpha value.
**kwargs: Keyword Arguments. **kwargs: Keyword Arguments.
""" """
...@@ -569,7 +581,7 @@ class CSPRoute(tf.keras.layers.Layer): ...@@ -569,7 +581,7 @@ class CSPRoute(tf.keras.layers.Layer):
self._filter_scale = filter_scale self._filter_scale = filter_scale
self._activation = activation self._activation = activation
# convolution params # convoultion params
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._bias_initializer = bias_initializer self._bias_initializer = bias_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
...@@ -631,12 +643,18 @@ class CSPRoute(tf.keras.layers.Layer): ...@@ -631,12 +643,18 @@ class CSPRoute(tf.keras.layers.Layer):
x = self._conv3(inputs) x = self._conv3(inputs)
return (x, y) return (x, y)
self._conv2 = ConvBN(
filters=self._filters // self._filter_scale,
kernel_size=(1, 1),
strides=(1, 1),
**dark_conv_args)
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class CSPConnect(tf.keras.layers.Layer): class CSPConnect(tf.keras.layers.Layer):
""" """CSPConnect block.
Sister Layer to the CSPRoute layer. Merges the partial feature stacks Sister Layer to the CSPRoute layer. Merges the partial feature stacks
generated by the CSPDownsampling layer, and the final output of the generated by the CSPDownsampling layer, and the finaly output of the
residual stack. Suggested in the CSPNet paper. residual stack. Suggested in the CSPNet paper.
Cross Stage Partial networks (CSPNets) were proposed in: Cross Stage Partial networks (CSPNets) were proposed in:
[1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, [1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu,
...@@ -663,12 +681,16 @@ class CSPConnect(tf.keras.layers.Layer): ...@@ -663,12 +681,16 @@ class CSPConnect(tf.keras.layers.Layer):
norm_epsilon=0.001, norm_epsilon=0.001,
leaky_alpha=0.1, leaky_alpha=0.1,
**kwargs): **kwargs):
""" """Initializer for CSPConnect block.
Args: Args:
filters: integer for output depth, or the number of features to learn filters: integer for output depth, or the number of features to learn
filter_scale: integer dicating (filters//2) or the number of filters in filter_scale: integer dicating (filters//2) or the number of filters in
the partial feature stack. the partial feature stack.
drop_final: `bool`, whether to drop final conv layer.
drop_first: `bool`, whether to drop first conv layer.
activation: string for activation function to use in layer. activation: string for activation function to use in layer.
kernel_size: `Tuple`, kernel size for conv layers.
kernel_initializer: string to indicate which function to use to initialize kernel_initializer: string to indicate which function to use to initialize
weights. weights.
bias_initializer: string to indicate which function to use to initialize bias_initializer: string to indicate which function to use to initialize
...@@ -677,12 +699,14 @@ class CSPConnect(tf.keras.layers.Layer): ...@@ -677,12 +699,14 @@ class CSPConnect(tf.keras.layers.Layer):
bias. bias.
kernel_regularizer: string to indicate which function to use to kernel_regularizer: string to indicate which function to use to
regularizer weights. regularizer weights.
dilation_rate: `int`, dilation rate for conv layers.
use_bn: boolean for whether to use batch normalization. use_bn: boolean for whether to use batch normalization.
use_sync_bn: boolean for whether sync batch normalization statistics use_sync_bn: boolean for whether sync batch normalization statistics
of all batch norm layers to the models global of all batch norm layers to the models global
statistics (across all input batches). statistics (across all input batches).
norm_momentum: float for moment to use for batch normalization. norm_momentum: float for moment to use for batch normalization.
norm_epsilon: float for batch normalization epsilon. norm_epsilon: float for batch normalization epsilon.
leaky_alpha: `float`, for leaky alpha value.
**kwargs: Keyword Arguments. **kwargs: Keyword Arguments.
""" """
...@@ -747,14 +771,15 @@ class CSPConnect(tf.keras.layers.Layer): ...@@ -747,14 +771,15 @@ class CSPConnect(tf.keras.layers.Layer):
class CSPStack(tf.keras.layers.Layer): class CSPStack(tf.keras.layers.Layer):
""" """CSP Stack layer.
CSP full stack, combines the route and the connect in case you don't want to
just quickly wrap an existing callable or list of layers to make it a cross CSP full stack, combines the route and the connect in case you dont want to
stage partial. Added for ease of use. you should be able to wrap any layer jsut quickly wrap an existing callable or list of layers to
stack with a CSP independent of whether it belongs to the Darknet family. If make it a cross stage partial. Added for ease of use. you should be able
filter_scale = 2, then the blocks in the stack passed into the the CSP stack to wrap any layer stack with a CSP independent of wether it belongs
should also have filters = filters/filter_scale Cross Stage Partial networks to the Darknet family. if filter_scale = 2, then the blocks in the stack
(CSPNets) were proposed in: passed into the the CSP stack should also have filters = filters/filter_scale
Cross Stage Partial networks (CSPNets) were proposed in:
[1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, [1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu,
Ping-Yang Chen, Jun-Wei Hsieh Ping-Yang Chen, Jun-Wei Hsieh
...@@ -777,7 +802,8 @@ class CSPStack(tf.keras.layers.Layer): ...@@ -777,7 +802,8 @@ class CSPStack(tf.keras.layers.Layer):
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
**kwargs): **kwargs):
""" """CSPStack layer initializer.
Args: Args:
filters: integer for output depth, or the number of features to learn. filters: integer for output depth, or the number of features to learn.
model_to_wrap: callable Model or a list of callable objects that will model_to_wrap: callable Model or a list of callable objects that will
...@@ -862,6 +888,7 @@ class CSPStack(tf.keras.layers.Layer): ...@@ -862,6 +888,7 @@ class CSPStack(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class PathAggregationBlock(tf.keras.layers.Layer): class PathAggregationBlock(tf.keras.layers.Layer):
"""Path Aggregation block."""
def __init__(self, def __init__(self,
filters=1, filters=1,
...@@ -881,7 +908,8 @@ class PathAggregationBlock(tf.keras.layers.Layer): ...@@ -881,7 +908,8 @@ class PathAggregationBlock(tf.keras.layers.Layer):
upsample=False, upsample=False,
upsample_size=2, upsample_size=2,
**kwargs): **kwargs):
""" """Initializer for path aggregation block.
Args: Args:
filters: integer for output depth, or the number of features to learn. filters: integer for output depth, or the number of features to learn.
drop_final: do not create the last convolution block. drop_final: do not create the last convolution block.
...@@ -903,13 +931,13 @@ class PathAggregationBlock(tf.keras.layers.Layer): ...@@ -903,13 +931,13 @@ class PathAggregationBlock(tf.keras.layers.Layer):
activation: string or None for activation function to use in layer, activation: string or None for activation function to use in layer,
if None activation is replaced by linear. if None activation is replaced by linear.
leaky_alpha: float to use as alpha if activation function is leaky. leaky_alpha: float to use as alpha if activation function is leaky.
downsample: `bool` for whether to downsample and merge. downsample: `bool` for whehter to downwample and merge.
upsample: `bool` for whether to upsample and merge. upsample: `bool` for whehter to upsample and merge.
upsample_size: `int` how much to upsample in order to match shapes. upsample_size: `int` how much to upsample in order to match shapes.
**kwargs: Keyword Arguments. **kwargs: Keyword Arguments.
""" """
# darkconv params # Darkconv params
self._filters = filters self._filters = filters
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._bias_initializer = bias_initializer self._bias_initializer = bias_initializer
...@@ -918,11 +946,11 @@ class PathAggregationBlock(tf.keras.layers.Layer): ...@@ -918,11 +946,11 @@ class PathAggregationBlock(tf.keras.layers.Layer):
self._use_bn = use_bn self._use_bn = use_bn
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
# normal params # Normal params
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
# activation params # Activation params
self._conv_activation = activation self._conv_activation = activation
self._leaky_alpha = leaky_alpha self._leaky_alpha = leaky_alpha
self._downsample = downsample self._downsample = downsample
...@@ -930,7 +958,7 @@ class PathAggregationBlock(tf.keras.layers.Layer): ...@@ -930,7 +958,7 @@ class PathAggregationBlock(tf.keras.layers.Layer):
self._upsample_size = upsample_size self._upsample_size = upsample_size
self._drop_final = drop_final self._drop_final = drop_final
#block params # Block params
self._inverted = inverted self._inverted = inverted
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -1047,13 +1075,14 @@ class PathAggregationBlock(tf.keras.layers.Layer): ...@@ -1047,13 +1075,14 @@ class PathAggregationBlock(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class SPP(tf.keras.layers.Layer): class SPP(tf.keras.layers.Layer):
""" """Spatial Pyramid Pooling.
A non-aggregated SPP layer that uses Pooling to gain more performance.
A non-agregated SPP layer that uses Pooling.
""" """
def __init__(self, sizes, **kwargs): def __init__(self, sizes, **kwargs):
self._sizes = list(reversed(sizes)) self._sizes = list(reversed(sizes))
if len(sizes) == 0: if not sizes:
raise ValueError('More than one maxpool should be specified in SSP block') raise ValueError('More than one maxpool should be specified in SSP block')
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -1084,11 +1113,12 @@ class SPP(tf.keras.layers.Layer): ...@@ -1084,11 +1113,12 @@ class SPP(tf.keras.layers.Layer):
class SAM(tf.keras.layers.Layer): class SAM(tf.keras.layers.Layer):
""" """Spatial Attention Model.
[1] Sanghyun Woo, Jongchan Park, Joon-Young Lee, In So Kweon [1] Sanghyun Woo, Jongchan Park, Joon-Young Lee, In So Kweon
CBAM: Convolutional Block Attention Module. arXiv:1807.06521 CBAM: Convolutional Block Attention Module. arXiv:1807.06521
Implementation of the Spatial Attention Model (SAM) implementation of the Spatial Attention Model (SAM)
""" """
def __init__(self, def __init__(self,
...@@ -1161,7 +1191,8 @@ class SAM(tf.keras.layers.Layer): ...@@ -1161,7 +1191,8 @@ class SAM(tf.keras.layers.Layer):
class CAM(tf.keras.layers.Layer): class CAM(tf.keras.layers.Layer):
""" """Channel Attention Model.
[1] Sanghyun Woo, Jongchan Park, Joon-Young Lee, In So Kweon [1] Sanghyun Woo, Jongchan Park, Joon-Young Lee, In So Kweon
CBAM: Convolutional Block Attention Module. arXiv:1807.06521 CBAM: Convolutional Block Attention Module. arXiv:1807.06521
...@@ -1247,11 +1278,12 @@ class CAM(tf.keras.layers.Layer): ...@@ -1247,11 +1278,12 @@ class CAM(tf.keras.layers.Layer):
class CBAM(tf.keras.layers.Layer): class CBAM(tf.keras.layers.Layer):
""" """Convolutional Block Attention Module.
[1] Sanghyun Woo, Jongchan Park, Joon-Young Lee, In So Kweon [1] Sanghyun Woo, Jongchan Park, Joon-Young Lee, In So Kweon
CBAM: Convolutional Block Attention Module. arXiv:1807.06521 CBAM: Convolutional Block Attention Module. arXiv:1807.06521
Implementation of the Convolution Block Attention Module (CBAM) implementation of the Convolution Block Attention Module (CBAM)
""" """
def __init__(self, def __init__(self,
...@@ -1318,10 +1350,10 @@ class CBAM(tf.keras.layers.Layer): ...@@ -1318,10 +1350,10 @@ class CBAM(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo') @tf.keras.utils.register_keras_serializable(package='yolo')
class DarkRouteProcess(tf.keras.layers.Layer): class DarkRouteProcess(tf.keras.layers.Layer):
""" """Dark Route Process block.
Processes darknet outputs and connects the backbone to the head for more
generalizability and abstracts the repetition of DarkConv objects that is Process darknet outputs and connect back bone to head more generalizably
common in YOLO. Abstracts repetition of DarkConv objects that is common in YOLO.
It is used like the following: It is used like the following:
...@@ -1342,8 +1374,8 @@ class DarkRouteProcess(tf.keras.layers.Layer): ...@@ -1342,8 +1374,8 @@ class DarkRouteProcess(tf.keras.layers.Layer):
kernel_initializer='glorot_uniform', kernel_initializer='glorot_uniform',
bias_initializer='zeros', bias_initializer='zeros',
bias_regularizer=None, bias_regularizer=None,
kernel_regularizer=None,
use_sync_bn=False, use_sync_bn=False,
kernel_regularizer=None, # default find where is it is stated
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
block_invert=False, block_invert=False,
...@@ -1351,30 +1383,53 @@ class DarkRouteProcess(tf.keras.layers.Layer): ...@@ -1351,30 +1383,53 @@ class DarkRouteProcess(tf.keras.layers.Layer):
leaky_alpha=0.1, leaky_alpha=0.1,
spp_keys=None, spp_keys=None,
**kwargs): **kwargs):
""" """DarkRouteProcess initializer.
Args: Args:
filters: the number of filters to be used in all subsequent layers filters: the number of filters to be used in all subsequent layers
filters should be the depth of the tensor input into this layer, filters should be the depth of the tensor input into this layer,
as no downsampling can be done within this layer object. as no downsampling can be done within this layer object.
repetitions: number of times to repeat the processing nodes repetitions: number of times to repeat the processign nodes.
for tiny: 1 repetition, no spp allowed for tiny: 1 repition, no spp allowed.
for spp: insert_spp = True, and allow for 3+ repetitions for spp: insert_spp = True, and allow for 6 repetitions.
for regular: insert_spp = False, and allow for 3+ repetitions. for regular: insert_spp = False, and allow for 6 repetitions.
insert_spp: bool if true add the spatial pyramid pooling layer. insert_spp: bool if true add the spatial pyramid pooling layer.
insert_sam: bool if true add spatial attention module to path.
insert_cbam: bool if true add convolutional block attention
module to path.
csp_stack: int for the number of sequential layers from 0
to <value> you would like to convert into a Cross Stage
Partial(csp) type.
csp_scale: int for how much to down scale the number of filters
only for the csp layers in the csp section of the processing
path. A value 2 indicates that each layer that is int eh CSP
stack will have filters = filters/2.
kernel_initializer: method to use to initialize kernel weights. kernel_initializer: method to use to initialize kernel weights.
bias_initializer: method to use to initialize the bias of the conv bias_initializer: method to use to initialize the bias of the conv
layers. layers.
norm_momentum: batch norm parameter see TensorFlow documentation. bias_regularizer: string to indicate which function to use to regularizer
norm_epsilon: batch norm parameter see TensorFlow documentation. bias.
kernel_regularizer: string to indicate which function to use to
regularizer weights.
use_sync_bn: bool if true use the sync batch normalization.
norm_momentum: batch norm parameter see Tensorflow documentation.
norm_epsilon: batch norm parameter see Tensorflow documentation.
block_invert: bool use for switching between the even and odd
repretions of layers. usually the repetition is based on a
3x3 conv with filters, followed by a 1x1 with filters/2 with
an even number of repetitions to ensure each 3x3 gets a 1x1
sqeeze. block invert swaps the 3x3/1 1x1/2 to a 1x1/2 3x3/1
ordering typically used when the model requires an odd number
of repetiitions. All other peramters maintain their affects
activation: activation function to use in processing. activation: activation function to use in processing.
leaky_alpha: if leaky activation function, the alpha to use in leaky_alpha: if leaky acitivation function, the alpha to use in
processing the relu input. processing the relu input.
spp_keys: List[int] of the sampling levels to be applied by
Returns: the Spatial Pyramid Pooling Layer. By default it is
callable tensorflow layer [5, 9, 13] inidicating a 5x5 pooling followed by 9x9
followed by 13x13 then followed by the standard concatnation
Raises: and convolution.
None **kwargs: Keyword Arguments.
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -1555,7 +1610,7 @@ class DarkRouteProcess(tf.keras.layers.Layer): ...@@ -1555,7 +1610,7 @@ class DarkRouteProcess(tf.keras.layers.Layer):
x_prev = x x_prev = x
output_prev = True output_prev = True
for i, (layer, output) in enumerate(zip(self.layers, self.outputs)): for (layer, output) in zip(self.layers, self.outputs):
if output_prev: if output_prev:
x_prev = x x_prev = x
x = layer(x) x = layer(x)
...@@ -1585,4 +1640,4 @@ class DarkRouteProcess(tf.keras.layers.Layer): ...@@ -1585,4 +1640,4 @@ class DarkRouteProcess(tf.keras.layers.Layer):
if self._csp_stack > 0: if self._csp_stack > 0:
return self._call_csp(inputs, training=training) return self._call_csp(inputs, training=training)
else: else:
return self._call_regular(inputs) return self._call_regular(inputs)
\ No newline at end of file
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from absl.testing import parameterized from absl.testing import parameterized
...@@ -334,12 +336,13 @@ class DarkRouteProcessTest(tf.test.TestCase, parameterized.TestCase): ...@@ -334,12 +336,13 @@ class DarkRouteProcessTest(tf.test.TestCase, parameterized.TestCase):
test_layer = nn_blocks.DarkRouteProcess( test_layer = nn_blocks.DarkRouteProcess(
filters=filters, repetitions=repetitions, insert_spp=spp) filters=filters, repetitions=repetitions, insert_spp=spp)
outx = test_layer(x) outx = test_layer(x)
self.assertEqual(len(outx), 2, msg='len(outx) != 2') self.assertLen(outx, 2, msg='len(outx) != 2')
if repetitions == 1: if repetitions == 1:
filter_y1 = filters filter_y1 = filters
else: else:
filter_y1 = filters // 2 filter_y1 = filters // 2
self.assertAllEqual(outx[1].shape.as_list(), [None, width, height, filter_y1]) self.assertAllEqual(
outx[1].shape.as_list(), [None, width, height, filter_y1])
self.assertAllEqual( self.assertAllEqual(
filters % 2, filters % 2,
0, 0,
...@@ -366,7 +369,8 @@ class DarkRouteProcessTest(tf.test.TestCase, parameterized.TestCase): ...@@ -366,7 +369,8 @@ class DarkRouteProcessTest(tf.test.TestCase, parameterized.TestCase):
y_0 = tf.Variable( y_0 = tf.Variable(
initial_value=init(shape=(1, width, height, filters), dtype=tf.float32)) initial_value=init(shape=(1, width, height, filters), dtype=tf.float32))
y_1 = tf.Variable( y_1 = tf.Variable(
initial_value=init(shape=(1, width, height, filter_y1), dtype=tf.float32)) initial_value=init(
shape=(1, width, height, filter_y1), dtype=tf.float32))
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
x_hat_0, x_hat_1 = test_layer(x) x_hat_0, x_hat_1 = test_layer(x)
...@@ -379,6 +383,5 @@ class DarkRouteProcessTest(tf.test.TestCase, parameterized.TestCase): ...@@ -379,6 +383,5 @@ class DarkRouteProcessTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotIn(None, grad) self.assertNotIn(None, grad)
return return
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
\ No newline at end of file
...@@ -34,9 +34,9 @@ class DetectionModule(export_base.ExportModule): ...@@ -34,9 +34,9 @@ class DetectionModule(export_base.ExportModule):
def _build_model(self): def _build_model(self):
if self._batch_size is None: if self._batch_size is None:
ValueError("batch_size can't be None for detection models") raise ValueError('batch_size cannot be None for detection models.')
if not self.params.task.model.detection_generator.use_batched_nms: if not self.params.task.model.detection_generator.use_batched_nms:
ValueError('Only batched_nms is supported.') raise ValueError('Only batched_nms is supported.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] + input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3]) self._input_image_size + [3])
......
...@@ -118,6 +118,20 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -118,6 +118,20 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(outputs['num_detections'].numpy(), self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy()) expected_outputs['num_detections'].numpy())
def test_build_model_fail_with_none_batch_size(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
with self.assertRaisesRegex(
ValueError, 'batch_size cannot be None for detection models.'):
detection.DetectionModule(
params, batch_size=None, input_image_size=[640, 640])
def test_build_model_fail_with_batched_nms_false(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
params.task.model.detection_generator.use_batched_nms = False
with self.assertRaisesRegex(ValueError, 'Only batched_nms is supported.'):
detection.DetectionModule(
params, batch_size=1, input_image_size=[640, 640])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -104,6 +104,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -104,6 +104,7 @@ class ImageClassificationTask(base_task.Task):
num_classes=num_classes, num_classes=num_classes,
image_field_key=image_field_key, image_field_key=image_field_key,
label_field_key=label_field_key, label_field_key=label_field_key,
decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip, aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type, aug_type=params.aug_type,
is_multilabel=is_multilabel, is_multilabel=is_multilabel,
......
...@@ -133,12 +133,54 @@ class RetinaNetTask(base_task.Task): ...@@ -133,12 +133,54 @@ class RetinaNetTask(base_task.Task):
return dataset return dataset
def build_attribute_loss(self,
attribute_heads: List[exp_cfg.AttributeHead],
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
box_sample_weight: tf.Tensor) -> float:
"""Computes attribute loss.
Args:
attribute_heads: a list of attribute head configs.
outputs: RetinaNet model outputs.
labels: RetinaNet labels.
box_sample_weight: normalized bounding box sample weights.
Returns:
Attribute loss of all attribute heads.
"""
attribute_loss = 0.0
for head in attribute_heads:
if head.name not in labels['attribute_targets']:
raise ValueError(f'Attribute {head.name} not found in label targets.')
if head.name not in outputs['attribute_outputs']:
raise ValueError(f'Attribute {head.name} not found in model outputs.')
y_true_att = keras_cv.losses.multi_level_flatten(
labels['attribute_targets'][head.name], last_dim=head.size)
y_pred_att = keras_cv.losses.multi_level_flatten(
outputs['attribute_outputs'][head.name], last_dim=head.size)
if head.type == 'regression':
att_loss_fn = tf.keras.losses.Huber(
1.0, reduction=tf.keras.losses.Reduction.SUM)
att_loss = att_loss_fn(
y_true=y_true_att,
y_pred=y_pred_att,
sample_weight=box_sample_weight)
else:
raise ValueError(f'Attribute type {head.type} not supported.')
attribute_loss += att_loss
return attribute_loss
def build_losses(self, def build_losses(self,
outputs: Mapping[str, Any], outputs: Mapping[str, Any],
labels: Mapping[str, Any], labels: Mapping[str, Any],
aux_losses: Optional[Any] = None): aux_losses: Optional[Any] = None):
"""Build RetinaNet losses.""" """Build RetinaNet losses."""
params = self.task_config params = self.task_config
attribute_heads = self.task_config.model.head.attribute_heads
cls_loss_fn = keras_cv.losses.FocalLoss( cls_loss_fn = keras_cv.losses.FocalLoss(
alpha=params.losses.focal_loss_alpha, alpha=params.losses.focal_loss_alpha,
gamma=params.losses.focal_loss_gamma, gamma=params.losses.focal_loss_gamma,
...@@ -170,6 +212,10 @@ class RetinaNetTask(base_task.Task): ...@@ -170,6 +212,10 @@ class RetinaNetTask(base_task.Task):
model_loss = cls_loss + params.losses.box_loss_weight * box_loss model_loss = cls_loss + params.losses.box_loss_weight * box_loss
if attribute_heads:
model_loss += self.build_attribute_loss(attribute_heads, outputs, labels,
box_sample_weight)
total_loss = model_loss total_loss = model_loss
if aux_losses: if aux_losses:
reg_loss = tf.reduce_sum(aux_losses) reg_loss = tf.reduce_sum(aux_losses)
......
...@@ -322,21 +322,21 @@ class DistributedExecutor(object): ...@@ -322,21 +322,21 @@ class DistributedExecutor(object):
return test_step return test_step
def train(self, def train(
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset], self,
eval_input_fn: Callable[[params_dict.ParamsDict], train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
tf.data.Dataset] = None, eval_input_fn: Optional[Callable[[params_dict.ParamsDict],
model_dir: Text = None, tf.data.Dataset]] = None,
total_steps: int = 1, model_dir: Optional[Text] = None,
iterations_per_loop: int = 1, total_steps: int = 1,
train_metric_fn: Callable[[], Any] = None, iterations_per_loop: int = 1,
eval_metric_fn: Callable[[], Any] = None, train_metric_fn: Optional[Callable[[], Any]] = None,
summary_writer_fn: Callable[[Text, Text], eval_metric_fn: Optional[Callable[[], Any]] = None,
SummaryWriter] = SummaryWriter, summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter,
init_checkpoint: Callable[[tf.keras.Model], Any] = None, init_checkpoint: Optional[Callable[[tf.keras.Model], Any]] = None,
custom_callbacks: List[tf.keras.callbacks.Callback] = None, custom_callbacks: Optional[List[tf.keras.callbacks.Callback]] = None,
continuous_eval: bool = False, continuous_eval: bool = False,
save_config: bool = True): save_config: bool = True):
"""Runs distributed training. """Runs distributed training.
Args: Args:
...@@ -590,7 +590,7 @@ class DistributedExecutor(object): ...@@ -590,7 +590,7 @@ class DistributedExecutor(object):
eval_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset], eval_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_metric_fn: Callable[[], Any], eval_metric_fn: Callable[[], Any],
total_steps: int = -1, total_steps: int = -1,
eval_timeout: int = None, eval_timeout: Optional[int] = None,
min_eval_interval: int = 180, min_eval_interval: int = 180,
summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter): summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter):
"""Runs distributed evaluation on model folder. """Runs distributed evaluation on model folder.
...@@ -646,7 +646,7 @@ class DistributedExecutor(object): ...@@ -646,7 +646,7 @@ class DistributedExecutor(object):
eval_input_fn: Callable[[params_dict.ParamsDict], eval_input_fn: Callable[[params_dict.ParamsDict],
tf.data.Dataset], tf.data.Dataset],
eval_metric_fn: Callable[[], Any], eval_metric_fn: Callable[[], Any],
summary_writer: SummaryWriter = None): summary_writer: Optional[SummaryWriter] = None):
"""Runs distributed evaluation on the one checkpoint. """Runs distributed evaluation on the one checkpoint.
Args: Args:
......
...@@ -20,7 +20,7 @@ from __future__ import division ...@@ -20,7 +20,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from typing import Any, List, MutableMapping, Text from typing import Any, List, MutableMapping, Optional, Text
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -39,7 +39,7 @@ def get_callbacks( ...@@ -39,7 +39,7 @@ def get_callbacks(
initial_step: int = 0, initial_step: int = 0,
batch_size: int = 0, batch_size: int = 0,
log_steps: int = 0, log_steps: int = 0,
model_dir: str = None, model_dir: Optional[str] = None,
backup_and_restore: bool = False) -> List[tf.keras.callbacks.Callback]: backup_and_restore: bool = False) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks.""" """Get all callbacks."""
model_dir = model_dir or '' model_dir = model_dir or ''
...@@ -120,7 +120,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -120,7 +120,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_batch_begin(self, def on_batch_begin(self,
epoch: int, epoch: int,
logs: MutableMapping[str, Any] = None) -> None: logs: Optional[MutableMapping[str, Any]] = None) -> None:
self.step += 1 self.step += 1
if logs is None: if logs is None:
logs = {} logs = {}
...@@ -129,7 +129,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -129,7 +129,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_begin(self, def on_epoch_begin(self,
epoch: int, epoch: int,
logs: MutableMapping[str, Any] = None) -> None: logs: Optional[MutableMapping[str, Any]] = None) -> None:
if logs is None: if logs is None:
logs = {} logs = {}
metrics = self._calculate_metrics() metrics = self._calculate_metrics()
...@@ -140,7 +140,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -140,7 +140,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_end(self, def on_epoch_end(self,
epoch: int, epoch: int,
logs: MutableMapping[str, Any] = None) -> None: logs: Optional[MutableMapping[str, Any]] = None) -> None:
if logs is None: if logs is None:
logs = {} logs = {}
metrics = self._calculate_metrics() metrics = self._calculate_metrics()
...@@ -195,13 +195,13 @@ class MovingAverageCallback(tf.keras.callbacks.Callback): ...@@ -195,13 +195,13 @@ class MovingAverageCallback(tf.keras.callbacks.Callback):
optimization.ExponentialMovingAverage) optimization.ExponentialMovingAverage)
self.model.optimizer.shadow_copy(self.model) self.model.optimizer.shadow_copy(self.model)
def on_test_begin(self, logs: MutableMapping[Text, Any] = None): def on_test_begin(self, logs: Optional[MutableMapping[Text, Any]] = None):
self.model.optimizer.swap_weights() self.model.optimizer.swap_weights()
def on_test_end(self, logs: MutableMapping[Text, Any] = None): def on_test_end(self, logs: Optional[MutableMapping[Text, Any]] = None):
self.model.optimizer.swap_weights() self.model.optimizer.swap_weights()
def on_train_end(self, logs: MutableMapping[Text, Any] = None): def on_train_end(self, logs: Optional[MutableMapping[Text, Any]] = None):
if self.overwrite_weights_on_train_end: if self.overwrite_weights_on_train_end:
self.model.optimizer.assign_average_vars(self.model.variables) self.model.optimizer.assign_average_vars(self.model.variables)
......
...@@ -280,7 +280,9 @@ class DatasetBuilder: ...@@ -280,7 +280,9 @@ class DatasetBuilder:
raise e raise e
return self.builder_info return self.builder_info
def build(self, strategy: tf.distribute.Strategy = None) -> tf.data.Dataset: def build(
self,
strategy: Optional[tf.distribute.Strategy] = None) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it using an optional strategy. """Construct a dataset end-to-end and return it using an optional strategy.
Args: Args:
...@@ -305,7 +307,8 @@ class DatasetBuilder: ...@@ -305,7 +307,8 @@ class DatasetBuilder:
def _build( def _build(
self, self,
input_context: tf.distribute.InputContext = None) -> tf.data.Dataset: input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it. """Construct a dataset end-to-end and return it.
Args: Args:
......
...@@ -160,9 +160,9 @@ def conv2d_block(inputs: tf.Tensor, ...@@ -160,9 +160,9 @@ def conv2d_block(inputs: tf.Tensor,
strides: Any = (1, 1), strides: Any = (1, 1),
use_batch_norm: bool = True, use_batch_norm: bool = True,
use_bias: bool = False, use_bias: bool = False,
activation: Any = None, activation: Optional[Any] = None,
depthwise: bool = False, depthwise: bool = False,
name: Text = None): name: Optional[Text] = None):
"""A conv2d followed by batch norm and an activation.""" """A conv2d followed by batch norm and an activation."""
batch_norm = common_modules.get_batch_norm(config.batch_norm) batch_norm = common_modules.get_batch_norm(config.batch_norm)
bn_momentum = config.bn_momentum bn_momentum = config.bn_momentum
...@@ -212,7 +212,7 @@ def conv2d_block(inputs: tf.Tensor, ...@@ -212,7 +212,7 @@ def conv2d_block(inputs: tf.Tensor,
def mb_conv_block(inputs: tf.Tensor, def mb_conv_block(inputs: tf.Tensor,
block: BlockConfig, block: BlockConfig,
config: ModelConfig, config: ModelConfig,
prefix: Text = None): prefix: Optional[Text] = None):
"""Mobile Inverted Residual Bottleneck. """Mobile Inverted Residual Bottleneck.
Args: Args:
...@@ -432,8 +432,8 @@ class EfficientNet(tf.keras.Model): ...@@ -432,8 +432,8 @@ class EfficientNet(tf.keras.Model):
""" """
def __init__(self, def __init__(self,
config: ModelConfig = None, config: Optional[ModelConfig] = None,
overrides: Dict[Text, Any] = None): overrides: Optional[Dict[Text, Any]] = None):
"""Create an EfficientNet model. """Create an EfficientNet model.
Args: Args:
...@@ -463,9 +463,9 @@ class EfficientNet(tf.keras.Model): ...@@ -463,9 +463,9 @@ class EfficientNet(tf.keras.Model):
@classmethod @classmethod
def from_name(cls, def from_name(cls,
model_name: Text, model_name: Text,
model_weights_path: Text = None, model_weights_path: Optional[Text] = None,
weights_format: Text = 'saved_model', weights_format: Text = 'saved_model',
overrides: Dict[Text, Any] = None): overrides: Optional[Dict[Text, Any]] = None):
"""Construct an EfficientNet model from a predefined model name. """Construct an EfficientNet model from a predefined model name.
E.g., `EfficientNet.from_name('efficientnet-b0')`. E.g., `EfficientNet.from_name('efficientnet-b0')`.
......
...@@ -18,7 +18,7 @@ from __future__ import division ...@@ -18,7 +18,7 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
from typing import Any, Dict, Text from typing import Any, Dict, Optional, Text
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -35,7 +35,7 @@ def build_optimizer( ...@@ -35,7 +35,7 @@ def build_optimizer(
optimizer_name: Text, optimizer_name: Text,
base_learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule, base_learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule,
params: Dict[Text, Any], params: Dict[Text, Any],
model: tf.keras.Model = None): model: Optional[tf.keras.Model] = None):
"""Build the optimizer based on name. """Build the optimizer based on name.
Args: Args:
...@@ -124,9 +124,9 @@ def build_optimizer( ...@@ -124,9 +124,9 @@ def build_optimizer(
def build_learning_rate(params: base_configs.LearningRateConfig, def build_learning_rate(params: base_configs.LearningRateConfig,
batch_size: int = None, batch_size: Optional[int] = None,
train_epochs: int = None, train_epochs: Optional[int] = None,
train_steps: int = None): train_steps: Optional[int] = None):
"""Build the learning rate given the provided configuration.""" """Build the learning rate given the provided configuration."""
decay_type = params.name decay_type = params.name
base_lr = params.initial_lr base_lr = params.initial_lr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment