yt8m_task.py 12.4 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Hye Yoon's avatar
Hye Yoon committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

Hye Yoon's avatar
Hye Yoon committed
15
"""Video classification task definition."""
Chaochao Yan's avatar
Chaochao Yan committed
16
17
from typing import Dict, List, Optional, Tuple

Hye Yoon's avatar
Hye Yoon committed
18
from absl import logging
19
20
21
22
23
import tensorflow as tf

from official.core import base_task
from official.core import input_reader
from official.core import task_factory
Hye Yoon's avatar
Hye Yoon committed
24
from official.modeling import tf_utils
Yeqing Li's avatar
Yeqing Li committed
25
26
27
28
29
from official.projects.yt8m.configs import yt8m as yt8m_cfg
from official.projects.yt8m.dataloaders import yt8m_input
from official.projects.yt8m.eval_utils import eval_util
from official.projects.yt8m.modeling import yt8m_model_utils as utils
from official.projects.yt8m.modeling.yt8m_model import DbofModel
Hye Yoon's avatar
Hye Yoon committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


@task_factory.register_task_cls(yt8m_cfg.YT8MTask)
class YT8MTask(base_task.Task):
  """A task for video classification."""

  def build_model(self):
    """Builds model for YT8M Task."""
    train_cfg = self.task_config.train_data
    common_input_shape = [None, sum(train_cfg.feature_sizes)]

    # [batch_size x num_frames x num_features]
    input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape)
    logging.info('Build model input %r', common_input_shape)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
45
46
47
48
49
50
51
    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (
        tf.keras.regularizers.l2(l2_weight_decay /
                                 2.0) if l2_weight_decay else None)
52
    # Model configuration.
Hye Yoon's avatar
Hye Yoon committed
53
    model_config = self.task_config.model
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
54
55
56
    norm_activation_config = model_config.norm_activation
    model = DbofModel(
        params=model_config,
57
58
        input_specs=input_specs,
        num_frames=train_cfg.num_frames,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
59
60
61
62
63
64
        num_classes=train_cfg.num_classes,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)
Hye Yoon's avatar
Hye Yoon committed
65
66
67
68
    return model

  def build_inputs(self, params: yt8m_cfg.DataConfig, input_context=None):
    """Builds input.
69

Hye Yoon's avatar
Hye Yoon committed
70
71
    Args:
      params: configuration for input data
72
73
      input_context: indicates information about the compute replicas and input
        pipelines
Hye Yoon's avatar
Hye Yoon committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    Returns:
      dataset: dataset fetched from reader
    """

    decoder = yt8m_input.Decoder(input_params=params)
    decoder_fn = decoder.decode
    parser = yt8m_input.Parser(input_params=params)
    parser_fn = parser.parse_fn(params.is_training)
    postprocess = yt8m_input.PostBatchProcessor(input_params=params)
    postprocess_fn = postprocess.post_fn
    transform_batch = yt8m_input.TransformBatcher(input_params=params)
    batch_fn = transform_batch.batch_fn

    reader = input_reader.InputReader(
        params,
        dataset_fn=tf.data.TFRecordDataset,
        decoder_fn=decoder_fn,
        parser_fn=parser_fn,
        postprocess_fn=postprocess_fn,
94
        transform_and_batch_fn=batch_fn)
Hye Yoon's avatar
Hye Yoon committed
95
96
97
98
99

    dataset = reader.read(input_context=input_context)

    return dataset

Chaochao Yan's avatar
Chaochao Yan committed
100
101
102
103
104
  def build_losses(self,
                   labels,
                   model_outputs,
                   label_weights=None,
                   aux_losses=None):
105
106
    """Sigmoid Cross Entropy.

Hye Yoon's avatar
Hye Yoon committed
107
108
109
    Args:
      labels: tensor containing truth labels.
      model_outputs: output logits of the classifier.
Chaochao Yan's avatar
Chaochao Yan committed
110
      label_weights: optional tensor of label weights.
111
112
      aux_losses: tensor containing auxiliarly loss tensors, i.e. `losses` in
        keras.Model.
Hye Yoon's avatar
Hye Yoon committed
113
114

    Returns:
Chaochao Yan's avatar
Chaochao Yan committed
115
      A dict of tensors contains total loss, model loss tensors.
Hye Yoon's avatar
Hye Yoon committed
116
117
118
    """
    losses_config = self.task_config.losses
    model_loss = tf.keras.losses.binary_crossentropy(
119
120
121
        labels,
        model_outputs,
        from_logits=losses_config.from_logits,
Chaochao Yan's avatar
Chaochao Yan committed
122
123
124
125
126
127
128
129
130
131
132
133
        label_smoothing=losses_config.label_smoothing,
        axis=None)

    if label_weights is None:
      model_loss = tf_utils.safe_mean(model_loss)
    else:
      model_loss = model_loss * label_weights
      # Manutally compute weighted mean loss.
      total_loss = tf.reduce_sum(model_loss)
      total_weight = tf.cast(
          tf.reduce_sum(label_weights), dtype=total_loss.dtype)
      model_loss = tf.math.divide_no_nan(total_loss, total_weight)
Hye Yoon's avatar
Hye Yoon committed
134
135
136
137
138

    total_loss = model_loss
    if aux_losses:
      total_loss += tf.add_n(aux_losses)

Chaochao Yan's avatar
Chaochao Yan committed
139
    return {'total_loss': total_loss, 'model_loss': model_loss}
Hye Yoon's avatar
Hye Yoon committed
140
141
142

  def build_metrics(self, training=True):
    """Gets streaming metrics for training/validation.
143

Hye Yoon's avatar
Hye Yoon committed
144
145
146
147
148
149
       metric: mAP/gAP
       top_k: A positive integer specifying how many predictions are considered
        per video.
       top_n: A positive Integer specifying the average precision at n, or None
        to use all provided data points.
    Args:
Chaochao Yan's avatar
Chaochao Yan committed
150
      training: Bool value, true for training mode, false for eval/validation.
Hye Yoon's avatar
Hye Yoon committed
151
152

    Returns:
Chaochao Yan's avatar
Chaochao Yan committed
153
      A list of strings that indicate metrics to be used.
Hye Yoon's avatar
Hye Yoon committed
154
155
156
157
158
159
    """
    metrics = []
    metric_names = ['total_loss', 'model_loss']
    for name in metric_names:
      metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))

160
    if not training:  # Cannot run in train step.
Hye Yoon's avatar
Hye Yoon committed
161
162
163
164
      num_classes = self.task_config.validation_data.num_classes
      top_k = self.task_config.top_k
      top_n = self.task_config.top_n
      self.avg_prec_metric = eval_util.EvaluationMetrics(
165
          num_classes, top_k=top_k, top_n=top_n)
Hye Yoon's avatar
Hye Yoon committed
166
167
168

    return metrics

Chaochao Yan's avatar
Chaochao Yan committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
  def process_metrics(self,
                      metrics: List[tf.keras.metrics.Metric],
                      labels: tf.Tensor,
                      outputs: tf.Tensor,
                      model_losses: Optional[Dict[str, tf.Tensor]] = None,
                      label_weights: Optional[tf.Tensor] = None,
                      training: bool = True,
                      **kwargs) -> Dict[str, Tuple[tf.Tensor, ...]]:
    """Updates metrics.

    Args:
      metrics: Evaluation metrics to be updated.
      labels: A tensor containing truth labels.
      outputs: Model output logits of the classifier.
      model_losses: An optional dict of model losses.
      label_weights: Optional label weights, can be broadcast into shape of
        outputs/labels.
      training: Bool indicates if in training mode.
      **kwargs: Additional input arguments.

    Returns:
      Updated dict of metrics log.
    """
    if model_losses is None:
      model_losses = {}

    logs = {}
    if not training:
      logs.update({self.avg_prec_metric.name: (labels, outputs)})

    for m in metrics:
      m.update_state(model_losses[m.name])
      logs[m.name] = m.result()
    return logs

Hye Yoon's avatar
Hye Yoon committed
204
205
  def train_step(self, inputs, model, optimizer, metrics=None):
    """Does forward and backward.
206

Hye Yoon's avatar
Hye Yoon committed
207
    Args:
Chaochao Yan's avatar
Chaochao Yan committed
208
209
210
      inputs: a dictionary of input tensors. output_dict = { "video_ids":
        batch_video_ids, "video_matrix": batch_video_matrix, "labels":
        batch_labels, "num_frames": batch_frames, }
Hye Yoon's avatar
Hye Yoon committed
211
212
213
214
215
216
217
218
219
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      a dictionary of logs.
    """
    features, labels = inputs['video_matrix'], inputs['labels']
    num_frames = inputs['num_frames']
Chaochao Yan's avatar
Chaochao Yan committed
220
    label_weights = inputs.get('label_weights', None)
Hye Yoon's avatar
Hye Yoon committed
221
222
223
224
225

    # sample random frames / random sequence
    num_frames = tf.cast(num_frames, tf.float32)
    sample_frames = self.task_config.train_data.num_frames
    if self.task_config.model.sample_random_frames:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
226
      features = utils.sample_random_frames(features, num_frames, sample_frames)
Hye Yoon's avatar
Hye Yoon committed
227
    else:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
228
229
      features = utils.sample_random_sequence(features, num_frames,
                                              sample_frames)
Hye Yoon's avatar
Hye Yoon committed
230
231
232
233
234
235
236
237

    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(features, training=True)
      # Casting output layer as float32 is necessary when mixed_precision is
      # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
      outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
      # Computes per-replica loss
Chaochao Yan's avatar
Chaochao Yan committed
238
239
240
241
242
243
244
      all_losses = self.build_losses(
          model_outputs=outputs,
          labels=labels,
          label_weights=label_weights,
          aux_losses=model.losses)

      loss = all_losses['total_loss']
