Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9a1dfdf2
Commit
9a1dfdf2
authored
Apr 12, 2016
by
Derek Murray
Browse files
Merge pull request #43 from mrry/newslim
Updated to the latest version of TF-Slim
parents
c74897b0
c74d4385
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
309 additions
and
27 deletions
+309
-27
inception/inception/slim/ops.py
inception/inception/slim/ops.py
+28
-22
inception/inception/slim/ops_test.py
inception/inception/slim/ops_test.py
+42
-0
inception/inception/slim/variables.py
inception/inception/slim/variables.py
+76
-4
inception/inception/slim/variables_test.py
inception/inception/slim/variables_test.py
+163
-1
No files found.
inception/inception/slim/ops.py
View file @
9a1dfdf2
...
@@ -42,6 +42,7 @@ UPDATE_OPS_COLLECTION = '_update_ops_'
...
@@ -42,6 +42,7 @@ UPDATE_OPS_COLLECTION = '_update_ops_'
@
scopes
.
add_arg_scope
@
scopes
.
add_arg_scope
def
batch_norm
(
inputs
,
def
batch_norm
(
inputs
,
decay
=
0.999
,
decay
=
0.999
,
center
=
True
,
scale
=
False
,
scale
=
False
,
epsilon
=
0.001
,
epsilon
=
0.001
,
moving_vars
=
'moving_vars'
,
moving_vars
=
'moving_vars'
,
...
@@ -57,6 +58,7 @@ def batch_norm(inputs,
...
@@ -57,6 +58,7 @@ def batch_norm(inputs,
inputs: a tensor of size [batch_size, height, width, channels]
inputs: a tensor of size [batch_size, height, width, channels]
or [batch_size, channels].
or [batch_size, channels].
decay: decay for the moving average.
decay: decay for the moving average.
center: If True, subtract beta. If False, beta is not created and ignored.
scale: If True, multiply by gamma. If False, gamma is
scale: If True, multiply by gamma. If False, gamma is
not used. When the next layer is linear (also e.g. ReLU), this can be
not used. When the next layer is linear (also e.g. ReLU), this can be
disabled since the scaling can be done by the next layer.
disabled since the scaling can be done by the next layer.
...
@@ -78,31 +80,35 @@ def batch_norm(inputs,
...
@@ -78,31 +80,35 @@ def batch_norm(inputs,
with
tf
.
variable_op_scope
([
inputs
],
scope
,
'BatchNorm'
,
reuse
=
reuse
):
with
tf
.
variable_op_scope
([
inputs
],
scope
,
'BatchNorm'
,
reuse
=
reuse
):
axis
=
list
(
range
(
len
(
inputs_shape
)
-
1
))
axis
=
list
(
range
(
len
(
inputs_shape
)
-
1
))
params_shape
=
inputs_shape
[
-
1
:]
params_shape
=
inputs_shape
[
-
1
:]
with
scopes
.
arg_scope
([
variables
.
variable
],
restore
=
restore
):
# Allocate parameters for the beta and gamma of the normalization.
# Allocate parameters for the beta and gamma of the normalization.
beta
,
gamma
=
None
,
None
if
center
:
beta
=
variables
.
variable
(
'beta'
,
beta
=
variables
.
variable
(
'beta'
,
params_shape
,
params_shape
,
initializer
=
tf
.
zeros_initializer
,
initializer
=
tf
.
zeros_initializer
,
trainable
=
trainable
)
trainable
=
trainable
,
restore
=
restore
)
if
scale
:
if
scale
:
gamma
=
variables
.
variable
(
'gamma'
,
gamma
=
variables
.
variable
(
'gamma'
,
params_shape
,
params_shape
,
initializer
=
tf
.
ones
,
initializer
=
tf
.
ones_initializer
,
trainable
=
trainable
)
trainable
=
trainable
,
else
:
restore
=
restore
)
gamma
=
None
# Create moving_mean and moving_variance add them to
# Create moving_mean and moving_variance add them to moving_vars and
# GraphKeys.MOVING_AVERAGE_VARIABLES collections.
# GraphKeys.MOVING_AVERAGE_VARIABLES collections.
with
scopes
.
arg_scope
([
variables
.
variable
],
trainable
=
False
,
moving_collections
=
[
moving_vars
,
tf
.
GraphKeys
.
MOVING_AVERAGE_VARIABLES
]
collections
=
[
moving_vars
,
tf
.
GraphKeys
.
MOVING_AVERAGE_VARIABLES
]):
moving_mean
=
variables
.
variable
(
'moving_mean'
,
moving_mean
=
variables
.
variable
(
'moving_mean'
,
params_shape
,
params_shape
,
initializer
=
tf
.
zeros_initializer
)
initializer
=
tf
.
zeros_initializer
,
trainable
=
False
,
restore
=
restore
,
collections
=
moving_collections
)
moving_variance
=
variables
.
variable
(
'moving_variance'
,
moving_variance
=
variables
.
variable
(
'moving_variance'
,
params_shape
,
params_shape
,
initializer
=
tf
.
ones
)
initializer
=
tf
.
ones_initializer
,
trainable
=
False
,
restore
=
restore
,
collections
=
moving_collections
)
if
is_training
:
if
is_training
:
# Calculate the moments based on the individual batch.
# Calculate the moments based on the individual batch.
mean
,
variance
=
tf
.
nn
.
moments
(
inputs
,
axis
)
mean
,
variance
=
tf
.
nn
.
moments
(
inputs
,
axis
)
...
@@ -400,7 +406,7 @@ def dropout(inputs, keep_prob=0.5, is_training=True, scope=None):
...
@@ -400,7 +406,7 @@ def dropout(inputs, keep_prob=0.5, is_training=True, scope=None):
Args:
Args:
inputs: the tensor to pass to the Dropout layer.
inputs: the tensor to pass to the Dropout layer.
keep_prob: the probability of
drop
ping each input unit.
keep_prob: the probability of
kee
ping each input unit.
is_training: whether or not the model is in training mode. If so, dropout is
is_training: whether or not the model is in training mode. If so, dropout is
applied and values scaled. Otherwise, inputs is returned.
applied and values scaled. Otherwise, inputs is returned.
scope: Optional scope for op_scope.
scope: Optional scope for op_scope.
...
...
inception/inception/slim/ops_test.py
View file @
9a1dfdf2
...
@@ -476,6 +476,20 @@ class BatchNormTest(tf.test.TestCase):
...
@@ -476,6 +476,20 @@ class BatchNormTest(tf.test.TestCase):
self
.
assertListEqual
(
output
.
get_shape
().
as_list
(),
[
5
,
height
,
width
,
3
])
self
.
assertListEqual
(
output
.
get_shape
().
as_list
(),
[
5
,
height
,
width
,
3
])
def
testCreateVariables
(
self
):
def
testCreateVariables
(
self
):
height
,
width
=
3
,
3
with
self
.
test_session
():
images
=
tf
.
random_uniform
((
5
,
height
,
width
,
3
),
seed
=
1
)
ops
.
batch_norm
(
images
)
beta
=
variables
.
get_variables_by_name
(
'beta'
)[
0
]
self
.
assertEquals
(
beta
.
op
.
name
,
'BatchNorm/beta'
)
gamma
=
variables
.
get_variables_by_name
(
'gamma'
)
self
.
assertEquals
(
gamma
,
[])
moving_mean
=
tf
.
moving_average_variables
()[
0
]
moving_variance
=
tf
.
moving_average_variables
()[
1
]
self
.
assertEquals
(
moving_mean
.
op
.
name
,
'BatchNorm/moving_mean'
)
self
.
assertEquals
(
moving_variance
.
op
.
name
,
'BatchNorm/moving_variance'
)
def
testCreateVariablesWithScale
(
self
):
height
,
width
=
3
,
3
height
,
width
=
3
,
3
with
self
.
test_session
():
with
self
.
test_session
():
images
=
tf
.
random_uniform
((
5
,
height
,
width
,
3
),
seed
=
1
)
images
=
tf
.
random_uniform
((
5
,
height
,
width
,
3
),
seed
=
1
)
...
@@ -489,6 +503,34 @@ class BatchNormTest(tf.test.TestCase):
...
@@ -489,6 +503,34 @@ class BatchNormTest(tf.test.TestCase):
self
.
assertEquals
(
moving_mean
.
op
.
name
,
'BatchNorm/moving_mean'
)
self
.
assertEquals
(
moving_mean
.
op
.
name
,
'BatchNorm/moving_mean'
)
self
.
assertEquals
(
moving_variance
.
op
.
name
,
'BatchNorm/moving_variance'
)
self
.
assertEquals
(
moving_variance
.
op
.
name
,
'BatchNorm/moving_variance'
)
def
testCreateVariablesWithoutCenterWithScale
(
self
):
height
,
width
=
3
,
3
with
self
.
test_session
():
images
=
tf
.
random_uniform
((
5
,
height
,
width
,
3
),
seed
=
1
)
ops
.
batch_norm
(
images
,
center
=
False
,
scale
=
True
)
beta
=
variables
.
get_variables_by_name
(
'beta'
)
self
.
assertEquals
(
beta
,
[])
gamma
=
variables
.
get_variables_by_name
(
'gamma'
)[
0
]
self
.
assertEquals
(
gamma
.
op
.
name
,
'BatchNorm/gamma'
)
moving_mean
=
tf
.
moving_average_variables
()[
0
]
moving_variance
=
tf
.
moving_average_variables
()[
1
]
self
.
assertEquals
(
moving_mean
.
op
.
name
,
'BatchNorm/moving_mean'
)
self
.
assertEquals
(
moving_variance
.
op
.
name
,
'BatchNorm/moving_variance'
)
def
testCreateVariablesWithoutCenterWithoutScale
(
self
):
height
,
width
=
3
,
3
with
self
.
test_session
():
images
=
tf
.
random_uniform
((
5
,
height
,
width
,
3
),
seed
=
1
)
ops
.
batch_norm
(
images
,
center
=
False
,
scale
=
False
)
beta
=
variables
.
get_variables_by_name
(
'beta'
)
self
.
assertEquals
(
beta
,
[])
gamma
=
variables
.
get_variables_by_name
(
'gamma'
)
self
.
assertEquals
(
gamma
,
[])
moving_mean
=
tf
.
moving_average_variables
()[
0
]
moving_variance
=
tf
.
moving_average_variables
()[
1
]
self
.
assertEquals
(
moving_mean
.
op
.
name
,
'BatchNorm/moving_mean'
)
self
.
assertEquals
(
moving_variance
.
op
.
name
,
'BatchNorm/moving_variance'
)
def
testMovingAverageVariables
(
self
):
def
testMovingAverageVariables
(
self
):
height
,
width
=
3
,
3
height
,
width
=
3
,
3
with
self
.
test_session
():
with
self
.
test_session
():
...
...
inception/inception/slim/variables.py
View file @
9a1dfdf2
...
@@ -84,6 +84,7 @@ from __future__ import print_function
...
@@ -84,6 +84,7 @@ from __future__ import print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.core.framework
import
graph_pb2
from
inception.slim
import
scopes
from
inception.slim
import
scopes
# Collection containing all the variables created using slim.variables
# Collection containing all the variables created using slim.variables
...
@@ -171,6 +172,79 @@ def get_unique_variable(name):
...
@@ -171,6 +172,79 @@ def get_unique_variable(name):
raise
ValueError
(
'Variable %s does not uniquely identify a variable'
,
name
)
raise
ValueError
(
'Variable %s does not uniquely identify a variable'
,
name
)
class
VariableDeviceChooser
(
object
):
"""Slim device chooser for variables.
When using a parameter server it will assign them in a round-robin fashion.
When not using a parameter server it allows GPU:0 placement otherwise CPU:0.
"""
def
__init__
(
self
,
num_parameter_servers
=
0
,
ps_device
=
'/job:ps'
,
placement
=
'CPU:0'
):
"""Initialize VariableDeviceChooser.
Args:
num_parameter_servers: number of parameter servers.
ps_device: string representing the parameter server device.
placement: string representing the placement of the variable either CPU:0
or GPU:0. When using parameter servers forced to CPU:0.
"""
self
.
_num_ps
=
num_parameter_servers
self
.
_ps_device
=
ps_device
self
.
_placement
=
placement
if
num_parameter_servers
==
0
else
'CPU:0'
self
.
_next_task_id
=
0
def
__call__
(
self
,
op
):
device_string
=
''
if
self
.
_num_ps
>
0
:
task_id
=
self
.
_next_task_id
self
.
_next_task_id
=
(
self
.
_next_task_id
+
1
)
%
self
.
_num_ps
device_string
=
'%s/task:%d'
%
(
self
.
_ps_device
,
task_id
)
device_string
+=
'/%s'
%
self
.
_placement
return
device_string
# TODO(sguada) Remove once get_variable is able to colocate op.devices.
def
variable_device
(
device
,
name
):
"""Fix the variable device to colocate its ops."""
if
callable
(
device
):
var_name
=
tf
.
get_variable_scope
().
name
+
'/'
+
name
var_def
=
graph_pb2
.
NodeDef
(
name
=
var_name
,
op
=
'Variable'
)
device
=
device
(
var_def
)
if
device
is
None
:
device
=
''
return
device
@
scopes
.
add_arg_scope
def
global_step
(
device
=
''
):
"""Returns the global step variable.
Args:
device: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
Returns:
the tensor representing the global step variable.
"""
global_step_ref
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_STEP
)
if
global_step_ref
:
return
global_step_ref
[
0
]
else
:
collections
=
[
VARIABLES_TO_RESTORE
,
tf
.
GraphKeys
.
VARIABLES
,
tf
.
GraphKeys
.
GLOBAL_STEP
,
]
# Get the device for the variable.
with
tf
.
device
(
variable_device
(
device
,
'global_step'
)):
return
tf
.
get_variable
(
'global_step'
,
shape
=
[],
dtype
=
tf
.
int64
,
initializer
=
tf
.
zeros_initializer
,
trainable
=
False
,
collections
=
collections
)
@
scopes
.
add_arg_scope
@
scopes
.
add_arg_scope
def
variable
(
name
,
shape
=
None
,
dtype
=
tf
.
float32
,
initializer
=
None
,
def
variable
(
name
,
shape
=
None
,
dtype
=
tf
.
float32
,
initializer
=
None
,
regularizer
=
None
,
trainable
=
True
,
collections
=
None
,
device
=
''
,
regularizer
=
None
,
trainable
=
True
,
collections
=
None
,
device
=
''
,
...
@@ -200,9 +274,6 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
...
@@ -200,9 +274,6 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
Returns:
Returns:
The created or existing variable.
The created or existing variable.
"""
"""
# Instantiate the device for this variable if it is passed as a function.
if
device
and
callable
(
device
):
device
=
device
()
collections
=
list
(
collections
or
[])
collections
=
list
(
collections
or
[])
# Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
# Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
...
@@ -212,7 +283,8 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
...
@@ -212,7 +283,8 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
collections
.
append
(
VARIABLES_TO_RESTORE
)
collections
.
append
(
VARIABLES_TO_RESTORE
)
# Remove duplicates
# Remove duplicates
collections
=
set
(
collections
)
collections
=
set
(
collections
)
with
tf
.
device
(
device
):
# Get the device for the variable.
with
tf
.
device
(
variable_device
(
device
,
name
)):
return
tf
.
get_variable
(
name
,
shape
=
shape
,
dtype
=
dtype
,
return
tf
.
get_variable
(
name
,
shape
=
shape
,
dtype
=
dtype
,
initializer
=
initializer
,
regularizer
=
regularizer
,
initializer
=
initializer
,
regularizer
=
regularizer
,
trainable
=
trainable
,
collections
=
collections
)
trainable
=
trainable
,
collections
=
collections
)
inception/inception/slim/variables_test.py
View file @
9a1dfdf2
...
@@ -134,6 +134,109 @@ class VariablesTest(tf.test.TestCase):
...
@@ -134,6 +134,109 @@ class VariablesTest(tf.test.TestCase):
self
.
assertDeviceEqual
(
a
.
device
,
'cpu:0'
)
self
.
assertDeviceEqual
(
a
.
device
,
'cpu:0'
)
self
.
assertDeviceEqual
(
b
.
device
,
'cpu:1'
)
self
.
assertDeviceEqual
(
b
.
device
,
'cpu:1'
)
def
testVariableWithDeviceFunction
(
self
):
class
DevFn
(
object
):
def
__init__
(
self
):
self
.
counter
=
-
1
def
__call__
(
self
,
op
):
self
.
counter
+=
1
return
'cpu:%d'
%
self
.
counter
with
self
.
test_session
():
with
scopes
.
arg_scope
([
variables
.
variable
],
device
=
DevFn
()):
a
=
variables
.
variable
(
'a'
,
[])
b
=
variables
.
variable
(
'b'
,
[])
c
=
variables
.
variable
(
'c'
,
[],
device
=
'cpu:12'
)
d
=
variables
.
variable
(
'd'
,
[])
with
tf
.
device
(
'cpu:99'
):
e_init
=
tf
.
constant
(
12
)
e
=
variables
.
variable
(
'e'
,
initializer
=
e_init
)
self
.
assertDeviceEqual
(
a
.
device
,
'cpu:0'
)
self
.
assertDeviceEqual
(
a
.
initial_value
.
device
,
'cpu:0'
)
self
.
assertDeviceEqual
(
b
.
device
,
'cpu:1'
)
self
.
assertDeviceEqual
(
b
.
initial_value
.
device
,
'cpu:1'
)
self
.
assertDeviceEqual
(
c
.
device
,
'cpu:12'
)
self
.
assertDeviceEqual
(
c
.
initial_value
.
device
,
'cpu:12'
)
self
.
assertDeviceEqual
(
d
.
device
,
'cpu:2'
)
self
.
assertDeviceEqual
(
d
.
initial_value
.
device
,
'cpu:2'
)
self
.
assertDeviceEqual
(
e
.
device
,
'cpu:3'
)
self
.
assertDeviceEqual
(
e
.
initial_value
.
device
,
'cpu:99'
)
def
testVariableWithReplicaDeviceSetter
(
self
):
with
self
.
test_session
():
with
tf
.
device
(
tf
.
train
.
replica_device_setter
(
ps_tasks
=
2
)):
a
=
variables
.
variable
(
'a'
,
[])
b
=
variables
.
variable
(
'b'
,
[])
c
=
variables
.
variable
(
'c'
,
[],
device
=
'cpu:12'
)
d
=
variables
.
variable
(
'd'
,
[])
with
tf
.
device
(
'cpu:99'
):
e_init
=
tf
.
constant
(
12
)
e
=
variables
.
variable
(
'e'
,
initializer
=
e_init
)
# The values below highlight how the replica_device_setter puts initial
# values on the worker job, and how it merges explicit devices.
self
.
assertDeviceEqual
(
a
.
device
,
'/job:ps/task:0/cpu:0'
)
self
.
assertDeviceEqual
(
a
.
initial_value
.
device
,
'/job:worker/cpu:0'
)
self
.
assertDeviceEqual
(
b
.
device
,
'/job:ps/task:1/cpu:0'
)
self
.
assertDeviceEqual
(
b
.
initial_value
.
device
,
'/job:worker/cpu:0'
)
self
.
assertDeviceEqual
(
c
.
device
,
'/job:ps/task:0/cpu:12'
)
self
.
assertDeviceEqual
(
c
.
initial_value
.
device
,
'/job:worker/cpu:12'
)
self
.
assertDeviceEqual
(
d
.
device
,
'/job:ps/task:1/cpu:0'
)
self
.
assertDeviceEqual
(
d
.
initial_value
.
device
,
'/job:worker/cpu:0'
)
self
.
assertDeviceEqual
(
e
.
device
,
'/job:ps/task:0/cpu:0'
)
self
.
assertDeviceEqual
(
e
.
initial_value
.
device
,
'/job:worker/cpu:99'
)
def
testVariableWithVariableDeviceChooser
(
self
):
with
tf
.
Graph
().
as_default
():
device_fn
=
variables
.
VariableDeviceChooser
(
num_parameter_servers
=
2
)
with
scopes
.
arg_scope
([
variables
.
variable
],
device
=
device_fn
):
a
=
variables
.
variable
(
'a'
,
[])
b
=
variables
.
variable
(
'b'
,
[])
c
=
variables
.
variable
(
'c'
,
[],
device
=
'cpu:12'
)
d
=
variables
.
variable
(
'd'
,
[])
with
tf
.
device
(
'cpu:99'
):
e_init
=
tf
.
constant
(
12
)
e
=
variables
.
variable
(
'e'
,
initializer
=
e_init
)
# The values below highlight how the VariableDeviceChooser puts initial
# values on the same device as the variable job.
self
.
assertDeviceEqual
(
a
.
device
,
'/job:ps/task:0/cpu:0'
)
self
.
assertDeviceEqual
(
a
.
initial_value
.
device
,
a
.
device
)
self
.
assertDeviceEqual
(
b
.
device
,
'/job:ps/task:1/cpu:0'
)
self
.
assertDeviceEqual
(
b
.
initial_value
.
device
,
b
.
device
)
self
.
assertDeviceEqual
(
c
.
device
,
'/cpu:12'
)
self
.
assertDeviceEqual
(
c
.
initial_value
.
device
,
c
.
device
)
self
.
assertDeviceEqual
(
d
.
device
,
'/job:ps/task:0/cpu:0'
)
self
.
assertDeviceEqual
(
d
.
initial_value
.
device
,
d
.
device
)
self
.
assertDeviceEqual
(
e
.
device
,
'/job:ps/task:1/cpu:0'
)
self
.
assertDeviceEqual
(
e
.
initial_value
.
device
,
'/cpu:99'
)
def
testVariableGPUPlacement
(
self
):
with
tf
.
Graph
().
as_default
():
device_fn
=
variables
.
VariableDeviceChooser
(
placement
=
'gpu:0'
)
with
scopes
.
arg_scope
([
variables
.
variable
],
device
=
device_fn
):
a
=
variables
.
variable
(
'a'
,
[])
b
=
variables
.
variable
(
'b'
,
[])
c
=
variables
.
variable
(
'c'
,
[],
device
=
'cpu:12'
)
d
=
variables
.
variable
(
'd'
,
[])
with
tf
.
device
(
'cpu:99'
):
e_init
=
tf
.
constant
(
12
)
e
=
variables
.
variable
(
'e'
,
initializer
=
e_init
)
# The values below highlight how the VariableDeviceChooser puts initial
# values on the same device as the variable job.
self
.
assertDeviceEqual
(
a
.
device
,
'/gpu:0'
)
self
.
assertDeviceEqual
(
a
.
initial_value
.
device
,
a
.
device
)
self
.
assertDeviceEqual
(
b
.
device
,
'/gpu:0'
)
self
.
assertDeviceEqual
(
b
.
initial_value
.
device
,
b
.
device
)
self
.
assertDeviceEqual
(
c
.
device
,
'/cpu:12'
)
self
.
assertDeviceEqual
(
c
.
initial_value
.
device
,
c
.
device
)
self
.
assertDeviceEqual
(
d
.
device
,
'/gpu:0'
)
self
.
assertDeviceEqual
(
d
.
initial_value
.
device
,
d
.
device
)
self
.
assertDeviceEqual
(
e
.
device
,
'/gpu:0'
)
self
.
assertDeviceEqual
(
e
.
initial_value
.
device
,
'/cpu:99'
)
def
testVariableCollection
(
self
):
def
testVariableCollection
(
self
):
with
self
.
test_session
():
with
self
.
test_session
():
a
=
variables
.
variable
(
'a'
,
[],
collections
=
'A'
)
a
=
variables
.
variable
(
'a'
,
[],
collections
=
'A'
)
...
@@ -178,7 +281,8 @@ class VariablesTest(tf.test.TestCase):
...
@@ -178,7 +281,8 @@ class VariablesTest(tf.test.TestCase):
with
self
.
test_session
():
with
self
.
test_session
():
with
scopes
.
arg_scope
([
variables
.
variable
],
restore
=
True
):
with
scopes
.
arg_scope
([
variables
.
variable
],
restore
=
True
):
a
=
variables
.
variable
(
'a'
,
[])
a
=
variables
.
variable
(
'a'
,
[])
with
scopes
.
arg_scope
([
variables
.
variable
],
trainable
=
False
,
with
scopes
.
arg_scope
([
variables
.
variable
],
trainable
=
False
,
collections
=
[
'A'
,
'B'
]):
collections
=
[
'A'
,
'B'
]):
b
=
variables
.
variable
(
'b'
,
[])
b
=
variables
.
variable
(
'b'
,
[])
c
=
variables
.
variable
(
'c'
,
[])
c
=
variables
.
variable
(
'c'
,
[])
...
@@ -226,5 +330,63 @@ class GetVariablesByNameTest(tf.test.TestCase):
...
@@ -226,5 +330,63 @@ class GetVariablesByNameTest(tf.test.TestCase):
self
.
assertEquals
([
a
],
matched_variables
)
self
.
assertEquals
([
a
],
matched_variables
)
class
GlobalStepTest
(
tf
.
test
.
TestCase
):
def
testStable
(
self
):
with
tf
.
Graph
().
as_default
():
gs
=
variables
.
global_step
()
gs2
=
variables
.
global_step
()
self
.
assertTrue
(
gs
is
gs2
)
def
testDevice
(
self
):
with
tf
.
Graph
().
as_default
():
with
scopes
.
arg_scope
([
variables
.
global_step
],
device
=
'/gpu:0'
):
gs
=
variables
.
global_step
()
self
.
assertDeviceEqual
(
gs
.
device
,
'/gpu:0'
)
def
testDeviceFn
(
self
):
class
DevFn
(
object
):
def
__init__
(
self
):
self
.
counter
=
-
1
def
__call__
(
self
,
op
):
self
.
counter
+=
1
return
'/cpu:%d'
%
self
.
counter
with
tf
.
Graph
().
as_default
():
with
scopes
.
arg_scope
([
variables
.
global_step
],
device
=
DevFn
()):
gs
=
variables
.
global_step
()
gs2
=
variables
.
global_step
()
self
.
assertDeviceEqual
(
gs
.
device
,
'/cpu:0'
)
self
.
assertEquals
(
gs
,
gs2
)
self
.
assertDeviceEqual
(
gs2
.
device
,
'/cpu:0'
)
def
testReplicaDeviceSetter
(
self
):
device_fn
=
tf
.
train
.
replica_device_setter
(
2
)
with
tf
.
Graph
().
as_default
():
with
scopes
.
arg_scope
([
variables
.
global_step
],
device
=
device_fn
):
gs
=
variables
.
global_step
()
gs2
=
variables
.
global_step
()
self
.
assertEquals
(
gs
,
gs2
)
self
.
assertDeviceEqual
(
gs
.
device
,
'/job:ps/task:0'
)
self
.
assertDeviceEqual
(
gs
.
initial_value
.
device
,
'/job:ps/task:0'
)
self
.
assertDeviceEqual
(
gs2
.
device
,
'/job:ps/task:0'
)
self
.
assertDeviceEqual
(
gs2
.
initial_value
.
device
,
'/job:ps/task:0'
)
def
testVariableWithVariableDeviceChooser
(
self
):
with
tf
.
Graph
().
as_default
():
device_fn
=
variables
.
VariableDeviceChooser
()
with
scopes
.
arg_scope
([
variables
.
global_step
],
device
=
device_fn
):
gs
=
variables
.
global_step
()
gs2
=
variables
.
global_step
()
self
.
assertEquals
(
gs
,
gs2
)
self
.
assertDeviceEqual
(
gs
.
device
,
'cpu:0'
)
self
.
assertDeviceEqual
(
gs
.
initial_value
.
device
,
gs
.
device
)
self
.
assertDeviceEqual
(
gs2
.
device
,
'cpu:0'
)
self
.
assertDeviceEqual
(
gs2
.
initial_value
.
device
,
gs2
.
device
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment