"magic_pdf/vscode:/vscode.git/clone" did not exist on "c968ce860dd284f6814fe69456d6c19d5cdf9a18"
Commit 143464d2 authored by pkulzc's avatar pkulzc
Browse files

Sync to latest.

parents 1f4747a4 c3b26603
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizers for use in unrolled optimization.
These optimizers contain a compute_updates function and its own ability to keep
track of internal state.
These functions can be used with a tf.while_loop to perform multiple training
steps per sess.run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import tensorflow as tf
import sonnet as snt
from learning_unsupervised_learning import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
class UnrollableOptimizer(snt.AbstractModule):
"""Interface for optimizers that can be used in unrolled computation.
apply_gradients is derrived from compute_update and assign_state.
"""
def __init__(self, *args, **kwargs):
super(UnrollableOptimizer, self).__init__(*args, **kwargs)
self()
@abc.abstractmethod
def compute_updates(self, xs, gs, state=None):
"""Compute next step updates for a given variable list and state.
Args:
xs: list of tensors
The "variables" to perform an update on.
Note these must match the same order for which get_state was originally
called.
gs: list of tensors
Gradients of `xs` with respect to some loss.
state: Any
Optimizer specific state to keep track of accumulators such as momentum
terms
"""
raise NotImplementedError()
def _build(self):
pass
@abc.abstractmethod
def get_state(self, var_list):
"""Get the state value associated with a list of tf.Variables.
This state is commonly going to be a NamedTuple that contains some
mapping between variables and the state associated with those variables.
This state could be a moving momentum variable tracked by the optimizer.
Args:
var_list: list of tf.Variable
Returns:
state: Any
Optimizer specific state
"""
raise NotImplementedError()
def assign_state(self, state):
"""Assigns the state to the optimizers internal variables.
Args:
state: Any
Returns:
op: tf.Operation
The operation that performs the assignment.
"""
raise NotImplementedError()
def apply_gradients(self, grad_vars):
gradients, variables = zip(*grad_vars)
state = self.get_state(variables)
new_vars, new_state = self.compute_updates(variables, gradients, state)
assign_op = self.assign_state(new_state)
op = utils.assign_variables(variables, new_vars)
return tf.group(assign_op, op, name="apply_gradients")
class UnrollableGradientDescentRollingOptimizer(UnrollableOptimizer):
def __init__(self,
learning_rate,
name="UnrollableGradientDescentRollingOptimizer"):
self.learning_rate = learning_rate
super(UnrollableGradientDescentRollingOptimizer, self).__init__(name=name)
def compute_updates(self, xs, gs, learning_rates, state):
new_vars = []
for x, g, lr in utils.eqzip(xs, gs, learning_rates):
if lr is None:
lr = self.learning_rate
if g is not None:
new_vars.append((x * (1 - lr) - g * lr))
else:
new_vars.append(x)
return new_vars, state
def get_state(self, var_list):
return tf.constant(0.0)
def assign_state(self, state, var_list=None):
return tf.no_op()
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Script that iteratively applies the unsupervised update rule and evaluates the
meta-objective performance.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from absl import app
from learning_unsupervised_learning import evaluation
from learning_unsupervised_learning import datasets
from learning_unsupervised_learning import architectures
from learning_unsupervised_learning import summary_utils
from learning_unsupervised_learning import meta_objective
import tensorflow as tf
import sonnet as snt
from tensorflow.contrib.framework.python.framework import checkpoint_utils
flags.DEFINE_string("checkpoint", None, "Dir to load pretrained update rule from")
flags.DEFINE_string("train_log_dir", None, "Training log directory")
FLAGS = flags.FLAGS
def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
dataset_fn = datasets.mnist.TinyMnist
w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess
meta_objectives = []
meta_objectives.append(
meta_objective.linear_regression.LinearRegressionMetaObjective)
meta_objectives.append(meta_objective.sklearn.LogisticRegression)
checkpoint_vars, train_one_step_op, (
base_model, dataset) = evaluation.construct_evaluation_graph(
theta_process_fn=theta_process_fn,
w_learner_fn=w_learner_fn,
dataset_fn=dataset_fn,
meta_objectives=meta_objectives)
batch = dataset()
pre_logit, outputs = base_model(batch)
global_step = tf.train.get_or_create_global_step()
var_list = list(
snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES))
tf.logging.info("all vars")
for v in tf.all_variables():
tf.logging.info(" %s" % str(v))
global_step = tf.train.get_global_step()
accumulate_global_step = global_step.assign_add(1)
reset_global_step = global_step.assign(0)
train_op = tf.group(
train_one_step_op, accumulate_global_step, name="train_op")
summary_op = tf.summary.merge_all()
file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
if checkpoint:
str_var_list = checkpoint_utils.list_variables(checkpoint)
name_to_v_map = {v.op.name: v for v in tf.all_variables()}
var_list = [
name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
]
saver = tf.train.Saver(var_list)
missed_variables = [
v.op.name for v in set(
snt.get_variables_in_scope("LocalWeightUpdateProcess",
tf.GraphKeys.GLOBAL_VARIABLES)) -
set(var_list)
]
assert len(missed_variables) == 0, "Missed a theta variable."
hooks = []
with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:
# global step should be restored from the evals job checkpoint or zero for fresh.
step = sess.run(global_step)
if step == 0 and checkpoint:
tf.logging.info("force restore")
saver.restore(sess, checkpoint)
tf.logging.info("force restore done")
sess.run(reset_global_step)
step = sess.run(global_step)
while step < num_steps:
if step % eval_every_n_steps == 0:
s, _, step = sess.run([summary_op, train_op, global_step])
file_writer.add_summary(s, step)
else:
_, step = sess.run([train_op, global_step])
def main(argv):
train(FLAGS.train_log_dir, FLAGS.checkpoint)
if __name__ == "__main__":
app.run(main)
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import collections
import functools
import threading
import tensorflow as tf
import matplotlib
import numpy as np
import time
import re
import math
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import scipy.signal
from tensorflow.python.util import tf_should_use
from tensorflow.contrib.summary import summary_ops
from tensorflow.python.ops import summary_op_util
from tensorflow.contrib.summary import gen_summary_ops
_DEBUG_DISABLE_SUMMARIES=False
class LoggingFileWriter(tf.summary.FileWriter):
"""A FileWriter that also logs things out.
This is entirely for ease of debugging / not having to open up Tensorboard
a lot.
"""
def __init__(self, logdir, regexes=[], **kwargs):
self.regexes = regexes
super(LoggingFileWriter, self).__init__(logdir, **kwargs)
def add_summary(self, summary, global_step):
if type(summary) != tf.Summary:
summary_p = tf.Summary()
summary_p.ParseFromString(summary)
summary = summary_p
for s in summary.value:
for exists in [re.match(p, s.tag) for p in self.regexes]:
if exists is not None:
tf.logging.info("%d ] %s : %f", global_step, s.tag, s.simple_value)
break
super(LoggingFileWriter, self).add_summary(summary, global_step)
def image_grid(images, max_grid_size=4, border=1):
"""Given images and N, return first N^2 images as an NxN image grid.
Args:
images: a `Tensor` of size [batch_size, height, width, channels]
max_grid_size: Maximum image grid height/width
Returns:
Single image batch, of dim [1, h*n, w*n, c]
"""
batch_size = images.shape.as_list()[0]
to_pad = int((np.ceil(np.sqrt(batch_size)))**2 - batch_size)
images = tf.pad(images, [[0, to_pad], [0, border], [0, border], [0, 0]])
batch_size = images.shape.as_list()[0]
grid_size = min(int(np.sqrt(batch_size)), max_grid_size)
assert images.shape.as_list()[0] >= grid_size * grid_size
# If we have a depth channel
if images.shape.as_list()[-1] == 4:
images = images[:grid_size * grid_size, :, :, 0:3]
depth = tf.image.grayscale_to_rgb(images[:grid_size * grid_size, :, :, 3:4])
images = tf.reshape(images, [-1, images.shape.as_list()[2], 3])
split = tf.split(images, grid_size, axis=0)
depth = tf.reshape(depth, [-1, images.shape.as_list()[2], 3])
depth_split = tf.split(depth, grid_size, axis=0)
grid = tf.concat(split + depth_split, 1)
return tf.expand_dims(grid, 0)
else:
images = images[:grid_size * grid_size, :, :, :]
images = tf.reshape(
images, [-1, images.shape.as_list()[2],
images.shape.as_list()[3]])
split = tf.split(value=images, num_or_size_splits=grid_size, axis=0)
grid = tf.concat(split, 1)
return tf.expand_dims(grid, 0)
def first_layer_weight_image(weight, shape):
weight_image = tf.reshape(weight,
shape + [tf.identity(weight).shape.as_list()[1]])
# [winx, winy, wout]
mean, var = tf.nn.moments(weight_image, [0,1,2], keep_dims=True)
#mean, var = tf.nn.moments(weight_image, [0,1], keep_dims=True)
weight_image = (weight_image - mean) / tf.sqrt(var + 1e-5)
weight_image = (weight_image + 1.0) / 2.0
weight_image = tf.clip_by_value(weight_image, 0, 1)
weight_image = tf.transpose(weight_image, (3, 0, 1, 2))
grid = image_grid(weight_image, max_grid_size=10)
return grid
def inner_layer_weight_image(weight):
"""Visualize a weight matrix of an inner layer.
Add padding to make it square, then visualize as a gray scale image
"""
weight = tf.identity(weight) # turn into a tensor
weight = weight / (tf.reduce_max(tf.abs(weight), [0], keep_dims=True))
weight = tf.reshape(weight, [1]+weight.shape.as_list() + [1])
return weight
def activation_image(activations, label_onehot):
"""Make a row sorted by class for each activation. Put a black line around the activations."""
labels = tf.argmax(label_onehot, axis=1)
_, n_classes = label_onehot.shape.as_list()
mean, var = tf.nn.moments(activations, [0, 1])
activations = (activations - mean)/tf.sqrt(var+1e-5)
activations = tf.clip_by_value(activations, -1, 1)
activations = (activations + 1.0) / 2.0 # shift to [0, 1]
canvas = []
for i in xrange(n_classes):
inds = tf.where(tf.equal(labels, i))
def _gather():
return tf.squeeze(tf.gather(activations, inds), 1)
def _empty():
return tf.zeros([0, activations.shape.as_list()[1]], dtype=tf.float32)
assert inds.shape.as_list()[0] is None
x = tf.cond(tf.equal(tf.shape(inds)[0], 0), _empty, _gather)
canvas.append(x)
canvas.append(tf.zeros([1, activations.shape.as_list()[1]]))
canvas = tf.concat(canvas, 0)
canvas = tf.reshape(canvas, [1, activations.shape.as_list()[0]+n_classes, canvas.shape.as_list()[1], 1])
return canvas
def sorted_images(images, label_onehot):
# images is [bs, x, y, c]
labels = tf.argmax(label_onehot, axis=1)
_, n_classes = label_onehot.shape.as_list()
to_stack = []
for i in xrange(n_classes):
inds = tf.where(tf.equal(labels, i))
def _gather():
return tf.squeeze(tf.gather(images, inds), 1)
def _empty():
return tf.zeros([0] + images.shape.as_list()[1:], dtype=tf.float32)
assert inds.shape.as_list()[0] is None
x = tf.cond(tf.equal(tf.shape(inds)[0], 0), _empty, _gather)
to_stack.append(x)
# pad / trim all up to 10.
padded = []
for t in to_stack:
n_found = tf.shape(t)[0]
pad = tf.pad(t[0:10], tf.stack([tf.stack([0,tf.maximum(0, 10-n_found)]), [0,0], [0,0], [0,0]]))
padded.append(pad)
xs = [tf.concat(tf.split(p, 10), axis=1) for p in padded]
ys = tf.concat(xs, axis=2)
ys = tf.cast(tf.clip_by_value(ys, 0., 1.) * 255., tf.uint8)
return ys
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import tensorflow as tf
import sonnet as snt
import itertools
import functools
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.ops import variable_scope as variable_scope_ops
from sonnet.python.modules import util as snt_util
from tensorflow.python.util import nest
def eqzip(*args):
"""Zip but raises error if lengths don't match.
Args:
*args: list of lists or tuples
Returns:
list: the result of zip
Raises:
ValueError: when the lengths don't match
"""
sizes = [len(x) for x in args]
if not all([sizes[0] == x for x in sizes]):
raise ValueError("Lists are of different sizes. \n %s"%str(sizes))
return zip(*args)
@contextlib.contextmanager
def assert_no_new_variables():
"""Ensure that no tf.Variables are constructed inside the context.
Yields:
None
Raises:
ValueError: if there is a variable created.
"""
num_vars = len(tf.global_variables())
old_variables = tf.global_variables()
yield
if len(tf.global_variables()) != num_vars:
new_vars = set(tf.global_variables()) - set(old_variables)
tf.logging.error("NEW VARIABLES CREATED")
tf.logging.error(10*"=")
for v in new_vars:
tf.logging.error(v)
raise ValueError("Variables created inside an "
"assert_no_new_variables context")
if old_variables != tf.global_variables():
raise ValueError("Variables somehow changed inside an "
"assert_no_new_variables context."
"This means something modified the tf.global_variables()")
def get_variables_in_modules(module_list):
var_list = []
for m in module_list:
var_list.extend(snt.get_variables_in_module(m))
return var_list
def state_barrier_context(state):
"""Return a context manager that prevents interior ops from running
unless the whole state has been computed.
This is to prevent assign race conditions.
"""
tensors = [x for x in nest.flatten(state) if type(x) == tf.Tensor]
tarray = [x.flow for x in nest.flatten(state) if hasattr(x, "flow")]
return tf.control_dependencies(tensors + tarray)
def _identity_fn(tf_entity):
if hasattr(tf_entity, "identity"):
return tf_entity.identity()
else:
return tf.identity(tf_entity)
def state_barrier_result(state):
"""Return the same state, but with a control dependency to prevent it from
being partially computed
"""
with state_barrier_context(state):
return nest.map_structure(_identity_fn, state)
def train_iterator(num_iterations):
"""Iterator that returns an index of the current step.
This iterator runs forever if num_iterations is None
otherwise it runs for some fixed amount of steps.
"""
if num_iterations is None:
return itertools.count()
else:
return xrange(num_iterations)
def print_op(op, msg):
"""Print a string and return an op wrapped in a control dependency to make
sure it ran."""
print_op = tf.Print(tf.constant(0), [tf.constant(0)], msg)
return tf.group(op, print_op)
class MultiQueueRunner(tf.train.QueueRunner):
"""A QueueRunner with multiple queues """
def __init__(self, queues, enqueue_ops):
close_op = tf.group(* [q.close() for q in queues])
cancel_op = tf.group(
* [q.close(cancel_pending_enqueues=True) for q in queues])
queue_closed_exception_types = (errors.OutOfRangeError,)
enqueue_op = tf.group(*enqueue_ops, name="multi_enqueue")
super(MultiQueueRunner, self).__init__(
queues[0],
enqueue_ops=[enqueue_op],
close_op=close_op,
cancel_op=cancel_op,
queue_closed_exception_types=queue_closed_exception_types)
# This function is not elegant, but I tried so many other ways to get this to
# work and this is the only one that ended up not incuring significant overhead
# or obscure tensorflow bugs.
def sample_n_per_class(dataset, samples_per_class):
"""Create a new callable / dataset object that returns batches of each with
samples_per_class per label.
Args:
dataset: fn
samples_per_class: int
Returns:
function, [] -> batch where batch is the same type as the return of
dataset().
"""
with tf.control_dependencies(None), tf.name_scope(None):
with tf.name_scope("queue_runner/sample_n_per_class"):
batch = dataset()
num_classes = batch.label_onehot.shape.as_list()[1]
batch_size = num_classes * samples_per_class
flatten = nest.flatten(batch)
queues = []
enqueue_ops = []
capacity = samples_per_class * 20
for i in xrange(num_classes):
queue = tf.FIFOQueue(
capacity=capacity,
shapes=[f.shape.as_list()[1:] for f in flatten],
dtypes=[f.dtype for f in flatten])
queues.append(queue)
idx = tf.where(tf.equal(batch.label, i))
sub_batch = []
to_enqueue = []
for elem in batch:
new_e = tf.gather(elem, idx)
new_e = tf.squeeze(new_e, 1)
to_enqueue.append(new_e)
remaining = (capacity - queue.size())
to_add = tf.minimum(tf.shape(idx)[0], remaining)
def _enqueue():
return queue.enqueue_many([t[:to_add] for t in to_enqueue])
enqueue_op = tf.cond(
tf.equal(to_add, 0), tf.no_op, _enqueue)
enqueue_ops.append(enqueue_op)
# This has caused many deadlocks / issues. This is some logging to at least
# shed light to what is going on.
print_lam = lambda: tf.Print(tf.constant(0.0), [q.size() for q in queues], "MultiQueueRunner queues status. Has capacity %d"%capacity)
some_percent_of_time = tf.less(tf.random_uniform([]), 0.0005)
maybe_print = tf.cond(some_percent_of_time, print_lam, lambda: tf.constant(0.0))
with tf.control_dependencies([maybe_print]):
enqueue_ops = [tf.group(e) for e in enqueue_ops]
qr = MultiQueueRunner(queues=queues, enqueue_ops=enqueue_ops)
tf.train.add_queue_runner(qr)
def dequeue_batch():
with tf.name_scope("sample_n_per_batch/dequeue/"):
entries = []
for q in queues:
entries.append(q.dequeue_many(samples_per_class))
flat_batch = [tf.concat(x, 0) for x in zip(*entries)]
idx = tf.random_shuffle(tf.range(batch_size))
flat_batch = [tf.gather(f, idx, axis=0) for f in flat_batch]
return nest.pack_sequence_as(batch, flat_batch)
return dequeue_batch
def structure_map_multi(func, values):
all_values = [nest.flatten(v) for v in values]
rets = []
for pair in zip(*all_values):
rets.append(func(pair))
return nest.pack_sequence_as(values[0], rets)
def structure_map_split(func, value):
vv = nest.flatten(value)
rets = []
for v in vv:
rets.append(func(v))
return [nest.pack_sequence_as(value, r) for r in zip(*rets)]
def assign_variables(targets, values):
return tf.group(*[t.assign(v) for t,v in eqzip(targets, values)],
name="assign_variables")
def create_variables_in_class_scope(method):
"""Force the variables constructed in this class to live in the sonnet module.
Wraps a method on a sonnet module.
For example the following will create two different variables.
```
class Mod(snt.AbstractModule):
@create_variables_in_class_scope
def dynamic_thing(self, input, name):
return snt.Linear(name)(input)
mod.dynamic_thing(x, name="module_nameA")
mod.dynamic_thing(x, name="module_nameB")
# reuse
mod.dynamic_thing(y, name="module_nameA")
```
"""
@functools.wraps(method)
def wrapper(obj, *args, **kwargs):
def default_context_manager(reuse=None):
variable_scope = obj.variable_scope
return tf.variable_scope(variable_scope, reuse=reuse)
variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
default_context_manager)
graph = tf.get_default_graph()
# Temporarily enter the variable scope to capture it
with variable_scope_context_manager() as tmp_variable_scope:
variable_scope = tmp_variable_scope
with variable_scope_ops._pure_variable_scope(
variable_scope, reuse=tf.AUTO_REUSE) as pure_variable_scope:
name_scope = variable_scope.original_name_scope
if name_scope[-1] != "/":
name_scope += "/"
with tf.name_scope(name_scope):
sub_scope = snt_util.to_snake_case(method.__name__)
with tf.name_scope(sub_scope) as scope:
out_ops = method(obj, *args, **kwargs)
return out_ops
return wrapper
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
import tensorflow as tf
from contextlib import contextmanager
from tensorflow.python.ops import variable_scope
# sanity global state to ensure non recursive.
_is_variable_replacing = [False]
def in_variable_replace_scope():
return _is_variable_replacing[0]
@contextmanager
def variable_replace(replacements, no_new=True):
""" A context manager that replaces variables.
This is a context manager that replaces all calls to
get_variable with the variable in replacements.
This function does not support recursive application.
Args:
replacements: dict
dictionary mapping a variable to replace (the key), with
the variable one wants to replace this variable with (the value).
no_new: bool
raise an error if variables were created.
This is for sanity checking.
Raises:
ValueError: if a new variable or not all the replacements are used.
"""
# TODO(lmetz) This function is a bit scary, as it relies on monkey patching
# the call to get_variable. Ideally this can be done with variable_scope's
# custom_getter attribute, but when initially writing this that was not
# avalible.
replacements = {k: v for k, v in replacements.items() if not k == v}
init_vars = tf.trainable_variables()
old_get_variable = variable_scope.get_variable
old_tf_get_variable = tf.get_variable
names_replace = {}
has_replaced_names = []
tf.logging.vlog(2, "Trying to replace")
for k, v in replacements.items():
tf.logging.vlog(2, k.name + " >> " + v.name)
tf.logging.vlog(2, "===")
for k, v in replacements.items():
strip_name = k.name.replace("/read:0", "")
strip_name = strip_name.replace(":0", "")
names_replace[strip_name] = v
# TODO(lmetz) is there a cleaner way to do this?
def new_get_variable(name, *args, **kwargs):
#print "Monkeypatch get variable run with name:", name
n = tf.get_variable_scope().name + "/" + name
#print "Monkeypatch get variable run with name:", n
if n in names_replace:
has_replaced_names.append(n)
return names_replace[n]
else:
return old_get_variable(name, *args, **kwargs)
# perform the monkey patch
if _is_variable_replacing[0] == True:
raise ValueError("No recursive calling to variable replace allowed.")
variable_scope.get_variable = new_get_variable
tf.get_variable = new_get_variable
_is_variable_replacing[0] = True
yield
if set(has_replaced_names) != set(names_replace.keys()):
print "Didn't use all replacements"
print "replaced variables that are not requested??"
print "==="
for n in list(set(has_replaced_names) - set(names_replace.keys())):
print n
print "Missed replacing variables"
print "==="
for n in list(set(names_replace.keys()) - set(has_replaced_names)):
print n, "==>", names_replace[n].name
raise ValueError("Fix this -- see stderr")
# undo the monkey patch
tf.get_variable = old_tf_get_variable
variable_scope.get_variable = old_get_variable
_is_variable_replacing[0] = False
final_vars = tf.trainable_variables()
assert set(init_vars) == set(final_vars), "trainable variables changed"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment