# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Helper functions for manipulating collections of variables during training. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import logging import os import re import tensorflow.compat.v1 as tf import tf_slim as slim from tensorflow.python.ops import variables as tf_variables # Maps checkpoint types to variable name prefixes that are no longer # supported DETECTION_FEATURE_EXTRACTOR_MSG = """\ The checkpoint type 'detection' is not supported when it contains variable names with 'feature_extractor'. Please download the new checkpoint file from model zoo. """ DEPRECATED_CHECKPOINT_MAP = { 'detection': ('feature_extractor', DETECTION_FEATURE_EXTRACTOR_MSG) } # TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in # tensorflow/contrib/framework/python/ops/variables.py def filter_variables(variables, filter_regex_list, invert=False): """Filters out the variables matching the filter_regex. Filter out the variables whose name matches the any of the regular expressions in filter_regex_list and returns the remaining variables. Optionally, if invert=True, the complement set is returned. Args: variables: a list of tensorflow variables. filter_regex_list: a list of string regular expressions. invert: (boolean). If True, returns the complement of the filter set; that is, all variables matching filter_regex are kept and all others discarded. Returns: a list of filtered variables. """ kept_vars = [] variables_to_ignore_patterns = list([fre for fre in filter_regex_list if fre]) for var in variables: add = True for pattern in variables_to_ignore_patterns: if re.match(pattern, var.op.name): add = False break if add != invert: kept_vars.append(var) return kept_vars def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier): """Multiply gradients whose variable names match a regular expression. Args: grads_and_vars: A list of gradient to variable pairs (tuples). regex_list: A list of string regular expressions. multiplier: A (float) multiplier to apply to each gradient matching the regular expression. Returns: grads_and_vars: A list of gradient to variable pairs (tuples). """ variables = [pair[1] for pair in grads_and_vars] matching_vars = filter_variables(variables, regex_list, invert=True) for var in matching_vars: logging.info('Applying multiplier %f to variable [%s]', multiplier, var.op.name) grad_multipliers = {var: float(multiplier) for var in matching_vars} return slim.learning.multiply_gradients(grads_and_vars, grad_multipliers) def freeze_gradients_matching_regex(grads_and_vars, regex_list): """Freeze gradients whose variable names match a regular expression. Args: grads_and_vars: A list of gradient to variable pairs (tuples). regex_list: A list of string regular expressions. Returns: grads_and_vars: A list of gradient to variable pairs (tuples) that do not contain the variables and gradients matching the regex. """ variables = [pair[1] for pair in grads_and_vars] matching_vars = filter_variables(variables, regex_list, invert=True) kept_grads_and_vars = [pair for pair in grads_and_vars if pair[1] not in matching_vars] for var in matching_vars: logging.info('Freezing variable [%s]', var.op.name) return kept_grads_and_vars def get_variables_available_in_checkpoint(variables, checkpoint_path, include_global_step=True): """Returns the subset of variables available in the checkpoint. Inspects given checkpoint and returns the subset of variables that are available in it. TODO(rathodv): force input and output to be a dictionary. Args: variables: a list or dictionary of variables to find in checkpoint. checkpoint_path: path to the checkpoint to restore variables from. include_global_step: whether to include `global_step` variable, if it exists. Default True. Returns: A list or dictionary of variables. Raises: ValueError: if `variables` is not a list or dict. """ if isinstance(variables, list): variable_names_map = {} for variable in variables: if isinstance(variable, tf_variables.PartitionedVariable): name = variable.name else: name = variable.op.name variable_names_map[name] = variable elif isinstance(variables, dict): variable_names_map = variables else: raise ValueError('`variables` is expected to be a list or dict.') ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path) ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map() if not include_global_step: ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None) vars_in_ckpt = {} for variable_name, variable in sorted(variable_names_map.items()): if variable_name in ckpt_vars_to_shape_map: if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list(): vars_in_ckpt[variable_name] = variable else: logging.warning('Variable [%s] is available in checkpoint, but has an ' 'incompatible shape with model variable. Checkpoint ' 'shape: [%s], model variable shape: [%s]. This ' 'variable will not be initialized from the checkpoint.', variable_name, ckpt_vars_to_shape_map[variable_name], variable.shape.as_list()) else: logging.warning('Variable [%s] is not available in checkpoint', variable_name) if isinstance(variables, list): return list(vars_in_ckpt.values()) return vars_in_ckpt def get_global_variables_safely(): """If not executing eagerly, returns tf.global_variables(). Raises a ValueError if eager execution is enabled, because the variables are not tracked when executing eagerly. If executing eagerly, use a Keras model's .variables property instead. Returns: The result of tf.global_variables() """ with tf.init_scope(): if tf.executing_eagerly(): raise ValueError("Global variables collection is not tracked when " "executing eagerly. Use a Keras model's `.variables` " "attribute instead.") return tf.global_variables() def ensure_checkpoint_supported(checkpoint_path, checkpoint_type, model_dir): """Ensures that the given checkpoint can be properly loaded. Performs the following checks 1. Raises an error if checkpoint_path and model_dir are same. 2. Checks that checkpoint_path does not contain a deprecated checkpoint file by inspecting its variables. Args: checkpoint_path: str, path to checkpoint. checkpoint_type: str, denotes the type of checkpoint. model_dir: The model directory to store intermediate training checkpoints. Raises: RuntimeError: If 1. We detect an deprecated checkpoint file. 2. model_dir and checkpoint_path are in the same directory. """ variables = tf.train.list_variables(checkpoint_path) if checkpoint_type in DEPRECATED_CHECKPOINT_MAP: blocked_prefix, msg = DEPRECATED_CHECKPOINT_MAP[checkpoint_type] for var_name, _ in variables: if var_name.startswith(blocked_prefix): tf.logging.error('Found variable name - %s with prefix %s', var_name, blocked_prefix) raise RuntimeError(msg) checkpoint_path_dir = os.path.abspath(os.path.dirname(checkpoint_path)) model_dir = os.path.abspath(model_dir) if model_dir == checkpoint_path_dir: raise RuntimeError( ('Checkpoint dir ({}) and model_dir ({}) cannot be same.'.format( checkpoint_path_dir, model_dir) + (' Please set model_dir to a different path.')))