Commit 7adc6ec1 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 373155894
parent 6d6cd4ac
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for movinet_layers.py."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.projects.movinet.modeling import movinet_layers
class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
def test_squeeze3d(self):
squeeze = movinet_layers.Squeeze3D()
inputs = tf.ones([5, 1, 1, 1, 3])
predicted = squeeze(inputs)
expected = tf.ones([5, 3])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllEqual(predicted, expected)
def test_mobile_conv2d(self):
conv2d = movinet_layers.MobileConv2D(
filters=3,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
kernel_initializer='ones',
use_bias=False,
use_depthwise=False,
use_temporal=False,
use_buffered_input=True,
)
inputs = tf.ones([1, 2, 2, 2, 3])
predicted = conv2d(inputs)
expected = tf.constant(
[[[[[12., 12., 12.],
[12., 12., 12.]],
[[12., 12., 12.],
[12., 12., 12.]]],
[[[12., 12., 12.],
[12., 12., 12.]],
[[12., 12., 12.],
[12., 12., 12.]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_mobile_conv2d_temporal(self):
conv2d = movinet_layers.MobileConv2D(
filters=3,
kernel_size=(3, 1),
strides=(1, 1),
padding='causal',
kernel_initializer='ones',
use_bias=False,
use_depthwise=True,
use_temporal=True,
use_buffered_input=True,
)
inputs = tf.ones([1, 2, 2, 1, 3])
paddings = [[0, 0], [2, 0], [0, 0], [0, 0], [0, 0]]
padded_inputs = tf.pad(inputs, paddings)
predicted = conv2d(padded_inputs)
expected = tf.constant(
[[[[[1., 1., 1.]],
[[1., 1., 1.]]],
[[[2., 2., 2.]],
[[2., 2., 2.]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_stream_buffer(self):
conv3d_stream = nn_layers.Conv3D(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
padding='causal',
kernel_initializer='ones',
use_bias=False,
use_buffered_input=True,
)
buffer = movinet_layers.StreamBuffer(buffer_size=2)
conv3d = nn_layers.Conv3D(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
padding='causal',
kernel_initializer='ones',
use_bias=False,
use_buffered_input=False,
)
inputs = tf.ones([1, 4, 2, 2, 3])
expected = conv3d(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = buffer(frame, states=states)
x = conv3d_stream(x)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[12., 12., 12.]]],
[[[24., 24., 24.]]],
[[[36., 36., 36.]]],
[[[36., 36., 36.]]]]])
def test_stream_conv_block_2plus1d(self):
conv_block = movinet_layers.ConvBlock(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
kernel_initializer='ones',
use_bias=False,
activation='relu',
conv_type='2plus1d',
use_positional_encoding=True,
)
stream_conv_block = movinet_layers.StreamConvBlock(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
kernel_initializer='ones',
use_bias=False,
activation='relu',
conv_type='2plus1d',
use_positional_encoding=True,
)
inputs = tf.ones([1, 4, 2, 2, 3])
expected = conv_block(inputs)
predicted_disabled, _ = stream_conv_block(inputs)
self.assertEqual(predicted_disabled.shape, expected.shape)
self.assertAllClose(predicted_disabled, expected)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = stream_conv_block(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[35.9640400, 35.9640400, 35.9640400]]],
[[[71.9280700, 71.9280700, 71.9280700]]],
[[[107.892105, 107.892105, 107.892105]]],
[[[107.892105, 107.892105, 107.892105]]]]])
def test_stream_conv_block_3d_2plus1d(self):
conv_block = movinet_layers.ConvBlock(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
kernel_initializer='ones',
use_bias=False,
activation='relu',
conv_type='3d_2plus1d',
use_positional_encoding=True,
)
stream_conv_block = movinet_layers.StreamConvBlock(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
kernel_initializer='ones',
use_bias=False,
activation='relu',
conv_type='3d_2plus1d',
use_positional_encoding=True,
)
inputs = tf.ones([1, 4, 2, 2, 3])
expected = conv_block(inputs)
predicted_disabled, _ = stream_conv_block(inputs)
self.assertEqual(predicted_disabled.shape, expected.shape)
self.assertAllClose(predicted_disabled, expected)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = stream_conv_block(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[35.9640400, 35.9640400, 35.9640400]]],
[[[71.9280700, 71.9280700, 71.9280700]]],
[[[107.892105, 107.892105, 107.892105]]],
[[[107.892105, 107.892105, 107.892105]]]]])
def test_stream_conv_block(self):
conv_block = movinet_layers.ConvBlock(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
kernel_initializer='ones',
use_bias=False,
activation='relu',
)
stream_conv_block = movinet_layers.StreamConvBlock(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
kernel_initializer='ones',
use_bias=False,
activation='relu',
)
inputs = tf.ones([1, 4, 2, 2, 3])
expected = conv_block(inputs)
predicted_disabled, _ = stream_conv_block(inputs)
self.assertEqual(predicted_disabled.shape, expected.shape)
self.assertAllClose(predicted_disabled, expected)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = stream_conv_block(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[11.994005, 11.994005, 11.994005]]],
[[[23.988010, 23.988010, 23.988010]]],
[[[35.982014, 35.982014, 35.982014]]],
[[[35.982014, 35.982014, 35.982014]]]]])
def test_stream_squeeze_excitation(self):
se = movinet_layers.StreamSqueezeExcitation(
3, causal=True, kernel_initializer='ones')
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
expected, _ = se(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = se(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
self.assertAllClose(
predicted,
[[[[[0.9998109, 0.9998109, 0.9998109]],
[[0.9998109, 0.9998109, 0.9998109]]],
[[[1.9999969, 1.9999969, 1.9999969]],
[[1.9999969, 1.9999969, 1.9999969]]],
[[[3., 3., 3.]],
[[3., 3., 3.]]],
[[[4., 4., 4.]],
[[4., 4., 4.]]]]],
1e-5, 1e-5)
def test_stream_movinet_block(self):
block = movinet_layers.MovinetBlock(
out_filters=3,
expand_filters=6,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
expected, _ = block(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = block(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_stream_classifier_head(self):
head = movinet_layers.Head(project_filters=5)
classifier_head = movinet_layers.ClassifierHead(
head_filters=10, num_classes=4)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
x, _ = head(inputs)
expected = classifier_head(x)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
for frame in frames:
x, states = head(frame, states=states)
predicted = classifier_head(x)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build Movinet for video classification.
Reference: https://arxiv.org/pdf/2103.11511.pdf
"""
from typing import Mapping
from absl import logging
import tensorflow as tf
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import factory_3d as model_factory
from official.vision.beta.projects.movinet.configs import movinet as cfg
from official.vision.beta.projects.movinet.modeling import movinet_layers
@tf.keras.utils.register_keras_serializable(package='Vision')
class MovinetClassifier(tf.keras.Model):
"""A video classification class builder."""
def __init__(self,
backbone: tf.keras.Model,
num_classes: int,
input_specs: Mapping[str, tf.keras.layers.InputSpec] = None,
dropout_rate: float = 0.0,
kernel_initializer: str = 'HeNormal',
kernel_regularizer: tf.keras.regularizers.Regularizer = None,
bias_regularizer: tf.keras.regularizers.Regularizer = None,
output_states: bool = False,
**kwargs):
"""Movinet initialization function.
Args:
backbone: A 3d backbone network.
num_classes: Number of classes in classification task.
input_specs: Specs of the input tensor.
dropout_rate: Rate for dropout regularization.
kernel_initializer: Kernel initializer for the final dense layer.
kernel_regularizer: Kernel regularizer.
bias_regularizer: Bias regularizer.
output_states: if True, output intermediate states that can be used to run
the model in streaming mode. Inputting the output states of the
previous input clip with the current input clip will utilize a stream
buffer for streaming video.
**kwargs: Keyword arguments to be passed.
"""
if not input_specs:
input_specs = {
'image': tf.keras.layers.InputSpec(shape=[None, None, None, None, 3])
}
self._num_classes = num_classes
self._input_specs = input_specs
self._dropout_rate = dropout_rate
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._output_states = output_states
# Keras model variable that excludes @property.setters from tracking
self._self_setattr_tracking = False
inputs = {
name: tf.keras.Input(shape=state.shape[1:], name=f'states/{name}')
for name, state in input_specs.items()
}
states = inputs.get('states', {})
endpoints, states = backbone(dict(image=inputs['image'], states=states))
x = endpoints['head']
x = movinet_layers.ClassifierHead(
head_filters=backbone._head_filters,
num_classes=num_classes,
dropout_rate=dropout_rate,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
conv_type=backbone._conv_type)(x)
if output_states:
inputs['states'] = {
k: tf.keras.Input(shape=v.shape[1:], name=k)
for k, v in states.items()
}
outputs = (x, states) if output_states else x
super(MovinetClassifier, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
# Move backbone after super() call so Keras is happy
self._backbone = backbone
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
return dict(backbone=self.backbone)
@property
def backbone(self):
return self._backbone
def get_config(self):
config = {
'backbone': self._backbone,
'num_classes': self._num_classes,
'input_specs': self._input_specs,
'dropout_rate': self._dropout_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'output_states': self._output_states,
}
return config
@classmethod
def from_config(cls, config, custom_objects=None):
# Each InputSpec may need to be deserialized
# This handles the case where we want to load a saved_model loaded with
# `tf.keras.models.load_model`
if config['input_specs']:
for name in config['input_specs']:
if isinstance(config['input_specs'][name], dict):
config['input_specs'][name] = tf.keras.layers.deserialize(
config['input_specs'][name])
return cls(**config)
@model_factory.register_model_builder('movinet')
def build_movinet_model(
input_specs: tf.keras.layers.InputSpec,
model_config: cfg.MovinetModel,
num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds movinet model."""
logging.info('Building movinet model with num classes: %s', num_classes)
if l2_regularizer is not None:
logging.info('Building movinet model with regularizer: %s',
l2_regularizer.get_config())
input_specs_dict = {'image': input_specs}
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=model_config.norm_activation,
l2_regularizer=l2_regularizer)
model = MovinetClassifier(
backbone,
num_classes=num_classes,
kernel_regularizer=l2_regularizer,
input_specs=input_specs_dict,
dropout_rate=model_config.dropout_rate)
return model
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for movinet_model.py."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model
class MovinetModelTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(False, True)
def test_movinet_classifier_creation(self, is_training):
"""Test for creation of a Movinet classifier."""
temporal_size = 16
spatial_size = 224
tf.keras.backend.set_image_data_format('channels_last')
input_specs = tf.keras.layers.InputSpec(
shape=[None, temporal_size, spatial_size, spatial_size, 3])
backbone = movinet.Movinet(model_id='a0', input_specs=input_specs)
num_classes = 1000
model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=num_classes,
input_specs={'image': input_specs},
dropout_rate=0.2)
inputs = np.random.rand(2, temporal_size, spatial_size, spatial_size, 3)
logits = model(inputs, training=is_training)
self.assertAllEqual([2, num_classes], logits.shape)
def test_movinet_classifier_stream(self):
tf.keras.backend.set_image_data_format('channels_last')
model = movinet.Movinet(
model_id='a0',
causal=True,
)
inputs = tf.ones([1, 5, 128, 128, 3])
expected_endpoints, _ = model(dict(image=inputs, states={}))
frames = tf.split(inputs, inputs.shape[1], axis=1)
output, states = None, {}
for frame in frames:
output, states = model(dict(image=frame, states=states))
predicted_endpoints = output
predicted = predicted_endpoints['head']
# The expected final output is simply the mean across frames
expected = expected_endpoints['head']
expected = tf.reduce_mean(expected, 1, keepdims=True)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_serialize_deserialize(self):
"""Validate the classification network can be serialized and deserialized."""
backbone = movinet.Movinet(model_id='a0')
model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=1000)
config = model.get_config()
new_model = movinet_model.MovinetClassifier.from_config(config)
# Validate that the config can be forced to JSON.
new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
def test_saved_model_save_load(self):
backbone = movinet.Movinet('a0')
model = movinet_model.MovinetClassifier(
backbone, num_classes=600)
model.build([1, 5, 172, 172, 3])
model.compile(metrics=['acc'])
tf.keras.models.save_model(model, '/tmp/movinet/')
loaded_model = tf.keras.models.load_model('/tmp/movinet/')
output = loaded_model(dict(image=tf.ones([1, 1, 1, 1, 3])))
self.assertAllEqual(output.shape, [1, 600])
@parameterized.parameters(
('a0', 3.126071),
('a1', 4.717912),
('a2', 5.280922),
('a3', 7.443289),
('a4', 11.422727),
('a5', 18.763355),
('t0', 1.740502),
)
def test_movinet_models(self, model_id, expected_params_millions):
"""Test creation of MoViNet family models with states."""
tf.keras.backend.set_image_data_format('channels_last')
model = movinet_model.MovinetClassifier(
backbone=movinet.Movinet(
model_id=model_id,
causal=True),
num_classes=600)
model.build([1, 1, 1, 1, 3])
num_params_millions = model.count_params() / 1e6
self.assertEqual(num_params_millions, expected_params_millions)
def test_movinet_a0_2plus1d(self):
"""Test creation of MoViNet with 2plus1d configuration."""
tf.keras.backend.set_image_data_format('channels_last')
model_2plus1d = movinet_model.MovinetClassifier(
backbone=movinet.Movinet(
model_id='a0',
conv_type='2plus1d'),
num_classes=600)
model_2plus1d.build([1, 1, 1, 1, 3])
model_3d_2plus1d = movinet_model.MovinetClassifier(
backbone=movinet.Movinet(
model_id='a0',
conv_type='3d_2plus1d'),
num_classes=600)
model_3d_2plus1d.build([1, 1, 1, 1, 3])
# Ensure both models have the same weights
weights = []
for var_2plus1d, var_3d_2plus1d in zip(
model_2plus1d.get_weights(), model_3d_2plus1d.get_weights()):
if var_2plus1d.shape == var_3d_2plus1d.shape:
weights.append(var_3d_2plus1d)
else:
if var_3d_2plus1d.shape[0] == 1:
weight = var_3d_2plus1d[0]
else:
weight = var_3d_2plus1d[:, 0]
if weight.shape[-1] != var_2plus1d.shape[-1]:
# Transpose any depthwise kernels (conv3d --> depthwise_conv2d)
weight = tf.transpose(weight, perm=(0, 1, 3, 2))
weights.append(weight)
model_2plus1d.set_weights(weights)
inputs = np.random.rand(2, 8, 172, 172, 3)
logits_2plus1d = model_2plus1d(inputs)
logits_3d_2plus1d = model_3d_2plus1d(inputs)
# Ensure both models have the same output, since the weights are the same
self.assertAllEqual(logits_2plus1d.shape, logits_3d_2plus1d.shape)
self.assertAllClose(logits_2plus1d, logits_3d_2plus1d)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for movinet.py."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.movinet.modeling import movinet
class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
def test_network_creation(self):
"""Test creation of MoViNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = movinet.Movinet(
model_id='a0',
causal=True,
)
inputs = tf.keras.Input(shape=(8, 128, 128, 3), batch_size=1)
endpoints, states = network(inputs)
self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8])
self.assertAllEqual(endpoints['b0/l0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['b1/l0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['b2/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b3/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b4/l0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480])
self.assertNotEmpty(states)
def test_network_with_states(self):
"""Test creation of MoViNet family models with states."""
tf.keras.backend.set_image_data_format('channels_last')
network = movinet.Movinet(
model_id='a0',
causal=True,
)
inputs = tf.ones([1, 8, 128, 128, 3])
_, states = network(inputs)
endpoints, new_states = network(dict(image=inputs, states=states))
self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8])
self.assertAllEqual(endpoints['b0/l0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['b1/l0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['b2/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b3/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b4/l0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480])
self.assertNotEmpty(states)
self.assertNotEmpty(new_states)
def test_movinet_stream(self):
tf.keras.backend.set_image_data_format('channels_last')
model = movinet.Movinet(
model_id='a0',
causal=True,
)
inputs = tf.ones([1, 5, 128, 128, 3])
expected_endpoints, _ = model(dict(image=inputs, states={}))
frames = tf.split(inputs, inputs.shape[1], axis=1)
output, states = None, {}
for frame in frames:
output, states = model(dict(image=frame, states=states))
predicted_endpoints = output
predicted = predicted_endpoints['head']
# The expected final output is simply the mean across frames
expected = expected_endpoints['head']
expected = tf.reduce_mean(expected, 1, keepdims=True)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_2plus1d_stream(self):
tf.keras.backend.set_image_data_format('channels_last')
model = movinet.Movinet(
model_id='a0',
causal=True,
conv_type='2plus1d',
)
inputs = tf.ones([1, 5, 128, 128, 3])
expected_endpoints, _ = model(dict(image=inputs, states={}))
frames = tf.split(inputs, inputs.shape[1], axis=1)
output, states = None, {}
for frame in frames:
output, states = model(dict(image=frame, states=states))
predicted_endpoints = output
predicted = predicted_endpoints['head']
# The expected final output is simply the mean across frames
expected = expected_endpoints['head']
expected = tf.reduce_mean(expected, 1, keepdims=True)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_3d_2plus1d_stream(self):
tf.keras.backend.set_image_data_format('channels_last')
model = movinet.Movinet(
model_id='a0',
causal=True,
conv_type='3d_2plus1d',
)
inputs = tf.ones([1, 5, 128, 128, 3])
expected_endpoints, _ = model(dict(image=inputs, states={}))
frames = tf.split(inputs, inputs.shape[1], axis=1)
output, states = None, {}
for frame in frames:
output, states = model(dict(image=frame, states=states))
predicted_endpoints = output
predicted = predicted_endpoints['head']
# The expected final output is simply the mean across frames
expected = expected_endpoints['head']
expected = tf.reduce_mean(expected, 1, keepdims=True)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
kwargs = dict(
model_id='a0',
causal=True,
use_positional_encoding=True,
)
network = movinet.Movinet(**kwargs)
# Create another network object from the first object's config.
new_network = movinet.Movinet.from_config(network.get_config())
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == '__main__':
tf.test.main()
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "3E96e1UKQ8uR"
},
"source": [
"# MoViNet Tutorial\n",
"\n",
"This notebook provides basic example code to create, build, and run [MoViNets (Mobile Video Networks)](https://arxiv.org/pdf/2103.11511.pdf). Models use TF Keras and support inference in TF 1 and TF 2. Pretrained models are provided by [TensorFlow Hub](https://tfhub.dev/google/collections/movinet/), trained on [Kinetics 600](https://deepmind.com/research/open-source/kinetics) for video action classification."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8_oLnvJy7kz5"
},
"source": [
"## Setup\n",
"\n",
"It is recommended to run the models using GPUs or TPUs.\n",
"\n",
"To select a GPU/TPU in Colab, select `Runtime \u003e Change runtime type \u003e Hardware accelerator` dropdown in the top menu.\n",
"\n",
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"- tf-models-official is the stable Model Garden package. Note that it may not include the latest changes in the tensorflow_models github repo.\n",
"- To include latest changes, you may install tf-models-nightly, which is the nightly Model Garden package created daily automatically.\n",
"pip will install all models and dependencies automatically.\n",
"\n",
"Install the [mediapy](https://github.com/google/mediapy) package for visualizing images/videos."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "s3khsunT7kWa"
},
"outputs": [],
"source": [
"!pip install -q tf-models-nightly\n",
"\n",
"!command -v ffmpeg \u003e/dev/null || (apt update \u0026\u0026 apt install -y ffmpeg)\n",
"!pip install -q mediapy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dI_1csl6Q-gH"
},
"outputs": [],
"source": [
"from io import BytesIO\n",
"import os\n",
"from six.moves import urllib\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import mediapy as media\n",
"import numpy as np\n",
"from PIL import Image\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow_hub as hub\n",
"\n",
"from official.vision.beta.configs import video_classification\n",
"from official.vision.beta.projects.movinet.configs import movinet as movinet_configs\n",
"from official.vision.beta.projects.movinet.modeling import movinet as movinet_backbone\n",
"from official.vision.beta.projects.movinet.modeling import movinet_layers\n",
"from official.vision.beta.projects.movinet.modeling import movinet_model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6g0tuFvf71S9"
},
"source": [
"## Example Usage with TensorFlow Hub\n",
"\n",
"Load MoViNet-A2-Base from TensorFlow Hub, as part of the [MoViNet collection](https://tfhub.dev/google/collections/movinet/).\n",
"\n",
"The following code will:\n",
"\n",
"- Load a MoViNet KerasLayer from [tfhub.dev](https://tfhub.dev).\n",
"- Wrap the layer in a [Keras Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model).\n",
"- Load an example image, and reshape it to a single frame video.\n",
"- Classify the video"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nTUdhlRJzl2o"
},
"outputs": [],
"source": [
"movinet_a2_hub_url = 'https://tfhub.dev/tensorflow/movinet/a2/base/kinetics-600/classification/1'\n",
"\n",
"inputs = tf.keras.layers.Input(\n",
" shape=[None, None, None, 3],\n",
" dtype=tf.float32)\n",
"\n",
"encoder = hub.KerasLayer(movinet_a2_hub_url, trainable=True)\n",
"\n",
"# Important: To use tf.nn.conv3d on CPU, we must compile with tf.function.\n",
"encoder.call = tf.function(encoder.call, experimental_compile=True)\n",
"\n",
"# [batch_size, 600]\n",
"outputs = encoder(dict(image=inputs))\n",
"\n",
"model = tf.keras.Model(inputs, outputs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7kU1_pL10l0B"
},
"source": [
"To provide a simple example video for classification, we can load a static image and reshape it to produce a video with a single frame."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Iy0rKRrT723_"
},
"outputs": [],
"source": [
"image_url = 'https://upload.wikimedia.org/wikipedia/commons/8/84/Ski_Famille_-_Family_Ski_Holidays.jpg'\n",
"image_height = 224\n",
"image_width = 224\n",
"\n",
"with urllib.request.urlopen(image_url) as f:\n",
" image = Image.open(BytesIO(f.read())).resize((image_height, image_width))\n",
"video = tf.reshape(np.array(image), [1, 1, image_height, image_width, 3])\n",
"video = tf.cast(video, tf.float32) / 255.\n",
"\n",
"image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yf6EefHuWfxC"
},
"source": [
"Run the model and output the predicted label. Expected output should be skiing (labels 464-467). E.g., 465 = \"skiing crosscountry\".\n",
"\n",
"See [here](https://gist.github.com/willprice/f19da185c9c5f32847134b87c1960769#file-kinetics_600_labels-csv) for a full list of all labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OOpEKuqH8sH7"
},
"outputs": [],
"source": [
"output = model(video)\n",
"output_label_index = tf.argmax(output, -1)[0].numpy()\n",
"\n",
"print(output_label_index)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_s-7bEoa3f8g"
},
"source": [
"## Example Usage with the TensorFlow Model Garden\n",
"\n",
"Fine-tune MoViNet-A0-Base on [UCF-101](https://www.crcv.ucf.edu/research/data-sets/ucf101/).\n",
"\n",
"The following code will:\n",
"\n",
"- Load the UCF-101 dataset with [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/ucf101).\n",
"- Create a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) pipeline for training and evaluation.\n",
"- Display some example videos from the dataset.\n",
"- Build a MoViNet model and load pretrained weights.\n",
"- Fine-tune the final classifier layers on UCF-101."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o7unW4WVr580"
},
"source": [
"### Load the UCF-101 Dataset with TensorFlow Datasets\n",
"\n",
"Calling `download_and_prepare()` will automatically download the dataset. After downloading, this cell will output information about the dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 2957,
"status": "ok",
"timestamp": 1619748263684,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 360
},
"id": "boQHbcfDhXpJ",
"outputId": "eabc3307-d6bf-4f29-cc5a-c8dc6360701b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of classes: 101\n",
"Number of examples for train: 9537\n",
"Number of examples for test: 3783\n",
"\n"
]
},
{
"data": {
"text/plain": [
"tfds.core.DatasetInfo(\n",
" name='ucf101',\n",
" full_name='ucf101/ucf101_1_256/2.0.0',\n",
" description=\"\"\"\n",
" A 101-label video classification dataset.\n",
" \"\"\",\n",
" config_description=\"\"\"\n",
" 256x256 UCF with the first action recognition split.\n",
" \"\"\",\n",
" homepage='https://www.crcv.ucf.edu/data-sets/ucf101/',\n",
" data_path='/readahead/128M/placer/prod/home/tensorflow-datasets-cns-storage-owner/datasets/ucf101/ucf101_1_256/2.0.0',\n",
" download_size=6.48 GiB,\n",
" dataset_size=Unknown size,\n",
" features=FeaturesDict({\n",
" 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=101),\n",
" 'video': Video(Image(shape=(256, 256, 3), dtype=tf.uint8)),\n",
" }),\n",
" supervised_keys=None,\n",
" splits={\n",
" 'test': \u003cSplitInfo num_examples=3783, num_shards=32\u003e,\n",
" 'train': \u003cSplitInfo num_examples=9537, num_shards=64\u003e,\n",
" },\n",
" citation=\"\"\"@article{DBLP:journals/corr/abs-1212-0402,\n",
" author = {Khurram Soomro and\n",
" Amir Roshan Zamir and\n",
" Mubarak Shah},\n",
" title = {{UCF101:} {A} Dataset of 101 Human Actions Classes From Videos in\n",
" The Wild},\n",
" journal = {CoRR},\n",
" volume = {abs/1212.0402},\n",
" year = {2012},\n",
" url = {http://arxiv.org/abs/1212.0402},\n",
" archivePrefix = {arXiv},\n",
" eprint = {1212.0402},\n",
" timestamp = {Mon, 13 Aug 2018 16:47:45 +0200},\n",
" biburl = {https://dblp.org/rec/bib/journals/corr/abs-1212-0402},\n",
" bibsource = {dblp computer science bibliography, https://dblp.org}\n",
" }\"\"\",\n",
")"
]
},
"execution_count": 0,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"dataset_name = 'ucf101'\n",
"\n",
"builder = tfds.builder(dataset_name)\n",
"builder.download_and_prepare()\n",
"\n",
"num_classes = builder.info.features['label'].num_classes\n",
"num_examples = {\n",
" name: split.num_examples\n",
" for name, split in builder.info.splits.items()\n",
"}\n",
"\n",
"print('Number of classes:', num_classes)\n",
"print('Number of examples for train:', num_examples['train'])\n",
"print('Number of examples for test:', num_examples['test'])\n",
"print()\n",
"\n",
"builder.info"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BsJJgnBBqDKZ"
},
"source": [
"Build the training and evaluation datasets."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9cO_BCu9le3r"
},
"outputs": [],
"source": [
"batch_size = 8\n",
"num_frames = 8\n",
"frame_stride = 10\n",
"resolution = 172\n",
"\n",
"def format_features(features):\n",
" video = features['video']\n",
" video = video[:, ::frame_stride]\n",
" video = video[:, :num_frames]\n",
"\n",
" video = tf.reshape(video, [-1, video.shape[2], video.shape[3], 3])\n",
" video = tf.image.resize(video, (resolution, resolution))\n",
" video = tf.reshape(video, [-1, num_frames, resolution, resolution, 3])\n",
" video = tf.cast(video, tf.float32) / 255.\n",
"\n",
" label = tf.one_hot(features['label'], num_classes)\n",
" return (video, label)\n",
"\n",
"train_dataset = builder.as_dataset(\n",
" split='train',\n",
" batch_size=batch_size,\n",
" shuffle_files=True)\n",
"train_dataset = train_dataset.map(\n",
" format_features,\n",
" num_parallel_calls=tf.data.AUTOTUNE)\n",
"train_dataset = train_dataset.repeat()\n",
"train_dataset = train_dataset.prefetch(2)\n",
"\n",
"test_dataset = builder.as_dataset(\n",
" split='test',\n",
" batch_size=batch_size)\n",
"test_dataset = test_dataset.map(\n",
" format_features,\n",
" num_parallel_calls=tf.data.AUTOTUNE,\n",
" deterministic=True)\n",
"test_dataset = test_dataset.prefetch(2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rToX7_Ymgh57"
},
"source": [
"Display some example videos from the dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KG8Z7rUj06of"
},
"outputs": [],
"source": [
"videos, labels = next(iter(train_dataset))\n",
"media.show_videos(videos.numpy(), codec='gif', fps=5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R3RHeuHdsd_3"
},
"source": [
"### Build MoViNet-A0-Base and Load Pretrained Weights"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JXVQOP9Rqk0I"
},
"source": [
"Here we create a MoViNet model using the open source code provided in [tensorflow/models](https://github.com/tensorflow/models) and load the pretrained weights. Here we freeze the all layers except the final classifier head to speed up fine-tuning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JpfxpeGSsbzJ"
},
"outputs": [],
"source": [
"model_id = 'a0'\n",
"\n",
"tf.keras.backend.clear_session()\n",
"\n",
"backbone = movinet.Movinet(\n",
" model_id=model_id,\n",
" stochastic_depth_rate=0.)\n",
"model = movinet_model.MovinetClassifier(\n",
" backbone=backbone,\n",
" num_classes=600,\n",
" dropout_rate=0.)\n",
"model.build([batch_size, num_frames, resolution, resolution, 3])\n",
"\n",
"# Load pretrained weights from TF Hub\n",
"movinet_hub_url = f'https://tfhub.dev/tensorflow/movinet/{model_id}/base/kinetics-600/classification/1'\n",
"movinet_hub_model = hub.KerasLayer(movinet_hub_url, trainable=True)\n",
"pretrained_weights = {w.name: w for w in movinet_hub_model.weights}\n",
"model_weights = {w.name: w for w in model.weights}\n",
"for name in pretrained_weights:\n",
" model_weights[name].assign(pretrained_weights[name])\n",
"\n",
"# Wrap the backbone with a new classifier to create a new classifier head\n",
"# with num_classes outputs\n",
"model = movinet_model.MovinetClassifier(\n",
" backbone=backbone,\n",
" num_classes=num_classes)\n",
"model.build([batch_size, num_frames, resolution, resolution, 3])\n",
"\n",
"# Freeze all layers except for the final classifier head\n",
"for layer in model.layers[:-1]:\n",
" layer.trainable = False\n",
"model.layers[-1].trainable = True"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ucntdu2xqgXB"
},
"source": [
"Configure fine-tuning with training/evaluation steps, loss object, metrics, learning rate, optimizer, and callbacks.\n",
"\n",
"Here we use 3 epochs. Training for more epochs should improve accuracy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WUYTw48BouTu"
},
"outputs": [],
"source": [
"num_epochs = 3\n",
"\n",
"train_steps = num_examples['train'] // batch_size\n",
"total_train_steps = train_steps * num_epochs\n",
"test_steps = num_examples['test'] // batch_size\n",
"\n",
"loss_obj = tf.keras.losses.CategoricalCrossentropy(\n",
" from_logits=True,\n",
" label_smoothing=0.1)\n",
"\n",
"metrics = [\n",
" tf.keras.metrics.TopKCategoricalAccuracy(\n",
" k=1, name='top_1', dtype=tf.float32),\n",
" tf.keras.metrics.TopKCategoricalAccuracy(\n",
" k=5, name='top_5', dtype=tf.float32),\n",
"]\n",
"\n",
"initial_learning_rate = 0.01\n",
"learning_rate = tf.keras.optimizers.schedules.CosineDecay(\n",
" initial_learning_rate, decay_steps=total_train_steps,\n",
")\n",
"optimizer = tf.keras.optimizers.RMSprop(\n",
" learning_rate, rho=0.9, momentum=0.9, epsilon=1.0, clipnorm=1.0)\n",
"\n",
"model.compile(loss=loss_obj, optimizer=optimizer, metrics=metrics)\n",
"\n",
"callbacks = [\n",
" tf.keras.callbacks.TensorBoard(),\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0IyAOOlcpHna"
},
"source": [
"Run the fine-tuning with Keras compile/fit. After fine-tuning the model, we should be able to achieve \u003e70% accuracy on the test set."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 982253,
"status": "ok",
"timestamp": 1619750139919,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 360
},
"id": "Zecc_K3lga8I",
"outputId": "e4c5c61e-aa08-47db-c04c-42dea3efb545"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"1192/1192 [==============================] - 348s 286ms/step - loss: 3.4914 - top_1: 0.3639 - top_5: 0.6294 - val_loss: 2.5153 - val_top_1: 0.5975 - val_top_5: 0.8565\n",
"Epoch 2/3\n",
"1192/1192 [==============================] - 286s 240ms/step - loss: 2.1397 - top_1: 0.6794 - top_5: 0.9231 - val_loss: 2.0695 - val_top_1: 0.6838 - val_top_5: 0.9070\n",
"Epoch 3/3\n",
"1192/1192 [==============================] - 348s 292ms/step - loss: 1.8925 - top_1: 0.7660 - top_5: 0.9454 - val_loss: 1.9848 - val_top_1: 0.7116 - val_top_5: 0.9227\n"
]
}
],
"source": [
"results = model.fit(\n",
" train_dataset,\n",
" validation_data=test_dataset,\n",
" epochs=num_epochs,\n",
" steps_per_epoch=train_steps,\n",
" validation_steps=test_steps,\n",
" callbacks=callbacks,\n",
" validation_freq=1,\n",
" verbose=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XuH8XflmpU9d"
},
"source": [
"We can also view the training and evaluation progress in TensorBoard."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9fZhzhRJRd2J"
},
"outputs": [],
"source": [
"%reload_ext tensorboard\n",
"%tensorboard --logdir logs --port 0"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
"kind": "private"
},
"name": "movinet_tutorial.ipynb",
"provenance": [
{
"file_id": "11msGCxFjxwioBOBJavP9alfTclUQCJf-",
"timestamp": 1617043059980
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r"""Training driver.
To train:
CONFIG_FILE=official/vision/beta/projects/movinet/configs/yaml/movinet_a0_k600_8x8.yaml
python3 official/vision/beta/projects/movinet/train.py \
--experiment=movinet_kinetics600 \
--mode=train \
--model_dir=/tmp/movinet/ \
--config_file=${CONFIG_FILE} \
--params_override="" \
--gin_file="" \
--gin_params="" \
--tpu="" \
--tf_data_service=""
"""
from absl import app
from absl import flags
import gin
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# Import movinet libraries to register the backbone and model into tf.vision
# model garden factory.
# pylint: disable=unused-import
from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model
# pylint: enable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
if 'train_and_eval' in FLAGS.mode:
assert (params.task.train_data.feature_shape ==
params.task.validation_data.feature_shape), (
f'train {params.task.train_data.feature_shape} != validate '
f'{params.task.validation_data.feature_shape}')
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for train.py."""
import json
import os
import random
from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.movinet import train as train_lib
FLAGS = flags.FLAGS
class TrainTest(tf.test.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord')
# pylint: disable=g-complex-comprehension
examples = [
tfexample_utils.make_video_test_example(
image_shape=(32, 32, 3),
audio_shape=(20, 128),
label=random.randint(0, 100)) for _ in range(2)
]
# pylint: enable=g-complex-comprehension
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def test_train_and_evaluation_pipeline_runs(self):
saved_flag_values = flagsaver.save_flag_values()
train_lib.tfm_flags.define_flags()
FLAGS.mode = 'train'
FLAGS.model_dir = self._model_dir
FLAGS.experiment = 'movinet_kinetics600'
logging.info('Test pipeline correctness.')
num_frames = 4
# Test model training pipeline runs.
params_override = json.dumps({
'trainer': {
'train_steps': 2,
'validation_steps': 2,
},
'task': {
'train_data': {
'input_path': self._data_path,
'file_type': 'tfrecord',
'feature_shape': [num_frames, 32, 32, 3],
'global_batch_size': 2,
},
'validation_data': {
'input_path': self._data_path,
'file_type': 'tfrecord',
'global_batch_size': 2,
'feature_shape': [num_frames * 2, 32, 32, 3],
}
}
})
FLAGS.params_override = params_override
train_lib.main('unused_args')
# Test model evaluation pipeline runs on newly produced checkpoint.
FLAGS.mode = 'eval'
with train_lib.gin.unlock_config():
train_lib.main('unused_args')
flagsaver.restore_flag_values(saved_flag_values)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment