Commit 04585eca authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 430533634
parent 475c3c3e
...@@ -26,6 +26,8 @@ from official.modeling.hyperparams import oneof ...@@ -26,6 +26,8 @@ from official.modeling.hyperparams import oneof
from official.projects.edgetpu.vision.modeling import common_modules from official.projects.edgetpu.vision.modeling import common_modules
from official.projects.edgetpu.vision.modeling import custom_layers from official.projects.edgetpu.vision.modeling import custom_layers
InitializerType = Optional[Union[str, tf.keras.initializers.Initializer]]
@dataclasses.dataclass @dataclasses.dataclass
class BlockType(oneof.OneOfConfig): class BlockType(oneof.OneOfConfig):
...@@ -216,6 +218,8 @@ class ModelConfig(base_config.Config): ...@@ -216,6 +218,8 @@ class ModelConfig(base_config.Config):
stem_base_filters: int = 64 stem_base_filters: int = 64
stem_kernel_size: int = 5 stem_kernel_size: int = 5
top_base_filters: int = 1280 top_base_filters: int = 1280
conv_kernel_initializer: InitializerType = None
dense_kernel_initializer: InitializerType = None
blocks: Tuple[BlockConfig, ...] = ( blocks: Tuple[BlockConfig, ...] = (
# (input_filters, output_filters, kernel_size, num_repeat, # (input_filters, output_filters, kernel_size, num_repeat,
# expand_ratio, strides, se_ratio, id_skip, fused_conv, conv_type) # expand_ratio, strides, se_ratio, id_skip, fused_conv, conv_type)
...@@ -279,7 +283,8 @@ def mobilenet_edgetpu_v2_base( ...@@ -279,7 +283,8 @@ def mobilenet_edgetpu_v2_base(
drop_connect_rate: float = 0.1, drop_connect_rate: float = 0.1,
filter_size_overrides: Optional[Dict[int, int]] = None, filter_size_overrides: Optional[Dict[int, int]] = None,
block_op_overrides: Optional[Dict[int, Dict[int, Dict[str, Any]]]] = None, block_op_overrides: Optional[Dict[int, Dict[int, Dict[str, Any]]]] = None,
block_group_overrides: Optional[Dict[int, Dict[str, Any]]] = None): block_group_overrides: Optional[Dict[int, Dict[str, Any]]] = None,
topology: Optional[TopologyConfig] = None):
"""Creates MobilenetEdgeTPUV2 ModelConfig based on tuning parameters.""" """Creates MobilenetEdgeTPUV2 ModelConfig based on tuning parameters."""
config = ModelConfig() config = ModelConfig()
...@@ -295,7 +300,7 @@ def mobilenet_edgetpu_v2_base( ...@@ -295,7 +300,7 @@ def mobilenet_edgetpu_v2_base(
} }
config = config.replace(**param_overrides) config = config.replace(**param_overrides)
topology_config = TopologyConfig() topology_config = TopologyConfig() if topology is None else topology
if filter_size_overrides: if filter_size_overrides:
for group_id in filter_size_overrides: for group_id in filter_size_overrides:
topology_config.block_groups[group_id].filters = filter_size_overrides[ topology_config.block_groups[group_id].filters = filter_size_overrides[
...@@ -724,6 +729,7 @@ def conv2d_block_as_layers( ...@@ -724,6 +729,7 @@ def conv2d_block_as_layers(
use_bias: bool = False, use_bias: bool = False,
activation: Any = None, activation: Any = None,
depthwise: bool = False, depthwise: bool = False,
kernel_initializer: InitializerType = None,
name: Optional[str] = None) -> List[tf.keras.layers.Layer]: name: Optional[str] = None) -> List[tf.keras.layers.Layer]:
"""A conv2d followed by batch norm and an activation.""" """A conv2d followed by batch norm and an activation."""
batch_norm = common_modules.get_batch_norm(config.batch_norm) batch_norm = common_modules.get_batch_norm(config.batch_norm)
...@@ -748,11 +754,13 @@ def conv2d_block_as_layers( ...@@ -748,11 +754,13 @@ def conv2d_block_as_layers(
sequential_layers: List[tf.keras.layers.Layer] = [] sequential_layers: List[tf.keras.layers.Layer] = []
if depthwise: if depthwise:
conv2d = tf.keras.layers.DepthwiseConv2D conv2d = tf.keras.layers.DepthwiseConv2D
init_kwargs.update({'depthwise_initializer': CONV_KERNEL_INITIALIZER}) init_kwargs.update({'depthwise_initializer': kernel_initializer})
else: else:
conv2d = tf.keras.layers.Conv2D conv2d = tf.keras.layers.Conv2D
init_kwargs.update({'filters': conv_filters, init_kwargs.update({
'kernel_initializer': CONV_KERNEL_INITIALIZER}) 'filters': conv_filters,
'kernel_initializer': kernel_initializer
})
sequential_layers.append(conv2d(**init_kwargs)) sequential_layers.append(conv2d(**init_kwargs))
...@@ -780,12 +788,21 @@ def conv2d_block(inputs: tf.Tensor, ...@@ -780,12 +788,21 @@ def conv2d_block(inputs: tf.Tensor,
use_bias: bool = False, use_bias: bool = False,
activation: Any = None, activation: Any = None,
depthwise: bool = False, depthwise: bool = False,
kernel_initializer: Optional[InitializerType] = None,
name: Optional[str] = None) -> tf.Tensor: name: Optional[str] = None) -> tf.Tensor:
"""Compatibility with third_party/car/deep_nets.""" """Compatibility with third_party/car/deep_nets."""
x = inputs x = inputs
for layer in conv2d_block_as_layers(conv_filters, config, kernel_size, for layer in conv2d_block_as_layers(
strides, use_batch_norm, use_bias, conv_filters=conv_filters,
activation, depthwise, name): config=config,
kernel_size=kernel_size,
strides=strides,
use_batch_norm=use_batch_norm,
use_bias=use_bias,
activation=activation,
depthwise=depthwise,
kernel_initializer=kernel_initializer,
name=name):
x = layer(x) x = layer(x)
return x return x
...@@ -828,6 +845,9 @@ class _MbConvBlock: ...@@ -828,6 +845,9 @@ class _MbConvBlock:
use_groupconv = block.conv_type == 'group' use_groupconv = block.conv_type == 'group'
prefix = prefix or '' prefix = prefix or ''
self.name = prefix self.name = prefix
conv_kernel_initializer = (
config.conv_kernel_initializer if config.conv_kernel_initializer
is not None else CONV_KERNEL_INITIALIZER)
filters = block.input_filters * block.expand_ratio filters = block.input_filters * block.expand_ratio
...@@ -851,22 +871,26 @@ class _MbConvBlock: ...@@ -851,22 +871,26 @@ class _MbConvBlock:
activation=activation, activation=activation,
name=prefix + 'fused')) name=prefix + 'fused'))
else: else:
self.expand_block.extend(conv2d_block_as_layers( self.expand_block.extend(
filters, conv2d_block_as_layers(
config, conv_filters=filters,
kernel_size=block.kernel_size, config=config,
strides=block.strides, kernel_size=block.kernel_size,
activation=activation, strides=block.strides,
name=prefix + 'fused')) activation=activation,
kernel_initializer=conv_kernel_initializer,
name=prefix + 'fused'))
else: else:
if block.expand_ratio != 1: if block.expand_ratio != 1:
# Expansion phase with a pointwise conv # Expansion phase with a pointwise conv
self.expand_block.extend(conv2d_block_as_layers( self.expand_block.extend(
filters, conv2d_block_as_layers(
config, conv_filters=filters,
kernel_size=(1, 1), config=config,
activation=activation, kernel_size=(1, 1),
name=prefix + 'expand')) activation=activation,
kernel_initializer=conv_kernel_initializer,
name=prefix + 'expand'))
# Main kernel, after the expansion (if applicable, i.e. not fused). # Main kernel, after the expansion (if applicable, i.e. not fused).
if use_depthwise: if use_depthwise:
...@@ -876,6 +900,7 @@ class _MbConvBlock: ...@@ -876,6 +900,7 @@ class _MbConvBlock:
kernel_size=block.kernel_size, kernel_size=block.kernel_size,
strides=block.strides, strides=block.strides,
activation=activation, activation=activation,
kernel_initializer=conv_kernel_initializer,
depthwise=True, depthwise=True,
name=prefix + 'depthwise')) name=prefix + 'depthwise'))
elif use_groupconv: elif use_groupconv:
...@@ -907,27 +932,30 @@ class _MbConvBlock: ...@@ -907,27 +932,30 @@ class _MbConvBlock:
tf.keras.layers.Reshape(se_shape, name=prefix + 'se_reshape')) tf.keras.layers.Reshape(se_shape, name=prefix + 'se_reshape'))
self.squeeze_excitation.extend( self.squeeze_excitation.extend(
conv2d_block_as_layers( conv2d_block_as_layers(
num_reduced_filters, conv_filters=num_reduced_filters,
config, config=config,
use_bias=True, use_bias=True,
use_batch_norm=False, use_batch_norm=False,
activation=activation, activation=activation,
kernel_initializer=conv_kernel_initializer,
name=prefix + 'se_reduce')) name=prefix + 'se_reduce'))
self.squeeze_excitation.extend( self.squeeze_excitation.extend(
conv2d_block_as_layers( conv2d_block_as_layers(
filters, conv_filters=filters,
config, config=config,
use_bias=True, use_bias=True,
use_batch_norm=False, use_batch_norm=False,
activation='sigmoid', activation='sigmoid',
kernel_initializer=conv_kernel_initializer,
name=prefix + 'se_expand')) name=prefix + 'se_expand'))
# Output phase # Output phase
self.project_block.extend( self.project_block.extend(
conv2d_block_as_layers( conv2d_block_as_layers(
block.output_filters, conv_filters=block.output_filters,
config, config=config,
activation=None, activation=None,
kernel_initializer=conv_kernel_initializer,
name=prefix + 'project')) name=prefix + 'project'))
# Add identity so that quantization-aware training can insert quantization # Add identity so that quantization-aware training can insert quantization
...@@ -993,6 +1021,12 @@ def mobilenet_edgetpu_v2(image_input: tf.keras.layers.Input, ...@@ -993,6 +1021,12 @@ def mobilenet_edgetpu_v2(image_input: tf.keras.layers.Input,
activation = tf_utils.get_activation(config.activation) activation = tf_utils.get_activation(config.activation)
dropout_rate = config.dropout_rate dropout_rate = config.dropout_rate
drop_connect_rate = config.drop_connect_rate drop_connect_rate = config.drop_connect_rate
conv_kernel_initializer = (
config.conv_kernel_initializer if config.conv_kernel_initializer
is not None else CONV_KERNEL_INITIALIZER)
dense_kernel_initializer = (
config.dense_kernel_initializer if config.dense_kernel_initializer
is not None else DENSE_KERNEL_INITIALIZER)
num_classes = config.num_classes num_classes = config.num_classes
input_channels = config.input_channels input_channels = config.input_channels
rescale_input = config.rescale_input rescale_input = config.rescale_input
...@@ -1010,12 +1044,13 @@ def mobilenet_edgetpu_v2(image_input: tf.keras.layers.Input, ...@@ -1010,12 +1044,13 @@ def mobilenet_edgetpu_v2(image_input: tf.keras.layers.Input,
# Build stem # Build stem
x = conv2d_block( x = conv2d_block(
x, inputs=x,
round_filters(stem_base_filters, config), conv_filters=round_filters(stem_base_filters, config),
config, config=config,
kernel_size=[stem_kernel_size, stem_kernel_size], kernel_size=[stem_kernel_size, stem_kernel_size],
strides=[2, 2], strides=[2, 2],
activation=activation, activation=activation,
kernel_initializer=conv_kernel_initializer,
name='stem') name='stem')
# Build blocks # Build blocks
...@@ -1061,11 +1096,13 @@ def mobilenet_edgetpu_v2(image_input: tf.keras.layers.Input, ...@@ -1061,11 +1096,13 @@ def mobilenet_edgetpu_v2(image_input: tf.keras.layers.Input,
if config.backbone_only: if config.backbone_only:
return backbone_levels return backbone_levels
# Build top # Build top
x = conv2d_block(x, x = conv2d_block(
round_filters(top_base_filters, config), inputs=x,
config, conv_filters=round_filters(top_base_filters, config),
activation=activation, config=config,
name='top') activation=activation,
kernel_initializer=conv_kernel_initializer,
name='top')
# Build classifier # Build classifier
pool_size = (x.shape.as_list()[1], x.shape.as_list()[2]) pool_size = (x.shape.as_list()[1], x.shape.as_list()[2])
...@@ -1075,7 +1112,7 @@ def mobilenet_edgetpu_v2(image_input: tf.keras.layers.Input, ...@@ -1075,7 +1112,7 @@ def mobilenet_edgetpu_v2(image_input: tf.keras.layers.Input,
x = tf.keras.layers.Conv2D( x = tf.keras.layers.Conv2D(
num_classes, num_classes,
1, 1,
kernel_initializer=DENSE_KERNEL_INITIALIZER, kernel_initializer=dense_kernel_initializer,
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay), bias_regularizer=tf.keras.regularizers.l2(weight_decay),
name='logits')( name='logits')(
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mobilenet_edgetpu_v2_model_blocks."""
import tensorflow as tf
from official.projects.edgetpu.vision.modeling import custom_layers
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model_blocks
class MobilenetEdgetpuV2ModelBlocksTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.model_config = mobilenet_edgetpu_v2_model_blocks.ModelConfig()
def test_model_creatation(self):
model_input = tf.keras.layers.Input(shape=(224, 224, 1))
model_output = mobilenet_edgetpu_v2_model_blocks.mobilenet_edgetpu_v2(
image_input=model_input,
config=self.model_config)
test_model = tf.keras.Model(inputs=model_input, outputs=model_output)
self.assertIsInstance(test_model, tf.keras.Model)
self.assertEqual(test_model.input.shape, (None, 224, 224, 1))
self.assertEqual(test_model.output.shape, (None, 1001))
def test_model_with_customized_kernel_initializer(self):
self.model_config.conv_kernel_initializer = 'he_uniform'
self.model_config.dense_kernel_initializer = 'glorot_normal'
model_input = tf.keras.layers.Input(shape=(224, 224, 1))
model_output = mobilenet_edgetpu_v2_model_blocks.mobilenet_edgetpu_v2(
image_input=model_input,
config=self.model_config)
test_model = tf.keras.Model(inputs=model_input, outputs=model_output)
conv_layer_stack = []
for layer in test_model.layers:
if (isinstance(layer, tf.keras.layers.Conv2D) or
isinstance(layer, tf.keras.layers.DepthwiseConv2D) or
isinstance(layer, custom_layers.GroupConv2D)):
conv_layer_stack.append(layer)
self.assertGreater(len(conv_layer_stack), 2)
# The last Conv layer is used as a Dense layer.
for layer in conv_layer_stack[:-1]:
if isinstance(layer, custom_layers.GroupConv2D):
self.assertIsInstance(layer.kernel_initializer,
tf.keras.initializers.GlorotUniform)
elif isinstance(layer, tf.keras.layers.Conv2D):
self.assertIsInstance(layer.kernel_initializer,
tf.keras.initializers.HeUniform)
elif isinstance(layer, tf.keras.layers.DepthwiseConv2D):
self.assertIsInstance(layer.depthwise_initializer,
tf.keras.initializers.HeUniform)
self.assertIsInstance(conv_layer_stack[-1].kernel_initializer,
tf.keras.initializers.GlorotNormal)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment