Unverified Commit c127d527 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents 78657911 457bcb85
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for detection."""
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from official.projects.detr import optimization
from official.projects.detr.configs import detr as detr_cfg
from official.projects.detr.dataloaders import coco
from official.projects.detr.tasks import detection
_NUM_EXAMPLES = 10
def _gen_fn():
h = np.random.randint(0, 300)
w = np.random.randint(0, 300)
num_boxes = np.random.randint(0, 50)
return {
'image': np.ones(shape=(h, w, 3), dtype=np.uint8),
'image/id': np.random.randint(0, 100),
'image/filename': 'test',
'objects': {
'is_crowd': np.ones(shape=(num_boxes), dtype=np.bool),
'bbox': np.ones(shape=(num_boxes, 4), dtype=np.float32),
'label': np.ones(shape=(num_boxes), dtype=np.int64),
'id': np.ones(shape=(num_boxes), dtype=np.int64),
'area': np.ones(shape=(num_boxes), dtype=np.int64),
}
}
def _as_dataset(self, *args, **kwargs):
del args
del kwargs
return tf.data.Dataset.from_generator(
lambda: (_gen_fn() for i in range(_NUM_EXAMPLES)),
output_types=self.info.features.dtype,
output_shapes=self.info.features.shape,
)
class DetectionTest(tf.test.TestCase):
def test_train_step(self):
config = detr_cfg.DetectionConfig(
num_encoder_layers=1,
num_decoder_layers=1,
train_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=True,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
opt_cfg = optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [120000],
'values': [0.0001, 1.0e-05]
}
},
})
optimizer = detection.DectectionTask.create_optimizer(opt_cfg)
task.train_step(next(iterator), model, optimizer)
def test_validation_step(self):
config = detr_cfg.DetectionConfig(
num_encoder_layers=1,
num_decoder_layers=1,
validation_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data)
iterator = iter(dataset)
logs = task.validation_step(next(iterator), model, metrics)
state = task.aggregate_logs(step_outputs=logs)
task.reduce_aggregated_logs(state)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# pylint: disable=unused-import
from official.projects.detr.configs import detr
from official.projects.detr.tasks import detection
# pylint: enable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
...@@ -135,7 +135,8 @@ def main(argv: Sequence[str]) -> None: ...@@ -135,7 +135,8 @@ def main(argv: Sequence[str]) -> None:
checkpoint = tf.train.Checkpoint(**checkpoint_dict) checkpoint = tf.train.Checkpoint(**checkpoint_dict)
checkpoint.restore(FLAGS.model_checkpoint).assert_existing_objects_matched() checkpoint.restore(FLAGS.model_checkpoint).assert_existing_objects_matched()
model_for_serving = build_model_for_serving(model) model_for_serving = build_model_for_serving(model, FLAGS.sequence_length,
FLAGS.batch_size)
model_for_serving.summary() model_for_serving.summary()
# TODO(b/194449109): Need to save the model to file and then convert tflite # TODO(b/194449109): Need to save the model to file and then convert tflite
......
...@@ -18,6 +18,7 @@ task: ...@@ -18,6 +18,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.2 dropout_rate: 0.2
activation: 'swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -12,6 +12,7 @@ task: ...@@ -12,6 +12,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: false use_sync_bn: false
dropout_rate: 0.5 dropout_rate: 0.5
activation: 'swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -24,6 +24,7 @@ task: ...@@ -24,6 +24,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.2 dropout_rate: 0.2
activation: 'hard_swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -18,6 +18,7 @@ task: ...@@ -18,6 +18,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.5 dropout_rate: 0.5
activation: 'swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -24,6 +24,7 @@ task: ...@@ -24,6 +24,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.2 dropout_rate: 0.2
activation: 'hard_swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -18,6 +18,7 @@ task: ...@@ -18,6 +18,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.5 dropout_rate: 0.5
activation: 'swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -24,6 +24,7 @@ task: ...@@ -24,6 +24,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.5 dropout_rate: 0.5
activation: 'hard_swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -18,6 +18,7 @@ task: ...@@ -18,6 +18,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.5 dropout_rate: 0.5
activation: 'swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -25,6 +25,7 @@ task: ...@@ -25,6 +25,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.5 dropout_rate: 0.5
activation: 'hard_swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -18,6 +18,7 @@ task: ...@@ -18,6 +18,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.5 dropout_rate: 0.5
activation: 'swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -25,6 +25,7 @@ task: ...@@ -25,6 +25,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.5 dropout_rate: 0.5
activation: 'hard_swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -18,6 +18,7 @@ task: ...@@ -18,6 +18,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.5 dropout_rate: 0.5
activation: 'swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -25,6 +25,7 @@ task: ...@@ -25,6 +25,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.5 dropout_rate: 0.5
activation: 'hard_swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -18,6 +18,7 @@ task: ...@@ -18,6 +18,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.2 dropout_rate: 0.2
activation: 'swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -24,6 +24,7 @@ task: ...@@ -24,6 +24,7 @@ task:
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.2 dropout_rate: 0.2
activation: 'hard_swish'
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -338,7 +338,7 @@ class Movinet(tf.keras.Model): ...@@ -338,7 +338,7 @@ class Movinet(tf.keras.Model):
3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with 3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv). by 5x1x1 conv).
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D se_type: '3d', '2d', '2plus3d' or 'none'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d' spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d' uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling. concatenates both 3D and 2D global average pooling.
...@@ -369,7 +369,7 @@ class Movinet(tf.keras.Model): ...@@ -369,7 +369,7 @@ class Movinet(tf.keras.Model):
if conv_type not in ('3d', '2plus1d', '3d_2plus1d'): if conv_type not in ('3d', '2plus1d', '3d_2plus1d'):
raise ValueError('Unknown conv type: {}'.format(conv_type)) raise ValueError('Unknown conv type: {}'.format(conv_type))
if se_type not in ('3d', '2d', '2plus3d'): if se_type not in ('3d', '2d', '2plus3d', 'none'):
raise ValueError('Unknown squeeze excitation type: {}'.format(se_type)) raise ValueError('Unknown squeeze excitation type: {}'.format(se_type))
self._model_id = model_id self._model_id = model_id
...@@ -602,10 +602,11 @@ class Movinet(tf.keras.Model): ...@@ -602,10 +602,11 @@ class Movinet(tf.keras.Model):
expand_filters, expand_filters,
) )
states[f'{prefix}_pool_buffer'] = ( if '3d' in self._se_type:
input_shape[0], 1, 1, 1, expand_filters, states[f'{prefix}_pool_buffer'] = (
) input_shape[0], 1, 1, 1, expand_filters,
states[f'{prefix}_pool_frame_count'] = (1,) )
states[f'{prefix}_pool_frame_count'] = (1,)
if use_positional_encoding: if use_positional_encoding:
name = f'{prefix}_pos_enc_frame_count' name = f'{prefix}_pos_enc_frame_count'
......
...@@ -93,10 +93,9 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -93,10 +93,9 @@ class MobileConv2D(tf.keras.layers.Layer):
data_format: Optional[str] = None, data_format: Optional[str] = None,
dilation_rate: Union[int, Sequence[int]] = (1, 1), dilation_rate: Union[int, Sequence[int]] = (1, 1),
groups: int = 1, groups: int = 1,
activation: Optional[nn_layers.Activation] = None,
use_bias: bool = True, use_bias: bool = True,
kernel_initializer: tf.keras.initializers.Initializer = 'glorot_uniform', kernel_initializer: str = 'glorot_uniform',
bias_initializer: tf.keras.initializers.Initializer = 'zeros', bias_initializer: str = 'zeros',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
activity_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, activity_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
...@@ -105,6 +104,8 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -105,6 +104,8 @@ class MobileConv2D(tf.keras.layers.Layer):
use_depthwise: bool = False, use_depthwise: bool = False,
use_temporal: bool = False, use_temporal: bool = False,
use_buffered_input: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras use_buffered_input: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
batch_norm_op: Optional[Any] = None,
activation_op: Optional[Any] = None,
**kwargs): # pylint: disable=g-doc-args **kwargs): # pylint: disable=g-doc-args
"""Initializes mobile conv2d. """Initializes mobile conv2d.
...@@ -117,6 +118,10 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -117,6 +118,10 @@ class MobileConv2D(tf.keras.layers.Layer):
use_buffered_input: if True, the input is expected to be padded use_buffered_input: if True, the input is expected to be padded
beforehand. In effect, calling this layer will use 'valid' padding on beforehand. In effect, calling this layer will use 'valid' padding on
the temporal dimension to simulate 'causal' padding. the temporal dimension to simulate 'causal' padding.
batch_norm_op: A callable object of batch norm layer. If None, no batch
norm will be applied after the convolution.
activation_op: A callabel object of activation layer. If None, no
activation will be applied after the convolution.
**kwargs: keyword arguments to be passed to this layer. **kwargs: keyword arguments to be passed to this layer.
Returns: Returns:
...@@ -130,7 +135,6 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -130,7 +135,6 @@ class MobileConv2D(tf.keras.layers.Layer):
self._data_format = data_format self._data_format = data_format
self._dilation_rate = dilation_rate self._dilation_rate = dilation_rate
self._groups = groups self._groups = groups
self._activation = activation
self._use_bias = use_bias self._use_bias = use_bias
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._bias_initializer = bias_initializer self._bias_initializer = bias_initializer
...@@ -142,6 +146,8 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -142,6 +146,8 @@ class MobileConv2D(tf.keras.layers.Layer):
self._use_depthwise = use_depthwise self._use_depthwise = use_depthwise
self._use_temporal = use_temporal self._use_temporal = use_temporal
self._use_buffered_input = use_buffered_input self._use_buffered_input = use_buffered_input
self._batch_norm_op = batch_norm_op
self._activation_op = activation_op
kernel_size = normalize_tuple(kernel_size, 2, 'kernel_size') kernel_size = normalize_tuple(kernel_size, 2, 'kernel_size')
...@@ -156,7 +162,6 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -156,7 +162,6 @@ class MobileConv2D(tf.keras.layers.Layer):
depth_multiplier=1, depth_multiplier=1,
data_format=data_format, data_format=data_format,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias, use_bias=use_bias,
depthwise_initializer=kernel_initializer, depthwise_initializer=kernel_initializer,
bias_initializer=bias_initializer, bias_initializer=bias_initializer,
...@@ -175,7 +180,6 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -175,7 +180,6 @@ class MobileConv2D(tf.keras.layers.Layer):
data_format=data_format, data_format=data_format,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
groups=groups, groups=groups,
activation=activation,
use_bias=use_bias, use_bias=use_bias,
kernel_initializer=kernel_initializer, kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer, bias_initializer=bias_initializer,
...@@ -196,7 +200,6 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -196,7 +200,6 @@ class MobileConv2D(tf.keras.layers.Layer):
'data_format': self._data_format, 'data_format': self._data_format,
'dilation_rate': self._dilation_rate, 'dilation_rate': self._dilation_rate,
'groups': self._groups, 'groups': self._groups,
'activation': self._activation,
'use_bias': self._use_bias, 'use_bias': self._use_bias,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'bias_initializer': self._bias_initializer, 'bias_initializer': self._bias_initializer,
...@@ -229,6 +232,10 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -229,6 +232,10 @@ class MobileConv2D(tf.keras.layers.Layer):
x = tf.reshape(inputs, input_shape) x = tf.reshape(inputs, input_shape)
x = self._conv(x) x = self._conv(x)
if self._batch_norm_op is not None:
x = self._batch_norm_op(x)
if self._activation_op is not None:
x = self._activation_op(x)
if self._use_temporal: if self._use_temporal:
output_shape = [ output_shape = [
...@@ -357,8 +364,20 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -357,8 +364,20 @@ class ConvBlock(tf.keras.layers.Layer):
padding = 'causal' if self._causal else 'same' padding = 'causal' if self._causal else 'same'
self._groups = input_shape[-1] if self._depthwise else 1 self._groups = input_shape[-1] if self._depthwise else 1
self._conv_temporal = None self._batch_norm = None
self._batch_norm_temporal = None
if self._use_batch_norm:
self._batch_norm = self._batch_norm_layer(
momentum=self._batch_norm_momentum,
epsilon=self._batch_norm_epsilon,
name='bn')
if self._conv_type != '3d' and self._kernel_size[0] > 1:
self._batch_norm_temporal = self._batch_norm_layer(
momentum=self._batch_norm_momentum,
epsilon=self._batch_norm_epsilon,
name='bn_temporal')
self._conv_temporal = None
if self._conv_type == '3d_2plus1d' and self._kernel_size[0] > 1: if self._conv_type == '3d_2plus1d' and self._kernel_size[0] > 1:
self._conv = nn_layers.Conv3D( self._conv = nn_layers.Conv3D(
self._filters, self._filters,
...@@ -394,6 +413,8 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -394,6 +413,8 @@ class ConvBlock(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
use_buffered_input=False, use_buffered_input=False,
batch_norm_op=self._batch_norm,
activation_op=self._activation_layer,
name='conv2d') name='conv2d')
if self._kernel_size[0] > 1: if self._kernel_size[0] > 1:
self._conv_temporal = MobileConv2D( self._conv_temporal = MobileConv2D(
...@@ -408,6 +429,8 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -408,6 +429,8 @@ class ConvBlock(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
use_buffered_input=self._use_buffered_input, use_buffered_input=self._use_buffered_input,
batch_norm_op=self._batch_norm_temporal,
activation_op=self._activation_layer,
name='conv2d_temporal') name='conv2d_temporal')
else: else:
self._conv = nn_layers.Conv3D( self._conv = nn_layers.Conv3D(
...@@ -422,37 +445,26 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -422,37 +445,26 @@ class ConvBlock(tf.keras.layers.Layer):
use_buffered_input=self._use_buffered_input, use_buffered_input=self._use_buffered_input,
name='conv3d') name='conv3d')
self._batch_norm = None
self._batch_norm_temporal = None
if self._use_batch_norm:
self._batch_norm = self._batch_norm_layer(
momentum=self._batch_norm_momentum,
epsilon=self._batch_norm_epsilon,
name='bn')
if self._conv_type != '3d' and self._conv_temporal is not None:
self._batch_norm_temporal = self._batch_norm_layer(
momentum=self._batch_norm_momentum,
epsilon=self._batch_norm_epsilon,
name='bn_temporal')
super(ConvBlock, self).build(input_shape) super(ConvBlock, self).build(input_shape)
def call(self, inputs): def call(self, inputs):
"""Calls the layer with the given inputs.""" """Calls the layer with the given inputs."""
x = inputs x = inputs
# bn_op and activation_op are folded into the '2plus1d' conv layer so that
# we do not explicitly call them here.
# TODO(lzyuan): clean the conv layers api once the models are re-trained.
x = self._conv(x) x = self._conv(x)
if self._batch_norm is not None: if self._batch_norm is not None and self._conv_type != '2plus1d':
x = self._batch_norm(x) x = self._batch_norm(x)
if self._activation_layer is not None: if self._activation_layer is not None and self._conv_type != '2plus1d':
x = self._activation_layer(x) x = self._activation_layer(x)
if self._conv_temporal is not None: if self._conv_temporal is not None:
x = self._conv_temporal(x) x = self._conv_temporal(x)
if self._batch_norm_temporal is not None: if self._batch_norm_temporal is not None and self._conv_type != '2plus1d':
x = self._batch_norm_temporal(x) x = self._batch_norm_temporal(x)
if self._activation_layer is not None: if self._activation_layer is not None and self._conv_type != '2plus1d':
x = self._activation_layer(x) x = self._activation_layer(x)
return x return x
...@@ -640,10 +652,13 @@ class StreamConvBlock(ConvBlock): ...@@ -640,10 +652,13 @@ class StreamConvBlock(ConvBlock):
if self._conv_temporal is None and self._stream_buffer is not None: if self._conv_temporal is None and self._stream_buffer is not None:
x, states = self._stream_buffer(x, states=states) x, states = self._stream_buffer(x, states=states)
# bn_op and activation_op are folded into the '2plus1d' conv layer so that
# we do not explicitly call them here.
# TODO(lzyuan): clean the conv layers api once the models are re-trained.
x = self._conv(x) x = self._conv(x)
if self._batch_norm is not None: if self._batch_norm is not None and self._conv_type != '2plus1d':
x = self._batch_norm(x) x = self._batch_norm(x)
if self._activation_layer is not None: if self._activation_layer is not None and self._conv_type != '2plus1d':
x = self._activation_layer(x) x = self._activation_layer(x)
if self._conv_temporal is not None: if self._conv_temporal is not None:
...@@ -653,9 +668,9 @@ class StreamConvBlock(ConvBlock): ...@@ -653,9 +668,9 @@ class StreamConvBlock(ConvBlock):
x, states = self._stream_buffer(x, states=states) x, states = self._stream_buffer(x, states=states)
x = self._conv_temporal(x) x = self._conv_temporal(x)
if self._batch_norm_temporal is not None: if self._batch_norm_temporal is not None and self._conv_type != '2plus1d':
x = self._batch_norm_temporal(x) x = self._batch_norm_temporal(x)
if self._activation_layer is not None: if self._activation_layer is not None and self._conv_type != '2plus1d':
x = self._activation_layer(x) x = self._activation_layer(x)
return x, states return x, states
...@@ -885,7 +900,8 @@ class MobileBottleneck(tf.keras.layers.Layer): ...@@ -885,7 +900,8 @@ class MobileBottleneck(tf.keras.layers.Layer):
x = self._expansion_layer(inputs) x = self._expansion_layer(inputs)
x, states = self._feature_layer(x, states=states) x, states = self._feature_layer(x, states=states)
x, states = self._attention_layer(x, states=states) if self._attention_layer is not None:
x, states = self._attention_layer(x, states=states)
x = self._projection_layer(x) x = self._projection_layer(x)
# Add identity so that the ops are ordered as written. This is useful for, # Add identity so that the ops are ordered as written. This is useful for,
...@@ -1136,18 +1152,20 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1136,18 +1152,20 @@ class MovinetBlock(tf.keras.layers.Layer):
batch_norm_momentum=self._batch_norm_momentum, batch_norm_momentum=self._batch_norm_momentum,
batch_norm_epsilon=self._batch_norm_epsilon, batch_norm_epsilon=self._batch_norm_epsilon,
name='projection') name='projection')
self._attention = StreamSqueezeExcitation( self._attention = None
se_hidden_filters, if se_type != 'none':
se_type=se_type, self._attention = StreamSqueezeExcitation(
activation=activation, se_hidden_filters,
gating_activation=gating_activation, se_type=se_type,
causal=self._causal, activation=activation,
conv_type=conv_type, gating_activation=gating_activation,
use_positional_encoding=use_positional_encoding, causal=self._causal,
kernel_initializer=kernel_initializer, conv_type=conv_type,
kernel_regularizer=kernel_regularizer, use_positional_encoding=use_positional_encoding,
state_prefix=state_prefix, kernel_initializer=kernel_initializer,
name='se') kernel_regularizer=kernel_regularizer,
state_prefix=state_prefix,
name='se')
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment