Commit c3d18895 authored by Jianmin Chen's avatar Jianmin Chen
Browse files

Merge pull request #40 from jmchen-g/master

update inception slim.
parents d2c7a37b 1a8c7121
......@@ -101,3 +101,12 @@ py_library(
":variables",
],
)
py_test(
name = "collections_test",
size = "small",
srcs = ["collections_test.py"],
deps = [
":slim",
],
)
This diff is collapsed.
......@@ -43,7 +43,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from inception.slim import ops
......@@ -98,10 +97,10 @@ def inception_v3(inputs,
# 73 x 73 x 64
end_points['conv3'] = ops.conv2d(end_points['pool1'], 80, [1, 1],
scope='conv3')
# 71 x 71 x 80.
# 73 x 73 x 80.
end_points['conv4'] = ops.conv2d(end_points['conv3'], 192, [3, 3],
scope='conv4')
# 69 x 69 x 192.
# 71 x 71 x 192.
end_points['pool2'] = ops.max_pool(end_points['conv4'], [3, 3],
stride=2, scope='pool2')
# 35 x 35 x 192.
......@@ -260,7 +259,10 @@ def inception_v3(inputs,
aux_logits = ops.fc(aux_logits, num_classes, activation=None,
stddev=0.001, restore=restore_logits)
end_points['aux_logits'] = aux_logits
# mixed_8: 17 x 17 x 1280.
# mixed_8: 8 x 8 x 1280.
# Note that the scope below is not changed to not void previous
# checkpoints.
# (TODO) Fix the scope when appropriate.
with tf.variable_scope('mixed_17x17x1280a'):
with tf.variable_scope('branch3x3'):
branch3x3 = ops.conv2d(net, 192, [1, 1])
......@@ -327,3 +329,28 @@ def inception_v3(inputs,
end_points['predictions'] = tf.nn.softmax(logits, name='predictions')
return logits, end_points
def inception_v3_parameters(weight_decay=0.00004, stddev=0.1,
batch_norm_decay=0.9997, batch_norm_epsilon=0.001):
"""Yields the scope with the default parameters for inception_v3.
Args:
weight_decay: the weight decay for weights variables.
stddev: standard deviation of the truncated guassian weight distribution.
batch_norm_decay: decay for the moving average of batch_norm momentums.
batch_norm_epsilon: small float added to variance to avoid dividing by zero.
Yields:
a arg_scope with the parameters needed for inception_v3.
"""
# Set weight_decay for weights in Conv and FC layers.
with scopes.arg_scope([ops.conv2d, ops.fc],
weight_decay=weight_decay):
# Set stddev, activation and parameters for batch_norm.
with scopes.arg_scope([ops.conv2d],
stddev=stddev,
activation=tf.nn.relu,
batch_norm_params={
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon}) as arg_scope:
yield arg_scope
......@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from inception.slim import inception_model as inception
......@@ -55,6 +54,22 @@ class InceptionTest(tf.test.TestCase):
self.assertListEqual(pre_pool.get_shape().as_list(),
[batch_size, 8, 8, 2048])
def testVariablesSetDevice(self):
batch_size = 5
height, width = 299, 299
num_classes = 1000
with self.test_session():
inputs = tf.random_uniform((batch_size, height, width, 3))
# Force all Variables to reside on the device.
with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
inception.inception_v3(inputs, num_classes)
with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
inception.inception_v3(inputs, num_classes)
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
self.assertDeviceEqual(v.device, '/cpu:0')
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
self.assertDeviceEqual(v.device, '/gpu:0')
def testHalfSizeImages(self):
batch_size = 5
height, width = 150, 150
......
......@@ -26,7 +26,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# In order to gather all losses in a network, the user should use this
......@@ -35,6 +34,71 @@ import tensorflow as tf
LOSSES_COLLECTION = '_losses'
def l1_regularizer(weight=1.0, scope=None):
"""Define a L1 regularizer.
Args:
weight: scale the loss by this factor.
scope: Optional scope for op_scope.
Returns:
a regularizer function.
"""
def regularizer(tensor):
with tf.op_scope([tensor], scope, 'L1Regularizer'):
l1_weight = tf.convert_to_tensor(weight,
dtype=tensor.dtype.base_dtype,
name='weight')
return tf.mul(l1_weight, tf.reduce_sum(tf.abs(tensor)), name='value')
return regularizer
def l2_regularizer(weight=1.0, scope=None):
"""Define a L2 regularizer.
Args:
weight: scale the loss by this factor.
scope: Optional scope for op_scope.
Returns:
a regularizer function.
"""
def regularizer(tensor):
with tf.op_scope([tensor], scope, 'L2Regularizer'):
l2_weight = tf.convert_to_tensor(weight,
dtype=tensor.dtype.base_dtype,
name='weight')
return tf.mul(l2_weight, tf.nn.l2_loss(tensor), name='value')
return regularizer
def l1_l2_regularizer(weight_l1=1.0, weight_l2=1.0, scope=None):
"""Define a L1L2 regularizer.
Args:
weight_l1: scale the L1 loss by this factor.
weight_l2: scale the L2 loss by this factor.
scope: Optional scope for op_scope.
Returns:
a regularizer function.
"""
def regularizer(tensor):
with tf.op_scope([tensor], scope, 'L1L2Regularizer'):
weight_l1_t = tf.convert_to_tensor(weight_l1,
dtype=tensor.dtype.base_dtype,
name='weight_l1')
weight_l2_t = tf.convert_to_tensor(weight_l2,
dtype=tensor.dtype.base_dtype,
name='weight_l2')
reg_l1 = tf.mul(weight_l1_t, tf.reduce_sum(tf.abs(tensor)),
name='value_l1')
reg_l2 = tf.mul(weight_l2_t, tf.nn.l2_loss(tensor),
name='value_l2')
return tf.add(reg_l1, reg_l2, name='value')
return regularizer
def l1_loss(tensor, weight=1.0, scope=None):
"""Define a L1Loss, useful for regularize, i.e. lasso.
......
......@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
from inception.slim import losses
......@@ -47,6 +46,95 @@ class LossesTest(tf.test.TestCase):
self.assertAlmostEqual(loss.eval(), num_elem * wd / 2, 5)
class RegularizersTest(tf.test.TestCase):
def testL1Regularizer(self):
with self.test_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = tf.constant(1.0, shape=shape)
loss = losses.l1_regularizer()(tensor)
self.assertEquals(loss.op.name, 'L1Regularizer/value')
self.assertAlmostEqual(loss.eval(), num_elem, 5)
def testL1RegularizerWithScope(self):
with self.test_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = tf.constant(1.0, shape=shape)
loss = losses.l1_regularizer(scope='L1')(tensor)
self.assertEquals(loss.op.name, 'L1/value')
self.assertAlmostEqual(loss.eval(), num_elem, 5)
def testL1RegularizerWithWeight(self):
with self.test_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = tf.constant(1.0, shape=shape)
weight = 0.01
loss = losses.l1_regularizer(weight)(tensor)
self.assertEquals(loss.op.name, 'L1Regularizer/value')
self.assertAlmostEqual(loss.eval(), num_elem * weight, 5)
def testL2Regularizer(self):
with self.test_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = tf.constant(1.0, shape=shape)
loss = losses.l2_regularizer()(tensor)
self.assertEquals(loss.op.name, 'L2Regularizer/value')
self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
def testL2RegularizerWithScope(self):
with self.test_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = tf.constant(1.0, shape=shape)
loss = losses.l2_regularizer(scope='L2')(tensor)
self.assertEquals(loss.op.name, 'L2/value')
self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
def testL2RegularizerWithWeight(self):
with self.test_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = tf.constant(1.0, shape=shape)
weight = 0.01
loss = losses.l2_regularizer(weight)(tensor)
self.assertEquals(loss.op.name, 'L2Regularizer/value')
self.assertAlmostEqual(loss.eval(), num_elem * weight / 2, 5)
def testL1L2Regularizer(self):
with self.test_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = tf.constant(1.0, shape=shape)
loss = losses.l1_l2_regularizer()(tensor)
self.assertEquals(loss.op.name, 'L1L2Regularizer/value')
self.assertAlmostEqual(loss.eval(), num_elem + num_elem / 2, 5)
def testL1L2RegularizerWithScope(self):
with self.test_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = tf.constant(1.0, shape=shape)
loss = losses.l1_l2_regularizer(scope='L1L2')(tensor)
self.assertEquals(loss.op.name, 'L1L2/value')
self.assertAlmostEqual(loss.eval(), num_elem + num_elem / 2, 5)
def testL1L2RegularizerWithWeights(self):
with self.test_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = tf.constant(1.0, shape=shape)
weight_l1 = 0.01
weight_l2 = 0.05
loss = losses.l1_l2_regularizer(weight_l1, weight_l2)(tensor)
self.assertEquals(loss.op.name, 'L1L2Regularizer/value')
self.assertAlmostEqual(loss.eval(),
num_elem * weight_l1 + num_elem * weight_l2 / 2, 5)
class CrossEntropyLossTest(tf.test.TestCase):
def testCrossEntropyLossAllCorrect(self):
......
......@@ -27,7 +27,6 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.training import moving_averages
......@@ -50,7 +49,8 @@ def batch_norm(inputs,
is_training=True,
trainable=True,
restore=True,
scope=None):
scope=None,
reuse=None):
"""Adds a Batch Normalization layer.
Args:
......@@ -67,13 +67,15 @@ def batch_norm(inputs,
trainable: whether or not the variables should be trainable or not.
restore: whether or not the variables should be marked for restore.
scope: Optional scope for variable_op_scope.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
Returns:
a tensor representing the output of the operation.
"""
inputs_shape = inputs.get_shape()
with tf.variable_op_scope([inputs], scope, 'BatchNorm'):
with tf.variable_op_scope([inputs], scope, 'BatchNorm', reuse=reuse):
axis = range(len(inputs_shape) - 1)
params_shape = inputs_shape[-1:]
with scopes.arg_scope([variables.variable], restore=restore):
......@@ -124,6 +126,37 @@ def batch_norm(inputs,
return outputs
def _two_element_tuple(int_or_tuple):
"""Converts `int_or_tuple` to height, width.
Several of the functions that follow accept arguments as either
a tuple of 2 integers or a single integer. A single integer
indicates that the 2 values of the tuple are the same.
This functions normalizes the input value by always returning a tuple.
Args:
int_or_tuple: A list of 2 ints, a single int or a tf.TensorShape.
Returns:
A tuple with 2 values.
Raises:
ValueError: If `int_or_tuple` it not well formed.
"""
if isinstance(int_or_tuple, (list, tuple)):
if len(int_or_tuple) != 2:
raise ValueError('Must be a list with 2 elements: %s' % int_or_tuple)
return int(int_or_tuple[0]), int(int_or_tuple[1])
if isinstance(int_or_tuple, int):
return int(int_or_tuple), int(int_or_tuple)
if isinstance(int_or_tuple, tf.TensorShape):
if len(int_or_tuple) == 2:
return int_or_tuple[0], int_or_tuple[1]
raise ValueError('Must be an int, a list with 2 elements or a TensorShape of '
'length 2')
@scopes.add_arg_scope
def conv2d(inputs,
num_filters_out,
......@@ -138,7 +171,8 @@ def conv2d(inputs,
is_training=True,
trainable=True,
restore=True,
scope=None):
scope=None,
reuse=None):
"""Adds a 2D convolution followed by an optional batch_norm layer.
conv2d creates a variable called 'weights', representing the convolutional
......@@ -149,8 +183,11 @@ def conv2d(inputs,
Args:
inputs: a tensor of size [batch_size, height, width, channels].
num_filters_out: the number of output filters.
kernel_size: a 2-D list comprising of the height and width of the filters.
stride: the stride in height and width of the convolution.
kernel_size: a list of length 2: [kernel_height, kernel_width] of
of the filters. Can be an int if both values are the same.
stride: a list of length 2: [stride_height, stride_width].
Can be an int if both strides are the same. Note that presently
both strides must have the same value.
padding: one of 'VALID' or 'SAME'.
activation: activation function.
stddev: standard deviation of the truncated guassian weight distribution.
......@@ -161,28 +198,29 @@ def conv2d(inputs,
trainable: whether or not the variables should be trainable or not.
restore: whether or not the variables should be marked for restore.
scope: Optional scope for variable_op_scope.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
Returns:
a tensor representing the output of the operation.
Raises:
ValueError: if 'kernel_size' is not a 2-D list.
"""
if len(kernel_size) != 2:
raise ValueError('kernel_size must be a 2-D list.')
with tf.variable_op_scope([inputs], scope, 'Conv'):
with tf.variable_op_scope([inputs], scope, 'Conv', reuse=reuse):
kernel_h, kernel_w = _two_element_tuple(kernel_size)
stride_h, stride_w = _two_element_tuple(stride)
num_filters_in = inputs.get_shape()[-1]
weights_shape = [kernel_size[0], kernel_size[1],
weights_shape = [kernel_h, kernel_w,
num_filters_in, num_filters_out]
weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
l2_regularizer = lambda t: losses.l2_loss(t, weight_decay)
l2_regularizer = None
if weight_decay and weight_decay > 0:
l2_regularizer = losses.l2_regularizer(weight_decay)
weights = variables.variable('weights',
shape=weights_shape,
initializer=weights_initializer,
regularizer=l2_regularizer,
trainable=trainable,
restore=restore)
conv = tf.nn.conv2d(inputs, weights, [1, stride, stride, 1],
conv = tf.nn.conv2d(inputs, weights, [1, stride_h, stride_w, 1],
padding=padding)
if batch_norm_params is not None:
with scopes.arg_scope([batch_norm], is_training=is_training,
......@@ -213,7 +251,8 @@ def fc(inputs,
is_training=True,
trainable=True,
restore=True,
scope=None):
scope=None,
reuse=None):
"""Adds a fully connected layer followed by an optional batch_norm layer.
FC creates a variable called 'weights', representing the fully connected
......@@ -234,15 +273,19 @@ def fc(inputs,
trainable: whether or not the variables should be trainable or not.
restore: whether or not the variables should be marked for restore.
scope: Optional scope for variable_op_scope.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
Returns:
the tensor variable representing the result of the series of operations.
"""
with tf.variable_op_scope([inputs], scope, 'FC'):
with tf.variable_op_scope([inputs], scope, 'FC', reuse=reuse):
num_units_in = inputs.get_shape()[1]
weights_shape = [num_units_in, num_units_out]
weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
l2_regularizer = lambda t: losses.l2_loss(t, weight_decay)
l2_regularizer = None
if weight_decay and weight_decay > 0:
l2_regularizer = losses.l2_regularizer(weight_decay)
weights = variables.variable('weights',
shape=weights_shape,
initializer=weights_initializer,
......@@ -298,8 +341,12 @@ def max_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
Args:
inputs: a tensor of size [batch_size, height, width, depth].
kernel_size: the size of the pooling kernel over which the op is computed.
stride: the stride in height and width of the convolution.
kernel_size: a list of length 2: [kernel_height, kernel_width] of the
pooling kernel over which the op is computed. Can be an int if both
values are the same.
stride: a list of length 2: [stride_height, stride_width].
Can be an int if both strides are the same. Note that presently
both strides must have the same value.
padding: the padding method, either 'VALID' or 'SAME'.
scope: Optional scope for op_scope.
......@@ -308,12 +355,12 @@ def max_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
Raises:
ValueError: if 'kernel_size' is not a 2-D list
"""
if len(kernel_size) != 2:
raise ValueError('kernel_size must be a 2-D list.')
with tf.op_scope([inputs], scope, 'MaxPool'):
kernel_h, kernel_w = _two_element_tuple(kernel_size)
stride_h, stride_w = _two_element_tuple(stride)
return tf.nn.max_pool(inputs,
ksize=[1, kernel_size[0], kernel_size[1], 1],
strides=[1, stride, stride, 1],
ksize=[1, kernel_h, kernel_w, 1],
strides=[1, stride_h, stride_w, 1],
padding=padding)
......@@ -326,22 +373,24 @@ def avg_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
Args:
inputs: a tensor of size [batch_size, height, width, depth].
kernel_size: the size of the pooling kernel over which the op is computed.
stride: the stride in height and width of the convolution.
kernel_size: a list of length 2: [kernel_height, kernel_width] of the
pooling kernel over which the op is computed. Can be an int if both
values are the same.
stride: a list of length 2: [stride_height, stride_width].
Can be an int if both strides are the same. Note that presently
both strides must have the same value.
padding: the padding method, either 'VALID' or 'SAME'.
scope: Optional scope for op_scope.
Returns:
a tensor representing the results of the pooling operation.
Raises:
ValueError: if 'kernel_size' is not a 2-D list
"""
if len(kernel_size) != 2:
raise ValueError('kernel_size must be a 2-D list.')
with tf.op_scope([inputs], scope, 'AvgPool'):
kernel_h, kernel_w = _two_element_tuple(kernel_size)
stride_h, stride_w = _two_element_tuple(stride)
return tf.nn.avg_pool(inputs,
ksize=[1, kernel_size[0], kernel_size[1], 1],
strides=[1, stride, stride, 1],
ksize=[1, kernel_h, kernel_w, 1],
strides=[1, stride_h, stride_w, 1],
padding=padding)
......
......@@ -18,13 +18,11 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
from inception.slim import losses
from inception.slim import ops
from inception.slim import scopes
from inception.slim import variables
......@@ -40,6 +38,57 @@ class ConvTest(tf.test.TestCase):
self.assertEquals(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
def testCreateSquareConv(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = ops.conv2d(images, 32, 3)
self.assertEquals(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
def testCreateConvWithTensorShape(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = ops.conv2d(images, 32, images.get_shape()[1:3])
self.assertEquals(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
def testCreateFullyConv(self):
height, width = 6, 6
with self.test_session():
images = tf.random_uniform((5, height, width, 32), seed=1)
output = ops.conv2d(images, 64, images.get_shape()[1:3], padding='VALID')
self.assertEquals(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 64])
def testCreateVerticalConv(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = ops.conv2d(images, 32, [3, 1])
self.assertEquals(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(),
[5, height, width, 32])
def testCreateHorizontalConv(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = ops.conv2d(images, 32, [1, 3])
self.assertEquals(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(),
[5, height, width, 32])
def testCreateConvWithStride(self):
height, width = 6, 6
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = ops.conv2d(images, 32, [3, 3], stride=2)
self.assertEquals(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(),
[5, height/2, width/2, 32])
def testCreateConvCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
images = tf.random_uniform((5, height, width, 3), seed=1)
......@@ -76,31 +125,73 @@ class ConvTest(tf.test.TestCase):
with self.test_session() as sess:
images = tf.random_uniform((5, height, width, 3), seed=1)
ops.conv2d(images, 32, [3, 3], weight_decay=0.01)
wd = tf.get_collection(losses.LOSSES_COLLECTION)[0]
self.assertEquals(wd.op.name, 'Conv/weights/Regularizer/L2Loss/value')
wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEquals(wd.op.name,
'Conv/weights/Regularizer/L2Regularizer/value')
sess.run(tf.initialize_all_variables())
self.assertTrue(sess.run(wd) <= 0.01)
def testCreateConvWithoutWD(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
ops.conv2d(images, 32, [3, 3], weight_decay=0)
self.assertEquals(
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
def testReuseVars(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
ops.conv2d(images, 32, [3, 3], scope='conv1')
self.assertEquals(len(variables.get_variables()), 2)
ops.conv2d(images, 32, [3, 3], scope='conv1', reuse=True)
self.assertEquals(len(variables.get_variables()), 2)
def testNonReuseVars(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
ops.conv2d(images, 32, [3, 3])
self.assertEquals(len(variables.get_variables()), 2)
ops.conv2d(images, 32, [3, 3])
self.assertEquals(len(variables.get_variables()), 4)
def testReuseConvWithWD(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
tf.get_variable_scope().reuse_variables()
ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
self.assertEquals(len(variables.get_variables()), 2)
self.assertEquals(
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1',
reuse=True)
self.assertEquals(len(variables.get_variables()), 2)
self.assertEquals(
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
def testConvWithBatchNorm(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
with scopes.arg_scope([ops.conv2d], batch_norm_params={}):
net = ops.conv2d(images, 32, [3, 3], scope='conv1')
net = ops.conv2d(net, 32, [3, 3], scope='conv2')
self.assertEquals(len(tf.get_collection('moving_vars')), 4)
self.assertEquals(len(variables.get_variables('conv1/BatchNorm')), 3)
self.assertEquals(len(variables.get_variables('conv2/BatchNorm')), 3)
images = tf.random_uniform((5, height, width, 32), seed=1)
with scopes.arg_scope([ops.conv2d], batch_norm_params={'decay': 0.9}):
net = ops.conv2d(images, 32, [3, 3])
net = ops.conv2d(net, 32, [3, 3])
self.assertEquals(len(variables.get_variables()), 8)
self.assertEquals(len(variables.get_variables('Conv/BatchNorm')), 3)
self.assertEquals(len(variables.get_variables('Conv_1/BatchNorm')), 3)
def testReuseConvWithBatchNorm(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 32), seed=1)
with scopes.arg_scope([ops.conv2d], batch_norm_params={'decay': 0.9}):
net = ops.conv2d(images, 32, [3, 3], scope='Conv')
net = ops.conv2d(net, 32, [3, 3], scope='Conv', reuse=True)
self.assertEquals(len(variables.get_variables()), 4)
self.assertEquals(len(variables.get_variables('Conv/BatchNorm')), 3)
self.assertEquals(len(variables.get_variables('Conv_1/BatchNorm')), 0)
class FCTest(tf.test.TestCase):
......@@ -136,8 +227,7 @@ class FCTest(tf.test.TestCase):
with self.test_session():
ops.fc(inputs, 32, scope='fc1')
self.assertEquals(len(variables.get_variables('fc1')), 2)
tf.get_variable_scope().reuse_variables()
ops.fc(inputs, 32, scope='fc1')
ops.fc(inputs, 32, scope='fc1', reuse=True)
self.assertEquals(len(variables.get_variables('fc1')), 2)
def testNonReuseVars(self):
......@@ -161,31 +251,53 @@ class FCTest(tf.test.TestCase):
with self.test_session() as sess:
inputs = tf.random_uniform((5, height * width * 3), seed=1)
ops.fc(inputs, 32, weight_decay=0.01)
wd = tf.get_collection(losses.LOSSES_COLLECTION)[0]
self.assertEquals(wd.op.name, 'FC/weights/Regularizer/L2Loss/value')
wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEquals(wd.op.name,
'FC/weights/Regularizer/L2Regularizer/value')
sess.run(tf.initialize_all_variables())
self.assertTrue(sess.run(wd) <= 0.01)
def testCreateFCWithoutWD(self):
height, width = 3, 3
with self.test_session():
inputs = tf.random_uniform((5, height * width * 3), seed=1)
ops.fc(inputs, 32, weight_decay=0)
self.assertEquals(
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
def testReuseFCWithWD(self):
height, width = 3, 3
with self.test_session():
inputs = tf.random_uniform((5, height * width * 3), seed=1)
ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
tf.get_variable_scope().reuse_variables()
ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
self.assertEquals(len(variables.get_variables()), 2)
self.assertEquals(
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
ops.fc(inputs, 32, weight_decay=0.01, scope='fc', reuse=True)
self.assertEquals(len(variables.get_variables()), 2)
self.assertEquals(
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
def testFCWithBatchNorm(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height * width * 3), seed=1)
with scopes.arg_scope([ops.fc], batch_norm_params={}):
net = ops.fc(images, 32, scope='fc1')
net = ops.fc(net, 32, scope='fc2')
self.assertEquals(len(tf.get_collection('moving_vars')), 4)
net = ops.fc(images, 27)
net = ops.fc(net, 27)
self.assertEquals(len(variables.get_variables()), 8)
self.assertEquals(len(variables.get_variables('FC/BatchNorm')), 3)
self.assertEquals(len(variables.get_variables('FC_1/BatchNorm')), 3)
def testReuseFCWithBatchNorm(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height * width * 3), seed=1)
with scopes.arg_scope([ops.fc], batch_norm_params={'decay': 0.9}):
net = ops.fc(images, 27, scope='fc1')
net = ops.fc(net, 27, scope='fc1', reuse=True)
self.assertEquals(len(variables.get_variables()), 4)
self.assertEquals(len(variables.get_variables('fc1/BatchNorm')), 3)
self.assertEquals(len(variables.get_variables('fc2/BatchNorm')), 3)
class MaxPoolTest(tf.test.TestCase):
......@@ -198,6 +310,14 @@ class MaxPoolTest(tf.test.TestCase):
self.assertEquals(output.op.name, 'MaxPool/MaxPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
def testCreateSquareMaxPool(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = ops.max_pool(images, 3)
self.assertEquals(output.op.name, 'MaxPool/MaxPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
def testCreateMaxPoolWithScope(self):
height, width = 3, 3
with self.test_session():
......@@ -219,6 +339,13 @@ class MaxPoolTest(tf.test.TestCase):
output = ops.max_pool(images, [3, 3], stride=1, padding='SAME')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
def testGlobalMaxPool(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = ops.max_pool(images, images.get_shape()[1:3], stride=1)
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
class AvgPoolTest(tf.test.TestCase):
......@@ -230,6 +357,14 @@ class AvgPoolTest(tf.test.TestCase):
self.assertEquals(output.op.name, 'AvgPool/AvgPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
def testCreateSquareAvgPool(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = ops.avg_pool(images, 3)
self.assertEquals(output.op.name, 'AvgPool/AvgPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
def testCreateAvgPoolWithScope(self):
height, width = 3, 3
with self.test_session():
......@@ -251,6 +386,13 @@ class AvgPoolTest(tf.test.TestCase):
output = ops.avg_pool(images, [3, 3], stride=1, padding='SAME')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
def testGlobalAvgPool(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = ops.avg_pool(images, images.get_shape()[1:3], stride=1)
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
class OneHotEncodingTest(tf.test.TestCase):
......@@ -342,8 +484,8 @@ class BatchNormTest(tf.test.TestCase):
gamma = variables.get_variables_by_name('gamma')[0]
self.assertEquals(beta.op.name, 'BatchNorm/beta')
self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
moving_mean = tf.get_collection('moving_vars')[0]
moving_variance = tf.get_collection('moving_vars')[1]
moving_mean = tf.moving_average_variables()[0]
moving_variance = tf.moving_average_variables()[1]
self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
......@@ -375,8 +517,7 @@ class BatchNormTest(tf.test.TestCase):
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
ops.batch_norm(images, scale=True, scope='bn')
tf.get_variable_scope().reuse_variables()
ops.batch_norm(images, scale=True, scope='bn')
ops.batch_norm(images, scale=True, scope='bn', reuse=True)
beta = variables.get_variables_by_name('beta')
gamma = variables.get_variables_by_name('gamma')
self.assertEquals(len(beta), 1)
......@@ -390,8 +531,7 @@ class BatchNormTest(tf.test.TestCase):
images = tf.random_uniform((5, height, width, 3), seed=1)
ops.batch_norm(images, scope='bn')
self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 2)
tf.get_variable_scope().reuse_variables()
ops.batch_norm(images, scope='bn')
ops.batch_norm(images, scope='bn', reuse=True)
self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 4)
def testCreateMovingVars(self):
......
......@@ -19,7 +19,7 @@
Example of how to use scopes.arg_scope:
with slim.arg_scope(ops.conv2d, padding='SAME',
with scopes.arg_scope(ops.conv2d, padding='SAME',
stddev=0.01, weight_decay=0.0005):
net = ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
net = ops.conv2d(net, 256, [5, 5], scope='conv2')
......@@ -32,6 +32,15 @@
ops.conv2d(inputs, 256, [5, 5], padding='SAME',
stddev=0.01, weight_decay=0.0005, scope='conv2')
Example of how to reuse an arg_scope:
with scopes.arg_scope(ops.conv2d, padding='SAME',
stddev=0.01, weight_decay=0.0005) as conv2d_arg_scope:
net = ops.conv2d(net, 256, [5, 5], scope='conv1')
....
with scopes.arg_scope(conv2d_arg_scope):
net = ops.conv2d(net, 256, [5, 5], scope='conv2')
Example of how to use scopes.add_arg_scope:
@scopes.add_arg_scope
......@@ -44,7 +53,6 @@ from __future__ import print_function
import contextlib
import functools
from tensorflow.python.framework import ops
_ARGSTACK_KEY = ("__arg_stack",)
......@@ -74,12 +82,16 @@ def _add_op(op):
@contextlib.contextmanager
def arg_scope(list_ops, **kwargs):
def arg_scope(list_ops_or_scope, **kwargs):
"""Stores the default arguments for the given set of list_ops.
For usage, please see examples at top of the file.
Args:
list_ops: List or tuple of operations to set argument scope for. Every op in
list_ops need to be decorated with @add_arg_scope to work.
list_ops_or_scope: List or tuple of operations to set argument scope for or
a dictionary containg the current scope. When list_ops_or_scope is a dict,
kwargs must be empty. When list_ops_or_scope is a list or tuple, then
every op in it need to be decorated with @add_arg_scope to work.
**kwargs: keyword=value that will define the defaults for each op in
list_ops. All the ops need to accept the given set of arguments.
......@@ -89,24 +101,38 @@ def arg_scope(list_ops, **kwargs):
TypeError: if list_ops is not a list or a tuple.
ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
"""
if not isinstance(list_ops, (list, tuple)):
raise TypeError("list_ops is not a list or a tuple")
try:
current_scope = _current_arg_scope().copy()
for op in list_ops:
key_op = (op.__module__, op.__name__)
if not has_arg_scope(op):
raise ValueError("%s is not decorated with @add_arg_scope", key_op)
if key_op in current_scope:
current_kwargs = current_scope[key_op].copy()
current_kwargs.update(kwargs)
current_scope[key_op] = current_kwargs
else:
current_scope[key_op] = kwargs.copy()
_get_arg_stack().append(current_scope)
yield current_scope
finally:
_get_arg_stack().pop()
if isinstance(list_ops_or_scope, dict):
# Assumes that list_ops_or_scope is a scope that is being reused.
if kwargs:
raise ValueError("When attempting to re-use a scope by suppling a"
"dictionary, kwargs must be empty.")
current_scope = list_ops_or_scope.copy()
try:
_get_arg_stack().append(current_scope)
yield current_scope
finally:
_get_arg_stack().pop()
else:
# Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
if not isinstance(list_ops_or_scope, (list, tuple)):
raise TypeError("list_ops_or_scope must either be a list/tuple or reused"
"scope (i.e. dict)")
try:
current_scope = _current_arg_scope().copy()
for op in list_ops_or_scope:
key_op = (op.__module__, op.__name__)
if not has_arg_scope(op):
raise ValueError("%s is not decorated with @add_arg_scope", key_op)
if key_op in current_scope:
current_kwargs = current_scope[key_op].copy()
current_kwargs.update(kwargs)
current_scope[key_op] = current_kwargs
else:
current_scope[key_op] = kwargs.copy()
_get_arg_stack().append(current_scope)
yield current_scope
finally:
_get_arg_stack().pop()
def add_arg_scope(func):
......
......@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
from inception.slim import scopes
......@@ -39,6 +38,51 @@ class ArgScopeTest(tf.test.TestCase):
with self.test_session():
self.assertEqual(scopes._current_arg_scope(), {})
def testCurrentArgScope(self):
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
key_op = (func1.__module__, func1.__name__)
current_scope = {key_op: func1_kwargs.copy()}
with self.test_session():
with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope:
self.assertDictEqual(scope, current_scope)
def testCurrentArgScopeNested(self):
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
func2_kwargs = {'b': 2, 'd': [2]}
key = lambda f: (f.__module__, f.__name__)
current_scope = {key(func1): func1_kwargs.copy(),
key(func2): func2_kwargs.copy()}
with self.test_session():
with scopes.arg_scope([func1], a=1, b=None, c=[1]):
with scopes.arg_scope([func2], b=2, d=[2]) as scope:
self.assertDictEqual(scope, current_scope)
def testReuseArgScope(self):
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
key_op = (func1.__module__, func1.__name__)
current_scope = {key_op: func1_kwargs.copy()}
with self.test_session():
with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope1:
pass
with scopes.arg_scope(scope1) as scope:
self.assertDictEqual(scope, current_scope)
def testReuseArgScopeNested(self):
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
func2_kwargs = {'b': 2, 'd': [2]}
key = lambda f: (f.__module__, f.__name__)
current_scope1 = {key(func1): func1_kwargs.copy()}
current_scope2 = {key(func1): func1_kwargs.copy(),
key(func2): func2_kwargs.copy()}
with self.test_session():
with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope1:
with scopes.arg_scope([func2], b=2, d=[2]) as scope2:
pass
with scopes.arg_scope(scope1):
self.assertDictEqual(scopes._current_arg_scope(), current_scope1)
with scopes.arg_scope(scope2):
self.assertDictEqual(scopes._current_arg_scope(), current_scope2)
def testSimpleArgScope(self):
func1_args = (0,)
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
......
......@@ -12,7 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains convenience wrappers for creating Variables in TensorFlow.
"""Contains convenience wrappers for creating variables in TF-Slim.
The variables module is typically used for defining model variables from the
ops routines (see slim.ops). Such variables are used for training, evaluation
and inference of models.
All the variables created through this module would be added to the
MODEL_VARIABLES collection, if you create a model variable outside slim, it can
be added with slim.variables.add_variable(external_variable, reuse).
Usage:
weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
......@@ -24,15 +32,15 @@ Usage:
device='/cpu:0')
biases = variables.variable('biases',
shape=[100],
initializer=tf.zeros_initializer,
device='/cpu:0')
shape=[100],
initializer=tf.zeros_initializer,
device='/cpu:0')
# More complex example.
net = slim.ops.conv2d(input, 32, [3, 3], scope='conv1')
net = slim.ops.conv2d(net, 64, [3, 3], scope='conv2')
with slim.arg_scope(variables.Variables, restore=False):
with slim.arg_scope([variables.variable], restore=False):
net = slim.ops.conv2d(net, 64, [3, 3], scope='conv3')
# Get all model variables from all the layers.
......@@ -47,9 +55,9 @@ Usage:
# Get all bias from all the layers.
biases = slim.variables.get_variables_by_name('biases')
# Get all variables in the VARIABLES_TO_RESTORE collection
# Get all variables to restore.
# (i.e. only those created by 'conv1' and 'conv2')
variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
variables_to_restore = slim.variables.get_variables_to_restore()
************************************************
* Initializing model variables from a checkpoint
......@@ -60,7 +68,7 @@ v1 = slim.variables.variable(name="v1", ..., restore=False)
v2 = slim.variables.variable(name="v2", ...) # By default restore=True
...
# The list of variables to restore should only contain 'v2'.
variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
variables_to_restore = slim.variables.get_variables_to_restore()
restorer = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
# Restore variables from disk.
......@@ -74,92 +82,71 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from inception.slim import scopes
# Collection containing all the variables created using slim.variables
VARIABLES_COLLECTION = '_variables_'
MODEL_VARIABLES = '_model_variables_'
# Collection containing all the slim.variables that are marked to_restore
# Collection containing the slim.variables that are created with restore=True.
VARIABLES_TO_RESTORE = '_variables_to_restore_'
def get_variable_given_name(var):
"""Gets the variable given name without the scope.
Args:
var: a variable.
Returns:
the given name of the variable without the scope.
"""
name = var.op.name
if '/' in name:
name = name.split('/')[-1]
return name
def default_collections(given_name, restore):
"""Define the set of default collections that variables should be added.
Args:
given_name: the given name of the variable.
restore: whether the variable should be added to the VARIABLES_TO_RESTORE
collection.
Returns:
a list of default collections.
"""
defaults = [tf.GraphKeys.VARIABLES, VARIABLES_COLLECTION]
defaults += [VARIABLES_COLLECTION + given_name]
if restore:
defaults += [VARIABLES_TO_RESTORE]
return defaults
def add_variable(var, restore=True):
"""Adds a variable to the default set of collections.
"""Adds a variable to the MODEL_VARIABLES collection.
Optionally it will add the variable to the VARIABLES_TO_RESTORE collection.
Args:
var: a variable.
restore: whether the variable should be added to the
VARIABLES_TO_RESTORE collection.
"""
given_name = get_variable_given_name(var)
for collection in default_collections(given_name, restore):
collections = [MODEL_VARIABLES]
if restore:
collections.append(VARIABLES_TO_RESTORE)
for collection in collections:
if var not in tf.get_collection(collection):
tf.add_to_collection(collection, var)
def get_variables(prefix=None, suffix=None):
"""Gets the list of variables, filtered by prefix and/or suffix.
def get_variables(scope=None, suffix=None):
"""Gets the list of variables, filtered by scope and/or suffix.
Args:
prefix: an optional prefix for filtering the variables to return.
scope: an optional scope for filtering the variables to return.
suffix: an optional suffix for filtering the variables to return.
Returns:
a list of variables with prefix and suffix.
a copied list of variables with scope and suffix.
"""
candidates = tf.get_collection(VARIABLES_COLLECTION, prefix)
candidates = tf.get_collection(MODEL_VARIABLES, scope)[:]
if suffix is not None:
candidates = [var for var in candidates if var.op.name.endswith(suffix)]
return candidates
def get_variables_by_name(given_name, prefix=None):
"""Gets the list of variables were given that name.
def get_variables_to_restore():
"""Gets the list of variables to restore.
Returns:
a copied list of variables.
"""
return tf.get_collection(VARIABLES_TO_RESTORE)[:]
def get_variables_by_name(given_name, scope=None):
"""Gets the list of variables that were given that name.
Args:
given_name: name given to the variable without scope.
prefix: an optional prefix for filtering the variables to return.
scope: an optional scope for filtering the variables to return.
Returns:
a list of variables with prefix and suffix.
a copied list of variables with the given name and prefix.
"""
return tf.get_collection(VARIABLES_COLLECTION + given_name, prefix)
return get_variables(scope=scope, suffix=given_name)
def get_unique_variable(name):
......@@ -204,7 +191,7 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
collections: A list of collection names to which the Variable will be added.
Note that the variable is always also added to the tf.GraphKeys.VARIABLES
collection.
and MODEL_VARIABLES collections.
device: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
restore: whether the variable should be added to the
......@@ -216,8 +203,15 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
# Instantiate the device for this variable if it is passed as a function.
if device and callable(device):
device = device()
collections = set(list(collections or []) + default_collections(name,
restore))
collections = list(collections or [])
# Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
collections += [tf.GraphKeys.VARIABLES, MODEL_VARIABLES]
# Add to VARIABLES_TO_RESTORE if necessary
if restore:
collections.append(VARIABLES_TO_RESTORE)
# Remove duplicates
collections = set(collections)
with tf.device(device):
return tf.get_variable(name, shape=shape, dtype=dtype,
initializer=initializer, regularizer=regularizer,
......
......@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from inception.slim import scopes
......@@ -33,29 +32,13 @@ class VariablesTest(tf.test.TestCase):
self.assertEquals(a.op.name, 'A/a')
self.assertListEqual(a.get_shape().as_list(), [5])
def testGetVariableGivenName(self):
with self.test_session():
with tf.variable_scope('A'):
a = variables.variable('a', [5])
with tf.variable_scope('B'):
b = variables.variable('a', [5])
self.assertEquals('a', variables.get_variable_given_name(a))
self.assertEquals('a', variables.get_variable_given_name(b))
def testGetVariableGivenNameScoped(self):
with self.test_session():
with tf.variable_scope('A'):
a = variables.variable('a', [5])
b = variables.variable('b', [5])
self.assertEquals([a], variables.get_variables_by_name('a'))
self.assertEquals([b], variables.get_variables_by_name('b'))
def testGetVariables(self):
with self.test_session():
with tf.variable_scope('A'):
a = variables.variable('a', [5])
with tf.variable_scope('B'):
b = variables.variable('a', [5])
self.assertEquals([a, b], variables.get_variables())
self.assertEquals([a], variables.get_variables('A'))
self.assertEquals([b], variables.get_variables('B'))
......@@ -103,19 +86,28 @@ class VariablesTest(tf.test.TestCase):
with tf.variable_scope('A'):
a = variables.variable('a', [5])
with tf.variable_scope('B'):
b = variables.variable('b', [5])
self.assertListEqual([a, b],
tf.get_collection(variables.VARIABLES_TO_RESTORE))
b = variables.variable('a', [5])
self.assertEquals([a, b], variables.get_variables_to_restore())
def testGetVariablesToRestorePartial(self):
def testNoneGetVariablesToRestore(self):
with self.test_session():
with tf.variable_scope('A'):
a = variables.variable('a', [5])
a = variables.variable('a', [5], restore=False)
with tf.variable_scope('B'):
b = variables.variable('a', [5], restore=False)
self.assertEquals([], variables.get_variables_to_restore())
self.assertEquals([a, b], variables.get_variables())
def testGetMixedVariablesToRestore(self):
with self.test_session():
with tf.variable_scope('A'):
a = variables.variable('a', [5])
b = variables.variable('b', [5], restore=False)
self.assertListEqual([a, b], variables.get_variables())
self.assertListEqual([a],
tf.get_collection(variables.VARIABLES_TO_RESTORE))
with tf.variable_scope('B'):
c = variables.variable('c', [5])
d = variables.variable('d', [5], restore=False)
self.assertEquals([a, b, c, d], variables.get_variables())
self.assertEquals([a, c], variables.get_variables_to_restore())
def testReuseVariable(self):
with self.test_session():
......@@ -190,11 +182,49 @@ class VariablesTest(tf.test.TestCase):
collections=['A', 'B']):
b = variables.variable('b', [])
c = variables.variable('c', [])
self.assertListEqual([a, b, c],
tf.get_collection(variables.VARIABLES_TO_RESTORE))
self.assertListEqual([a, b, c], variables.get_variables_to_restore())
self.assertListEqual([a, c], tf.trainable_variables())
self.assertListEqual([b], tf.get_collection('A'))
self.assertListEqual([b], tf.get_collection('B'))
class GetVariablesByNameTest(tf.test.TestCase):
def testGetVariableGivenNameScoped(self):
with self.test_session():
with tf.variable_scope('A'):
a = variables.variable('a', [5])
b = variables.variable('b', [5])
self.assertEquals([a], variables.get_variables_by_name('a'))
self.assertEquals([b], variables.get_variables_by_name('b'))
def testGetVariablesByNameReturnsByValueWithScope(self):
with self.test_session():
with tf.variable_scope('A'):
a = variables.variable('a', [5])
matched_variables = variables.get_variables_by_name('a')
# If variables.get_variables_by_name returns the list by reference, the
# following append should persist, and be returned, in subsequent calls
# to variables.get_variables_by_name('a').
matched_variables.append(4)
matched_variables = variables.get_variables_by_name('a')
self.assertEquals([a], matched_variables)
def testGetVariablesByNameReturnsByValueWithoutScope(self):
with self.test_session():
a = variables.variable('a', [5])
matched_variables = variables.get_variables_by_name('a')
# If variables.get_variables_by_name returns the list by reference, the
# following append should persist, and be returned, in subsequent calls
# to variables.get_variables_by_name('a').
matched_variables.append(4)
matched_variables = variables.get_variables_by_name('a')
self.assertEquals([a], matched_variables)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment