Commit dcf52aab authored by pkulzc's avatar pkulzc Committed by Sergio Guadarrama
Browse files

Merged commit includes the following changes: (#6062)

228203246  by Sergio Guadarrama:

    Add a write text graphdef option.

--
226110161  by Sergio Guadarrama:

    Add license to i3d/s3dg and tests.

--
226074013  by Sergio Guadarrama:

    Network definitions for I3D and S3D-G.

--
224394404  by Sergio Guadarrama:

    Add video model option for exported inference graphs.

--
224220779  by Sergio Guadarrama:

    Internal change

223589268  by Sergio Guadarrama:

    Internal change

PiperOrigin-RevId: 228203246
parent 8d5d9d19
...@@ -199,6 +199,7 @@ py_library( ...@@ -199,6 +199,7 @@ py_library(
":alexnet", ":alexnet",
":cifarnet", ":cifarnet",
":cyclegan", ":cyclegan",
":i3d",
":inception", ":inception",
":lenet", ":lenet",
":mobilenet", ":mobilenet",
...@@ -208,6 +209,7 @@ py_library( ...@@ -208,6 +209,7 @@ py_library(
":pnasnet", ":pnasnet",
":resnet_v1", ":resnet_v1",
":resnet_v2", ":resnet_v2",
":s3dg",
":vgg", ":vgg",
], ],
) )
...@@ -279,6 +281,38 @@ py_test( ...@@ -279,6 +281,38 @@ py_test(
], ],
) )
py_library(
name = "i3d",
srcs = ["nets/i3d.py"],
srcs_version = "PY2AND3",
deps = [
":i3d_utils",
":s3dg",
# "//tensorflow",
],
)
py_test(
name = "i3d_test",
size = "large",
srcs = ["nets/i3d_test.py"],
shard_count = 3,
srcs_version = "PY2AND3",
deps = [
":i3d",
# "//tensorflow",
],
)
py_library(
name = "i3d_utils",
srcs = ["nets/i3d_utils.py"],
srcs_version = "PY2AND3",
deps = [
# "//tensorflow",
],
)
py_library( py_library(
name = "inception", name = "inception",
srcs = ["nets/inception.py"], srcs = ["nets/inception.py"],
...@@ -653,6 +687,28 @@ py_test( ...@@ -653,6 +687,28 @@ py_test(
], ],
) )
py_library(
name = "s3dg",
srcs = ["nets/s3dg.py"],
srcs_version = "PY2AND3",
deps = [
":i3d_utils",
# "//tensorflow",
],
)
py_test(
name = "s3dg_test",
size = "large",
srcs = ["nets/s3dg_test.py"],
shard_count = 3,
srcs_version = "PY2AND3",
deps = [
":s3dg",
# "//tensorflow",
],
)
py_library( py_library(
name = "vgg", name = "vgg",
srcs = ["nets/vgg.py"], srcs = ["nets/vgg.py"],
...@@ -684,9 +740,9 @@ py_library( ...@@ -684,9 +740,9 @@ py_library(
py_test( py_test(
name = "nets_factory_test", name = "nets_factory_test",
size = "medium", size = "large",
srcs = ["nets/nets_factory_test.py"], srcs = ["nets/nets_factory_test.py"],
shard_count = 2, shard_count = 3,
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":nets_factory", ":nets_factory",
...@@ -709,7 +765,8 @@ py_library( ...@@ -709,7 +765,8 @@ py_library(
py_binary( py_binary(
name = "train_image_classifier", name = "train_image_classifier",
srcs = ["train_image_classifier.py"], srcs = ["train_image_classifier.py"],
paropts = ["--compress"], # WARNING: not supported in bazel; will be commented out by copybara.
# paropts = ["--compress"],
deps = [ deps = [
":train_image_classifier_lib", ":train_image_classifier_lib",
], ],
...@@ -737,7 +794,8 @@ py_binary( ...@@ -737,7 +794,8 @@ py_binary(
py_binary( py_binary(
name = "export_inference_graph", name = "export_inference_graph",
srcs = ["export_inference_graph.py"], srcs = ["export_inference_graph.py"],
paropts = ["--compress"], # WARNING: not supported in bazel; will be commented out by copybara.
# paropts = ["--compress"],
deps = [ deps = [
":dataset_factory", ":dataset_factory",
":nets_factory", ":nets_factory",
......
...@@ -55,6 +55,7 @@ bazel-bin/tensorflow/examples/label_image/label_image \ ...@@ -55,6 +55,7 @@ bazel-bin/tensorflow/examples/label_image/label_image \
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import tensorflow as tf import tensorflow as tf
...@@ -99,12 +100,25 @@ tf.app.flags.DEFINE_string( ...@@ -99,12 +100,25 @@ tf.app.flags.DEFINE_string(
tf.app.flags.DEFINE_bool( tf.app.flags.DEFINE_bool(
'quantize', False, 'whether to use quantized graph or not.') 'quantize', False, 'whether to use quantized graph or not.')
tf.app.flags.DEFINE_bool(
'is_video_model', False, 'whether to use 5-D inputs for video model.')
tf.app.flags.DEFINE_integer(
'num_frames', None,
'The number of frames to use. Only used if is_video_model is True.')
tf.app.flags.DEFINE_bool('write_text_graphdef', False,
'Whether to write a text version of graphdef.')
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
def main(_): def main(_):
if not FLAGS.output_file: if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file') raise ValueError('You must supply the path to save to with --output_file')
if FLAGS.is_video_model and not FLAGS.num_frames:
raise ValueError(
'Number of frames must be specified for video models with --num_frames')
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph: with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train', dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
...@@ -114,17 +128,28 @@ def main(_): ...@@ -114,17 +128,28 @@ def main(_):
num_classes=(dataset.num_classes - FLAGS.labels_offset), num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training) is_training=FLAGS.is_training)
image_size = FLAGS.image_size or network_fn.default_image_size image_size = FLAGS.image_size or network_fn.default_image_size
if FLAGS.is_video_model:
input_shape = [FLAGS.batch_size, FLAGS.num_frames,
image_size, image_size, 3]
else:
input_shape = [FLAGS.batch_size, image_size, image_size, 3]
placeholder = tf.placeholder(name='input', dtype=tf.float32, placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[FLAGS.batch_size, image_size, shape=input_shape)
image_size, 3])
network_fn(placeholder) network_fn(placeholder)
if FLAGS.quantize: if FLAGS.quantize:
tf.contrib.quantize.create_eval_graph() tf.contrib.quantize.create_eval_graph()
graph_def = graph.as_graph_def() graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f: if FLAGS.write_text_graphdef:
f.write(graph_def.SerializeToString()) tf.io.write_graph(
graph_def,
os.path.dirname(FLAGS.output_file),
os.path.basename(FLAGS.output_file),
as_text=True)
else:
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -134,7 +134,6 @@ def cyclegan_generator_resnet(images, ...@@ -134,7 +134,6 @@ def cyclegan_generator_resnet(images,
num_filters=64, num_filters=64,
upsample_fn=cyclegan_upsample, upsample_fn=cyclegan_upsample,
kernel_size=3, kernel_size=3,
num_outputs=3,
tanh_linear_slope=0.0, tanh_linear_slope=0.0,
is_training=False): is_training=False):
"""Defines the cyclegan resnet network architecture. """Defines the cyclegan resnet network architecture.
...@@ -156,7 +155,6 @@ def cyclegan_generator_resnet(images, ...@@ -156,7 +155,6 @@ def cyclegan_generator_resnet(images,
upsample_fn: Upsampling function for the decoder part of the generator. upsample_fn: Upsampling function for the decoder part of the generator.
kernel_size: Size w or list/tuple [h, w] of the filter kernels for all inner kernel_size: Size w or list/tuple [h, w] of the filter kernels for all inner
layers. layers.
num_outputs: Number of output layers. Defaults to 3 for RGB.
tanh_linear_slope: Slope of the linear function to add to the tanh over the tanh_linear_slope: Slope of the linear function to add to the tanh over the
logits. logits.
is_training: Whether the network is created in training mode or inference is_training: Whether the network is created in training mode or inference
...@@ -182,6 +180,7 @@ def cyclegan_generator_resnet(images, ...@@ -182,6 +180,7 @@ def cyclegan_generator_resnet(images,
raise ValueError('The input height must be a multiple of 4.') raise ValueError('The input height must be a multiple of 4.')
if width and width % 4 != 0: if width and width % 4 != 0:
raise ValueError('The input width must be a multiple of 4.') raise ValueError('The input width must be a multiple of 4.')
num_outputs = input_size[3]
if not isinstance(kernel_size, (list, tuple)): if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size, kernel_size] kernel_size = [kernel_size, kernel_size]
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the definition for Inflated 3D Inception V1 (I3D).
The network architecture is proposed by:
Joao Carreira and Andrew Zisserman,
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset.
https://arxiv.org/abs/1705.07750
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import i3d_utils
from nets import s3dg
slim = tf.contrib.slim
trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
conv3d_spatiotemporal = i3d_utils.conv3d_spatiotemporal
def i3d_arg_scope(weight_decay=1e-7,
batch_norm_decay=0.999,
batch_norm_epsilon=0.001,
use_renorm=False,
separable_conv3d=False):
"""Defines default arg_scope for I3D.
Args:
weight_decay: The weight decay to use for regularizing the model.
batch_norm_decay: Decay for batch norm moving average.
batch_norm_epsilon: Small float added to variance to avoid dividing by zero
in batch norm.
use_renorm: Whether to use batch renormalization or not.
separable_conv3d: Whether to use separable 3d Convs.
Returns:
sc: An arg_scope to use for the models.
"""
batch_norm_params = {
# Decay for the moving averages.
'decay': batch_norm_decay,
# epsilon to prevent 0s in variance.
'epsilon': batch_norm_epsilon,
# Turns off fused batch norm.
'fused': False,
'renorm': use_renorm,
# collection containing the moving mean and moving variance.
'variables_collections': {
'beta': None,
'gamma': None,
'moving_mean': ['moving_vars'],
'moving_variance': ['moving_vars'],
}
}
with slim.arg_scope(
[slim.conv3d, conv3d_spatiotemporal],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
with slim.arg_scope(
[conv3d_spatiotemporal], separable=separable_conv3d) as sc:
return sc
def i3d_base(inputs, final_endpoint='Mixed_5c',
scope='InceptionV1'):
"""Defines the I3D base architecture.
Note that we use the names as defined in Inception V1 to facilitate checkpoint
conversion from an image-trained Inception V1 checkpoint to I3D checkpoint.
Args:
inputs: A 5-D float tensor of size [batch_size, num_frames, height, width,
channels].
final_endpoint: Specifies the endpoint to construct the network up to. It
can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c']
scope: Optional variable_scope.
Returns:
A dictionary from components of the network to the corresponding activation.
Raises:
ValueError: if final_endpoint is not set to one of the predefined values.
"""
return s3dg.s3dg_base(
inputs,
first_temporal_kernel_size=7,
temporal_conv_startat='Conv2d_2c_3x3',
gating_startat=None,
final_endpoint=final_endpoint,
min_depth=16,
depth_multiplier=1.0,
data_format='NDHWC',
scope=scope)
def i3d(inputs,
num_classes=1000,
dropout_keep_prob=0.8,
is_training=True,
prediction_fn=slim.softmax,
spatial_squeeze=True,
reuse=None,
scope='InceptionV1'):
"""Defines the I3D architecture.
The default image size used to train this network is 224x224.
Args:
inputs: A 5-D float tensor of size [batch_size, num_frames, height, width,
channels].
num_classes: number of predicted classes.
dropout_keep_prob: the percentage of activation values that are retained.
is_training: whether is training or not.
prediction_fn: a function to get predictions out of logits.
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
Returns:
logits: the pre-softmax activations, a tensor of size
[batch_size, num_classes]
end_points: a dictionary from components of the network to the corresponding
activation.
"""
# Final pooling and prediction
with tf.variable_scope(
scope, 'InceptionV1', [inputs, num_classes], reuse=reuse) as scope:
with slim.arg_scope(
[slim.batch_norm, slim.dropout], is_training=is_training):
net, end_points = i3d_base(inputs, scope=scope)
with tf.variable_scope('Logits'):
kernel_size = i3d_utils.reduced_kernel_size_3d(net, [2, 7, 7])
net = slim.avg_pool3d(
net, kernel_size, stride=1, scope='AvgPool_0a_7x7')
net = slim.dropout(net, dropout_keep_prob, scope='Dropout_0b')
logits = slim.conv3d(
net,
num_classes, [1, 1, 1],
activation_fn=None,
normalizer_fn=None,
scope='Conv2d_0c_1x1')
# Temporal average pooling.
logits = tf.reduce_mean(logits, axis=1)
if spatial_squeeze:
logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
end_points['Logits'] = logits
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
return logits, end_points
i3d.default_image_size = 224
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for networks.i3d."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import i3d
class I3DTest(tf.test.TestCase):
def testBuildClassificationNetwork(self):
batch_size = 5
num_frames = 64
height, width = 224, 224
num_classes = 1000
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
logits, end_points = i3d.i3d(inputs, num_classes)
self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
self.assertListEqual(logits.get_shape().as_list(),
[batch_size, num_classes])
self.assertTrue('Predictions' in end_points)
self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
[batch_size, num_classes])
def testBuildBaseNetwork(self):
batch_size = 5
num_frames = 64
height, width = 224, 224
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
mixed_6c, end_points = i3d.i3d_base(inputs)
self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c'))
self.assertListEqual(mixed_6c.get_shape().as_list(),
[batch_size, 8, 7, 7, 1024])
expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b',
'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c',
'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2',
'Mixed_5b', 'Mixed_5c']
self.assertItemsEqual(end_points.keys(), expected_endpoints)
def testBuildOnlyUptoFinalEndpoint(self):
batch_size = 5
num_frames = 64
height, width = 224, 224
endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d',
'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b',
'Mixed_5c']
for index, endpoint in enumerate(endpoints):
with tf.Graph().as_default():
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
out_tensor, end_points = i3d.i3d_base(
inputs, final_endpoint=endpoint)
self.assertTrue(out_tensor.op.name.startswith(
'InceptionV1/' + endpoint))
self.assertItemsEqual(endpoints[:index+1], end_points)
def testBuildAndCheckAllEndPointsUptoMixed5c(self):
batch_size = 5
num_frames = 64
height, width = 224, 224
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
_, end_points = i3d.i3d_base(inputs,
final_endpoint='Mixed_5c')
endpoints_shapes = {'Conv2d_1a_7x7': [5, 32, 112, 112, 64],
'MaxPool_2a_3x3': [5, 32, 56, 56, 64],
'Conv2d_2b_1x1': [5, 32, 56, 56, 64],
'Conv2d_2c_3x3': [5, 32, 56, 56, 192],
'MaxPool_3a_3x3': [5, 32, 28, 28, 192],
'Mixed_3b': [5, 32, 28, 28, 256],
'Mixed_3c': [5, 32, 28, 28, 480],
'MaxPool_4a_3x3': [5, 16, 14, 14, 480],
'Mixed_4b': [5, 16, 14, 14, 512],
'Mixed_4c': [5, 16, 14, 14, 512],
'Mixed_4d': [5, 16, 14, 14, 512],
'Mixed_4e': [5, 16, 14, 14, 528],
'Mixed_4f': [5, 16, 14, 14, 832],
'MaxPool_5a_2x2': [5, 8, 7, 7, 832],
'Mixed_5b': [5, 8, 7, 7, 832],
'Mixed_5c': [5, 8, 7, 7, 1024]}
self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
for endpoint_name, expected_shape in endpoints_shapes.iteritems():
self.assertTrue(endpoint_name in end_points)
self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
expected_shape)
def testHalfSizeImages(self):
batch_size = 5
num_frames = 64
height, width = 112, 112
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
mixed_5c, _ = i3d.i3d_base(inputs)
self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c'))
self.assertListEqual(mixed_5c.get_shape().as_list(),
[batch_size, 8, 4, 4, 1024])
def testTenFrames(self):
batch_size = 5
num_frames = 10
height, width = 224, 224
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
mixed_5c, _ = i3d.i3d_base(inputs)
self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c'))
self.assertListEqual(mixed_5c.get_shape().as_list(),
[batch_size, 2, 7, 7, 1024])
def testEvaluation(self):
batch_size = 2
num_frames = 64
height, width = 224, 224
num_classes = 1000
eval_inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
logits, _ = i3d.i3d(eval_inputs, num_classes,
is_training=False)
predictions = tf.argmax(logits, 1)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(predictions)
self.assertEquals(output.shape, (batch_size,))
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for building I3D network models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
# Orignaly, add_arg_scope = slim.add_arg_scope and layers = slim, now switch to
# more update-to-date tf.contrib.* API.
add_arg_scope = tf.contrib.framework.add_arg_scope
layers = tf.contrib.layers
def center_initializer():
"""Centering Initializer for I3D.
This initializer allows identity mapping for temporal convolution at the
initialization, which is critical for a desired convergence behavior
for training a seprable I3D model.
The centering behavior of this initializer requires an odd-sized kernel,
typically set to 3.
Returns:
A weight initializer op used in temporal convolutional layers.
Raises:
ValueError: Input tensor data type has to be tf.float32.
ValueError: If input tensor is not a 5-D tensor.
ValueError: If input and output channel dimensions are different.
ValueError: If spatial kernel sizes are not 1.
ValueError: If temporal kernel size is even.
"""
def _initializer(shape, dtype=tf.float32, partition_info=None): # pylint: disable=unused-argument
"""Initializer op."""
if dtype != tf.float32 and dtype != tf.bfloat16:
raise ValueError(
'Input tensor data type has to be tf.float32 or tf.bfloat16.')
if len(shape) != 5:
raise ValueError('Input tensor has to be 5-D.')
if shape[3] != shape[4]:
raise ValueError('Input and output channel dimensions must be the same.')
if shape[1] != 1 or shape[2] != 1:
raise ValueError('Spatial kernel sizes must be 1 (pointwise conv).')
if shape[0] % 2 == 0:
raise ValueError('Temporal kernel size has to be odd.')
center_pos = int(shape[0] / 2)
init_mat = np.zeros(
[shape[0], shape[1], shape[2], shape[3], shape[4]], dtype=np.float32)
for i in range(0, shape[3]):
init_mat[center_pos, 0, 0, i, i] = 1.0
init_op = tf.constant(init_mat, dtype=dtype)
return init_op
return _initializer
@add_arg_scope
def conv3d_spatiotemporal(inputs,
num_outputs,
kernel_size,
stride=1,
padding='SAME',
activation_fn=None,
normalizer_fn=None,
normalizer_params=None,
weights_regularizer=None,
separable=False,
data_format='NDHWC',
scope=''):
"""A wrapper for conv3d to model spatiotemporal representations.
This allows switching between original 3D convolution and separable 3D
convolutions for spatial and temporal features respectively. On Kinetics,
seprable 3D convolutions yields better classification performance.
Args:
inputs: a 5-D tensor `[batch_size, depth, height, width, channels]`.
num_outputs: integer, the number of output filters.
kernel_size: a list of length 3
`[kernel_depth, kernel_height, kernel_width]` of the filters. Can be an
int if all values are the same.
stride: a list of length 3 `[stride_depth, stride_height, stride_width]`.
Can be an int if all strides are the same.
padding: one of `VALID` or `SAME`.
activation_fn: activation function.
normalizer_fn: normalization function to use instead of `biases`.
normalizer_params: dictionary of normalization function parameters.
weights_regularizer: Optional regularizer for the weights.
separable: If `True`, use separable spatiotemporal convolutions.
data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC".
The data format of the input and output data. With the default format
"NDHWC", the data is stored in the order of: [batch, in_depth, in_height,
in_width, in_channels]. Alternatively, the format could be "NCDHW", the
data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
scope: scope for `variable_scope`.
Returns:
A tensor representing the output of the (separable) conv3d operation.
"""
assert len(kernel_size) == 3
if separable and kernel_size[0] != 1:
spatial_kernel_size = [1, kernel_size[1], kernel_size[2]]
temporal_kernel_size = [kernel_size[0], 1, 1]
if isinstance(stride, list) and len(stride) == 3:
spatial_stride = [1, stride[1], stride[2]]
temporal_stride = [stride[0], 1, 1]
else:
spatial_stride = [1, stride, stride]
temporal_stride = [stride, 1, 1]
net = layers.conv3d(
inputs,
num_outputs,
spatial_kernel_size,
stride=spatial_stride,
padding=padding,
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
normalizer_params=normalizer_params,
weights_regularizer=weights_regularizer,
data_format=data_format,
scope=scope)
net = layers.conv3d(
net,
num_outputs,
temporal_kernel_size,
stride=temporal_stride,
padding=padding,
scope=scope + '/temporal',
activation_fn=activation_fn,
normalizer_fn=None,
data_format=data_format,
weights_initializer=center_initializer())
return net
else:
return layers.conv3d(
inputs,
num_outputs,
kernel_size,
stride=stride,
padding=padding,
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
normalizer_params=normalizer_params,
weights_regularizer=weights_regularizer,
data_format=data_format,
scope=scope)
@add_arg_scope
def inception_block_v1_3d(inputs,
num_outputs_0_0a,
num_outputs_1_0a,
num_outputs_1_0b,
num_outputs_2_0a,
num_outputs_2_0b,
num_outputs_3_0b,
temporal_kernel_size=3,
self_gating_fn=None,
data_format='NDHWC',
scope=''):
"""A 3D Inception v1 block.
This allows use of separable 3D convolutions and self-gating, as
described in:
Saining Xie, Chen Sun, Jonathan Huang, Zhuowen Tu and Kevin Murphy,
Rethinking Spatiotemporal Feature Learning For Video Understanding.
https://arxiv.org/abs/1712.04851.
Args:
inputs: a 5-D tensor `[batch_size, depth, height, width, channels]`.
num_outputs_0_0a: integer, the number of output filters for Branch 0,
operation Conv2d_0a_1x1.
num_outputs_1_0a: integer, the number of output filters for Branch 1,
operation Conv2d_0a_1x1.
num_outputs_1_0b: integer, the number of output filters for Branch 1,
operation Conv2d_0b_3x3.
num_outputs_2_0a: integer, the number of output filters for Branch 2,
operation Conv2d_0a_1x1.
num_outputs_2_0b: integer, the number of output filters for Branch 2,
operation Conv2d_0b_3x3.
num_outputs_3_0b: integer, the number of output filters for Branch 3,
operation Conv2d_0b_1x1.
temporal_kernel_size: integer, the size of the temporal convolutional
filters in the conv3d_spatiotemporal blocks.
self_gating_fn: function which optionally performs self-gating.
Must have two arguments, `inputs` and `scope`, and return one output
tensor the same size as `inputs`. If `None`, no self-gating is
applied.
data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC".
The data format of the input and output data. With the default format
"NDHWC", the data is stored in the order of: [batch, in_depth, in_height,
in_width, in_channels]. Alternatively, the format could be "NCDHW", the
data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
scope: scope for `variable_scope`.
Returns:
A 5-D tensor `[batch_size, depth, height, width, out_channels]`, where
`out_channels = num_outputs_0_0a + num_outputs_1_0b + num_outputs_2_0b
+ num_outputs_3_0b`.
"""
use_gating = self_gating_fn is not None
with tf.variable_scope(scope):
with tf.variable_scope('Branch_0'):
branch_0 = layers.conv3d(
inputs, num_outputs_0_0a, [1, 1, 1], scope='Conv2d_0a_1x1')
if use_gating:
branch_0 = self_gating_fn(branch_0, scope='Conv2d_0a_1x1')
with tf.variable_scope('Branch_1'):
branch_1 = layers.conv3d(
inputs, num_outputs_1_0a, [1, 1, 1], scope='Conv2d_0a_1x1')
branch_1 = conv3d_spatiotemporal(
branch_1, num_outputs_1_0b, [temporal_kernel_size, 3, 3],
scope='Conv2d_0b_3x3')
if use_gating:
branch_1 = self_gating_fn(branch_1, scope='Conv2d_0b_3x3')
with tf.variable_scope('Branch_2'):
branch_2 = layers.conv3d(
inputs, num_outputs_2_0a, [1, 1, 1], scope='Conv2d_0a_1x1')
branch_2 = conv3d_spatiotemporal(
branch_2, num_outputs_2_0b, [temporal_kernel_size, 3, 3],
scope='Conv2d_0b_3x3')
if use_gating:
branch_2 = self_gating_fn(branch_2, scope='Conv2d_0b_3x3')
with tf.variable_scope('Branch_3'):
branch_3 = layers.max_pool3d(inputs, [3, 3, 3], scope='MaxPool_0a_3x3')
branch_3 = layers.conv3d(
branch_3, num_outputs_3_0b, [1, 1, 1], scope='Conv2d_0b_1x1')
if use_gating:
branch_3 = self_gating_fn(branch_3, scope='Conv2d_0b_1x1')
index_c = data_format.index('C')
assert 1 <= index_c <= 4, 'Cannot identify channel dimension.'
output = tf.concat([branch_0, branch_1, branch_2, branch_3], index_c)
return output
def reduced_kernel_size_3d(input_tensor, kernel_size):
"""Define kernel size which is automatically reduced for small input.
If the shape of the input images is unknown at graph construction time this
function assumes that the input images are large enough.
Args:
input_tensor: input tensor of size
[batch_size, time, height, width, channels].
kernel_size: desired kernel size of length 3, corresponding to time,
height and width.
Returns:
a tensor with the kernel size.
"""
assert len(kernel_size) == 3
shape = input_tensor.get_shape().as_list()
assert len(shape) == 5
if None in shape[1:4]:
kernel_size_out = kernel_size
else:
kernel_size_out = [min(shape[1], kernel_size[0]),
min(shape[2], kernel_size[1]),
min(shape[3], kernel_size[2])]
return kernel_size_out
...@@ -23,17 +23,20 @@ import tensorflow as tf ...@@ -23,17 +23,20 @@ import tensorflow as tf
from nets import alexnet from nets import alexnet
from nets import cifarnet from nets import cifarnet
from nets import i3d
from nets import inception from nets import inception
from nets import lenet from nets import lenet
from nets import mobilenet_v1 from nets import mobilenet_v1
from nets import overfeat from nets import overfeat
from nets import resnet_v1 from nets import resnet_v1
from nets import resnet_v2 from nets import resnet_v2
from nets import s3dg
from nets import vgg from nets import vgg
from nets.mobilenet import mobilenet_v2 from nets.mobilenet import mobilenet_v2
from nets.nasnet import nasnet from nets.nasnet import nasnet
from nets.nasnet import pnasnet from nets.nasnet import pnasnet
slim = tf.contrib.slim slim = tf.contrib.slim
networks_map = {'alexnet_v2': alexnet.alexnet_v2, networks_map = {'alexnet_v2': alexnet.alexnet_v2,
...@@ -47,6 +50,8 @@ networks_map = {'alexnet_v2': alexnet.alexnet_v2, ...@@ -47,6 +50,8 @@ networks_map = {'alexnet_v2': alexnet.alexnet_v2,
'inception_v3': inception.inception_v3, 'inception_v3': inception.inception_v3,
'inception_v4': inception.inception_v4, 'inception_v4': inception.inception_v4,
'inception_resnet_v2': inception.inception_resnet_v2, 'inception_resnet_v2': inception.inception_resnet_v2,
'i3d': i3d.i3d,
's3dg': s3dg.s3dg,
'lenet': lenet.lenet, 'lenet': lenet.lenet,
'resnet_v1_50': resnet_v1.resnet_v1_50, 'resnet_v1_50': resnet_v1.resnet_v1_50,
'resnet_v1_101': resnet_v1.resnet_v1_101, 'resnet_v1_101': resnet_v1.resnet_v1_101,
...@@ -82,6 +87,8 @@ arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, ...@@ -82,6 +87,8 @@ arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
'inception_v4': inception.inception_v4_arg_scope, 'inception_v4': inception.inception_v4_arg_scope,
'inception_resnet_v2': 'inception_resnet_v2':
inception.inception_resnet_v2_arg_scope, inception.inception_resnet_v2_arg_scope,
'i3d': i3d.i3d_arg_scope,
's3dg': s3dg.s3dg_arg_scope,
'lenet': lenet.lenet_arg_scope, 'lenet': lenet.lenet_arg_scope,
'resnet_v1_50': resnet_v1.resnet_arg_scope, 'resnet_v1_50': resnet_v1.resnet_arg_scope,
'resnet_v1_101': resnet_v1.resnet_arg_scope, 'resnet_v1_101': resnet_v1.resnet_arg_scope,
...@@ -144,7 +151,8 @@ def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): ...@@ -144,7 +151,8 @@ def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
def network_fn(images, **kwargs): def network_fn(images, **kwargs):
arg_scope = arg_scopes_map[name](weight_decay=weight_decay) arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
with slim.arg_scope(arg_scope): with slim.arg_scope(arg_scope):
return func(images, num_classes, is_training=is_training, **kwargs) return func(images, num_classes=num_classes, is_training=is_training,
**kwargs)
if hasattr(func, 'default_image_size'): if hasattr(func, 'default_image_size'):
network_fn.default_image_size = func.default_image_size network_fn.default_image_size = func.default_image_size
......
...@@ -32,25 +32,45 @@ class NetworksTest(tf.test.TestCase): ...@@ -32,25 +32,45 @@ class NetworksTest(tf.test.TestCase):
num_classes = 1000 num_classes = 1000
for net in list(nets_factory.networks_map.keys())[:10]: for net in list(nets_factory.networks_map.keys())[:10]:
with tf.Graph().as_default() as g, self.test_session(g): with tf.Graph().as_default() as g, self.test_session(g):
net_fn = nets_factory.get_network_fn(net, num_classes) net_fn = nets_factory.get_network_fn(net, num_classes=num_classes)
# Most networks use 224 as their default_image_size # Most networks use 224 as their default_image_size
image_size = getattr(net_fn, 'default_image_size', 224) image_size = getattr(net_fn, 'default_image_size', 224)
inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) if net not in ['i3d', 's3dg']:
logits, end_points = net_fn(inputs) inputs = tf.random_uniform(
self.assertTrue(isinstance(logits, tf.Tensor)) (batch_size, image_size, image_size, 3))
self.assertTrue(isinstance(end_points, dict)) logits, end_points = net_fn(inputs)
self.assertEqual(logits.get_shape().as_list()[0], batch_size) self.assertTrue(isinstance(logits, tf.Tensor))
self.assertEqual(logits.get_shape().as_list()[-1], num_classes) self.assertTrue(isinstance(end_points, dict))
self.assertEqual(logits.get_shape().as_list()[0], batch_size)
self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
def testGetNetworkFnSecondHalf(self): def testGetNetworkFnSecondHalf(self):
batch_size = 5 batch_size = 5
num_classes = 1000 num_classes = 1000
for net in list(nets_factory.networks_map.keys())[10:]: for net in list(nets_factory.networks_map.keys())[10:]:
with tf.Graph().as_default() as g, self.test_session(g): with tf.Graph().as_default() as g, self.test_session(g):
net_fn = nets_factory.get_network_fn(net, num_classes) net_fn = nets_factory.get_network_fn(net, num_classes=num_classes)
# Most networks use 224 as their default_image_size # Most networks use 224 as their default_image_size
image_size = getattr(net_fn, 'default_image_size', 224) image_size = getattr(net_fn, 'default_image_size', 224)
inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) if net not in ['i3d', 's3dg']:
inputs = tf.random_uniform(
(batch_size, image_size, image_size, 3))
logits, end_points = net_fn(inputs)
self.assertTrue(isinstance(logits, tf.Tensor))
self.assertTrue(isinstance(end_points, dict))
self.assertEqual(logits.get_shape().as_list()[0], batch_size)
self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
def testGetNetworkFnVideoModels(self):
batch_size = 5
num_classes = 400
for net in ['i3d', 's3dg']:
with tf.Graph().as_default() as g, self.test_session(g):
net_fn = nets_factory.get_network_fn(net, num_classes=num_classes)
# Most networks use 224 as their default_image_size
image_size = getattr(net_fn, 'default_image_size', 224) // 2
inputs = tf.random_uniform(
(batch_size, 10, image_size, image_size, 3))
logits, end_points = net_fn(inputs) logits, end_points = net_fn(inputs)
self.assertTrue(isinstance(logits, tf.Tensor)) self.assertTrue(isinstance(logits, tf.Tensor))
self.assertTrue(isinstance(end_points, dict)) self.assertTrue(isinstance(end_points, dict))
......
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for networks.s3dg."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import s3dg
class S3DGTest(tf.test.TestCase):
def testBuildClassificationNetwork(self):
batch_size = 5
num_frames = 64
height, width = 224, 224
num_classes = 1000
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
logits, end_points = s3dg.s3dg(inputs, num_classes)
self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
self.assertListEqual(logits.get_shape().as_list(),
[batch_size, num_classes])
self.assertTrue('Predictions' in end_points)
self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
[batch_size, num_classes])
def testBuildBaseNetwork(self):
batch_size = 5
num_frames = 64
height, width = 224, 224
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
mixed_6c, end_points = s3dg.s3dg_base(inputs)
self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c'))
self.assertListEqual(mixed_6c.get_shape().as_list(),
[batch_size, 8, 7, 7, 1024])
expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b',
'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c',
'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2',
'Mixed_5b', 'Mixed_5c']
self.assertItemsEqual(end_points.keys(), expected_endpoints)
def testBuildOnlyUptoFinalEndpointNoGating(self):
batch_size = 5
num_frames = 64
height, width = 224, 224
endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d',
'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b',
'Mixed_5c']
for index, endpoint in enumerate(endpoints):
with tf.Graph().as_default():
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
out_tensor, end_points = s3dg.s3dg_base(
inputs, final_endpoint=endpoint, gating_startat=None)
print(endpoint, out_tensor.op.name)
self.assertTrue(out_tensor.op.name.startswith(
'InceptionV1/' + endpoint))
self.assertItemsEqual(endpoints[:index+1], end_points)
def testBuildAndCheckAllEndPointsUptoMixed5c(self):
batch_size = 5
num_frames = 64
height, width = 224, 224
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
_, end_points = s3dg.s3dg_base(inputs,
final_endpoint='Mixed_5c')
endpoints_shapes = {'Conv2d_1a_7x7': [5, 32, 112, 112, 64],
'MaxPool_2a_3x3': [5, 32, 56, 56, 64],
'Conv2d_2b_1x1': [5, 32, 56, 56, 64],
'Conv2d_2c_3x3': [5, 32, 56, 56, 192],
'MaxPool_3a_3x3': [5, 32, 28, 28, 192],
'Mixed_3b': [5, 32, 28, 28, 256],
'Mixed_3c': [5, 32, 28, 28, 480],
'MaxPool_4a_3x3': [5, 16, 14, 14, 480],
'Mixed_4b': [5, 16, 14, 14, 512],
'Mixed_4c': [5, 16, 14, 14, 512],
'Mixed_4d': [5, 16, 14, 14, 512],
'Mixed_4e': [5, 16, 14, 14, 528],
'Mixed_4f': [5, 16, 14, 14, 832],
'MaxPool_5a_2x2': [5, 8, 7, 7, 832],
'Mixed_5b': [5, 8, 7, 7, 832],
'Mixed_5c': [5, 8, 7, 7, 1024]}
self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
for endpoint_name, expected_shape in endpoints_shapes.iteritems():
self.assertTrue(endpoint_name in end_points)
self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
expected_shape)
def testHalfSizeImages(self):
batch_size = 5
num_frames = 64
height, width = 112, 112
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
mixed_5c, _ = s3dg.s3dg_base(inputs)
self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c'))
self.assertListEqual(mixed_5c.get_shape().as_list(),
[batch_size, 8, 4, 4, 1024])
def testTenFrames(self):
batch_size = 5
num_frames = 10
height, width = 224, 224
inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
mixed_5c, _ = s3dg.s3dg_base(inputs)
self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c'))
self.assertListEqual(mixed_5c.get_shape().as_list(),
[batch_size, 2, 7, 7, 1024])
def testEvaluation(self):
batch_size = 2
num_frames = 64
height, width = 224, 224
num_classes = 1000
eval_inputs = tf.random_uniform((batch_size, num_frames, height, width, 3))
logits, _ = s3dg.s3dg(eval_inputs, num_classes,
is_training=False)
predictions = tf.argmax(logits, 1)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(predictions)
self.assertEquals(output.shape, (batch_size,))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment