Commit a44be35f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by saberkun
Browse files

Internal change

PiperOrigin-RevId: 401839863
parent 21ce83d8
......@@ -19,10 +19,10 @@
import tensorflow as tf
from official.modeling import hyperparams
from official.projects.volumetric_models.modeling.decoders import factory as decoder_factory
from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d
from official.vision.beta.modeling import segmentation_model
from official.vision.beta.modeling.backbones import factory as backbone_factory
from official.vision.beta.projects.volumetric_models.modeling.decoders import factory as decoder_factory
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
def build_segmentation_model_3d(
......
......@@ -18,10 +18,10 @@ from absl.testing import parameterized
import tensorflow as tf
# pylint: disable=unused-import
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling import factory
from official.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.modeling import factory
class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -19,7 +19,7 @@ from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d
class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -19,7 +19,7 @@
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling import nn_blocks_3d
from official.projects.volumetric_models.modeling import nn_blocks_3d
class NNBlocks3DTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -18,11 +18,10 @@
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d
from official.vision.beta.modeling import segmentation_model
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
class SegmentationNetworkUNet3DTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -16,7 +16,7 @@
# pylint: disable=unused-import
from official.common import registry_imports
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as semantic_segmentation_3d_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d
from official.projects.volumetric_models.configs import semantic_segmentation_3d as semantic_segmentation_3d_cfg
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.tasks import semantic_segmentation_3d
......@@ -42,7 +42,7 @@ from absl import flags
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.projects.volumetric_models.serving import semantic_segmentation_3d
from official.projects.volumetric_models.serving import semantic_segmentation_3d
from official.vision.beta.serving import export_saved_model_lib
FLAGS = flags.FLAGS
......
......@@ -19,9 +19,9 @@ from typing import Mapping
import tensorflow as tf
# pylint: disable=unused-import
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling import factory
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.modeling import factory
from official.vision.beta.serving import export_base
......
......@@ -22,10 +22,10 @@ import tensorflow as tf
# pylint: disable=unused-import
from official.core import exp_factory
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.serving import semantic_segmentation_3d
from official.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.serving import semantic_segmentation_3d
class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
......
......@@ -23,11 +23,11 @@ from official.common import dataset_fn
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.dataloaders import segmentation_input_3d
from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics
from official.vision.beta.projects.volumetric_models.losses import segmentation_losses
from official.vision.beta.projects.volumetric_models.modeling import factory
from official.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.projects.volumetric_models.dataloaders import segmentation_input_3d
from official.projects.volumetric_models.evaluation import segmentation_metrics
from official.projects.volumetric_models.losses import segmentation_losses
from official.projects.volumetric_models.modeling import factory
@task_factory.register_task_cls(exp_cfg.SemanticSegmentation3DTask)
......
......@@ -26,11 +26,11 @@ import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import optimization
from official.projects.volumetric_models.evaluation import segmentation_metrics
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task
class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase):
......
......@@ -15,12 +15,11 @@
"""TensorFlow Model Garden Vision training driver."""
from absl import app
import gin # pylint: disable=unused-import
from official.common import flags as tfm_flags
from official.projects.volumetric_models import registry_imports # pylint: disable=unused-import
from official.vision.beta import train
from official.vision.beta.projects.volumetric_models import registry_imports # pylint: disable=unused-import
def main(_):
......
......@@ -20,9 +20,8 @@ from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
from official.projects.volumetric_models import train as train_lib
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.volumetric_models import train as train_lib
FLAGS = flags.FLAGS
......
......@@ -13,14 +13,15 @@
# limitations under the License.
"""Video classification configuration definition."""
import dataclasses
from typing import Optional, Tuple
from absl import flags
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.beta.configs import common
FLAGS = flags.FLAGS
......@@ -66,16 +67,28 @@ def yt8m(is_training):
@dataclasses.dataclass
class YT8MModel(hyperparams.Config):
class MoeModel(hyperparams.Config):
"""The model config."""
cluster_size: int = 2048
hidden_size: int = 2048
num_mixtures: int = 5
l2_penalty: float = 1e-5
use_input_context_gate: bool = False
use_output_context_gate: bool = False
@dataclasses.dataclass
class DbofModel(hyperparams.Config):
"""The model config."""
cluster_size: int = 3000
hidden_size: int = 2000
add_batch_norm: bool = True
sample_random_frames: bool = True
is_training: bool = True
activation: str = 'relu6'
use_context_gate_cluster_layer: bool = False
context_gate_cluster_bottleneck_size: int = 0
pooling_method: str = 'average'
yt8m_agg_classifier_model: str = 'MoeModel'
agg_model: hyperparams.Config = MoeModel()
norm_activation: common.NormActivation = common.NormActivation(
activation='relu', use_sync_bn=False)
@dataclasses.dataclass
......@@ -83,12 +96,13 @@ class Losses(hyperparams.Config):
name: str = 'binary_crossentropy'
from_logits: bool = False
label_smoothing: float = 0.0
l2_weight_decay: float = 1e-5
@dataclasses.dataclass
class YT8MTask(cfg.TaskConfig):
"""The task config."""
model: YT8MModel = YT8MModel()
model: DbofModel = DbofModel()
train_data: DataConfig = yt8m(is_training=True)
validation_data: DataConfig = yt8m(is_training=False)
losses: Losses = Losses()
......@@ -102,8 +116,8 @@ def add_trainer(
experiment: cfg.ExperimentConfig,
train_batch_size: int,
eval_batch_size: int,
learning_rate: float = 0.005,
train_epochs: int = 44,
learning_rate: float = 0.0001,
train_epochs: int = 50,
):
"""Add and config a trainer to the experiment config."""
if YT8M_TRAIN_EXAMPLES <= 0:
......@@ -115,13 +129,14 @@ def add_trainer(
experiment.task.train_data.global_batch_size = train_batch_size
experiment.task.validation_data.global_batch_size = eval_batch_size
steps_per_epoch = YT8M_TRAIN_EXAMPLES // train_batch_size
steps_per_loop = 30
experiment.trainer = cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
steps_per_loop=steps_per_loop,
summary_interval=steps_per_loop,
checkpoint_interval=steps_per_loop,
train_steps=train_epochs * steps_per_epoch,
validation_steps=YT8M_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
validation_interval=steps_per_loop,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
......@@ -132,9 +147,18 @@ def add_trainer(
'exponential': {
'initial_learning_rate': learning_rate,
'decay_rate': 0.95,
'decay_steps': 1500000,
'decay_steps': int(steps_per_epoch * 1.5),
'offset': 500,
}
},
'warmup': {
'linear': {
'name': 'linear',
'warmup_learning_rate': 0,
'warmup_steps': 500,
},
'type': 'linear',
}
}))
return experiment
......@@ -154,4 +178,17 @@ def yt8m_experiment() -> cfg.ExperimentConfig:
'task.train_data.feature_names != None',
])
return add_trainer(exp_config, train_batch_size=512, eval_batch_size=512)
# Per TPUv3 Core batch size 16GB HBM. `factor` in range(1, 26)
factor = 1
num_cores = 32 # for TPU 4x4
train_per_core_bs = 32 * factor
train_bs = train_per_core_bs * num_cores
eval_per_core_bs = 32 * 50 # multiplier<=100
eval_bs = eval_per_core_bs * num_cores
# based lr=0.0001 for bs=512
return add_trainer(
exp_config,
train_batch_size=train_bs,
eval_batch_size=eval_bs,
learning_rate=0.0001 * (train_bs / 512),
train_epochs=100)
......@@ -13,13 +13,12 @@
# limitations under the License.
"""Contains model definitions."""
from typing import Optional, Dict, Any
import tensorflow as tf
from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as utils
layers = tf.keras.layers
regularizers = tf.keras.regularizers
# The number of mixtures (excluding the dummy 'expert') used for MoeModel.
moe_num_mixtures = 2
class LogisticModel():
......@@ -41,7 +40,7 @@ class LogisticModel():
output = layers.Dense(
vocab_size,
activation=tf.nn.sigmoid,
kernel_regularizer=regularizers.l2(l2_penalty))(
kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
model_input)
return {"predictions": output}
......@@ -52,8 +51,12 @@ class MoeModel():
def create_model(self,
model_input,
vocab_size,
num_mixtures=None,
l2_penalty=1e-8):
num_mixtures: int = 2,
use_input_context_gate: bool = False,
use_output_context_gate: bool = False,
normalizer_fn=None,
normalizer_params: Optional[Dict[str, Any]] = None,
l2_penalty: float = 1e-5):
"""Creates a Mixture of (Logistic) Experts model.
The model consists of a per-class softmax distribution over a
......@@ -64,6 +67,10 @@ class MoeModel():
vocab_size: The number of classes in the dataset.
num_mixtures: The number of mixtures (excluding a dummy 'expert' that
always predicts the non-existence of an entity).
use_input_context_gate: if True apply context gate layer to the input.
use_output_context_gate: if True apply context gate layer to the output.
normalizer_fn: normalization op constructor (e.g. batch norm).
normalizer_params: parameters to the `normalizer_fn`.
l2_penalty: How much to penalize the squared magnitudes of parameter
values.
......@@ -72,18 +79,23 @@ class MoeModel():
of the model in the 'predictions' key. The dimensions of the tensor
are batch_size x num_classes.
"""
num_mixtures = num_mixtures or moe_num_mixtures
if use_input_context_gate:
model_input = utils.context_gate(
model_input,
normalizer_fn=normalizer_fn,
normalizer_params=normalizer_params,
)
gate_activations = layers.Dense(
vocab_size * (num_mixtures + 1),
activation=None,
bias_initializer=None,
kernel_regularizer=regularizers.l2(l2_penalty))(
kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
model_input)
expert_activations = layers.Dense(
vocab_size * num_mixtures,
activation=None,
kernel_regularizer=regularizers.l2(l2_penalty))(
kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
model_input)
gating_distribution = tf.nn.softmax(
......@@ -98,4 +110,10 @@ class MoeModel():
gating_distribution[:, :num_mixtures] * expert_distribution, 1)
final_probabilities = tf.reshape(final_probabilities_by_class_and_batch,
[-1, vocab_size])
if use_output_context_gate:
final_probabilities = utils.context_gate(
final_probabilities,
normalizer_fn=normalizer_fn,
normalizer_params=normalizer_params,
)
return {"predictions": final_probabilities}
......@@ -13,6 +13,7 @@
# limitations under the License.
"""YT8M model definition."""
from typing import Optional
import tensorflow as tf
from official.modeling import tf_utils
......@@ -23,23 +24,43 @@ from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as util
layers = tf.keras.layers
class YT8MModel(tf.keras.Model):
"""A YT8M model class builder."""
def __init__(self,
input_params: yt8m_cfg.YT8MModel,
num_frames=30,
num_classes=3862,
input_specs=layers.InputSpec(shape=[None, None, 1152]),
**kwargs):
class DbofModel(tf.keras.Model):
"""A YT8M model class builder.
Creates a Deep Bag of Frames model.
The model projects the features for each frame into a higher dimensional
'clustering' space, pools across frames in that space, and then
uses a configurable video-level model to classify the now aggregated features.
The model will randomly sample either frames or sequences of frames during
training to speed up convergence.
"""
def __init__(
self,
params: yt8m_cfg.DbofModel,
num_frames=30,
num_classes=3862,
input_specs=layers.InputSpec(shape=[None, None, 1152]),
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
activation: str = "relu",
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs):
"""YT8M initialization function.
Args:
input_params: model configuration parameters
params: model configuration parameters
num_frames: `int` number of frames in a single input.
num_classes: `int` number of classes in dataset.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
[batch_size x num_frames x num_features]
kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
activation: A `str` of name of the activation function.
use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: keyword arguments to be passed.
"""
......@@ -48,12 +69,19 @@ class YT8MModel(tf.keras.Model):
"input_specs": input_specs,
"num_classes": num_classes,
"num_frames": num_frames,
"input_params": input_params
"params": params
}
self._num_classes = num_classes
self._input_specs = input_specs
self._act_fn = tf_utils.get_activation(input_params.activation)
self._is_training = input_params.is_training
self._act_fn = tf_utils.get_activation(activation)
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == "channels_last":
bn_axis = -1
else:
bn_axis = 1
# [batch_size x num_frames x num_features]
feature_size = input_specs.shape[-1]
......@@ -63,31 +91,34 @@ class YT8MModel(tf.keras.Model):
tf.summary.histogram("input_hist", model_input)
# configure model
if input_params.add_batch_norm:
reshaped_input = layers.BatchNormalization(
name="input_bn", scale=True, center=True,
trainable=self._is_training)(
if params.add_batch_norm:
reshaped_input = self._norm(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
name="input_bn")(
reshaped_input)
# activation = reshaped input * cluster weights
activation = layers.Dense(
input_params.cluster_size,
kernel_initializer=tf.random_normal_initializer(
stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
reshaped_input)
if input_params.add_batch_norm:
activation = layers.BatchNormalization(
name="cluster_bn",
scale=True,
center=True,
trainable=self._is_training)(
if params.cluster_size > 0:
activation = layers.Dense(
params.cluster_size,
kernel_regularizer=kernel_regularizer,
kernel_initializer=tf.random_normal_initializer(
stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
reshaped_input)
if params.add_batch_norm:
activation = self._norm(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
name="cluster_bn")(
activation)
else:
cluster_biases = tf.Variable(
tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
shape=[input_params.cluster_size]),
shape=[params.cluster_size]),
name="cluster_biases")
tf.summary.histogram("cluster_biases", cluster_biases)
activation += cluster_biases
......@@ -95,30 +126,42 @@ class YT8MModel(tf.keras.Model):
activation = self._act_fn(activation)
tf.summary.histogram("cluster_output", activation)
activation = tf.reshape(activation,
[-1, num_frames, input_params.cluster_size])
activation = utils.FramePooling(activation, input_params.pooling_method)
if params.use_context_gate_cluster_layer:
pooling_method = None
norm_args = dict(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
name="context_gate_bn")
activation = utils.context_gate(
activation,
normalizer_fn=self._norm,
normalizer_params=norm_args,
pooling_method=pooling_method,
hidden_layer_size=params.context_gate_cluster_bottleneck_size,
kernel_regularizer=kernel_regularizer)
activation = tf.reshape(activation, [-1, num_frames, params.cluster_size])
activation = utils.frame_pooling(activation, params.pooling_method)
# activation = activation * hidden1_weights
activation = layers.Dense(
input_params.hidden_size,
params.hidden_size,
kernel_regularizer=kernel_regularizer,
kernel_initializer=tf.random_normal_initializer(
stddev=1 /
tf.sqrt(tf.cast(input_params.cluster_size, tf.float32))))(
stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
activation)
if input_params.add_batch_norm:
activation = layers.BatchNormalization(
name="hidden1_bn",
scale=True,
center=True,
trainable=self._is_training)(
if params.add_batch_norm:
activation = self._norm(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
name="hidden1_bn")(
activation)
else:
hidden1_biases = tf.Variable(
tf.random_normal_initializer(stddev=0.01)(
shape=[input_params.hidden_size]),
tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
name="hidden1_biases")
tf.summary.histogram("hidden1_biases", hidden1_biases)
......@@ -128,9 +171,15 @@ class YT8MModel(tf.keras.Model):
tf.summary.histogram("hidden1_output", activation)
aggregated_model = getattr(yt8m_agg_models,
input_params.yt8m_agg_classifier_model)
params.yt8m_agg_classifier_model)
norm_args = dict(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)
output = aggregated_model().create_model(
model_input=activation, vocab_size=self._num_classes)
model_input=activation,
vocab_size=self._num_classes,
num_mixtures=params.agg_model.num_mixtures,
normalizer_fn=self._norm,
normalizer_params=norm_args,
l2_penalty=params.agg_model.l2_penalty)
super().__init__(
inputs=model_input, outputs=output.get("predictions"), **kwargs)
......
......@@ -37,8 +37,8 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
input_specs = tf.keras.layers.InputSpec(shape=[num_frames, feature_dims])
num_classes = 3862
model = yt8m_model.YT8MModel(
input_params=yt8m_cfg.YT8MTask.model,
model = yt8m_model.DbofModel(
params=yt8m_cfg.YT8MTask.model,
num_frames=num_frames,
num_classes=num_classes,
input_specs=input_specs)
......@@ -49,10 +49,10 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual([2, num_classes], logits.numpy().shape)
def test_serialize_deserialize(self):
model = yt8m_model.YT8MModel(input_params=yt8m_cfg.YT8MTask.model)
model = yt8m_model.DbofModel(params=yt8m_cfg.YT8MTask.model)
config = model.get_config()
new_model = yt8m_model.YT8MModel.from_config(config)
new_model = yt8m_model.DbofModel.from_config(config)
# If the serialization was successful,
# the new config should match the old.
......
......@@ -13,10 +13,12 @@
# limitations under the License.
"""Contains a collection of util functions for model construction."""
from typing import Dict, Optional, Union, Any
import tensorflow as tf
def SampleRandomSequence(model_input, num_frames, num_samples):
def sample_random_sequence(model_input, num_frames, num_samples):
"""Samples a random sequence of frames of size num_samples.
Args:
......@@ -44,7 +46,7 @@ def SampleRandomSequence(model_input, num_frames, num_samples):
return tf.gather_nd(model_input, index)
def SampleRandomFrames(model_input, num_frames, num_samples):
def sample_random_frames(model_input, num_frames, num_samples):
"""Samples a random set of frames of size num_samples.
Args:
......@@ -66,7 +68,7 @@ def SampleRandomFrames(model_input, num_frames, num_samples):
return tf.gather_nd(model_input, index)
def FramePooling(frames, method):
def frame_pooling(frames, method):
"""Pools over the frames of a video.
Args:
......@@ -93,3 +95,110 @@ def FramePooling(frames, method):
raise ValueError("Unrecognized pooling method: %s" % method)
return reduced
def context_gate(
input_features,
normalizer_fn=None,
normalizer_params: Optional[Dict[str, Any]] = None,
kernel_initializer: Union[
str, tf.keras.regularizers.Regularizer] = "glorot_uniform",
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_initializer: Union[str, tf.keras.regularizers.Regularizer] = "zeros",
hidden_layer_size: int = 0,
pooling_method: Optional[str] = None,
additive_residual: bool = False):
"""Context Gating.
More details: https://arxiv.org/pdf/1706.06905.pdf.
Args:
input_features: a tensor of at least rank 2.
normalizer_fn: Normalization function to use instead of `biases` (e.g.
tf.contrib.layers.batch_norm). If None, bias is added.
normalizer_params: Normalization function parameters.
kernel_initializer: Weight initializer to use instead of Xavier (e.g.
tf.contrib.layers.variance_scaling_initializer).
kernel_regularizer: Weight regularizer to use instead of None (e.g.,
tf.contrib.layers.l2_regularizer(l2_penalty)).
bias_initializer: Biases initializer to use (default tf.zeros_initializer)
hidden_layer_size: Dimensionality of the context gating hidden layer size,
if any. If None, will apply a fully-connected context gating layer with
shape [input_size x input_size]. If set to an int N, will factorize the
context gating layer into [input_size x N] x [N x input_size] as in the
squeeze-and-excitation block from https://arxiv.org/pdf/1709.01507.pdf.
pooling_method: Whether to perform global pooling of the local features
before applying the context gating layer. This is relevant only if the
input_features tensor has rank > 2, e.g., it's a sequence of frame
features, [batch_size, num_frames, feature_dim], or spatial convolution
features, [batch_size*num_frames, h, w, feature_dim]. If the inputs are a
set of local features and pooling_method is not None, will pool features
across all but the batch_size dimension using the specified pooling
method, and pass the aggregated features as context to the gating layer.
For a list of pooling methods, see the frame_pooling() function.
additive_residual: If true, will use ReLu6-activated (additive) residual
connections instead of Sigmoid-activated (multiplicative) connections when
combining the input_features with the context gating branch.
Returns:
A tensor with the same shape as input_features.
"""
if normalizer_params is None:
normalizer_params = {}
with tf.name_scope("ContextGating"):
num_dimensions = len(input_features.shape.as_list())
feature_size = input_features.shape.as_list()[-1]
if pooling_method:
assert num_dimensions > 2
# Collapse the inner axes of the original features shape into a 3D tensor
original_shape = tf.shape(input_features)
# The last dimension will change after concatenating the context
new_shape = tf.concat(
[original_shape[:-1],
tf.constant([2 * feature_size])], 0)
batch_size = original_shape[0]
reshaped_features = tf.reshape(input_features,
[batch_size, -1, feature_size])
num_features = tf.shape(reshaped_features)[1]
# Pool the feature channels across the inner axes to get global context
context_features = frame_pooling(reshaped_features, pooling_method)
context_features = tf.expand_dims(context_features, 1)
# Replicate the global context features and concat to the local features.
context_features = tf.tile(context_features, [1, num_features, 1])
context_features = tf.concat([reshaped_features, context_features], 2)
context_features = tf.reshape(context_features, shape=new_shape)
else:
context_features = input_features
if hidden_layer_size >= 2:
gates_bottleneck = tf.keras.layers.Dense(
hidden_layer_size,
activation="relu6",
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
)(
context_features)
if normalizer_fn:
gates_bottleneck = normalizer_fn(**normalizer_params)(gates_bottleneck)
else:
gates_bottleneck = context_features
activation_fn = (tf.nn.relu6 if additive_residual else tf.nn.sigmoid)
gates = tf.keras.layers.Dense(
feature_size,
activation=activation_fn,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
)(
gates_bottleneck)
if normalizer_fn:
gates = normalizer_fn(**normalizer_params)(gates)
if additive_residual:
input_features += gates
else:
input_features *= gates
return input_features
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment