video_classification.py 12.3 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Yeqing Li's avatar
Yeqing Li committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Yeqing Li's avatar
Yeqing Li committed
15
"""Video classification task definition."""
Fan Yang's avatar
Fan Yang committed
16
17
from typing import Any, Optional, List, Tuple

18
from absl import logging
Yeqing Li's avatar
Yeqing Li committed
19
20
21
22
23
import tensorflow as tf
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.beta.configs import video_classification as exp_cfg
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
24
from official.vision.beta.dataloaders import input_reader_factory
Yeqing Li's avatar
Yeqing Li committed
25
from official.vision.beta.dataloaders import video_input
Yeqing Li's avatar
Yeqing Li committed
26
from official.vision.beta.modeling import factory_3d
Yeqing Li's avatar
Yeqing Li committed
27
28
29
30
31
32
33
34


@task_factory.register_task_cls(exp_cfg.VideoClassificationTask)
class VideoClassificationTask(base_task.Task):
  """A task for video classification."""

  def build_model(self):
    """Builds video classification model."""
35
36
37
38
39
40
41
    common_input_shape = [
        d1 if d1 == d2 else None
        for d1, d2 in zip(self.task_config.train_data.feature_shape,
                          self.task_config.validation_data.feature_shape)
    ]
    input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape)
    logging.info('Build model input %r', common_input_shape)
Yeqing Li's avatar
Yeqing Li committed
42
43
44
45
46
47
48
49

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (tf.keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

Yeqing Li's avatar
Yeqing Li committed
50
51
    model = factory_3d.build_model(
        self.task_config.model.model_type,
Yeqing Li's avatar
Yeqing Li committed
52
53
54
55
56
57
        input_specs=input_specs,
        model_config=self.task_config.model,
        num_classes=self.task_config.train_data.num_classes,
        l2_regularizer=l2_regularizer)
    return model

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  def initialize(self, model: tf.keras.Model):
    """Loads pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if self.task_config.init_checkpoint_modules == 'all':
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
      status = ckpt.restore(ckpt_dir_or_file)
      status.assert_consumed()
    elif self.task_config.init_checkpoint_modules == 'backbone':
      ckpt = tf.train.Checkpoint(backbone=model.backbone)
      status = ckpt.restore(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    else:
      raise ValueError(
          "Only 'all' or 'backbone' can be used to initialize the model.")

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

Yeqing Li's avatar
Yeqing Li committed
83
84
85
86
87
88
89
  def _get_dataset_fn(self, params):
    if params.file_type == 'tfrecord':
      return tf.data.TFRecordDataset
    else:
      raise ValueError('Unknown input file type {!r}'.format(params.file_type))

  def _get_decoder_fn(self, params):
90
91
92
93
94
95
    if params.tfds_name:
      decoder = video_input.VideoTfdsDecoder(
          image_key=params.image_field_key, label_key=params.label_field_key)
    else:
      decoder = video_input.Decoder(
          image_key=params.image_field_key, label_key=params.label_field_key)
Yeqing Li's avatar
Yeqing Li committed
96
97
98
99
100
101
    if self.task_config.train_data.output_audio:
      assert self.task_config.train_data.audio_feature, 'audio feature is empty'
      decoder.add_feature(self.task_config.train_data.audio_feature,
                          tf.io.VarLenFeature(dtype=tf.float32))
    return decoder.decode

Fan Yang's avatar
Fan Yang committed
102
103
104
  def build_inputs(self,
                   params: exp_cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
Yeqing Li's avatar
Yeqing Li committed
105
106
    """Builds classification input."""

Yeqing Li's avatar
Yeqing Li committed
107
108
109
110
    parser = video_input.Parser(
        input_params=params,
        image_key=params.image_field_key,
        label_key=params.label_field_key)
Yeqing Li's avatar
Yeqing Li committed
111
112
    postprocess_fn = video_input.PostBatchProcessor(params)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
113
    reader = input_reader_factory.input_reader_generator(
Yeqing Li's avatar
Yeqing Li committed
114
        params,
Yeqing Li's avatar
Yeqing Li committed
115
116
        dataset_fn=self._get_dataset_fn(params),
        decoder_fn=self._get_decoder_fn(params),
Yeqing Li's avatar
Yeqing Li committed
117
118
119
120
121
122
123
        parser_fn=parser.parse_fn(params.is_training),
        postprocess_fn=postprocess_fn)

    dataset = reader.read(input_context=input_context)

    return dataset

Fan Yang's avatar
Fan Yang committed
124
125
126
127
  def build_losses(self,
                   labels: Any,
                   model_outputs: Any,
                   aux_losses: Optional[Any] = None):
Yeqing Li's avatar
Yeqing Li committed
128
129
130
131
132
133
134
135
136
137
    """Sparse categorical cross entropy loss.

    Args:
      labels: labels.
      model_outputs: Output logits of the classifier.
      aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.

    Returns:
      The total loss tensor.
    """
Yeqing Li's avatar
Yeqing Li committed
138
    all_losses = {}
Yeqing Li's avatar
Yeqing Li committed
139
    losses_config = self.task_config.losses
Yeqing Li's avatar
Yeqing Li committed
140
141
142
143
144
145
146
147
148
149
    total_loss = None
    if self.task_config.train_data.is_multilabel:
      entropy = -tf.reduce_mean(
          tf.reduce_sum(model_outputs * tf.math.log(model_outputs + 1e-8), -1))
      total_loss = tf.keras.losses.binary_crossentropy(
          labels, model_outputs, from_logits=False)
      all_losses.update({
          'class_loss': total_loss,
          'entropy': entropy,
      })
Yeqing Li's avatar
Yeqing Li committed
150
    else:
Yeqing Li's avatar
Yeqing Li committed
151
152
153
154
155
156
157
158
159
      if losses_config.one_hot:
        total_loss = tf.keras.losses.categorical_crossentropy(
            labels,
            model_outputs,
            from_logits=False,
            label_smoothing=losses_config.label_smoothing)
      else:
        total_loss = tf.keras.losses.sparse_categorical_crossentropy(
            labels, model_outputs, from_logits=False)
Yeqing Li's avatar
Yeqing Li committed
160

Yeqing Li's avatar
Yeqing Li committed
161
162
163
164
      total_loss = tf_utils.safe_mean(total_loss)
      all_losses.update({
          'class_loss': total_loss,
      })
Yeqing Li's avatar
Yeqing Li committed
165
    if aux_losses:
Yeqing Li's avatar
Yeqing Li committed
166
167
168
      all_losses.update({
          'reg_loss': aux_losses,
      })
Yeqing Li's avatar
Yeqing Li committed
169
      total_loss += tf.add_n(aux_losses)
Yeqing Li's avatar
Yeqing Li committed
170
    all_losses[self.loss] = total_loss
Yeqing Li's avatar
Yeqing Li committed
171

Yeqing Li's avatar
Yeqing Li committed
172
    return all_losses
Yeqing Li's avatar
Yeqing Li committed
173

Fan Yang's avatar
Fan Yang committed
174
  def build_metrics(self, training: bool = True):
Yeqing Li's avatar
Yeqing Li committed
175
176
177
178
179
180
181
    """Gets streaming metrics for training/validation."""
    if self.task_config.losses.one_hot:
      metrics = [
          tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
          tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'),
          tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')
      ]
Yeqing Li's avatar
Yeqing Li committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
      if self.task_config.train_data.is_multilabel:
        metrics.append(
            tf.keras.metrics.AUC(
                curve='ROC',
                multi_label=self.task_config.train_data.is_multilabel,
                name='ROC-AUC'))
        metrics.append(
            tf.keras.metrics.RecallAtPrecision(
                0.95, name='RecallAtPrecision95'))
        metrics.append(
            tf.keras.metrics.AUC(
                curve='PR',
                multi_label=self.task_config.train_data.is_multilabel,
                name='PR-AUC'))
Yeqing Li's avatar
Yeqing Li committed
196
197
198
199
        if self.task_config.metrics.use_per_class_recall:
          for i in range(self.task_config.train_data.num_classes):
            metrics.append(
                tf.keras.metrics.Recall(class_id=i, name=f'recall-{i}'))
Yeqing Li's avatar
Yeqing Li committed
200
201
202
203
204
205
206
207
208
209
    else:
      metrics = [
          tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
          tf.keras.metrics.SparseTopKCategoricalAccuracy(
              k=1, name='top_1_accuracy'),
          tf.keras.metrics.SparseTopKCategoricalAccuracy(
              k=5, name='top_5_accuracy')
      ]
    return metrics

Fan Yang's avatar
Fan Yang committed
210
211
  def process_metrics(self, metrics: List[Any], labels: Any,
                      model_outputs: Any):
Yeqing Li's avatar
Yeqing Li committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    """Process and update metrics.

    Called when using custom training loop API.

    Args:
      metrics: a nested structure of metrics objects. The return of function
        self.build_metrics.
      labels: a tensor or a nested structure of tensors.
      model_outputs: a tensor or a nested structure of tensors. For example,
        output of the keras model built by self.build_model.
    """
    for metric in metrics:
      metric.update_state(labels, model_outputs)

Fan Yang's avatar
Fan Yang committed
226
227
228
229
230
  def train_step(self,
                 inputs: Tuple[Any, Any],
                 model: tf.keras.Model,
                 optimizer: tf.keras.optimizers.Optimizer,
                 metrics: Optional[List[Any]] = None):
Yeqing Li's avatar
Yeqing Li committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs

    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
Yeqing Li's avatar
Yeqing Li committed
246
      outputs = model(features, training=True)
Yeqing Li's avatar
Yeqing Li committed
247
248
249
250
251
252
      # Casting output layer as float32 is necessary when mixed_precision is
      # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
      outputs = tf.nest.map_structure(
          lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
Yeqing Li's avatar
Yeqing Li committed
253
254
255
256
257
      if self.task_config.train_data.is_multilabel:
        outputs = tf.math.sigmoid(outputs)
      else:
        outputs = tf.math.softmax(outputs)
      all_losses = self.build_losses(
Yeqing Li's avatar
Yeqing Li committed
258
          model_outputs=outputs, labels=labels, aux_losses=model.losses)
Yeqing Li's avatar
Yeqing Li committed
259
      loss = all_losses[self.loss]
Yeqing Li's avatar
Yeqing Li committed
260
261
262
263
264
265
266
      # Scales loss as the default gradients allreduce performs sum inside the
      # optimizer.
      scaled_loss = loss / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
      if isinstance(
Pankaj Kanwar's avatar
Pankaj Kanwar committed
267
          optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Yeqing Li's avatar
Yeqing Li committed
268
269
270
271
272
273
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient before apply_gradients when LossScaleOptimizer is
    # used.
Pankaj Kanwar's avatar
Pankaj Kanwar committed
274
    if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Yeqing Li's avatar
Yeqing Li committed
275
276
277
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

Yeqing Li's avatar
Yeqing Li committed
278
    logs = all_losses
Yeqing Li's avatar
Yeqing Li committed
279
280
281
282
283
284
285
286
    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})
    elif model.compiled_metrics:
      self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
      logs.update({m.name: m.result() for m in model.metrics})
    return logs

Fan Yang's avatar
Fan Yang committed
287
288
289
290
  def validation_step(self,
                      inputs: Tuple[Any, Any],
                      model: tf.keras.Model,
                      metrics: Optional[List[Any]] = None):
Yeqing Li's avatar
Yeqing Li committed
291
292
293
294
295
296
297
298
299
300
301
302
    """Validatation step.

    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs

Yeqing Li's avatar
Yeqing Li committed
303
    outputs = self.inference_step(features, model)
Yeqing Li's avatar
Yeqing Li committed
304
    outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
Yeqing Li's avatar
Yeqing Li committed
305
    logs = self.build_losses(model_outputs=outputs, labels=labels,
Yeqing Li's avatar
Yeqing Li committed
306
307
308
309
310
311
312
313
314
315
                             aux_losses=model.losses)

    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})
    elif model.compiled_metrics:
      self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
      logs.update({m.name: m.result() for m in model.metrics})
    return logs

Fan Yang's avatar
Fan Yang committed
316
  def inference_step(self, features: tf.Tensor, model: tf.keras.Model):
Yeqing Li's avatar
Yeqing Li committed
317
    """Performs the forward step."""
Yeqing Li's avatar
Yeqing Li committed
318
    outputs = model(features, training=False)
Yeqing Li's avatar
Yeqing Li committed
319
320
321
322
    if self.task_config.train_data.is_multilabel:
      outputs = tf.math.sigmoid(outputs)
    else:
      outputs = tf.math.softmax(outputs)
Yin Cui's avatar
Yin Cui committed
323
324
325
326
327
328
329
    num_test_clips = self.task_config.validation_data.num_test_clips
    num_test_crops = self.task_config.validation_data.num_test_crops
    num_test_views = num_test_clips * num_test_crops
    if num_test_views > 1:
      # Averaging output probabilities across multiples views.
      outputs = tf.reshape(outputs, [-1, num_test_views, outputs.shape[-1]])
      outputs = tf.reduce_mean(outputs, axis=1)
Yeqing Li's avatar
Yeqing Li committed
330
    return outputs