yt8m_model.py 6.73 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Hye Yoon's avatar
Hye Yoon committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

Hye Yoon's avatar
Hye Yoon committed
15
"""YT8M model definition."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
from typing import Optional
Hye Yoon's avatar
Hye Yoon committed
17
18

import tensorflow as tf
19

Hye Yoon's avatar
Hye Yoon committed
20
from official.modeling import tf_utils
Yeqing Li's avatar
Yeqing Li committed
21
from official.projects.yt8m.configs import yt8m as yt8m_cfg
22
from official.projects.yt8m.modeling import nn_layers
Yeqing Li's avatar
Yeqing Li committed
23
from official.projects.yt8m.modeling import yt8m_model_utils as utils
Hye Yoon's avatar
Hye Yoon committed
24
25
26
27

layers = tf.keras.layers


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class DbofModel(tf.keras.Model):
  """A YT8M model class builder.

  Creates a Deep Bag of Frames model.
  The model projects the features for each frame into a higher dimensional
  'clustering' space, pools across frames in that space, and then
  uses a configurable video-level model to classify the now aggregated features.
  The model will randomly sample either frames or sequences of frames during
  training to speed up convergence.
  """

  def __init__(
      self,
      params: yt8m_cfg.DbofModel,
42
43
44
45
      num_frames: int = 30,
      num_classes: int = 3862,
      input_specs: layers.InputSpec = layers.InputSpec(
          shape=[None, None, 1152]),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
46
47
48
49
50
51
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      activation: str = "relu",
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      **kwargs):
Hye Yoon's avatar
Hye Yoon committed
52
53
    """YT8M initialization function.

54
    Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
55
      params: model configuration parameters
56
57
58
59
      num_frames: `int` number of frames in a single input.
      num_classes: `int` number of classes in dataset.
      input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
        [batch_size x num_frames x num_features]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
60
61
62
63
64
65
      kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
        None.
      activation: A `str` of name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
66
      **kwargs: keyword arguments to be passed.
Hye Yoon's avatar
Hye Yoon committed
67
    """
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
68
    del num_frames
Hye Yoon's avatar
Hye Yoon committed
69
70
    self._self_setattr_tracking = False
    self._config_dict = {
71
72
        "input_specs": input_specs,
        "num_classes": num_classes,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
73
        "params": params
Hye Yoon's avatar
Hye Yoon committed
74
75
76
    }
    self._num_classes = num_classes
    self._input_specs = input_specs
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
77
78
79
80
81
    self._act_fn = tf_utils.get_activation(activation)
    if use_sync_bn:
      self._norm = layers.experimental.SyncBatchNormalization
    else:
      self._norm = layers.BatchNormalization
Hye Yoon's avatar
Hye Yoon committed
82

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
83
    bn_axis = -1
Hye Yoon's avatar
Hye Yoon committed
84
85
86
87
    # [batch_size x num_frames x num_features]
    feature_size = input_specs.shape[-1]
    # shape 'excluding' batch_size
    model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
Chaochao Yan's avatar
Chaochao Yan committed
88
89
90
    # normalize input features
    input_data = tf.nn.l2_normalize(model_input, -1)
    tf.summary.histogram("input_hist", input_data)
Hye Yoon's avatar
Hye Yoon committed
91
92

    # configure model
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
93
    if params.add_batch_norm:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
94
      input_data = self._norm(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
95
96
97
98
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="input_bn")(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
99
              input_data)
Hye Yoon's avatar
Hye Yoon committed
100
101

    # activation = reshaped input * cluster weights
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
102
103
104
105
106
107
    if params.cluster_size > 0:
      activation = layers.Dense(
          params.cluster_size,
          kernel_regularizer=kernel_regularizer,
          kernel_initializer=tf.random_normal_initializer(
              stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
108
                  input_data)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
109
110
111
112
113
114
115

    if params.add_batch_norm:
      activation = self._norm(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="cluster_bn")(
116
              activation)
Hye Yoon's avatar
Hye Yoon committed
117
118
    else:
      cluster_biases = tf.Variable(
119
          tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
120
              shape=[params.cluster_size]),
121
          name="cluster_biases")
Hye Yoon's avatar
Hye Yoon committed
122
123
124
125
126
127
      tf.summary.histogram("cluster_biases", cluster_biases)
      activation += cluster_biases

    activation = self._act_fn(activation)
    tf.summary.histogram("cluster_output", activation)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    if params.use_context_gate_cluster_layer:
      pooling_method = None
      norm_args = dict(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="context_gate_bn")
      activation = utils.context_gate(
          activation,
          normalizer_fn=self._norm,
          normalizer_params=norm_args,
          pooling_method=pooling_method,
          hidden_layer_size=params.context_gate_cluster_bottleneck_size,
          kernel_regularizer=kernel_regularizer)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
142

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
143
    activation = utils.frame_pooling(activation, params.pooling_method)
Hye Yoon's avatar
Hye Yoon committed
144
145
146

    # activation = activation * hidden1_weights
    activation = layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
147
148
        params.hidden_size,
        kernel_regularizer=kernel_regularizer,
149
        kernel_initializer=tf.random_normal_initializer(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
150
            stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
151
                activation)
Hye Yoon's avatar
Hye Yoon committed
152

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
153
154
155
156
157
158
    if params.add_batch_norm:
      activation = self._norm(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="hidden1_bn")(
159
              activation)
Hye Yoon's avatar
Hye Yoon committed
160
161
162

    else:
      hidden1_biases = tf.Variable(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
163
          tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
164
          name="hidden1_biases")
Hye Yoon's avatar
Hye Yoon committed
165
166
167
168
169
170
171

      tf.summary.histogram("hidden1_biases", hidden1_biases)
      activation += hidden1_biases

    activation = self._act_fn(activation)
    tf.summary.histogram("hidden1_output", activation)

172
    aggregated_model = getattr(nn_layers,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
173
174
                               params.yt8m_agg_classifier_model)
    norm_args = dict(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)
175
    output = aggregated_model().create_model(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
176
177
178
179
180
181
        model_input=activation,
        vocab_size=self._num_classes,
        num_mixtures=params.agg_model.num_mixtures,
        normalizer_fn=self._norm,
        normalizer_params=norm_args,
        l2_penalty=params.agg_model.l2_penalty)
Hye Yoon's avatar
Hye Yoon committed
182

183
184
    super().__init__(
        inputs=model_input, outputs=output.get("predictions"), **kwargs)
Hye Yoon's avatar
Hye Yoon committed
185
186
187
188
189
190
191
192
193
194
195
196

  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
    return dict()

  def get_config(self):
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)