yt8m_model.py 6.92 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Hye Yoon's avatar
Hye Yoon committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

Hye Yoon's avatar
Hye Yoon committed
15
"""YT8M model definition."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
from typing import Optional
Hye Yoon's avatar
Hye Yoon committed
17
18
19

import tensorflow as tf
from official.modeling import tf_utils
Yeqing Li's avatar
Yeqing Li committed
20
21
22
from official.projects.yt8m.configs import yt8m as yt8m_cfg
from official.projects.yt8m.modeling import yt8m_agg_models
from official.projects.yt8m.modeling import yt8m_model_utils as utils
Hye Yoon's avatar
Hye Yoon committed
23
24
25
26

layers = tf.keras.layers


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class DbofModel(tf.keras.Model):
  """A YT8M model class builder.

  Creates a Deep Bag of Frames model.
  The model projects the features for each frame into a higher dimensional
  'clustering' space, pools across frames in that space, and then
  uses a configurable video-level model to classify the now aggregated features.
  The model will randomly sample either frames or sequences of frames during
  training to speed up convergence.
  """

  def __init__(
      self,
      params: yt8m_cfg.DbofModel,
41
42
43
44
      num_frames: int = 30,
      num_classes: int = 3862,
      input_specs: layers.InputSpec = layers.InputSpec(
          shape=[None, None, 1152]),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
45
46
47
48
49
50
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      activation: str = "relu",
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      **kwargs):
Hye Yoon's avatar
Hye Yoon committed
51
52
    """YT8M initialization function.

53
    Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
54
      params: model configuration parameters
55
56
57
58
      num_frames: `int` number of frames in a single input.
      num_classes: `int` number of classes in dataset.
      input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
        [batch_size x num_frames x num_features]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
59
60
61
62
63
64
      kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
        None.
      activation: A `str` of name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
65
      **kwargs: keyword arguments to be passed.
Hye Yoon's avatar
Hye Yoon committed
66
67
68
69
    """

    self._self_setattr_tracking = False
    self._config_dict = {
70
71
72
        "input_specs": input_specs,
        "num_classes": num_classes,
        "num_frames": num_frames,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
73
        "params": params
Hye Yoon's avatar
Hye Yoon committed
74
75
76
    }
    self._num_classes = num_classes
    self._input_specs = input_specs
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
77
78
79
80
81
82
83
84
85
    self._act_fn = tf_utils.get_activation(activation)
    if use_sync_bn:
      self._norm = layers.experimental.SyncBatchNormalization
    else:
      self._norm = layers.BatchNormalization
    if tf.keras.backend.image_data_format() == "channels_last":
      bn_axis = -1
    else:
      bn_axis = 1
Hye Yoon's avatar
Hye Yoon committed
86
87
88
89
90
91
92
93
94

    # [batch_size x num_frames x num_features]
    feature_size = input_specs.shape[-1]
    # shape 'excluding' batch_size
    model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
    reshaped_input = tf.reshape(model_input, [-1, feature_size])
    tf.summary.histogram("input_hist", model_input)

    # configure model
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
95
96
97
98
99
100
    if params.add_batch_norm:
      reshaped_input = self._norm(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="input_bn")(
101
              reshaped_input)
Hye Yoon's avatar
Hye Yoon committed
102
103

    # activation = reshaped input * cluster weights
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    if params.cluster_size > 0:
      activation = layers.Dense(
          params.cluster_size,
          kernel_regularizer=kernel_regularizer,
          kernel_initializer=tf.random_normal_initializer(
              stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
                  reshaped_input)

    if params.add_batch_norm:
      activation = self._norm(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="cluster_bn")(
118
              activation)
Hye Yoon's avatar
Hye Yoon committed
119
120
    else:
      cluster_biases = tf.Variable(
121
          tf.random_normal_initializer(stddev=1 / tf.math.sqrt(feature_size))(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
122
              shape=[params.cluster_size]),
123
          name="cluster_biases")
Hye Yoon's avatar
Hye Yoon committed
124
125
126
127
128
129
      tf.summary.histogram("cluster_biases", cluster_biases)
      activation += cluster_biases

    activation = self._act_fn(activation)
    tf.summary.histogram("cluster_output", activation)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    if params.use_context_gate_cluster_layer:
      pooling_method = None
      norm_args = dict(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="context_gate_bn")
      activation = utils.context_gate(
          activation,
          normalizer_fn=self._norm,
          normalizer_params=norm_args,
          pooling_method=pooling_method,
          hidden_layer_size=params.context_gate_cluster_bottleneck_size,
          kernel_regularizer=kernel_regularizer)
    activation = tf.reshape(activation, [-1, num_frames, params.cluster_size])
    activation = utils.frame_pooling(activation, params.pooling_method)
Hye Yoon's avatar
Hye Yoon committed
146
147
148

    # activation = activation * hidden1_weights
    activation = layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
149
150
        params.hidden_size,
        kernel_regularizer=kernel_regularizer,
151
        kernel_initializer=tf.random_normal_initializer(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
152
            stddev=1 / tf.sqrt(tf.cast(params.cluster_size, tf.float32))))(
153
                activation)
Hye Yoon's avatar
Hye Yoon committed
154

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
155
156
157
158
159
160
    if params.add_batch_norm:
      activation = self._norm(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          name="hidden1_bn")(
161
              activation)
Hye Yoon's avatar
Hye Yoon committed
162
163
164

    else:
      hidden1_biases = tf.Variable(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
165
          tf.random_normal_initializer(stddev=0.01)(shape=[params.hidden_size]),
166
          name="hidden1_biases")
Hye Yoon's avatar
Hye Yoon committed
167
168
169
170
171
172
173
174

      tf.summary.histogram("hidden1_biases", hidden1_biases)
      activation += hidden1_biases

    activation = self._act_fn(activation)
    tf.summary.histogram("hidden1_output", activation)

    aggregated_model = getattr(yt8m_agg_models,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
175
176
                               params.yt8m_agg_classifier_model)
    norm_args = dict(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)
177
    output = aggregated_model().create_model(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
178
179
180
181
182
183
        model_input=activation,
        vocab_size=self._num_classes,
        num_mixtures=params.agg_model.num_mixtures,
        normalizer_fn=self._norm,
        normalizer_params=norm_args,
        l2_penalty=params.agg_model.l2_penalty)
Hye Yoon's avatar
Hye Yoon committed
184

185
186
    super().__init__(
        inputs=model_input, outputs=output.get("predictions"), **kwargs)
Hye Yoon's avatar
Hye Yoon committed
187
188
189
190
191
192
193
194
195
196
197
198

  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
    return dict()

  def get_config(self):
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)