variables_helper.py 7.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Helper functions for manipulating collections of variables during training.
"""
pkulzc's avatar
pkulzc committed
18
19
20
21
22

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

23
24
25
import logging
import re

26
27
import tensorflow.compat.v1 as tf
import tf_slim as slim
28

29
30
from tensorflow.python.ops import variables as tf_variables

31

32
# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# tensorflow/contrib/framework/python/ops/variables.py
def filter_variables(variables, filter_regex_list, invert=False):
  """Filters out the variables matching the filter_regex.

  Filter out the variables whose name matches the any of the regular
  expressions in filter_regex_list and returns the remaining variables.
  Optionally, if invert=True, the complement set is returned.

  Args:
    variables: a list of tensorflow variables.
    filter_regex_list: a list of string regular expressions.
    invert: (boolean).  If True, returns the complement of the filter set; that
      is, all variables matching filter_regex are kept and all others discarded.

  Returns:
    a list of filtered variables.
  """
50
51
  if tf.executing_eagerly():
    raise ValueError('Accessing variables is not supported in eager mode.')
52
  kept_vars = []
pkulzc's avatar
pkulzc committed
53
  variables_to_ignore_patterns = list([fre for fre in filter_regex_list if fre])
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  for var in variables:
    add = True
    for pattern in variables_to_ignore_patterns:
      if re.match(pattern, var.op.name):
        add = False
        break
    if add != invert:
      kept_vars.append(var)
  return kept_vars


def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier):
  """Multiply gradients whose variable names match a regular expression.

  Args:
    grads_and_vars: A list of gradient to variable pairs (tuples).
    regex_list: A list of string regular expressions.
    multiplier: A (float) multiplier to apply to each gradient matching the
      regular expression.

  Returns:
    grads_and_vars: A list of gradient to variable pairs (tuples).
  """
77
78
  if tf.executing_eagerly():
    raise ValueError('Accessing variables is not supported in eager mode.')
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
  variables = [pair[1] for pair in grads_and_vars]
  matching_vars = filter_variables(variables, regex_list, invert=True)
  for var in matching_vars:
    logging.info('Applying multiplier %f to variable [%s]',
                 multiplier, var.op.name)
  grad_multipliers = {var: float(multiplier) for var in matching_vars}
  return slim.learning.multiply_gradients(grads_and_vars,
                                          grad_multipliers)


def freeze_gradients_matching_regex(grads_and_vars, regex_list):
  """Freeze gradients whose variable names match a regular expression.

  Args:
    grads_and_vars: A list of gradient to variable pairs (tuples).
    regex_list: A list of string regular expressions.

  Returns:
    grads_and_vars: A list of gradient to variable pairs (tuples) that do not
      contain the variables and gradients matching the regex.
  """
100
101
  if tf.executing_eagerly():
    raise ValueError('Accessing variables is not supported in eager mode.')
102
103
104
105
106
107
108
109
110
  variables = [pair[1] for pair in grads_and_vars]
  matching_vars = filter_variables(variables, regex_list, invert=True)
  kept_grads_and_vars = [pair for pair in grads_and_vars
                         if pair[1] not in matching_vars]
  for var in matching_vars:
    logging.info('Freezing variable [%s]', var.op.name)
  return kept_grads_and_vars


111
112
113
def get_variables_available_in_checkpoint(variables,
                                          checkpoint_path,
                                          include_global_step=True):
114
115
116
117
118
  """Returns the subset of variables available in the checkpoint.

  Inspects given checkpoint and returns the subset of variables that are
  available in it.

119
  TODO(rathodv): force input and output to be a dictionary.
120
121
122
123

  Args:
    variables: a list or dictionary of variables to find in checkpoint.
    checkpoint_path: path to the checkpoint to restore variables from.
124
125
    include_global_step: whether to include `global_step` variable, if it
      exists. Default True.
126
127
128
129
130
131

  Returns:
    A list or dictionary of variables.
  Raises:
    ValueError: if `variables` is not a list or dict.
  """
132
133
  if tf.executing_eagerly():
    raise ValueError('Accessing variables is not supported in eager mode.')
134
  if isinstance(variables, list):
135
136
137
138
139
140
141
    variable_names_map = {}
    for variable in variables:
      if isinstance(variable, tf_variables.PartitionedVariable):
        name = variable.name
      else:
        name = variable.op.name
      variable_names_map[name] = variable
142
143
144
145
146
  elif isinstance(variables, dict):
    variable_names_map = variables
  else:
    raise ValueError('`variables` is expected to be a list or dict.')
  ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
147
148
149
  ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
  if not include_global_step:
    ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
150
  vars_in_ckpt = {}
151
  for variable_name, variable in sorted(variable_names_map.items()):
152
153
154
155
156
    if variable_name in ckpt_vars_to_shape_map:
      if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
        vars_in_ckpt[variable_name] = variable
      else:
        logging.warning('Variable [%s] is available in checkpoint, but has an '
157
158
159
160
161
                        'incompatible shape with model variable. Checkpoint '
                        'shape: [%s], model variable shape: [%s]. This '
                        'variable will not be initialized from the checkpoint.',
                        variable_name, ckpt_vars_to_shape_map[variable_name],
                        variable.shape.as_list())
162
    else:
163
      logging.warning('Variable [%s] is not available in checkpoint',
164
165
                      variable_name)
  if isinstance(variables, list):
pkulzc's avatar
pkulzc committed
166
    return list(vars_in_ckpt.values())
167
  return vars_in_ckpt
pkulzc's avatar
pkulzc committed
168
169
170
171
172
173
174
175
176
177
178
179
180


def get_global_variables_safely():
  """If not executing eagerly, returns tf.global_variables().

  Raises a ValueError if eager execution is enabled,
  because the variables are not tracked when executing eagerly.

  If executing eagerly, use a Keras model's .variables property instead.

  Returns:
    The result of tf.global_variables()
  """
181
182
  if tf.executing_eagerly():
    raise ValueError('Accessing variables is not supported in eager mode.')
pkulzc's avatar
pkulzc committed
183
184
185
186
187
188
  with tf.init_scope():
    if tf.executing_eagerly():
      raise ValueError("Global variables collection is not tracked when "
                       "executing eagerly. Use a Keras model's `.variables` "
                       "attribute instead.")
  return tf.global_variables()