Hye Yoon's avatar
Hye Yoon committed
245
246
247
248
249
250
      # Scales loss as the default gradients allreduce performs sum inside the
      # optimizer.
      scaled_loss = loss / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
Chaochao Yan's avatar
Chaochao Yan committed
251
      if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Hye Yoon's avatar
Hye Yoon committed
252
253
254
255
256
257
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient before apply_gradients when LossScaleOptimizer is
    # used.
Chaochao Yan's avatar
Chaochao Yan committed
258
    if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Hye Yoon's avatar
Hye Yoon committed
259
260
261
262
      grads = optimizer.get_unscaled_gradients(grads)

    # Apply gradient clipping.
    if self.task_config.gradient_clip_norm > 0:
263
264
      grads, _ = tf.clip_by_global_norm(grads,
                                        self.task_config.gradient_clip_norm)
Hye Yoon's avatar
Hye Yoon committed
265
266
267
268
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: loss}

Chaochao Yan's avatar
Chaochao Yan committed
269
270
271
272
273
274
275
276
    logs.update(
        self.process_metrics(
            metrics,
            labels=labels,
            outputs=outputs,
            model_losses=all_losses,
            label_weights=label_weights,
            training=True))
Hye Yoon's avatar
Hye Yoon committed
277
278
279
280
281
282
283

    return logs

  def validation_step(self, inputs, model, metrics=None):
    """Validatation step.

    Args:
Chaochao Yan's avatar
Chaochao Yan committed
284
285
286
      inputs: a dictionary of input tensors. output_dict = { "video_ids":
        batch_video_ids, "video_matrix": batch_video_matrix, "labels":
        batch_labels, "num_frames": batch_frames, }
Hye Yoon's avatar
Hye Yoon committed
287
288
289
290
291
292
293
294
      model: the model, forward definition
      metrics: a nested structure of metrics objects.

    Returns:
      a dictionary of logs.
    """
    features, labels = inputs['video_matrix'], inputs['labels']
    num_frames = inputs['num_frames']
Chaochao Yan's avatar
Chaochao Yan committed
295
    label_weights = inputs.get('label_weights', None)
Hye Yoon's avatar
Hye Yoon committed
296
297
298
299

    # sample random frames (None, 5, 1152) -> (None, 30, 1152)
    sample_frames = self.task_config.validation_data.num_frames
    if self.task_config.model.sample_random_frames:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
300
      features = utils.sample_random_frames(features, num_frames, sample_frames)
Hye Yoon's avatar
Hye Yoon committed
301
    else:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
302
303
      features = utils.sample_random_sequence(features, num_frames,
                                              sample_frames)
Hye Yoon's avatar
Hye Yoon committed
304
305
306
307
308

    outputs = self.inference_step(features, model)
    outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
    if self.task_config.validation_data.segment_labels:
      # workaround to ignore the unrated labels.
Chaochao Yan's avatar
Chaochao Yan committed
309
      outputs *= label_weights
Hye Yoon's avatar
Hye Yoon committed
310
311
312
313
      # remove padding
      outputs = outputs[~tf.reduce_all(labels == -1, axis=1)]
      labels = labels[~tf.reduce_all(labels == -1, axis=1)]

Chaochao Yan's avatar
Chaochao Yan committed
314
315
316
317
318
    all_losses = self.build_losses(
        labels=labels,
        model_outputs=outputs,
        label_weights=label_weights,
        aux_losses=model.losses)
Hye Yoon's avatar
Hye Yoon committed
319

Chaochao Yan's avatar
Chaochao Yan committed
320
    logs = {self.loss: all_losses['total_loss']}
Hye Yoon's avatar
Hye Yoon committed
321

Chaochao Yan's avatar
Chaochao Yan committed
322
323
324
325
326
327
328
329
    logs.update(
        self.process_metrics(
            metrics,
            labels=labels,
            outputs=outputs,
            model_losses=all_losses,
            label_weights=inputs.get('label_weights', None),
            training=False))
Hye Yoon's avatar
Hye Yoon committed
330
331
332
333
334
335
336
337
338
339
340

    return logs

  def inference_step(self, inputs, model):
    """Performs the forward step."""
    return model(inputs, training=False)

  def aggregate_logs(self, state=None, step_logs=None):
    if state is None:
      state = self.avg_prec_metric
    self.avg_prec_metric.accumulate(
341
342
        labels=step_logs[self.avg_prec_metric.name][0],
        predictions=step_logs[self.avg_prec_metric.name][1])
Hye Yoon's avatar
Hye Yoon committed
343
344
    return state

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
345
  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
Hye Yoon's avatar
Hye Yoon committed
346
347
348
    avg_prec_metrics = self.avg_prec_metric.get()
    self.avg_prec_metric.clear()
    return avg_prec_metrics