Commit e9355843 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by saberkun
Browse files

Internal change

PiperOrigin-RevId: 401839863
parent 8bfa4d03
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.projects.volumetric_models.modeling.decoders import factory as decoder_factory
from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d
from official.vision.beta.modeling import segmentation_model from official.vision.beta.modeling import segmentation_model
from official.vision.beta.modeling.backbones import factory as backbone_factory from official.vision.beta.modeling.backbones import factory as backbone_factory
from official.vision.beta.projects.volumetric_models.modeling.decoders import factory as decoder_factory
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
def build_segmentation_model_3d( def build_segmentation_model_3d(
......
...@@ -18,10 +18,10 @@ from absl.testing import parameterized ...@@ -18,10 +18,10 @@ from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg from official.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones from official.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders from official.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling import factory from official.projects.volumetric_models.modeling import factory
class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase): class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -19,7 +19,7 @@ from absl.testing import parameterized ...@@ -19,7 +19,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d
class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase): class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling import nn_blocks_3d from official.projects.volumetric_models.modeling import nn_blocks_3d
class NNBlocks3DTest(parameterized.TestCase, tf.test.TestCase): class NNBlocks3DTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -18,11 +18,10 @@ ...@@ -18,11 +18,10 @@
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d
from official.vision.beta.modeling import segmentation_model from official.vision.beta.modeling import segmentation_model
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
class SegmentationNetworkUNet3DTest(parameterized.TestCase, tf.test.TestCase): class SegmentationNetworkUNet3DTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports from official.common import registry_imports
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as semantic_segmentation_3d_cfg from official.projects.volumetric_models.configs import semantic_segmentation_3d as semantic_segmentation_3d_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones from official.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders from official.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d from official.projects.volumetric_models.tasks import semantic_segmentation_3d
...@@ -42,7 +42,7 @@ from absl import flags ...@@ -42,7 +42,7 @@ from absl import flags
from official.common import registry_imports # pylint: disable=unused-import from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.vision.beta.projects.volumetric_models.serving import semantic_segmentation_3d from official.projects.volumetric_models.serving import semantic_segmentation_3d
from official.vision.beta.serving import export_saved_model_lib from official.vision.beta.serving import export_saved_model_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -19,9 +19,9 @@ from typing import Mapping ...@@ -19,9 +19,9 @@ from typing import Mapping
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.vision.beta.projects.volumetric_models.modeling import backbones from official.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders from official.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling import factory from official.projects.volumetric_models.modeling import factory
from official.vision.beta.serving import export_base from official.vision.beta.serving import export_base
......
...@@ -22,10 +22,10 @@ import tensorflow as tf ...@@ -22,10 +22,10 @@ import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.core import exp_factory from official.core import exp_factory
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg from official.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones from official.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders from official.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.serving import semantic_segmentation_3d from official.projects.volumetric_models.serving import semantic_segmentation_3d
class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
......
...@@ -23,11 +23,11 @@ from official.common import dataset_fn ...@@ -23,11 +23,11 @@ from official.common import dataset_fn
from official.core import base_task from official.core import base_task
from official.core import input_reader from official.core import input_reader
from official.core import task_factory from official.core import task_factory
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg from official.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.dataloaders import segmentation_input_3d from official.projects.volumetric_models.dataloaders import segmentation_input_3d
from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics from official.projects.volumetric_models.evaluation import segmentation_metrics
from official.vision.beta.projects.volumetric_models.losses import segmentation_losses from official.projects.volumetric_models.losses import segmentation_losses
from official.vision.beta.projects.volumetric_models.modeling import factory from official.projects.volumetric_models.modeling import factory
@task_factory.register_task_cls(exp_cfg.SemanticSegmentation3DTask) @task_factory.register_task_cls(exp_cfg.SemanticSegmentation3DTask)
......
...@@ -26,11 +26,11 @@ import tensorflow as tf ...@@ -26,11 +26,11 @@ import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.projects.volumetric_models.evaluation import segmentation_metrics
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task
from official.vision.beta.dataloaders import tfexample_utils from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task
class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase): class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase):
......
...@@ -15,12 +15,11 @@ ...@@ -15,12 +15,11 @@
"""TensorFlow Model Garden Vision training driver.""" """TensorFlow Model Garden Vision training driver."""
from absl import app from absl import app
import gin # pylint: disable=unused-import import gin # pylint: disable=unused-import
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.projects.volumetric_models import registry_imports # pylint: disable=unused-import
from official.vision.beta import train from official.vision.beta import train
from official.vision.beta.projects.volumetric_models import registry_imports # pylint: disable=unused-import
def main(_): def main(_):
......
...@@ -20,9 +20,8 @@ from absl import flags ...@@ -20,9 +20,8 @@ from absl import flags
from absl import logging from absl import logging
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
from official.projects.volumetric_models import train as train_lib
from official.vision.beta.dataloaders import tfexample_utils from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.volumetric_models import train as train_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -13,14 +13,15 @@ ...@@ -13,14 +13,15 @@
# limitations under the License. # limitations under the License.
"""Video classification configuration definition.""" """Video classification configuration definition."""
import dataclasses
from typing import Optional, Tuple from typing import Optional, Tuple
from absl import flags from absl import flags
import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import common
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -66,16 +67,28 @@ def yt8m(is_training): ...@@ -66,16 +67,28 @@ def yt8m(is_training):
@dataclasses.dataclass @dataclasses.dataclass
class YT8MModel(hyperparams.Config): class MoeModel(hyperparams.Config):
"""The model config.""" """The model config."""
cluster_size: int = 2048 num_mixtures: int = 5
hidden_size: int = 2048 l2_penalty: float = 1e-5
use_input_context_gate: bool = False
use_output_context_gate: bool = False
@dataclasses.dataclass
class DbofModel(hyperparams.Config):
"""The model config."""
cluster_size: int = 3000
hidden_size: int = 2000
add_batch_norm: bool = True add_batch_norm: bool = True
sample_random_frames: bool = True sample_random_frames: bool = True
is_training: bool = True use_context_gate_cluster_layer: bool = False
activation: str = 'relu6' context_gate_cluster_bottleneck_size: int = 0
pooling_method: str = 'average' pooling_method: str = 'average'
yt8m_agg_classifier_model: str = 'MoeModel' yt8m_agg_classifier_model: str = 'MoeModel'
agg_model: hyperparams.Config = MoeModel()
norm_activation: common.NormActivation = common.NormActivation(
activation='relu', use_sync_bn=False)
@dataclasses.dataclass @dataclasses.dataclass
...@@ -83,12 +96,13 @@ class Losses(hyperparams.Config): ...@@ -83,12 +96,13 @@ class Losses(hyperparams.Config):
name: str = 'binary_crossentropy' name: str = 'binary_crossentropy'
from_logits: bool = False from_logits: bool = False
label_smoothing: float = 0.0 label_smoothing: float = 0.0
l2_weight_decay: float = 1e-5
@dataclasses.dataclass @dataclasses.dataclass
class YT8MTask(cfg.TaskConfig): class YT8MTask(cfg.TaskConfig):
"""The task config.""" """The task config."""
model: YT8MModel = YT8MModel() model: DbofModel = DbofModel()
train_data: DataConfig = yt8m(is_training=True) train_data: DataConfig = yt8m(is_training=True)
validation_data: DataConfig = yt8m(is_training=False) validation_data: DataConfig = yt8m(is_training=False)
losses: Losses = Losses() losses: Losses = Losses()
...@@ -102,8 +116,8 @@ def add_trainer( ...@@ -102,8 +116,8 @@ def add_trainer(
experiment: cfg.ExperimentConfig, experiment: cfg.ExperimentConfig,
train_batch_size: int, train_batch_size: int,
eval_batch_size: int, eval_batch_size: int,
learning_rate: float = 0.005, learning_rate: float = 0.0001,
train_epochs: int = 44, train_epochs: int = 50,
): ):
"""Add and config a trainer to the experiment config.""" """Add and config a trainer to the experiment config."""
if YT8M_TRAIN_EXAMPLES <= 0: if YT8M_TRAIN_EXAMPLES <= 0:
...@@ -115,13 +129,14 @@ def add_trainer( ...@@ -115,13 +129,14 @@ def add_trainer(
experiment.task.train_data.global_batch_size = train_batch_size experiment.task.train_data.global_batch_size = train_batch_size
experiment.task.validation_data.global_batch_size = eval_batch_size experiment.task.validation_data.global_batch_size = eval_batch_size
steps_per_epoch = YT8M_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = YT8M_TRAIN_EXAMPLES // train_batch_size
steps_per_loop = 30
experiment.trainer = cfg.TrainerConfig( experiment.trainer = cfg.TrainerConfig(
steps_per_loop=steps_per_epoch, steps_per_loop=steps_per_loop,
summary_interval=steps_per_epoch, summary_interval=steps_per_loop,
checkpoint_interval=steps_per_epoch, checkpoint_interval=steps_per_loop,
train_steps=train_epochs * steps_per_epoch, train_steps=train_epochs * steps_per_epoch,
validation_steps=YT8M_VAL_EXAMPLES // eval_batch_size, validation_steps=YT8M_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch, validation_interval=steps_per_loop,
optimizer_config=optimization.OptimizationConfig({ optimizer_config=optimization.OptimizationConfig({
'optimizer': { 'optimizer': {
'type': 'adam', 'type': 'adam',
...@@ -132,9 +147,18 @@ def add_trainer( ...@@ -132,9 +147,18 @@ def add_trainer(
'exponential': { 'exponential': {
'initial_learning_rate': learning_rate, 'initial_learning_rate': learning_rate,
'decay_rate': 0.95, 'decay_rate': 0.95,
'decay_steps': 1500000, 'decay_steps': int(steps_per_epoch * 1.5),
'offset': 500,
} }
}, },
'warmup': {
'linear': {
'name': 'linear',
'warmup_learning_rate': 0,
'warmup_steps': 500,
},
'type': 'linear',
}
})) }))
return experiment return experiment
...@@ -154,4 +178,17 @@ def yt8m_experiment() -> cfg.ExperimentConfig: ...@@ -154,4 +178,17 @@ def yt8m_experiment() -> cfg.ExperimentConfig:
'task.train_data.feature_names != None', 'task.train_data.feature_names != None',
]) ])
return add_trainer(exp_config, train_batch_size=512, eval_batch_size=512) # Per TPUv3 Core batch size 16GB HBM. `factor` in range(1, 26)
factor = 1
num_cores = 32 # for TPU 4x4
train_per_core_bs = 32 * factor
train_bs = train_per_core_bs * num_cores
eval_per_core_bs = 32 * 50 # multiplier<=100
eval_bs = eval_per_core_bs * num_cores
# based lr=0.0001 for bs=512
return add_trainer(
exp_config,
train_batch_size=train_bs,
eval_batch_size=eval_bs,
learning_rate=0.0001 * (train_bs / 512),
train_epochs=100)
...@@ -13,13 +13,12 @@ ...@@ -13,13 +13,12 @@
# limitations under the License. # limitations under the License.
"""Contains model definitions.""" """Contains model definitions."""
from typing import Optional, Dict, Any
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as utils
layers = tf.keras.layers layers = tf.keras.layers
regularizers = tf.keras.regularizers
# The number of mixtures (excluding the dummy 'expert') used for MoeModel.
moe_num_mixtures = 2
class LogisticModel(): class LogisticModel():
...@@ -41,7 +40,7 @@ class LogisticModel(): ...@@ -41,7 +40,7 @@ class LogisticModel():
output = layers.Dense( output = layers.Dense(
vocab_size, vocab_size,
activation=tf.nn.sigmoid, activation=tf.nn.sigmoid,
kernel_regularizer=regularizers.l2(l2_penalty))( kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
model_input) model_input)
return {"predictions": output} return {"predictions": output}
...@@ -52,8 +51,12 @@ class MoeModel(): ...@@ -52,8 +51,12 @@ class MoeModel():
def create_model(self, def create_model(self,
model_input, model_input,
vocab_size, vocab_size,
num_mixtures=None, num_mixtures: int = 2,
l2_penalty=1e-8): use_input_context_gate: bool = False,
use_output_context_gate: bool = False,
normalizer_fn=None,
normalizer_params: Optional[Dict[str, Any]] = None,
l2_penalty: float = 1e-5):
"""Creates a Mixture of (Logistic) Experts model. """Creates a Mixture of (Logistic) Experts model.
The model consists of a per-class softmax distribution over a The model consists of a per-class softmax distribution over a
...@@ -64,6 +67,10 @@ class MoeModel(): ...@@ -64,6 +67,10 @@ class MoeModel():
vocab_size: The number of classes in the dataset. vocab_size: The number of classes in the dataset.
num_mixtures: The number of mixtures (excluding a dummy 'expert' that num_mixtures: The number of mixtures (excluding a dummy 'expert' that
always predicts the non-existence of an entity). always predicts the non-existence of an entity).
use_input_context_gate: if True apply context gate layer to the input.
use_output_context_gate: if True apply context gate layer to the output.
normalizer_fn: normalization op constructor (e.g. batch norm).
normalizer_params: parameters to the `normalizer_fn`.
l2_penalty: How much to penalize the squared magnitudes of parameter l2_penalty: How much to penalize the squared magnitudes of parameter
values. values.
...@@ -72,18 +79,23 @@ class MoeModel(): ...@@ -72,18 +79,23 @@ class MoeModel():
of the model in the 'predictions' key. The dimensions of the tensor of the model in the 'predictions' key. The dimensions of the tensor
are batch_size x num_classes. are batch_size x num_classes.
""" """
num_mixtures = num_mixtures or moe_num_mixtures if use_input_context_gate:
model_input = utils.context_gate(
model_input,
normalizer_fn=normalizer_fn,
normalizer_params=normalizer_params,
)
gate_activations = layers.Dense( gate_activations = layers.Dense(
vocab_size * (num_mixtures + 1), vocab_size * (num_mixtures + 1),
activation=None, activation=None,
bias_initializer=None, bias_initializer=None,
kernel_regularizer=regularizers.l2(l2_penalty))( kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
model_input) model_input)
expert_activations = layers.Dense( expert_activations = layers.Dense(
vocab_size * num_mixtures, vocab_size * num_mixtures,
activation=None, activation=None,
kernel_regularizer=regularizers.l2(l2_penalty))( kernel_regularizer=tf.keras.regularizers.l2(l2_penalty))(
model_input) model_input)
gating_distribution = tf.nn.softmax( gating_distribution = tf.nn.softmax(
...@@ -98,4 +110,10 @@ class MoeModel(): ...@@ -98,4 +110,10 @@ class MoeModel():
gating_distribution[:, :num_mixtures] * expert_distribution, 1) gating_distribution[:, :num_mixtures] * expert_distribution, 1)
final_probabilities = tf.reshape(final_probabilities_by_class_and_batch, final_probabilities = tf.reshape(final_probabilities_by_class_and_batch,
[-1, vocab_size]) [-1, vocab_size])
if use_output_context_gate:
final_probabilities = utils.context_gate(
final_probabilities,
normalizer_fn=normalizer_fn,
normalizer_params=normalizer_params,
)
return {"predictions": final_probabilities} return {"predictions": final_probabilities}
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""YT8M model definition.""" """YT8M model definition."""
from typing import Optional
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -23,23 +24,43 @@ from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as util ...@@ -23,23 +24,43 @@ from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as util
layers = tf.keras.layers layers = tf.keras.layers
class YT8MModel(tf.keras.Model): class DbofModel(tf.keras.Model):
"""A YT8M model class builder.""" """A YT8M model class builder.
def __init__(self, Creates a Deep Bag of Frames model.
input_params: yt8m_cfg.YT8MModel, The model projects the features for each frame into a higher dimensional
'clustering' space, pools across frames in that space, and then
uses a configurable video-level model to classify the now aggregated features.
The model will randomly sample either frames or sequences of frames during
training to speed up convergence.
"""
def __init__(
self,
params: yt8m_cfg.DbofModel,
num_frames=30, num_frames=30,
num_classes=3862, num_classes=3862,
input_specs=layers.InputSpec(shape=[None, None, 1152]), input_specs=layers.InputSpec(shape=[None, None, 1152]),
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
activation: str = "relu",
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs): **kwargs):
"""YT8M initialization function. """YT8M initialization function.
Args: Args:
input_params: model configuration parameters params: model configuration parameters
num_frames: `int` number of frames in a single input. num_frames: `int` number of frames in a single input.
num_classes: `int` number of classes in dataset. num_classes: `int` number of classes in dataset.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor. input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
[batch_size x num_frames x num_features] [batch_size x num_frames x num_features]
kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
activation: A `str` of name of the activation function.
use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
...@@ -48,12 +69,19 @@ class YT8MModel(tf.keras.Model): ...@@ -48,12 +69,19 @@ class YT8MModel(tf.keras.Model):
"input_specs": input_specs, "input_specs": input_specs,
"num_classes": num_classes, "num_classes": num_classes,
"num_frames": num_frames, "num_frames": num_frames,
"input_params": input_params "params": params
} }
self._num_classes = num_classes self._num_classes = num_classes
self._input_specs = input_specs self._input_specs = input_specs
self._act_fn = tf_utils.get_activation(input_params.activation) self._act_fn = tf_utils.get_activation(activation)
self._is_training = input_params.is_training if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == "channels_last":
bn_axis = -1
else:
bn_axis = 1
# [batch_size x num_frames x num_features] # [batch_size x num_frames x num_features]
feature_size = input_specs.shape[-1] feature_size = input_specs.shape[-1]
...@@ -63,31 +91,34 @@ class YT8MModel(tf.keras.Model): ...@@ -63,31 +91,34 @@ class YT8MModel(tf.keras.Model):
tf.summary.histogram("input_hist", model_input) tf.summary.histogram("input_hist", model_input)
# configure model # configure model
if input_params.add_batch_norm: if params.add_batch_norm:
reshaped_input = layers.BatchNormalization( reshaped_input = self._norm(
name="input_bn", scale=True, center=True, axis=bn_axis,
trainable=self._is_training)( momentum=norm_momentum,
epsilon=norm_epsilon,
name="input_bn")(
reshaped_input) reshaped_input)
# activation = reshaped input * cluster weights # activation = reshaped input * cluster weights
if params.cluster_size > 0:
activation = layers.Dense( activation = layers.Dense(
input_params.cluster_size, params.cluster_size,
kernel_regularizer=kernel_regularizer,
kernel_initializer=tf.random_normal_initializer( kernel_initializer=tf.random_normal_initializer(
stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))( stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
reshaped_input) reshaped_input)
if input_params.add_batch_norm: if params.add_batch_norm:
activation = layers.BatchNormalization( activation = self._norm(
name="cluster_bn", axis=bn_axis,
scale=True, momentum=norm_momentum,
center=True, epsilon=norm_epsilon,
trainable=self._is_training)( name="cluster_bn")(
activation) activation)
else: else:
cluster_biases = tf.Variable( cluster_biases = tf.Variable(
tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))( tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
shape=[input_params.cluster_size]), shape=[params.cluster_size]),
name="cluster_biases") name="cluster_biases")
tf.summary.histogram("cluster_biases", cluster_biases) tf.summary.histogram("cluster_biases", cluster_biases)
activation += cluster_biases activation += cluster_biases
...@@ -95,30 +126,42 @@ class YT8MModel(tf.keras.Model): ...@@ -95,30 +126,42 @@ class YT8MModel(tf.keras.Model):
activation = self._act_fn(activation) activation = self._act_fn(activation)
tf.summary.histogram("cluster_output", activation) tf.summary.histogram("cluster_output", activation)
activation = tf.reshape(activation, if params.use_context_gate_cluster_layer:
[-1, num_frames, input_params.cluster_size]) pooling_method = None
activation = utils.FramePooling(activation, input_params.pooling_method) norm_args = dict(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
name="context_gate_bn")
activation = utils.context_gate(
activation,
normalizer_fn=self._norm,
normalizer_params=norm_args,
pooling_method=pooling_method,
hidden_layer_size=params.context_gate_cluster_bottleneck_size,
kernel_regularizer=kernel_regularizer)
activation = tf.reshape(activation, [-1, num_frames, params.cluster_size])
activation = utils.frame_pooling(activation, params.pooling_method)
# activation = activation * hidden1_weights # activation = activation * hidden1_weights
activation = layers.Dense( activation = layers.Dense(
input_params.hidden_size, params.hidden_size,
kernel_regularizer=kernel_regularizer,
kernel_initializer=tf.random_normal_initializer( kernel_initializer=tf.random_normal_initializer(
stddev=1 / stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
tf.sqrt(tf.cast(input_params.cluster_size, tf.float32))))(
activation) activation)
if input_params.add_batch_norm: if params.add_batch_norm:
activation = layers.BatchNormalization( activation = self._norm(
name="hidden1_bn", axis=bn_axis,
scale=True, momentum=norm_momentum,
center=True, epsilon=norm_epsilon,
trainable=self._is_training)( name="hidden1_bn")(
activation) activation)
else: else:
hidden1_biases = tf.Variable( hidden1_biases = tf.Variable(
tf.random_normal_initializer(stddev=0.01)( tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
shape=[input_params.hidden_size]),
name="hidden1_biases") name="hidden1_biases")
tf.summary.histogram("hidden1_biases", hidden1_biases) tf.summary.histogram("hidden1_biases", hidden1_biases)
...@@ -128,9 +171,15 @@ class YT8MModel(tf.keras.Model): ...@@ -128,9 +171,15 @@ class YT8MModel(tf.keras.Model):
tf.summary.histogram("hidden1_output", activation) tf.summary.histogram("hidden1_output", activation)
aggregated_model = getattr(yt8m_agg_models, aggregated_model = getattr(yt8m_agg_models,
input_params.yt8m_agg_classifier_model) params.yt8m_agg_classifier_model)
norm_args = dict(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)
output = aggregated_model().create_model( output = aggregated_model().create_model(
model_input=activation, vocab_size=self._num_classes) model_input=activation,
vocab_size=self._num_classes,
num_mixtures=params.agg_model.num_mixtures,
normalizer_fn=self._norm,
normalizer_params=norm_args,
l2_penalty=params.agg_model.l2_penalty)
super().__init__( super().__init__(
inputs=model_input, outputs=output.get("predictions"), **kwargs) inputs=model_input, outputs=output.get("predictions"), **kwargs)
......
...@@ -37,8 +37,8 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -37,8 +37,8 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
input_specs = tf.keras.layers.InputSpec(shape=[num_frames, feature_dims]) input_specs = tf.keras.layers.InputSpec(shape=[num_frames, feature_dims])
num_classes = 3862 num_classes = 3862
model = yt8m_model.YT8MModel( model = yt8m_model.DbofModel(
input_params=yt8m_cfg.YT8MTask.model, params=yt8m_cfg.YT8MTask.model,
num_frames=num_frames, num_frames=num_frames,
num_classes=num_classes, num_classes=num_classes,
input_specs=input_specs) input_specs=input_specs)
...@@ -49,10 +49,10 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -49,10 +49,10 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual([2, num_classes], logits.numpy().shape) self.assertAllEqual([2, num_classes], logits.numpy().shape)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
model = yt8m_model.YT8MModel(input_params=yt8m_cfg.YT8MTask.model) model = yt8m_model.DbofModel(params=yt8m_cfg.YT8MTask.model)
config = model.get_config() config = model.get_config()
new_model = yt8m_model.YT8MModel.from_config(config) new_model = yt8m_model.DbofModel.from_config(config)
# If the serialization was successful, # If the serialization was successful,
# the new config should match the old. # the new config should match the old.
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
"""Contains a collection of util functions for model construction.""" """Contains a collection of util functions for model construction."""
from typing import Dict, Optional, Union, Any
import tensorflow as tf import tensorflow as tf
def SampleRandomSequence(model_input, num_frames, num_samples): def sample_random_sequence(model_input, num_frames, num_samples):
"""Samples a random sequence of frames of size num_samples. """Samples a random sequence of frames of size num_samples.
Args: Args:
...@@ -44,7 +46,7 @@ def SampleRandomSequence(model_input, num_frames, num_samples): ...@@ -44,7 +46,7 @@ def SampleRandomSequence(model_input, num_frames, num_samples):
return tf.gather_nd(model_input, index) return tf.gather_nd(model_input, index)
def SampleRandomFrames(model_input, num_frames, num_samples): def sample_random_frames(model_input, num_frames, num_samples):
"""Samples a random set of frames of size num_samples. """Samples a random set of frames of size num_samples.
Args: Args:
...@@ -66,7 +68,7 @@ def SampleRandomFrames(model_input, num_frames, num_samples): ...@@ -66,7 +68,7 @@ def SampleRandomFrames(model_input, num_frames, num_samples):
return tf.gather_nd(model_input, index) return tf.gather_nd(model_input, index)
def FramePooling(frames, method): def frame_pooling(frames, method):
"""Pools over the frames of a video. """Pools over the frames of a video.
Args: Args:
...@@ -93,3 +95,110 @@ def FramePooling(frames, method): ...@@ -93,3 +95,110 @@ def FramePooling(frames, method):
raise ValueError("Unrecognized pooling method: %s" % method) raise ValueError("Unrecognized pooling method: %s" % method)
return reduced return reduced
def context_gate(
input_features,
normalizer_fn=None,
normalizer_params: Optional[Dict[str, Any]] = None,
kernel_initializer: Union[
str, tf.keras.regularizers.Regularizer] = "glorot_uniform",
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_initializer: Union[str, tf.keras.regularizers.Regularizer] = "zeros",
hidden_layer_size: int = 0,
pooling_method: Optional[str] = None,
additive_residual: bool = False):
"""Context Gating.
More details: https://arxiv.org/pdf/1706.06905.pdf.
Args:
input_features: a tensor of at least rank 2.
normalizer_fn: Normalization function to use instead of `biases` (e.g.
tf.contrib.layers.batch_norm). If None, bias is added.
normalizer_params: Normalization function parameters.
kernel_initializer: Weight initializer to use instead of Xavier (e.g.
tf.contrib.layers.variance_scaling_initializer).
kernel_regularizer: Weight regularizer to use instead of None (e.g.,
tf.contrib.layers.l2_regularizer(l2_penalty)).
bias_initializer: Biases initializer to use (default tf.zeros_initializer)
hidden_layer_size: Dimensionality of the context gating hidden layer size,
if any. If None, will apply a fully-connected context gating layer with
shape [input_size x input_size]. If set to an int N, will factorize the
context gating layer into [input_size x N] x [N x input_size] as in the
squeeze-and-excitation block from https://arxiv.org/pdf/1709.01507.pdf.
pooling_method: Whether to perform global pooling of the local features
before applying the context gating layer. This is relevant only if the
input_features tensor has rank > 2, e.g., it's a sequence of frame
features, [batch_size, num_frames, feature_dim], or spatial convolution
features, [batch_size*num_frames, h, w, feature_dim]. If the inputs are a
set of local features and pooling_method is not None, will pool features
across all but the batch_size dimension using the specified pooling
method, and pass the aggregated features as context to the gating layer.
For a list of pooling methods, see the frame_pooling() function.
additive_residual: If true, will use ReLu6-activated (additive) residual
connections instead of Sigmoid-activated (multiplicative) connections when
combining the input_features with the context gating branch.
Returns:
A tensor with the same shape as input_features.
"""
if normalizer_params is None:
normalizer_params = {}
with tf.name_scope("ContextGating"):
num_dimensions = len(input_features.shape.as_list())
feature_size = input_features.shape.as_list()[-1]
if pooling_method:
assert num_dimensions > 2
# Collapse the inner axes of the original features shape into a 3D tensor
original_shape = tf.shape(input_features)
# The last dimension will change after concatenating the context
new_shape = tf.concat(
[original_shape[:-1],
tf.constant([2 * feature_size])], 0)
batch_size = original_shape[0]
reshaped_features = tf.reshape(input_features,
[batch_size, -1, feature_size])
num_features = tf.shape(reshaped_features)[1]
# Pool the feature channels across the inner axes to get global context
context_features = frame_pooling(reshaped_features, pooling_method)
context_features = tf.expand_dims(context_features, 1)
# Replicate the global context features and concat to the local features.
context_features = tf.tile(context_features, [1, num_features, 1])
context_features = tf.concat([reshaped_features, context_features], 2)
context_features = tf.reshape(context_features, shape=new_shape)
else:
context_features = input_features
if hidden_layer_size >= 2:
gates_bottleneck = tf.keras.layers.Dense(
hidden_layer_size,
activation="relu6",
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
)(
context_features)
if normalizer_fn:
gates_bottleneck = normalizer_fn(**normalizer_params)(gates_bottleneck)
else:
gates_bottleneck = context_features
activation_fn = (tf.nn.relu6 if additive_residual else tf.nn.sigmoid)
gates = tf.keras.layers.Dense(
feature_size,
activation=activation_fn,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
)(
gates_bottleneck)
if normalizer_fn:
gates = normalizer_fn(**normalizer_params)(gates)
if additive_residual:
input_features += gates
else:
input_features *= gates
return input_features
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment