Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
143464d2
"magic_pdf/vscode:/vscode.git/clone" did not exist on "c968ce860dd284f6814fe69456d6c19d5cdf9a18"
Commit
143464d2
authored
Apr 03, 2018
by
pkulzc
Browse files
Sync to latest.
parents
1f4747a4
c3b26603
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
835 additions
and
0 deletions
+835
-0
research/learning_unsupervised_learning/optimizers.py
research/learning_unsupervised_learning/optimizers.py
+133
-0
research/learning_unsupervised_learning/run_eval.py
research/learning_unsupervised_learning/run_eval.py
+122
-0
research/learning_unsupervised_learning/summary_utils.py
research/learning_unsupervised_learning/summary_utils.py
+181
-0
research/learning_unsupervised_learning/utils.py
research/learning_unsupervised_learning/utils.py
+287
-0
research/learning_unsupervised_learning/variable_replace.py
research/learning_unsupervised_learning/variable_replace.py
+112
-0
No files found.
research/learning_unsupervised_learning/optimizers.py
0 → 100644
View file @
143464d2
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizers for use in unrolled optimization.
These optimizers contain a compute_updates function and its own ability to keep
track of internal state.
These functions can be used with a tf.while_loop to perform multiple training
steps per sess.run.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
abc
import
collections
import
tensorflow
as
tf
import
sonnet
as
snt
from
learning_unsupervised_learning
import
utils
from
tensorflow.python.framework
import
ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
resource_variable_ops
from
tensorflow.python.training
import
optimizer
from
tensorflow.python.training
import
training_ops
class
UnrollableOptimizer
(
snt
.
AbstractModule
):
"""Interface for optimizers that can be used in unrolled computation.
apply_gradients is derrived from compute_update and assign_state.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
UnrollableOptimizer
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
()
@
abc
.
abstractmethod
def
compute_updates
(
self
,
xs
,
gs
,
state
=
None
):
"""Compute next step updates for a given variable list and state.
Args:
xs: list of tensors
The "variables" to perform an update on.
Note these must match the same order for which get_state was originally
called.
gs: list of tensors
Gradients of `xs` with respect to some loss.
state: Any
Optimizer specific state to keep track of accumulators such as momentum
terms
"""
raise
NotImplementedError
()
def
_build
(
self
):
pass
@
abc
.
abstractmethod
def
get_state
(
self
,
var_list
):
"""Get the state value associated with a list of tf.Variables.
This state is commonly going to be a NamedTuple that contains some
mapping between variables and the state associated with those variables.
This state could be a moving momentum variable tracked by the optimizer.
Args:
var_list: list of tf.Variable
Returns:
state: Any
Optimizer specific state
"""
raise
NotImplementedError
()
def
assign_state
(
self
,
state
):
"""Assigns the state to the optimizers internal variables.
Args:
state: Any
Returns:
op: tf.Operation
The operation that performs the assignment.
"""
raise
NotImplementedError
()
def
apply_gradients
(
self
,
grad_vars
):
gradients
,
variables
=
zip
(
*
grad_vars
)
state
=
self
.
get_state
(
variables
)
new_vars
,
new_state
=
self
.
compute_updates
(
variables
,
gradients
,
state
)
assign_op
=
self
.
assign_state
(
new_state
)
op
=
utils
.
assign_variables
(
variables
,
new_vars
)
return
tf
.
group
(
assign_op
,
op
,
name
=
"apply_gradients"
)
class
UnrollableGradientDescentRollingOptimizer
(
UnrollableOptimizer
):
def
__init__
(
self
,
learning_rate
,
name
=
"UnrollableGradientDescentRollingOptimizer"
):
self
.
learning_rate
=
learning_rate
super
(
UnrollableGradientDescentRollingOptimizer
,
self
).
__init__
(
name
=
name
)
def
compute_updates
(
self
,
xs
,
gs
,
learning_rates
,
state
):
new_vars
=
[]
for
x
,
g
,
lr
in
utils
.
eqzip
(
xs
,
gs
,
learning_rates
):
if
lr
is
None
:
lr
=
self
.
learning_rate
if
g
is
not
None
:
new_vars
.
append
((
x
*
(
1
-
lr
)
-
g
*
lr
))
else
:
new_vars
.
append
(
x
)
return
new_vars
,
state
def
get_state
(
self
,
var_list
):
return
tf
.
constant
(
0.0
)
def
assign_state
(
self
,
state
,
var_list
=
None
):
return
tf
.
no_op
()
research/learning_unsupervised_learning/run_eval.py
0 → 100644
View file @
143464d2
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Script that iteratively applies the unsupervised update rule and evaluates the
meta-objective performance.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
from
absl
import
app
from
learning_unsupervised_learning
import
evaluation
from
learning_unsupervised_learning
import
datasets
from
learning_unsupervised_learning
import
architectures
from
learning_unsupervised_learning
import
summary_utils
from
learning_unsupervised_learning
import
meta_objective
import
tensorflow
as
tf
import
sonnet
as
snt
from
tensorflow.contrib.framework.python.framework
import
checkpoint_utils
flags
.
DEFINE_string
(
"checkpoint"
,
None
,
"Dir to load pretrained update rule from"
)
flags
.
DEFINE_string
(
"train_log_dir"
,
None
,
"Training log directory"
)
FLAGS
=
flags
.
FLAGS
def
train
(
train_log_dir
,
checkpoint
,
eval_every_n_steps
=
10
,
num_steps
=
3000
):
dataset_fn
=
datasets
.
mnist
.
TinyMnist
w_learner_fn
=
architectures
.
more_local_weight_update
.
MoreLocalWeightUpdateWLearner
theta_process_fn
=
architectures
.
more_local_weight_update
.
MoreLocalWeightUpdateProcess
meta_objectives
=
[]
meta_objectives
.
append
(
meta_objective
.
linear_regression
.
LinearRegressionMetaObjective
)
meta_objectives
.
append
(
meta_objective
.
sklearn
.
LogisticRegression
)
checkpoint_vars
,
train_one_step_op
,
(
base_model
,
dataset
)
=
evaluation
.
construct_evaluation_graph
(
theta_process_fn
=
theta_process_fn
,
w_learner_fn
=
w_learner_fn
,
dataset_fn
=
dataset_fn
,
meta_objectives
=
meta_objectives
)
batch
=
dataset
()
pre_logit
,
outputs
=
base_model
(
batch
)
global_step
=
tf
.
train
.
get_or_create_global_step
()
var_list
=
list
(
snt
.
get_variables_in_module
(
base_model
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
))
tf
.
logging
.
info
(
"all vars"
)
for
v
in
tf
.
all_variables
():
tf
.
logging
.
info
(
" %s"
%
str
(
v
))
global_step
=
tf
.
train
.
get_global_step
()
accumulate_global_step
=
global_step
.
assign_add
(
1
)
reset_global_step
=
global_step
.
assign
(
0
)
train_op
=
tf
.
group
(
train_one_step_op
,
accumulate_global_step
,
name
=
"train_op"
)
summary_op
=
tf
.
summary
.
merge_all
()
file_writer
=
summary_utils
.
LoggingFileWriter
(
train_log_dir
,
regexes
=
[
".*"
])
if
checkpoint
:
str_var_list
=
checkpoint_utils
.
list_variables
(
checkpoint
)
name_to_v_map
=
{
v
.
op
.
name
:
v
for
v
in
tf
.
all_variables
()}
var_list
=
[
name_to_v_map
[
vn
]
for
vn
,
_
in
str_var_list
if
vn
in
name_to_v_map
]
saver
=
tf
.
train
.
Saver
(
var_list
)
missed_variables
=
[
v
.
op
.
name
for
v
in
set
(
snt
.
get_variables_in_scope
(
"LocalWeightUpdateProcess"
,
tf
.
GraphKeys
.
GLOBAL_VARIABLES
))
-
set
(
var_list
)
]
assert
len
(
missed_variables
)
==
0
,
"Missed a theta variable."
hooks
=
[]
with
tf
.
train
.
SingularMonitoredSession
(
master
=
""
,
hooks
=
hooks
)
as
sess
:
# global step should be restored from the evals job checkpoint or zero for fresh.
step
=
sess
.
run
(
global_step
)
if
step
==
0
and
checkpoint
:
tf
.
logging
.
info
(
"force restore"
)
saver
.
restore
(
sess
,
checkpoint
)
tf
.
logging
.
info
(
"force restore done"
)
sess
.
run
(
reset_global_step
)
step
=
sess
.
run
(
global_step
)
while
step
<
num_steps
:
if
step
%
eval_every_n_steps
==
0
:
s
,
_
,
step
=
sess
.
run
([
summary_op
,
train_op
,
global_step
])
file_writer
.
add_summary
(
s
,
step
)
else
:
_
,
step
=
sess
.
run
([
train_op
,
global_step
])
def
main
(
argv
):
train
(
FLAGS
.
train_log_dir
,
FLAGS
.
checkpoint
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
research/learning_unsupervised_learning/summary_utils.py
0 → 100644
View file @
143464d2
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
collections
import
functools
import
threading
import
tensorflow
as
tf
import
matplotlib
import
numpy
as
np
import
time
import
re
import
math
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
import
scipy.signal
from
tensorflow.python.util
import
tf_should_use
from
tensorflow.contrib.summary
import
summary_ops
from
tensorflow.python.ops
import
summary_op_util
from
tensorflow.contrib.summary
import
gen_summary_ops
_DEBUG_DISABLE_SUMMARIES
=
False
class
LoggingFileWriter
(
tf
.
summary
.
FileWriter
):
"""A FileWriter that also logs things out.
This is entirely for ease of debugging / not having to open up Tensorboard
a lot.
"""
def
__init__
(
self
,
logdir
,
regexes
=
[],
**
kwargs
):
self
.
regexes
=
regexes
super
(
LoggingFileWriter
,
self
).
__init__
(
logdir
,
**
kwargs
)
def
add_summary
(
self
,
summary
,
global_step
):
if
type
(
summary
)
!=
tf
.
Summary
:
summary_p
=
tf
.
Summary
()
summary_p
.
ParseFromString
(
summary
)
summary
=
summary_p
for
s
in
summary
.
value
:
for
exists
in
[
re
.
match
(
p
,
s
.
tag
)
for
p
in
self
.
regexes
]:
if
exists
is
not
None
:
tf
.
logging
.
info
(
"%d ] %s : %f"
,
global_step
,
s
.
tag
,
s
.
simple_value
)
break
super
(
LoggingFileWriter
,
self
).
add_summary
(
summary
,
global_step
)
def
image_grid
(
images
,
max_grid_size
=
4
,
border
=
1
):
"""Given images and N, return first N^2 images as an NxN image grid.
Args:
images: a `Tensor` of size [batch_size, height, width, channels]
max_grid_size: Maximum image grid height/width
Returns:
Single image batch, of dim [1, h*n, w*n, c]
"""
batch_size
=
images
.
shape
.
as_list
()[
0
]
to_pad
=
int
((
np
.
ceil
(
np
.
sqrt
(
batch_size
)))
**
2
-
batch_size
)
images
=
tf
.
pad
(
images
,
[[
0
,
to_pad
],
[
0
,
border
],
[
0
,
border
],
[
0
,
0
]])
batch_size
=
images
.
shape
.
as_list
()[
0
]
grid_size
=
min
(
int
(
np
.
sqrt
(
batch_size
)),
max_grid_size
)
assert
images
.
shape
.
as_list
()[
0
]
>=
grid_size
*
grid_size
# If we have a depth channel
if
images
.
shape
.
as_list
()[
-
1
]
==
4
:
images
=
images
[:
grid_size
*
grid_size
,
:,
:,
0
:
3
]
depth
=
tf
.
image
.
grayscale_to_rgb
(
images
[:
grid_size
*
grid_size
,
:,
:,
3
:
4
])
images
=
tf
.
reshape
(
images
,
[
-
1
,
images
.
shape
.
as_list
()[
2
],
3
])
split
=
tf
.
split
(
images
,
grid_size
,
axis
=
0
)
depth
=
tf
.
reshape
(
depth
,
[
-
1
,
images
.
shape
.
as_list
()[
2
],
3
])
depth_split
=
tf
.
split
(
depth
,
grid_size
,
axis
=
0
)
grid
=
tf
.
concat
(
split
+
depth_split
,
1
)
return
tf
.
expand_dims
(
grid
,
0
)
else
:
images
=
images
[:
grid_size
*
grid_size
,
:,
:,
:]
images
=
tf
.
reshape
(
images
,
[
-
1
,
images
.
shape
.
as_list
()[
2
],
images
.
shape
.
as_list
()[
3
]])
split
=
tf
.
split
(
value
=
images
,
num_or_size_splits
=
grid_size
,
axis
=
0
)
grid
=
tf
.
concat
(
split
,
1
)
return
tf
.
expand_dims
(
grid
,
0
)
def
first_layer_weight_image
(
weight
,
shape
):
weight_image
=
tf
.
reshape
(
weight
,
shape
+
[
tf
.
identity
(
weight
).
shape
.
as_list
()[
1
]])
# [winx, winy, wout]
mean
,
var
=
tf
.
nn
.
moments
(
weight_image
,
[
0
,
1
,
2
],
keep_dims
=
True
)
#mean, var = tf.nn.moments(weight_image, [0,1], keep_dims=True)
weight_image
=
(
weight_image
-
mean
)
/
tf
.
sqrt
(
var
+
1e-5
)
weight_image
=
(
weight_image
+
1.0
)
/
2.0
weight_image
=
tf
.
clip_by_value
(
weight_image
,
0
,
1
)
weight_image
=
tf
.
transpose
(
weight_image
,
(
3
,
0
,
1
,
2
))
grid
=
image_grid
(
weight_image
,
max_grid_size
=
10
)
return
grid
def
inner_layer_weight_image
(
weight
):
"""Visualize a weight matrix of an inner layer.
Add padding to make it square, then visualize as a gray scale image
"""
weight
=
tf
.
identity
(
weight
)
# turn into a tensor
weight
=
weight
/
(
tf
.
reduce_max
(
tf
.
abs
(
weight
),
[
0
],
keep_dims
=
True
))
weight
=
tf
.
reshape
(
weight
,
[
1
]
+
weight
.
shape
.
as_list
()
+
[
1
])
return
weight
def
activation_image
(
activations
,
label_onehot
):
"""Make a row sorted by class for each activation. Put a black line around the activations."""
labels
=
tf
.
argmax
(
label_onehot
,
axis
=
1
)
_
,
n_classes
=
label_onehot
.
shape
.
as_list
()
mean
,
var
=
tf
.
nn
.
moments
(
activations
,
[
0
,
1
])
activations
=
(
activations
-
mean
)
/
tf
.
sqrt
(
var
+
1e-5
)
activations
=
tf
.
clip_by_value
(
activations
,
-
1
,
1
)
activations
=
(
activations
+
1.0
)
/
2.0
# shift to [0, 1]
canvas
=
[]
for
i
in
xrange
(
n_classes
):
inds
=
tf
.
where
(
tf
.
equal
(
labels
,
i
))
def
_gather
():
return
tf
.
squeeze
(
tf
.
gather
(
activations
,
inds
),
1
)
def
_empty
():
return
tf
.
zeros
([
0
,
activations
.
shape
.
as_list
()[
1
]],
dtype
=
tf
.
float32
)
assert
inds
.
shape
.
as_list
()[
0
]
is
None
x
=
tf
.
cond
(
tf
.
equal
(
tf
.
shape
(
inds
)[
0
],
0
),
_empty
,
_gather
)
canvas
.
append
(
x
)
canvas
.
append
(
tf
.
zeros
([
1
,
activations
.
shape
.
as_list
()[
1
]]))
canvas
=
tf
.
concat
(
canvas
,
0
)
canvas
=
tf
.
reshape
(
canvas
,
[
1
,
activations
.
shape
.
as_list
()[
0
]
+
n_classes
,
canvas
.
shape
.
as_list
()[
1
],
1
])
return
canvas
def
sorted_images
(
images
,
label_onehot
):
# images is [bs, x, y, c]
labels
=
tf
.
argmax
(
label_onehot
,
axis
=
1
)
_
,
n_classes
=
label_onehot
.
shape
.
as_list
()
to_stack
=
[]
for
i
in
xrange
(
n_classes
):
inds
=
tf
.
where
(
tf
.
equal
(
labels
,
i
))
def
_gather
():
return
tf
.
squeeze
(
tf
.
gather
(
images
,
inds
),
1
)
def
_empty
():
return
tf
.
zeros
([
0
]
+
images
.
shape
.
as_list
()[
1
:],
dtype
=
tf
.
float32
)
assert
inds
.
shape
.
as_list
()[
0
]
is
None
x
=
tf
.
cond
(
tf
.
equal
(
tf
.
shape
(
inds
)[
0
],
0
),
_empty
,
_gather
)
to_stack
.
append
(
x
)
# pad / trim all up to 10.
padded
=
[]
for
t
in
to_stack
:
n_found
=
tf
.
shape
(
t
)[
0
]
pad
=
tf
.
pad
(
t
[
0
:
10
],
tf
.
stack
([
tf
.
stack
([
0
,
tf
.
maximum
(
0
,
10
-
n_found
)]),
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]]))
padded
.
append
(
pad
)
xs
=
[
tf
.
concat
(
tf
.
split
(
p
,
10
),
axis
=
1
)
for
p
in
padded
]
ys
=
tf
.
concat
(
xs
,
axis
=
2
)
ys
=
tf
.
cast
(
tf
.
clip_by_value
(
ys
,
0.
,
1.
)
*
255.
,
tf
.
uint8
)
return
ys
research/learning_unsupervised_learning/utils.py
0 → 100644
View file @
143464d2
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
contextlib
import
tensorflow
as
tf
import
sonnet
as
snt
import
itertools
import
functools
from
tensorflow.core.framework
import
node_def_pb2
from
tensorflow.python.framework
import
device
as
pydev
from
tensorflow.python.framework
import
errors
from
tensorflow.python.ops
import
variable_scope
as
variable_scope_ops
from
sonnet.python.modules
import
util
as
snt_util
from
tensorflow.python.util
import
nest
def
eqzip
(
*
args
):
"""Zip but raises error if lengths don't match.
Args:
*args: list of lists or tuples
Returns:
list: the result of zip
Raises:
ValueError: when the lengths don't match
"""
sizes
=
[
len
(
x
)
for
x
in
args
]
if
not
all
([
sizes
[
0
]
==
x
for
x
in
sizes
]):
raise
ValueError
(
"Lists are of different sizes.
\n
%s"
%
str
(
sizes
))
return
zip
(
*
args
)
@
contextlib
.
contextmanager
def
assert_no_new_variables
():
"""Ensure that no tf.Variables are constructed inside the context.
Yields:
None
Raises:
ValueError: if there is a variable created.
"""
num_vars
=
len
(
tf
.
global_variables
())
old_variables
=
tf
.
global_variables
()
yield
if
len
(
tf
.
global_variables
())
!=
num_vars
:
new_vars
=
set
(
tf
.
global_variables
())
-
set
(
old_variables
)
tf
.
logging
.
error
(
"NEW VARIABLES CREATED"
)
tf
.
logging
.
error
(
10
*
"="
)
for
v
in
new_vars
:
tf
.
logging
.
error
(
v
)
raise
ValueError
(
"Variables created inside an "
"assert_no_new_variables context"
)
if
old_variables
!=
tf
.
global_variables
():
raise
ValueError
(
"Variables somehow changed inside an "
"assert_no_new_variables context."
"This means something modified the tf.global_variables()"
)
def
get_variables_in_modules
(
module_list
):
var_list
=
[]
for
m
in
module_list
:
var_list
.
extend
(
snt
.
get_variables_in_module
(
m
))
return
var_list
def
state_barrier_context
(
state
):
"""Return a context manager that prevents interior ops from running
unless the whole state has been computed.
This is to prevent assign race conditions.
"""
tensors
=
[
x
for
x
in
nest
.
flatten
(
state
)
if
type
(
x
)
==
tf
.
Tensor
]
tarray
=
[
x
.
flow
for
x
in
nest
.
flatten
(
state
)
if
hasattr
(
x
,
"flow"
)]
return
tf
.
control_dependencies
(
tensors
+
tarray
)
def
_identity_fn
(
tf_entity
):
if
hasattr
(
tf_entity
,
"identity"
):
return
tf_entity
.
identity
()
else
:
return
tf
.
identity
(
tf_entity
)
def
state_barrier_result
(
state
):
"""Return the same state, but with a control dependency to prevent it from
being partially computed
"""
with
state_barrier_context
(
state
):
return
nest
.
map_structure
(
_identity_fn
,
state
)
def
train_iterator
(
num_iterations
):
"""Iterator that returns an index of the current step.
This iterator runs forever if num_iterations is None
otherwise it runs for some fixed amount of steps.
"""
if
num_iterations
is
None
:
return
itertools
.
count
()
else
:
return
xrange
(
num_iterations
)
def
print_op
(
op
,
msg
):
"""Print a string and return an op wrapped in a control dependency to make
sure it ran."""
print_op
=
tf
.
Print
(
tf
.
constant
(
0
),
[
tf
.
constant
(
0
)],
msg
)
return
tf
.
group
(
op
,
print_op
)
class
MultiQueueRunner
(
tf
.
train
.
QueueRunner
):
"""A QueueRunner with multiple queues """
def
__init__
(
self
,
queues
,
enqueue_ops
):
close_op
=
tf
.
group
(
*
[
q
.
close
()
for
q
in
queues
])
cancel_op
=
tf
.
group
(
*
[
q
.
close
(
cancel_pending_enqueues
=
True
)
for
q
in
queues
])
queue_closed_exception_types
=
(
errors
.
OutOfRangeError
,)
enqueue_op
=
tf
.
group
(
*
enqueue_ops
,
name
=
"multi_enqueue"
)
super
(
MultiQueueRunner
,
self
).
__init__
(
queues
[
0
],
enqueue_ops
=
[
enqueue_op
],
close_op
=
close_op
,
cancel_op
=
cancel_op
,
queue_closed_exception_types
=
queue_closed_exception_types
)
# This function is not elegant, but I tried so many other ways to get this to
# work and this is the only one that ended up not incuring significant overhead
# or obscure tensorflow bugs.
def
sample_n_per_class
(
dataset
,
samples_per_class
):
"""Create a new callable / dataset object that returns batches of each with
samples_per_class per label.
Args:
dataset: fn
samples_per_class: int
Returns:
function, [] -> batch where batch is the same type as the return of
dataset().
"""
with
tf
.
control_dependencies
(
None
),
tf
.
name_scope
(
None
):
with
tf
.
name_scope
(
"queue_runner/sample_n_per_class"
):
batch
=
dataset
()
num_classes
=
batch
.
label_onehot
.
shape
.
as_list
()[
1
]
batch_size
=
num_classes
*
samples_per_class
flatten
=
nest
.
flatten
(
batch
)
queues
=
[]
enqueue_ops
=
[]
capacity
=
samples_per_class
*
20
for
i
in
xrange
(
num_classes
):
queue
=
tf
.
FIFOQueue
(
capacity
=
capacity
,
shapes
=
[
f
.
shape
.
as_list
()[
1
:]
for
f
in
flatten
],
dtypes
=
[
f
.
dtype
for
f
in
flatten
])
queues
.
append
(
queue
)
idx
=
tf
.
where
(
tf
.
equal
(
batch
.
label
,
i
))
sub_batch
=
[]
to_enqueue
=
[]
for
elem
in
batch
:
new_e
=
tf
.
gather
(
elem
,
idx
)
new_e
=
tf
.
squeeze
(
new_e
,
1
)
to_enqueue
.
append
(
new_e
)
remaining
=
(
capacity
-
queue
.
size
())
to_add
=
tf
.
minimum
(
tf
.
shape
(
idx
)[
0
],
remaining
)
def
_enqueue
():
return
queue
.
enqueue_many
([
t
[:
to_add
]
for
t
in
to_enqueue
])
enqueue_op
=
tf
.
cond
(
tf
.
equal
(
to_add
,
0
),
tf
.
no_op
,
_enqueue
)
enqueue_ops
.
append
(
enqueue_op
)
# This has caused many deadlocks / issues. This is some logging to at least
# shed light to what is going on.
print_lam
=
lambda
:
tf
.
Print
(
tf
.
constant
(
0.0
),
[
q
.
size
()
for
q
in
queues
],
"MultiQueueRunner queues status. Has capacity %d"
%
capacity
)
some_percent_of_time
=
tf
.
less
(
tf
.
random_uniform
([]),
0.0005
)
maybe_print
=
tf
.
cond
(
some_percent_of_time
,
print_lam
,
lambda
:
tf
.
constant
(
0.0
))
with
tf
.
control_dependencies
([
maybe_print
]):
enqueue_ops
=
[
tf
.
group
(
e
)
for
e
in
enqueue_ops
]
qr
=
MultiQueueRunner
(
queues
=
queues
,
enqueue_ops
=
enqueue_ops
)
tf
.
train
.
add_queue_runner
(
qr
)
def
dequeue_batch
():
with
tf
.
name_scope
(
"sample_n_per_batch/dequeue/"
):
entries
=
[]
for
q
in
queues
:
entries
.
append
(
q
.
dequeue_many
(
samples_per_class
))
flat_batch
=
[
tf
.
concat
(
x
,
0
)
for
x
in
zip
(
*
entries
)]
idx
=
tf
.
random_shuffle
(
tf
.
range
(
batch_size
))
flat_batch
=
[
tf
.
gather
(
f
,
idx
,
axis
=
0
)
for
f
in
flat_batch
]
return
nest
.
pack_sequence_as
(
batch
,
flat_batch
)
return
dequeue_batch
def
structure_map_multi
(
func
,
values
):
all_values
=
[
nest
.
flatten
(
v
)
for
v
in
values
]
rets
=
[]
for
pair
in
zip
(
*
all_values
):
rets
.
append
(
func
(
pair
))
return
nest
.
pack_sequence_as
(
values
[
0
],
rets
)
def
structure_map_split
(
func
,
value
):
vv
=
nest
.
flatten
(
value
)
rets
=
[]
for
v
in
vv
:
rets
.
append
(
func
(
v
))
return
[
nest
.
pack_sequence_as
(
value
,
r
)
for
r
in
zip
(
*
rets
)]
def
assign_variables
(
targets
,
values
):
return
tf
.
group
(
*
[
t
.
assign
(
v
)
for
t
,
v
in
eqzip
(
targets
,
values
)],
name
=
"assign_variables"
)
def
create_variables_in_class_scope
(
method
):
"""Force the variables constructed in this class to live in the sonnet module.
Wraps a method on a sonnet module.
For example the following will create two different variables.
```
class Mod(snt.AbstractModule):
@create_variables_in_class_scope
def dynamic_thing(self, input, name):
return snt.Linear(name)(input)
mod.dynamic_thing(x, name="module_nameA")
mod.dynamic_thing(x, name="module_nameB")
# reuse
mod.dynamic_thing(y, name="module_nameA")
```
"""
@
functools
.
wraps
(
method
)
def
wrapper
(
obj
,
*
args
,
**
kwargs
):
def
default_context_manager
(
reuse
=
None
):
variable_scope
=
obj
.
variable_scope
return
tf
.
variable_scope
(
variable_scope
,
reuse
=
reuse
)
variable_scope_context_manager
=
getattr
(
obj
,
"_enter_variable_scope"
,
default_context_manager
)
graph
=
tf
.
get_default_graph
()
# Temporarily enter the variable scope to capture it
with
variable_scope_context_manager
()
as
tmp_variable_scope
:
variable_scope
=
tmp_variable_scope
with
variable_scope_ops
.
_pure_variable_scope
(
variable_scope
,
reuse
=
tf
.
AUTO_REUSE
)
as
pure_variable_scope
:
name_scope
=
variable_scope
.
original_name_scope
if
name_scope
[
-
1
]
!=
"/"
:
name_scope
+=
"/"
with
tf
.
name_scope
(
name_scope
):
sub_scope
=
snt_util
.
to_snake_case
(
method
.
__name__
)
with
tf
.
name_scope
(
sub_scope
)
as
scope
:
out_ops
=
method
(
obj
,
*
args
,
**
kwargs
)
return
out_ops
return
wrapper
research/learning_unsupervised_learning/variable_replace.py
0 → 100644
View file @
143464d2
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
absolute_import
from
__future__
import
division
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
tensorflow.python.ops
import
variable_scope
# sanity global state to ensure non recursive.
_is_variable_replacing
=
[
False
]
def
in_variable_replace_scope
():
return
_is_variable_replacing
[
0
]
@
contextmanager
def
variable_replace
(
replacements
,
no_new
=
True
):
""" A context manager that replaces variables.
This is a context manager that replaces all calls to
get_variable with the variable in replacements.
This function does not support recursive application.
Args:
replacements: dict
dictionary mapping a variable to replace (the key), with
the variable one wants to replace this variable with (the value).
no_new: bool
raise an error if variables were created.
This is for sanity checking.
Raises:
ValueError: if a new variable or not all the replacements are used.
"""
# TODO(lmetz) This function is a bit scary, as it relies on monkey patching
# the call to get_variable. Ideally this can be done with variable_scope's
# custom_getter attribute, but when initially writing this that was not
# avalible.
replacements
=
{
k
:
v
for
k
,
v
in
replacements
.
items
()
if
not
k
==
v
}
init_vars
=
tf
.
trainable_variables
()
old_get_variable
=
variable_scope
.
get_variable
old_tf_get_variable
=
tf
.
get_variable
names_replace
=
{}
has_replaced_names
=
[]
tf
.
logging
.
vlog
(
2
,
"Trying to replace"
)
for
k
,
v
in
replacements
.
items
():
tf
.
logging
.
vlog
(
2
,
k
.
name
+
" >> "
+
v
.
name
)
tf
.
logging
.
vlog
(
2
,
"==="
)
for
k
,
v
in
replacements
.
items
():
strip_name
=
k
.
name
.
replace
(
"/read:0"
,
""
)
strip_name
=
strip_name
.
replace
(
":0"
,
""
)
names_replace
[
strip_name
]
=
v
# TODO(lmetz) is there a cleaner way to do this?
def
new_get_variable
(
name
,
*
args
,
**
kwargs
):
#print "Monkeypatch get variable run with name:", name
n
=
tf
.
get_variable_scope
().
name
+
"/"
+
name
#print "Monkeypatch get variable run with name:", n
if
n
in
names_replace
:
has_replaced_names
.
append
(
n
)
return
names_replace
[
n
]
else
:
return
old_get_variable
(
name
,
*
args
,
**
kwargs
)
# perform the monkey patch
if
_is_variable_replacing
[
0
]
==
True
:
raise
ValueError
(
"No recursive calling to variable replace allowed."
)
variable_scope
.
get_variable
=
new_get_variable
tf
.
get_variable
=
new_get_variable
_is_variable_replacing
[
0
]
=
True
yield
if
set
(
has_replaced_names
)
!=
set
(
names_replace
.
keys
()):
print
"Didn't use all replacements"
print
"replaced variables that are not requested??"
print
"==="
for
n
in
list
(
set
(
has_replaced_names
)
-
set
(
names_replace
.
keys
())):
print
n
print
"Missed replacing variables"
print
"==="
for
n
in
list
(
set
(
names_replace
.
keys
())
-
set
(
has_replaced_names
)):
print
n
,
"==>"
,
names_replace
[
n
].
name
raise
ValueError
(
"Fix this -- see stderr"
)
# undo the monkey patch
tf
.
get_variable
=
old_tf_get_variable
variable_scope
.
get_variable
=
old_get_variable
_is_variable_replacing
[
0
]
=
False
final_vars
=
tf
.
trainable_variables
()
assert
set
(
init_vars
)
==
set
(
final_vars
),
"trainable variables changed"
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment