Unverified Commit 5ffcc5b6 authored by Anirudh Vegesana's avatar Anirudh Vegesana Committed by GitHub
Browse files

Merge branch 'purdue-yolo' into detection_generator_pr

parents 0b81a843 76e0c014
...@@ -53,6 +53,18 @@ flags.DEFINE_string( ...@@ -53,6 +53,18 @@ flags.DEFINE_string(
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with ' '3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 ' 'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).') 'followed by 5x1x1 conv).')
flags.DEFINE_string(
'se_type', '3d',
'3d, 2d, or 2plus3d. 3d uses the default 3D spatiotemporal global average'
'pooling for squeeze excitation. 2d uses 2D spatial global average pooling '
'on each frame. 2plus3d concatenates both 3D and 2D global average '
'pooling.')
flags.DEFINE_string(
'activation', 'swish',
'The main activation to use across layers.')
flags.DEFINE_string(
'gating_activation', 'sigmoid',
'The gating activation to use in squeeze-excitation layers.')
flags.DEFINE_bool( flags.DEFINE_bool(
'use_positional_encoding', False, 'use_positional_encoding', False,
'Whether to use positional encoding (only applied when causal=True).') 'Whether to use positional encoding (only applied when causal=True).')
...@@ -94,6 +106,9 @@ def main(_) -> None: ...@@ -94,6 +106,9 @@ def main(_) -> None:
conv_type=FLAGS.conv_type, conv_type=FLAGS.conv_type,
use_external_states=FLAGS.causal, use_external_states=FLAGS.causal,
input_specs=input_specs, input_specs=input_specs,
activation=FLAGS.activation,
gating_activation=FLAGS.gating_activation,
se_type=FLAGS.se_type,
use_positional_encoding=FLAGS.use_positional_encoding) use_positional_encoding=FLAGS.use_positional_encoding)
model = movinet_model.MovinetClassifier( model = movinet_model.MovinetClassifier(
backbone, backbone,
......
...@@ -307,8 +307,10 @@ class Movinet(tf.keras.Model): ...@@ -307,8 +307,10 @@ class Movinet(tf.keras.Model):
causal: bool = False, causal: bool = False,
use_positional_encoding: bool = False, use_positional_encoding: bool = False,
conv_type: str = '3d', conv_type: str = '3d',
se_type: str = '3d',
input_specs: Optional[tf.keras.layers.InputSpec] = None, input_specs: Optional[tf.keras.layers.InputSpec] = None,
activation: str = 'swish', activation: str = 'swish',
gating_activation: str = 'sigmoid',
use_sync_bn: bool = True, use_sync_bn: bool = True,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
...@@ -332,8 +334,13 @@ class Movinet(tf.keras.Model): ...@@ -332,8 +334,13 @@ class Movinet(tf.keras.Model):
3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with 3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv). by 5x1x1 conv).
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
input_specs: the model input spec to use. input_specs: the model input spec to use.
activation: name of the activation function. activation: name of the main activation function.
gating_activation: gating activation to use in squeeze excitation layers.
use_sync_bn: if True, use synchronized batch normalization. use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: normalization momentum for the moving average. norm_momentum: normalization momentum for the moving average.
norm_epsilon: small float added to variance to avoid dividing by norm_epsilon: small float added to variance to avoid dividing by
...@@ -354,15 +361,19 @@ class Movinet(tf.keras.Model): ...@@ -354,15 +361,19 @@ class Movinet(tf.keras.Model):
if conv_type not in ('3d', '2plus1d', '3d_2plus1d'): if conv_type not in ('3d', '2plus1d', '3d_2plus1d'):
raise ValueError('Unknown conv type: {}'.format(conv_type)) raise ValueError('Unknown conv type: {}'.format(conv_type))
if se_type not in ('3d', '2d', '2plus3d'):
raise ValueError('Unknown squeeze excitation type: {}'.format(se_type))
self._model_id = model_id self._model_id = model_id
self._block_specs = block_specs self._block_specs = block_specs
self._causal = causal self._causal = causal
self._use_positional_encoding = use_positional_encoding self._use_positional_encoding = use_positional_encoding
self._conv_type = conv_type self._conv_type = conv_type
self._se_type = se_type
self._input_specs = input_specs self._input_specs = input_specs
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._activation = activation self._activation = activation
self._gating_activation = gating_activation
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
if use_sync_bn: if use_sync_bn:
...@@ -475,10 +486,12 @@ class Movinet(tf.keras.Model): ...@@ -475,10 +486,12 @@ class Movinet(tf.keras.Model):
strides=strides, strides=strides,
causal=self._causal, causal=self._causal,
activation=self._activation, activation=self._activation,
gating_activation=self._gating_activation,
stochastic_depth_drop_rate=stochastic_depth_drop_rate, stochastic_depth_drop_rate=stochastic_depth_drop_rate,
conv_type=self._conv_type, conv_type=self._conv_type,
use_positional_encoding=self._use_positional_encoding and se_type=self._se_type,
self._causal, use_positional_encoding=
self._use_positional_encoding and self._causal,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
batch_norm_layer=self._norm, batch_norm_layer=self._norm,
...@@ -691,8 +704,10 @@ def build_movinet( ...@@ -691,8 +704,10 @@ def build_movinet(
causal=backbone_cfg.causal, causal=backbone_cfg.causal,
use_positional_encoding=backbone_cfg.use_positional_encoding, use_positional_encoding=backbone_cfg.use_positional_encoding,
conv_type=backbone_cfg.conv_type, conv_type=backbone_cfg.conv_type,
se_type=backbone_cfg.se_type,
input_specs=input_specs, input_specs=input_specs,
activation=norm_activation_config.activation, activation=backbone_cfg.activation,
gating_activation=backbone_cfg.gating_activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
......
...@@ -669,6 +669,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -669,6 +669,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
def __init__( def __init__(
self, self,
hidden_filters: int, hidden_filters: int,
se_type: str = '3d',
activation: nn_layers.Activation = 'swish', activation: nn_layers.Activation = 'swish',
gating_activation: nn_layers.Activation = 'sigmoid', gating_activation: nn_layers.Activation = 'sigmoid',
causal: bool = False, causal: bool = False,
...@@ -683,6 +684,10 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -683,6 +684,10 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
Args: Args:
hidden_filters: The hidden filters of squeeze excite. hidden_filters: The hidden filters of squeeze excite.
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
activation: name of the activation function. activation: name of the activation function.
gating_activation: name of the activation function for gating. gating_activation: name of the activation function for gating.
causal: if True, use causal mode in the global average pool. causal: if True, use causal mode in the global average pool.
...@@ -700,6 +705,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -700,6 +705,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
super(StreamSqueezeExcitation, self).__init__(**kwargs) super(StreamSqueezeExcitation, self).__init__(**kwargs)
self._hidden_filters = hidden_filters self._hidden_filters = hidden_filters
self._se_type = se_type
self._activation = activation self._activation = activation
self._gating_activation = gating_activation self._gating_activation = gating_activation
self._causal = causal self._causal = causal
...@@ -709,8 +715,9 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -709,8 +715,9 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
self._use_positional_encoding = use_positional_encoding self._use_positional_encoding = use_positional_encoding
self._state_prefix = state_prefix self._state_prefix = state_prefix
self._pool = nn_layers.GlobalAveragePool3D( self._spatiotemporal_pool = nn_layers.GlobalAveragePool3D(
keepdims=True, causal=causal, state_prefix=state_prefix) keepdims=True, causal=causal, state_prefix=state_prefix)
self._spatial_pool = nn_layers.SpatialAveragePool3D(keepdims=True)
self._pos_encoding = None self._pos_encoding = None
if use_positional_encoding: if use_positional_encoding:
...@@ -721,6 +728,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -721,6 +728,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
config = { config = {
'hidden_filters': self._hidden_filters, 'hidden_filters': self._hidden_filters,
'se_type': self._se_type,
'activation': self._activation, 'activation': self._activation,
'gating_activation': self._gating_activation, 'gating_activation': self._gating_activation,
'causal': self._causal, 'causal': self._causal,
...@@ -777,13 +785,28 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -777,13 +785,28 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
""" """
states = dict(states) if states is not None else {} states = dict(states) if states is not None else {}
x, states = self._pool(inputs, states=states) if self._se_type == '3d':
x, states = self._spatiotemporal_pool(inputs, states=states)
elif self._se_type == '2d':
x = self._spatial_pool(inputs)
elif self._se_type == '2plus3d':
x_space = self._spatial_pool(inputs)
x, states = self._spatiotemporal_pool(x_space, states=states)
if not self._causal:
x = tf.tile(x, [1, tf.shape(inputs)[1], 1, 1, 1])
x = tf.concat([x, x_space], axis=-1)
else:
raise ValueError('Unknown Squeeze Excitation type {}'.format(
self._se_type))
if self._pos_encoding is not None: if self._pos_encoding is not None:
x, states = self._pos_encoding(x, states=states) x, states = self._pos_encoding(x, states=states)
x = self._se_reduce(x) x = self._se_reduce(x)
x = self._se_expand(x) x = self._se_expand(x)
return x * inputs, states return x * inputs, states
...@@ -999,9 +1022,11 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -999,9 +1022,11 @@ class MovinetBlock(tf.keras.layers.Layer):
strides: Union[int, Sequence[int]] = (1, 1, 1), strides: Union[int, Sequence[int]] = (1, 1, 1),
causal: bool = False, causal: bool = False,
activation: nn_layers.Activation = 'swish', activation: nn_layers.Activation = 'swish',
gating_activation: nn_layers.Activation = 'sigmoid',
se_ratio: float = 0.25, se_ratio: float = 0.25,
stochastic_depth_drop_rate: float = 0., stochastic_depth_drop_rate: float = 0.,
conv_type: str = '3d', conv_type: str = '3d',
se_type: str = '3d',
use_positional_encoding: bool = False, use_positional_encoding: bool = False,
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
...@@ -1021,12 +1046,17 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1021,12 +1046,17 @@ class MovinetBlock(tf.keras.layers.Layer):
strides: strides of the main depthwise convolution. strides: strides of the main depthwise convolution.
causal: if True, run the temporal convolutions in causal mode. causal: if True, run the temporal convolutions in causal mode.
activation: activation to use across all conv operations. activation: activation to use across all conv operations.
gating_activation: gating activation to use in squeeze excitation layers.
se_ratio: squeeze excite filters ratio. se_ratio: squeeze excite filters ratio.
stochastic_depth_drop_rate: optional drop rate for stochastic depth. stochastic_depth_drop_rate: optional drop rate for stochastic depth.
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their ops. '2plus1d' split any 3D ops into two sequential 2D ops with their
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but own batch norm and activation. '3d_2plus1d' is like '2plus1d', but
uses two sequential 3D ops instead. uses two sequential 3D ops instead.
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
use_positional_encoding: add a positional encoding after the (cumulative) use_positional_encoding: add a positional encoding after the (cumulative)
global average pooling layer in the squeeze excite layer. global average pooling layer in the squeeze excite layer.
kernel_initializer: kernel initializer for the conv operations. kernel_initializer: kernel initializer for the conv operations.
...@@ -1042,17 +1072,21 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1042,17 +1072,21 @@ class MovinetBlock(tf.keras.layers.Layer):
self._kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size') self._kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size')
self._strides = normalize_tuple(strides, 3, 'strides') self._strides = normalize_tuple(strides, 3, 'strides')
# Use a multiplier of 2 if concatenating multiple features
se_multiplier = 2 if se_type == '2plus3d' else 1
se_hidden_filters = nn_layers.make_divisible( se_hidden_filters = nn_layers.make_divisible(
se_ratio * expand_filters, divisor=8) se_ratio * expand_filters * se_multiplier, divisor=8)
self._out_filters = out_filters self._out_filters = out_filters
self._expand_filters = expand_filters self._expand_filters = expand_filters
self._kernel_size = kernel_size self._kernel_size = kernel_size
self._causal = causal self._causal = causal
self._activation = activation self._activation = activation
self._gating_activation = gating_activation
self._se_ratio = se_ratio self._se_ratio = se_ratio
self._downsample = any(s > 1 for s in self._strides) self._downsample = any(s > 1 for s in self._strides)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._conv_type = conv_type self._conv_type = conv_type
self._se_type = se_type
self._use_positional_encoding = use_positional_encoding self._use_positional_encoding = use_positional_encoding
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
...@@ -1103,7 +1137,9 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1103,7 +1137,9 @@ class MovinetBlock(tf.keras.layers.Layer):
name='projection') name='projection')
self._attention = StreamSqueezeExcitation( self._attention = StreamSqueezeExcitation(
se_hidden_filters, se_hidden_filters,
se_type=se_type,
activation=activation, activation=activation,
gating_activation=gating_activation,
causal=self._causal, causal=self._causal,
conv_type=conv_type, conv_type=conv_type,
use_positional_encoding=use_positional_encoding, use_positional_encoding=use_positional_encoding,
...@@ -1121,9 +1157,11 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1121,9 +1157,11 @@ class MovinetBlock(tf.keras.layers.Layer):
'strides': self._strides, 'strides': self._strides,
'causal': self._causal, 'causal': self._causal,
'activation': self._activation, 'activation': self._activation,
'gating_activation': self._gating_activation,
'se_ratio': self._se_ratio, 'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate, 'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'conv_type': self._conv_type, 'conv_type': self._conv_type,
'se_type': self._se_type,
'use_positional_encoding': self._use_positional_encoding, 'use_positional_encoding': self._use_positional_encoding,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
......
...@@ -314,6 +314,43 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -314,6 +314,43 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
[[4., 4., 4.]]]]], [[4., 4., 4.]]]]],
1e-5, 1e-5) 1e-5, 1e-5)
def test_stream_squeeze_excitation_2plus3d(self):
se = movinet_layers.StreamSqueezeExcitation(
3,
se_type='2plus3d',
causal=True,
activation='hard_swish',
gating_activation='hard_sigmoid',
kernel_initializer='ones')
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
expected, _ = se(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = se(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[1., 1., 1.]],
[[1., 1., 1.]]],
[[[2., 2., 2.]],
[[2., 2., 2.]]],
[[[3., 3., 3.]],
[[3., 3., 3.]]],
[[[4., 4., 4.]],
[[4., 4., 4.]]]]])
def test_stream_movinet_block(self): def test_stream_movinet_block(self):
block = movinet_layers.MovinetBlock( block = movinet_layers.MovinetBlock(
out_filters=3, out_filters=3,
......
...@@ -131,6 +131,37 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -131,6 +131,37 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5) self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_classifier_mobile(self):
"""Test if the model can run with mobile parameters."""
tf.keras.backend.set_image_data_format('channels_last')
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
conv_type='2plus1d',
se_type='2plus3d',
activation='hard_swish',
gating_activation='hard_sigmoid'
)
model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True)
inputs = tf.ones([1, 8, 172, 172, 3])
init_states = model.init_states(tf.shape(inputs))
expected, _ = model({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1)
states = init_states
for frame in frames:
output, states = model({**states, 'image': frame})
predicted = output
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
"""Validate the classification network can be serialized and deserialized.""" """Validate the classification network can be serialized and deserialized."""
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converts '3d_2plus1d' checkpoints into '2plus1d'."""
from absl import app
from absl import flags
import tensorflow as tf
from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model
flags.DEFINE_string(
'input_checkpoint_path', None,
'Checkpoint path to load.')
flags.DEFINE_string(
'output_checkpoint_path', None,
'Export path to save the saved_model file.')
flags.DEFINE_string(
'model_id', 'a0', 'MoViNet model name.')
flags.DEFINE_bool(
'causal', False, 'Run the model in causal mode.')
flags.DEFINE_bool(
'use_positional_encoding', False,
'Whether to use positional encoding (only applied when causal=True).')
flags.DEFINE_integer(
'num_classes', 600, 'The number of classes for prediction.')
flags.DEFINE_bool(
'verify_output', False, 'Verify the output matches between the models.')
FLAGS = flags.FLAGS
def main(_) -> None:
backbone_2plus1d = movinet.Movinet(
model_id=FLAGS.model_id,
causal=FLAGS.causal,
conv_type='2plus1d',
use_positional_encoding=FLAGS.use_positional_encoding)
model_2plus1d = movinet_model.MovinetClassifier(
backbone=backbone_2plus1d,
num_classes=FLAGS.num_classes)
model_2plus1d.build([1, 1, 1, 1, 3])
backbone_3d_2plus1d = movinet.Movinet(
model_id=FLAGS.model_id,
causal=FLAGS.causal,
conv_type='3d_2plus1d',
use_positional_encoding=FLAGS.use_positional_encoding)
model_3d_2plus1d = movinet_model.MovinetClassifier(
backbone=backbone_3d_2plus1d,
num_classes=FLAGS.num_classes)
model_3d_2plus1d.build([1, 1, 1, 1, 3])
checkpoint = tf.train.Checkpoint(model=model_3d_2plus1d)
status = checkpoint.restore(FLAGS.input_checkpoint_path)
status.assert_existing_objects_matched()
# Ensure both models have the same weights
weights = []
for var_2plus1d, var_3d_2plus1d in zip(
model_2plus1d.get_weights(), model_3d_2plus1d.get_weights()):
if var_2plus1d.shape == var_3d_2plus1d.shape:
weights.append(var_3d_2plus1d)
else:
if var_3d_2plus1d.shape[0] == 1:
weight = var_3d_2plus1d[0]
else:
weight = var_3d_2plus1d[:, 0]
if weight.shape[-1] != var_2plus1d.shape[-1]:
# Transpose any depthwise kernels (conv3d --> depthwise_conv2d)
weight = tf.transpose(weight, perm=(0, 1, 3, 2))
weights.append(weight)
model_2plus1d.set_weights(weights)
if FLAGS.verify_output:
inputs = tf.random.uniform([1, 6, 64, 64, 3], dtype=tf.float32)
logits_2plus1d = model_2plus1d(inputs)
logits_3d_2plus1d = model_3d_2plus1d(inputs)
if tf.reduce_mean(logits_2plus1d - logits_3d_2plus1d) > 1e-5:
raise ValueError('Bad conversion, model outputs do not match.')
save_checkpoint = tf.train.Checkpoint(
model=model_2plus1d, backbone=backbone_2plus1d)
save_checkpoint.save(FLAGS.output_checkpoint_path)
if __name__ == '__main__':
flags.mark_flag_as_required('input_checkpoint_path')
flags.mark_flag_as_required('output_checkpoint_path')
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for convert_3d_2plus1d."""
import os
from absl import flags
import tensorflow as tf
from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model
from official.vision.beta.projects.movinet.tools import convert_3d_2plus1d
FLAGS = flags.FLAGS
class Convert3d2plus1dTest(tf.test.TestCase):
def test_convert_model(self):
saved_model_path = self.get_temp_dir()
input_checkpoint_path = os.path.join(saved_model_path, 'ckpt-input')
output_checkpoint_path = os.path.join(saved_model_path, 'ckpt')
model_3d_2plus1d = movinet_model.MovinetClassifier(
backbone=movinet.Movinet(
model_id='a0',
conv_type='3d_2plus1d'),
num_classes=600)
model_3d_2plus1d.build([1, 1, 1, 1, 3])
save_checkpoint = tf.train.Checkpoint(model=model_3d_2plus1d)
save_checkpoint.save(input_checkpoint_path)
FLAGS.input_checkpoint_path = f'{input_checkpoint_path}-1'
FLAGS.output_checkpoint_path = output_checkpoint_path
FLAGS.model_id = 'a0'
FLAGS.use_positional_encoding = False
FLAGS.num_classes = 600
FLAGS.verify_output = True
convert_3d_2plus1d.main('unused_args')
print(os.listdir(saved_model_path))
self.assertTrue(tf.io.gfile.exists(f'{output_checkpoint_path}-1.index'))
if __name__ == '__main__':
tf.test.main()
...@@ -46,6 +46,7 @@ from official.modeling import performance ...@@ -46,6 +46,7 @@ from official.modeling import performance
# Import movinet libraries to register the backbone and model into tf.vision # Import movinet libraries to register the backbone and model into tf.vision
# model garden factory. # model garden factory.
# pylint: disable=unused-import # pylint: disable=unused-import
# the followings are the necessary imports.
from official.vision.beta.projects.movinet.modeling import movinet from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model from official.vision.beta.projects.movinet.modeling import movinet_model
# pylint: enable=unused-import # pylint: enable=unused-import
......
# Panoptic Segmentation
## Description
Panoptic Segmentation combines the two distinct vision tasks - semantic
segmentation and instance segmentation. These tasks are unified such that, each
pixel in the image is assigned the label of the class it belongs to, and also
the instance identifier of the object it a part of.
## Environment setup
The code can be run on multiple GPUs or TPUs with different distribution
strategies. See the TensorFlow distributed training
[guide](https://www.tensorflow.org/guide/distributed_training) for an overview
of `tf.distribute`.
The code is compatible with TensorFlow 2.4+. See requirements.txt for all
prerequisites, and you can also install them using the following command. `pip
install -r ./official/requirements.txt`
**DISCLAIMER**: Panoptic MaskRCNN is still under active development, stay tuned!
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Panoptic Mask R-CNN configuration definition."""
import dataclasses
from official.vision.beta.configs import maskrcnn
from official.vision.beta.configs import semantic_segmentation
@dataclasses.dataclass
class PanopticMaskRCNN(maskrcnn.MaskRCNN):
"""Panoptic Mask R-CNN model config."""
segmentation_model: semantic_segmentation.SemanticSegmentationModel = (
semantic_segmentation.SemanticSegmentationModel(num_classes=2))
shared_backbone: bool = True
shared_decoder: bool = True
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory method to build panoptic segmentation model."""
import tensorflow as tf
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import factory as models_factory
from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
def build_panoptic_maskrcnn(
input_specs: tf.keras.layers.InputSpec,
model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Panoptic Mask R-CNN model.
This factory function builds the mask rcnn first, builds the non-shared
semantic segmentation layers, and finally combines the two models to form
the panoptic segmentation model.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
model_config: Config instance for the panoptic maskrcnn model.
l2_regularizer: Optional `tf.keras.regularizers.Regularizer`, if specified,
the model is built with the provided regularization layer.
Returns:
tf.keras.Model for the panoptic segmentation model.
"""
norm_activation_config = model_config.norm_activation
segmentation_config = model_config.segmentation_model
# Builds the maskrcnn model.
maskrcnn_model = models_factory.build_maskrcnn(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
# Builds the semantic segmentation branch.
if not model_config.shared_backbone:
segmentation_backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=segmentation_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
segmentation_decoder_input_specs = segmentation_backbone.output_specs
else:
segmentation_backbone = None
segmentation_decoder_input_specs = maskrcnn_model.backbone.output_specs
if not model_config.shared_decoder:
segmentation_decoder = decoder_factory.build_decoder(
input_specs=segmentation_decoder_input_specs,
model_config=segmentation_config,
l2_regularizer=l2_regularizer)
else:
segmentation_decoder = None
segmentation_head_config = segmentation_config.head
detection_head_config = model_config.detection_head
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=segmentation_config.num_classes,
level=segmentation_head_config.level,
num_convs=segmentation_head_config.num_convs,
prediction_kernel_size=segmentation_head_config.prediction_kernel_size,
num_filters=segmentation_head_config.num_filters,
upsample_factor=segmentation_head_config.upsample_factor,
feature_fusion=segmentation_head_config.feature_fusion,
low_level=segmentation_head_config.low_level,
low_level_num_filters=segmentation_head_config.low_level_num_filters,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
# Combines maskrcnn, and segmentation models to build panoptic segmentation
# model.
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone=maskrcnn_model.backbone,
decoder=maskrcnn_model.decoder,
rpn_head=maskrcnn_model.rpn_head,
detection_head=maskrcnn_model.detection_head,
roi_generator=maskrcnn_model.roi_generator,
roi_sampler=maskrcnn_model.roi_sampler,
roi_aligner=maskrcnn_model.roi_aligner,
detection_generator=maskrcnn_model.detection_generator,
mask_head=maskrcnn_model.mask_head,
mask_sampler=maskrcnn_model.mask_sampler,
mask_roi_aligner=maskrcnn_model.mask_roi_aligner,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
cascade_class_ensemble=detection_head_config.cascade_class_ensemble,
min_level=model_config.min_level,
max_level=model_config.max_level,
num_scales=model_config.anchor.num_scales,
aspect_ratios=model_config.anchor.aspect_ratios,
anchor_size=model_config.anchor.anchor_size)
return model
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for factory.py."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.configs import backbones
from official.vision.beta.configs import decoders
from official.vision.beta.configs import semantic_segmentation
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory
class PanopticMaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
('resnet', (640, 640), 'dilated_resnet', 'fpn'),
('resnet', (640, 640), 'dilated_resnet', 'aspp'),
('resnet', (640, 640), None, 'fpn'),
('resnet', (640, 640), None, 'aspp'),
('resnet', (640, 640), None, None),
('resnet', (None, None), 'dilated_resnet', 'fpn'),
('resnet', (None, None), 'dilated_resnet', 'aspp'),
('resnet', (None, None), None, 'fpn'),
('resnet', (None, None), None, 'aspp'),
('resnet', (None, None), None, None)
)
def test_builder(self, backbone_type, input_size, segmentation_backbone_type,
segmentation_decoder_type):
num_classes = 2
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
segmentation_output_stride = 16
level = int(np.math.log2(segmentation_output_stride))
segmentation_model = semantic_segmentation.SemanticSegmentationModel(
num_classes=2,
backbone=backbones.Backbone(type=segmentation_backbone_type),
decoder=decoders.Decoder(type=segmentation_decoder_type),
head=semantic_segmentation.SegmentationHead(level=level))
model_config = panoptic_maskrcnn_cfg.PanopticMaskRCNN(
num_classes=num_classes,
segmentation_model=segmentation_model,
backbone=backbones.Backbone(type=backbone_type),
shared_backbone=segmentation_backbone_type is None,
shared_decoder=segmentation_decoder_type is None)
l2_regularizer = tf.keras.regularizers.l2(5e-5)
_ = factory.build_panoptic_maskrcnn(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Panoptic Segmentation model."""
from typing import List, Mapping, Optional, Union
import tensorflow as tf
from official.vision.beta.modeling import maskrcnn_model
@tf.keras.utils.register_keras_serializable(package='Vision')
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Panoptic Segmentation model."""
def __init__(self,
backbone: tf.keras.Model,
decoder: tf.keras.Model,
rpn_head: tf.keras.layers.Layer,
detection_head: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_generator: tf.keras.layers.Layer,
roi_sampler: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_aligner: tf.keras.layers.Layer,
detection_generator: tf.keras.layers.Layer,
mask_head: Optional[tf.keras.layers.Layer] = None,
mask_sampler: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
segmentation_backbone: Optional[tf.keras.Model] = None,
segmentation_decoder: Optional[tf.keras.Model] = None,
segmentation_head: tf.keras.layers.Layer = None,
class_agnostic_bbox_pred: bool = False,
cascade_class_ensemble: bool = False,
min_level: Optional[int] = None,
max_level: Optional[int] = None,
num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None,
**kwargs):
"""Initializes the Panoptic Mask R-CNN model.
Args:
backbone: `tf.keras.Model`, the backbone network.
decoder: `tf.keras.Model`, the decoder network.
rpn_head: the RPN head.
detection_head: the detection head or a list of heads.
roi_generator: the ROI generator.
roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
detection heads.
roi_aligner: the ROI aligner.
detection_generator: the detection generator.
mask_head: the mask head.
mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction.
segmentation_backbone: `tf.keras.Model`, the backbone network for the
segmentation head for panoptic task. Providing `segmentation_backbone`
will allow the segmentation head to use a standlone backbone. Setting
`segmentation_backbone=None` would enable backbone sharing between the
MaskRCNN model and segmentation head.
segmentation_decoder: `tf.keras.Model`, the decoder network for the
segmentation head for panoptic task. Providing `segmentation_decoder`
will allow the segmentation head to use a standlone decoder. Setting
`segmentation_decoder=None` would enable decoder sharing between the
MaskRCNN model and segmentation head. Decoders can only be shared when
`segmentation_backbone` is shared as well.
segmentation_head: segmentatation head for panoptic task.
class_agnostic_bbox_pred: if True, perform class agnostic bounding box
prediction. Needs to be `True` for Cascade RCNN models.
cascade_class_ensemble: if True, ensemble classification scores over all
detection heads.
min_level: Minimum level in output feature maps.
max_level: Maximum level in output feature maps.
num_scales: A number representing intermediate scales added on each level.
For instances, num_scales=2 adds one additional intermediate anchor
scales [2^0, 2^0.5] on each level.
aspect_ratios: A list representing the aspect raito anchors added on each
level. The number indicates the ratio of width to height. For instances,
aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
anchor_size: A number representing the scale of size of the base anchor to
the feature stride 2^level.
**kwargs: keyword arguments to be passed.
"""
super(PanopticMaskRCNNModel, self).__init__(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=detection_head,
roi_generator=roi_generator,
roi_sampler=roi_sampler,
roi_aligner=roi_aligner,
detection_generator=detection_generator,
mask_head=mask_head,
mask_sampler=mask_sampler,
mask_roi_aligner=mask_roi_aligner,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
**kwargs)
self._config_dict.update({
'segmentation_backbone': segmentation_backbone,
'segmentation_decoder': segmentation_decoder,
'segmentation_head': segmentation_head
})
if not self._include_mask:
raise ValueError(
'`mask_head` needs to be provided for Panoptic Mask R-CNN.')
if segmentation_backbone is not None and segmentation_decoder is None:
raise ValueError(
'`segmentation_decoder` needs to be provided for Panoptic Mask R-CNN'
'if `backbone` is not shared.')
self.segmentation_backbone = segmentation_backbone
self.segmentation_decoder = segmentation_decoder
self.segmentation_head = segmentation_head
def call(self,
images: tf.Tensor,
image_shape: tf.Tensor,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes: Optional[tf.Tensor] = None,
gt_classes: Optional[tf.Tensor] = None,
gt_masks: Optional[tf.Tensor] = None,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
model_outputs = super(PanopticMaskRCNNModel, self).call(
images=images,
image_shape=image_shape,
anchor_boxes=anchor_boxes,
gt_boxes=gt_boxes,
gt_classes=gt_classes,
gt_masks=gt_masks,
training=training)
if self.segmentation_backbone is not None:
backbone_features = self.segmentation_backbone(images, training=training)
else:
backbone_features = model_outputs['backbone_features']
if self.segmentation_decoder is not None:
decoder_features = self.segmentation_decoder(
backbone_features, training=training)
else:
decoder_features = model_outputs['decoder_features']
segmentation_outputs = self.segmentation_head(
backbone_features, decoder_features, training=training)
model_outputs.update({
'segmentation_outputs': segmentation_outputs,
})
return model_outputs
@property
def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed."""
items = super(PanopticMaskRCNNModel, self).checkpoint_items
if self.segmentation_backbone is not None:
items.update(segmentation_backbone=self.segmentation_backbone)
if self.segmentation_decoder is not None:
items.update(segmentation_decoder=self.segmentation_decoder)
items.update(segmentation_head=self.segmentation_head)
return items
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for panoptic_maskrcnn_model.py."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.modeling.backbones import resnet
from official.vision.beta.modeling.decoders import aspp
from official.vision.beta.modeling.decoders import fpn
from official.vision.beta.modeling.heads import dense_prediction_heads
from official.vision.beta.modeling.heads import instance_heads
from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.modeling.layers import detection_generator
from official.vision.beta.modeling.layers import mask_sampler
from official.vision.beta.modeling.layers import roi_aligner
from official.vision.beta.modeling.layers import roi_generator
from official.vision.beta.modeling.layers import roi_sampler
from official.vision.beta.ops import anchor
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(
combinations.combine(
use_separable_conv=[True, False],
build_anchor_boxes=[True, False],
shared_backbone=[True, False],
shared_decoder=[True, False],
is_training=[True, False]))
def test_build_model(self,
use_separable_conv,
build_anchor_boxes,
shared_backbone,
shared_decoder,
is_training=True):
num_classes = 3
min_level = 3
max_level = 7
num_scales = 3
aspect_ratios = [1.0]
anchor_size = 3
resnet_model_id = 50
segmentation_resnet_model_id = 50
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
num_anchors_per_location = num_scales * len(aspect_ratios)
image_size = 128
images = np.random.rand(2, image_size, image_size, 3)
image_shape = np.array([[image_size, image_size], [image_size, image_size]])
shared_decoder = shared_decoder and shared_backbone
if build_anchor_boxes:
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=3,
image_size=(image_size, image_size)).multilevel_boxes
for l in anchor_boxes:
anchor_boxes[l] = tf.tile(
tf.expand_dims(anchor_boxes[l], axis=0), [2, 1, 1, 1])
else:
anchor_boxes = None
backbone = resnet.ResNet(model_id=resnet_model_id)
decoder = fpn.FPN(
input_specs=backbone.output_specs,
min_level=min_level,
max_level=max_level,
use_separable_conv=use_separable_conv)
rpn_head = dense_prediction_heads.RPNHead(
min_level=min_level,
max_level=max_level,
num_anchors_per_location=num_anchors_per_location,
num_convs=1)
detection_head = instance_heads.DetectionHead(num_classes=num_classes)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size)
gt_boxes = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
[[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32)
gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32)
gt_masks = np.ones((2, 3, 100, 100))
# Results will be checked in test_forward.
_ = model(
images,
image_shape,
anchor_boxes,
gt_boxes,
gt_classes,
gt_masks,
training=is_training)
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
shared_backbone=[True, False],
shared_decoder=[True, False],
training=[True, False],
))
def test_forward(self, strategy, training,
shared_backbone, shared_decoder):
num_classes = 3
min_level = 3
max_level = 4
num_scales = 3
aspect_ratios = [1.0]
anchor_size = 3
segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
class_agnostic_bbox_pred = False
cascade_class_ensemble = False
image_size = (256, 256)
images = np.random.rand(2, image_size[0], image_size[1], 3)
image_shape = np.array([[224, 100], [100, 224]])
shared_decoder = shared_decoder and shared_backbone
with strategy.scope():
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
image_size=image_size).multilevel_boxes
num_anchors_per_location = len(aspect_ratios) * num_scales
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=min_level,
max_level=max_level,
input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=min_level,
max_level=max_level,
num_anchors_per_location=num_anchors_per_location)
detection_head = instance_heads.DetectionHead(
num_classes=num_classes,
class_agnostic_bbox_pred=class_agnostic_bbox_pred)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_cascade = []
roi_sampler_obj = roi_sampler.ROISampler()
roi_sampler_cascade.append(roi_sampler_obj)
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size)
gt_boxes = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
[[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32)
gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32)
gt_masks = np.ones((2, 3, 100, 100))
results = model(
images,
image_shape,
anchor_boxes,
gt_boxes,
gt_classes,
gt_masks,
training=training)
self.assertIn('rpn_boxes', results)
self.assertIn('rpn_scores', results)
if training:
self.assertIn('class_targets', results)
self.assertIn('box_targets', results)
self.assertIn('class_outputs', results)
self.assertIn('box_outputs', results)
self.assertIn('mask_outputs', results)
else:
self.assertIn('detection_boxes', results)
self.assertIn('detection_scores', results)
self.assertIn('detection_classes', results)
self.assertIn('num_detections', results)
self.assertIn('detection_masks', results)
self.assertIn('segmentation_outputs', results)
self.assertAllEqual(
[2, image_size[0] // (2**level), image_size[1] // (2**level), 2],
results['segmentation_outputs'].numpy().shape)
@combinations.generate(
combinations.combine(
shared_backbone=[True, False], shared_decoder=[True, False]))
def test_serialize_deserialize(self, shared_backbone, shared_decoder):
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=3, max_level=7, input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=3, max_level=7, num_anchors_per_location=3)
detection_head = instance_heads.DetectionHead(num_classes=2)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
shared_decoder = shared_decoder and shared_backbone
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
min_level=3,
max_level=7,
num_scales=3,
aspect_ratios=[1.0],
anchor_size=3)
config = model.get_config()
new_model = panoptic_maskrcnn_model.PanopticMaskRCNNModel.from_config(
config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
@combinations.generate(
combinations.combine(
shared_backbone=[True, False], shared_decoder=[True, False]))
def test_checkpoint(self, shared_backbone, shared_decoder):
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=3, max_level=7, input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=3, max_level=7, num_anchors_per_location=3)
detection_head = instance_heads.DetectionHead(num_classes=2)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
shared_decoder = shared_decoder and shared_backbone
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
min_level=3,
max_level=7,
num_scales=3,
aspect_ratios=[1.0],
anchor_size=3)
expect_checkpoint_items = dict(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=[detection_head])
expect_checkpoint_items['mask_head'] = mask_head
if not shared_backbone:
expect_checkpoint_items['segmentation_backbone'] = segmentation_backbone
if not shared_decoder:
expect_checkpoint_items['segmentation_decoder'] = segmentation_decoder
expect_checkpoint_items['segmentation_head'] = segmentation_head
self.assertAllEqual(expect_checkpoint_items, model.checkpoint_items)
# Test save and load checkpoints.
ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
save_dir = self.create_tempdir().full_path
ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
partial_ckpt_mask = tf.train.Checkpoint(
backbone=backbone, mask_head=mask_head)
partial_ckpt_mask.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if not shared_backbone:
partial_ckpt_segmentation = tf.train.Checkpoint(
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head)
elif not shared_decoder:
partial_ckpt_segmentation = tf.train.Checkpoint(
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head)
else:
partial_ckpt_segmentation = tf.train.Checkpoint(
segmentation_head=segmentation_head)
partial_ckpt_segmentation.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if __name__ == '__main__':
tf.test.main()
...@@ -97,7 +97,6 @@ class ProjectionHead(tf.keras.layers.Layer): ...@@ -97,7 +97,6 @@ class ProjectionHead(tf.keras.layers.Layer):
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer, 'bias_regularizer': self._bias_regularizer,
'use_normalization': self._use_normalization,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon 'norm_epsilon': self._norm_epsilon
} }
......
...@@ -90,14 +90,15 @@ class SimCLRModel(tf.keras.Model): ...@@ -90,14 +90,15 @@ class SimCLRModel(tf.keras.Model):
if training and self._mode == PRETRAIN: if training and self._mode == PRETRAIN:
num_transforms = 2 num_transforms = 2
# Split channels, and optionally apply extra batched augmentation.
# (bsz, h, w, c*num_transforms) -> [(bsz, h, w, c), ....]
features_list = tf.split(
inputs, num_or_size_splits=num_transforms, axis=-1)
# (num_transforms * bsz, h, w, c)
features = tf.concat(features_list, 0)
else: else:
num_transforms = 1 num_transforms = 1
features = inputs
# Split channels, and optionally apply extra batched augmentation.
# (bsz, h, w, c*num_transforms) -> [(bsz, h, w, c), ....]
features_list = tf.split(inputs, num_or_size_splits=num_transforms, axis=-1)
# (num_transforms * bsz, h, w, c)
features = tf.concat(features_list, 0)
# Base network forward pass. # Base network forward pass.
endpoints = self._backbone(features, training=training) endpoints = self._backbone(features, training=training)
......
...@@ -415,7 +415,8 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -415,7 +415,8 @@ class SimCLRFinetuneTask(base_task.Task):
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, backbone_config=model_config.backbone,
norm_activation_config=model_config.norm_activation,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
......
# Vision Transformer (ViT)
**DISCLAIMER**: This implementation is still under development. No support will
be provided during the development phase.
[![Paper](http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv)](https://arxiv.org/abs/2010.11929)
This repository is the implementations of Vision Transformer (ViT) in
TensorFlow 2.
* Paper title:
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf).
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from official.vision.beta.projects.vit.configs import image_classification
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment