Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a8ba923c
Unverified
Commit
a8ba923c
authored
Jul 30, 2020
by
Jaeyoun Kim
Committed by
GitHub
Jul 30, 2020
Browse files
Deprecate old models (#8934)
Deprecate old models
parent
5eb294f8
Changes
278
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
0 additions
and
5299 deletions
+0
-5299
research/fivo/experimental/bounds.py
research/fivo/experimental/bounds.py
+0
-673
research/fivo/experimental/data.py
research/fivo/experimental/data.py
+0
-192
research/fivo/experimental/models.py
research/fivo/experimental/models.py
+0
-1227
research/fivo/experimental/run.sh
research/fivo/experimental/run.sh
+0
-54
research/fivo/experimental/summary_utils.py
research/fivo/experimental/summary_utils.py
+0
-332
research/fivo/experimental/train.py
research/fivo/experimental/train.py
+0
-637
research/fivo/fivo/__init__.py
research/fivo/fivo/__init__.py
+0
-0
research/fivo/fivo/bounds.py
research/fivo/fivo/bounds.py
+0
-317
research/fivo/fivo/bounds_test.py
research/fivo/fivo/bounds_test.py
+0
-183
research/fivo/fivo/data/__init__.py
research/fivo/fivo/data/__init__.py
+0
-0
research/fivo/fivo/data/calculate_pianoroll_mean.py
research/fivo/fivo/data/calculate_pianoroll_mean.py
+0
-65
research/fivo/fivo/data/create_timit_dataset.py
research/fivo/fivo/data/create_timit_dataset.py
+0
-180
research/fivo/fivo/data/datasets.py
research/fivo/fivo/data/datasets.py
+0
-453
research/fivo/fivo/data/datasets_test.py
research/fivo/fivo/data/datasets_test.py
+0
-303
research/fivo/fivo/ghmm_runners.py
research/fivo/fivo/ghmm_runners.py
+0
-235
research/fivo/fivo/ghmm_runners_test.py
research/fivo/fivo/ghmm_runners_test.py
+0
-106
research/fivo/fivo/models/__init__.py
research/fivo/fivo/models/__init__.py
+0
-0
research/fivo/fivo/models/base.py
research/fivo/fivo/models/base.py
+0
-342
No files found.
Too many changes to show.
To preserve performance only
278 of 278+
files are displayed.
Plain diff
Email patch
research/fivo/experimental/bounds.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
namedtuple
import
tensorflow
as
tf
import
summary_utils
as
summ
Loss
=
namedtuple
(
"Loss"
,
"name loss vars"
)
Loss
.
__new__
.
__defaults__
=
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,)
def
iwae
(
model
,
observation
,
num_timesteps
,
num_samples
=
1
,
summarize
=
False
):
"""Compute the IWAE evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
num_samples: The number of samples to use to compute the IWAE bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: A no-op included for compatibility with FIVO.
states: The sequence of states sampled.
"""
# Initialization
num_instances
=
tf
.
shape
(
observation
)[
0
]
batch_size
=
tf
.
cast
(
num_instances
/
num_samples
,
tf
.
int32
)
states
=
[
model
.
zero_state
(
num_instances
)]
log_weights
=
[]
log_weight_acc
=
tf
.
zeros
([
num_samples
,
batch_size
],
dtype
=
observation
.
dtype
)
for
t
in
xrange
(
num_timesteps
):
# run the model for one timestep
(
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
_
)
=
model
(
states
[
-
1
],
observation
,
t
)
# update accumulators
states
.
append
(
zt
)
log_weight
=
log_p_zt
+
log_p_x_given_z
-
log_q_zt
log_weight_acc
+=
tf
.
reshape
(
log_weight
,
[
num_samples
,
batch_size
])
if
summarize
:
weight_dist
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
tf
.
transpose
(
log_weight_acc
,
perm
=
[
1
,
0
]),
allow_nan_stats
=
False
)
weight_entropy
=
weight_dist
.
entropy
()
weight_entropy
=
tf
.
reduce_mean
(
weight_entropy
)
tf
.
summary
.
scalar
(
"weight_entropy/%d"
%
t
,
weight_entropy
)
log_weights
.
append
(
log_weight_acc
)
# Compute the lower bound on the log evidence.
log_p_hat
=
(
tf
.
reduce_logsumexp
(
log_weight_acc
,
axis
=
0
)
-
tf
.
log
(
tf
.
cast
(
num_samples
,
observation
.
dtype
)))
/
num_timesteps
loss
=
-
tf
.
reduce_mean
(
log_p_hat
)
losses
=
[
Loss
(
"log_p_hat"
,
loss
)]
# we clip off the initial state before returning.
# there are no emas for iwae, so we return a noop for that
return
log_p_hat
,
losses
,
tf
.
no_op
(),
states
[
1
:],
log_weights
def
multinomial_resampling
(
log_weights
,
states
,
n
,
b
):
"""Resample states with multinomial resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights
=
tf
.
convert_to_tensor
(
log_weights
)
states
=
[
tf
.
convert_to_tensor
(
state
)
for
state
in
states
]
resampling_parameters
=
tf
.
transpose
(
log_weights
,
perm
=
[
1
,
0
])
resampling_dist
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
resampling_parameters
)
ancestors
=
tf
.
stop_gradient
(
resampling_dist
.
sample
(
sample_shape
=
n
))
log_probs
=
resampling_dist
.
log_prob
(
ancestors
)
offset
=
tf
.
expand_dims
(
tf
.
range
(
b
),
0
)
ancestor_inds
=
tf
.
reshape
(
ancestors
*
b
+
offset
,
[
-
1
])
resampled_states
=
[]
for
state
in
states
:
resampled_states
.
append
(
tf
.
gather
(
state
,
ancestor_inds
))
return
resampled_states
,
log_probs
,
resampling_parameters
,
ancestors
,
resampling_dist
def
stratified_resampling
(
log_weights
,
states
,
n
,
b
):
"""Resample states with straitified resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights
=
tf
.
convert_to_tensor
(
log_weights
)
states
=
[
tf
.
convert_to_tensor
(
state
)
for
state
in
states
]
log_weights
=
tf
.
transpose
(
log_weights
,
perm
=
[
1
,
0
])
probs
=
tf
.
nn
.
softmax
(
tf
.
tile
(
tf
.
expand_dims
(
log_weights
,
axis
=
1
),
[
1
,
n
,
1
])
)
cdfs
=
tf
.
concat
([
tf
.
zeros
((
b
,
n
,
1
),
dtype
=
probs
.
dtype
),
tf
.
cumsum
(
probs
,
axis
=
2
)],
2
)
bins
=
tf
.
range
(
n
,
dtype
=
probs
.
dtype
)
/
n
bins
=
tf
.
tile
(
tf
.
reshape
(
bins
,
[
1
,
-
1
,
1
]),
[
b
,
1
,
n
+
1
])
strat_cdfs
=
tf
.
minimum
(
tf
.
maximum
((
cdfs
-
bins
)
*
n
,
0.0
),
1.0
)
resampling_parameters
=
strat_cdfs
[:,:,
1
:]
-
strat_cdfs
[:,:,:
-
1
]
resampling_dist
=
tf
.
contrib
.
distributions
.
Categorical
(
probs
=
resampling_parameters
,
allow_nan_stats
=
False
)
ancestors
=
tf
.
stop_gradient
(
resampling_dist
.
sample
())
log_probs
=
resampling_dist
.
log_prob
(
ancestors
)
ancestors
=
tf
.
transpose
(
ancestors
,
perm
=
[
1
,
0
])
log_probs
=
tf
.
transpose
(
log_probs
,
perm
=
[
1
,
0
])
offset
=
tf
.
expand_dims
(
tf
.
range
(
b
),
0
)
ancestor_inds
=
tf
.
reshape
(
ancestors
*
b
+
offset
,
[
-
1
])
resampled_states
=
[]
for
state
in
states
:
resampled_states
.
append
(
tf
.
gather
(
state
,
ancestor_inds
))
return
resampled_states
,
log_probs
,
resampling_parameters
,
ancestors
,
resampling_dist
def
systematic_resampling
(
log_weights
,
states
,
n
,
b
):
"""Resample states with systematic resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights
=
tf
.
convert_to_tensor
(
log_weights
)
states
=
[
tf
.
convert_to_tensor
(
state
)
for
state
in
states
]
log_weights
=
tf
.
transpose
(
log_weights
,
perm
=
[
1
,
0
])
probs
=
tf
.
nn
.
softmax
(
tf
.
tile
(
tf
.
expand_dims
(
log_weights
,
axis
=
1
),
[
1
,
n
,
1
])
)
cdfs
=
tf
.
concat
([
tf
.
zeros
((
b
,
n
,
1
),
dtype
=
probs
.
dtype
),
tf
.
cumsum
(
probs
,
axis
=
2
)],
2
)
bins
=
tf
.
range
(
n
,
dtype
=
probs
.
dtype
)
/
n
bins
=
tf
.
tile
(
tf
.
reshape
(
bins
,
[
1
,
-
1
,
1
]),
[
b
,
1
,
n
+
1
])
strat_cdfs
=
tf
.
minimum
(
tf
.
maximum
((
cdfs
-
bins
)
*
n
,
0.0
),
1.0
)
resampling_parameters
=
strat_cdfs
[:,:,
1
:]
-
strat_cdfs
[:,:,:
-
1
]
resampling_dist
=
tf
.
contrib
.
distributions
.
Categorical
(
probs
=
resampling_parameters
,
allow_nan_stats
=
True
)
U
=
tf
.
random_uniform
((
b
,
1
,
1
),
dtype
=
probs
.
dtype
)
ancestors
=
tf
.
stop_gradient
(
tf
.
reduce_sum
(
tf
.
to_float
(
U
>
strat_cdfs
[:,:,
1
:]),
axis
=-
1
))
log_probs
=
resampling_dist
.
log_prob
(
ancestors
)
ancestors
=
tf
.
transpose
(
ancestors
,
perm
=
[
1
,
0
])
log_probs
=
tf
.
transpose
(
log_probs
,
perm
=
[
1
,
0
])
offset
=
tf
.
expand_dims
(
tf
.
range
(
b
,
dtype
=
probs
.
dtype
),
0
)
ancestor_inds
=
tf
.
reshape
(
ancestors
*
b
+
offset
,
[
-
1
])
resampled_states
=
[]
for
state
in
states
:
resampled_states
.
append
(
tf
.
gather
(
state
,
ancestor_inds
))
return
resampled_states
,
log_probs
,
resampling_parameters
,
ancestors
,
resampling_dist
def
log_blend
(
inputs
,
weights
):
"""Blends state in the log space.
Args:
inputs: A set of scalar states, one for each particle in each particle filter.
Should be [num_samples, batch_size].
weights: A set of weights used to blend the state. Each set of weights
should be of dimension [num_samples] (one weight for each previous particle).
There should be one set of weights for each new particle in each particle filter.
Thus the shape should be [num_samples, batch_size, num_samples] where
the first axis indexes new particle and the last axis indexes old particles.
Returns:
blended: The blended states, a tensor of shape [num_samples, batch_size].
"""
raw_max
=
tf
.
reduce_max
(
inputs
,
axis
=
0
,
keepdims
=
True
)
my_max
=
tf
.
stop_gradient
(
tf
.
where
(
tf
.
is_finite
(
raw_max
),
raw_max
,
tf
.
zeros_like
(
raw_max
))
)
# Don't ask.
blended
=
tf
.
log
(
tf
.
einsum
(
"ijk,kj->ij"
,
weights
,
tf
.
exp
(
inputs
-
raw_max
)))
+
my_max
return
blended
def
relaxed_resampling
(
log_weights
,
states
,
num_samples
,
batch_size
,
log_r_x
=
None
,
blend_type
=
"log"
,
temperature
=
0.5
,
straight_through
=
False
):
"""Resample states with relaxed resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b x n) Tensor of relaxed one hot representations of the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
assert
blend_type
in
[
"log"
,
"linear"
],
"Blend type must be 'log' or 'linear'."
log_weights
=
tf
.
convert_to_tensor
(
log_weights
)
states
=
[
tf
.
convert_to_tensor
(
state
)
for
state
in
states
]
state_dim
=
states
[
0
].
get_shape
().
as_list
()[
-
1
]
# weights are num_samples by batch_size, so we transpose to get a
# set of batch_size distributions over [0,num_samples).
resampling_parameters
=
tf
.
transpose
(
log_weights
,
perm
=
[
1
,
0
])
resampling_dist
=
tf
.
contrib
.
distributions
.
RelaxedOneHotCategorical
(
temperature
,
logits
=
resampling_parameters
)
# sample num_samples samples from the distribution, resulting in a
# [num_samples, batch_size, num_samples] Tensor that represents a set of
# [num_samples, batch_size] blending weights. The dimensions represent
# [sample index, batch index, blending weight index]
ancestors
=
resampling_dist
.
sample
(
sample_shape
=
num_samples
)
if
straight_through
:
# Forward pass discrete choices, backwards pass soft choices
hard_ancestor_indices
=
tf
.
argmax
(
ancestors
,
axis
=-
1
)
hard_ancestors
=
tf
.
one_hot
(
hard_ancestor_indices
,
num_samples
,
dtype
=
ancestors
.
dtype
)
ancestors
=
tf
.
stop_gradient
(
hard_ancestors
-
ancestors
)
+
ancestors
log_probs
=
resampling_dist
.
log_prob
(
ancestors
)
if
log_r_x
is
not
None
and
blend_type
==
"log"
:
log_r_x
=
tf
.
reshape
(
log_r_x
,
[
num_samples
,
batch_size
])
log_r_x
=
log_blend
(
log_r_x
,
ancestors
)
log_r_x
=
tf
.
reshape
(
log_r_x
,
[
num_samples
*
batch_size
])
elif
log_r_x
is
not
None
and
blend_type
==
"linear"
:
# If blend type is linear just add log_r to the states that will be blended
# linearly.
states
.
append
(
log_r_x
)
# transpose the 'indices' to be [batch_index, blending weight index, sample index]
ancestor_inds
=
tf
.
transpose
(
ancestors
,
perm
=
[
1
,
2
,
0
])
resampled_states
=
[]
for
state
in
states
:
# state is currently [num_samples * batch_size, state_dim] so we reshape
# to [num_samples, batch_size, state_dim] and then transpose to
# [batch_size, state_size, num_samples]
state
=
tf
.
transpose
(
tf
.
reshape
(
state
,
[
num_samples
,
batch_size
,
-
1
]),
perm
=
[
1
,
2
,
0
])
# state is now (batch_size, state_size, num_samples)
# and ancestor is (batch index, blending weight index, sample index)
# multiplying these gives a matrix of size [batch_size, state_size, num_samples]
next_state
=
tf
.
matmul
(
state
,
ancestor_inds
)
# transpose the state to be [num_samples, batch_size, state_size]
# and then reshape it to match the state format.
next_state
=
tf
.
reshape
(
tf
.
transpose
(
next_state
,
perm
=
[
2
,
0
,
1
]),
[
num_samples
*
batch_size
,
state_dim
])
resampled_states
.
append
(
next_state
)
new_dist
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
resampling_parameters
)
if
log_r_x
is
not
None
and
blend_type
==
"linear"
:
# If blend type is linear pop off log_r that we added to the states.
log_r_x
=
tf
.
squeeze
(
resampled_states
[
-
1
])
resampled_states
=
resampled_states
[:
-
1
]
return
resampled_states
,
log_probs
,
log_r_x
,
resampling_parameters
,
ancestors
,
new_dist
def
fivo
(
model
,
observation
,
num_timesteps
,
resampling_schedule
,
num_samples
=
1
,
use_resampling_grads
=
True
,
resampling_type
=
"multinomial"
,
resampling_temperature
=
0.5
,
aux
=
True
,
summarize
=
False
):
"""Compute the FIVO evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
resampling_schedule: A list of booleans of length num_timesteps, contains
True if a resampling should occur on a specific timestep.
num_samples: The number of samples to use to compute the IWAE bound.
use_resampling_grads: Whether or not to include the resampling gradients
in loss.
resampling type: The type of resampling, one of "multinomial", "stratified",
"relaxed-logblend", "relaxed-linearblend", "relaxed-stateblend", or
"systematic".
resampling_temperature: A positive temperature only used for relaxed
resampling.
aux: If true, compute the FIVO-AUX bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: An op to update the baseline ema used for the resampling
gradients.
states: The sequence of states sampled.
"""
# Initialization
num_instances
=
tf
.
cast
(
tf
.
shape
(
observation
)[
0
],
tf
.
int32
)
batch_size
=
tf
.
cast
(
num_instances
/
num_samples
,
tf
.
int32
)
states
=
[
model
.
zero_state
(
num_instances
)]
prev_state
=
states
[
0
]
log_weight_acc
=
tf
.
zeros
(
shape
=
[
num_samples
,
batch_size
],
dtype
=
observation
.
dtype
)
prev_log_r_zt
=
tf
.
zeros
([
num_instances
],
dtype
=
observation
.
dtype
)
log_weights
=
[]
log_weights_all
=
[]
log_p_hats
=
[]
resampling_log_probs
=
[]
for
t
in
xrange
(
num_timesteps
):
# run the model for one timestep
(
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
log_r_zt
)
=
model
(
prev_state
,
observation
,
t
)
# update accumulators
states
.
append
(
zt
)
log_weight
=
log_p_zt
+
log_p_x_given_z
-
log_q_zt
if
aux
:
if
t
==
num_timesteps
-
1
:
log_weight
-=
prev_log_r_zt
else
:
log_weight
+=
log_r_zt
-
prev_log_r_zt
prev_log_r_zt
=
log_r_zt
log_weight_acc
+=
tf
.
reshape
(
log_weight
,
[
num_samples
,
batch_size
])
log_weights_all
.
append
(
log_weight_acc
)
if
resampling_schedule
[
t
]:
# These objects will be resampled
to_resample
=
[
states
[
-
1
]]
if
aux
and
"relaxed"
not
in
resampling_type
:
to_resample
.
append
(
prev_log_r_zt
)
# do the resampling
if
resampling_type
==
"multinomial"
:
(
resampled
,
resampling_log_prob
,
_
,
_
,
_
)
=
multinomial_resampling
(
log_weight_acc
,
to_resample
,
num_samples
,
batch_size
)
elif
resampling_type
==
"stratified"
:
(
resampled
,
resampling_log_prob
,
_
,
_
,
_
)
=
stratified_resampling
(
log_weight_acc
,
to_resample
,
num_samples
,
batch_size
)
elif
resampling_type
==
"systematic"
:
(
resampled
,
resampling_log_prob
,
_
,
_
,
_
)
=
systematic_resampling
(
log_weight_acc
,
to_resample
,
num_samples
,
batch_size
)
elif
"relaxed"
in
resampling_type
:
if
aux
:
if
resampling_type
==
"relaxed-logblend"
:
(
resampled
,
resampling_log_prob
,
prev_log_r_zt
,
_
,
_
,
_
)
=
relaxed_resampling
(
log_weight_acc
,
to_resample
,
num_samples
,
batch_size
,
temperature
=
resampling_temperature
,
log_r_x
=
prev_log_r_zt
,
blend_type
=
"log"
)
elif
resampling_type
==
"relaxed-linearblend"
:
(
resampled
,
resampling_log_prob
,
prev_log_r_zt
,
_
,
_
,
_
)
=
relaxed_resampling
(
log_weight_acc
,
to_resample
,
num_samples
,
batch_size
,
temperature
=
resampling_temperature
,
log_r_x
=
prev_log_r_zt
,
blend_type
=
"linear"
)
elif
resampling_type
==
"relaxed-stateblend"
:
(
resampled
,
resampling_log_prob
,
_
,
_
,
_
,
_
)
=
relaxed_resampling
(
log_weight_acc
,
to_resample
,
num_samples
,
batch_size
,
temperature
=
resampling_temperature
)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt
=
model
.
r
.
r_xn
(
resampled
[
0
],
t
)
prev_log_r_zt
=
tf
.
reduce_sum
(
prev_r_zt
.
log_prob
(
observation
),
axis
=
[
1
])
elif
resampling_type
==
"relaxed-stateblend-st"
:
(
resampled
,
resampling_log_prob
,
_
,
_
,
_
,
_
)
=
relaxed_resampling
(
log_weight_acc
,
to_resample
,
num_samples
,
batch_size
,
temperature
=
resampling_temperature
,
straight_through
=
True
)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt
=
model
.
r
.
r_xn
(
resampled
[
0
],
t
)
prev_log_r_zt
=
tf
.
reduce_sum
(
prev_r_zt
.
log_prob
(
observation
),
axis
=
[
1
])
else
:
(
resampled
,
resampling_log_prob
,
_
,
_
,
_
,
_
)
=
relaxed_resampling
(
log_weight_acc
,
to_resample
,
num_samples
,
batch_size
,
temperature
=
resampling_temperature
)
#if summarize:
# resampling_entropy = resampling_dist.entropy()
# resampling_entropy = tf.reduce_mean(resampling_entropy)
# tf.summary.scalar("weight_entropy/%d" % t, resampling_entropy)
resampling_log_probs
.
append
(
tf
.
reduce_sum
(
resampling_log_prob
,
axis
=
0
))
prev_state
=
resampled
[
0
]
if
aux
and
"relaxed"
not
in
resampling_type
:
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt should always be [num_instances]
prev_log_r_zt
=
tf
.
squeeze
(
resampled
[
1
])
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats
.
append
(
tf
.
reduce_logsumexp
(
log_weight_acc
,
axis
=
0
)
-
tf
.
log
(
tf
.
cast
(
num_samples
,
dtype
=
observation
.
dtype
)))
# reset the weights
log_weights
.
append
(
log_weight_acc
)
log_weight_acc
=
tf
.
zeros_like
(
log_weight_acc
)
else
:
prev_state
=
states
[
-
1
]
# Compute the final weight update. If we just resampled this will be zero.
final_update
=
(
tf
.
reduce_logsumexp
(
log_weight_acc
,
axis
=
0
)
-
tf
.
log
(
tf
.
cast
(
num_samples
,
dtype
=
observation
.
dtype
)))
# If we ever resampled, then sum up the previous log p hat terms
if
len
(
log_p_hats
)
>
0
:
log_p_hat
=
tf
.
reduce_sum
(
log_p_hats
,
axis
=
0
)
+
final_update
else
:
# otherwise, log_p_hat only comes from the final update
log_p_hat
=
final_update
if
use_resampling_grads
and
any
(
resampling_schedule
):
# compute the rewards
# cumsum([a, b, c]) => [a, a+b, a+b+c]
# learning signal at timestep t is
# [sum from i=t+1 to T of log_p_hat_i for t=1:T]
# so we will compute (sum from i=1 to T of log_p_hat_i)
# and at timestep t will subtract off (sum from i=1 to t of log_p_hat_i)
# rewards is a [num_resampling_events, batch_size] Tensor
rewards
=
tf
.
stop_gradient
(
tf
.
expand_dims
(
log_p_hat
,
0
)
-
tf
.
cumsum
(
log_p_hats
,
axis
=
0
))
batch_avg_rewards
=
tf
.
reduce_mean
(
rewards
,
axis
=
1
)
# compute ema baseline.
# centered_rewards is [num_resampling_events, batch_size]
baseline_ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
0.94
)
maintain_baseline_op
=
baseline_ema
.
apply
([
batch_avg_rewards
])
baseline
=
tf
.
expand_dims
(
baseline_ema
.
average
(
batch_avg_rewards
),
1
)
centered_rewards
=
rewards
-
baseline
if
summarize
:
summ
.
summarize_learning_signal
(
rewards
,
"rewards"
)
summ
.
summarize_learning_signal
(
centered_rewards
,
"centered_rewards"
)
# compute the loss tensor.
resampling_grads
=
tf
.
reduce_sum
(
tf
.
stop_gradient
(
centered_rewards
)
*
resampling_log_probs
,
axis
=
0
)
losses
=
[
Loss
(
"log_p_hat"
,
-
tf
.
reduce_mean
(
log_p_hat
)
/
num_timesteps
),
Loss
(
"resampling_grads"
,
-
tf
.
reduce_mean
(
resampling_grads
)
/
num_timesteps
)]
else
:
losses
=
[
Loss
(
"log_p_hat"
,
-
tf
.
reduce_mean
(
log_p_hat
)
/
num_timesteps
)]
maintain_baseline_op
=
tf
.
no_op
()
log_p_hat
/=
num_timesteps
# we clip off the initial state before returning.
return
log_p_hat
,
losses
,
maintain_baseline_op
,
states
[
1
:],
log_weights_all
def
fivo_aux_td
(
model
,
observation
,
num_timesteps
,
resampling_schedule
,
num_samples
=
1
,
summarize
=
False
):
"""Compute the FIVO_AUX evidence lower bound."""
# Initialization
num_instances
=
tf
.
cast
(
tf
.
shape
(
observation
)[
0
],
tf
.
int32
)
batch_size
=
tf
.
cast
(
num_instances
/
num_samples
,
tf
.
int32
)
states
=
[
model
.
zero_state
(
num_instances
)]
prev_state
=
states
[
0
]
log_weight_acc
=
tf
.
zeros
(
shape
=
[
num_samples
,
batch_size
],
dtype
=
observation
.
dtype
)
prev_log_r
=
tf
.
zeros
([
num_instances
],
dtype
=
observation
.
dtype
)
# must be pre-resampling
log_rs
=
[]
# must be post-resampling
r_tilde_params
=
[
model
.
r_tilde
.
r_zt
(
states
[
0
],
observation
,
0
)]
log_r_tildes
=
[]
log_p_xs
=
[]
# contains the weight at each timestep before resampling only on resampling timesteps
log_weights
=
[]
# contains weight at each timestep before resampling
log_weights_all
=
[]
log_p_hats
=
[]
for
t
in
xrange
(
num_timesteps
):
# run the model for one timestep
# zt is state, [num_instances, state_dim]
# log_q_zt, log_p_x_given_z is [num_instances]
# r_tilde_mu, r_tilde_sigma is [num_instances, state_dim]
# p_ztplus1 is a normal distribution on [num_instances, state_dim]
(
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
r_tilde_mu
,
r_tilde_sigma_sq
,
p_ztplus1
)
=
model
(
prev_state
,
observation
,
t
)
# Compute the log weight without log r.
log_weight
=
log_p_zt
+
log_p_x_given_z
-
log_q_zt
# Compute log r.
if
t
==
num_timesteps
-
1
:
log_r
=
tf
.
zeros_like
(
prev_log_r
)
else
:
p_mu
=
p_ztplus1
.
mean
()
p_sigma_sq
=
p_ztplus1
.
variance
()
log_r
=
(
tf
.
log
(
r_tilde_sigma_sq
)
-
tf
.
log
(
r_tilde_sigma_sq
+
p_sigma_sq
)
-
tf
.
square
(
r_tilde_mu
-
p_mu
)
/
(
r_tilde_sigma_sq
+
p_sigma_sq
))
log_r
=
0.5
*
tf
.
reduce_sum
(
log_r
,
axis
=-
1
)
#log_weight += tf.stop_gradient(log_r - prev_log_r)
log_weight
+=
log_r
-
prev_log_r
log_weight_acc
+=
tf
.
reshape
(
log_weight
,
[
num_samples
,
batch_size
])
# Update accumulators
states
.
append
(
zt
)
log_weights_all
.
append
(
log_weight_acc
)
log_p_xs
.
append
(
log_p_x_given_z
)
log_rs
.
append
(
log_r
)
# Compute log_r_tilde as [num_instances] Tensor.
prev_r_tilde_mu
,
prev_r_tilde_sigma_sq
=
r_tilde_params
[
-
1
]
prev_log_r_tilde
=
-
0.5
*
tf
.
reduce_sum
(
tf
.
square
(
zt
-
prev_r_tilde_mu
)
/
prev_r_tilde_sigma_sq
,
axis
=-
1
)
#tf.square(tf.stop_gradient(zt) - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
#tf.square(zt - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
log_r_tildes
.
append
(
prev_log_r_tilde
)
# optionally resample
if
resampling_schedule
[
t
]:
# These objects will be resampled
if
t
<
num_timesteps
-
1
:
to_resample
=
[
zt
,
log_r
,
r_tilde_mu
,
r_tilde_sigma_sq
]
else
:
to_resample
=
[
zt
,
log_r
]
(
resampled
,
_
,
_
,
_
,
_
)
=
multinomial_resampling
(
log_weight_acc
,
to_resample
,
num_samples
,
batch_size
)
prev_state
=
resampled
[
0
]
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt and log_r_tilde should always be [num_instances]
prev_log_r
=
tf
.
squeeze
(
resampled
[
1
])
if
t
<
num_timesteps
-
1
:
r_tilde_params
.
append
((
resampled
[
2
],
resampled
[
3
]))
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats
.
append
(
tf
.
reduce_logsumexp
(
log_weight_acc
,
axis
=
0
)
-
tf
.
log
(
tf
.
cast
(
num_samples
,
dtype
=
observation
.
dtype
)))
# reset the weights
log_weights
.
append
(
log_weight_acc
)
log_weight_acc
=
tf
.
zeros_like
(
log_weight_acc
)
else
:
prev_state
=
zt
prev_log_r
=
log_r
if
t
<
num_timesteps
-
1
:
r_tilde_params
.
append
((
r_tilde_mu
,
r_tilde_sigma_sq
))
# Compute the final weight update. If we just resampled this will be zero.
final_update
=
(
tf
.
reduce_logsumexp
(
log_weight_acc
,
axis
=
0
)
-
tf
.
log
(
tf
.
cast
(
num_samples
,
dtype
=
observation
.
dtype
)))
# If we ever resampled, then sum up the previous log p hat terms
if
len
(
log_p_hats
)
>
0
:
log_p_hat
=
tf
.
reduce_sum
(
log_p_hats
,
axis
=
0
)
+
final_update
else
:
# otherwise, log_p_hat only comes from the final update
log_p_hat
=
final_update
# Compute the bellman loss.
# Will remove the first timestep as it is not used.
# log p(x_t|z_t) is in row t-1.
log_p_x
=
tf
.
reshape
(
tf
.
stack
(
log_p_xs
),
[
num_timesteps
,
num_samples
,
batch_size
])
# log r_t is contained in row t-1.
# last column is zeros (because at timestep T (num_timesteps) r is 1.
log_r
=
tf
.
reshape
(
tf
.
stack
(
log_rs
),
[
num_timesteps
,
num_samples
,
batch_size
])
# [num_timesteps, num_instances]. log r_tilde_t is in row t-1.
log_r_tilde
=
tf
.
reshape
(
tf
.
stack
(
log_r_tildes
),
[
num_timesteps
,
num_samples
,
batch_size
])
log_lambda
=
tf
.
reduce_mean
(
log_r_tilde
-
log_p_x
-
log_r
,
axis
=
1
,
keepdims
=
True
)
bellman_sos
=
tf
.
reduce_mean
(
tf
.
square
(
log_r_tilde
-
tf
.
stop_gradient
(
log_lambda
+
log_p_x
+
log_r
)),
axis
=
[
0
,
1
])
bellman_loss
=
tf
.
reduce_mean
(
bellman_sos
)
/
num_timesteps
tf
.
summary
.
scalar
(
"bellman_loss"
,
bellman_loss
)
if
len
(
tf
.
get_collection
(
"LOG_P_HAT_VARS"
))
==
0
:
log_p_hat_collection
=
list
(
set
(
tf
.
trainable_variables
())
-
set
(
tf
.
get_collection
(
"R_TILDE_VARS"
)))
for
v
in
log_p_hat_collection
:
tf
.
add_to_collection
(
"LOG_P_HAT_VARS"
,
v
)
log_p_hat
/=
num_timesteps
losses
=
[
Loss
(
"log_p_hat"
,
-
tf
.
reduce_mean
(
log_p_hat
),
"LOG_P_HAT_VARS"
),
Loss
(
"bellman_loss"
,
bellman_loss
,
"R_TILDE_VARS"
)]
return
log_p_hat
,
losses
,
tf
.
no_op
(),
states
[
1
:],
log_weights_all
research/fivo/experimental/data.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Datasets."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
import
models
def
make_long_chain_dataset
(
state_size
=
1
,
num_obs
=
5
,
steps_per_obs
=
3
,
variance
=
1.
,
observation_variance
=
1.
,
batch_size
=
4
,
num_samples
=
1
,
observation_type
=
models
.
STANDARD_OBSERVATION
,
transition_type
=
models
.
STANDARD_TRANSITION
,
fixed_observation
=
None
,
dtype
=
"float32"
):
"""Creates a long chain data generating process.
Creates a tf.data.Dataset that provides batches of data from a long
chain.
Args:
state_size: The dimension of the state space of the process.
num_obs: The number of observations in the chain.
steps_per_obs: The number of steps between each observation.
variance: The variance of the normal distributions used at each timestep.
batch_size: The number of trajectories to include in each batch.
num_samples: The number of replicas of each trajectory to include in each
batch.
dtype: The datatype of the states and observations.
Returns:
dataset: A tf.data.Dataset that can be iterated over.
"""
num_timesteps
=
num_obs
*
steps_per_obs
def
data_generator
():
"""An infinite generator of latents and observations from the model."""
while
True
:
states
=
[]
observations
=
[]
# z0 ~ Normal(0, sqrt(variance)).
states
.
append
(
np
.
random
.
normal
(
size
=
[
state_size
],
scale
=
np
.
sqrt
(
variance
)).
astype
(
dtype
))
# start at 1 because we've already generated z0
# go to num_timesteps+1 because we want to include the num_timesteps-th step
for
t
in
xrange
(
1
,
num_timesteps
+
1
):
if
transition_type
==
models
.
ROUND_TRANSITION
:
loc
=
np
.
round
(
states
[
-
1
])
elif
transition_type
==
models
.
STANDARD_TRANSITION
:
loc
=
states
[
-
1
]
new_state
=
np
.
random
.
normal
(
size
=
[
state_size
],
loc
=
loc
,
scale
=
np
.
sqrt
(
variance
))
states
.
append
(
new_state
.
astype
(
dtype
))
if
t
%
steps_per_obs
==
0
:
if
fixed_observation
is
None
:
if
observation_type
==
models
.
SQUARED_OBSERVATION
:
loc
=
np
.
square
(
states
[
-
1
])
elif
observation_type
==
models
.
ABS_OBSERVATION
:
loc
=
np
.
abs
(
states
[
-
1
])
elif
observation_type
==
models
.
STANDARD_OBSERVATION
:
loc
=
states
[
-
1
]
new_obs
=
np
.
random
.
normal
(
size
=
[
state_size
],
loc
=
loc
,
scale
=
np
.
sqrt
(
observation_variance
)).
astype
(
dtype
)
else
:
new_obs
=
np
.
ones
([
state_size
])
*
fixed_observation
observations
.
append
(
new_obs
)
yield
states
,
observations
dataset
=
tf
.
data
.
Dataset
.
from_generator
(
data_generator
,
output_types
=
(
tf
.
as_dtype
(
dtype
),
tf
.
as_dtype
(
dtype
)),
output_shapes
=
([
num_timesteps
+
1
,
state_size
],
[
num_obs
,
state_size
]))
dataset
=
dataset
.
repeat
().
batch
(
batch_size
)
def
tile_batch
(
state
,
observation
):
state
=
tf
.
tile
(
state
,
[
num_samples
,
1
,
1
])
observation
=
tf
.
tile
(
observation
,
[
num_samples
,
1
,
1
])
return
state
,
observation
dataset
=
dataset
.
map
(
tile_batch
,
num_parallel_calls
=
12
).
prefetch
(
1024
)
return
dataset
def
make_dataset
(
bs
=
None
,
state_size
=
1
,
num_timesteps
=
10
,
variance
=
1.
,
prior_type
=
"unimodal"
,
bimodal_prior_weight
=
0.5
,
bimodal_prior_mean
=
1
,
transition_type
=
models
.
STANDARD_TRANSITION
,
fixed_observation
=
None
,
batch_size
=
4
,
num_samples
=
1
,
dtype
=
'float32'
):
"""Creates a data generating process.
Creates a tf.data.Dataset that provides batches of data.
Args:
bs: The parameters of the data generating process. If None, new bs are
randomly generated.
state_size: The dimension of the state space of the process.
num_timesteps: The length of the state sequences in the process.
variance: The variance of the normal distributions used at each timestep.
batch_size: The number of trajectories to include in each batch.
num_samples: The number of replicas of each trajectory to include in each
batch.
Returns:
bs: The true bs used to generate the data
dataset: A tf.data.Dataset that can be iterated over.
"""
if
bs
is
None
:
bs
=
[
np
.
random
.
uniform
(
size
=
[
state_size
]).
astype
(
dtype
)
for
_
in
xrange
(
num_timesteps
)]
tf
.
logging
.
info
(
"data generating processs bs: %s"
,
np
.
array
(
bs
).
reshape
(
num_timesteps
))
def
data_generator
():
"""An infinite generator of latents and observations from the model."""
while
True
:
states
=
[]
if
prior_type
==
"unimodal"
or
prior_type
==
"nonlinear"
:
# Prior is Normal(0, sqrt(variance)).
states
.
append
(
np
.
random
.
normal
(
size
=
[
state_size
],
scale
=
np
.
sqrt
(
variance
)).
astype
(
dtype
))
elif
prior_type
==
"bimodal"
:
if
np
.
random
.
uniform
()
>
bimodal_prior_weight
:
loc
=
bimodal_prior_mean
else
:
loc
=
-
bimodal_prior_mean
states
.
append
(
np
.
random
.
normal
(
size
=
[
state_size
],
loc
=
loc
,
scale
=
np
.
sqrt
(
variance
)
).
astype
(
dtype
))
for
t
in
xrange
(
num_timesteps
):
if
transition_type
==
models
.
ROUND_TRANSITION
:
loc
=
np
.
round
(
states
[
-
1
])
elif
transition_type
==
models
.
STANDARD_TRANSITION
:
loc
=
states
[
-
1
]
loc
+=
bs
[
t
]
new_state
=
np
.
random
.
normal
(
size
=
[
state_size
],
loc
=
loc
,
scale
=
np
.
sqrt
(
variance
)).
astype
(
dtype
)
states
.
append
(
new_state
)
if
fixed_observation
is
None
:
observation
=
states
[
-
1
]
else
:
observation
=
np
.
ones_like
(
states
[
-
1
])
*
fixed_observation
yield
np
.
array
(
states
[:
-
1
]),
observation
dataset
=
tf
.
data
.
Dataset
.
from_generator
(
data_generator
,
output_types
=
(
tf
.
as_dtype
(
dtype
),
tf
.
as_dtype
(
dtype
)),
output_shapes
=
([
num_timesteps
,
state_size
],
[
state_size
]))
dataset
=
dataset
.
repeat
().
batch
(
batch_size
)
def
tile_batch
(
state
,
observation
):
state
=
tf
.
tile
(
state
,
[
num_samples
,
1
,
1
])
observation
=
tf
.
tile
(
observation
,
[
num_samples
,
1
])
return
state
,
observation
dataset
=
dataset
.
map
(
tile_batch
,
num_parallel_calls
=
12
).
prefetch
(
1024
)
return
np
.
array
(
bs
),
dataset
research/fivo/experimental/models.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
functools
import
sonnet
as
snt
import
tensorflow
as
tf
import
numpy
as
np
import
math
SQUARED_OBSERVATION
=
"squared"
ABS_OBSERVATION
=
"abs"
STANDARD_OBSERVATION
=
"standard"
OBSERVATION_TYPES
=
[
SQUARED_OBSERVATION
,
ABS_OBSERVATION
,
STANDARD_OBSERVATION
]
ROUND_TRANSITION
=
"round"
STANDARD_TRANSITION
=
"standard"
TRANSITION_TYPES
=
[
ROUND_TRANSITION
,
STANDARD_TRANSITION
]
class
Q
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
init_mu0_to_zero
=
False
,
graph_collection_name
=
"Q_VARS"
):
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
graph_collection_name
=
graph_collection_name
initializers
=
[]
for
t
in
xrange
(
num_timesteps
):
if
t
==
0
and
init_mu0_to_zero
:
initializers
.
append
(
{
"w"
:
tf
.
zeros_initializer
,
"b"
:
tf
.
zeros_initializer
})
else
:
initializers
.
append
(
{
"w"
:
tf
.
random_uniform_initializer
(
seed
=
random_seed
),
"b"
:
tf
.
zeros_initializer
})
def
custom_getter
(
getter
,
*
args
,
**
kwargs
):
out
=
getter
(
*
args
,
**
kwargs
)
ref
=
tf
.
get_collection_ref
(
self
.
graph_collection_name
)
if
out
not
in
ref
:
ref
.
append
(
out
)
return
out
self
.
mus
=
[
snt
.
Linear
(
output_size
=
state_size
,
initializers
=
initializers
[
t
],
name
=
"q_mu_%d"
%
t
,
custom_getter
=
custom_getter
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"q_sigma_%d"
%
(
t
+
1
),
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
initializer
=
tf
.
random_uniform_initializer
(
seed
=
random_seed
))
for
t
in
xrange
(
num_timesteps
)
]
def
q_zt
(
self
,
observation
,
prev_state
,
t
):
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
q_mu
=
self
.
mus
[
t
](
tf
.
concat
([
observation
,
prev_state
],
axis
=
1
))
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
return
q_zt
def
summarize_weights
(
self
):
for
t
,
sigma
in
enumerate
(
self
.
sigmas
):
tf
.
summary
.
scalar
(
"q_sigma/%d"
%
t
,
sigma
[
0
])
for
t
,
f
in
enumerate
(
self
.
mus
):
tf
.
summary
.
scalar
(
"q_mu/b_%d"
%
t
,
f
.
b
[
0
])
tf
.
summary
.
scalar
(
"q_mu/w_obs_%d"
%
t
,
f
.
w
[
0
,
0
])
if
t
!=
0
:
tf
.
summary
.
scalar
(
"q_mu/w_prev_state_%d"
%
t
,
f
.
w
[
1
,
0
])
class
PreviousStateQ
(
Q
):
def
q_zt
(
self
,
unused_observation
,
prev_state
,
t
):
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
q_mu
=
self
.
mus
[
t
](
prev_state
)
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
return
q_zt
def
summarize_weights
(
self
):
for
t
,
sigma
in
enumerate
(
self
.
sigmas
):
tf
.
summary
.
scalar
(
"q_sigma/%d"
%
t
,
sigma
[
0
])
for
t
,
f
in
enumerate
(
self
.
mus
):
tf
.
summary
.
scalar
(
"q_mu/b_%d"
%
t
,
f
.
b
[
0
])
tf
.
summary
.
scalar
(
"q_mu/w_prev_state_%d"
%
t
,
f
.
w
[
0
,
0
])
class
ObservationQ
(
Q
):
def
q_zt
(
self
,
observation
,
prev_state
,
t
):
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
q_mu
=
self
.
mus
[
t
](
observation
)
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
return
q_zt
def
summarize_weights
(
self
):
for
t
,
sigma
in
enumerate
(
self
.
sigmas
):
tf
.
summary
.
scalar
(
"q_sigma/%d"
%
t
,
sigma
[
0
])
for
t
,
f
in
enumerate
(
self
.
mus
):
tf
.
summary
.
scalar
(
"q_mu/b_%d"
%
t
,
f
.
b
[
0
])
tf
.
summary
.
scalar
(
"q_mu/w_obs_%d"
%
t
,
f
.
w
[
0
,
0
])
class
SimpleMeanQ
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
init_mu0_to_zero
=
False
,
graph_collection_name
=
"Q_VARS"
):
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
graph_collection_name
=
graph_collection_name
initializers
=
[]
for
t
in
xrange
(
num_timesteps
):
if
t
==
0
and
init_mu0_to_zero
:
initializers
.
append
(
tf
.
zeros_initializer
)
else
:
initializers
.
append
(
tf
.
random_uniform_initializer
(
seed
=
random_seed
))
self
.
mus
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"q_mu_%d"
%
(
t
+
1
),
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
initializer
=
initializers
[
t
])
for
t
in
xrange
(
num_timesteps
)
]
self
.
sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"q_sigma_%d"
%
(
t
+
1
),
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
initializer
=
tf
.
random_uniform_initializer
(
seed
=
random_seed
))
for
t
in
xrange
(
num_timesteps
)
]
def
q_zt
(
self
,
unused_observation
,
prev_state
,
t
):
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
q_mu
=
tf
.
tile
(
self
.
mus
[
t
][
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
return
q_zt
def
summarize_weights
(
self
):
for
t
,
sigma
in
enumerate
(
self
.
sigmas
):
tf
.
summary
.
scalar
(
"q_sigma/%d"
%
t
,
sigma
[
0
])
for
t
,
f
in
enumerate
(
self
.
mus
):
tf
.
summary
.
scalar
(
"q_mu/%d"
%
t
,
f
[
0
])
class
R
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
sigma_init
=
1.
,
random_seed
=
None
,
graph_collection_name
=
"R_VARS"
):
self
.
dtype
=
dtype
self
.
sigma_min
=
sigma_min
initializers
=
{
"w"
:
tf
.
truncated_normal_initializer
(
seed
=
random_seed
),
"b"
:
tf
.
zeros_initializer
}
self
.
graph_collection_name
=
graph_collection_name
def
custom_getter
(
getter
,
*
args
,
**
kwargs
):
out
=
getter
(
*
args
,
**
kwargs
)
ref
=
tf
.
get_collection_ref
(
self
.
graph_collection_name
)
if
out
not
in
ref
:
ref
.
append
(
out
)
return
out
self
.
mus
=
[
snt
.
Linear
(
output_size
=
state_size
,
initializers
=
initializers
,
name
=
"r_mu_%d"
%
t
,
custom_getter
=
custom_getter
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"r_sigma_%d"
%
(
t
+
1
),
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer
=
tf
.
constant_initializer
(
sigma_init
))
for
t
in
xrange
(
num_timesteps
)
]
def
r_xn
(
self
,
z_t
,
t
):
batch_size
=
tf
.
shape
(
z_t
)[
0
]
r_mu
=
self
.
mus
[
t
](
z_t
)
r_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
r_sigma
=
tf
.
tile
(
r_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
r_mu
,
scale
=
tf
.
sqrt
(
r_sigma
))
def
summarize_weights
(
self
):
for
t
in
range
(
len
(
self
.
mus
)
-
1
):
tf
.
summary
.
scalar
(
"r_mu/%d"
%
t
,
self
.
mus
[
t
][
0
])
tf
.
summary
.
scalar
(
"r_sigma/%d"
%
t
,
self
.
sigmas
[
t
][
0
])
class
P
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
variance
=
1.0
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
trainable
=
True
,
init_bs_to_zero
=
False
,
graph_collection_name
=
"P_VARS"
):
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
variance
=
variance
self
.
graph_collection_name
=
graph_collection_name
if
init_bs_to_zero
:
initializers
=
[
tf
.
zeros_initializer
for
_
in
xrange
(
num_timesteps
)]
else
:
initializers
=
[
tf
.
random_uniform_initializer
(
seed
=
random_seed
)
for
_
in
xrange
(
num_timesteps
)]
self
.
bs
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"p_b_%d"
%
(
t
+
1
),
initializer
=
initializers
[
t
],
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
trainable
=
trainable
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
Bs
=
tf
.
cumsum
(
self
.
bs
,
reverse
=
True
,
axis
=
0
)
def
posterior
(
self
,
observation
,
prev_state
,
t
):
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu
=
observation
-
self
.
Bs
[
t
]
if
t
>
0
:
mu
+=
(
prev_state
+
self
.
bs
[
t
-
1
])
*
float
(
self
.
num_timesteps
-
t
)
mu
/=
float
(
self
.
num_timesteps
-
t
+
1
)
sigma
=
tf
.
ones_like
(
mu
)
*
self
.
variance
*
(
float
(
self
.
num_timesteps
-
t
)
/
float
(
self
.
num_timesteps
-
t
+
1
))
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
def
lookahead
(
self
,
state
,
t
):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu
=
state
+
self
.
Bs
[
t
]
sigma
=
tf
.
ones_like
(
state
)
*
self
.
variance
*
float
(
self
.
num_timesteps
-
t
)
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
def
likelihood
(
self
,
observation
):
batch_size
=
tf
.
shape
(
observation
)[
0
]
mu
=
tf
.
tile
(
tf
.
reduce_sum
(
self
.
bs
,
axis
=
0
)[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
sigma
=
tf
.
ones_like
(
mu
)
*
self
.
variance
*
(
self
.
num_timesteps
+
1
)
dist
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
# Average over the batch and take the sum over the state size
return
tf
.
reduce_mean
(
tf
.
reduce_sum
(
dist
.
log_prob
(
observation
),
axis
=
1
))
def
p_zt
(
self
,
prev_state
,
t
):
"""Computes the model p(z_t| z_{t-1})."""
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
if
t
>
0
:
z_mu_p
=
prev_state
+
self
.
bs
[
t
-
1
]
else
:
# p(z_0) is Normal(0,1)
z_mu_p
=
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
p_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
z_mu_p
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
z_mu_p
)
*
self
.
variance
))
return
p_zt
def
generative
(
self
,
unused_observation
,
z_nm1
):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu
=
z_nm1
+
self
.
bs
[
-
1
]
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
generative_p_mu
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
generative_p_mu
)
*
self
.
variance
))
class
ShortChainNonlinearP
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
variance
=
1.0
,
observation_variance
=
1.0
,
transition_type
=
STANDARD_TRANSITION
,
transition_dist
=
tf
.
contrib
.
distributions
.
Normal
,
dtype
=
tf
.
float32
,
random_seed
=
None
):
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
variance
=
variance
self
.
observation_variance
=
observation_variance
self
.
transition_type
=
transition_type
self
.
transition_dist
=
transition_dist
def
p_zt
(
self
,
prev_state
,
t
):
"""Computes the model p(z_t| z_{t-1})."""
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
if
t
>
0
:
if
self
.
transition_type
==
ROUND_TRANSITION
:
loc
=
tf
.
round
(
prev_state
)
tf
.
logging
.
info
(
"p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)"
%
(
t
,
t
-
1
,
t
-
1
,
self
.
variance
))
elif
self
.
transition_type
==
STANDARD_TRANSITION
:
loc
=
prev_state
tf
.
logging
.
info
(
"p(z_%d | z_%d) ~ N(z_%d, %0.1f)"
%
(
t
,
t
-
1
,
t
-
1
,
self
.
variance
))
else
:
# p(z_0) is Normal(0,1)
loc
=
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
tf
.
logging
.
info
(
"p(z_0) ~ N(0,%0.1f)"
%
self
.
variance
)
p_zt
=
self
.
transition_dist
(
loc
=
loc
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
loc
)
*
self
.
variance
))
return
p_zt
def
generative
(
self
,
unused_obs
,
z_ni
):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if
self
.
transition_type
==
ROUND_TRANSITION
:
loc
=
tf
.
round
(
z_ni
)
elif
self
.
transition_type
==
STANDARD_TRANSITION
:
loc
=
z_ni
generative_sigma_sq
=
tf
.
ones_like
(
loc
)
*
self
.
observation_variance
return
self
.
transition_dist
(
loc
=
loc
,
scale
=
tf
.
sqrt
(
generative_sigma_sq
))
class
BimodalPriorP
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
mixing_coeff
=
0.5
,
prior_mode_mean
=
1
,
sigma_min
=
1e-5
,
variance
=
1.0
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
trainable
=
True
,
init_bs_to_zero
=
False
,
graph_collection_name
=
"P_VARS"
):
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
variance
=
variance
self
.
mixing_coeff
=
mixing_coeff
self
.
prior_mode_mean
=
prior_mode_mean
if
init_bs_to_zero
:
initializers
=
[
tf
.
zeros_initializer
for
_
in
xrange
(
num_timesteps
)]
else
:
initializers
=
[
tf
.
random_uniform_initializer
(
seed
=
random_seed
)
for
_
in
xrange
(
num_timesteps
)]
self
.
bs
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"b_%d"
%
(
t
+
1
),
initializer
=
initializers
[
t
],
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
trainable
=
trainable
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
Bs
=
tf
.
cumsum
(
self
.
bs
,
reverse
=
True
,
axis
=
0
)
def
posterior
(
self
,
observation
,
prev_state
,
t
):
# NOTE: This is currently wrong, but would require a refactoring of
# summarize_q to fix as kl is not defined for a mixture
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu
=
observation
-
self
.
Bs
[
t
]
if
t
>
0
:
mu
+=
(
prev_state
+
self
.
bs
[
t
-
1
])
*
float
(
self
.
num_timesteps
-
t
)
mu
/=
float
(
self
.
num_timesteps
-
t
+
1
)
sigma
=
tf
.
ones_like
(
mu
)
*
self
.
variance
*
(
float
(
self
.
num_timesteps
-
t
)
/
float
(
self
.
num_timesteps
-
t
+
1
))
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
def
lookahead
(
self
,
state
,
t
):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu
=
state
+
self
.
Bs
[
t
]
sigma
=
tf
.
ones_like
(
state
)
*
self
.
variance
*
float
(
self
.
num_timesteps
-
t
)
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
def
likelihood
(
self
,
observation
):
batch_size
=
tf
.
shape
(
observation
)[
0
]
sum_of_bs
=
tf
.
tile
(
tf
.
reduce_sum
(
self
.
bs
,
axis
=
0
)[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
sigma
=
tf
.
ones_like
(
sum_of_bs
)
*
self
.
variance
*
(
self
.
num_timesteps
+
1
)
mu_pos
=
(
tf
.
ones
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
*
self
.
prior_mode_mean
)
+
sum_of_bs
mu_neg
=
(
tf
.
ones
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
*
-
self
.
prior_mode_mean
)
+
sum_of_bs
zn_pos
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu_pos
,
scale
=
tf
.
sqrt
(
sigma
))
zn_neg
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu_neg
,
scale
=
tf
.
sqrt
(
sigma
))
mode_probs
=
tf
.
convert_to_tensor
([
self
.
mixing_coeff
,
1
-
self
.
mixing_coeff
],
dtype
=
tf
.
float64
)
mode_probs
=
tf
.
tile
(
mode_probs
[
tf
.
newaxis
,
tf
.
newaxis
,
:],
[
batch_size
,
1
,
1
])
mode_selection_dist
=
tf
.
contrib
.
distributions
.
Categorical
(
probs
=
mode_probs
)
zn_dist
=
tf
.
contrib
.
distributions
.
Mixture
(
cat
=
mode_selection_dist
,
components
=
[
zn_pos
,
zn_neg
],
validate_args
=
True
)
# Average over the batch and take the sum over the state size
return
tf
.
reduce_mean
(
tf
.
reduce_sum
(
zn_dist
.
log_prob
(
observation
),
axis
=
1
))
def
p_zt
(
self
,
prev_state
,
t
):
"""Computes the model p(z_t| z_{t-1})."""
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
if
t
>
0
:
z_mu_p
=
prev_state
+
self
.
bs
[
t
-
1
]
p_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
z_mu_p
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
z_mu_p
)
*
self
.
variance
))
return
p_zt
else
:
# p(z_0) is mixture of two Normals
mu_pos
=
tf
.
ones
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
*
self
.
prior_mode_mean
mu_neg
=
tf
.
ones
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
*
-
self
.
prior_mode_mean
z0_pos
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu_pos
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
mu_pos
)
*
self
.
variance
))
z0_neg
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu_neg
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
mu_neg
)
*
self
.
variance
))
mode_probs
=
tf
.
convert_to_tensor
([
self
.
mixing_coeff
,
1
-
self
.
mixing_coeff
],
dtype
=
tf
.
float64
)
mode_probs
=
tf
.
tile
(
mode_probs
[
tf
.
newaxis
,
tf
.
newaxis
,
:],
[
batch_size
,
1
,
1
])
mode_selection_dist
=
tf
.
contrib
.
distributions
.
Categorical
(
probs
=
mode_probs
)
z0_dist
=
tf
.
contrib
.
distributions
.
Mixture
(
cat
=
mode_selection_dist
,
components
=
[
z0_pos
,
z0_neg
],
validate_args
=
False
)
return
z0_dist
def
generative
(
self
,
unused_observation
,
z_nm1
):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu
=
z_nm1
+
self
.
bs
[
-
1
]
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
generative_p_mu
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
generative_p_mu
)
*
self
.
variance
))
class
Model
(
object
):
def
__init__
(
self
,
p
,
q
,
r
,
state_size
,
num_timesteps
,
dtype
=
tf
.
float32
):
self
.
p
=
p
self
.
q
=
q
self
.
r
=
r
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
dtype
=
dtype
def
zero_state
(
self
,
batch_size
):
return
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
def
__call__
(
self
,
prev_state
,
observation
,
t
):
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt
=
self
.
q
.
q_zt
(
observation
,
prev_state
,
t
)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt
=
self
.
p
.
p_zt
(
prev_state
,
t
)
# sample from q
zt
=
q_zt
.
sample
()
r_xn
=
self
.
r
.
r_xn
(
zt
,
t
)
# Calculate the logprobs and sum over the state size.
log_q_zt
=
tf
.
reduce_sum
(
q_zt
.
log_prob
(
zt
),
axis
=
1
)
log_p_zt
=
tf
.
reduce_sum
(
p_zt
.
log_prob
(
zt
),
axis
=
1
)
log_r_xn
=
tf
.
reduce_sum
(
r_xn
.
log_prob
(
observation
),
axis
=
1
)
# If we're at the last timestep, also calc the logprob of the observation.
if
t
==
self
.
num_timesteps
-
1
:
generative_dist
=
self
.
p
.
generative
(
observation
,
zt
)
log_p_x_given_z
=
tf
.
reduce_sum
(
generative_dist
.
log_prob
(
observation
),
axis
=
1
)
else
:
log_p_x_given_z
=
tf
.
zeros_like
(
log_q_zt
)
return
(
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
log_r_xn
)
@
staticmethod
def
create
(
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
r_sigma_init
=
1
,
variance
=
1.0
,
mixing_coeff
=
0.5
,
prior_mode_mean
=
1.0
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
train_p
=
True
,
p_type
=
"unimodal"
,
q_type
=
"normal"
,
observation_variance
=
1.0
,
transition_type
=
STANDARD_TRANSITION
,
use_bs
=
True
):
if
p_type
==
"unimodal"
:
p
=
P
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
variance
=
variance
,
dtype
=
dtype
,
random_seed
=
random_seed
,
trainable
=
train_p
,
init_bs_to_zero
=
not
use_bs
)
elif
p_type
==
"bimodal"
:
p
=
BimodalPriorP
(
state_size
,
num_timesteps
,
mixing_coeff
=
mixing_coeff
,
prior_mode_mean
=
prior_mode_mean
,
sigma_min
=
sigma_min
,
variance
=
variance
,
dtype
=
dtype
,
random_seed
=
random_seed
,
trainable
=
train_p
,
init_bs_to_zero
=
not
use_bs
)
elif
"nonlinear"
in
p_type
:
if
"cauchy"
in
p_type
:
trans_dist
=
tf
.
contrib
.
distributions
.
Cauchy
else
:
trans_dist
=
tf
.
contrib
.
distributions
.
Normal
p
=
ShortChainNonlinearP
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
variance
=
variance
,
observation_variance
=
observation_variance
,
transition_type
=
transition_type
,
transition_dist
=
trans_dist
,
dtype
=
dtype
,
random_seed
=
random_seed
)
if
q_type
==
"normal"
:
q_class
=
Q
elif
q_type
==
"simple_mean"
:
q_class
=
SimpleMeanQ
elif
q_type
==
"prev_state"
:
q_class
=
PreviousStateQ
elif
q_type
==
"observation"
:
q_class
=
ObservationQ
q
=
q_class
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
dtype
=
dtype
,
random_seed
=
random_seed
,
init_mu0_to_zero
=
not
use_bs
)
r
=
R
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
sigma_init
=
r_sigma_init
,
dtype
=
dtype
,
random_seed
=
random_seed
)
model
=
Model
(
p
,
q
,
r
,
state_size
,
num_timesteps
,
dtype
=
dtype
)
return
model
class
BackwardsModel
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
):
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
bs
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"b_%d"
%
(
t
+
1
),
initializer
=
tf
.
zeros_initializer
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
Bs
=
tf
.
cumsum
(
self
.
bs
,
reverse
=
True
,
axis
=
0
)
self
.
q_mus
=
[
snt
.
Linear
(
output_size
=
state_size
)
for
_
in
xrange
(
num_timesteps
)
]
self
.
q_sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"q_sigma_%d"
%
(
t
+
1
),
initializer
=
tf
.
zeros_initializer
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
r_mus
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"r_mu_%d"
%
(
t
+
1
),
initializer
=
tf
.
zeros_initializer
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
r_sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"r_sigma_%d"
%
(
t
+
1
),
initializer
=
tf
.
zeros_initializer
)
for
t
in
xrange
(
num_timesteps
)
]
def
zero_state
(
self
,
batch_size
):
return
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
def
posterior
(
self
,
unused_observation
,
prev_state
,
unused_t
):
# TODO(dieterichl): Correct this.
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
tf
.
zeros_like
(
prev_state
),
scale
=
tf
.
zeros_like
(
prev_state
))
def
lookahead
(
self
,
state
,
unused_t
):
# TODO(dieterichl): Correct this.
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
tf
.
zeros_like
(
state
),
scale
=
tf
.
zeros_like
(
state
))
def
q_zt
(
self
,
observation
,
next_state
,
t
):
"""Computes the variational posterior q(z_{t}|z_{t+1}, z_n)."""
t_backwards
=
self
.
num_timesteps
-
t
-
1
batch_size
=
tf
.
shape
(
next_state
)[
0
]
q_mu
=
self
.
q_mus
[
t_backwards
](
tf
.
concat
([
observation
,
next_state
],
axis
=
1
))
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
q_sigmas
[
t_backwards
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
return
q_zt
def
p_zt
(
self
,
prev_state
,
t
):
"""Computes the model p(z_{t+1}| z_{t})."""
t_backwards
=
self
.
num_timesteps
-
t
-
1
z_mu_p
=
prev_state
+
self
.
bs
[
t_backwards
]
p_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
z_mu_p
,
scale
=
tf
.
ones_like
(
z_mu_p
))
return
p_zt
def
generative
(
self
,
unused_observation
,
z_nm1
):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu
=
z_nm1
+
self
.
bs
[
-
1
]
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
generative_p_mu
,
scale
=
tf
.
ones_like
(
generative_p_mu
))
def
r
(
self
,
z_t
,
t
):
t_backwards
=
self
.
num_timesteps
-
t
-
1
batch_size
=
tf
.
shape
(
z_t
)[
0
]
r_mu
=
tf
.
tile
(
self
.
r_mus
[
t_backwards
][
tf
.
newaxis
,
:],
[
batch_size
,
1
])
r_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
r_sigmas
[
t_backwards
]),
self
.
sigma_min
)
r_sigma
=
tf
.
tile
(
r_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
r_mu
,
scale
=
tf
.
sqrt
(
r_sigma
))
def
likelihood
(
self
,
observation
):
batch_size
=
tf
.
shape
(
observation
)[
0
]
mu
=
tf
.
tile
(
tf
.
reduce_sum
(
self
.
bs
,
axis
=
0
)[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
sigma
=
tf
.
ones_like
(
mu
)
*
(
self
.
num_timesteps
+
1
)
dist
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
# Average over the batch and take the sum over the state size
return
tf
.
reduce_mean
(
tf
.
reduce_sum
(
dist
.
log_prob
(
observation
),
axis
=
1
))
def
__call__
(
self
,
next_state
,
observation
,
t
):
# next state = z_{t+1}
# Compute the q distribution over z, q(z_{t}|z_n, z_{t+1}).
q_zt
=
self
.
q_zt
(
observation
,
next_state
,
t
)
# sample from q
zt
=
q_zt
.
sample
()
# Compute the p distribution over z, p(z_{t+1}|z_{t}).
p_zt
=
self
.
p_zt
(
zt
,
t
)
# Compute log p(z_{t+1} | z_t)
if
t
==
0
:
log_p_zt
=
p_zt
.
log_prob
(
observation
)
else
:
log_p_zt
=
p_zt
.
log_prob
(
next_state
)
# Compute r prior over zt
r_zt
=
self
.
r
(
zt
,
t
)
log_r_zt
=
r_zt
.
log_prob
(
zt
)
# Compute proposal density at zt
log_q_zt
=
q_zt
.
log_prob
(
zt
)
# If we're at the last timestep, also calc the logprob of the observation.
if
t
==
self
.
num_timesteps
-
1
:
p_z0_dist
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
tf
.
zeros_like
(
zt
),
scale
=
tf
.
ones_like
(
zt
))
z0_log_prob
=
p_z0_dist
.
log_prob
(
zt
)
else
:
z0_log_prob
=
tf
.
zeros_like
(
log_q_zt
)
return
(
zt
,
log_q_zt
,
log_p_zt
,
z0_log_prob
,
log_r_zt
)
class
LongChainP
(
object
):
def
__init__
(
self
,
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
1e-5
,
variance
=
1.0
,
observation_variance
=
1.0
,
observation_type
=
STANDARD_OBSERVATION
,
transition_type
=
STANDARD_TRANSITION
,
dtype
=
tf
.
float32
,
random_seed
=
None
):
self
.
state_size
=
state_size
self
.
steps_per_obs
=
steps_per_obs
self
.
num_obs
=
num_obs
self
.
num_timesteps
=
steps_per_obs
*
num_obs
+
1
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
variance
=
variance
self
.
observation_variance
=
observation_variance
self
.
observation_type
=
observation_type
self
.
transition_type
=
transition_type
def
likelihood
(
self
,
observations
):
"""Computes the model's true likelihood of the observations.
Args:
observations: A [batch_size, m, state_size] Tensor representing each of
the m observations.
Returns:
logprob: The true likelihood of the observations given the model.
"""
raise
ValueError
(
"Likelihood is not defined for long-chain models"
)
# batch_size = tf.shape(observations)[0]
# mu = tf.zeros([batch_size, self.state_size, self.num_obs], dtype=self.dtype)
# sigma = np.fromfunction(
# lambda i, j: 1 + self.steps_per_obs*np.minimum(i+1, j+1),
# [self.num_obs, self.num_obs])
# sigma += np.eye(self.num_obs)
# sigma = tf.convert_to_tensor(sigma * self.variance, dtype=self.dtype)
# sigma = tf.tile(sigma[tf.newaxis, tf.newaxis, ...],
# [batch_size, self.state_size, 1, 1])
# dist = tf.contrib.distributions.MultivariateNormalFullCovariance(
# loc=mu,
# covariance_matrix=sigma)
# Average over the batch and take the sum over the state size
#return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observations), axis=1))
def
p_zt
(
self
,
prev_state
,
t
):
"""Computes the model p(z_t| z_{t-1})."""
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
if
t
>
0
:
if
self
.
transition_type
==
ROUND_TRANSITION
:
loc
=
tf
.
round
(
prev_state
)
tf
.
logging
.
info
(
"p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)"
%
(
t
,
t
-
1
,
t
-
1
,
self
.
variance
))
elif
self
.
transition_type
==
STANDARD_TRANSITION
:
loc
=
prev_state
tf
.
logging
.
info
(
"p(z_%d | z_%d) ~ N(z_%d, %0.1f)"
%
(
t
,
t
-
1
,
t
-
1
,
self
.
variance
))
else
:
# p(z_0) is Normal(0,1)
loc
=
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
tf
.
logging
.
info
(
"p(z_0) ~ N(0,%0.1f)"
%
self
.
variance
)
p_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
loc
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
loc
)
*
self
.
variance
))
return
p_zt
def
generative
(
self
,
z_ni
,
t
):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if
self
.
observation_type
==
SQUARED_OBSERVATION
:
generative_mu
=
tf
.
square
(
z_ni
)
tf
.
logging
.
info
(
"p(x_%d | z_%d) ~ N(z_%d^2, %0.1f)"
%
(
t
,
t
,
t
,
self
.
variance
))
elif
self
.
observation_type
==
ABS_OBSERVATION
:
generative_mu
=
tf
.
abs
(
z_ni
)
tf
.
logging
.
info
(
"p(x_%d | z_%d) ~ N(|z_%d|, %0.1f)"
%
(
t
,
t
,
t
,
self
.
variance
))
elif
self
.
observation_type
==
STANDARD_OBSERVATION
:
generative_mu
=
z_ni
tf
.
logging
.
info
(
"p(x_%d | z_%d) ~ N(z_%d, %0.1f)"
%
(
t
,
t
,
t
,
self
.
variance
))
generative_sigma_sq
=
tf
.
ones_like
(
generative_mu
)
*
self
.
observation_variance
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
generative_mu
,
scale
=
tf
.
sqrt
(
generative_sigma_sq
))
class
LongChainQ
(
object
):
def
__init__
(
self
,
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
random_seed
=
None
):
self
.
state_size
=
state_size
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
steps_per_obs
=
steps_per_obs
self
.
num_obs
=
num_obs
self
.
num_timesteps
=
num_obs
*
steps_per_obs
+
1
initializers
=
{
"w"
:
tf
.
random_uniform_initializer
(
seed
=
random_seed
),
"b"
:
tf
.
zeros_initializer
}
self
.
mus
=
[
snt
.
Linear
(
output_size
=
state_size
,
initializers
=
initializers
)
for
t
in
xrange
(
self
.
num_timesteps
)
]
self
.
sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"q_sigma_%d"
%
(
t
+
1
),
initializer
=
tf
.
random_uniform_initializer
(
seed
=
random_seed
))
for
t
in
xrange
(
self
.
num_timesteps
)
]
def
first_relevant_obs_index
(
self
,
t
):
return
int
(
max
((
t
-
1
)
/
self
.
steps_per_obs
,
0
))
def
q_zt
(
self
,
observations
,
prev_state
,
t
):
"""Computes a distribution over z_t.
Args:
observations: a [batch_size, num_observations, state_size] Tensor.
prev_state: a [batch_size, state_size] Tensor.
t: The current timestep, an int Tensor.
"""
# filter out unneeded past obs
first_relevant_obs_index
=
int
(
math
.
floor
(
max
(
t
-
1
,
0
)
/
self
.
steps_per_obs
))
num_relevant_observations
=
self
.
num_obs
-
first_relevant_obs_index
observations
=
observations
[:,
first_relevant_obs_index
:,:]
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
# concatenate the prev state and observations along the second axis (that is
# not the batch or state size axis, and then flatten it to
# [batch_size, (num_relevant_observations + 1) * state_size] to feed it into
# the linear layer.
q_input
=
tf
.
concat
([
observations
,
prev_state
[:,
tf
.
newaxis
,
:]],
axis
=
1
)
q_input
=
tf
.
reshape
(
q_input
,
[
batch_size
,
(
num_relevant_observations
+
1
)
*
self
.
state_size
])
q_mu
=
self
.
mus
[
t
](
q_input
)
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
tf
.
logging
.
info
(
"q(z_{t} | z_{tm1}, x_{obsf}:{obst}) ~ N(Linear([z_{tm1},x_{obsf}:{obst}]), sigma_{t})"
.
format
(
**
{
"t"
:
t
,
"tm1"
:
t
-
1
,
"obsf"
:
(
first_relevant_obs_index
+
1
)
*
self
.
steps_per_obs
,
"obst"
:
self
.
steps_per_obs
*
self
.
num_obs
}))
return
q_zt
def
summarize_weights
(
self
):
pass
class
LongChainR
(
object
):
def
__init__
(
self
,
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
random_seed
=
None
):
self
.
state_size
=
state_size
self
.
dtype
=
dtype
self
.
sigma_min
=
sigma_min
self
.
steps_per_obs
=
steps_per_obs
self
.
num_obs
=
num_obs
self
.
num_timesteps
=
num_obs
*
steps_per_obs
+
1
self
.
sigmas
=
[
tf
.
get_variable
(
shape
=
[
self
.
num_future_obs
(
t
)],
dtype
=
self
.
dtype
,
name
=
"r_sigma_%d"
%
(
t
+
1
),
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer
=
tf
.
constant_initializer
(
1.0
))
for
t
in
range
(
self
.
num_timesteps
)
]
def
first_future_obs_index
(
self
,
t
):
return
int
(
math
.
floor
(
t
/
self
.
steps_per_obs
))
def
num_future_obs
(
self
,
t
):
return
int
(
self
.
num_obs
-
self
.
first_future_obs_index
(
t
))
def
r_xn
(
self
,
z_t
,
t
):
"""Computes a distribution over the future observations given current latent
state.
The indexing in these messages is 1 indexed and inclusive. This is
consistent with the latex documents.
Args:
z_t: [batch_size, state_size] Tensor
t: Current timestep
"""
tf
.
logging
.
info
(
"r(x_{start}:{end} | z_{t}) ~ N(z_{t}, sigma_{t})"
.
format
(
**
{
"t"
:
t
,
"start"
:
(
self
.
first_future_obs_index
(
t
)
+
1
)
*
self
.
steps_per_obs
,
"end"
:
self
.
num_timesteps
-
1
}))
batch_size
=
tf
.
shape
(
z_t
)[
0
]
# the mean for all future observations is the same.
# this tiling results in a [batch_size, num_future_obs, state_size] Tensor
r_mu
=
tf
.
tile
(
z_t
[:,
tf
.
newaxis
,:],
[
1
,
self
.
num_future_obs
(
t
),
1
])
# compute the variance
r_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
# the variance is the same across all state dimensions, so we only have to
# time sigma to be [batch_size, num_future_obs].
r_sigma
=
tf
.
tile
(
r_sigma
[
tf
.
newaxis
,:,
tf
.
newaxis
],
[
batch_size
,
1
,
self
.
state_size
])
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
r_mu
,
scale
=
tf
.
sqrt
(
r_sigma
))
def
summarize_weights
(
self
):
pass
class
LongChainModel
(
object
):
def
__init__
(
self
,
p
,
q
,
r
,
state_size
,
num_obs
,
steps_per_obs
,
dtype
=
tf
.
float32
,
disable_r
=
False
):
self
.
p
=
p
self
.
q
=
q
self
.
r
=
r
self
.
disable_r
=
disable_r
self
.
state_size
=
state_size
self
.
num_obs
=
num_obs
self
.
steps_per_obs
=
steps_per_obs
self
.
num_timesteps
=
steps_per_obs
*
num_obs
+
1
self
.
dtype
=
dtype
def
zero_state
(
self
,
batch_size
):
return
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
def
next_obs_ind
(
self
,
t
):
return
int
(
math
.
floor
(
max
(
t
-
1
,
0
)
/
self
.
steps_per_obs
))
def
__call__
(
self
,
prev_state
,
observations
,
t
):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt
=
self
.
q
.
q_zt
(
observations
,
prev_state
,
t
)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt
=
self
.
p
.
p_zt
(
prev_state
,
t
)
# sample from q and evaluate the logprobs, summing over the state size
zt
=
q_zt
.
sample
()
log_q_zt
=
tf
.
reduce_sum
(
q_zt
.
log_prob
(
zt
),
axis
=
1
)
log_p_zt
=
tf
.
reduce_sum
(
p_zt
.
log_prob
(
zt
),
axis
=
1
)
if
not
self
.
disable_r
and
t
<
self
.
num_timesteps
-
1
:
# score the remaining observations using r
r_xn
=
self
.
r
.
r_xn
(
zt
,
t
)
log_r_xn
=
r_xn
.
log_prob
(
observations
[:,
self
.
next_obs_ind
(
t
+
1
):,
:])
# sum over state size and observation, leaving the batch index
log_r_xn
=
tf
.
reduce_sum
(
log_r_xn
,
axis
=
[
1
,
2
])
else
:
log_r_xn
=
tf
.
zeros_like
(
log_p_zt
)
if
t
!=
0
and
t
%
self
.
steps_per_obs
==
0
:
generative_dist
=
self
.
p
.
generative
(
zt
,
t
)
log_p_x_given_z
=
generative_dist
.
log_prob
(
observations
[:,
self
.
next_obs_ind
(
t
),:])
log_p_x_given_z
=
tf
.
reduce_sum
(
log_p_x_given_z
,
axis
=
1
)
else
:
log_p_x_given_z
=
tf
.
zeros_like
(
log_q_zt
)
return
(
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
log_r_xn
)
@
staticmethod
def
create
(
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
1e-5
,
variance
=
1.0
,
observation_variance
=
1.0
,
observation_type
=
STANDARD_OBSERVATION
,
transition_type
=
STANDARD_TRANSITION
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
disable_r
=
False
):
p
=
LongChainP
(
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
sigma_min
,
variance
=
variance
,
observation_variance
=
observation_variance
,
observation_type
=
observation_type
,
transition_type
=
transition_type
,
dtype
=
dtype
,
random_seed
=
random_seed
)
q
=
LongChainQ
(
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
sigma_min
,
dtype
=
dtype
,
random_seed
=
random_seed
)
r
=
LongChainR
(
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
sigma_min
,
dtype
=
dtype
,
random_seed
=
random_seed
)
model
=
LongChainModel
(
p
,
q
,
r
,
state_size
,
num_obs
,
steps_per_obs
,
dtype
=
dtype
,
disable_r
=
disable_r
)
return
model
class
RTilde
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
graph_collection_name
=
"R_TILDE_VARS"
):
self
.
dtype
=
dtype
self
.
sigma_min
=
sigma_min
initializers
=
{
"w"
:
tf
.
truncated_normal_initializer
(
seed
=
random_seed
),
"b"
:
tf
.
zeros_initializer
}
self
.
graph_collection_name
=
graph_collection_name
def
custom_getter
(
getter
,
*
args
,
**
kwargs
):
out
=
getter
(
*
args
,
**
kwargs
)
ref
=
tf
.
get_collection_ref
(
self
.
graph_collection_name
)
if
out
not
in
ref
:
ref
.
append
(
out
)
return
out
self
.
fns
=
[
snt
.
Linear
(
output_size
=
2
*
state_size
,
initializers
=
initializers
,
name
=
"r_tilde_%d"
%
t
,
custom_getter
=
custom_getter
)
for
t
in
xrange
(
num_timesteps
)
]
def
r_zt
(
self
,
z_t
,
observation
,
t
):
#out = self.fns[t](tf.stop_gradient(tf.concat([z_t, observation], axis=1)))
out
=
self
.
fns
[
t
](
tf
.
concat
([
z_t
,
observation
],
axis
=
1
))
mu
,
raw_sigma_sq
=
tf
.
split
(
out
,
2
,
axis
=
1
)
sigma_sq
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
raw_sigma_sq
),
self
.
sigma_min
)
return
mu
,
sigma_sq
class
TDModel
(
object
):
def
__init__
(
self
,
p
,
q
,
r_tilde
,
state_size
,
num_timesteps
,
dtype
=
tf
.
float32
,
disable_r
=
False
):
self
.
p
=
p
self
.
q
=
q
self
.
r_tilde
=
r_tilde
self
.
disable_r
=
disable_r
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
dtype
=
dtype
def
zero_state
(
self
,
batch_size
):
return
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
def
__call__
(
self
,
prev_state
,
observation
,
t
):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt
=
self
.
q
.
q_zt
(
observation
,
prev_state
,
t
)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt
=
self
.
p
.
p_zt
(
prev_state
,
t
)
# sample from q and evaluate the logprobs, summing over the state size
zt
=
q_zt
.
sample
()
# If it isn't the last timestep, compute the distribution over the next z.
if
t
<
self
.
num_timesteps
-
1
:
p_ztplus1
=
self
.
p
.
p_zt
(
zt
,
t
+
1
)
else
:
p_ztplus1
=
None
log_q_zt
=
tf
.
reduce_sum
(
q_zt
.
log_prob
(
zt
),
axis
=
1
)
log_p_zt
=
tf
.
reduce_sum
(
p_zt
.
log_prob
(
zt
),
axis
=
1
)
if
not
self
.
disable_r
and
t
<
self
.
num_timesteps
-
1
:
# score the remaining observations using r
r_tilde_mu
,
r_tilde_sigma_sq
=
self
.
r_tilde
.
r_zt
(
zt
,
observation
,
t
+
1
)
else
:
r_tilde_mu
=
None
r_tilde_sigma_sq
=
None
if
t
==
self
.
num_timesteps
-
1
:
generative_dist
=
self
.
p
.
generative
(
observation
,
zt
)
log_p_x_given_z
=
tf
.
reduce_sum
(
generative_dist
.
log_prob
(
observation
),
axis
=
1
)
else
:
log_p_x_given_z
=
tf
.
zeros_like
(
log_q_zt
)
return
(
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
r_tilde_mu
,
r_tilde_sigma_sq
,
p_ztplus1
)
@
staticmethod
def
create
(
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
variance
=
1.0
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
train_p
=
True
,
p_type
=
"unimodal"
,
q_type
=
"normal"
,
mixing_coeff
=
0.5
,
prior_mode_mean
=
1.0
,
observation_variance
=
1.0
,
transition_type
=
STANDARD_TRANSITION
,
use_bs
=
True
):
if
p_type
==
"unimodal"
:
p
=
P
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
variance
=
variance
,
dtype
=
dtype
,
random_seed
=
random_seed
,
trainable
=
train_p
,
init_bs_to_zero
=
not
use_bs
)
elif
p_type
==
"bimodal"
:
p
=
BimodalPriorP
(
state_size
,
num_timesteps
,
mixing_coeff
=
mixing_coeff
,
prior_mode_mean
=
prior_mode_mean
,
sigma_min
=
sigma_min
,
variance
=
variance
,
dtype
=
dtype
,
random_seed
=
random_seed
,
trainable
=
train_p
,
init_bs_to_zero
=
not
use_bs
)
elif
"nonlinear"
in
p_type
:
if
"cauchy"
in
p_type
:
trans_dist
=
tf
.
contrib
.
distributions
.
Cauchy
else
:
trans_dist
=
tf
.
contrib
.
distributions
.
Normal
p
=
ShortChainNonlinearP
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
variance
=
variance
,
observation_variance
=
observation_variance
,
transition_type
=
transition_type
,
transition_dist
=
trans_dist
,
dtype
=
dtype
,
random_seed
=
random_seed
)
if
q_type
==
"normal"
:
q_class
=
Q
elif
q_type
==
"simple_mean"
:
q_class
=
SimpleMeanQ
elif
q_type
==
"prev_state"
:
q_class
=
PreviousStateQ
elif
q_type
==
"observation"
:
q_class
=
ObservationQ
q
=
q_class
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
dtype
=
dtype
,
random_seed
=
random_seed
,
init_mu0_to_zero
=
not
use_bs
)
r_tilde
=
RTilde
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
dtype
=
dtype
,
random_seed
=
random_seed
)
model
=
TDModel
(
p
,
q
,
r_tilde
,
state_size
,
num_timesteps
,
dtype
=
dtype
)
return
model
research/fivo/experimental/run.sh
deleted
100644 → 0
View file @
5eb294f8
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
model
=
"forward"
T
=
5
num_obs
=
1
var
=
0.1
n
=
4
lr
=
0.0001
bound
=
"fivo-aux"
q_type
=
"normal"
resampling_method
=
"multinomial"
rgrad
=
"true"
p_type
=
"unimodal"
use_bs
=
false
LOGDIR
=
/tmp/fivo/model-
$model
-
$bound
-
$resampling_method
-resampling-rgrad-
$rgrad
-T-
$T
-var-
$var
-n-
$n
-lr-
$lr
-q-
$q_type
-p-
$p_type
python train.py
\
--logdir
=
$LOGDIR
\
--model
=
$model
\
--bound
=
$bound
\
--q_type
=
$q_type
\
--p_type
=
$p_type
\
--variance
=
$var
\
--use_resampling_grads
=
$rgrad
\
--resampling
=
always
\
--resampling_method
=
$resampling_method
\
--batch_size
=
4
\
--num_samples
=
$n
\
--num_timesteps
=
$T
\
--num_eval_samples
=
256
\
--summarize_every
=
100
\
--learning_rate
=
$lr
\
--decay_steps
=
1000000
\
--max_steps
=
1000000000
\
--random_seed
=
1234
\
--train_p
=
false
\
--use_bs
=
$use_bs
\
--alsologtostderr
research/fivo/experimental/summary_utils.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for plotting and summarizing.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
matplotlib.gridspec
as
gridspec
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
scipy
import
tensorflow
as
tf
import
models
def
summarize_ess
(
weights
,
only_last_timestep
=
False
):
"""Plots the effective sample size.
Args:
weights: List of length num_timesteps Tensors of shape
[num_samples, batch_size]
"""
num_timesteps
=
len
(
weights
)
batch_size
=
tf
.
cast
(
tf
.
shape
(
weights
[
0
])[
1
],
dtype
=
tf
.
float64
)
for
i
in
range
(
num_timesteps
):
if
only_last_timestep
and
i
<
num_timesteps
-
1
:
continue
w
=
tf
.
nn
.
softmax
(
weights
[
i
],
dim
=
0
)
centered_weights
=
w
-
tf
.
reduce_mean
(
w
,
axis
=
0
,
keepdims
=
True
)
variance
=
tf
.
reduce_sum
(
tf
.
square
(
centered_weights
))
/
(
batch_size
-
1
)
ess
=
1.
/
tf
.
reduce_mean
(
tf
.
reduce_sum
(
tf
.
square
(
w
),
axis
=
0
))
tf
.
summary
.
scalar
(
"ess/%d"
%
i
,
ess
)
tf
.
summary
.
scalar
(
"ese/%d"
%
i
,
ess
/
batch_size
)
tf
.
summary
.
scalar
(
"weight_variance/%d"
%
i
,
variance
)
def
summarize_particles
(
states
,
weights
,
observation
,
model
):
"""Plots particle locations and weights.
Args:
states: List of length num_timesteps Tensors of shape
[batch_size*num_particles, state_size].
weights: List of length num_timesteps Tensors of shape [num_samples,
batch_size]
observation: Tensor of shape [batch_size*num_samples, state_size]
"""
num_timesteps
=
len
(
weights
)
num_samples
,
batch_size
=
weights
[
0
].
get_shape
().
as_list
()
# get q0 information for plotting
q0_dist
=
model
.
q
.
q_zt
(
observation
,
tf
.
zeros_like
(
states
[
0
]),
0
)
q0_loc
=
q0_dist
.
loc
[
0
:
batch_size
,
0
]
q0_scale
=
q0_dist
.
scale
[
0
:
batch_size
,
0
]
# get posterior information for plotting
post
=
(
model
.
p
.
mixing_coeff
,
model
.
p
.
prior_mode_mean
,
model
.
p
.
variance
,
tf
.
reduce_sum
(
model
.
p
.
bs
),
model
.
p
.
num_timesteps
)
# Reshape states and weights to be [time, num_samples, batch_size]
states
=
tf
.
stack
(
states
)
weights
=
tf
.
stack
(
weights
)
# normalize the weights over the sample dimension
weights
=
tf
.
nn
.
softmax
(
weights
,
dim
=
1
)
states
=
tf
.
reshape
(
states
,
tf
.
shape
(
weights
))
ess
=
1.
/
tf
.
reduce_sum
(
tf
.
square
(
weights
),
axis
=
1
)
def
_plot_states
(
states_batch
,
weights_batch
,
observation_batch
,
ess_batch
,
q0
,
post
):
"""
states: [time, num_samples, batch_size]
weights [time, num_samples, batch_size]
observation: [batch_size, 1]
q0: ([batch_size], [batch_size])
post: ...
"""
num_timesteps
,
_
,
batch_size
=
states_batch
.
shape
plots
=
[]
for
i
in
range
(
batch_size
):
states
=
states_batch
[:,:,
i
]
weights
=
weights_batch
[:,:,
i
]
observation
=
observation_batch
[
i
]
ess
=
ess_batch
[:,
i
]
q0_loc
=
q0
[
0
][
i
]
q0_scale
=
q0
[
1
][
i
]
fig
=
plt
.
figure
(
figsize
=
(
7
,
(
num_timesteps
+
1
)
*
2
))
# Each timestep gets two plots -- a bar plot and a histogram of state locs.
# The bar plot will be bar_rows rows tall.
# The histogram will be 1 row tall.
# There is also 1 extra plot at the top showing the posterior and q.
bar_rows
=
8
num_rows
=
(
num_timesteps
+
1
)
*
(
bar_rows
+
1
)
gs
=
gridspec
.
GridSpec
(
num_rows
,
1
)
# Figure out how wide to make the plot
prior_lims
=
(
post
[
1
]
*
-
2
,
post
[
1
]
*
2
)
q_lims
=
(
scipy
.
stats
.
norm
.
ppf
(
0.01
,
loc
=
q0_loc
,
scale
=
q0_scale
),
scipy
.
stats
.
norm
.
ppf
(
0.99
,
loc
=
q0_loc
,
scale
=
q0_scale
))
state_width
=
states
.
max
()
-
states
.
min
()
state_lims
=
(
states
.
min
()
-
state_width
*
0.15
,
states
.
max
()
+
state_width
*
0.15
)
lims
=
(
min
(
prior_lims
[
0
],
q_lims
[
0
],
state_lims
[
0
]),
max
(
prior_lims
[
1
],
q_lims
[
1
],
state_lims
[
1
]))
# plot the posterior
z0
=
np
.
arange
(
lims
[
0
],
lims
[
1
],
0.1
)
alpha
,
pos_mu
,
sigma_sq
,
B
,
T
=
post
neg_mu
=
-
pos_mu
scale
=
np
.
sqrt
((
T
+
1
)
*
sigma_sq
)
p_zn
=
(
alpha
*
scipy
.
stats
.
norm
.
pdf
(
observation
,
loc
=
pos_mu
+
B
,
scale
=
scale
)
+
(
1
-
alpha
)
*
scipy
.
stats
.
norm
.
pdf
(
observation
,
loc
=
neg_mu
+
B
,
scale
=
scale
))
p_z0
=
(
alpha
*
scipy
.
stats
.
norm
.
pdf
(
z0
,
loc
=
pos_mu
,
scale
=
np
.
sqrt
(
sigma_sq
))
+
(
1
-
alpha
)
*
scipy
.
stats
.
norm
.
pdf
(
z0
,
loc
=
neg_mu
,
scale
=
np
.
sqrt
(
sigma_sq
)))
p_zn_given_z0
=
scipy
.
stats
.
norm
.
pdf
(
observation
,
loc
=
z0
+
B
,
scale
=
np
.
sqrt
(
T
*
sigma_sq
))
post_z0
=
(
p_z0
*
p_zn_given_z0
)
/
p_zn
# plot q
q_z0
=
scipy
.
stats
.
norm
.
pdf
(
z0
,
loc
=
q0_loc
,
scale
=
q0_scale
)
ax
=
plt
.
subplot
(
gs
[
0
:
bar_rows
,
:])
ax
.
plot
(
z0
,
q_z0
,
color
=
"blue"
)
ax
.
plot
(
z0
,
post_z0
,
color
=
"green"
)
ax
.
plot
(
z0
,
p_z0
,
color
=
"red"
)
ax
.
legend
((
"q"
,
"posterior"
,
"prior"
),
loc
=
"best"
,
prop
=
{
"size"
:
10
})
ax
.
set_xticks
([])
ax
.
set_xlim
(
*
lims
)
# plot the states
for
t
in
range
(
num_timesteps
):
start
=
(
t
+
1
)
*
(
bar_rows
+
1
)
ax1
=
plt
.
subplot
(
gs
[
start
:
start
+
bar_rows
,
:])
ax2
=
plt
.
subplot
(
gs
[
start
+
bar_rows
:
start
+
bar_rows
+
1
,
:])
# plot the states barplot
# ax1.hist(
# states[t, :],
# weights=weights[t, :],
# bins=50,
# edgecolor="none",
# alpha=0.2)
ax1
.
bar
(
states
[
t
,:],
weights
[
t
,:],
width
=
0.02
,
alpha
=
0.2
,
edgecolor
=
"none"
)
ax1
.
set_ylabel
(
"t=%d"
%
t
)
ax1
.
set_xticks
([])
ax1
.
grid
(
True
,
which
=
"both"
)
ax1
.
set_xlim
(
*
lims
)
# plot the observation
ax1
.
axvline
(
x
=
observation
,
color
=
"red"
,
linestyle
=
"dashed"
)
# add the ESS
ax1
.
text
(
0.1
,
0.9
,
"ESS: %0.2f"
%
ess
[
t
],
ha
=
'center'
,
va
=
'center'
,
transform
=
ax1
.
transAxes
)
# plot the state location histogram
ax2
.
hist2d
(
states
[
t
,
:],
np
.
zeros_like
(
states
[
t
,
:]),
bins
=
[
50
,
1
],
cmap
=
"Greys"
)
ax2
.
grid
(
False
)
ax2
.
set_yticks
([])
ax2
.
set_xlim
(
*
lims
)
if
t
!=
num_timesteps
-
1
:
ax2
.
set_xticks
([])
fig
.
canvas
.
draw
()
p
=
np
.
fromstring
(
fig
.
canvas
.
tostring_rgb
(),
dtype
=
np
.
uint8
,
sep
=
""
)
plots
.
append
(
p
.
reshape
(
fig
.
canvas
.
get_width_height
()[::
-
1
]
+
(
3
,)))
plt
.
close
(
fig
)
return
np
.
stack
(
plots
)
plots
=
tf
.
py_func
(
_plot_states
,
[
states
,
weights
,
observation
,
ess
,
(
q0_loc
,
q0_scale
),
post
],
[
tf
.
uint8
])[
0
]
tf
.
summary
.
image
(
"states"
,
plots
,
5
,
collections
=
[
"infrequent_summaries"
])
def
plot_weights
(
weights
,
resampled
=
None
):
"""Plots the weights and effective sample size from an SMC rollout.
Args:
weights: [num_timesteps, num_samples, batch_size] importance weights
resampled: [num_timesteps] 0/1 indicating if resampling ocurred
"""
weights
=
tf
.
convert_to_tensor
(
weights
)
def
_make_plots
(
weights
,
resampled
):
num_timesteps
,
num_samples
,
batch_size
=
weights
.
shape
plots
=
[]
for
i
in
range
(
batch_size
):
fig
,
axes
=
plt
.
subplots
(
nrows
=
1
,
sharex
=
True
,
figsize
=
(
8
,
4
))
axes
.
stackplot
(
np
.
arange
(
num_timesteps
),
np
.
transpose
(
weights
[:,
:,
i
]))
axes
.
set_title
(
"Weights"
)
axes
.
set_xlabel
(
"Steps"
)
axes
.
set_ylim
([
0
,
1
])
axes
.
set_xlim
([
0
,
num_timesteps
-
1
])
for
j
in
np
.
where
(
resampled
>
0
)[
0
]:
axes
.
axvline
(
x
=
j
,
color
=
"red"
,
linestyle
=
"dashed"
,
ymin
=
0.0
,
ymax
=
1.0
)
fig
.
canvas
.
draw
()
data
=
np
.
fromstring
(
fig
.
canvas
.
tostring_rgb
(),
dtype
=
np
.
uint8
,
sep
=
""
)
data
=
data
.
reshape
(
fig
.
canvas
.
get_width_height
()[::
-
1
]
+
(
3
,))
plots
.
append
(
data
)
plt
.
close
(
fig
)
return
np
.
stack
(
plots
,
axis
=
0
)
if
resampled
is
None
:
num_timesteps
,
_
,
batch_size
=
weights
.
get_shape
().
as_list
()
resampled
=
tf
.
zeros
([
num_timesteps
],
dtype
=
tf
.
float32
)
plots
=
tf
.
py_func
(
_make_plots
,
[
tf
.
nn
.
softmax
(
weights
,
dim
=
1
),
tf
.
to_float
(
resampled
)],
[
tf
.
uint8
])[
0
]
batch_size
=
weights
.
get_shape
().
as_list
()[
-
1
]
tf
.
summary
.
image
(
"weights"
,
plots
,
batch_size
,
collections
=
[
"infrequent_summaries"
])
def
summarize_weights
(
weights
,
num_timesteps
,
num_samples
):
# weights is [num_timesteps, num_samples, batch_size]
weights
=
tf
.
convert_to_tensor
(
weights
)
mean
=
tf
.
reduce_mean
(
weights
,
axis
=
1
,
keepdims
=
True
)
squared_diff
=
tf
.
square
(
weights
-
mean
)
variances
=
tf
.
reduce_sum
(
squared_diff
,
axis
=
1
)
/
(
num_samples
-
1
)
# average the variance over the batch
variances
=
tf
.
reduce_mean
(
variances
,
axis
=
1
)
avg_magnitude
=
tf
.
reduce_mean
(
tf
.
abs
(
weights
),
axis
=
[
1
,
2
])
for
t
in
xrange
(
num_timesteps
):
tf
.
summary
.
scalar
(
"weights/variance_%d"
%
t
,
variances
[
t
])
tf
.
summary
.
scalar
(
"weights/magnitude_%d"
%
t
,
avg_magnitude
[
t
])
tf
.
summary
.
histogram
(
"weights/step_%d"
%
t
,
weights
[
t
])
def
summarize_learning_signal
(
rewards
,
tag
):
num_resampling_events
,
_
=
rewards
.
get_shape
().
as_list
()
mean
=
tf
.
reduce_mean
(
rewards
,
axis
=
1
)
avg_magnitude
=
tf
.
reduce_mean
(
tf
.
abs
(
rewards
),
axis
=
1
)
reward_square
=
tf
.
reduce_mean
(
tf
.
square
(
rewards
),
axis
=
1
)
for
t
in
xrange
(
num_resampling_events
):
tf
.
summary
.
scalar
(
"%s/mean_%d"
%
(
tag
,
t
),
mean
[
t
])
tf
.
summary
.
scalar
(
"%s/magnitude_%d"
%
(
tag
,
t
),
avg_magnitude
[
t
])
tf
.
summary
.
scalar
(
"%s/squared_%d"
%
(
tag
,
t
),
reward_square
[
t
])
tf
.
summary
.
histogram
(
"%s/step_%d"
%
(
tag
,
t
),
rewards
[
t
])
def
summarize_qs
(
model
,
observation
,
states
):
model
.
q
.
summarize_weights
()
if
hasattr
(
model
.
p
,
"posterior"
)
and
callable
(
getattr
(
model
.
p
,
"posterior"
)):
states
=
[
tf
.
zeros_like
(
states
[
0
])]
+
states
[:
-
1
]
for
t
,
prev_state
in
enumerate
(
states
):
p
=
model
.
p
.
posterior
(
observation
,
prev_state
,
t
)
q
=
model
.
q
.
q_zt
(
observation
,
prev_state
,
t
)
kl
=
tf
.
reduce_mean
(
tf
.
contrib
.
distributions
.
kl_divergence
(
p
,
q
))
tf
.
summary
.
scalar
(
"kl_q/%d"
%
t
,
tf
.
reduce_mean
(
kl
))
mean_diff
=
q
.
loc
-
p
.
loc
mean_abs_err
=
tf
.
abs
(
mean_diff
)
mean_rel_err
=
tf
.
abs
(
mean_diff
/
p
.
loc
)
tf
.
summary
.
scalar
(
"q_mean_convergence/absolute_error_%d"
%
t
,
tf
.
reduce_mean
(
mean_abs_err
))
tf
.
summary
.
scalar
(
"q_mean_convergence/relative_error_%d"
%
t
,
tf
.
reduce_mean
(
mean_rel_err
))
sigma_diff
=
tf
.
square
(
q
.
scale
)
-
tf
.
square
(
p
.
scale
)
sigma_abs_err
=
tf
.
abs
(
sigma_diff
)
sigma_rel_err
=
tf
.
abs
(
sigma_diff
/
tf
.
square
(
p
.
scale
))
tf
.
summary
.
scalar
(
"q_variance_convergence/absolute_error_%d"
%
t
,
tf
.
reduce_mean
(
sigma_abs_err
))
tf
.
summary
.
scalar
(
"q_variance_convergence/relative_error_%d"
%
t
,
tf
.
reduce_mean
(
sigma_rel_err
))
def
summarize_rs
(
model
,
states
):
model
.
r
.
summarize_weights
()
for
t
,
state
in
enumerate
(
states
):
true_r
=
model
.
p
.
lookahead
(
state
,
t
)
r
=
model
.
r
.
r_xn
(
state
,
t
)
kl
=
tf
.
reduce_mean
(
tf
.
contrib
.
distributions
.
kl_divergence
(
true_r
,
r
))
tf
.
summary
.
scalar
(
"kl_r/%d"
%
t
,
tf
.
reduce_mean
(
kl
))
mean_diff
=
true_r
.
loc
-
r
.
loc
mean_abs_err
=
tf
.
abs
(
mean_diff
)
mean_rel_err
=
tf
.
abs
(
mean_diff
/
true_r
.
loc
)
tf
.
summary
.
scalar
(
"r_mean_convergence/absolute_error_%d"
%
t
,
tf
.
reduce_mean
(
mean_abs_err
))
tf
.
summary
.
scalar
(
"r_mean_convergence/relative_error_%d"
%
t
,
tf
.
reduce_mean
(
mean_rel_err
))
sigma_diff
=
tf
.
square
(
r
.
scale
)
-
tf
.
square
(
true_r
.
scale
)
sigma_abs_err
=
tf
.
abs
(
sigma_diff
)
sigma_rel_err
=
tf
.
abs
(
sigma_diff
/
tf
.
square
(
true_r
.
scale
))
tf
.
summary
.
scalar
(
"r_variance_convergence/absolute_error_%d"
%
t
,
tf
.
reduce_mean
(
sigma_abs_err
))
tf
.
summary
.
scalar
(
"r_variance_convergence/relative_error_%d"
%
t
,
tf
.
reduce_mean
(
sigma_rel_err
))
def
summarize_model
(
model
,
true_bs
,
observation
,
states
,
bound
,
summarize_r
=
True
):
if
hasattr
(
model
.
p
,
"bs"
):
model_b
=
tf
.
reduce_sum
(
model
.
p
.
bs
,
axis
=
0
)
true_b
=
tf
.
reduce_sum
(
true_bs
,
axis
=
0
)
abs_err
=
tf
.
abs
(
model_b
-
true_b
)
rel_err
=
abs_err
/
true_b
tf
.
summary
.
scalar
(
"sum_of_bs/data_generating_process"
,
tf
.
reduce_mean
(
true_b
))
tf
.
summary
.
scalar
(
"sum_of_bs/model"
,
tf
.
reduce_mean
(
model_b
))
tf
.
summary
.
scalar
(
"sum_of_bs/absolute_error"
,
tf
.
reduce_mean
(
abs_err
))
tf
.
summary
.
scalar
(
"sum_of_bs/relative_error"
,
tf
.
reduce_mean
(
rel_err
))
#summarize_qs(model, observation, states)
#if bound == "fivo-aux" and summarize_r:
# summarize_rs(model, states)
def
summarize_grads
(
grads
,
loss_name
):
grad_ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
0.99
)
vectorized_grads
=
tf
.
concat
(
[
tf
.
reshape
(
g
,
[
-
1
])
for
g
,
_
in
grads
if
g
is
not
None
],
axis
=
0
)
new_second_moments
=
tf
.
square
(
vectorized_grads
)
new_first_moments
=
vectorized_grads
maintain_grad_ema_op
=
grad_ema
.
apply
([
new_first_moments
,
new_second_moments
])
first_moments
=
grad_ema
.
average
(
new_first_moments
)
second_moments
=
grad_ema
.
average
(
new_second_moments
)
variances
=
second_moments
-
tf
.
square
(
first_moments
)
tf
.
summary
.
scalar
(
"grad_variance/%s"
%
loss_name
,
tf
.
reduce_mean
(
variances
))
tf
.
summary
.
histogram
(
"grad_variance/%s"
%
loss_name
,
variances
)
tf
.
summary
.
histogram
(
"grad_mean/%s"
%
loss_name
,
first_moments
)
return
maintain_grad_ema_op
research/fivo/experimental/train.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main script for running fivo"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
defaultdict
import
numpy
as
np
import
tensorflow
as
tf
import
bounds
import
data
import
models
import
summary_utils
as
summ
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
app
.
flags
.
DEFINE_integer
(
"random_seed"
,
None
,
"A random seed for the data generating process. Same seed "
"-> same data generating process and initialization."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"bound"
,
"fivo"
,
[
"iwae"
,
"fivo"
,
"fivo-aux"
,
"fivo-aux-td"
],
"The bound to optimize."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"model"
,
"forward"
,
[
"forward"
,
"long_chain"
],
"The model to use."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"q_type"
,
"normal"
,
[
"normal"
,
"simple_mean"
,
"prev_state"
,
"observation"
],
"The parameterization to use for q"
)
tf
.
app
.
flags
.
DEFINE_enum
(
"p_type"
,
"unimodal"
,
[
"unimodal"
,
"bimodal"
,
"nonlinear"
],
"The type of prior."
)
tf
.
app
.
flags
.
DEFINE_boolean
(
"train_p"
,
True
,
"If false, do not train the model p."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"state_size"
,
1
,
"The dimensionality of the state space."
)
tf
.
app
.
flags
.
DEFINE_float
(
"variance"
,
1.0
,
"The variance of the data generating process."
)
tf
.
app
.
flags
.
DEFINE_boolean
(
"use_bs"
,
True
,
"If False, initialize all bs to 0."
)
tf
.
app
.
flags
.
DEFINE_float
(
"bimodal_prior_weight"
,
0.5
,
"The weight assigned to the positive mode of the prior in "
"both the data generating process and p."
)
tf
.
app
.
flags
.
DEFINE_float
(
"bimodal_prior_mean"
,
None
,
"If supplied, sets the mean of the 2 modes of the prior to "
"be 1 and -1 times the supplied value. This is for both the "
"data generating process and p."
)
tf
.
app
.
flags
.
DEFINE_float
(
"fixed_observation"
,
None
,
"If supplied, fix the observation to a constant value in the"
" data generating process only."
)
tf
.
app
.
flags
.
DEFINE_float
(
"r_sigma_init"
,
1.
,
"Value to initialize variance of r to."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"observation_type"
,
models
.
STANDARD_OBSERVATION
,
models
.
OBSERVATION_TYPES
,
"The type of observation for the long chain model."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"transition_type"
,
models
.
STANDARD_TRANSITION
,
models
.
TRANSITION_TYPES
,
"The type of transition for the long chain model."
)
tf
.
app
.
flags
.
DEFINE_float
(
"observation_variance"
,
None
,
"The variance of the observation. Defaults to 'variance'"
)
tf
.
app
.
flags
.
DEFINE_integer
(
"num_timesteps"
,
5
,
"Number of timesteps in the sequence."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"num_observations"
,
1
,
"The number of observations."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"steps_per_observation"
,
5
,
"The number of timesteps between each observation."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"batch_size"
,
4
,
"The number of examples per batch."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"num_samples"
,
4
,
"The number particles to use."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"num_eval_samples"
,
512
,
"The batch size and # of particles to use for eval."
)
tf
.
app
.
flags
.
DEFINE_string
(
"resampling"
,
"always"
,
"How to resample. Accepts 'always','never', or a "
"comma-separated list of booleans like 'true,true,false'."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"resampling_method"
,
"multinomial"
,
[
"multinomial"
,
"stratified"
,
"systematic"
,
"relaxed-logblend"
,
"relaxed-stateblend"
,
"relaxed-linearblend"
,
"relaxed-stateblend-st"
,],
"Type of resampling method to use."
)
tf
.
app
.
flags
.
DEFINE_boolean
(
"use_resampling_grads"
,
True
,
"Whether or not to use resampling grads to optimize FIVO."
"Disabled automatically if resampling_method=relaxed."
)
tf
.
app
.
flags
.
DEFINE_boolean
(
"disable_r"
,
False
,
"If false, r is not used for fivo-aux and is set to zeros."
)
tf
.
app
.
flags
.
DEFINE_float
(
"learning_rate"
,
1e-4
,
"The learning rate to use for ADAM or SGD."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"decay_steps"
,
25000
,
"The number of steps before the learning rate is halved."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"max_steps"
,
int
(
1e6
),
"The number of steps to run training for."
)
tf
.
app
.
flags
.
DEFINE_string
(
"logdir"
,
"/tmp/fivo-aux"
,
"Directory for summaries and checkpoints."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"summarize_every"
,
int
(
1e3
),
"The number of steps between each evaluation."
)
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
combine_grad_lists
(
grad_lists
):
# grads is num_losses by num_variables.
# each list could have different variables.
# for each variable, sum the grads across all losses.
grads_dict
=
defaultdict
(
list
)
var_dict
=
{}
for
grad_list
in
grad_lists
:
for
grad
,
var
in
grad_list
:
if
grad
is
not
None
:
grads_dict
[
var
.
name
].
append
(
grad
)
var_dict
[
var
.
name
]
=
var
final_grads
=
[]
for
var_name
,
var
in
var_dict
.
iteritems
():
grads
=
grads_dict
[
var_name
]
if
len
(
grads
)
>
0
:
tf
.
logging
.
info
(
"Var %s has combined grads from %s."
%
(
var_name
,
[
g
.
name
for
g
in
grads
]))
grad
=
tf
.
reduce_sum
(
grads
,
axis
=
0
)
else
:
tf
.
logging
.
info
(
"Var %s has no grads"
%
var_name
)
grad
=
None
final_grads
.
append
((
grad
,
var
))
return
final_grads
def
make_apply_grads_op
(
losses
,
global_step
,
learning_rate
,
lr_decay_steps
):
for
l
in
losses
:
assert
isinstance
(
l
,
bounds
.
Loss
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
,
global_step
,
lr_decay_steps
,
0.5
,
staircase
=
False
)
tf
.
summary
.
scalar
(
"learning_rate"
,
lr
)
opt
=
tf
.
train
.
AdamOptimizer
(
lr
)
ema_ops
=
[]
grads
=
[]
for
loss_name
,
loss
,
loss_var_collection
in
losses
:
tf
.
logging
.
info
(
"Computing grads of %s w.r.t. vars in collection %s"
%
(
loss_name
,
loss_var_collection
))
g
=
opt
.
compute_gradients
(
loss
,
var_list
=
tf
.
get_collection
(
loss_var_collection
))
ema_ops
.
append
(
summ
.
summarize_grads
(
g
,
loss_name
))
grads
.
append
(
g
)
all_grads
=
combine_grad_lists
(
grads
)
apply_grads_op
=
opt
.
apply_gradients
(
all_grads
,
global_step
=
global_step
)
# Update the emas after applying the grads.
with
tf
.
control_dependencies
([
apply_grads_op
]):
train_op
=
tf
.
group
(
*
ema_ops
)
return
train_op
def
add_check_numerics_ops
():
check_op
=
[]
for
op
in
tf
.
get_default_graph
().
get_operations
():
bad
=
[
"logits/Log"
,
"sample/Reshape"
,
"log_prob/mul"
,
"log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape"
,
"entropy/Reshape"
,
"entropy/LogSoftmax"
,
"Categorical"
,
"Mean"
]
if
all
([
x
not
in
op
.
name
for
x
in
bad
]):
for
output
in
op
.
outputs
:
if
output
.
dtype
in
[
tf
.
float16
,
tf
.
float32
,
tf
.
float64
]:
if
op
.
_get_control_flow_context
()
is
not
None
:
# pylint: disable=protected-access
raise
ValueError
(
"`tf.add_check_numerics_ops() is not compatible "
"with TensorFlow control flow operations such as "
"`tf.cond()` or `tf.while_loop()`."
)
message
=
op
.
name
+
":"
+
str
(
output
.
value_index
)
with
tf
.
control_dependencies
(
check_op
):
check_op
=
[
tf
.
check_numerics
(
output
,
message
=
message
)]
return
tf
.
group
(
*
check_op
)
def
create_long_chain_graph
(
bound
,
state_size
,
num_obs
,
steps_per_obs
,
batch_size
,
num_samples
,
num_eval_samples
,
resampling_schedule
,
use_resampling_grads
,
learning_rate
,
lr_decay_steps
,
dtype
=
"float64"
):
num_timesteps
=
num_obs
*
steps_per_obs
+
1
# Make the dataset.
dataset
=
data
.
make_long_chain_dataset
(
state_size
=
state_size
,
num_obs
=
num_obs
,
steps_per_obs
=
steps_per_obs
,
batch_size
=
batch_size
,
num_samples
=
num_samples
,
variance
=
FLAGS
.
variance
,
observation_variance
=
FLAGS
.
observation_variance
,
dtype
=
dtype
,
observation_type
=
FLAGS
.
observation_type
,
transition_type
=
FLAGS
.
transition_type
,
fixed_observation
=
FLAGS
.
fixed_observation
)
itr
=
dataset
.
make_one_shot_iterator
()
_
,
observations
=
itr
.
get_next
()
# Make the dataset for eval
eval_dataset
=
data
.
make_long_chain_dataset
(
state_size
=
state_size
,
num_obs
=
num_obs
,
steps_per_obs
=
steps_per_obs
,
batch_size
=
batch_size
,
num_samples
=
num_eval_samples
,
variance
=
FLAGS
.
variance
,
observation_variance
=
FLAGS
.
observation_variance
,
dtype
=
dtype
,
observation_type
=
FLAGS
.
observation_type
,
transition_type
=
FLAGS
.
transition_type
,
fixed_observation
=
FLAGS
.
fixed_observation
)
eval_itr
=
eval_dataset
.
make_one_shot_iterator
()
_
,
eval_observations
=
eval_itr
.
get_next
()
# Make the model.
model
=
models
.
LongChainModel
.
create
(
state_size
,
num_obs
,
steps_per_obs
,
observation_type
=
FLAGS
.
observation_type
,
transition_type
=
FLAGS
.
transition_type
,
variance
=
FLAGS
.
variance
,
observation_variance
=
FLAGS
.
observation_variance
,
dtype
=
tf
.
as_dtype
(
dtype
),
disable_r
=
FLAGS
.
disable_r
)
# Compute the bound and loss
if
bound
==
"iwae"
:
(
_
,
losses
,
ema_op
,
_
,
_
)
=
bounds
.
iwae
(
model
,
observations
,
num_timesteps
,
num_samples
=
num_samples
)
(
eval_log_p_hat
,
_
,
_
,
_
,
eval_log_weights
)
=
bounds
.
iwae
(
model
,
eval_observations
,
num_timesteps
,
num_samples
=
num_eval_samples
,
summarize
=
False
)
eval_log_p_hat
=
tf
.
reduce_mean
(
eval_log_p_hat
)
elif
bound
==
"fivo"
or
"fivo-aux"
:
(
_
,
losses
,
ema_op
,
_
,
_
)
=
bounds
.
fivo
(
model
,
observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
use_resampling_grads
=
use_resampling_grads
,
resampling_type
=
FLAGS
.
resampling_method
,
aux
=
(
"aux"
in
bound
),
num_samples
=
num_samples
)
(
eval_log_p_hat
,
_
,
_
,
_
,
eval_log_weights
)
=
bounds
.
fivo
(
model
,
eval_observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
use_resampling_grads
=
False
,
resampling_type
=
"multinomial"
,
aux
=
(
"aux"
in
bound
),
num_samples
=
num_eval_samples
,
summarize
=
False
)
eval_log_p_hat
=
tf
.
reduce_mean
(
eval_log_p_hat
)
summ
.
summarize_ess
(
eval_log_weights
,
only_last_timestep
=
True
)
tf
.
summary
.
scalar
(
"log_p_hat"
,
eval_log_p_hat
)
# Compute and apply grads.
global_step
=
tf
.
train
.
get_or_create_global_step
()
apply_grads
=
make_apply_grads_op
(
losses
,
global_step
,
learning_rate
,
lr_decay_steps
)
# Update the emas after applying the grads.
with
tf
.
control_dependencies
([
apply_grads
]):
train_op
=
tf
.
group
(
ema_op
)
# We can't calculate the likelihood for most of these models
# so we just return zeros.
eval_likelihood
=
tf
.
zeros
([],
dtype
=
dtype
)
return
global_step
,
train_op
,
eval_log_p_hat
,
eval_likelihood
def
create_graph
(
bound
,
state_size
,
num_timesteps
,
batch_size
,
num_samples
,
num_eval_samples
,
resampling_schedule
,
use_resampling_grads
,
learning_rate
,
lr_decay_steps
,
train_p
,
dtype
=
'float64'
):
if
FLAGS
.
use_bs
:
true_bs
=
None
else
:
true_bs
=
[
np
.
zeros
([
state_size
]).
astype
(
dtype
)
for
_
in
xrange
(
num_timesteps
)]
# Make the dataset.
true_bs
,
dataset
=
data
.
make_dataset
(
bs
=
true_bs
,
state_size
=
state_size
,
num_timesteps
=
num_timesteps
,
batch_size
=
batch_size
,
num_samples
=
num_samples
,
variance
=
FLAGS
.
variance
,
prior_type
=
FLAGS
.
p_type
,
bimodal_prior_weight
=
FLAGS
.
bimodal_prior_weight
,
bimodal_prior_mean
=
FLAGS
.
bimodal_prior_mean
,
transition_type
=
FLAGS
.
transition_type
,
fixed_observation
=
FLAGS
.
fixed_observation
,
dtype
=
dtype
)
itr
=
dataset
.
make_one_shot_iterator
()
_
,
observations
=
itr
.
get_next
()
# Make the dataset for eval
_
,
eval_dataset
=
data
.
make_dataset
(
bs
=
true_bs
,
state_size
=
state_size
,
num_timesteps
=
num_timesteps
,
batch_size
=
num_eval_samples
,
num_samples
=
num_eval_samples
,
variance
=
FLAGS
.
variance
,
prior_type
=
FLAGS
.
p_type
,
bimodal_prior_weight
=
FLAGS
.
bimodal_prior_weight
,
bimodal_prior_mean
=
FLAGS
.
bimodal_prior_mean
,
transition_type
=
FLAGS
.
transition_type
,
fixed_observation
=
FLAGS
.
fixed_observation
,
dtype
=
dtype
)
eval_itr
=
eval_dataset
.
make_one_shot_iterator
()
_
,
eval_observations
=
eval_itr
.
get_next
()
# Make the model.
if
bound
==
"fivo-aux-td"
:
model
=
models
.
TDModel
.
create
(
state_size
,
num_timesteps
,
variance
=
FLAGS
.
variance
,
train_p
=
train_p
,
p_type
=
FLAGS
.
p_type
,
q_type
=
FLAGS
.
q_type
,
mixing_coeff
=
FLAGS
.
bimodal_prior_weight
,
prior_mode_mean
=
FLAGS
.
bimodal_prior_mean
,
observation_variance
=
FLAGS
.
observation_variance
,
transition_type
=
FLAGS
.
transition_type
,
use_bs
=
FLAGS
.
use_bs
,
dtype
=
tf
.
as_dtype
(
dtype
),
random_seed
=
FLAGS
.
random_seed
)
else
:
model
=
models
.
Model
.
create
(
state_size
,
num_timesteps
,
variance
=
FLAGS
.
variance
,
train_p
=
train_p
,
p_type
=
FLAGS
.
p_type
,
q_type
=
FLAGS
.
q_type
,
mixing_coeff
=
FLAGS
.
bimodal_prior_weight
,
prior_mode_mean
=
FLAGS
.
bimodal_prior_mean
,
observation_variance
=
FLAGS
.
observation_variance
,
transition_type
=
FLAGS
.
transition_type
,
use_bs
=
FLAGS
.
use_bs
,
r_sigma_init
=
FLAGS
.
r_sigma_init
,
dtype
=
tf
.
as_dtype
(
dtype
),
random_seed
=
FLAGS
.
random_seed
)
# Compute the bound and loss
if
bound
==
"iwae"
:
(
_
,
losses
,
ema_op
,
_
,
_
)
=
bounds
.
iwae
(
model
,
observations
,
num_timesteps
,
num_samples
=
num_samples
)
(
eval_log_p_hat
,
_
,
_
,
eval_states
,
eval_log_weights
)
=
bounds
.
iwae
(
model
,
eval_observations
,
num_timesteps
,
num_samples
=
num_eval_samples
,
summarize
=
True
)
eval_log_p_hat
=
tf
.
reduce_mean
(
eval_log_p_hat
)
elif
"fivo"
in
bound
:
if
bound
==
"fivo-aux-td"
:
(
_
,
losses
,
ema_op
,
_
,
_
)
=
bounds
.
fivo_aux_td
(
model
,
observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
num_samples
=
num_samples
)
(
eval_log_p_hat
,
_
,
_
,
eval_states
,
eval_log_weights
)
=
bounds
.
fivo_aux_td
(
model
,
eval_observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
num_samples
=
num_eval_samples
,
summarize
=
True
)
else
:
(
_
,
losses
,
ema_op
,
_
,
_
)
=
bounds
.
fivo
(
model
,
observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
use_resampling_grads
=
use_resampling_grads
,
resampling_type
=
FLAGS
.
resampling_method
,
aux
=
(
"aux"
in
bound
),
num_samples
=
num_samples
)
(
eval_log_p_hat
,
_
,
_
,
eval_states
,
eval_log_weights
)
=
bounds
.
fivo
(
model
,
eval_observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
use_resampling_grads
=
False
,
resampling_type
=
"multinomial"
,
aux
=
(
"aux"
in
bound
),
num_samples
=
num_eval_samples
,
summarize
=
True
)
eval_log_p_hat
=
tf
.
reduce_mean
(
eval_log_p_hat
)
summ
.
summarize_ess
(
eval_log_weights
,
only_last_timestep
=
True
)
# if FLAGS.p_type == "bimodal":
# # create the observations that showcase the model.
# mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.],
# dtype=tf.float64)
# mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1)
# k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean)
# explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k
# explain_obs = tf.tile(explain_obs, [num_eval_samples, 1])
# # run the model on the explainable observations
# if bound == "iwae":
# (_, _, _, explain_states, explain_log_weights) = bounds.iwae(
# model,
# explain_obs,
# num_timesteps,
# num_samples=num_eval_samples)
# elif bound == "fivo" or "fivo-aux":
# (_, _, _, explain_states, explain_log_weights) = bounds.fivo(
# model,
# explain_obs,
# num_timesteps,
# resampling_schedule=resampling_schedule,
# use_resampling_grads=False,
# resampling_type="multinomial",
# aux=("aux" in bound),
# num_samples=num_eval_samples)
# summ.summarize_particles(explain_states,
# explain_log_weights,
# explain_obs,
# model)
# Calculate the true likelihood.
if
hasattr
(
model
.
p
,
'likelihood'
)
and
callable
(
getattr
(
model
.
p
,
'likelihood'
)):
eval_likelihood
=
model
.
p
.
likelihood
(
eval_observations
)
/
FLAGS
.
num_timesteps
else
:
eval_likelihood
=
tf
.
zeros_like
(
eval_log_p_hat
)
tf
.
summary
.
scalar
(
"log_p_hat"
,
eval_log_p_hat
)
tf
.
summary
.
scalar
(
"likelihood"
,
eval_likelihood
)
tf
.
summary
.
scalar
(
"bound_gap"
,
eval_likelihood
-
eval_log_p_hat
)
summ
.
summarize_model
(
model
,
true_bs
,
eval_observations
,
eval_states
,
bound
,
summarize_r
=
not
bound
==
"fivo-aux-td"
)
# Compute and apply grads.
global_step
=
tf
.
train
.
get_or_create_global_step
()
apply_grads
=
make_apply_grads_op
(
losses
,
global_step
,
learning_rate
,
lr_decay_steps
)
# Update the emas after applying the grads.
with
tf
.
control_dependencies
([
apply_grads
]):
train_op
=
tf
.
group
(
ema_op
)
#train_op = tf.group(ema_op, add_check_numerics_ops())
return
global_step
,
train_op
,
eval_log_p_hat
,
eval_likelihood
def
parse_resampling_schedule
(
schedule
,
num_timesteps
):
schedule
=
schedule
.
strip
().
lower
()
if
schedule
==
"always"
:
return
[
True
]
*
(
num_timesteps
-
1
)
+
[
False
]
elif
schedule
==
"never"
:
return
[
False
]
*
num_timesteps
elif
"every"
in
schedule
:
n
=
int
(
schedule
.
split
(
"_"
)[
1
])
return
[(
i
+
1
)
%
n
==
0
for
i
in
xrange
(
num_timesteps
)]
else
:
sched
=
[
x
.
strip
()
==
"true"
for
x
in
schedule
.
split
(
","
)]
assert
len
(
sched
)
==
num_timesteps
,
"Wrong number of timesteps in resampling schedule."
return
sched
def
create_log_hook
(
step
,
eval_log_p_hat
,
eval_likelihood
):
def
summ_formatter
(
d
):
return
(
"Step {step}, log p_hat: {log_p_hat:.5f} likelihood: {likelihood:.5f}"
.
format
(
**
d
))
hook
=
tf
.
train
.
LoggingTensorHook
(
{
"step"
:
step
,
"log_p_hat"
:
eval_log_p_hat
,
"likelihood"
:
eval_likelihood
,
},
every_n_iter
=
FLAGS
.
summarize_every
,
formatter
=
summ_formatter
)
return
hook
def
create_infrequent_summary_hook
():
infrequent_summary_hook
=
tf
.
train
.
SummarySaverHook
(
save_steps
=
10000
,
output_dir
=
FLAGS
.
logdir
,
summary_op
=
tf
.
summary
.
merge_all
(
key
=
"infrequent_summaries"
)
)
return
infrequent_summary_hook
def
main
(
unused_argv
):
if
FLAGS
.
model
==
"long_chain"
:
resampling_schedule
=
parse_resampling_schedule
(
FLAGS
.
resampling
,
FLAGS
.
num_timesteps
+
1
)
else
:
resampling_schedule
=
parse_resampling_schedule
(
FLAGS
.
resampling
,
FLAGS
.
num_timesteps
)
if
FLAGS
.
random_seed
is
None
:
seed
=
np
.
random
.
randint
(
0
,
high
=
10000
)
else
:
seed
=
FLAGS
.
random_seed
tf
.
logging
.
info
(
"Using random seed %d"
,
seed
)
if
FLAGS
.
model
==
"long_chain"
:
assert
FLAGS
.
q_type
==
"normal"
,
"Q type %s not supported for long chain models"
%
FLAGS
.
q_type
assert
FLAGS
.
p_type
==
"unimodal"
,
"Bimodal priors are not supported for long chain models"
assert
not
FLAGS
.
use_bs
,
"Bs are not supported with long chain models"
assert
FLAGS
.
num_timesteps
==
FLAGS
.
num_observations
*
FLAGS
.
steps_per_observation
,
"Num timesteps does not match."
assert
FLAGS
.
bound
!=
"fivo-aux-td"
,
"TD Training is not compatible with long chain models."
if
FLAGS
.
model
==
"forward"
:
if
"nonlinear"
not
in
FLAGS
.
p_type
:
assert
FLAGS
.
transition_type
==
models
.
STANDARD_TRANSITION
,
"Non-standard transitions not supported by the forward model."
assert
FLAGS
.
observation_type
==
models
.
STANDARD_OBSERVATION
,
"Non-standard observations not supported by the forward model."
assert
FLAGS
.
observation_variance
is
None
,
"Forward model does not support observation variance."
assert
FLAGS
.
num_observations
==
1
,
"Forward model only supports 1 observation."
if
"relaxed"
in
FLAGS
.
resampling_method
:
FLAGS
.
use_resampling_grads
=
False
assert
FLAGS
.
bound
!=
"fivo-aux-td"
,
"TD Training is not compatible with relaxed resampling."
if
FLAGS
.
observation_variance
is
None
:
FLAGS
.
observation_variance
=
FLAGS
.
variance
if
FLAGS
.
p_type
==
"bimodal"
:
assert
FLAGS
.
bimodal_prior_mean
is
not
None
,
"Must specify prior mean if using bimodal p."
if
FLAGS
.
p_type
==
"nonlinear"
or
FLAGS
.
p_type
==
"nonlinear-cauchy"
:
assert
not
FLAGS
.
use_bs
,
"Using bs is not compatible with the nonlinear model."
g
=
tf
.
Graph
()
with
g
.
as_default
():
# Set the seeds.
tf
.
set_random_seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
FLAGS
.
model
==
"long_chain"
:
(
global_step
,
train_op
,
eval_log_p_hat
,
eval_likelihood
)
=
create_long_chain_graph
(
FLAGS
.
bound
,
FLAGS
.
state_size
,
FLAGS
.
num_observations
,
FLAGS
.
steps_per_observation
,
FLAGS
.
batch_size
,
FLAGS
.
num_samples
,
FLAGS
.
num_eval_samples
,
resampling_schedule
,
FLAGS
.
use_resampling_grads
,
FLAGS
.
learning_rate
,
FLAGS
.
decay_steps
)
else
:
(
global_step
,
train_op
,
eval_log_p_hat
,
eval_likelihood
)
=
create_graph
(
FLAGS
.
bound
,
FLAGS
.
state_size
,
FLAGS
.
num_timesteps
,
FLAGS
.
batch_size
,
FLAGS
.
num_samples
,
FLAGS
.
num_eval_samples
,
resampling_schedule
,
FLAGS
.
use_resampling_grads
,
FLAGS
.
learning_rate
,
FLAGS
.
decay_steps
,
FLAGS
.
train_p
)
log_hooks
=
[
create_log_hook
(
global_step
,
eval_log_p_hat
,
eval_likelihood
)]
if
len
(
tf
.
get_collection
(
"infrequent_summaries"
))
>
0
:
log_hooks
.
append
(
create_infrequent_summary_hook
())
tf
.
logging
.
info
(
"trainable variables:"
)
tf
.
logging
.
info
([
v
.
name
for
v
in
tf
.
trainable_variables
()])
tf
.
logging
.
info
(
"p vars:"
)
tf
.
logging
.
info
([
v
.
name
for
v
in
tf
.
get_collection
(
"P_VARS"
)])
tf
.
logging
.
info
(
"q vars:"
)
tf
.
logging
.
info
([
v
.
name
for
v
in
tf
.
get_collection
(
"Q_VARS"
)])
tf
.
logging
.
info
(
"r vars:"
)
tf
.
logging
.
info
([
v
.
name
for
v
in
tf
.
get_collection
(
"R_VARS"
)])
tf
.
logging
.
info
(
"r tilde vars:"
)
tf
.
logging
.
info
([
v
.
name
for
v
in
tf
.
get_collection
(
"R_TILDE_VARS"
)])
with
tf
.
train
.
MonitoredTrainingSession
(
master
=
""
,
is_chief
=
True
,
hooks
=
log_hooks
,
checkpoint_dir
=
FLAGS
.
logdir
,
save_checkpoint_secs
=
120
,
save_summaries_steps
=
FLAGS
.
summarize_every
,
log_step_count_steps
=
FLAGS
.
summarize_every
)
as
sess
:
cur_step
=
-
1
while
True
:
if
sess
.
should_stop
()
or
cur_step
>
FLAGS
.
max_steps
:
break
# run a step
_
,
cur_step
=
sess
.
run
([
train_op
,
global_step
])
if
__name__
==
"__main__"
:
tf
.
app
.
run
(
main
)
research/fivo/fivo/__init__.py
deleted
100644 → 0
View file @
5eb294f8
research/fivo/fivo/bounds.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of objectives for training stochastic latent variable models.
Contains implementations of the Importance Weighted Autoencoder objective (IWAE)
and the Filtering Variational objective (FIVO).
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
functools
import
tensorflow
as
tf
from
fivo
import
nested_utils
as
nested
from
fivo
import
smc
def
iwae
(
model
,
observations
,
seq_lengths
,
num_samples
=
1
,
parallel_iterations
=
30
,
swap_memory
=
True
):
"""Computes the IWAE lower bound on the log marginal probability.
This method accepts a stochastic latent variable model and some observations
and computes a stochastic lower bound on the log marginal probability of the
observations. The IWAE estimator is defined by averaging multiple importance
weights. For more details see "Importance Weighted Autoencoders" by Burda
et al. https://arxiv.org/abs/1509.00519.
When num_samples = 1, this bound becomes the evidence lower bound (ELBO).
Args:
model: A subclass of ELBOTrainableSequenceModel that implements one
timestep of the model. See models/vrnn.py for an example.
observations: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. The model
will be provided with the observations before computing the bound.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of samples to use.
parallel_iterations: The number of parallel iterations to use for the
internal while loop.
swap_memory: Whether GPU-CPU memory swapping should be enabled for the
internal while loop.
Returns:
log_p_hat: A Tensor of shape [batch_size] containing IWAE's estimate of the
log marginal probability of the observations.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep. Will not be valid for
timesteps past the end of a sequence.
"""
log_p_hat
,
log_weights
,
_
,
final_state
=
fivo
(
model
,
observations
,
seq_lengths
,
num_samples
=
num_samples
,
resampling_criterion
=
smc
.
never_resample_criterion
,
parallel_iterations
=
parallel_iterations
,
swap_memory
=
swap_memory
)
return
log_p_hat
,
log_weights
,
final_state
def
fivo
(
model
,
observations
,
seq_lengths
,
num_samples
=
1
,
resampling_criterion
=
smc
.
ess_criterion
,
resampling_type
=
'multinomial'
,
relaxed_resampling_temperature
=
0.5
,
parallel_iterations
=
30
,
swap_memory
=
True
,
random_seed
=
None
):
"""Computes the FIVO lower bound on the log marginal probability.
This method accepts a stochastic latent variable model and some observations
and computes a stochastic lower bound on the log marginal probability of the
observations. The lower bound is defined by a particle filter's unbiased
estimate of the marginal probability of the observations. For more details see
"Filtering Variational Objectives" by Maddison et al.
https://arxiv.org/abs/1705.09279.
When the resampling criterion is "never resample", this bound becomes IWAE.
Args:
model: A subclass of ELBOTrainableSequenceModel that implements one
timestep of the model. See models/vrnn.py for an example.
observations: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. The model
will be provided with the observations before computing the bound.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of particles to use in each particle filter.
resampling_criterion: The resampling criterion to use for this particle
filter. Must accept the number of samples, the current log weights,
and the current timestep and return a boolean Tensor of shape [batch_size]
indicating whether each particle filter should resample. See
ess_criterion and related functions for examples. When
resampling_criterion is never_resample_criterion, resampling_fn is ignored
and never called.
resampling_type: The type of resampling, one of "multinomial" or "relaxed".
relaxed_resampling_temperature: A positive temperature only used for relaxed
resampling.
parallel_iterations: The number of parallel iterations to use for the
internal while loop. Note that values greater than 1 can introduce
non-determinism even when random_seed is provided.
swap_memory: Whether GPU-CPU memory swapping should be enabled for the
internal while loop.
random_seed: The random seed to pass to the resampling operations in
the particle filter. Mainly useful for testing.
Returns:
log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
log marginal probability of the observations.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep of the particle filter. Note
that on timesteps when a resampling operation is performed the log weights
are reset to 0. Will not be valid for timesteps past the end of a
sequence.
resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
particle filters resampled. Will be 1.0 on timesteps when resampling
occurred and 0.0 on timesteps when it did not.
"""
# batch_size is the number of particle filters running in parallel.
batch_size
=
tf
.
shape
(
seq_lengths
)[
0
]
# Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as:
# particle 1 of particle filter 1
# particle 1 of particle filter 2
# ...
# particle 1 of particle filter batch_size
# particle 2 of particle filter 1
# ...
# particle num_samples of particle filter batch_size
observations
=
nested
.
tile_tensors
(
observations
,
[
1
,
num_samples
])
tiled_seq_lengths
=
tf
.
tile
(
seq_lengths
,
[
num_samples
])
model
.
set_observations
(
observations
,
tiled_seq_lengths
)
if
resampling_type
==
'multinomial'
:
resampling_fn
=
smc
.
multinomial_resampling
elif
resampling_type
==
'relaxed'
:
resampling_fn
=
functools
.
partial
(
smc
.
relaxed_resampling
,
temperature
=
relaxed_resampling_temperature
)
resampling_fn
=
functools
.
partial
(
resampling_fn
,
random_seed
=
random_seed
)
def
transition_fn
(
prev_state
,
t
):
if
prev_state
is
None
:
return
model
.
zero_state
(
batch_size
*
num_samples
,
tf
.
float32
)
return
model
.
propose_and_weight
(
prev_state
,
t
)
log_p_hat
,
log_weights
,
resampled
,
final_state
,
_
=
smc
.
smc
(
transition_fn
,
seq_lengths
,
num_particles
=
num_samples
,
resampling_criterion
=
resampling_criterion
,
resampling_fn
=
resampling_fn
,
parallel_iterations
=
parallel_iterations
,
swap_memory
=
swap_memory
)
return
log_p_hat
,
log_weights
,
resampled
,
final_state
def
fivo_aux_td
(
model
,
observations
,
seq_lengths
,
num_samples
=
1
,
resampling_criterion
=
smc
.
ess_criterion
,
resampling_type
=
'multinomial'
,
relaxed_resampling_temperature
=
0.5
,
parallel_iterations
=
30
,
swap_memory
=
True
,
random_seed
=
None
):
"""Experimental."""
# batch_size is the number of particle filters running in parallel.
batch_size
=
tf
.
shape
(
seq_lengths
)[
0
]
max_seq_len
=
tf
.
reduce_max
(
seq_lengths
)
# Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as:
# particle 1 of particle filter 1
# particle 1 of particle filter 2
# ...
# particle 1 of particle filter batch_size
# particle 2 of particle filter 1
# ...
# particle num_samples of particle filter batch_size
observations
=
nested
.
tile_tensors
(
observations
,
[
1
,
num_samples
])
tiled_seq_lengths
=
tf
.
tile
(
seq_lengths
,
[
num_samples
])
model
.
set_observations
(
observations
,
tiled_seq_lengths
)
if
resampling_type
==
'multinomial'
:
resampling_fn
=
smc
.
multinomial_resampling
elif
resampling_type
==
'relaxed'
:
resampling_fn
=
functools
.
partial
(
smc
.
relaxed_resampling
,
temperature
=
relaxed_resampling_temperature
)
resampling_fn
=
functools
.
partial
(
resampling_fn
,
random_seed
=
random_seed
)
def
transition_fn
(
prev_state
,
t
):
if
prev_state
is
None
:
model_init_state
=
model
.
zero_state
(
batch_size
*
num_samples
,
tf
.
float32
)
return
(
tf
.
zeros
([
num_samples
*
batch_size
],
dtype
=
tf
.
float32
),
(
tf
.
zeros
([
num_samples
*
batch_size
,
model
.
latent_size
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
num_samples
*
batch_size
,
model
.
latent_size
],
dtype
=
tf
.
float32
)),
model_init_state
)
prev_log_r
,
prev_log_r_tilde
,
prev_model_state
=
prev_state
(
new_model_state
,
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
log_r_tilde
,
p_ztplus1
)
=
model
(
prev_model_state
,
t
)
r_tilde_mu
,
r_tilde_sigma_sq
=
log_r_tilde
# Compute the weight without r.
log_weight
=
log_p_zt
+
log_p_x_given_z
-
log_q_zt
# Compute log_r and log_r_tilde.
p_mu
=
tf
.
stop_gradient
(
p_ztplus1
.
mean
())
p_sigma_sq
=
tf
.
stop_gradient
(
p_ztplus1
.
variance
())
log_r
=
(
tf
.
log
(
r_tilde_sigma_sq
)
-
tf
.
log
(
r_tilde_sigma_sq
+
p_sigma_sq
)
-
tf
.
square
(
r_tilde_mu
-
p_mu
)
/
(
r_tilde_sigma_sq
+
p_sigma_sq
))
# log_r is [num_samples*batch_size, latent_size]. We sum it along the last
# dimension to compute log r.
log_r
=
0.5
*
tf
.
reduce_sum
(
log_r
,
axis
=-
1
)
# Compute prev log r tilde
prev_r_tilde_mu
,
prev_r_tilde_sigma_sq
=
prev_log_r_tilde
prev_log_r_tilde
=
-
0.5
*
tf
.
reduce_sum
(
tf
.
square
(
tf
.
stop_gradient
(
zt
)
-
prev_r_tilde_mu
)
/
prev_r_tilde_sigma_sq
,
axis
=-
1
)
# If the sequence is on the last timestep, log_r and log_r_tilde are just zeros.
last_timestep
=
t
>=
(
tiled_seq_lengths
-
1
)
log_r
=
tf
.
where
(
last_timestep
,
tf
.
zeros_like
(
log_r
),
log_r
)
prev_log_r_tilde
=
tf
.
where
(
last_timestep
,
tf
.
zeros_like
(
prev_log_r_tilde
),
prev_log_r_tilde
)
log_weight
+=
tf
.
stop_gradient
(
log_r
-
prev_log_r
)
new_state
=
(
log_r
,
log_r_tilde
,
new_model_state
)
loop_fn_args
=
(
log_r
,
prev_log_r_tilde
,
log_p_x_given_z
,
log_r
-
prev_log_r
)
return
log_weight
,
new_state
,
loop_fn_args
def
loop_fn
(
loop_state
,
loop_args
,
unused_model_state
,
log_weights
,
resampled
,
mask
,
t
):
if
loop_state
is
None
:
return
(
tf
.
zeros
([
batch_size
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
batch_size
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
num_samples
,
batch_size
],
dtype
=
tf
.
float32
))
log_p_hat_acc
,
bellman_loss_acc
,
log_r_diff_acc
=
loop_state
log_r
,
prev_log_r_tilde
,
log_p_x_given_z
,
log_r_diff
=
loop_args
# Compute the log_p_hat update
log_p_hat_update
=
tf
.
reduce_logsumexp
(
log_weights
,
axis
=
0
)
-
tf
.
log
(
tf
.
to_float
(
num_samples
))
# If it is the last timestep, we always add the update.
log_p_hat_acc
+=
tf
.
cond
(
t
>=
max_seq_len
-
1
,
lambda
:
log_p_hat_update
,
lambda
:
log_p_hat_update
*
resampled
)
# Compute the Bellman update.
log_r
=
tf
.
reshape
(
log_r
,
[
num_samples
,
batch_size
])
prev_log_r_tilde
=
tf
.
reshape
(
prev_log_r_tilde
,
[
num_samples
,
batch_size
])
log_p_x_given_z
=
tf
.
reshape
(
log_p_x_given_z
,
[
num_samples
,
batch_size
])
mask
=
tf
.
reshape
(
mask
,
[
num_samples
,
batch_size
])
# On the first timestep there is no bellman error because there is no
# prev_log_r_tilde.
mask
=
tf
.
cond
(
tf
.
equal
(
t
,
0
),
lambda
:
tf
.
zeros_like
(
mask
),
lambda
:
mask
)
# On the first timestep also fix up prev_log_r_tilde, which will be -inf.
prev_log_r_tilde
=
tf
.
where
(
tf
.
is_inf
(
prev_log_r_tilde
),
tf
.
zeros_like
(
prev_log_r_tilde
),
prev_log_r_tilde
)
# log_lambda is [num_samples, batch_size]
log_lambda
=
tf
.
reduce_mean
(
prev_log_r_tilde
-
log_p_x_given_z
-
log_r
,
axis
=
0
,
keepdims
=
True
)
bellman_error
=
mask
*
tf
.
square
(
prev_log_r_tilde
-
tf
.
stop_gradient
(
log_lambda
+
log_p_x_given_z
+
log_r
)
)
bellman_loss_acc
+=
tf
.
reduce_mean
(
bellman_error
,
axis
=
0
)
# Compute the log_r_diff update
log_r_diff_acc
+=
mask
*
tf
.
reshape
(
log_r_diff
,
[
num_samples
,
batch_size
])
return
(
log_p_hat_acc
,
bellman_loss_acc
,
log_r_diff_acc
)
log_weights
,
resampled
,
accs
=
smc
.
smc
(
transition_fn
,
seq_lengths
,
num_particles
=
num_samples
,
resampling_criterion
=
resampling_criterion
,
resampling_fn
=
resampling_fn
,
loop_fn
=
loop_fn
,
parallel_iterations
=
parallel_iterations
,
swap_memory
=
swap_memory
)
log_p_hat
,
bellman_loss
,
log_r_diff
=
accs
loss_per_seq
=
[
-
log_p_hat
,
bellman_loss
]
tf
.
summary
.
scalar
(
"bellman_loss"
,
tf
.
reduce_mean
(
bellman_loss
/
tf
.
to_float
(
seq_lengths
)))
tf
.
summary
.
scalar
(
"log_r_diff"
,
tf
.
reduce_mean
(
tf
.
reduce_mean
(
log_r_diff
,
axis
=
0
)
/
tf
.
to_float
(
seq_lengths
)))
return
loss_per_seq
,
log_p_hat
,
log_weights
,
resampled
research/fivo/fivo/bounds_test.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.bounds"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
fivo.test_utils
import
create_vrnn
from
fivo
import
bounds
class
BoundsTest
(
tf
.
test
.
TestCase
):
def
test_elbo
(
self
):
"""A golden-value test for the ELBO (the IWAE bound with num_samples=1)."""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
model
,
inputs
,
targets
,
lengths
=
create_vrnn
(
random_seed
=
1234
)
outs
=
bounds
.
iwae
(
model
,
(
inputs
,
targets
),
lengths
,
num_samples
=
1
,
parallel_iterations
=
1
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_p_hat
,
_
,
_
=
sess
.
run
(
outs
)
self
.
assertAllClose
([
-
21.615765
,
-
13.614225
],
log_p_hat
)
def
test_iwae
(
self
):
"""A golden-value test for the IWAE bound."""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
model
,
inputs
,
targets
,
lengths
=
create_vrnn
(
random_seed
=
1234
)
outs
=
bounds
.
iwae
(
model
,
(
inputs
,
targets
),
lengths
,
num_samples
=
4
,
parallel_iterations
=
1
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_p_hat
,
weights
,
_
=
sess
.
run
(
outs
)
self
.
assertAllClose
([
-
23.301426
,
-
13.64028
],
log_p_hat
)
weights_gt
=
np
.
array
(
[[[
-
3.66708851
,
-
2.07074022
,
-
4.91751671
,
-
5.03293562
],
[
-
2.99690723
,
-
3.17782736
,
-
4.50084877
,
-
3.48536515
]],
[[
-
6.2539978
,
-
4.37615728
,
-
7.43738699
,
-
7.85044909
],
[
-
8.27518654
,
-
6.71545124
,
-
8.96198845
,
-
7.05567837
]],
[[
-
9.19093227
,
-
8.01637268
,
-
11.64603615
,
-
10.51128292
],
[
-
12.34527206
,
-
11.54284477
,
-
11.8667469
,
-
9.69417381
]],
[[
-
12.20609856
,
-
10.47217369
,
-
13.66270638
,
-
13.46115875
],
[
-
17.17656708
,
-
16.25190353
,
-
15.28658581
,
-
12.33067703
]],
[[
-
16.14766312
,
-
15.57472229
,
-
17.47755432
,
-
17.98189926
],
[
-
17.17656708
,
-
16.25190353
,
-
15.28658581
,
-
12.33067703
]],
[[
-
20.07182884
,
-
18.43191147
,
-
20.1606636
,
-
21.45263863
],
[
-
17.17656708
,
-
16.25190353
,
-
15.28658581
,
-
12.33067703
]],
[[
-
24.10270691
,
-
22.20865822
,
-
24.14675522
,
-
25.27248383
],
[
-
17.17656708
,
-
16.25190353
,
-
15.28658581
,
-
12.33067703
]]])
self
.
assertAllClose
(
weights_gt
,
weights
)
def
test_fivo
(
self
):
"""A golden-value test for the FIVO bound."""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
model
,
inputs
,
targets
,
lengths
=
create_vrnn
(
random_seed
=
1234
)
outs
=
bounds
.
fivo
(
model
,
(
inputs
,
targets
),
lengths
,
num_samples
=
4
,
random_seed
=
1234
,
parallel_iterations
=
1
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_p_hat
,
weights
,
resampled
,
_
=
sess
.
run
(
outs
)
self
.
assertAllClose
([
-
22.98902512
,
-
14.21689224
],
log_p_hat
)
weights_gt
=
np
.
array
(
[[[
-
3.66708851
,
-
2.07074022
,
-
4.91751671
,
-
5.03293562
],
[
-
2.99690723
,
-
3.17782736
,
-
4.50084877
,
-
3.48536515
]],
[[
-
2.67100811
,
-
2.30541706
,
-
2.34178066
,
-
2.81751347
],
[
-
8.27518654
,
-
6.71545124
,
-
8.96198845
,
-
7.05567837
]],
[[
-
5.65190411
,
-
5.94563246
,
-
6.55041981
,
-
5.4783473
],
[
-
12.34527206
,
-
11.54284477
,
-
11.8667469
,
-
9.69417381
]],
[[
-
8.71947861
,
-
8.40143299
,
-
8.54593086
,
-
8.42822266
],
[
-
4.28782988
,
-
4.50591278
,
-
3.40847206
,
-
2.63650274
]],
[[
-
12.7003831
,
-
13.5039815
,
-
12.3569726
,
-
12.9489622
],
[
-
4.28782988
,
-
4.50591278
,
-
3.40847206
,
-
2.63650274
]],
[[
-
16.4520301
,
-
16.3611698
,
-
15.0314846
,
-
16.4197006
],
[
-
4.28782988
,
-
4.50591278
,
-
3.40847206
,
-
2.63650274
]],
[[
-
20.7010765
,
-
20.1379165
,
-
19.0020351
,
-
20.2395458
],
[
-
4.28782988
,
-
4.50591278
,
-
3.40847206
,
-
2.63650274
]]])
self
.
assertAllClose
(
weights_gt
,
weights
)
resampled_gt
=
np
.
array
(
[[
1.
,
0.
],
[
0.
,
0.
],
[
0.
,
1.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]])
self
.
assertAllClose
(
resampled_gt
,
resampled
)
def
test_fivo_relaxed
(
self
):
"""A golden-value test for the FIVO bound with relaxed sampling."""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
model
,
inputs
,
targets
,
lengths
=
create_vrnn
(
random_seed
=
1234
)
outs
=
bounds
.
fivo
(
model
,
(
inputs
,
targets
),
lengths
,
num_samples
=
4
,
random_seed
=
1234
,
parallel_iterations
=
1
,
resampling_type
=
"relaxed"
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_p_hat
,
weights
,
resampled
,
_
=
sess
.
run
(
outs
)
self
.
assertAllClose
([
-
22.942394
,
-
14.273882
],
log_p_hat
)
weights_gt
=
np
.
array
(
[[[
-
3.66708851
,
-
2.07074118
,
-
4.91751575
,
-
5.03293514
],
[
-
2.99690628
,
-
3.17782831
,
-
4.50084877
,
-
3.48536515
]],
[[
-
2.84939098
,
-
2.30087185
,
-
2.35649204
,
-
2.48417377
],
[
-
8.27518654
,
-
6.71545172
,
-
8.96199131
,
-
7.05567837
]],
[[
-
5.92327023
,
-
5.9433074
,
-
6.5826683
,
-
5.04259014
],
[
-
12.34527206
,
-
11.54284668
,
-
11.86675072
,
-
9.69417477
]],
[[
-
8.95323944
,
-
8.40061855
,
-
8.52760506
,
-
7.99130583
],
[
-
4.58102798
,
-
4.56017351
,
-
3.46283388
,
-
2.65550804
]],
[[
-
12.87836456
,
-
13.49628639
,
-
12.31680107
,
-
12.74228859
],
[
-
4.58102798
,
-
4.56017351
,
-
3.46283388
,
-
2.65550804
]],
[[
-
16.78347397
,
-
16.35150909
,
-
14.98797417
,
-
16.35162735
],
[
-
4.58102798
,
-
4.56017351
,
-
3.46283388
,
-
2.65550804
]],
[[
-
20.81165886
,
-
20.1307621
,
-
18.92229652
,
-
20.17458153
],
[
-
4.58102798
,
-
4.56017351
,
-
3.46283388
,
-
2.65550804
]]])
self
.
assertAllClose
(
weights_gt
,
weights
)
resampled_gt
=
np
.
array
(
[[
1.
,
0.
],
[
0.
,
0.
],
[
0.
,
1.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]])
self
.
assertAllClose
(
resampled_gt
,
resampled
)
def
test_fivo_aux_relaxed
(
self
):
"""A golden-value test for the FIVO-AUX bound with relaxed sampling."""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
model
,
inputs
,
targets
,
lengths
=
create_vrnn
(
random_seed
=
1234
,
use_tilt
=
True
)
outs
=
bounds
.
fivo
(
model
,
(
inputs
,
targets
),
lengths
,
num_samples
=
4
,
random_seed
=
1234
,
parallel_iterations
=
1
,
resampling_type
=
"relaxed"
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_p_hat
,
weights
,
resampled
,
_
=
sess
.
run
(
outs
)
self
.
assertAllClose
([
-
23.1395
,
-
14.271059
],
log_p_hat
)
weights_gt
=
np
.
array
(
[[[
-
5.19826221
,
-
3.55476403
,
-
5.98663855
,
-
6.08058834
],
[
-
6.31685925
,
-
5.70243931
,
-
7.07638931
,
-
6.18138981
]],
[[
-
3.97986865
,
-
3.58831525
,
-
3.85753584
,
-
3.5010016
],
[
-
11.38203049
,
-
8.66213989
,
-
11.23646641
,
-
10.02024746
]],
[[
-
6.62269831
,
-
6.36680222
,
-
6.78096485
,
-
5.80072498
],
[
-
3.55419445
,
-
8.11326408
,
-
3.48766923
,
-
3.08593249
]],
[[
-
10.56472301
,
-
10.16084099
,
-
9.96741676
,
-
8.5270071
],
[
-
6.04880285
,
-
7.80853653
,
-
4.72652149
,
-
3.49711013
]],
[[
-
13.36585426
,
-
16.08720398
,
-
13.33416367
,
-
13.1017189
],
[
-
0.
,
-
0.
,
-
0.
,
-
0.
]],
[[
-
17.54233551
,
-
17.35167503
,
-
16.79163361
,
-
16.51471138
],
[
0.
,
-
0.
,
-
0.
,
-
0.
]],
[[
-
19.74024963
,
-
18.69452858
,
-
17.76246452
,
-
18.76182365
],
[
0.
,
-
0.
,
-
0.
,
-
0.
]]])
self
.
assertAllClose
(
weights_gt
,
weights
)
resampled_gt
=
np
.
array
([[
1.
,
0.
],
[
0.
,
1.
],
[
0.
,
0.
],
[
0.
,
1.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]])
self
.
assertAllClose
(
resampled_gt
,
resampled
)
if
__name__
==
"__main__"
:
np
.
set_printoptions
(
threshold
=
np
.
nan
)
# Used to easily see the gold values.
# Use print(repr(numpy_array)) to print the values.
tf
.
test
.
main
()
research/fivo/fivo/data/__init__.py
deleted
100644 → 0
View file @
5eb294f8
research/fivo/fivo/data/calculate_pianoroll_mean.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to calculate the mean of a pianoroll dataset.
Given a pianoroll pickle file, this script loads the dataset and
calculates the mean of the training set. Then it updates the pickle file
so that the key "train_mean" points to the mean vector.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
pickle
import
numpy
as
np
import
tensorflow
as
tf
from
datasets
import
sparse_pianoroll_to_dense
tf
.
app
.
flags
.
DEFINE_string
(
'in_file'
,
None
,
'Filename of the pickled pianoroll dataset to load.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'out_file'
,
None
,
'Name of the output pickle file. Defaults to in_file, '
'updating the input pickle file.'
)
tf
.
app
.
flags
.
mark_flag_as_required
(
'in_file'
)
FLAGS
=
tf
.
app
.
flags
.
FLAGS
MIN_NOTE
=
21
MAX_NOTE
=
108
NUM_NOTES
=
MAX_NOTE
-
MIN_NOTE
+
1
def
main
(
unused_argv
):
if
FLAGS
.
out_file
is
None
:
FLAGS
.
out_file
=
FLAGS
.
in_file
with
tf
.
gfile
.
Open
(
FLAGS
.
in_file
,
'r'
)
as
f
:
pianorolls
=
pickle
.
load
(
f
)
dense_pianorolls
=
[
sparse_pianoroll_to_dense
(
p
,
MIN_NOTE
,
NUM_NOTES
)[
0
]
for
p
in
pianorolls
[
'train'
]]
# Concatenate all elements along the time axis.
concatenated
=
np
.
concatenate
(
dense_pianorolls
,
axis
=
0
)
mean
=
np
.
mean
(
concatenated
,
axis
=
0
)
pianorolls
[
'train_mean'
]
=
mean
# Write out the whole pickle file, including the train mean.
pickle
.
dump
(
pianorolls
,
open
(
FLAGS
.
out_file
,
'wb'
))
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
research/fivo/fivo/data/create_timit_dataset.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocesses TIMIT from raw wavfiles to create a set of TFRecords.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
glob
import
os
import
random
import
re
import
numpy
as
np
import
tensorflow
as
tf
tf
.
app
.
flags
.
DEFINE_string
(
"raw_timit_dir"
,
None
,
"Directory containing TIMIT files."
)
tf
.
app
.
flags
.
DEFINE_string
(
"out_dir"
,
None
,
"Output directory for TFRecord files."
)
tf
.
app
.
flags
.
DEFINE_float
(
"valid_frac"
,
0.05
,
"Fraction of train set to use as valid set. "
"Must be between 0.0 and 1.0."
)
tf
.
app
.
flags
.
mark_flag_as_required
(
"raw_timit_dir"
)
tf
.
app
.
flags
.
mark_flag_as_required
(
"out_dir"
)
FLAGS
=
tf
.
app
.
flags
.
FLAGS
NUM_TRAIN_FILES
=
4620
NUM_TEST_FILES
=
1680
SAMPLES_PER_TIMESTEP
=
200
# Regexes for reading SPHERE header files.
SAMPLE_COUNT_REGEX
=
re
.
compile
(
r
"sample_count -i (\d+)"
)
SAMPLE_MIN_REGEX
=
re
.
compile
(
r
"sample_min -i (-?\d+)"
)
SAMPLE_MAX_REGEX
=
re
.
compile
(
r
"sample_max -i (-?\d+)"
)
def
get_filenames
(
split
):
"""Get all wav filenames from the TIMIT archive."""
path
=
os
.
path
.
join
(
FLAGS
.
raw_timit_dir
,
"TIMIT"
,
split
,
"*"
,
"*"
,
"*.WAV"
)
# Sort the output by name so the order is deterministic.
files
=
sorted
(
glob
.
glob
(
path
))
return
files
def
load_timit_wav
(
filename
):
"""Loads a TIMIT wavfile into a numpy array.
TIMIT wavfiles include a SPHERE header, detailed in the TIMIT docs. The first
line is the header type and the second is the length of the header in bytes.
After the header, the remaining bytes are actual WAV data.
The header includes information about the WAV data such as the number of
samples and minimum and maximum amplitude. This function asserts that the
loaded wav data matches the header.
Args:
filename: The name of the TIMIT wavfile to load.
Returns:
wav: A numpy array containing the loaded wav data.
"""
wav_file
=
open
(
filename
,
"rb"
)
header_type
=
wav_file
.
readline
()
header_length_str
=
wav_file
.
readline
()
# The header length includes the length of the first two lines.
header_remaining_bytes
=
(
int
(
header_length_str
)
-
len
(
header_type
)
-
len
(
header_length_str
))
header
=
wav_file
.
read
(
header_remaining_bytes
)
# Read the relevant header fields.
sample_count
=
int
(
SAMPLE_COUNT_REGEX
.
search
(
header
).
group
(
1
))
sample_min
=
int
(
SAMPLE_MIN_REGEX
.
search
(
header
).
group
(
1
))
sample_max
=
int
(
SAMPLE_MAX_REGEX
.
search
(
header
).
group
(
1
))
wav
=
np
.
fromstring
(
wav_file
.
read
(),
dtype
=
"int16"
).
astype
(
"float32"
)
# Check that the loaded data conforms to the header description.
assert
len
(
wav
)
==
sample_count
assert
wav
.
min
()
==
sample_min
assert
wav
.
max
()
==
sample_max
return
wav
def
preprocess
(
wavs
,
block_size
,
mean
,
std
):
"""Normalize the wav data and reshape it into chunks."""
processed_wavs
=
[]
for
wav
in
wavs
:
wav
=
(
wav
-
mean
)
/
std
wav_length
=
wav
.
shape
[
0
]
if
wav_length
%
block_size
!=
0
:
pad_width
=
block_size
-
(
wav_length
%
block_size
)
wav
=
np
.
pad
(
wav
,
(
0
,
pad_width
),
"constant"
)
assert
wav
.
shape
[
0
]
%
block_size
==
0
wav
=
wav
.
reshape
((
-
1
,
block_size
))
processed_wavs
.
append
(
wav
)
return
processed_wavs
def
create_tfrecord_from_wavs
(
wavs
,
output_file
):
"""Writes processed wav files to disk as sharded TFRecord files."""
with
tf
.
python_io
.
TFRecordWriter
(
output_file
)
as
builder
:
for
wav
in
wavs
:
builder
.
write
(
wav
.
astype
(
np
.
float32
).
tobytes
())
def
main
(
unused_argv
):
train_filenames
=
get_filenames
(
"TRAIN"
)
test_filenames
=
get_filenames
(
"TEST"
)
num_train_files
=
len
(
train_filenames
)
num_test_files
=
len
(
test_filenames
)
num_valid_files
=
int
(
num_train_files
*
FLAGS
.
valid_frac
)
num_train_files
-=
num_valid_files
print
(
"%d train / %d valid / %d test"
%
(
num_train_files
,
num_valid_files
,
num_test_files
))
random
.
seed
(
1234
)
random
.
shuffle
(
train_filenames
)
valid_filenames
=
train_filenames
[:
num_valid_files
]
train_filenames
=
train_filenames
[
num_valid_files
:]
# Make sure there is no overlap in the train, test, and valid sets.
train_s
=
set
(
train_filenames
)
test_s
=
set
(
test_filenames
)
valid_s
=
set
(
valid_filenames
)
# Disable explicit length testing to make the assertions more readable.
# pylint: disable=g-explicit-length-test
assert
len
(
train_s
&
test_s
)
==
0
assert
len
(
train_s
&
valid_s
)
==
0
assert
len
(
valid_s
&
test_s
)
==
0
# pylint: enable=g-explicit-length-test
train_wavs
=
[
load_timit_wav
(
f
)
for
f
in
train_filenames
]
valid_wavs
=
[
load_timit_wav
(
f
)
for
f
in
valid_filenames
]
test_wavs
=
[
load_timit_wav
(
f
)
for
f
in
test_filenames
]
assert
len
(
train_wavs
)
+
len
(
valid_wavs
)
==
NUM_TRAIN_FILES
assert
len
(
test_wavs
)
==
NUM_TEST_FILES
# Calculate the mean and standard deviation of the train set.
train_stacked
=
np
.
hstack
(
train_wavs
)
train_mean
=
np
.
mean
(
train_stacked
)
train_std
=
np
.
std
(
train_stacked
)
print
(
"train mean: %f train std: %f"
%
(
train_mean
,
train_std
))
# Process all data, normalizing with the train set statistics.
processed_train_wavs
=
preprocess
(
train_wavs
,
SAMPLES_PER_TIMESTEP
,
train_mean
,
train_std
)
processed_valid_wavs
=
preprocess
(
valid_wavs
,
SAMPLES_PER_TIMESTEP
,
train_mean
,
train_std
)
processed_test_wavs
=
preprocess
(
test_wavs
,
SAMPLES_PER_TIMESTEP
,
train_mean
,
train_std
)
# Write the datasets to disk.
create_tfrecord_from_wavs
(
processed_train_wavs
,
os
.
path
.
join
(
FLAGS
.
out_dir
,
"train"
))
create_tfrecord_from_wavs
(
processed_valid_wavs
,
os
.
path
.
join
(
FLAGS
.
out_dir
,
"valid"
))
create_tfrecord_from_wavs
(
processed_test_wavs
,
os
.
path
.
join
(
FLAGS
.
out_dir
,
"test"
))
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
research/fivo/fivo/data/datasets.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Code for creating sequence datasets.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
pickle
import
numpy
as
np
from
scipy.sparse
import
coo_matrix
import
tensorflow
as
tf
# The default number of threads used to process data in parallel.
DEFAULT_PARALLELISM
=
12
def
sparse_pianoroll_to_dense
(
pianoroll
,
min_note
,
num_notes
):
"""Converts a sparse pianoroll to a dense numpy array.
Given a sparse pianoroll, converts it to a dense numpy array of shape
[num_timesteps, num_notes] where entry i,j is 1.0 if note j is active on
timestep i and 0.0 otherwise.
Args:
pianoroll: A sparse pianoroll object, a list of tuples where the i'th tuple
contains the indices of the notes active at timestep i.
min_note: The minimum note in the pianoroll, subtracted from all notes so
that the minimum note becomes 0.
num_notes: The number of possible different note indices, determines the
second dimension of the resulting dense array.
Returns:
dense_pianoroll: A [num_timesteps, num_notes] numpy array of floats.
num_timesteps: A python int, the number of timesteps in the pianoroll.
"""
num_timesteps
=
len
(
pianoroll
)
inds
=
[]
for
time
,
chord
in
enumerate
(
pianoroll
):
# Re-index the notes to start from min_note.
inds
.
extend
((
time
,
note
-
min_note
)
for
note
in
chord
)
shape
=
[
num_timesteps
,
num_notes
]
values
=
[
1.
]
*
len
(
inds
)
sparse_pianoroll
=
coo_matrix
(
(
values
,
([
x
[
0
]
for
x
in
inds
],
[
x
[
1
]
for
x
in
inds
])),
shape
=
shape
)
return
sparse_pianoroll
.
toarray
(),
num_timesteps
def
create_pianoroll_dataset
(
path
,
split
,
batch_size
,
num_parallel_calls
=
DEFAULT_PARALLELISM
,
shuffle
=
False
,
repeat
=
False
,
min_note
=
21
,
max_note
=
108
):
"""Creates a pianoroll dataset.
Args:
path: The path of a pickle file containing the dataset to load.
split: The split to use, can be train, test, or valid.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
min_note: The minimum note number of the dataset. For all pianoroll datasets
the minimum note is number 21, and changing this affects the dimension of
the data. This is useful mostly for testing.
max_note: The maximum note number of the dataset. For all pianoroll datasets
the maximum note is number 108, and changing this affects the dimension of
the data. This is useful mostly for testing.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros. This tensor is mean-centered, with the mean taken from the pickle
file key 'train_mean'.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
mean: A float Tensor of shape [data_dimension] containing the mean loaded
from the pickle file.
"""
# Load the data from disk.
num_notes
=
max_note
-
min_note
+
1
with
tf
.
gfile
.
Open
(
path
,
"r"
)
as
f
:
raw_data
=
pickle
.
load
(
f
)
pianorolls
=
raw_data
[
split
]
mean
=
raw_data
[
"train_mean"
]
num_examples
=
len
(
pianorolls
)
def
pianoroll_generator
():
for
sparse_pianoroll
in
pianorolls
:
yield
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
,
num_notes
)
dataset
=
tf
.
data
.
Dataset
.
from_generator
(
pianoroll_generator
,
output_types
=
(
tf
.
float64
,
tf
.
int64
),
output_shapes
=
([
None
,
num_notes
],
[]))
if
repeat
:
dataset
=
dataset
.
repeat
()
if
shuffle
:
dataset
=
dataset
.
shuffle
(
num_examples
)
# Batch sequences togther, padding them to a common length in time.
dataset
=
dataset
.
padded_batch
(
batch_size
,
padded_shapes
=
([
None
,
num_notes
],
[]))
def
process_pianoroll_batch
(
data
,
lengths
):
"""Create mean-centered and time-major next-step prediction Tensors."""
data
=
tf
.
to_float
(
tf
.
transpose
(
data
,
perm
=
[
1
,
0
,
2
]))
lengths
=
tf
.
to_int32
(
lengths
)
targets
=
data
# Mean center the inputs.
inputs
=
data
-
tf
.
constant
(
mean
,
dtype
=
tf
.
float32
,
shape
=
[
1
,
1
,
mean
.
shape
[
0
]])
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs
=
tf
.
pad
(
inputs
,
[[
1
,
0
],
[
0
,
0
],
[
0
,
0
]],
mode
=
"CONSTANT"
)[:
-
1
]
# Mask out unused timesteps.
inputs
*=
tf
.
expand_dims
(
tf
.
transpose
(
tf
.
sequence_mask
(
lengths
,
dtype
=
inputs
.
dtype
)),
2
)
return
inputs
,
targets
,
lengths
dataset
=
dataset
.
map
(
process_pianoroll_batch
,
num_parallel_calls
=
num_parallel_calls
)
dataset
=
dataset
.
prefetch
(
num_examples
)
itr
=
dataset
.
make_one_shot_iterator
()
inputs
,
targets
,
lengths
=
itr
.
get_next
()
return
inputs
,
targets
,
lengths
,
tf
.
constant
(
mean
,
dtype
=
tf
.
float32
)
def
create_human_pose_dataset
(
path
,
split
,
batch_size
,
num_parallel_calls
=
DEFAULT_PARALLELISM
,
shuffle
=
False
,
repeat
=
False
,):
"""Creates a human pose dataset.
Args:
path: The path of a pickle file containing the dataset to load.
split: The split to use, can be train, test, or valid.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros. This tensor is mean-centered, with the mean taken from the pickle
file key 'train_mean'.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
mean: A float Tensor of shape [data_dimension] containing the mean loaded
from the pickle file.
"""
# Load the data from disk.
with
tf
.
gfile
.
Open
(
path
,
"r"
)
as
f
:
raw_data
=
pickle
.
load
(
f
)
mean
=
raw_data
[
"train_mean"
]
pose_sequences
=
raw_data
[
split
]
num_examples
=
len
(
pose_sequences
)
num_features
=
pose_sequences
[
0
].
shape
[
1
]
def
pose_generator
():
"""A generator that yields pose data sequences."""
# Each timestep has 32 x values followed by 32 y values so is 64
# dimensional.
for
pose_sequence
in
pose_sequences
:
yield
pose_sequence
,
pose_sequence
.
shape
[
0
]
dataset
=
tf
.
data
.
Dataset
.
from_generator
(
pose_generator
,
output_types
=
(
tf
.
float64
,
tf
.
int64
),
output_shapes
=
([
None
,
num_features
],
[]))
if
repeat
:
dataset
=
dataset
.
repeat
()
if
shuffle
:
dataset
=
dataset
.
shuffle
(
num_examples
)
# Batch sequences togther, padding them to a common length in time.
dataset
=
dataset
.
padded_batch
(
batch_size
,
padded_shapes
=
([
None
,
num_features
],
[]))
# Post-process each batch, ensuring that it is mean-centered and time-major.
def
process_pose_data
(
data
,
lengths
):
"""Creates Tensors for next step prediction and mean-centers the input."""
data
=
tf
.
to_float
(
tf
.
transpose
(
data
,
perm
=
[
1
,
0
,
2
]))
lengths
=
tf
.
to_int32
(
lengths
)
targets
=
data
# Mean center the inputs.
inputs
=
data
-
tf
.
constant
(
mean
,
dtype
=
tf
.
float32
,
shape
=
[
1
,
1
,
mean
.
shape
[
0
]])
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs
=
tf
.
pad
(
inputs
,
[[
1
,
0
],
[
0
,
0
],
[
0
,
0
]],
mode
=
"CONSTANT"
)[:
-
1
]
# Mask out unused timesteps.
inputs
*=
tf
.
expand_dims
(
tf
.
transpose
(
tf
.
sequence_mask
(
lengths
,
dtype
=
inputs
.
dtype
)),
2
)
return
inputs
,
targets
,
lengths
dataset
=
dataset
.
map
(
process_pose_data
,
num_parallel_calls
=
num_parallel_calls
)
dataset
=
dataset
.
prefetch
(
num_examples
)
itr
=
dataset
.
make_one_shot_iterator
()
inputs
,
targets
,
lengths
=
itr
.
get_next
()
return
inputs
,
targets
,
lengths
,
tf
.
constant
(
mean
,
dtype
=
tf
.
float32
)
def
create_speech_dataset
(
path
,
batch_size
,
samples_per_timestep
=
200
,
num_parallel_calls
=
DEFAULT_PARALLELISM
,
prefetch_buffer_size
=
2048
,
shuffle
=
False
,
repeat
=
False
):
"""Creates a speech dataset.
Args:
path: The path of a possibly sharded TFRecord file containing the data.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
samples_per_timestep: The number of audio samples per timestep. Used to
reshape the data into sequences of shape [time, samples_per_timestep].
Should not change except for testing -- in all speech datasets 200 is the
number of samples per timestep.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
prefetch_buffer_size: The size of the prefetch queues to use after reading
and processing the raw data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, samples_per_timestep]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, samples_per_timestep].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
"""
filenames
=
[
path
]
def
read_speech_example
(
value
):
"""Parses a single tf.Example from the TFRecord file."""
decoded
=
tf
.
decode_raw
(
value
,
out_type
=
tf
.
float32
)
example
=
tf
.
reshape
(
decoded
,
[
-
1
,
samples_per_timestep
])
length
=
tf
.
shape
(
example
)[
0
]
return
example
,
length
# Create the dataset from the TFRecord files
dataset
=
tf
.
data
.
TFRecordDataset
(
filenames
).
map
(
read_speech_example
,
num_parallel_calls
=
num_parallel_calls
)
dataset
=
dataset
.
prefetch
(
prefetch_buffer_size
)
if
repeat
:
dataset
=
dataset
.
repeat
()
if
shuffle
:
dataset
=
dataset
.
shuffle
(
prefetch_buffer_size
)
dataset
=
dataset
.
padded_batch
(
batch_size
,
padded_shapes
=
([
None
,
samples_per_timestep
],
[]))
def
process_speech_batch
(
data
,
lengths
):
"""Creates Tensors for next step prediction."""
data
=
tf
.
transpose
(
data
,
perm
=
[
1
,
0
,
2
])
lengths
=
tf
.
to_int32
(
lengths
)
targets
=
data
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs
=
tf
.
pad
(
data
,
[[
1
,
0
],
[
0
,
0
],
[
0
,
0
]],
mode
=
"CONSTANT"
)[:
-
1
]
# Mask out unused timesteps.
inputs
*=
tf
.
expand_dims
(
tf
.
transpose
(
tf
.
sequence_mask
(
lengths
,
dtype
=
inputs
.
dtype
)),
2
)
return
inputs
,
targets
,
lengths
dataset
=
dataset
.
map
(
process_speech_batch
,
num_parallel_calls
=
num_parallel_calls
)
dataset
=
dataset
.
prefetch
(
prefetch_buffer_size
)
itr
=
dataset
.
make_one_shot_iterator
()
inputs
,
targets
,
lengths
=
itr
.
get_next
()
return
inputs
,
targets
,
lengths
SQUARED_OBSERVATION
=
"squared"
ABS_OBSERVATION
=
"abs"
STANDARD_OBSERVATION
=
"standard"
OBSERVATION_TYPES
=
[
SQUARED_OBSERVATION
,
ABS_OBSERVATION
,
STANDARD_OBSERVATION
]
ROUND_TRANSITION
=
"round"
STANDARD_TRANSITION
=
"standard"
TRANSITION_TYPES
=
[
ROUND_TRANSITION
,
STANDARD_TRANSITION
]
def
create_chain_graph_dataset
(
batch_size
,
num_timesteps
,
steps_per_observation
=
None
,
state_size
=
1
,
transition_variance
=
1.
,
observation_variance
=
1.
,
transition_type
=
STANDARD_TRANSITION
,
observation_type
=
STANDARD_OBSERVATION
,
fixed_observation
=
None
,
prefetch_buffer_size
=
2048
,
dtype
=
"float32"
):
"""Creates a toy chain graph dataset.
Creates a dataset where the data are sampled from a diffusion process. The
'latent' states of the process are sampled as a chain of Normals:
z0 ~ N(0, transition_variance)
z1 ~ N(transition_fn(z0), transition_variance)
...
where transition_fn could be round z0 or pass it through unchanged.
The observations are produced every steps_per_observation timesteps as a
function of the latent zs. For example if steps_per_observation is 3 then the
first observation will be produced as a function of z3:
x1 ~ N(observation_fn(z3), observation_variance)
where observation_fn could square z3, take the absolute value, or pass
it through unchanged.
Only the observations are returned.
Args:
batch_size: The batch size. The number of trajectories to run in parallel.
num_timesteps: The length of the chain of latent states (i.e. the
number of z's excluding z0.
steps_per_observation: The number of latent states between each observation,
must evenly divide num_timesteps.
state_size: The size of the latent state and observation, must be a
python int.
transition_variance: The variance of the transition density.
observation_variance: The variance of the observation density.
transition_type: Must be one of "round" or "standard". "round" means that
the transition density is centered at the rounded previous latent state.
"standard" centers the transition density at the previous latent state,
unchanged.
observation_type: Must be one of "squared", "abs" or "standard". "squared"
centers the observation density at the squared latent state. "abs"
centers the observaiton density at the absolute value of the current
latent state. "standard" centers the observation density at the current
latent state.
fixed_observation: If not None, fixes all observations to be a constant.
Must be a scalar.
prefetch_buffer_size: The size of the prefetch queues to use after reading
and processing the raw data.
dtype: A string convertible to a tensorflow datatype. The datatype used
to represent the states and observations.
Returns:
observations: A batch of observations represented as a dense Tensor of
shape [num_observations, batch_size, state_size]. num_observations is
num_timesteps/steps_per_observation.
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch. Will contain num_observations as each entry.
Raises:
ValueError: Raised if steps_per_observation does not evenly divide
num_timesteps.
"""
if
steps_per_observation
is
None
:
steps_per_observation
=
num_timesteps
if
num_timesteps
%
steps_per_observation
!=
0
:
raise
ValueError
(
"steps_per_observation must evenly divide num_timesteps."
)
num_observations
=
int
(
num_timesteps
/
steps_per_observation
)
def
data_generator
():
"""An infinite generator of latents and observations from the model."""
transition_std
=
np
.
sqrt
(
transition_variance
)
observation_std
=
np
.
sqrt
(
observation_variance
)
while
True
:
states
=
[]
observations
=
[]
# Sample z0 ~ Normal(0, sqrt(variance)).
states
.
append
(
np
.
random
.
normal
(
size
=
[
state_size
],
scale
=
observation_std
).
astype
(
dtype
))
# Start the range at 1 because we've already generated z0.
# The range ends at num_timesteps+1 because we want to include the
# num_timesteps-th step.
for
t
in
xrange
(
1
,
num_timesteps
+
1
):
if
transition_type
==
ROUND_TRANSITION
:
loc
=
np
.
round
(
states
[
-
1
])
elif
transition_type
==
STANDARD_TRANSITION
:
loc
=
states
[
-
1
]
z_t
=
np
.
random
.
normal
(
size
=
[
state_size
],
loc
=
loc
,
scale
=
transition_std
)
states
.
append
(
z_t
.
astype
(
dtype
))
if
t
%
steps_per_observation
==
0
:
if
fixed_observation
is
None
:
if
observation_type
==
SQUARED_OBSERVATION
:
loc
=
np
.
square
(
states
[
-
1
])
elif
observation_type
==
ABS_OBSERVATION
:
loc
=
np
.
abs
(
states
[
-
1
])
elif
observation_type
==
STANDARD_OBSERVATION
:
loc
=
states
[
-
1
]
x_t
=
np
.
random
.
normal
(
size
=
[
state_size
],
loc
=
loc
,
scale
=
observation_std
).
astype
(
dtype
)
else
:
x_t
=
np
.
ones
([
state_size
])
*
fixed_observation
observations
.
append
(
x_t
)
yield
states
,
observations
dataset
=
tf
.
data
.
Dataset
.
from_generator
(
data_generator
,
output_types
=
(
tf
.
as_dtype
(
dtype
),
tf
.
as_dtype
(
dtype
)),
output_shapes
=
([
num_timesteps
+
1
,
state_size
],
[
num_observations
,
state_size
])
)
dataset
=
dataset
.
repeat
().
batch
(
batch_size
)
dataset
=
dataset
.
prefetch
(
prefetch_buffer_size
)
itr
=
dataset
.
make_one_shot_iterator
()
_
,
observations
=
itr
.
get_next
()
# Transpose observations from [batch, time, state_size] to
# [time, batch, state_size].
observations
=
tf
.
transpose
(
observations
,
perm
=
[
1
,
0
,
2
])
lengths
=
tf
.
ones
([
batch_size
],
dtype
=
tf
.
int32
)
*
num_observations
return
observations
,
lengths
research/fivo/fivo/data/datasets_test.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.data.datasets."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
pickle
import
os
import
numpy
as
np
import
tensorflow
as
tf
from
fivo.data
import
datasets
FLAGS
=
tf
.
app
.
flags
.
FLAGS
class
DatasetsTest
(
tf
.
test
.
TestCase
):
def
test_sparse_pianoroll_to_dense_empty_at_end
(
self
):
sparse_pianoroll
=
[(
0
,
1
),
(
1
,
0
),
(),
(
1
,),
(),
()]
dense_pianoroll
,
num_timesteps
=
datasets
.
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
=
0
,
num_notes
=
2
)
self
.
assertEqual
(
num_timesteps
,
6
)
self
.
assertAllEqual
([[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
],
[
0
,
0
]],
dense_pianoroll
)
def
test_sparse_pianoroll_to_dense_with_chord
(
self
):
sparse_pianoroll
=
[(
0
,
1
),
(
1
,
0
),
(),
(
1
,)]
dense_pianoroll
,
num_timesteps
=
datasets
.
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
=
0
,
num_notes
=
2
)
self
.
assertEqual
(
num_timesteps
,
4
)
self
.
assertAllEqual
([[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
1
]],
dense_pianoroll
)
def
test_sparse_pianoroll_to_dense_simple
(
self
):
sparse_pianoroll
=
[(
0
,),
(),
(
1
,)]
dense_pianoroll
,
num_timesteps
=
datasets
.
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
=
0
,
num_notes
=
2
)
self
.
assertEqual
(
num_timesteps
,
3
)
self
.
assertAllEqual
([[
1
,
0
],
[
0
,
0
],
[
0
,
1
]],
dense_pianoroll
)
def
test_sparse_pianoroll_to_dense_subtracts_min_note
(
self
):
sparse_pianoroll
=
[(
4
,
5
),
(
5
,
4
),
(),
(
5
,),
(),
()]
dense_pianoroll
,
num_timesteps
=
datasets
.
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
=
4
,
num_notes
=
2
)
self
.
assertEqual
(
num_timesteps
,
6
)
self
.
assertAllEqual
([[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
],
[
0
,
0
]],
dense_pianoroll
)
def
test_sparse_pianoroll_to_dense_uses_num_notes
(
self
):
sparse_pianoroll
=
[(
4
,
5
),
(
5
,
4
),
(),
(
5
,),
(),
()]
dense_pianoroll
,
num_timesteps
=
datasets
.
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
=
4
,
num_notes
=
3
)
self
.
assertEqual
(
num_timesteps
,
6
)
self
.
assertAllEqual
([[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
0
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
0
]],
dense_pianoroll
)
def
test_pianoroll_dataset
(
self
):
pianoroll_data
=
[[(
0
,),
(),
(
1
,)],
[(
0
,
1
),
(
1
,)],
[(
1
,),
(
0
,),
(),
(
0
,
1
),
(),
()]]
pianoroll_mean
=
np
.
zeros
([
3
])
pianoroll_mean
[
-
1
]
=
1
data
=
{
"train"
:
pianoroll_data
,
"train_mean"
:
pianoroll_mean
}
path
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
"test.pkl"
)
pickle
.
dump
(
data
,
open
(
path
,
"wb"
))
with
self
.
test_session
()
as
sess
:
inputs
,
targets
,
lens
,
mean
=
datasets
.
create_pianoroll_dataset
(
path
,
"train"
,
2
,
num_parallel_calls
=
1
,
shuffle
=
False
,
repeat
=
False
,
min_note
=
0
,
max_note
=
2
)
i1
,
t1
,
l1
=
sess
.
run
([
inputs
,
targets
,
lens
])
i2
,
t2
,
l2
=
sess
.
run
([
inputs
,
targets
,
lens
])
m
=
sess
.
run
(
mean
)
# Check the lengths.
self
.
assertAllEqual
([
3
,
2
],
l1
)
self
.
assertAllEqual
([
6
],
l2
)
# Check the mean.
self
.
assertAllEqual
(
pianoroll_mean
,
m
)
# Check the targets. The targets should not be mean-centered and should
# be padded with zeros to a common length within a batch.
self
.
assertAllEqual
([[
1
,
0
,
0
],
[
0
,
0
,
0
],
[
0
,
1
,
0
]],
t1
[:,
0
,
:])
self
.
assertAllEqual
([[
1
,
1
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
0
]],
t1
[:,
1
,
:])
self
.
assertAllEqual
([[
0
,
1
,
0
],
[
1
,
0
,
0
],
[
0
,
0
,
0
],
[
1
,
1
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
0
]],
t2
[:,
0
,
:])
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch. The mean should be subtracted from all timesteps except
# the first and the padding.
self
.
assertAllEqual
([[
0
,
0
,
0
],
[
1
,
0
,
-
1
],
[
0
,
0
,
-
1
]],
i1
[:,
0
,
:])
self
.
assertAllEqual
([[
0
,
0
,
0
],
[
1
,
1
,
-
1
],
[
0
,
0
,
0
]],
i1
[:,
1
,
:])
self
.
assertAllEqual
([[
0
,
0
,
0
],
[
0
,
1
,
-
1
],
[
1
,
0
,
-
1
],
[
0
,
0
,
-
1
],
[
1
,
1
,
-
1
],
[
0
,
0
,
-
1
]],
i2
[:,
0
,
:])
def
test_human_pose_dataset
(
self
):
pose_data
=
[
[[
0
,
0
],
[
2
,
2
]],
[[
2
,
2
]],
[[
0
,
0
],
[
0
,
0
],
[
2
,
2
],
[
2
,
2
],
[
0
,
0
]],
]
pose_data
=
[
np
.
array
(
x
,
dtype
=
np
.
float64
)
for
x
in
pose_data
]
pose_data_mean
=
np
.
array
([
1
,
1
],
dtype
=
np
.
float64
)
data
=
{
"train"
:
pose_data
,
"train_mean"
:
pose_data_mean
,
}
path
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
"test_human_pose_dataset.pkl"
)
with
open
(
path
,
"wb"
)
as
out
:
pickle
.
dump
(
data
,
out
)
with
self
.
test_session
()
as
sess
:
inputs
,
targets
,
lens
,
mean
=
datasets
.
create_human_pose_dataset
(
path
,
"train"
,
2
,
num_parallel_calls
=
1
,
shuffle
=
False
,
repeat
=
False
)
i1
,
t1
,
l1
=
sess
.
run
([
inputs
,
targets
,
lens
])
i2
,
t2
,
l2
=
sess
.
run
([
inputs
,
targets
,
lens
])
m
=
sess
.
run
(
mean
)
# Check the lengths.
self
.
assertAllEqual
([
2
,
1
],
l1
)
self
.
assertAllEqual
([
5
],
l2
)
# Check the mean.
self
.
assertAllEqual
(
pose_data_mean
,
m
)
# Check the targets. The targets should not be mean-centered and should
# be padded with zeros to a common length within a batch.
self
.
assertAllEqual
([[
0
,
0
],
[
2
,
2
]],
t1
[:,
0
,
:])
self
.
assertAllEqual
([[
2
,
2
],
[
0
,
0
]],
t1
[:,
1
,
:])
self
.
assertAllEqual
([[
0
,
0
],
[
0
,
0
],
[
2
,
2
],
[
2
,
2
],
[
0
,
0
]],
t2
[:,
0
,
:])
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch. The mean should be subtracted from all timesteps except
# the first and the padding.
self
.
assertAllEqual
([[
0
,
0
],
[
-
1
,
-
1
]],
i1
[:,
0
,
:])
self
.
assertAllEqual
([[
0
,
0
],
[
0
,
0
]],
i1
[:,
1
,
:])
self
.
assertAllEqual
([[
0
,
0
],
[
-
1
,
-
1
],
[
-
1
,
-
1
],
[
1
,
1
],
[
1
,
1
]],
i2
[:,
0
,
:])
def
test_speech_dataset
(
self
):
with
self
.
test_session
()
as
sess
:
path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))),
"test_data"
,
"tiny_speech_dataset.tfrecord"
)
inputs
,
targets
,
lens
=
datasets
.
create_speech_dataset
(
path
,
3
,
samples_per_timestep
=
2
,
num_parallel_calls
=
1
,
prefetch_buffer_size
=
3
,
shuffle
=
False
,
repeat
=
False
)
inputs1
,
targets1
,
lengths1
=
sess
.
run
([
inputs
,
targets
,
lens
])
inputs2
,
targets2
,
lengths2
=
sess
.
run
([
inputs
,
targets
,
lens
])
# Check the lengths.
self
.
assertAllEqual
([
1
,
2
,
3
],
lengths1
)
self
.
assertAllEqual
([
4
],
lengths2
)
# Check the targets. The targets should be padded with zeros to a common
# length within a batch.
self
.
assertAllEqual
([[[
0.
,
1.
],
[
0.
,
1.
],
[
0.
,
1.
]],
[[
0.
,
0.
],
[
2.
,
3.
],
[
2.
,
3.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
4.
,
5.
]]],
targets1
)
self
.
assertAllEqual
([[[
0.
,
1.
]],
[[
2.
,
3.
]],
[[
4.
,
5.
]],
[[
6.
,
7.
]]],
targets2
)
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch.
self
.
assertAllEqual
([[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]],
[[
0.
,
0.
],
[
0.
,
1.
],
[
0.
,
1.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
2.
,
3.
]]],
inputs1
)
self
.
assertAllEqual
([[[
0.
,
0.
]],
[[
0.
,
1.
]],
[[
2.
,
3.
]],
[[
4.
,
5.
]]],
inputs2
)
def
test_chain_graph_raises_error_on_wrong_steps_per_observation
(
self
):
with
self
.
assertRaises
(
ValueError
):
datasets
.
create_chain_graph_dataset
(
batch_size
=
4
,
num_timesteps
=
10
,
steps_per_observation
=
9
)
def
test_chain_graph_single_obs
(
self
):
with
self
.
test_session
()
as
sess
:
np
.
random
.
seed
(
1234
)
num_observations
=
1
num_timesteps
=
5
batch_size
=
2
state_size
=
1
observations
,
lengths
=
datasets
.
create_chain_graph_dataset
(
batch_size
=
batch_size
,
num_timesteps
=
num_timesteps
,
state_size
=
state_size
)
out_observations
,
out_lengths
=
sess
.
run
([
observations
,
lengths
])
self
.
assertAllEqual
([
num_observations
,
num_observations
],
out_lengths
)
self
.
assertAllClose
(
[[[
1.426677
],
[
-
1.789461
]]],
out_observations
)
def
test_chain_graph_multiple_obs
(
self
):
with
self
.
test_session
()
as
sess
:
np
.
random
.
seed
(
1234
)
num_observations
=
3
num_timesteps
=
6
batch_size
=
2
state_size
=
1
observations
,
lengths
=
datasets
.
create_chain_graph_dataset
(
batch_size
=
batch_size
,
num_timesteps
=
num_timesteps
,
steps_per_observation
=
num_timesteps
/
num_observations
,
state_size
=
state_size
)
out_observations
,
out_lengths
=
sess
.
run
([
observations
,
lengths
])
self
.
assertAllEqual
([
num_observations
,
num_observations
],
out_lengths
)
self
.
assertAllClose
(
[[[
0.40051451
],
[
1.07405114
]],
[[
1.73932898
],
[
3.16880035
]],
[[
-
1.98377144
],
[
2.82669163
]]],
out_observations
)
def
test_chain_graph_state_dims
(
self
):
with
self
.
test_session
()
as
sess
:
np
.
random
.
seed
(
1234
)
num_observations
=
1
num_timesteps
=
5
batch_size
=
2
state_size
=
3
observations
,
lengths
=
datasets
.
create_chain_graph_dataset
(
batch_size
=
batch_size
,
num_timesteps
=
num_timesteps
,
state_size
=
state_size
)
out_observations
,
out_lengths
=
sess
.
run
([
observations
,
lengths
])
self
.
assertAllEqual
([
num_observations
,
num_observations
],
out_lengths
)
self
.
assertAllClose
(
[[[
1.052287
,
-
4.560759
,
3.07988
],
[
2.008926
,
0.495567
,
3.488678
]]],
out_observations
)
def
test_chain_graph_fixed_obs
(
self
):
with
self
.
test_session
()
as
sess
:
np
.
random
.
seed
(
1234
)
num_observations
=
3
num_timesteps
=
6
batch_size
=
2
state_size
=
1
observations
,
lengths
=
datasets
.
create_chain_graph_dataset
(
batch_size
=
batch_size
,
num_timesteps
=
num_timesteps
,
steps_per_observation
=
num_timesteps
/
num_observations
,
state_size
=
state_size
,
fixed_observation
=
4.
)
out_observations
,
out_lengths
=
sess
.
run
([
observations
,
lengths
])
self
.
assertAllEqual
([
num_observations
,
num_observations
],
out_lengths
)
self
.
assertAllClose
(
np
.
ones
([
num_observations
,
batch_size
,
state_size
])
*
4.
,
out_observations
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/fivo/fivo/ghmm_runners.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Creates and runs Gaussian HMM-related graphs."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
tensorflow
as
tf
from
fivo
import
smc
from
fivo
import
bounds
from
fivo.data
import
datasets
from
fivo.models
import
ghmm
def
run_train
(
config
):
"""Runs training for a Gaussian HMM setup."""
def
create_logging_hook
(
step
,
bound_value
,
likelihood
,
bound_gap
):
"""Creates a logging hook that prints the bound value periodically."""
bound_label
=
config
.
bound
+
"/t"
def
summary_formatter
(
log_dict
):
string
=
(
"Step {step}, %s: {value:.3f}, "
"likelihood: {ll:.3f}, gap: {gap:.3e}"
)
%
bound_label
return
string
.
format
(
**
log_dict
)
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
{
"step"
:
step
,
"value"
:
bound_value
,
"ll"
:
likelihood
,
"gap"
:
bound_gap
},
every_n_iter
=
config
.
summarize_every
,
formatter
=
summary_formatter
)
return
logging_hook
def
create_losses
(
model
,
observations
,
lengths
):
"""Creates the loss to be optimized.
Args:
model: A Trainable GHMM model.
observations: A set of observations.
lengths: The lengths of each sequence in the observations.
Returns:
loss: A float Tensor that when differentiated yields the gradients
to apply to the model. Should be optimized via gradient descent.
bound: A float Tensor containing the value of the bound that is
being optimized.
true_ll: The true log-likelihood of the data under the model.
bound_gap: The gap between the bound and the true log-likelihood.
"""
# Compute lower bounds on the log likelihood.
if
config
.
bound
==
"elbo"
:
ll_per_seq
,
_
,
_
=
bounds
.
iwae
(
model
,
observations
,
lengths
,
num_samples
=
1
,
parallel_iterations
=
config
.
parallel_iterations
)
elif
config
.
bound
==
"iwae"
:
ll_per_seq
,
_
,
_
=
bounds
.
iwae
(
model
,
observations
,
lengths
,
num_samples
=
config
.
num_samples
,
parallel_iterations
=
config
.
parallel_iterations
)
elif
config
.
bound
==
"fivo"
:
if
config
.
resampling_type
==
"relaxed"
:
ll_per_seq
,
_
,
_
,
_
=
bounds
.
fivo
(
model
,
observations
,
lengths
,
num_samples
=
config
.
num_samples
,
resampling_criterion
=
smc
.
ess_criterion
,
resampling_type
=
config
.
resampling_type
,
relaxed_resampling_temperature
=
config
.
relaxed_resampling_temperature
,
random_seed
=
config
.
random_seed
,
parallel_iterations
=
config
.
parallel_iterations
)
else
:
ll_per_seq
,
_
,
_
,
_
=
bounds
.
fivo
(
model
,
observations
,
lengths
,
num_samples
=
config
.
num_samples
,
resampling_criterion
=
smc
.
ess_criterion
,
resampling_type
=
config
.
resampling_type
,
random_seed
=
config
.
random_seed
,
parallel_iterations
=
config
.
parallel_iterations
)
ll_per_t
=
tf
.
reduce_mean
(
ll_per_seq
/
tf
.
to_float
(
lengths
))
# Compute the data's true likelihood under the model and the bound gap.
true_ll_per_seq
=
model
.
likelihood
(
tf
.
squeeze
(
observations
))
true_ll_per_t
=
tf
.
reduce_mean
(
true_ll_per_seq
/
tf
.
to_float
(
lengths
))
bound_gap
=
true_ll_per_seq
-
ll_per_seq
bound_gap
=
tf
.
reduce_mean
(
bound_gap
/
tf
.
to_float
(
lengths
))
tf
.
summary
.
scalar
(
"train_ll_bound"
,
ll_per_t
)
tf
.
summary
.
scalar
(
"train_true_ll"
,
true_ll_per_t
)
tf
.
summary
.
scalar
(
"bound_gap"
,
bound_gap
)
return
-
ll_per_t
,
ll_per_t
,
true_ll_per_t
,
bound_gap
def
create_graph
():
"""Creates the training graph."""
global_step
=
tf
.
train
.
get_or_create_global_step
()
xs
,
lengths
=
datasets
.
create_chain_graph_dataset
(
config
.
batch_size
,
config
.
num_timesteps
,
steps_per_observation
=
1
,
state_size
=
1
,
transition_variance
=
config
.
variance
,
observation_variance
=
config
.
variance
)
model
=
ghmm
.
TrainableGaussianHMM
(
config
.
num_timesteps
,
config
.
proposal_type
,
transition_variances
=
config
.
variance
,
emission_variances
=
config
.
variance
,
random_seed
=
config
.
random_seed
)
loss
,
bound
,
true_ll
,
gap
=
create_losses
(
model
,
xs
,
lengths
)
opt
=
tf
.
train
.
AdamOptimizer
(
config
.
learning_rate
)
grads
=
opt
.
compute_gradients
(
loss
,
var_list
=
tf
.
trainable_variables
())
train_op
=
opt
.
apply_gradients
(
grads
,
global_step
=
global_step
)
return
bound
,
true_ll
,
gap
,
train_op
,
global_step
with
tf
.
Graph
().
as_default
():
if
config
.
random_seed
:
tf
.
set_random_seed
(
config
.
random_seed
)
np
.
random
.
seed
(
config
.
random_seed
)
bound
,
true_ll
,
gap
,
train_op
,
global_step
=
create_graph
()
log_hook
=
create_logging_hook
(
global_step
,
bound
,
true_ll
,
gap
)
with
tf
.
train
.
MonitoredTrainingSession
(
master
=
""
,
hooks
=
[
log_hook
],
checkpoint_dir
=
config
.
logdir
,
save_checkpoint_secs
=
120
,
save_summaries_steps
=
config
.
summarize_every
,
log_step_count_steps
=
config
.
summarize_every
*
20
)
as
sess
:
cur_step
=
-
1
while
cur_step
<=
config
.
max_steps
and
not
sess
.
should_stop
():
cur_step
=
sess
.
run
(
global_step
)
_
,
cur_step
=
sess
.
run
([
train_op
,
global_step
])
def
run_eval
(
config
):
"""Evaluates a Gaussian HMM using the given config."""
def
create_bound
(
model
,
xs
,
lengths
):
"""Creates the bound to be evaluated."""
if
config
.
bound
==
"elbo"
:
ll_per_seq
,
log_weights
,
_
=
bounds
.
iwae
(
model
,
xs
,
lengths
,
num_samples
=
1
,
parallel_iterations
=
config
.
parallel_iterations
)
elif
config
.
bound
==
"iwae"
:
ll_per_seq
,
log_weights
,
_
=
bounds
.
iwae
(
model
,
xs
,
lengths
,
num_samples
=
config
.
num_samples
,
parallel_iterations
=
config
.
parallel_iterations
)
elif
config
.
bound
==
"fivo"
:
ll_per_seq
,
log_weights
,
resampled
,
_
=
bounds
.
fivo
(
model
,
xs
,
lengths
,
num_samples
=
config
.
num_samples
,
resampling_criterion
=
smc
.
ess_criterion
,
resampling_type
=
config
.
resampling_type
,
random_seed
=
config
.
random_seed
,
parallel_iterations
=
config
.
parallel_iterations
)
# Compute bound scaled by number of timesteps.
bound_per_t
=
ll_per_seq
/
tf
.
to_float
(
lengths
)
if
config
.
bound
==
"fivo"
:
return
bound_per_t
,
log_weights
,
resampled
else
:
return
bound_per_t
,
log_weights
def
create_graph
():
"""Creates the dataset, model, and bound."""
xs
,
lengths
=
datasets
.
create_chain_graph_dataset
(
config
.
batch_size
,
config
.
num_timesteps
,
steps_per_observation
=
1
,
state_size
=
1
,
transition_variance
=
config
.
variance
,
observation_variance
=
config
.
variance
)
model
=
ghmm
.
TrainableGaussianHMM
(
config
.
num_timesteps
,
config
.
proposal_type
,
transition_variances
=
config
.
variance
,
emission_variances
=
config
.
variance
,
random_seed
=
config
.
random_seed
)
true_likelihood
=
tf
.
reduce_mean
(
model
.
likelihood
(
tf
.
squeeze
(
xs
))
/
tf
.
to_float
(
lengths
))
outs
=
[
true_likelihood
]
outs
.
extend
(
list
(
create_bound
(
model
,
xs
,
lengths
)))
return
outs
with
tf
.
Graph
().
as_default
():
if
config
.
random_seed
:
tf
.
set_random_seed
(
config
.
random_seed
)
np
.
random
.
seed
(
config
.
random_seed
)
graph_outs
=
create_graph
()
with
tf
.
train
.
SingularMonitoredSession
(
checkpoint_dir
=
config
.
logdir
)
as
sess
:
outs
=
sess
.
run
(
graph_outs
)
likelihood
=
outs
[
0
]
avg_bound
=
np
.
mean
(
outs
[
1
])
std
=
np
.
std
(
outs
[
1
])
log_weights
=
outs
[
2
]
log_weight_variances
=
np
.
var
(
log_weights
,
axis
=
2
)
avg_log_weight_variance
=
np
.
var
(
log_weight_variances
,
axis
=
1
)
avg_log_weight
=
np
.
mean
(
log_weights
,
axis
=
(
1
,
2
))
data
=
{
"mean"
:
avg_bound
,
"std"
:
std
,
"log_weights"
:
log_weights
,
"log_weight_means"
:
avg_log_weight
,
"log_weight_variances"
:
avg_log_weight_variance
}
if
len
(
outs
)
==
4
:
data
[
"resampled"
]
=
outs
[
3
]
data
[
"avg_resampled"
]
=
np
.
mean
(
outs
[
3
],
axis
=
1
)
# Log some useful statistics.
tf
.
logging
.
info
(
"Evaled bound %s with batch_size: %d, num_samples: %d."
%
(
config
.
bound
,
config
.
batch_size
,
config
.
num_samples
))
tf
.
logging
.
info
(
"mean: %f, std: %f"
%
(
avg_bound
,
std
))
tf
.
logging
.
info
(
"true likelihood: %s"
%
likelihood
)
tf
.
logging
.
info
(
"avg log weight: %s"
%
avg_log_weight
)
tf
.
logging
.
info
(
"log weight variance: %s"
%
avg_log_weight_variance
)
if
len
(
outs
)
==
4
:
tf
.
logging
.
info
(
"avg resamples per t: %s"
%
data
[
"avg_resampled"
])
if
not
tf
.
gfile
.
Exists
(
config
.
logdir
):
tf
.
gfile
.
MakeDirs
(
config
.
logdir
)
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
config
.
logdir
,
"out.npz"
),
"w"
)
as
fout
:
np
.
save
(
fout
,
data
)
research/fivo/fivo/ghmm_runners_test.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.ghmm_runners."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
tensorflow
as
tf
from
fivo
import
ghmm_runners
class
GHMMRunnersTest
(
tf
.
test
.
TestCase
):
def
default_config
(
self
):
class
Config
(
object
):
pass
config
=
Config
()
config
.
model
=
"ghmm"
config
.
bound
=
"fivo"
config
.
proposal_type
=
"prior"
config
.
batch_size
=
4
config
.
num_samples
=
4
config
.
num_timesteps
=
10
config
.
variance
=
0.1
config
.
resampling_type
=
"multinomial"
config
.
random_seed
=
1234
config
.
parallel_iterations
=
1
config
.
learning_rate
=
1e-4
config
.
summarize_every
=
1
config
.
max_steps
=
1
return
config
def
test_eval_ghmm_notraining_fivo_prior
(
self
):
self
.
eval_ghmm_notraining
(
"fivo"
,
"prior"
,
-
3.063864
)
def
test_eval_ghmm_notraining_fivo_true_filtering
(
self
):
self
.
eval_ghmm_notraining
(
"fivo"
,
"true-filtering"
,
-
1.1409812
)
def
test_eval_ghmm_notraining_fivo_true_smoothing
(
self
):
self
.
eval_ghmm_notraining
(
"fivo"
,
"true-smoothing"
,
-
0.85592091
)
def
test_eval_ghmm_notraining_iwae_prior
(
self
):
self
.
eval_ghmm_notraining
(
"iwae"
,
"prior"
,
-
5.9730167
)
def
test_eval_ghmm_notraining_iwae_true_filtering
(
self
):
self
.
eval_ghmm_notraining
(
"iwae"
,
"true-filtering"
,
-
1.1485999
)
def
test_eval_ghmm_notraining_iwae_true_smoothing
(
self
):
self
.
eval_ghmm_notraining
(
"iwae"
,
"true-smoothing"
,
-
0.85592091
)
def
eval_ghmm_notraining
(
self
,
bound
,
proposal_type
,
expected_bound_avg
):
config
=
self
.
default_config
()
config
.
proposal_type
=
proposal_type
config
.
bound
=
bound
config
.
logdir
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
"test-ghmm-%s-%s"
%
(
proposal_type
,
bound
))
ghmm_runners
.
run_eval
(
config
)
data
=
np
.
load
(
os
.
path
.
join
(
config
.
logdir
,
"out.npz"
)).
item
()
self
.
assertAlmostEqual
(
expected_bound_avg
,
data
[
"mean"
],
places
=
3
)
def
test_train_ghmm_for_one_step_and_eval_fivo_filtering
(
self
):
self
.
train_ghmm_for_one_step_and_eval
(
"fivo"
,
"filtering"
,
-
16.727108
)
def
test_train_ghmm_for_one_step_and_eval_fivo_smoothing
(
self
):
self
.
train_ghmm_for_one_step_and_eval
(
"fivo"
,
"smoothing"
,
-
19.381277
)
def
test_train_ghmm_for_one_step_and_eval_iwae_filtering
(
self
):
self
.
train_ghmm_for_one_step_and_eval
(
"iwae"
,
"filtering"
,
-
33.31966
)
def
test_train_ghmm_for_one_step_and_eval_iwae_smoothing
(
self
):
self
.
train_ghmm_for_one_step_and_eval
(
"iwae"
,
"smoothing"
,
-
46.388447
)
def
train_ghmm_for_one_step_and_eval
(
self
,
bound
,
proposal_type
,
expected_bound_avg
):
config
=
self
.
default_config
()
config
.
proposal_type
=
proposal_type
config
.
bound
=
bound
config
.
max_steps
=
1
config
.
logdir
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
"test-ghmm-training-%s-%s"
%
(
proposal_type
,
bound
))
ghmm_runners
.
run_train
(
config
)
ghmm_runners
.
run_eval
(
config
)
data
=
np
.
load
(
os
.
path
.
join
(
config
.
logdir
,
"out.npz"
)).
item
()
self
.
assertAlmostEqual
(
expected_bound_avg
,
data
[
"mean"
],
places
=
2
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/fivo/fivo/models/__init__.py
deleted
100644 → 0
View file @
5eb294f8
research/fivo/fivo/models/base.py
deleted
100644 → 0
View file @
5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reusable model classes for FIVO."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
sonnet
as
snt
import
tensorflow
as
tf
from
fivo
import
nested_utils
as
nested
tfd
=
tf
.
contrib
.
distributions
class
ELBOTrainableSequenceModel
(
object
):
"""An abstract class for ELBO-trainable sequence models to extend.
Because the ELBO, IWAE, and FIVO bounds all accept the same arguments,
any model that is ELBO-trainable is also IWAE- and FIVO-trainable.
"""
def
zero_state
(
self
,
batch_size
,
dtype
):
"""Returns the initial state of the model as a Tensor or tuple of Tensors.
Args:
batch_size: The batch size.
dtype: The datatype to use for the state.
"""
raise
NotImplementedError
(
"zero_state not yet implemented."
)
def
set_observations
(
self
,
observations
,
seq_lengths
):
"""Sets the observations for the model.
This method provides the model with all observed variables including both
inputs and targets. It will be called before running any computations with
the model that require the observations, e.g. training the model or
computing bounds, and should be used to run any necessary preprocessing
steps.
Args:
observations: A potentially nested set of Tensors containing
all observations for the model, both inputs and targets. Typically
a set of Tensors with shape [max_seq_len, batch_size, data_size].
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
"""
self
.
observations
=
observations
self
.
max_seq_len
=
tf
.
reduce_max
(
seq_lengths
)
self
.
observations_ta
=
nested
.
tas_for_tensors
(
observations
,
self
.
max_seq_len
,
clear_after_read
=
False
)
self
.
seq_lengths
=
seq_lengths
def
propose_and_weight
(
self
,
state
,
t
):
"""Propogates model state one timestep and computes log weights.
This method accepts the current state of the model and computes the state
for the next timestep as well as the incremental log weight of each
element in the batch.
Args:
state: The current state of the model.
t: A scalar integer Tensor representing the current timestep.
Returns:
next_state: The state of the model after one timestep.
log_weights: A [batch_size] Tensor containing the incremental log weights.
"""
raise
NotImplementedError
(
"propose_and_weight not yet implemented."
)
DEFAULT_INITIALIZERS
=
{
"w"
:
tf
.
contrib
.
layers
.
xavier_initializer
(),
"b"
:
tf
.
zeros_initializer
()}
class
ConditionalNormalDistribution
(
object
):
"""A Normal distribution conditioned on Tensor inputs via a fc network."""
def
__init__
(
self
,
size
,
hidden_layer_sizes
,
sigma_min
=
0.0
,
raw_sigma_bias
=
0.25
,
hidden_activation_fn
=
tf
.
nn
.
relu
,
initializers
=
None
,
name
=
"conditional_normal_distribution"
):
"""Creates a conditional Normal distribution.
Args:
size: The dimension of the random variable.
hidden_layer_sizes: The sizes of the hidden layers of the fully connected
network used to condition the distribution on the inputs.
sigma_min: The minimum standard deviation allowed, a scalar.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the fully connected network. Set to 0.25 by default to
prevent standard deviations close to 0.
hidden_activation_fn: The activation function to use on the hidden layers
of the fully connected network.
initializers: The variable intitializers to use for the fully connected
network. The network is implemented using snt.nets.MLP so it must
be a dictionary mapping the keys 'w' and 'b' to the initializers for
the weights and biases. Defaults to xavier for the weights and zeros
for the biases when initializers is None.
name: The name of this distribution, used for sonnet scoping.
"""
self
.
sigma_min
=
sigma_min
self
.
raw_sigma_bias
=
raw_sigma_bias
self
.
name
=
name
self
.
size
=
size
if
initializers
is
None
:
initializers
=
DEFAULT_INITIALIZERS
self
.
fcnet
=
snt
.
nets
.
MLP
(
output_sizes
=
hidden_layer_sizes
+
[
2
*
size
],
activation
=
hidden_activation_fn
,
initializers
=
initializers
,
activate_final
=
False
,
use_bias
=
True
,
name
=
name
+
"_fcnet"
)
def
condition
(
self
,
tensor_list
,
**
unused_kwargs
):
"""Computes the parameters of a normal distribution based on the inputs."""
inputs
=
tf
.
concat
(
tensor_list
,
axis
=
1
)
outs
=
self
.
fcnet
(
inputs
)
mu
,
sigma
=
tf
.
split
(
outs
,
2
,
axis
=
1
)
sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
sigma
+
self
.
raw_sigma_bias
),
self
.
sigma_min
)
return
mu
,
sigma
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""Creates a normal distribution conditioned on the inputs."""
mu
,
sigma
=
self
.
condition
(
args
,
**
kwargs
)
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
sigma
)
class
ConditionalBernoulliDistribution
(
object
):
"""A Bernoulli distribution conditioned on Tensor inputs via a fc net."""
def
__init__
(
self
,
size
,
hidden_layer_sizes
,
hidden_activation_fn
=
tf
.
nn
.
relu
,
initializers
=
None
,
bias_init
=
0.0
,
name
=
"conditional_bernoulli_distribution"
):
"""Creates a conditional Bernoulli distribution.
Args:
size: The dimension of the random variable.
hidden_layer_sizes: The sizes of the hidden layers of the fully connected
network used to condition the distribution on the inputs.
hidden_activation_fn: The activation function to use on the hidden layers
of the fully connected network.
initializers: The variable intiializers to use for the fully connected
network. The network is implemented using snt.nets.MLP so it must
be a dictionary mapping the keys 'w' and 'b' to the initializers for
the weights and biases. Defaults to xavier for the weights and zeros
for the biases when initializers is None.
bias_init: A scalar or vector Tensor that is added to the output of the
fully-connected network that parameterizes the mean of this
distribution.
name: The name of this distribution, used for sonnet scoping.
"""
self
.
bias_init
=
bias_init
self
.
size
=
size
if
initializers
is
None
:
initializers
=
DEFAULT_INITIALIZERS
self
.
fcnet
=
snt
.
nets
.
MLP
(
output_sizes
=
hidden_layer_sizes
+
[
size
],
activation
=
hidden_activation_fn
,
initializers
=
initializers
,
activate_final
=
False
,
use_bias
=
True
,
name
=
name
+
"_fcnet"
)
def
condition
(
self
,
tensor_list
):
"""Computes the p parameter of the Bernoulli distribution."""
inputs
=
tf
.
concat
(
tensor_list
,
axis
=
1
)
return
self
.
fcnet
(
inputs
)
+
self
.
bias_init
def
__call__
(
self
,
*
args
):
p
=
self
.
condition
(
args
)
return
tf
.
contrib
.
distributions
.
Bernoulli
(
logits
=
p
)
class
NormalApproximatePosterior
(
ConditionalNormalDistribution
):
"""A Normally-distributed approx. posterior with res_q parameterization."""
def
__init__
(
self
,
size
,
hidden_layer_sizes
,
sigma_min
=
0.0
,
raw_sigma_bias
=
0.25
,
hidden_activation_fn
=
tf
.
nn
.
relu
,
initializers
=
None
,
smoothing
=
False
,
name
=
"conditional_normal_distribution"
):
super
(
NormalApproximatePosterior
,
self
).
__init__
(
size
,
hidden_layer_sizes
,
sigma_min
=
sigma_min
,
raw_sigma_bias
=
raw_sigma_bias
,
hidden_activation_fn
=
hidden_activation_fn
,
initializers
=
initializers
,
name
=
name
)
self
.
smoothing
=
smoothing
def
condition
(
self
,
tensor_list
,
prior_mu
,
smoothing_tensors
=
None
):
"""Generates the mean and variance of the normal distribution.
Args:
tensor_list: The list of Tensors to condition on. Will be concatenated and
fed through a fully connected network.
prior_mu: The mean of the prior distribution associated with this
approximate posterior. Will be added to the mean produced by
this approximate posterior, in res_q fashion.
smoothing_tensors: A list of Tensors. If smoothing is True, these Tensors
will be concatenated with the tensors in tensor_list.
Returns:
mu: The mean of the approximate posterior.
sigma: The standard deviation of the approximate posterior.
"""
if
self
.
smoothing
:
tensor_list
.
extend
(
smoothing_tensors
)
mu
,
sigma
=
super
(
NormalApproximatePosterior
,
self
).
condition
(
tensor_list
)
return
mu
+
prior_mu
,
sigma
class
NonstationaryLinearDistribution
(
object
):
"""A set of loc-scale distributions that are linear functions of inputs.
This class defines a series of location-scale distributions such that
the means are learnable linear functions of the inputs and the log variances
are learnable constants. The functions and log variances are different across
timesteps, allowing the distributions to be nonstationary.
"""
def
__init__
(
self
,
num_timesteps
,
inputs_per_timestep
=
None
,
outputs_per_timestep
=
None
,
initializers
=
None
,
variance_min
=
0.0
,
output_distribution
=
tfd
.
Normal
,
dtype
=
tf
.
float32
):
"""Creates a NonstationaryLinearDistribution.
Args:
num_timesteps: The number of timesteps, i.e. the number of distributions.
inputs_per_timestep: A list of python ints, the dimension of inputs to the
linear function at each timestep. If not provided, the dimension at each
timestep is assumed to be 1.
outputs_per_timestep: A list of python ints, the dimension of the output
distribution at each timestep. If not provided, the dimension at each
timestep is assumed to be 1.
initializers: A dictionary containing intializers for the variables. The
initializer under the key 'w' is used for the weights in the linear
function and the initializer under the key 'b' is used for the biases.
Defaults to xavier initialization for the weights and zeros for the
biases.
variance_min: Python float, the minimum variance of each distribution.
output_distribution: A locatin-scale subclass of tfd.Distribution that
defines the output distribution, e.g. Normal.
dtype: The dtype of the weights and biases.
"""
if
not
initializers
:
initializers
=
DEFAULT_INITIALIZERS
if
not
inputs_per_timestep
:
inputs_per_timestep
=
[
1
]
*
num_timesteps
if
not
outputs_per_timestep
:
outputs_per_timestep
=
[
1
]
*
num_timesteps
self
.
num_timesteps
=
num_timesteps
self
.
variance_min
=
variance_min
self
.
initializers
=
initializers
self
.
dtype
=
dtype
self
.
output_distribution
=
output_distribution
def
_get_variables_ta
(
shapes
,
name
,
initializer
,
trainable
=
True
):
"""Creates a sequence of variables and stores them in a TensorArray."""
# Infer shape if all shapes are equal.
first_shape
=
shapes
[
0
]
infer_shape
=
all
(
shape
==
first_shape
for
shape
in
shapes
)
ta
=
tf
.
TensorArray
(
dtype
=
dtype
,
size
=
len
(
shapes
),
dynamic_size
=
False
,
clear_after_read
=
False
,
infer_shape
=
infer_shape
)
for
t
,
shape
in
enumerate
(
shapes
):
var
=
tf
.
get_variable
(
name
%
t
,
shape
=
shape
,
initializer
=
initializer
,
trainable
=
trainable
)
ta
=
ta
.
write
(
t
,
var
)
return
ta
bias_shapes
=
[[
num_outputs
]
for
num_outputs
in
outputs_per_timestep
]
self
.
log_variances
=
_get_variables_ta
(
bias_shapes
,
"proposal_log_variance_%d"
,
initializers
[
"b"
])
self
.
mean_biases
=
_get_variables_ta
(
bias_shapes
,
"proposal_b_%d"
,
initializers
[
"b"
])
weight_shapes
=
zip
(
inputs_per_timestep
,
outputs_per_timestep
)
self
.
mean_weights
=
_get_variables_ta
(
weight_shapes
,
"proposal_w_%d"
,
initializers
[
"w"
])
self
.
shapes
=
tf
.
TensorArray
(
dtype
=
tf
.
int32
,
size
=
num_timesteps
,
dynamic_size
=
False
,
clear_after_read
=
False
).
unstack
(
weight_shapes
)
def
__call__
(
self
,
t
,
inputs
):
"""Computes the distribution at timestep t.
Args:
t: Scalar integer Tensor, the current timestep. Must be in
[0, num_timesteps).
inputs: The inputs to the linear function parameterizing the mean of
the current distribution. A Tensor of shape [batch_size, num_inputs_t].
Returns:
A tfd.Distribution subclass representing the distribution at timestep t.
"""
b
=
self
.
mean_biases
.
read
(
t
)
w
=
self
.
mean_weights
.
read
(
t
)
shape
=
self
.
shapes
.
read
(
t
)
w
=
tf
.
reshape
(
w
,
shape
)
b
=
tf
.
reshape
(
b
,
[
shape
[
1
],
1
])
log_variance
=
self
.
log_variances
.
read
(
t
)
scale
=
tf
.
sqrt
(
tf
.
maximum
(
tf
.
exp
(
log_variance
),
self
.
variance_min
))
loc
=
tf
.
matmul
(
w
,
inputs
,
transpose_a
=
True
)
+
b
return
self
.
output_distribution
(
loc
=
loc
,
scale
=
scale
)
def
encode_all
(
inputs
,
encoder
):
"""Encodes a timeseries of inputs with a time independent encoder.
Args:
inputs: A [time, batch, feature_dimensions] tensor.
encoder: A network that takes a [batch, features_dimensions] input and
encodes the input.
Returns:
A [time, batch, encoded_feature_dimensions] output tensor.
"""
input_shape
=
tf
.
shape
(
inputs
)
num_timesteps
,
batch_size
=
input_shape
[
0
],
input_shape
[
1
]
reshaped_inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
inputs
.
shape
[
-
1
]])
inputs_encoded
=
encoder
(
reshaped_inputs
)
inputs_encoded
=
tf
.
reshape
(
inputs_encoded
,
[
num_timesteps
,
batch_size
,
encoder
.
output_size
])
return
inputs_encoded
def
ta_for_tensor
(
x
,
**
kwargs
):
"""Creates a TensorArray for the input tensor."""
return
tf
.
TensorArray
(
x
.
dtype
,
tf
.
shape
(
x
)[
0
],
dynamic_size
=
False
,
**
kwargs
).
unstack
(
x
)
Prev
1
…
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment