variables_helper.py 6.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Helper functions for manipulating collections of variables during training.
"""
pkulzc's avatar
pkulzc committed
18
19
20
21
22

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

23
24
25
import logging
import re

26
27
import tensorflow.compat.v1 as tf
import tf_slim as slim
28

29
30
from tensorflow.python.ops import variables as tf_variables

31

32
# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# tensorflow/contrib/framework/python/ops/variables.py
def filter_variables(variables, filter_regex_list, invert=False):
  """Filters out the variables matching the filter_regex.

  Filter out the variables whose name matches the any of the regular
  expressions in filter_regex_list and returns the remaining variables.
  Optionally, if invert=True, the complement set is returned.

  Args:
    variables: a list of tensorflow variables.
    filter_regex_list: a list of string regular expressions.
    invert: (boolean).  If True, returns the complement of the filter set; that
      is, all variables matching filter_regex are kept and all others discarded.

  Returns:
    a list of filtered variables.
  """
  kept_vars = []
pkulzc's avatar
pkulzc committed
51
  variables_to_ignore_patterns = list([fre for fre in filter_regex_list if fre])
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
  for var in variables:
    add = True
    for pattern in variables_to_ignore_patterns:
      if re.match(pattern, var.op.name):
        add = False
        break
    if add != invert:
      kept_vars.append(var)
  return kept_vars


def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier):
  """Multiply gradients whose variable names match a regular expression.

  Args:
    grads_and_vars: A list of gradient to variable pairs (tuples).
    regex_list: A list of string regular expressions.
    multiplier: A (float) multiplier to apply to each gradient matching the
      regular expression.

  Returns:
    grads_and_vars: A list of gradient to variable pairs (tuples).
  """
  variables = [pair[1] for pair in grads_and_vars]
  matching_vars = filter_variables(variables, regex_list, invert=True)
  for var in matching_vars:
    logging.info('Applying multiplier %f to variable [%s]',
                 multiplier, var.op.name)
  grad_multipliers = {var: float(multiplier) for var in matching_vars}
  return slim.learning.multiply_gradients(grads_and_vars,
                                          grad_multipliers)


def freeze_gradients_matching_regex(grads_and_vars, regex_list):
  """Freeze gradients whose variable names match a regular expression.

  Args:
    grads_and_vars: A list of gradient to variable pairs (tuples).
    regex_list: A list of string regular expressions.

  Returns:
    grads_and_vars: A list of gradient to variable pairs (tuples) that do not
      contain the variables and gradients matching the regex.
  """
  variables = [pair[1] for pair in grads_and_vars]
  matching_vars = filter_variables(variables, regex_list, invert=True)
  kept_grads_and_vars = [pair for pair in grads_and_vars
                         if pair[1] not in matching_vars]
  for var in matching_vars:
    logging.info('Freezing variable [%s]', var.op.name)
  return kept_grads_and_vars


105
106
107
def get_variables_available_in_checkpoint(variables,
                                          checkpoint_path,
                                          include_global_step=True):
108
109
110
111
112
  """Returns the subset of variables available in the checkpoint.

  Inspects given checkpoint and returns the subset of variables that are
  available in it.

113
  TODO(rathodv): force input and output to be a dictionary.
114
115
116
117

  Args:
    variables: a list or dictionary of variables to find in checkpoint.
    checkpoint_path: path to the checkpoint to restore variables from.
118
119
    include_global_step: whether to include `global_step` variable, if it
      exists. Default True.
120
121
122
123
124
125
126

  Returns:
    A list or dictionary of variables.
  Raises:
    ValueError: if `variables` is not a list or dict.
  """
  if isinstance(variables, list):
127
128
129
130
131
132
133
    variable_names_map = {}
    for variable in variables:
      if isinstance(variable, tf_variables.PartitionedVariable):
        name = variable.name
      else:
        name = variable.op.name
      variable_names_map[name] = variable
134
135
136
137
138
  elif isinstance(variables, dict):
    variable_names_map = variables
  else:
    raise ValueError('`variables` is expected to be a list or dict.')
  ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
139
140
141
  ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
  if not include_global_step:
    ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
142
  vars_in_ckpt = {}
143
  for variable_name, variable in sorted(variable_names_map.items()):
144
145
146
147
148
    if variable_name in ckpt_vars_to_shape_map:
      if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
        vars_in_ckpt[variable_name] = variable
      else:
        logging.warning('Variable [%s] is available in checkpoint, but has an '
149
150
151
152
153
                        'incompatible shape with model variable. Checkpoint '
                        'shape: [%s], model variable shape: [%s]. This '
                        'variable will not be initialized from the checkpoint.',
                        variable_name, ckpt_vars_to_shape_map[variable_name],
                        variable.shape.as_list())
154
    else:
155
      logging.warning('Variable [%s] is not available in checkpoint',
156
157
                      variable_name)
  if isinstance(variables, list):
pkulzc's avatar
pkulzc committed
158
    return list(vars_in_ckpt.values())
159
  return vars_in_ckpt
pkulzc's avatar
pkulzc committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178


def get_global_variables_safely():
  """If not executing eagerly, returns tf.global_variables().

  Raises a ValueError if eager execution is enabled,
  because the variables are not tracked when executing eagerly.

  If executing eagerly, use a Keras model's .variables property instead.

  Returns:
    The result of tf.global_variables()
  """
  with tf.init_scope():
    if tf.executing_eagerly():
      raise ValueError("Global variables collection is not tracked when "
                       "executing eagerly. Use a Keras model's `.variables` "
                       "attribute instead.")
  return tf.global_variables()