Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f282f6ef
Commit
f282f6ef
authored
Jul 05, 2017
by
Alexander Gorban
Browse files
Merge branch 'master' of github.com:tensorflow/models
parents
58a5da7b
a2970b03
Changes
302
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
318 additions
and
0 deletions
+318
-0
object_detection/utils/variables_helper.py
object_detection/utils/variables_helper.py
+133
-0
object_detection/utils/variables_helper_test.py
object_detection/utils/variables_helper_test.py
+185
-0
No files found.
Too many changes to show.
To preserve performance only
302 of 302+
files are displayed.
Plain diff
Email patch
object_detection/utils/variables_helper.py
0 → 100644
View file @
f282f6ef
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for manipulating collections of variables during training.
"""
import
logging
import
re
import
tensorflow
as
tf
slim
=
tf
.
contrib
.
slim
# TODO: Consider replacing with tf.contrib.filter_variables in
# tensorflow/contrib/framework/python/ops/variables.py
def
filter_variables
(
variables
,
filter_regex_list
,
invert
=
False
):
"""Filters out the variables matching the filter_regex.
Filter out the variables whose name matches the any of the regular
expressions in filter_regex_list and returns the remaining variables.
Optionally, if invert=True, the complement set is returned.
Args:
variables: a list of tensorflow variables.
filter_regex_list: a list of string regular expressions.
invert: (boolean). If True, returns the complement of the filter set; that
is, all variables matching filter_regex are kept and all others discarded.
Returns:
a list of filtered variables.
"""
kept_vars
=
[]
variables_to_ignore_patterns
=
filter
(
None
,
filter_regex_list
)
for
var
in
variables
:
add
=
True
for
pattern
in
variables_to_ignore_patterns
:
if
re
.
match
(
pattern
,
var
.
op
.
name
):
add
=
False
break
if
add
!=
invert
:
kept_vars
.
append
(
var
)
return
kept_vars
def
multiply_gradients_matching_regex
(
grads_and_vars
,
regex_list
,
multiplier
):
"""Multiply gradients whose variable names match a regular expression.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
regex_list: A list of string regular expressions.
multiplier: A (float) multiplier to apply to each gradient matching the
regular expression.
Returns:
grads_and_vars: A list of gradient to variable pairs (tuples).
"""
variables
=
[
pair
[
1
]
for
pair
in
grads_and_vars
]
matching_vars
=
filter_variables
(
variables
,
regex_list
,
invert
=
True
)
for
var
in
matching_vars
:
logging
.
info
(
'Applying multiplier %f to variable [%s]'
,
multiplier
,
var
.
op
.
name
)
grad_multipliers
=
{
var
:
float
(
multiplier
)
for
var
in
matching_vars
}
return
slim
.
learning
.
multiply_gradients
(
grads_and_vars
,
grad_multipliers
)
def
freeze_gradients_matching_regex
(
grads_and_vars
,
regex_list
):
"""Freeze gradients whose variable names match a regular expression.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
regex_list: A list of string regular expressions.
Returns:
grads_and_vars: A list of gradient to variable pairs (tuples) that do not
contain the variables and gradients matching the regex.
"""
variables
=
[
pair
[
1
]
for
pair
in
grads_and_vars
]
matching_vars
=
filter_variables
(
variables
,
regex_list
,
invert
=
True
)
kept_grads_and_vars
=
[
pair
for
pair
in
grads_and_vars
if
pair
[
1
]
not
in
matching_vars
]
for
var
in
matching_vars
:
logging
.
info
(
'Freezing variable [%s]'
,
var
.
op
.
name
)
return
kept_grads_and_vars
def
get_variables_available_in_checkpoint
(
variables
,
checkpoint_path
):
"""Returns the subset of variables available in the checkpoint.
Inspects given checkpoint and returns the subset of variables that are
available in it.
TODO: force input and output to be a dictionary.
Args:
variables: a list or dictionary of variables to find in checkpoint.
checkpoint_path: path to the checkpoint to restore variables from.
Returns:
A list or dictionary of variables.
Raises:
ValueError: if `variables` is not a list or dict.
"""
if
isinstance
(
variables
,
list
):
variable_names_map
=
{
variable
.
op
.
name
:
variable
for
variable
in
variables
}
elif
isinstance
(
variables
,
dict
):
variable_names_map
=
variables
else
:
raise
ValueError
(
'`variables` is expected to be a list or dict.'
)
ckpt_reader
=
tf
.
train
.
NewCheckpointReader
(
checkpoint_path
)
ckpt_vars
=
ckpt_reader
.
get_variable_to_shape_map
().
keys
()
vars_in_ckpt
=
{}
for
variable_name
,
variable
in
sorted
(
variable_names_map
.
items
()):
if
variable_name
in
ckpt_vars
:
vars_in_ckpt
[
variable_name
]
=
variable
else
:
logging
.
warning
(
'Variable [%s] not available in checkpoint'
,
variable_name
)
if
isinstance
(
variables
,
list
):
return
vars_in_ckpt
.
values
()
return
vars_in_ckpt
object_detection/utils/variables_helper_test.py
0 → 100644
View file @
f282f6ef
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.utils.variables_helper."""
import
os
import
tensorflow
as
tf
from
object_detection.utils
import
variables_helper
class
FilterVariablesTest
(
tf
.
test
.
TestCase
):
def
_create_variables
(
self
):
return
[
tf
.
Variable
(
1.0
,
name
=
'FeatureExtractor/InceptionV3/weights'
),
tf
.
Variable
(
1.0
,
name
=
'FeatureExtractor/InceptionV3/biases'
),
tf
.
Variable
(
1.0
,
name
=
'StackProposalGenerator/weights'
),
tf
.
Variable
(
1.0
,
name
=
'StackProposalGenerator/biases'
)]
def
test_return_all_variables_when_empty_regex
(
self
):
variables
=
self
.
_create_variables
()
out_variables
=
variables_helper
.
filter_variables
(
variables
,
[
''
])
self
.
assertItemsEqual
(
out_variables
,
variables
)
def
test_return_variables_which_do_not_match_single_regex
(
self
):
variables
=
self
.
_create_variables
()
out_variables
=
variables_helper
.
filter_variables
(
variables
,
[
'FeatureExtractor/.*'
])
self
.
assertItemsEqual
(
out_variables
,
variables
[
2
:])
def
test_return_variables_which_do_not_match_any_regex_in_list
(
self
):
variables
=
self
.
_create_variables
()
out_variables
=
variables_helper
.
filter_variables
(
variables
,
[
'FeatureExtractor.*biases'
,
'StackProposalGenerator.*biases'
])
self
.
assertItemsEqual
(
out_variables
,
[
variables
[
0
],
variables
[
2
]])
def
test_return_variables_matching_empty_regex_list
(
self
):
variables
=
self
.
_create_variables
()
out_variables
=
variables_helper
.
filter_variables
(
variables
,
[
''
],
invert
=
True
)
self
.
assertItemsEqual
(
out_variables
,
[])
def
test_return_variables_matching_some_regex_in_list
(
self
):
variables
=
self
.
_create_variables
()
out_variables
=
variables_helper
.
filter_variables
(
variables
,
[
'FeatureExtractor.*biases'
,
'StackProposalGenerator.*biases'
],
invert
=
True
)
self
.
assertItemsEqual
(
out_variables
,
[
variables
[
1
],
variables
[
3
]])
class
MultiplyGradientsMatchingRegexTest
(
tf
.
test
.
TestCase
):
def
_create_grads_and_vars
(
self
):
return
[(
tf
.
constant
(
1.0
),
tf
.
Variable
(
1.0
,
name
=
'FeatureExtractor/InceptionV3/weights'
)),
(
tf
.
constant
(
2.0
),
tf
.
Variable
(
2.0
,
name
=
'FeatureExtractor/InceptionV3/biases'
)),
(
tf
.
constant
(
3.0
),
tf
.
Variable
(
3.0
,
name
=
'StackProposalGenerator/weights'
)),
(
tf
.
constant
(
4.0
),
tf
.
Variable
(
4.0
,
name
=
'StackProposalGenerator/biases'
))]
def
test_multiply_all_feature_extractor_variables
(
self
):
grads_and_vars
=
self
.
_create_grads_and_vars
()
regex_list
=
[
'FeatureExtractor/.*'
]
multiplier
=
0.0
grads_and_vars
=
variables_helper
.
multiply_gradients_matching_regex
(
grads_and_vars
,
regex_list
,
multiplier
)
exp_output
=
[(
0.0
,
1.0
),
(
0.0
,
2.0
),
(
3.0
,
3.0
),
(
4.0
,
4.0
)]
init_op
=
tf
.
global_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
output
=
sess
.
run
(
grads_and_vars
)
self
.
assertItemsEqual
(
output
,
exp_output
)
def
test_multiply_all_bias_variables
(
self
):
grads_and_vars
=
self
.
_create_grads_and_vars
()
regex_list
=
[
'.*/biases'
]
multiplier
=
0.0
grads_and_vars
=
variables_helper
.
multiply_gradients_matching_regex
(
grads_and_vars
,
regex_list
,
multiplier
)
exp_output
=
[(
1.0
,
1.0
),
(
0.0
,
2.0
),
(
3.0
,
3.0
),
(
0.0
,
4.0
)]
init_op
=
tf
.
global_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
output
=
sess
.
run
(
grads_and_vars
)
self
.
assertItemsEqual
(
output
,
exp_output
)
class
FreezeGradientsMatchingRegexTest
(
tf
.
test
.
TestCase
):
def
_create_grads_and_vars
(
self
):
return
[(
tf
.
constant
(
1.0
),
tf
.
Variable
(
1.0
,
name
=
'FeatureExtractor/InceptionV3/weights'
)),
(
tf
.
constant
(
2.0
),
tf
.
Variable
(
2.0
,
name
=
'FeatureExtractor/InceptionV3/biases'
)),
(
tf
.
constant
(
3.0
),
tf
.
Variable
(
3.0
,
name
=
'StackProposalGenerator/weights'
)),
(
tf
.
constant
(
4.0
),
tf
.
Variable
(
4.0
,
name
=
'StackProposalGenerator/biases'
))]
def
test_freeze_all_feature_extractor_variables
(
self
):
grads_and_vars
=
self
.
_create_grads_and_vars
()
regex_list
=
[
'FeatureExtractor/.*'
]
grads_and_vars
=
variables_helper
.
freeze_gradients_matching_regex
(
grads_and_vars
,
regex_list
)
exp_output
=
[(
3.0
,
3.0
),
(
4.0
,
4.0
)]
init_op
=
tf
.
global_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
output
=
sess
.
run
(
grads_and_vars
)
self
.
assertItemsEqual
(
output
,
exp_output
)
class
GetVariablesAvailableInCheckpointTest
(
tf
.
test
.
TestCase
):
def
test_return_all_variables_from_checkpoint
(
self
):
variables
=
[
tf
.
Variable
(
1.0
,
name
=
'weights'
),
tf
.
Variable
(
1.0
,
name
=
'biases'
)
]
checkpoint_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'graph.pb'
)
init_op
=
tf
.
global_variables_initializer
()
saver
=
tf
.
train
.
Saver
(
variables
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
saver
.
save
(
sess
,
checkpoint_path
)
out_variables
=
variables_helper
.
get_variables_available_in_checkpoint
(
variables
,
checkpoint_path
)
self
.
assertItemsEqual
(
out_variables
,
variables
)
def
test_return_variables_available_in_checkpoint
(
self
):
checkpoint_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'graph.pb'
)
graph1_variables
=
[
tf
.
Variable
(
1.0
,
name
=
'weights'
),
]
init_op
=
tf
.
global_variables_initializer
()
saver
=
tf
.
train
.
Saver
(
graph1_variables
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
saver
.
save
(
sess
,
checkpoint_path
)
graph2_variables
=
graph1_variables
+
[
tf
.
Variable
(
1.0
,
name
=
'biases'
)]
out_variables
=
variables_helper
.
get_variables_available_in_checkpoint
(
graph2_variables
,
checkpoint_path
)
self
.
assertItemsEqual
(
out_variables
,
graph1_variables
)
def
test_return_variables_available_an_checkpoint_with_dict_inputs
(
self
):
checkpoint_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'graph.pb'
)
graph1_variables
=
[
tf
.
Variable
(
1.0
,
name
=
'ckpt_weights'
),
]
init_op
=
tf
.
global_variables_initializer
()
saver
=
tf
.
train
.
Saver
(
graph1_variables
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
saver
.
save
(
sess
,
checkpoint_path
)
graph2_variables_dict
=
{
'ckpt_weights'
:
tf
.
Variable
(
1.0
,
name
=
'weights'
),
'ckpt_biases'
:
tf
.
Variable
(
1.0
,
name
=
'biases'
)
}
out_variables
=
variables_helper
.
get_variables_available_in_checkpoint
(
graph2_variables_dict
,
checkpoint_path
)
self
.
assertTrue
(
isinstance
(
out_variables
,
dict
))
self
.
assertItemsEqual
(
out_variables
.
keys
(),
[
'ckpt_weights'
])
self
.
assertTrue
(
out_variables
[
'ckpt_weights'
].
op
.
name
==
'weights'
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Prev
1
…
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment