video_classification.py 10.2 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Video classification task definition."""
17
from absl import logging
Yeqing Li's avatar
Yeqing Li committed
18
19
20
21
22
23
24
import tensorflow as tf
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import video_input
Yeqing Li's avatar
Yeqing Li committed
25
from official.vision.beta.modeling import factory_3d
Yeqing Li's avatar
Yeqing Li committed
26
27
28
29
30
31
32
33


@task_factory.register_task_cls(exp_cfg.VideoClassificationTask)
class VideoClassificationTask(base_task.Task):
  """A task for video classification."""

  def build_model(self):
    """Builds video classification model."""
34
35
36
37
38
39
40
    common_input_shape = [
        d1 if d1 == d2 else None
        for d1, d2 in zip(self.task_config.train_data.feature_shape,
                          self.task_config.validation_data.feature_shape)
    ]
    input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape)
    logging.info('Build model input %r', common_input_shape)
Yeqing Li's avatar
Yeqing Li committed
41
42
43
44
45
46
47
48

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (tf.keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

Yeqing Li's avatar
Yeqing Li committed
49
50
    model = factory_3d.build_model(
        self.task_config.model.model_type,
Yeqing Li's avatar
Yeqing Li committed
51
52
53
54
55
56
        input_specs=input_specs,
        model_config=self.task_config.model,
        num_classes=self.task_config.train_data.num_classes,
        l2_regularizer=l2_regularizer)
    return model

Yeqing Li's avatar
Yeqing Li committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  def _get_dataset_fn(self, params):
    if params.file_type == 'tfrecord':
      return tf.data.TFRecordDataset
    else:
      raise ValueError('Unknown input file type {!r}'.format(params.file_type))

  def _get_decoder_fn(self, params):
    decoder = video_input.Decoder()
    if self.task_config.train_data.output_audio:
      assert self.task_config.train_data.audio_feature, 'audio feature is empty'
      decoder.add_feature(self.task_config.train_data.audio_feature,
                          tf.io.VarLenFeature(dtype=tf.float32))
    return decoder.decode

Yeqing Li's avatar
Yeqing Li committed
71
72
73
74
75
76
77
78
  def build_inputs(self, params: exp_cfg.DataConfig, input_context=None):
    """Builds classification input."""

    parser = video_input.Parser(input_params=params)
    postprocess_fn = video_input.PostBatchProcessor(params)

    reader = input_reader.InputReader(
        params,
Yeqing Li's avatar
Yeqing Li committed
79
80
        dataset_fn=self._get_dataset_fn(params),
        decoder_fn=self._get_decoder_fn(params),
Yeqing Li's avatar
Yeqing Li committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        parser_fn=parser.parse_fn(params.is_training),
        postprocess_fn=postprocess_fn)

    dataset = reader.read(input_context=input_context)

    return dataset

  def build_losses(self, labels, model_outputs, aux_losses=None):
    """Sparse categorical cross entropy loss.

    Args:
      labels: labels.
      model_outputs: Output logits of the classifier.
      aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.

    Returns:
      The total loss tensor.
    """
Yeqing Li's avatar
Yeqing Li committed
99
    all_losses = {}
Yeqing Li's avatar
Yeqing Li committed
100
    losses_config = self.task_config.losses
Yeqing Li's avatar
Yeqing Li committed
101
102
103
104
105
106
107
108
109
110
    total_loss = None
    if self.task_config.train_data.is_multilabel:
      entropy = -tf.reduce_mean(
          tf.reduce_sum(model_outputs * tf.math.log(model_outputs + 1e-8), -1))
      total_loss = tf.keras.losses.binary_crossentropy(
          labels, model_outputs, from_logits=False)
      all_losses.update({
          'class_loss': total_loss,
          'entropy': entropy,
      })
Yeqing Li's avatar
Yeqing Li committed
111
    else:
Yeqing Li's avatar
Yeqing Li committed
112
113
114
115
116
117
118
119
120
      if losses_config.one_hot:
        total_loss = tf.keras.losses.categorical_crossentropy(
            labels,
            model_outputs,
            from_logits=False,
            label_smoothing=losses_config.label_smoothing)
      else:
        total_loss = tf.keras.losses.sparse_categorical_crossentropy(
            labels, model_outputs, from_logits=False)
Yeqing Li's avatar
Yeqing Li committed
121

Yeqing Li's avatar
Yeqing Li committed
122
123
124
125
      total_loss = tf_utils.safe_mean(total_loss)
      all_losses.update({
          'class_loss': total_loss,
      })
Yeqing Li's avatar
Yeqing Li committed
126
    if aux_losses:
Yeqing Li's avatar
Yeqing Li committed
127
128
129
      all_losses.update({
          'reg_loss': aux_losses,
      })
Yeqing Li's avatar
Yeqing Li committed
130
      total_loss += tf.add_n(aux_losses)
Yeqing Li's avatar
Yeqing Li committed
131
    all_losses[self.loss] = total_loss
Yeqing Li's avatar
Yeqing Li committed
132

Yeqing Li's avatar
Yeqing Li committed
133
    return all_losses
Yeqing Li's avatar
Yeqing Li committed
134
135
136
137
138
139
140
141
142

  def build_metrics(self, training=True):
    """Gets streaming metrics for training/validation."""
    if self.task_config.losses.one_hot:
      metrics = [
          tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
          tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'),
          tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')
      ]
Yeqing Li's avatar
Yeqing Li committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
      if self.task_config.train_data.is_multilabel:
        metrics.append(
            tf.keras.metrics.AUC(
                curve='ROC',
                multi_label=self.task_config.train_data.is_multilabel,
                name='ROC-AUC'))
        metrics.append(
            tf.keras.metrics.RecallAtPrecision(
                0.95, name='RecallAtPrecision95'))
        metrics.append(
            tf.keras.metrics.AUC(
                curve='PR',
                multi_label=self.task_config.train_data.is_multilabel,
                name='PR-AUC'))
Yeqing Li's avatar
Yeqing Li committed
157
158
159
160
161
162
163
164
165
166
    else:
      metrics = [
          tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
          tf.keras.metrics.SparseTopKCategoricalAccuracy(
              k=1, name='top_1_accuracy'),
          tf.keras.metrics.SparseTopKCategoricalAccuracy(
              k=5, name='top_5_accuracy')
      ]
    return metrics

Yeqing Li's avatar
Yeqing Li committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
  def process_metrics(self, metrics, labels, model_outputs):
    """Process and update metrics.

    Called when using custom training loop API.

    Args:
      metrics: a nested structure of metrics objects. The return of function
        self.build_metrics.
      labels: a tensor or a nested structure of tensors.
      model_outputs: a tensor or a nested structure of tensors. For example,
        output of the keras model built by self.build_model.
    """
    for metric in metrics:
      metric.update_state(labels, model_outputs)

Yeqing Li's avatar
Yeqing Li committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
  def train_step(self, inputs, model, optimizer, metrics=None):
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs

    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
Yeqing Li's avatar
Yeqing Li committed
198
199
200
201
      if self.task_config.train_data.output_audio:
        outputs = model(features, training=True)
      else:
        outputs = model(features['image'], training=True)
Yeqing Li's avatar
Yeqing Li committed
202
203
204
205
206
207
      # Casting output layer as float32 is necessary when mixed_precision is
      # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
      outputs = tf.nest.map_structure(
          lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
Yeqing Li's avatar
Yeqing Li committed
208
209
210
211
212
      if self.task_config.train_data.is_multilabel:
        outputs = tf.math.sigmoid(outputs)
      else:
        outputs = tf.math.softmax(outputs)
      all_losses = self.build_losses(
Yeqing Li's avatar
Yeqing Li committed
213
          model_outputs=outputs, labels=labels, aux_losses=model.losses)
Yeqing Li's avatar
Yeqing Li committed
214
      loss = all_losses[self.loss]
Yeqing Li's avatar
Yeqing Li committed
215
216
217
218
219
220
221
      # Scales loss as the default gradients allreduce performs sum inside the
      # optimizer.
      scaled_loss = loss / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
      if isinstance(
Pankaj Kanwar's avatar
Pankaj Kanwar committed
222
          optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Yeqing Li's avatar
Yeqing Li committed
223
224
225
226
227
228
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient before apply_gradients when LossScaleOptimizer is
    # used.
Pankaj Kanwar's avatar
Pankaj Kanwar committed
229
    if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Yeqing Li's avatar
Yeqing Li committed
230
231
232
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

Yeqing Li's avatar
Yeqing Li committed
233
    logs = all_losses
Yeqing Li's avatar
Yeqing Li committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})
    elif model.compiled_metrics:
      self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
      logs.update({m.name: m.result() for m in model.metrics})
    return logs

  def validation_step(self, inputs, model, metrics=None):
    """Validatation step.

    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs

Yeqing Li's avatar
Yeqing Li committed
255
    outputs = self.inference_step(features, model)
Yeqing Li's avatar
Yeqing Li committed
256
    outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
Yeqing Li's avatar
Yeqing Li committed
257
    logs = self.build_losses(model_outputs=outputs, labels=labels,
Yeqing Li's avatar
Yeqing Li committed
258
259
260
261
262
263
264
265
266
267
                             aux_losses=model.losses)

    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})
    elif model.compiled_metrics:
      self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
      logs.update({m.name: m.result() for m in model.metrics})
    return logs

Yeqing Li's avatar
Yeqing Li committed
268
  def inference_step(self, features, model):
Yeqing Li's avatar
Yeqing Li committed
269
    """Performs the forward step."""
Yeqing Li's avatar
Yeqing Li committed
270
271
272
273
    if self.task_config.train_data.output_audio:
      outputs = model(features, training=False)
    else:
      outputs = model(features['image'], training=False)
Yeqing Li's avatar
Yeqing Li committed
274
275
276
277
278
    if self.task_config.train_data.is_multilabel:
      outputs = tf.math.sigmoid(outputs)
    else:
      outputs = tf.math.softmax(outputs)
    return outputs