yt8m_model.py 6.88 KB
Newer Older
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Hye Yoon's avatar
Hye Yoon committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

Hye Yoon's avatar
Hye Yoon committed
15
"""YT8M model definition."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
from typing import Optional
Hye Yoon's avatar
Hye Yoon committed
17
18
19

import tensorflow as tf
from official.modeling import tf_utils
Yeqing Li's avatar
Yeqing Li committed
20
21
22
from official.projects.yt8m.configs import yt8m as yt8m_cfg
from official.projects.yt8m.modeling import yt8m_agg_models
from official.projects.yt8m.modeling import yt8m_model_utils as utils
Hye Yoon's avatar
Hye Yoon committed
23
24
25
26

layers = tf.keras.layers


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class DbofModel(tf.keras.Model):
  """A YT8M model class builder.

  Creates a Deep Bag of Frames model.
  The model projects the features for each frame into a higher dimensional
  'clustering' space, pools across frames in that space, and then
  uses a configurable video-level model to classify the now aggregated features.
  The model will randomly sample either frames or sequences of frames during
  training to speed up convergence.
  """

  def __init__(
      self,
      params: yt8m_cfg.DbofModel,
      num_frames=30,
      num_classes=3862,
      input_specs=layers.InputSpec(shape=[None, None, 1152]),
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      activation: str = "relu",
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      **kwargs):
Hye Yoon's avatar
Hye Yoon committed
50
51
    """YT8M initialization function.

52
    Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
53
      params: model configuration parameters
54
55
56
57
      num_frames: `int` number of frames in a single input.
      num_classes: `int` number of classes in dataset.
      input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
        [batch_size x num_frames x num_features]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
58
59
60
61
62
63
      kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
        None.
      activation: A `str` of name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
64
      **kwargs: keyword arguments to be passed.
Hye Yoon's avatar
Hye Yoon committed
65
66
67
68
    """

    self._self_setattr_tracking = False
    self._config_dict = {
69
70
71
        "input_specs": input_specs,
        "num_classes": num_classes,
        "num_frames": num_frames,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
72
        "params": params
Hye Yoon's avatar
Hye Yoon committed
73
74
75
    }
    self._num_classes = num_classes
    self._input_specs = input_specs
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
76
77
78
79
80
81
82
83
84
    self._act_fn = tf_utils.get_activation(activation)
    if use_sync_bn:
      self._norm = layers.experimental.SyncBatchNormalization
    else:
      self._norm = layers.BatchNormalization
    if tf.keras.backend.image_data_format() == "channels_last":
      bn_axis = -1
    else:
      bn_axis = 1
Hye Yoon's avatar
Hye Yoon committed
85
86
87
88
89
90
91
92
93

    # [batch_size x num_frames x num_features]
    feature_size = input_specs.shape[-1]
    # shape 'excluding' batch_size
    model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
    reshaped_input = tf.reshape(model_input, [-1, feature_size])
    tf.summary.histogram("input_hist", model_input)

    # configure model
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
94
95
96
97
98
99
    if params.add_batch_norm:
      reshaped_input = self._norm(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="input_bn")(
100
              reshaped_input)
Hye Yoon's avatar
Hye Yoon committed
101
102

    # activation = reshaped input * cluster weights
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    if params.cluster_size > 0:
      activation = layers.Dense(
          params.cluster_size,
          kernel_regularizer=kernel_regularizer,
          kernel_initializer=tf.random_normal_initializer(
              stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
                  reshaped_input)

    if params.add_batch_norm:
      activation = self._norm(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="cluster_bn")(
117
              activation)
Hye Yoon's avatar
Hye Yoon committed
118
119
    else:
      cluster_biases = tf.Variable(
120
          tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
121
              shape=[params.cluster_size]),
122
          name="cluster_biases")
Hye Yoon's avatar
Hye Yoon committed
123
124
125
126
127
128
      tf.summary.histogram("cluster_biases", cluster_biases)
      activation += cluster_biases

    activation = self._act_fn(activation)
    tf.summary.histogram("cluster_output", activation)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    if params.use_context_gate_cluster_layer:
      pooling_method = None
      norm_args = dict(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="context_gate_bn")
      activation = utils.context_gate(
          activation,
          normalizer_fn=self._norm,
          normalizer_params=norm_args,
          pooling_method=pooling_method,
          hidden_layer_size=params.context_gate_cluster_bottleneck_size,
          kernel_regularizer=kernel_regularizer)
    activation = tf.reshape(activation, [-1, num_frames, params.cluster_size])
    activation = utils.frame_pooling(activation, params.pooling_method)
Hye Yoon's avatar
Hye Yoon committed
145
146
147

    # activation = activation * hidden1_weights
    activation = layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
148
149
        params.hidden_size,
        kernel_regularizer=kernel_regularizer,
150
        kernel_initializer=tf.random_normal_initializer(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
151
            stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
152
                activation)
Hye Yoon's avatar
Hye Yoon committed
153

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
154
155
156
157
158
159
    if params.add_batch_norm:
      activation = self._norm(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="hidden1_bn")(
160
              activation)
Hye Yoon's avatar
Hye Yoon committed
161
162
163

    else:
      hidden1_biases = tf.Variable(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
164
          tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
165
          name="hidden1_biases")
Hye Yoon's avatar
Hye Yoon committed
166
167
168
169
170
171
172
173

      tf.summary.histogram("hidden1_biases", hidden1_biases)
      activation += hidden1_biases

    activation = self._act_fn(activation)
    tf.summary.histogram("hidden1_output", activation)

    aggregated_model = getattr(yt8m_agg_models,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
174
175
                               params.yt8m_agg_classifier_model)
    norm_args = dict(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)
176
    output = aggregated_model().create_model(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
177
178
179
180
181
182
        model_input=activation,
        vocab_size=self._num_classes,
        num_mixtures=params.agg_model.num_mixtures,
        normalizer_fn=self._norm,
        normalizer_params=norm_args,
        l2_penalty=params.agg_model.l2_penalty)
Hye Yoon's avatar
Hye Yoon committed
183

184
185
    super().__init__(
        inputs=model_input, outputs=output.get("predictions"), **kwargs)
Hye Yoon's avatar
Hye Yoon committed
186
187
188
189
190
191
192
193
194
195
196
197

  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
    return dict()

  def get_config(self):
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)