Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
27b4acd4
Commit
27b4acd4
authored
Sep 25, 2018
by
Aman Gupta
Browse files
Merge remote-tracking branch 'upstream/master'
parents
5133522f
d4e1f97f
Changes
240
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5457 additions
and
4 deletions
+5457
-4
research/fivo/experimental/models.py
research/fivo/experimental/models.py
+1227
-0
research/fivo/experimental/run.sh
research/fivo/experimental/run.sh
+54
-0
research/fivo/experimental/summary_utils.py
research/fivo/experimental/summary_utils.py
+332
-0
research/fivo/experimental/train.py
research/fivo/experimental/train.py
+637
-0
research/fivo/fivo/__init__.py
research/fivo/fivo/__init__.py
+0
-0
research/fivo/fivo/bounds.py
research/fivo/fivo/bounds.py
+317
-0
research/fivo/fivo/bounds_test.py
research/fivo/fivo/bounds_test.py
+183
-0
research/fivo/fivo/data/__init__.py
research/fivo/fivo/data/__init__.py
+0
-0
research/fivo/fivo/data/calculate_pianoroll_mean.py
research/fivo/fivo/data/calculate_pianoroll_mean.py
+1
-1
research/fivo/fivo/data/create_timit_dataset.py
research/fivo/fivo/data/create_timit_dataset.py
+1
-2
research/fivo/fivo/data/datasets.py
research/fivo/fivo/data/datasets.py
+231
-1
research/fivo/fivo/data/datasets_test.py
research/fivo/fivo/data/datasets_test.py
+303
-0
research/fivo/fivo/ghmm_runners.py
research/fivo/fivo/ghmm_runners.py
+235
-0
research/fivo/fivo/ghmm_runners_test.py
research/fivo/fivo/ghmm_runners_test.py
+106
-0
research/fivo/fivo/models/__init__.py
research/fivo/fivo/models/__init__.py
+0
-0
research/fivo/fivo/models/base.py
research/fivo/fivo/models/base.py
+342
-0
research/fivo/fivo/models/ghmm.py
research/fivo/fivo/models/ghmm.py
+483
-0
research/fivo/fivo/models/ghmm_test.py
research/fivo/fivo/models/ghmm_test.py
+313
-0
research/fivo/fivo/models/srnn.py
research/fivo/fivo/models/srnn.py
+587
-0
research/fivo/fivo/models/srnn_test.py
research/fivo/fivo/models/srnn_test.py
+105
-0
No files found.
research/fivo/experimental/models.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
functools
import
sonnet
as
snt
import
tensorflow
as
tf
import
numpy
as
np
import
math
SQUARED_OBSERVATION
=
"squared"
ABS_OBSERVATION
=
"abs"
STANDARD_OBSERVATION
=
"standard"
OBSERVATION_TYPES
=
[
SQUARED_OBSERVATION
,
ABS_OBSERVATION
,
STANDARD_OBSERVATION
]
ROUND_TRANSITION
=
"round"
STANDARD_TRANSITION
=
"standard"
TRANSITION_TYPES
=
[
ROUND_TRANSITION
,
STANDARD_TRANSITION
]
class
Q
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
init_mu0_to_zero
=
False
,
graph_collection_name
=
"Q_VARS"
):
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
graph_collection_name
=
graph_collection_name
initializers
=
[]
for
t
in
xrange
(
num_timesteps
):
if
t
==
0
and
init_mu0_to_zero
:
initializers
.
append
(
{
"w"
:
tf
.
zeros_initializer
,
"b"
:
tf
.
zeros_initializer
})
else
:
initializers
.
append
(
{
"w"
:
tf
.
random_uniform_initializer
(
seed
=
random_seed
),
"b"
:
tf
.
zeros_initializer
})
def
custom_getter
(
getter
,
*
args
,
**
kwargs
):
out
=
getter
(
*
args
,
**
kwargs
)
ref
=
tf
.
get_collection_ref
(
self
.
graph_collection_name
)
if
out
not
in
ref
:
ref
.
append
(
out
)
return
out
self
.
mus
=
[
snt
.
Linear
(
output_size
=
state_size
,
initializers
=
initializers
[
t
],
name
=
"q_mu_%d"
%
t
,
custom_getter
=
custom_getter
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"q_sigma_%d"
%
(
t
+
1
),
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
initializer
=
tf
.
random_uniform_initializer
(
seed
=
random_seed
))
for
t
in
xrange
(
num_timesteps
)
]
def
q_zt
(
self
,
observation
,
prev_state
,
t
):
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
q_mu
=
self
.
mus
[
t
](
tf
.
concat
([
observation
,
prev_state
],
axis
=
1
))
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
return
q_zt
def
summarize_weights
(
self
):
for
t
,
sigma
in
enumerate
(
self
.
sigmas
):
tf
.
summary
.
scalar
(
"q_sigma/%d"
%
t
,
sigma
[
0
])
for
t
,
f
in
enumerate
(
self
.
mus
):
tf
.
summary
.
scalar
(
"q_mu/b_%d"
%
t
,
f
.
b
[
0
])
tf
.
summary
.
scalar
(
"q_mu/w_obs_%d"
%
t
,
f
.
w
[
0
,
0
])
if
t
!=
0
:
tf
.
summary
.
scalar
(
"q_mu/w_prev_state_%d"
%
t
,
f
.
w
[
1
,
0
])
class
PreviousStateQ
(
Q
):
def
q_zt
(
self
,
unused_observation
,
prev_state
,
t
):
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
q_mu
=
self
.
mus
[
t
](
prev_state
)
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
return
q_zt
def
summarize_weights
(
self
):
for
t
,
sigma
in
enumerate
(
self
.
sigmas
):
tf
.
summary
.
scalar
(
"q_sigma/%d"
%
t
,
sigma
[
0
])
for
t
,
f
in
enumerate
(
self
.
mus
):
tf
.
summary
.
scalar
(
"q_mu/b_%d"
%
t
,
f
.
b
[
0
])
tf
.
summary
.
scalar
(
"q_mu/w_prev_state_%d"
%
t
,
f
.
w
[
0
,
0
])
class
ObservationQ
(
Q
):
def
q_zt
(
self
,
observation
,
prev_state
,
t
):
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
q_mu
=
self
.
mus
[
t
](
observation
)
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
return
q_zt
def
summarize_weights
(
self
):
for
t
,
sigma
in
enumerate
(
self
.
sigmas
):
tf
.
summary
.
scalar
(
"q_sigma/%d"
%
t
,
sigma
[
0
])
for
t
,
f
in
enumerate
(
self
.
mus
):
tf
.
summary
.
scalar
(
"q_mu/b_%d"
%
t
,
f
.
b
[
0
])
tf
.
summary
.
scalar
(
"q_mu/w_obs_%d"
%
t
,
f
.
w
[
0
,
0
])
class
SimpleMeanQ
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
init_mu0_to_zero
=
False
,
graph_collection_name
=
"Q_VARS"
):
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
graph_collection_name
=
graph_collection_name
initializers
=
[]
for
t
in
xrange
(
num_timesteps
):
if
t
==
0
and
init_mu0_to_zero
:
initializers
.
append
(
tf
.
zeros_initializer
)
else
:
initializers
.
append
(
tf
.
random_uniform_initializer
(
seed
=
random_seed
))
self
.
mus
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"q_mu_%d"
%
(
t
+
1
),
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
initializer
=
initializers
[
t
])
for
t
in
xrange
(
num_timesteps
)
]
self
.
sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"q_sigma_%d"
%
(
t
+
1
),
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
initializer
=
tf
.
random_uniform_initializer
(
seed
=
random_seed
))
for
t
in
xrange
(
num_timesteps
)
]
def
q_zt
(
self
,
unused_observation
,
prev_state
,
t
):
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
q_mu
=
tf
.
tile
(
self
.
mus
[
t
][
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
return
q_zt
def
summarize_weights
(
self
):
for
t
,
sigma
in
enumerate
(
self
.
sigmas
):
tf
.
summary
.
scalar
(
"q_sigma/%d"
%
t
,
sigma
[
0
])
for
t
,
f
in
enumerate
(
self
.
mus
):
tf
.
summary
.
scalar
(
"q_mu/%d"
%
t
,
f
[
0
])
class
R
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
sigma_init
=
1.
,
random_seed
=
None
,
graph_collection_name
=
"R_VARS"
):
self
.
dtype
=
dtype
self
.
sigma_min
=
sigma_min
initializers
=
{
"w"
:
tf
.
truncated_normal_initializer
(
seed
=
random_seed
),
"b"
:
tf
.
zeros_initializer
}
self
.
graph_collection_name
=
graph_collection_name
def
custom_getter
(
getter
,
*
args
,
**
kwargs
):
out
=
getter
(
*
args
,
**
kwargs
)
ref
=
tf
.
get_collection_ref
(
self
.
graph_collection_name
)
if
out
not
in
ref
:
ref
.
append
(
out
)
return
out
self
.
mus
=
[
snt
.
Linear
(
output_size
=
state_size
,
initializers
=
initializers
,
name
=
"r_mu_%d"
%
t
,
custom_getter
=
custom_getter
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"r_sigma_%d"
%
(
t
+
1
),
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer
=
tf
.
constant_initializer
(
sigma_init
))
for
t
in
xrange
(
num_timesteps
)
]
def
r_xn
(
self
,
z_t
,
t
):
batch_size
=
tf
.
shape
(
z_t
)[
0
]
r_mu
=
self
.
mus
[
t
](
z_t
)
r_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
r_sigma
=
tf
.
tile
(
r_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
r_mu
,
scale
=
tf
.
sqrt
(
r_sigma
))
def
summarize_weights
(
self
):
for
t
in
range
(
len
(
self
.
mus
)
-
1
):
tf
.
summary
.
scalar
(
"r_mu/%d"
%
t
,
self
.
mus
[
t
][
0
])
tf
.
summary
.
scalar
(
"r_sigma/%d"
%
t
,
self
.
sigmas
[
t
][
0
])
class
P
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
variance
=
1.0
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
trainable
=
True
,
init_bs_to_zero
=
False
,
graph_collection_name
=
"P_VARS"
):
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
variance
=
variance
self
.
graph_collection_name
=
graph_collection_name
if
init_bs_to_zero
:
initializers
=
[
tf
.
zeros_initializer
for
_
in
xrange
(
num_timesteps
)]
else
:
initializers
=
[
tf
.
random_uniform_initializer
(
seed
=
random_seed
)
for
_
in
xrange
(
num_timesteps
)]
self
.
bs
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"p_b_%d"
%
(
t
+
1
),
initializer
=
initializers
[
t
],
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
trainable
=
trainable
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
Bs
=
tf
.
cumsum
(
self
.
bs
,
reverse
=
True
,
axis
=
0
)
def
posterior
(
self
,
observation
,
prev_state
,
t
):
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu
=
observation
-
self
.
Bs
[
t
]
if
t
>
0
:
mu
+=
(
prev_state
+
self
.
bs
[
t
-
1
])
*
float
(
self
.
num_timesteps
-
t
)
mu
/=
float
(
self
.
num_timesteps
-
t
+
1
)
sigma
=
tf
.
ones_like
(
mu
)
*
self
.
variance
*
(
float
(
self
.
num_timesteps
-
t
)
/
float
(
self
.
num_timesteps
-
t
+
1
))
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
def
lookahead
(
self
,
state
,
t
):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu
=
state
+
self
.
Bs
[
t
]
sigma
=
tf
.
ones_like
(
state
)
*
self
.
variance
*
float
(
self
.
num_timesteps
-
t
)
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
def
likelihood
(
self
,
observation
):
batch_size
=
tf
.
shape
(
observation
)[
0
]
mu
=
tf
.
tile
(
tf
.
reduce_sum
(
self
.
bs
,
axis
=
0
)[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
sigma
=
tf
.
ones_like
(
mu
)
*
self
.
variance
*
(
self
.
num_timesteps
+
1
)
dist
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
# Average over the batch and take the sum over the state size
return
tf
.
reduce_mean
(
tf
.
reduce_sum
(
dist
.
log_prob
(
observation
),
axis
=
1
))
def
p_zt
(
self
,
prev_state
,
t
):
"""Computes the model p(z_t| z_{t-1})."""
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
if
t
>
0
:
z_mu_p
=
prev_state
+
self
.
bs
[
t
-
1
]
else
:
# p(z_0) is Normal(0,1)
z_mu_p
=
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
p_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
z_mu_p
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
z_mu_p
)
*
self
.
variance
))
return
p_zt
def
generative
(
self
,
unused_observation
,
z_nm1
):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu
=
z_nm1
+
self
.
bs
[
-
1
]
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
generative_p_mu
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
generative_p_mu
)
*
self
.
variance
))
class
ShortChainNonlinearP
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
variance
=
1.0
,
observation_variance
=
1.0
,
transition_type
=
STANDARD_TRANSITION
,
transition_dist
=
tf
.
contrib
.
distributions
.
Normal
,
dtype
=
tf
.
float32
,
random_seed
=
None
):
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
variance
=
variance
self
.
observation_variance
=
observation_variance
self
.
transition_type
=
transition_type
self
.
transition_dist
=
transition_dist
def
p_zt
(
self
,
prev_state
,
t
):
"""Computes the model p(z_t| z_{t-1})."""
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
if
t
>
0
:
if
self
.
transition_type
==
ROUND_TRANSITION
:
loc
=
tf
.
round
(
prev_state
)
tf
.
logging
.
info
(
"p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)"
%
(
t
,
t
-
1
,
t
-
1
,
self
.
variance
))
elif
self
.
transition_type
==
STANDARD_TRANSITION
:
loc
=
prev_state
tf
.
logging
.
info
(
"p(z_%d | z_%d) ~ N(z_%d, %0.1f)"
%
(
t
,
t
-
1
,
t
-
1
,
self
.
variance
))
else
:
# p(z_0) is Normal(0,1)
loc
=
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
tf
.
logging
.
info
(
"p(z_0) ~ N(0,%0.1f)"
%
self
.
variance
)
p_zt
=
self
.
transition_dist
(
loc
=
loc
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
loc
)
*
self
.
variance
))
return
p_zt
def
generative
(
self
,
unused_obs
,
z_ni
):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if
self
.
transition_type
==
ROUND_TRANSITION
:
loc
=
tf
.
round
(
z_ni
)
elif
self
.
transition_type
==
STANDARD_TRANSITION
:
loc
=
z_ni
generative_sigma_sq
=
tf
.
ones_like
(
loc
)
*
self
.
observation_variance
return
self
.
transition_dist
(
loc
=
loc
,
scale
=
tf
.
sqrt
(
generative_sigma_sq
))
class
BimodalPriorP
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
mixing_coeff
=
0.5
,
prior_mode_mean
=
1
,
sigma_min
=
1e-5
,
variance
=
1.0
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
trainable
=
True
,
init_bs_to_zero
=
False
,
graph_collection_name
=
"P_VARS"
):
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
variance
=
variance
self
.
mixing_coeff
=
mixing_coeff
self
.
prior_mode_mean
=
prior_mode_mean
if
init_bs_to_zero
:
initializers
=
[
tf
.
zeros_initializer
for
_
in
xrange
(
num_timesteps
)]
else
:
initializers
=
[
tf
.
random_uniform_initializer
(
seed
=
random_seed
)
for
_
in
xrange
(
num_timesteps
)]
self
.
bs
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"b_%d"
%
(
t
+
1
),
initializer
=
initializers
[
t
],
collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
graph_collection_name
],
trainable
=
trainable
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
Bs
=
tf
.
cumsum
(
self
.
bs
,
reverse
=
True
,
axis
=
0
)
def
posterior
(
self
,
observation
,
prev_state
,
t
):
# NOTE: This is currently wrong, but would require a refactoring of
# summarize_q to fix as kl is not defined for a mixture
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu
=
observation
-
self
.
Bs
[
t
]
if
t
>
0
:
mu
+=
(
prev_state
+
self
.
bs
[
t
-
1
])
*
float
(
self
.
num_timesteps
-
t
)
mu
/=
float
(
self
.
num_timesteps
-
t
+
1
)
sigma
=
tf
.
ones_like
(
mu
)
*
self
.
variance
*
(
float
(
self
.
num_timesteps
-
t
)
/
float
(
self
.
num_timesteps
-
t
+
1
))
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
def
lookahead
(
self
,
state
,
t
):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu
=
state
+
self
.
Bs
[
t
]
sigma
=
tf
.
ones_like
(
state
)
*
self
.
variance
*
float
(
self
.
num_timesteps
-
t
)
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
def
likelihood
(
self
,
observation
):
batch_size
=
tf
.
shape
(
observation
)[
0
]
sum_of_bs
=
tf
.
tile
(
tf
.
reduce_sum
(
self
.
bs
,
axis
=
0
)[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
sigma
=
tf
.
ones_like
(
sum_of_bs
)
*
self
.
variance
*
(
self
.
num_timesteps
+
1
)
mu_pos
=
(
tf
.
ones
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
*
self
.
prior_mode_mean
)
+
sum_of_bs
mu_neg
=
(
tf
.
ones
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
*
-
self
.
prior_mode_mean
)
+
sum_of_bs
zn_pos
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu_pos
,
scale
=
tf
.
sqrt
(
sigma
))
zn_neg
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu_neg
,
scale
=
tf
.
sqrt
(
sigma
))
mode_probs
=
tf
.
convert_to_tensor
([
self
.
mixing_coeff
,
1
-
self
.
mixing_coeff
],
dtype
=
tf
.
float64
)
mode_probs
=
tf
.
tile
(
mode_probs
[
tf
.
newaxis
,
tf
.
newaxis
,
:],
[
batch_size
,
1
,
1
])
mode_selection_dist
=
tf
.
contrib
.
distributions
.
Categorical
(
probs
=
mode_probs
)
zn_dist
=
tf
.
contrib
.
distributions
.
Mixture
(
cat
=
mode_selection_dist
,
components
=
[
zn_pos
,
zn_neg
],
validate_args
=
True
)
# Average over the batch and take the sum over the state size
return
tf
.
reduce_mean
(
tf
.
reduce_sum
(
zn_dist
.
log_prob
(
observation
),
axis
=
1
))
def
p_zt
(
self
,
prev_state
,
t
):
"""Computes the model p(z_t| z_{t-1})."""
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
if
t
>
0
:
z_mu_p
=
prev_state
+
self
.
bs
[
t
-
1
]
p_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
z_mu_p
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
z_mu_p
)
*
self
.
variance
))
return
p_zt
else
:
# p(z_0) is mixture of two Normals
mu_pos
=
tf
.
ones
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
*
self
.
prior_mode_mean
mu_neg
=
tf
.
ones
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
*
-
self
.
prior_mode_mean
z0_pos
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu_pos
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
mu_pos
)
*
self
.
variance
))
z0_neg
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu_neg
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
mu_neg
)
*
self
.
variance
))
mode_probs
=
tf
.
convert_to_tensor
([
self
.
mixing_coeff
,
1
-
self
.
mixing_coeff
],
dtype
=
tf
.
float64
)
mode_probs
=
tf
.
tile
(
mode_probs
[
tf
.
newaxis
,
tf
.
newaxis
,
:],
[
batch_size
,
1
,
1
])
mode_selection_dist
=
tf
.
contrib
.
distributions
.
Categorical
(
probs
=
mode_probs
)
z0_dist
=
tf
.
contrib
.
distributions
.
Mixture
(
cat
=
mode_selection_dist
,
components
=
[
z0_pos
,
z0_neg
],
validate_args
=
False
)
return
z0_dist
def
generative
(
self
,
unused_observation
,
z_nm1
):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu
=
z_nm1
+
self
.
bs
[
-
1
]
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
generative_p_mu
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
generative_p_mu
)
*
self
.
variance
))
class
Model
(
object
):
def
__init__
(
self
,
p
,
q
,
r
,
state_size
,
num_timesteps
,
dtype
=
tf
.
float32
):
self
.
p
=
p
self
.
q
=
q
self
.
r
=
r
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
dtype
=
dtype
def
zero_state
(
self
,
batch_size
):
return
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
def
__call__
(
self
,
prev_state
,
observation
,
t
):
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt
=
self
.
q
.
q_zt
(
observation
,
prev_state
,
t
)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt
=
self
.
p
.
p_zt
(
prev_state
,
t
)
# sample from q
zt
=
q_zt
.
sample
()
r_xn
=
self
.
r
.
r_xn
(
zt
,
t
)
# Calculate the logprobs and sum over the state size.
log_q_zt
=
tf
.
reduce_sum
(
q_zt
.
log_prob
(
zt
),
axis
=
1
)
log_p_zt
=
tf
.
reduce_sum
(
p_zt
.
log_prob
(
zt
),
axis
=
1
)
log_r_xn
=
tf
.
reduce_sum
(
r_xn
.
log_prob
(
observation
),
axis
=
1
)
# If we're at the last timestep, also calc the logprob of the observation.
if
t
==
self
.
num_timesteps
-
1
:
generative_dist
=
self
.
p
.
generative
(
observation
,
zt
)
log_p_x_given_z
=
tf
.
reduce_sum
(
generative_dist
.
log_prob
(
observation
),
axis
=
1
)
else
:
log_p_x_given_z
=
tf
.
zeros_like
(
log_q_zt
)
return
(
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
log_r_xn
)
@
staticmethod
def
create
(
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
r_sigma_init
=
1
,
variance
=
1.0
,
mixing_coeff
=
0.5
,
prior_mode_mean
=
1.0
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
train_p
=
True
,
p_type
=
"unimodal"
,
q_type
=
"normal"
,
observation_variance
=
1.0
,
transition_type
=
STANDARD_TRANSITION
,
use_bs
=
True
):
if
p_type
==
"unimodal"
:
p
=
P
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
variance
=
variance
,
dtype
=
dtype
,
random_seed
=
random_seed
,
trainable
=
train_p
,
init_bs_to_zero
=
not
use_bs
)
elif
p_type
==
"bimodal"
:
p
=
BimodalPriorP
(
state_size
,
num_timesteps
,
mixing_coeff
=
mixing_coeff
,
prior_mode_mean
=
prior_mode_mean
,
sigma_min
=
sigma_min
,
variance
=
variance
,
dtype
=
dtype
,
random_seed
=
random_seed
,
trainable
=
train_p
,
init_bs_to_zero
=
not
use_bs
)
elif
"nonlinear"
in
p_type
:
if
"cauchy"
in
p_type
:
trans_dist
=
tf
.
contrib
.
distributions
.
Cauchy
else
:
trans_dist
=
tf
.
contrib
.
distributions
.
Normal
p
=
ShortChainNonlinearP
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
variance
=
variance
,
observation_variance
=
observation_variance
,
transition_type
=
transition_type
,
transition_dist
=
trans_dist
,
dtype
=
dtype
,
random_seed
=
random_seed
)
if
q_type
==
"normal"
:
q_class
=
Q
elif
q_type
==
"simple_mean"
:
q_class
=
SimpleMeanQ
elif
q_type
==
"prev_state"
:
q_class
=
PreviousStateQ
elif
q_type
==
"observation"
:
q_class
=
ObservationQ
q
=
q_class
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
dtype
=
dtype
,
random_seed
=
random_seed
,
init_mu0_to_zero
=
not
use_bs
)
r
=
R
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
sigma_init
=
r_sigma_init
,
dtype
=
dtype
,
random_seed
=
random_seed
)
model
=
Model
(
p
,
q
,
r
,
state_size
,
num_timesteps
,
dtype
=
dtype
)
return
model
class
BackwardsModel
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
):
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
bs
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"b_%d"
%
(
t
+
1
),
initializer
=
tf
.
zeros_initializer
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
Bs
=
tf
.
cumsum
(
self
.
bs
,
reverse
=
True
,
axis
=
0
)
self
.
q_mus
=
[
snt
.
Linear
(
output_size
=
state_size
)
for
_
in
xrange
(
num_timesteps
)
]
self
.
q_sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"q_sigma_%d"
%
(
t
+
1
),
initializer
=
tf
.
zeros_initializer
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
r_mus
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"r_mu_%d"
%
(
t
+
1
),
initializer
=
tf
.
zeros_initializer
)
for
t
in
xrange
(
num_timesteps
)
]
self
.
r_sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"r_sigma_%d"
%
(
t
+
1
),
initializer
=
tf
.
zeros_initializer
)
for
t
in
xrange
(
num_timesteps
)
]
def
zero_state
(
self
,
batch_size
):
return
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
def
posterior
(
self
,
unused_observation
,
prev_state
,
unused_t
):
# TODO(dieterichl): Correct this.
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
tf
.
zeros_like
(
prev_state
),
scale
=
tf
.
zeros_like
(
prev_state
))
def
lookahead
(
self
,
state
,
unused_t
):
# TODO(dieterichl): Correct this.
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
tf
.
zeros_like
(
state
),
scale
=
tf
.
zeros_like
(
state
))
def
q_zt
(
self
,
observation
,
next_state
,
t
):
"""Computes the variational posterior q(z_{t}|z_{t+1}, z_n)."""
t_backwards
=
self
.
num_timesteps
-
t
-
1
batch_size
=
tf
.
shape
(
next_state
)[
0
]
q_mu
=
self
.
q_mus
[
t_backwards
](
tf
.
concat
([
observation
,
next_state
],
axis
=
1
))
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
q_sigmas
[
t_backwards
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
return
q_zt
def
p_zt
(
self
,
prev_state
,
t
):
"""Computes the model p(z_{t+1}| z_{t})."""
t_backwards
=
self
.
num_timesteps
-
t
-
1
z_mu_p
=
prev_state
+
self
.
bs
[
t_backwards
]
p_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
z_mu_p
,
scale
=
tf
.
ones_like
(
z_mu_p
))
return
p_zt
def
generative
(
self
,
unused_observation
,
z_nm1
):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu
=
z_nm1
+
self
.
bs
[
-
1
]
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
generative_p_mu
,
scale
=
tf
.
ones_like
(
generative_p_mu
))
def
r
(
self
,
z_t
,
t
):
t_backwards
=
self
.
num_timesteps
-
t
-
1
batch_size
=
tf
.
shape
(
z_t
)[
0
]
r_mu
=
tf
.
tile
(
self
.
r_mus
[
t_backwards
][
tf
.
newaxis
,
:],
[
batch_size
,
1
])
r_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
r_sigmas
[
t_backwards
]),
self
.
sigma_min
)
r_sigma
=
tf
.
tile
(
r_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
r_mu
,
scale
=
tf
.
sqrt
(
r_sigma
))
def
likelihood
(
self
,
observation
):
batch_size
=
tf
.
shape
(
observation
)[
0
]
mu
=
tf
.
tile
(
tf
.
reduce_sum
(
self
.
bs
,
axis
=
0
)[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
sigma
=
tf
.
ones_like
(
mu
)
*
(
self
.
num_timesteps
+
1
)
dist
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
mu
,
scale
=
tf
.
sqrt
(
sigma
))
# Average over the batch and take the sum over the state size
return
tf
.
reduce_mean
(
tf
.
reduce_sum
(
dist
.
log_prob
(
observation
),
axis
=
1
))
def
__call__
(
self
,
next_state
,
observation
,
t
):
# next state = z_{t+1}
# Compute the q distribution over z, q(z_{t}|z_n, z_{t+1}).
q_zt
=
self
.
q_zt
(
observation
,
next_state
,
t
)
# sample from q
zt
=
q_zt
.
sample
()
# Compute the p distribution over z, p(z_{t+1}|z_{t}).
p_zt
=
self
.
p_zt
(
zt
,
t
)
# Compute log p(z_{t+1} | z_t)
if
t
==
0
:
log_p_zt
=
p_zt
.
log_prob
(
observation
)
else
:
log_p_zt
=
p_zt
.
log_prob
(
next_state
)
# Compute r prior over zt
r_zt
=
self
.
r
(
zt
,
t
)
log_r_zt
=
r_zt
.
log_prob
(
zt
)
# Compute proposal density at zt
log_q_zt
=
q_zt
.
log_prob
(
zt
)
# If we're at the last timestep, also calc the logprob of the observation.
if
t
==
self
.
num_timesteps
-
1
:
p_z0_dist
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
tf
.
zeros_like
(
zt
),
scale
=
tf
.
ones_like
(
zt
))
z0_log_prob
=
p_z0_dist
.
log_prob
(
zt
)
else
:
z0_log_prob
=
tf
.
zeros_like
(
log_q_zt
)
return
(
zt
,
log_q_zt
,
log_p_zt
,
z0_log_prob
,
log_r_zt
)
class
LongChainP
(
object
):
def
__init__
(
self
,
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
1e-5
,
variance
=
1.0
,
observation_variance
=
1.0
,
observation_type
=
STANDARD_OBSERVATION
,
transition_type
=
STANDARD_TRANSITION
,
dtype
=
tf
.
float32
,
random_seed
=
None
):
self
.
state_size
=
state_size
self
.
steps_per_obs
=
steps_per_obs
self
.
num_obs
=
num_obs
self
.
num_timesteps
=
steps_per_obs
*
num_obs
+
1
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
variance
=
variance
self
.
observation_variance
=
observation_variance
self
.
observation_type
=
observation_type
self
.
transition_type
=
transition_type
def
likelihood
(
self
,
observations
):
"""Computes the model's true likelihood of the observations.
Args:
observations: A [batch_size, m, state_size] Tensor representing each of
the m observations.
Returns:
logprob: The true likelihood of the observations given the model.
"""
raise
ValueError
(
"Likelihood is not defined for long-chain models"
)
# batch_size = tf.shape(observations)[0]
# mu = tf.zeros([batch_size, self.state_size, self.num_obs], dtype=self.dtype)
# sigma = np.fromfunction(
# lambda i, j: 1 + self.steps_per_obs*np.minimum(i+1, j+1),
# [self.num_obs, self.num_obs])
# sigma += np.eye(self.num_obs)
# sigma = tf.convert_to_tensor(sigma * self.variance, dtype=self.dtype)
# sigma = tf.tile(sigma[tf.newaxis, tf.newaxis, ...],
# [batch_size, self.state_size, 1, 1])
# dist = tf.contrib.distributions.MultivariateNormalFullCovariance(
# loc=mu,
# covariance_matrix=sigma)
# Average over the batch and take the sum over the state size
#return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observations), axis=1))
def
p_zt
(
self
,
prev_state
,
t
):
"""Computes the model p(z_t| z_{t-1})."""
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
if
t
>
0
:
if
self
.
transition_type
==
ROUND_TRANSITION
:
loc
=
tf
.
round
(
prev_state
)
tf
.
logging
.
info
(
"p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)"
%
(
t
,
t
-
1
,
t
-
1
,
self
.
variance
))
elif
self
.
transition_type
==
STANDARD_TRANSITION
:
loc
=
prev_state
tf
.
logging
.
info
(
"p(z_%d | z_%d) ~ N(z_%d, %0.1f)"
%
(
t
,
t
-
1
,
t
-
1
,
self
.
variance
))
else
:
# p(z_0) is Normal(0,1)
loc
=
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
tf
.
logging
.
info
(
"p(z_0) ~ N(0,%0.1f)"
%
self
.
variance
)
p_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
loc
,
scale
=
tf
.
sqrt
(
tf
.
ones_like
(
loc
)
*
self
.
variance
))
return
p_zt
def
generative
(
self
,
z_ni
,
t
):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if
self
.
observation_type
==
SQUARED_OBSERVATION
:
generative_mu
=
tf
.
square
(
z_ni
)
tf
.
logging
.
info
(
"p(x_%d | z_%d) ~ N(z_%d^2, %0.1f)"
%
(
t
,
t
,
t
,
self
.
variance
))
elif
self
.
observation_type
==
ABS_OBSERVATION
:
generative_mu
=
tf
.
abs
(
z_ni
)
tf
.
logging
.
info
(
"p(x_%d | z_%d) ~ N(|z_%d|, %0.1f)"
%
(
t
,
t
,
t
,
self
.
variance
))
elif
self
.
observation_type
==
STANDARD_OBSERVATION
:
generative_mu
=
z_ni
tf
.
logging
.
info
(
"p(x_%d | z_%d) ~ N(z_%d, %0.1f)"
%
(
t
,
t
,
t
,
self
.
variance
))
generative_sigma_sq
=
tf
.
ones_like
(
generative_mu
)
*
self
.
observation_variance
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
generative_mu
,
scale
=
tf
.
sqrt
(
generative_sigma_sq
))
class
LongChainQ
(
object
):
def
__init__
(
self
,
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
random_seed
=
None
):
self
.
state_size
=
state_size
self
.
sigma_min
=
sigma_min
self
.
dtype
=
dtype
self
.
steps_per_obs
=
steps_per_obs
self
.
num_obs
=
num_obs
self
.
num_timesteps
=
num_obs
*
steps_per_obs
+
1
initializers
=
{
"w"
:
tf
.
random_uniform_initializer
(
seed
=
random_seed
),
"b"
:
tf
.
zeros_initializer
}
self
.
mus
=
[
snt
.
Linear
(
output_size
=
state_size
,
initializers
=
initializers
)
for
t
in
xrange
(
self
.
num_timesteps
)
]
self
.
sigmas
=
[
tf
.
get_variable
(
shape
=
[
state_size
],
dtype
=
self
.
dtype
,
name
=
"q_sigma_%d"
%
(
t
+
1
),
initializer
=
tf
.
random_uniform_initializer
(
seed
=
random_seed
))
for
t
in
xrange
(
self
.
num_timesteps
)
]
def
first_relevant_obs_index
(
self
,
t
):
return
int
(
max
((
t
-
1
)
/
self
.
steps_per_obs
,
0
))
def
q_zt
(
self
,
observations
,
prev_state
,
t
):
"""Computes a distribution over z_t.
Args:
observations: a [batch_size, num_observations, state_size] Tensor.
prev_state: a [batch_size, state_size] Tensor.
t: The current timestep, an int Tensor.
"""
# filter out unneeded past obs
first_relevant_obs_index
=
int
(
math
.
floor
(
max
(
t
-
1
,
0
)
/
self
.
steps_per_obs
))
num_relevant_observations
=
self
.
num_obs
-
first_relevant_obs_index
observations
=
observations
[:,
first_relevant_obs_index
:,:]
batch_size
=
tf
.
shape
(
prev_state
)[
0
]
# concatenate the prev state and observations along the second axis (that is
# not the batch or state size axis, and then flatten it to
# [batch_size, (num_relevant_observations + 1) * state_size] to feed it into
# the linear layer.
q_input
=
tf
.
concat
([
observations
,
prev_state
[:,
tf
.
newaxis
,
:]],
axis
=
1
)
q_input
=
tf
.
reshape
(
q_input
,
[
batch_size
,
(
num_relevant_observations
+
1
)
*
self
.
state_size
])
q_mu
=
self
.
mus
[
t
](
q_input
)
q_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
q_sigma
=
tf
.
tile
(
q_sigma
[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
q_zt
=
tf
.
contrib
.
distributions
.
Normal
(
loc
=
q_mu
,
scale
=
tf
.
sqrt
(
q_sigma
))
tf
.
logging
.
info
(
"q(z_{t} | z_{tm1}, x_{obsf}:{obst}) ~ N(Linear([z_{tm1},x_{obsf}:{obst}]), sigma_{t})"
.
format
(
**
{
"t"
:
t
,
"tm1"
:
t
-
1
,
"obsf"
:
(
first_relevant_obs_index
+
1
)
*
self
.
steps_per_obs
,
"obst"
:
self
.
steps_per_obs
*
self
.
num_obs
}))
return
q_zt
def
summarize_weights
(
self
):
pass
class
LongChainR
(
object
):
def
__init__
(
self
,
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
random_seed
=
None
):
self
.
state_size
=
state_size
self
.
dtype
=
dtype
self
.
sigma_min
=
sigma_min
self
.
steps_per_obs
=
steps_per_obs
self
.
num_obs
=
num_obs
self
.
num_timesteps
=
num_obs
*
steps_per_obs
+
1
self
.
sigmas
=
[
tf
.
get_variable
(
shape
=
[
self
.
num_future_obs
(
t
)],
dtype
=
self
.
dtype
,
name
=
"r_sigma_%d"
%
(
t
+
1
),
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer
=
tf
.
constant_initializer
(
1.0
))
for
t
in
range
(
self
.
num_timesteps
)
]
def
first_future_obs_index
(
self
,
t
):
return
int
(
math
.
floor
(
t
/
self
.
steps_per_obs
))
def
num_future_obs
(
self
,
t
):
return
int
(
self
.
num_obs
-
self
.
first_future_obs_index
(
t
))
def
r_xn
(
self
,
z_t
,
t
):
"""Computes a distribution over the future observations given current latent
state.
The indexing in these messages is 1 indexed and inclusive. This is
consistent with the latex documents.
Args:
z_t: [batch_size, state_size] Tensor
t: Current timestep
"""
tf
.
logging
.
info
(
"r(x_{start}:{end} | z_{t}) ~ N(z_{t}, sigma_{t})"
.
format
(
**
{
"t"
:
t
,
"start"
:
(
self
.
first_future_obs_index
(
t
)
+
1
)
*
self
.
steps_per_obs
,
"end"
:
self
.
num_timesteps
-
1
}))
batch_size
=
tf
.
shape
(
z_t
)[
0
]
# the mean for all future observations is the same.
# this tiling results in a [batch_size, num_future_obs, state_size] Tensor
r_mu
=
tf
.
tile
(
z_t
[:,
tf
.
newaxis
,:],
[
1
,
self
.
num_future_obs
(
t
),
1
])
# compute the variance
r_sigma
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
self
.
sigmas
[
t
]),
self
.
sigma_min
)
# the variance is the same across all state dimensions, so we only have to
# time sigma to be [batch_size, num_future_obs].
r_sigma
=
tf
.
tile
(
r_sigma
[
tf
.
newaxis
,:,
tf
.
newaxis
],
[
batch_size
,
1
,
self
.
state_size
])
return
tf
.
contrib
.
distributions
.
Normal
(
loc
=
r_mu
,
scale
=
tf
.
sqrt
(
r_sigma
))
def
summarize_weights
(
self
):
pass
class
LongChainModel
(
object
):
def
__init__
(
self
,
p
,
q
,
r
,
state_size
,
num_obs
,
steps_per_obs
,
dtype
=
tf
.
float32
,
disable_r
=
False
):
self
.
p
=
p
self
.
q
=
q
self
.
r
=
r
self
.
disable_r
=
disable_r
self
.
state_size
=
state_size
self
.
num_obs
=
num_obs
self
.
steps_per_obs
=
steps_per_obs
self
.
num_timesteps
=
steps_per_obs
*
num_obs
+
1
self
.
dtype
=
dtype
def
zero_state
(
self
,
batch_size
):
return
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
def
next_obs_ind
(
self
,
t
):
return
int
(
math
.
floor
(
max
(
t
-
1
,
0
)
/
self
.
steps_per_obs
))
def
__call__
(
self
,
prev_state
,
observations
,
t
):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt
=
self
.
q
.
q_zt
(
observations
,
prev_state
,
t
)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt
=
self
.
p
.
p_zt
(
prev_state
,
t
)
# sample from q and evaluate the logprobs, summing over the state size
zt
=
q_zt
.
sample
()
log_q_zt
=
tf
.
reduce_sum
(
q_zt
.
log_prob
(
zt
),
axis
=
1
)
log_p_zt
=
tf
.
reduce_sum
(
p_zt
.
log_prob
(
zt
),
axis
=
1
)
if
not
self
.
disable_r
and
t
<
self
.
num_timesteps
-
1
:
# score the remaining observations using r
r_xn
=
self
.
r
.
r_xn
(
zt
,
t
)
log_r_xn
=
r_xn
.
log_prob
(
observations
[:,
self
.
next_obs_ind
(
t
+
1
):,
:])
# sum over state size and observation, leaving the batch index
log_r_xn
=
tf
.
reduce_sum
(
log_r_xn
,
axis
=
[
1
,
2
])
else
:
log_r_xn
=
tf
.
zeros_like
(
log_p_zt
)
if
t
!=
0
and
t
%
self
.
steps_per_obs
==
0
:
generative_dist
=
self
.
p
.
generative
(
zt
,
t
)
log_p_x_given_z
=
generative_dist
.
log_prob
(
observations
[:,
self
.
next_obs_ind
(
t
),:])
log_p_x_given_z
=
tf
.
reduce_sum
(
log_p_x_given_z
,
axis
=
1
)
else
:
log_p_x_given_z
=
tf
.
zeros_like
(
log_q_zt
)
return
(
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
log_r_xn
)
@
staticmethod
def
create
(
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
1e-5
,
variance
=
1.0
,
observation_variance
=
1.0
,
observation_type
=
STANDARD_OBSERVATION
,
transition_type
=
STANDARD_TRANSITION
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
disable_r
=
False
):
p
=
LongChainP
(
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
sigma_min
,
variance
=
variance
,
observation_variance
=
observation_variance
,
observation_type
=
observation_type
,
transition_type
=
transition_type
,
dtype
=
dtype
,
random_seed
=
random_seed
)
q
=
LongChainQ
(
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
sigma_min
,
dtype
=
dtype
,
random_seed
=
random_seed
)
r
=
LongChainR
(
state_size
,
num_obs
,
steps_per_obs
,
sigma_min
=
sigma_min
,
dtype
=
dtype
,
random_seed
=
random_seed
)
model
=
LongChainModel
(
p
,
q
,
r
,
state_size
,
num_obs
,
steps_per_obs
,
dtype
=
dtype
,
disable_r
=
disable_r
)
return
model
class
RTilde
(
object
):
def
__init__
(
self
,
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
graph_collection_name
=
"R_TILDE_VARS"
):
self
.
dtype
=
dtype
self
.
sigma_min
=
sigma_min
initializers
=
{
"w"
:
tf
.
truncated_normal_initializer
(
seed
=
random_seed
),
"b"
:
tf
.
zeros_initializer
}
self
.
graph_collection_name
=
graph_collection_name
def
custom_getter
(
getter
,
*
args
,
**
kwargs
):
out
=
getter
(
*
args
,
**
kwargs
)
ref
=
tf
.
get_collection_ref
(
self
.
graph_collection_name
)
if
out
not
in
ref
:
ref
.
append
(
out
)
return
out
self
.
fns
=
[
snt
.
Linear
(
output_size
=
2
*
state_size
,
initializers
=
initializers
,
name
=
"r_tilde_%d"
%
t
,
custom_getter
=
custom_getter
)
for
t
in
xrange
(
num_timesteps
)
]
def
r_zt
(
self
,
z_t
,
observation
,
t
):
#out = self.fns[t](tf.stop_gradient(tf.concat([z_t, observation], axis=1)))
out
=
self
.
fns
[
t
](
tf
.
concat
([
z_t
,
observation
],
axis
=
1
))
mu
,
raw_sigma_sq
=
tf
.
split
(
out
,
2
,
axis
=
1
)
sigma_sq
=
tf
.
maximum
(
tf
.
nn
.
softplus
(
raw_sigma_sq
),
self
.
sigma_min
)
return
mu
,
sigma_sq
class
TDModel
(
object
):
def
__init__
(
self
,
p
,
q
,
r_tilde
,
state_size
,
num_timesteps
,
dtype
=
tf
.
float32
,
disable_r
=
False
):
self
.
p
=
p
self
.
q
=
q
self
.
r_tilde
=
r_tilde
self
.
disable_r
=
disable_r
self
.
state_size
=
state_size
self
.
num_timesteps
=
num_timesteps
self
.
dtype
=
dtype
def
zero_state
(
self
,
batch_size
):
return
tf
.
zeros
([
batch_size
,
self
.
state_size
],
dtype
=
self
.
dtype
)
def
__call__
(
self
,
prev_state
,
observation
,
t
):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt
=
self
.
q
.
q_zt
(
observation
,
prev_state
,
t
)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt
=
self
.
p
.
p_zt
(
prev_state
,
t
)
# sample from q and evaluate the logprobs, summing over the state size
zt
=
q_zt
.
sample
()
# If it isn't the last timestep, compute the distribution over the next z.
if
t
<
self
.
num_timesteps
-
1
:
p_ztplus1
=
self
.
p
.
p_zt
(
zt
,
t
+
1
)
else
:
p_ztplus1
=
None
log_q_zt
=
tf
.
reduce_sum
(
q_zt
.
log_prob
(
zt
),
axis
=
1
)
log_p_zt
=
tf
.
reduce_sum
(
p_zt
.
log_prob
(
zt
),
axis
=
1
)
if
not
self
.
disable_r
and
t
<
self
.
num_timesteps
-
1
:
# score the remaining observations using r
r_tilde_mu
,
r_tilde_sigma_sq
=
self
.
r_tilde
.
r_zt
(
zt
,
observation
,
t
+
1
)
else
:
r_tilde_mu
=
None
r_tilde_sigma_sq
=
None
if
t
==
self
.
num_timesteps
-
1
:
generative_dist
=
self
.
p
.
generative
(
observation
,
zt
)
log_p_x_given_z
=
tf
.
reduce_sum
(
generative_dist
.
log_prob
(
observation
),
axis
=
1
)
else
:
log_p_x_given_z
=
tf
.
zeros_like
(
log_q_zt
)
return
(
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
r_tilde_mu
,
r_tilde_sigma_sq
,
p_ztplus1
)
@
staticmethod
def
create
(
state_size
,
num_timesteps
,
sigma_min
=
1e-5
,
variance
=
1.0
,
dtype
=
tf
.
float32
,
random_seed
=
None
,
train_p
=
True
,
p_type
=
"unimodal"
,
q_type
=
"normal"
,
mixing_coeff
=
0.5
,
prior_mode_mean
=
1.0
,
observation_variance
=
1.0
,
transition_type
=
STANDARD_TRANSITION
,
use_bs
=
True
):
if
p_type
==
"unimodal"
:
p
=
P
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
variance
=
variance
,
dtype
=
dtype
,
random_seed
=
random_seed
,
trainable
=
train_p
,
init_bs_to_zero
=
not
use_bs
)
elif
p_type
==
"bimodal"
:
p
=
BimodalPriorP
(
state_size
,
num_timesteps
,
mixing_coeff
=
mixing_coeff
,
prior_mode_mean
=
prior_mode_mean
,
sigma_min
=
sigma_min
,
variance
=
variance
,
dtype
=
dtype
,
random_seed
=
random_seed
,
trainable
=
train_p
,
init_bs_to_zero
=
not
use_bs
)
elif
"nonlinear"
in
p_type
:
if
"cauchy"
in
p_type
:
trans_dist
=
tf
.
contrib
.
distributions
.
Cauchy
else
:
trans_dist
=
tf
.
contrib
.
distributions
.
Normal
p
=
ShortChainNonlinearP
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
variance
=
variance
,
observation_variance
=
observation_variance
,
transition_type
=
transition_type
,
transition_dist
=
trans_dist
,
dtype
=
dtype
,
random_seed
=
random_seed
)
if
q_type
==
"normal"
:
q_class
=
Q
elif
q_type
==
"simple_mean"
:
q_class
=
SimpleMeanQ
elif
q_type
==
"prev_state"
:
q_class
=
PreviousStateQ
elif
q_type
==
"observation"
:
q_class
=
ObservationQ
q
=
q_class
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
dtype
=
dtype
,
random_seed
=
random_seed
,
init_mu0_to_zero
=
not
use_bs
)
r_tilde
=
RTilde
(
state_size
,
num_timesteps
,
sigma_min
=
sigma_min
,
dtype
=
dtype
,
random_seed
=
random_seed
)
model
=
TDModel
(
p
,
q
,
r_tilde
,
state_size
,
num_timesteps
,
dtype
=
dtype
)
return
model
research/fivo/experimental/run.sh
0 → 100644
View file @
27b4acd4
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
model
=
"forward"
T
=
5
num_obs
=
1
var
=
0.1
n
=
4
lr
=
0.0001
bound
=
"fivo-aux"
q_type
=
"normal"
resampling_method
=
"multinomial"
rgrad
=
"true"
p_type
=
"unimodal"
use_bs
=
false
LOGDIR
=
/tmp/fivo/model-
$model
-
$bound
-
$resampling_method
-resampling-rgrad-
$rgrad
-T-
$T
-var-
$var
-n-
$n
-lr-
$lr
-q-
$q_type
-p-
$p_type
python train.py
\
--logdir
=
$LOGDIR
\
--model
=
$model
\
--bound
=
$bound
\
--q_type
=
$q_type
\
--p_type
=
$p_type
\
--variance
=
$var
\
--use_resampling_grads
=
$rgrad
\
--resampling
=
always
\
--resampling_method
=
$resampling_method
\
--batch_size
=
4
\
--num_samples
=
$n
\
--num_timesteps
=
$T
\
--num_eval_samples
=
256
\
--summarize_every
=
100
\
--learning_rate
=
$lr
\
--decay_steps
=
1000000
\
--max_steps
=
1000000000
\
--random_seed
=
1234
\
--train_p
=
false
\
--use_bs
=
$use_bs
\
--alsologtostderr
research/fivo/experimental/summary_utils.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for plotting and summarizing.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
matplotlib.gridspec
as
gridspec
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
scipy
import
tensorflow
as
tf
import
models
def
summarize_ess
(
weights
,
only_last_timestep
=
False
):
"""Plots the effective sample size.
Args:
weights: List of length num_timesteps Tensors of shape
[num_samples, batch_size]
"""
num_timesteps
=
len
(
weights
)
batch_size
=
tf
.
cast
(
tf
.
shape
(
weights
[
0
])[
1
],
dtype
=
tf
.
float64
)
for
i
in
range
(
num_timesteps
):
if
only_last_timestep
and
i
<
num_timesteps
-
1
:
continue
w
=
tf
.
nn
.
softmax
(
weights
[
i
],
dim
=
0
)
centered_weights
=
w
-
tf
.
reduce_mean
(
w
,
axis
=
0
,
keepdims
=
True
)
variance
=
tf
.
reduce_sum
(
tf
.
square
(
centered_weights
))
/
(
batch_size
-
1
)
ess
=
1.
/
tf
.
reduce_mean
(
tf
.
reduce_sum
(
tf
.
square
(
w
),
axis
=
0
))
tf
.
summary
.
scalar
(
"ess/%d"
%
i
,
ess
)
tf
.
summary
.
scalar
(
"ese/%d"
%
i
,
ess
/
batch_size
)
tf
.
summary
.
scalar
(
"weight_variance/%d"
%
i
,
variance
)
def
summarize_particles
(
states
,
weights
,
observation
,
model
):
"""Plots particle locations and weights.
Args:
states: List of length num_timesteps Tensors of shape
[batch_size*num_particles, state_size].
weights: List of length num_timesteps Tensors of shape [num_samples,
batch_size]
observation: Tensor of shape [batch_size*num_samples, state_size]
"""
num_timesteps
=
len
(
weights
)
num_samples
,
batch_size
=
weights
[
0
].
get_shape
().
as_list
()
# get q0 information for plotting
q0_dist
=
model
.
q
.
q_zt
(
observation
,
tf
.
zeros_like
(
states
[
0
]),
0
)
q0_loc
=
q0_dist
.
loc
[
0
:
batch_size
,
0
]
q0_scale
=
q0_dist
.
scale
[
0
:
batch_size
,
0
]
# get posterior information for plotting
post
=
(
model
.
p
.
mixing_coeff
,
model
.
p
.
prior_mode_mean
,
model
.
p
.
variance
,
tf
.
reduce_sum
(
model
.
p
.
bs
),
model
.
p
.
num_timesteps
)
# Reshape states and weights to be [time, num_samples, batch_size]
states
=
tf
.
stack
(
states
)
weights
=
tf
.
stack
(
weights
)
# normalize the weights over the sample dimension
weights
=
tf
.
nn
.
softmax
(
weights
,
dim
=
1
)
states
=
tf
.
reshape
(
states
,
tf
.
shape
(
weights
))
ess
=
1.
/
tf
.
reduce_sum
(
tf
.
square
(
weights
),
axis
=
1
)
def
_plot_states
(
states_batch
,
weights_batch
,
observation_batch
,
ess_batch
,
q0
,
post
):
"""
states: [time, num_samples, batch_size]
weights [time, num_samples, batch_size]
observation: [batch_size, 1]
q0: ([batch_size], [batch_size])
post: ...
"""
num_timesteps
,
_
,
batch_size
=
states_batch
.
shape
plots
=
[]
for
i
in
range
(
batch_size
):
states
=
states_batch
[:,:,
i
]
weights
=
weights_batch
[:,:,
i
]
observation
=
observation_batch
[
i
]
ess
=
ess_batch
[:,
i
]
q0_loc
=
q0
[
0
][
i
]
q0_scale
=
q0
[
1
][
i
]
fig
=
plt
.
figure
(
figsize
=
(
7
,
(
num_timesteps
+
1
)
*
2
))
# Each timestep gets two plots -- a bar plot and a histogram of state locs.
# The bar plot will be bar_rows rows tall.
# The histogram will be 1 row tall.
# There is also 1 extra plot at the top showing the posterior and q.
bar_rows
=
8
num_rows
=
(
num_timesteps
+
1
)
*
(
bar_rows
+
1
)
gs
=
gridspec
.
GridSpec
(
num_rows
,
1
)
# Figure out how wide to make the plot
prior_lims
=
(
post
[
1
]
*
-
2
,
post
[
1
]
*
2
)
q_lims
=
(
scipy
.
stats
.
norm
.
ppf
(
0.01
,
loc
=
q0_loc
,
scale
=
q0_scale
),
scipy
.
stats
.
norm
.
ppf
(
0.99
,
loc
=
q0_loc
,
scale
=
q0_scale
))
state_width
=
states
.
max
()
-
states
.
min
()
state_lims
=
(
states
.
min
()
-
state_width
*
0.15
,
states
.
max
()
+
state_width
*
0.15
)
lims
=
(
min
(
prior_lims
[
0
],
q_lims
[
0
],
state_lims
[
0
]),
max
(
prior_lims
[
1
],
q_lims
[
1
],
state_lims
[
1
]))
# plot the posterior
z0
=
np
.
arange
(
lims
[
0
],
lims
[
1
],
0.1
)
alpha
,
pos_mu
,
sigma_sq
,
B
,
T
=
post
neg_mu
=
-
pos_mu
scale
=
np
.
sqrt
((
T
+
1
)
*
sigma_sq
)
p_zn
=
(
alpha
*
scipy
.
stats
.
norm
.
pdf
(
observation
,
loc
=
pos_mu
+
B
,
scale
=
scale
)
+
(
1
-
alpha
)
*
scipy
.
stats
.
norm
.
pdf
(
observation
,
loc
=
neg_mu
+
B
,
scale
=
scale
))
p_z0
=
(
alpha
*
scipy
.
stats
.
norm
.
pdf
(
z0
,
loc
=
pos_mu
,
scale
=
np
.
sqrt
(
sigma_sq
))
+
(
1
-
alpha
)
*
scipy
.
stats
.
norm
.
pdf
(
z0
,
loc
=
neg_mu
,
scale
=
np
.
sqrt
(
sigma_sq
)))
p_zn_given_z0
=
scipy
.
stats
.
norm
.
pdf
(
observation
,
loc
=
z0
+
B
,
scale
=
np
.
sqrt
(
T
*
sigma_sq
))
post_z0
=
(
p_z0
*
p_zn_given_z0
)
/
p_zn
# plot q
q_z0
=
scipy
.
stats
.
norm
.
pdf
(
z0
,
loc
=
q0_loc
,
scale
=
q0_scale
)
ax
=
plt
.
subplot
(
gs
[
0
:
bar_rows
,
:])
ax
.
plot
(
z0
,
q_z0
,
color
=
"blue"
)
ax
.
plot
(
z0
,
post_z0
,
color
=
"green"
)
ax
.
plot
(
z0
,
p_z0
,
color
=
"red"
)
ax
.
legend
((
"q"
,
"posterior"
,
"prior"
),
loc
=
"best"
,
prop
=
{
"size"
:
10
})
ax
.
set_xticks
([])
ax
.
set_xlim
(
*
lims
)
# plot the states
for
t
in
range
(
num_timesteps
):
start
=
(
t
+
1
)
*
(
bar_rows
+
1
)
ax1
=
plt
.
subplot
(
gs
[
start
:
start
+
bar_rows
,
:])
ax2
=
plt
.
subplot
(
gs
[
start
+
bar_rows
:
start
+
bar_rows
+
1
,
:])
# plot the states barplot
# ax1.hist(
# states[t, :],
# weights=weights[t, :],
# bins=50,
# edgecolor="none",
# alpha=0.2)
ax1
.
bar
(
states
[
t
,:],
weights
[
t
,:],
width
=
0.02
,
alpha
=
0.2
,
edgecolor
=
"none"
)
ax1
.
set_ylabel
(
"t=%d"
%
t
)
ax1
.
set_xticks
([])
ax1
.
grid
(
True
,
which
=
"both"
)
ax1
.
set_xlim
(
*
lims
)
# plot the observation
ax1
.
axvline
(
x
=
observation
,
color
=
"red"
,
linestyle
=
"dashed"
)
# add the ESS
ax1
.
text
(
0.1
,
0.9
,
"ESS: %0.2f"
%
ess
[
t
],
ha
=
'center'
,
va
=
'center'
,
transform
=
ax1
.
transAxes
)
# plot the state location histogram
ax2
.
hist2d
(
states
[
t
,
:],
np
.
zeros_like
(
states
[
t
,
:]),
bins
=
[
50
,
1
],
cmap
=
"Greys"
)
ax2
.
grid
(
False
)
ax2
.
set_yticks
([])
ax2
.
set_xlim
(
*
lims
)
if
t
!=
num_timesteps
-
1
:
ax2
.
set_xticks
([])
fig
.
canvas
.
draw
()
p
=
np
.
fromstring
(
fig
.
canvas
.
tostring_rgb
(),
dtype
=
np
.
uint8
,
sep
=
""
)
plots
.
append
(
p
.
reshape
(
fig
.
canvas
.
get_width_height
()[::
-
1
]
+
(
3
,)))
plt
.
close
(
fig
)
return
np
.
stack
(
plots
)
plots
=
tf
.
py_func
(
_plot_states
,
[
states
,
weights
,
observation
,
ess
,
(
q0_loc
,
q0_scale
),
post
],
[
tf
.
uint8
])[
0
]
tf
.
summary
.
image
(
"states"
,
plots
,
5
,
collections
=
[
"infrequent_summaries"
])
def
plot_weights
(
weights
,
resampled
=
None
):
"""Plots the weights and effective sample size from an SMC rollout.
Args:
weights: [num_timesteps, num_samples, batch_size] importance weights
resampled: [num_timesteps] 0/1 indicating if resampling ocurred
"""
weights
=
tf
.
convert_to_tensor
(
weights
)
def
_make_plots
(
weights
,
resampled
):
num_timesteps
,
num_samples
,
batch_size
=
weights
.
shape
plots
=
[]
for
i
in
range
(
batch_size
):
fig
,
axes
=
plt
.
subplots
(
nrows
=
1
,
sharex
=
True
,
figsize
=
(
8
,
4
))
axes
.
stackplot
(
np
.
arange
(
num_timesteps
),
np
.
transpose
(
weights
[:,
:,
i
]))
axes
.
set_title
(
"Weights"
)
axes
.
set_xlabel
(
"Steps"
)
axes
.
set_ylim
([
0
,
1
])
axes
.
set_xlim
([
0
,
num_timesteps
-
1
])
for
j
in
np
.
where
(
resampled
>
0
)[
0
]:
axes
.
axvline
(
x
=
j
,
color
=
"red"
,
linestyle
=
"dashed"
,
ymin
=
0.0
,
ymax
=
1.0
)
fig
.
canvas
.
draw
()
data
=
np
.
fromstring
(
fig
.
canvas
.
tostring_rgb
(),
dtype
=
np
.
uint8
,
sep
=
""
)
data
=
data
.
reshape
(
fig
.
canvas
.
get_width_height
()[::
-
1
]
+
(
3
,))
plots
.
append
(
data
)
plt
.
close
(
fig
)
return
np
.
stack
(
plots
,
axis
=
0
)
if
resampled
is
None
:
num_timesteps
,
_
,
batch_size
=
weights
.
get_shape
().
as_list
()
resampled
=
tf
.
zeros
([
num_timesteps
],
dtype
=
tf
.
float32
)
plots
=
tf
.
py_func
(
_make_plots
,
[
tf
.
nn
.
softmax
(
weights
,
dim
=
1
),
tf
.
to_float
(
resampled
)],
[
tf
.
uint8
])[
0
]
batch_size
=
weights
.
get_shape
().
as_list
()[
-
1
]
tf
.
summary
.
image
(
"weights"
,
plots
,
batch_size
,
collections
=
[
"infrequent_summaries"
])
def
summarize_weights
(
weights
,
num_timesteps
,
num_samples
):
# weights is [num_timesteps, num_samples, batch_size]
weights
=
tf
.
convert_to_tensor
(
weights
)
mean
=
tf
.
reduce_mean
(
weights
,
axis
=
1
,
keepdims
=
True
)
squared_diff
=
tf
.
square
(
weights
-
mean
)
variances
=
tf
.
reduce_sum
(
squared_diff
,
axis
=
1
)
/
(
num_samples
-
1
)
# average the variance over the batch
variances
=
tf
.
reduce_mean
(
variances
,
axis
=
1
)
avg_magnitude
=
tf
.
reduce_mean
(
tf
.
abs
(
weights
),
axis
=
[
1
,
2
])
for
t
in
xrange
(
num_timesteps
):
tf
.
summary
.
scalar
(
"weights/variance_%d"
%
t
,
variances
[
t
])
tf
.
summary
.
scalar
(
"weights/magnitude_%d"
%
t
,
avg_magnitude
[
t
])
tf
.
summary
.
histogram
(
"weights/step_%d"
%
t
,
weights
[
t
])
def
summarize_learning_signal
(
rewards
,
tag
):
num_resampling_events
,
_
=
rewards
.
get_shape
().
as_list
()
mean
=
tf
.
reduce_mean
(
rewards
,
axis
=
1
)
avg_magnitude
=
tf
.
reduce_mean
(
tf
.
abs
(
rewards
),
axis
=
1
)
reward_square
=
tf
.
reduce_mean
(
tf
.
square
(
rewards
),
axis
=
1
)
for
t
in
xrange
(
num_resampling_events
):
tf
.
summary
.
scalar
(
"%s/mean_%d"
%
(
tag
,
t
),
mean
[
t
])
tf
.
summary
.
scalar
(
"%s/magnitude_%d"
%
(
tag
,
t
),
avg_magnitude
[
t
])
tf
.
summary
.
scalar
(
"%s/squared_%d"
%
(
tag
,
t
),
reward_square
[
t
])
tf
.
summary
.
histogram
(
"%s/step_%d"
%
(
tag
,
t
),
rewards
[
t
])
def
summarize_qs
(
model
,
observation
,
states
):
model
.
q
.
summarize_weights
()
if
hasattr
(
model
.
p
,
"posterior"
)
and
callable
(
getattr
(
model
.
p
,
"posterior"
)):
states
=
[
tf
.
zeros_like
(
states
[
0
])]
+
states
[:
-
1
]
for
t
,
prev_state
in
enumerate
(
states
):
p
=
model
.
p
.
posterior
(
observation
,
prev_state
,
t
)
q
=
model
.
q
.
q_zt
(
observation
,
prev_state
,
t
)
kl
=
tf
.
reduce_mean
(
tf
.
contrib
.
distributions
.
kl_divergence
(
p
,
q
))
tf
.
summary
.
scalar
(
"kl_q/%d"
%
t
,
tf
.
reduce_mean
(
kl
))
mean_diff
=
q
.
loc
-
p
.
loc
mean_abs_err
=
tf
.
abs
(
mean_diff
)
mean_rel_err
=
tf
.
abs
(
mean_diff
/
p
.
loc
)
tf
.
summary
.
scalar
(
"q_mean_convergence/absolute_error_%d"
%
t
,
tf
.
reduce_mean
(
mean_abs_err
))
tf
.
summary
.
scalar
(
"q_mean_convergence/relative_error_%d"
%
t
,
tf
.
reduce_mean
(
mean_rel_err
))
sigma_diff
=
tf
.
square
(
q
.
scale
)
-
tf
.
square
(
p
.
scale
)
sigma_abs_err
=
tf
.
abs
(
sigma_diff
)
sigma_rel_err
=
tf
.
abs
(
sigma_diff
/
tf
.
square
(
p
.
scale
))
tf
.
summary
.
scalar
(
"q_variance_convergence/absolute_error_%d"
%
t
,
tf
.
reduce_mean
(
sigma_abs_err
))
tf
.
summary
.
scalar
(
"q_variance_convergence/relative_error_%d"
%
t
,
tf
.
reduce_mean
(
sigma_rel_err
))
def
summarize_rs
(
model
,
states
):
model
.
r
.
summarize_weights
()
for
t
,
state
in
enumerate
(
states
):
true_r
=
model
.
p
.
lookahead
(
state
,
t
)
r
=
model
.
r
.
r_xn
(
state
,
t
)
kl
=
tf
.
reduce_mean
(
tf
.
contrib
.
distributions
.
kl_divergence
(
true_r
,
r
))
tf
.
summary
.
scalar
(
"kl_r/%d"
%
t
,
tf
.
reduce_mean
(
kl
))
mean_diff
=
true_r
.
loc
-
r
.
loc
mean_abs_err
=
tf
.
abs
(
mean_diff
)
mean_rel_err
=
tf
.
abs
(
mean_diff
/
true_r
.
loc
)
tf
.
summary
.
scalar
(
"r_mean_convergence/absolute_error_%d"
%
t
,
tf
.
reduce_mean
(
mean_abs_err
))
tf
.
summary
.
scalar
(
"r_mean_convergence/relative_error_%d"
%
t
,
tf
.
reduce_mean
(
mean_rel_err
))
sigma_diff
=
tf
.
square
(
r
.
scale
)
-
tf
.
square
(
true_r
.
scale
)
sigma_abs_err
=
tf
.
abs
(
sigma_diff
)
sigma_rel_err
=
tf
.
abs
(
sigma_diff
/
tf
.
square
(
true_r
.
scale
))
tf
.
summary
.
scalar
(
"r_variance_convergence/absolute_error_%d"
%
t
,
tf
.
reduce_mean
(
sigma_abs_err
))
tf
.
summary
.
scalar
(
"r_variance_convergence/relative_error_%d"
%
t
,
tf
.
reduce_mean
(
sigma_rel_err
))
def
summarize_model
(
model
,
true_bs
,
observation
,
states
,
bound
,
summarize_r
=
True
):
if
hasattr
(
model
.
p
,
"bs"
):
model_b
=
tf
.
reduce_sum
(
model
.
p
.
bs
,
axis
=
0
)
true_b
=
tf
.
reduce_sum
(
true_bs
,
axis
=
0
)
abs_err
=
tf
.
abs
(
model_b
-
true_b
)
rel_err
=
abs_err
/
true_b
tf
.
summary
.
scalar
(
"sum_of_bs/data_generating_process"
,
tf
.
reduce_mean
(
true_b
))
tf
.
summary
.
scalar
(
"sum_of_bs/model"
,
tf
.
reduce_mean
(
model_b
))
tf
.
summary
.
scalar
(
"sum_of_bs/absolute_error"
,
tf
.
reduce_mean
(
abs_err
))
tf
.
summary
.
scalar
(
"sum_of_bs/relative_error"
,
tf
.
reduce_mean
(
rel_err
))
#summarize_qs(model, observation, states)
#if bound == "fivo-aux" and summarize_r:
# summarize_rs(model, states)
def
summarize_grads
(
grads
,
loss_name
):
grad_ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
0.99
)
vectorized_grads
=
tf
.
concat
(
[
tf
.
reshape
(
g
,
[
-
1
])
for
g
,
_
in
grads
if
g
is
not
None
],
axis
=
0
)
new_second_moments
=
tf
.
square
(
vectorized_grads
)
new_first_moments
=
vectorized_grads
maintain_grad_ema_op
=
grad_ema
.
apply
([
new_first_moments
,
new_second_moments
])
first_moments
=
grad_ema
.
average
(
new_first_moments
)
second_moments
=
grad_ema
.
average
(
new_second_moments
)
variances
=
second_moments
-
tf
.
square
(
first_moments
)
tf
.
summary
.
scalar
(
"grad_variance/%s"
%
loss_name
,
tf
.
reduce_mean
(
variances
))
tf
.
summary
.
histogram
(
"grad_variance/%s"
%
loss_name
,
variances
)
tf
.
summary
.
histogram
(
"grad_mean/%s"
%
loss_name
,
first_moments
)
return
maintain_grad_ema_op
research/fivo/experimental/train.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main script for running fivo"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
defaultdict
import
numpy
as
np
import
tensorflow
as
tf
import
bounds
import
data
import
models
import
summary_utils
as
summ
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
app
.
flags
.
DEFINE_integer
(
"random_seed"
,
None
,
"A random seed for the data generating process. Same seed "
"-> same data generating process and initialization."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"bound"
,
"fivo"
,
[
"iwae"
,
"fivo"
,
"fivo-aux"
,
"fivo-aux-td"
],
"The bound to optimize."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"model"
,
"forward"
,
[
"forward"
,
"long_chain"
],
"The model to use."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"q_type"
,
"normal"
,
[
"normal"
,
"simple_mean"
,
"prev_state"
,
"observation"
],
"The parameterization to use for q"
)
tf
.
app
.
flags
.
DEFINE_enum
(
"p_type"
,
"unimodal"
,
[
"unimodal"
,
"bimodal"
,
"nonlinear"
],
"The type of prior."
)
tf
.
app
.
flags
.
DEFINE_boolean
(
"train_p"
,
True
,
"If false, do not train the model p."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"state_size"
,
1
,
"The dimensionality of the state space."
)
tf
.
app
.
flags
.
DEFINE_float
(
"variance"
,
1.0
,
"The variance of the data generating process."
)
tf
.
app
.
flags
.
DEFINE_boolean
(
"use_bs"
,
True
,
"If False, initialize all bs to 0."
)
tf
.
app
.
flags
.
DEFINE_float
(
"bimodal_prior_weight"
,
0.5
,
"The weight assigned to the positive mode of the prior in "
"both the data generating process and p."
)
tf
.
app
.
flags
.
DEFINE_float
(
"bimodal_prior_mean"
,
None
,
"If supplied, sets the mean of the 2 modes of the prior to "
"be 1 and -1 times the supplied value. This is for both the "
"data generating process and p."
)
tf
.
app
.
flags
.
DEFINE_float
(
"fixed_observation"
,
None
,
"If supplied, fix the observation to a constant value in the"
" data generating process only."
)
tf
.
app
.
flags
.
DEFINE_float
(
"r_sigma_init"
,
1.
,
"Value to initialize variance of r to."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"observation_type"
,
models
.
STANDARD_OBSERVATION
,
models
.
OBSERVATION_TYPES
,
"The type of observation for the long chain model."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"transition_type"
,
models
.
STANDARD_TRANSITION
,
models
.
TRANSITION_TYPES
,
"The type of transition for the long chain model."
)
tf
.
app
.
flags
.
DEFINE_float
(
"observation_variance"
,
None
,
"The variance of the observation. Defaults to 'variance'"
)
tf
.
app
.
flags
.
DEFINE_integer
(
"num_timesteps"
,
5
,
"Number of timesteps in the sequence."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"num_observations"
,
1
,
"The number of observations."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"steps_per_observation"
,
5
,
"The number of timesteps between each observation."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"batch_size"
,
4
,
"The number of examples per batch."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"num_samples"
,
4
,
"The number particles to use."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"num_eval_samples"
,
512
,
"The batch size and # of particles to use for eval."
)
tf
.
app
.
flags
.
DEFINE_string
(
"resampling"
,
"always"
,
"How to resample. Accepts 'always','never', or a "
"comma-separated list of booleans like 'true,true,false'."
)
tf
.
app
.
flags
.
DEFINE_enum
(
"resampling_method"
,
"multinomial"
,
[
"multinomial"
,
"stratified"
,
"systematic"
,
"relaxed-logblend"
,
"relaxed-stateblend"
,
"relaxed-linearblend"
,
"relaxed-stateblend-st"
,],
"Type of resampling method to use."
)
tf
.
app
.
flags
.
DEFINE_boolean
(
"use_resampling_grads"
,
True
,
"Whether or not to use resampling grads to optimize FIVO."
"Disabled automatically if resampling_method=relaxed."
)
tf
.
app
.
flags
.
DEFINE_boolean
(
"disable_r"
,
False
,
"If false, r is not used for fivo-aux and is set to zeros."
)
tf
.
app
.
flags
.
DEFINE_float
(
"learning_rate"
,
1e-4
,
"The learning rate to use for ADAM or SGD."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"decay_steps"
,
25000
,
"The number of steps before the learning rate is halved."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"max_steps"
,
int
(
1e6
),
"The number of steps to run training for."
)
tf
.
app
.
flags
.
DEFINE_string
(
"logdir"
,
"/tmp/fivo-aux"
,
"Directory for summaries and checkpoints."
)
tf
.
app
.
flags
.
DEFINE_integer
(
"summarize_every"
,
int
(
1e3
),
"The number of steps between each evaluation."
)
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
combine_grad_lists
(
grad_lists
):
# grads is num_losses by num_variables.
# each list could have different variables.
# for each variable, sum the grads across all losses.
grads_dict
=
defaultdict
(
list
)
var_dict
=
{}
for
grad_list
in
grad_lists
:
for
grad
,
var
in
grad_list
:
if
grad
is
not
None
:
grads_dict
[
var
.
name
].
append
(
grad
)
var_dict
[
var
.
name
]
=
var
final_grads
=
[]
for
var_name
,
var
in
var_dict
.
iteritems
():
grads
=
grads_dict
[
var_name
]
if
len
(
grads
)
>
0
:
tf
.
logging
.
info
(
"Var %s has combined grads from %s."
%
(
var_name
,
[
g
.
name
for
g
in
grads
]))
grad
=
tf
.
reduce_sum
(
grads
,
axis
=
0
)
else
:
tf
.
logging
.
info
(
"Var %s has no grads"
%
var_name
)
grad
=
None
final_grads
.
append
((
grad
,
var
))
return
final_grads
def
make_apply_grads_op
(
losses
,
global_step
,
learning_rate
,
lr_decay_steps
):
for
l
in
losses
:
assert
isinstance
(
l
,
bounds
.
Loss
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
,
global_step
,
lr_decay_steps
,
0.5
,
staircase
=
False
)
tf
.
summary
.
scalar
(
"learning_rate"
,
lr
)
opt
=
tf
.
train
.
AdamOptimizer
(
lr
)
ema_ops
=
[]
grads
=
[]
for
loss_name
,
loss
,
loss_var_collection
in
losses
:
tf
.
logging
.
info
(
"Computing grads of %s w.r.t. vars in collection %s"
%
(
loss_name
,
loss_var_collection
))
g
=
opt
.
compute_gradients
(
loss
,
var_list
=
tf
.
get_collection
(
loss_var_collection
))
ema_ops
.
append
(
summ
.
summarize_grads
(
g
,
loss_name
))
grads
.
append
(
g
)
all_grads
=
combine_grad_lists
(
grads
)
apply_grads_op
=
opt
.
apply_gradients
(
all_grads
,
global_step
=
global_step
)
# Update the emas after applying the grads.
with
tf
.
control_dependencies
([
apply_grads_op
]):
train_op
=
tf
.
group
(
*
ema_ops
)
return
train_op
def
add_check_numerics_ops
():
check_op
=
[]
for
op
in
tf
.
get_default_graph
().
get_operations
():
bad
=
[
"logits/Log"
,
"sample/Reshape"
,
"log_prob/mul"
,
"log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape"
,
"entropy/Reshape"
,
"entropy/LogSoftmax"
,
"Categorical"
,
"Mean"
]
if
all
([
x
not
in
op
.
name
for
x
in
bad
]):
for
output
in
op
.
outputs
:
if
output
.
dtype
in
[
tf
.
float16
,
tf
.
float32
,
tf
.
float64
]:
if
op
.
_get_control_flow_context
()
is
not
None
:
# pylint: disable=protected-access
raise
ValueError
(
"`tf.add_check_numerics_ops() is not compatible "
"with TensorFlow control flow operations such as "
"`tf.cond()` or `tf.while_loop()`."
)
message
=
op
.
name
+
":"
+
str
(
output
.
value_index
)
with
tf
.
control_dependencies
(
check_op
):
check_op
=
[
tf
.
check_numerics
(
output
,
message
=
message
)]
return
tf
.
group
(
*
check_op
)
def
create_long_chain_graph
(
bound
,
state_size
,
num_obs
,
steps_per_obs
,
batch_size
,
num_samples
,
num_eval_samples
,
resampling_schedule
,
use_resampling_grads
,
learning_rate
,
lr_decay_steps
,
dtype
=
"float64"
):
num_timesteps
=
num_obs
*
steps_per_obs
+
1
# Make the dataset.
dataset
=
data
.
make_long_chain_dataset
(
state_size
=
state_size
,
num_obs
=
num_obs
,
steps_per_obs
=
steps_per_obs
,
batch_size
=
batch_size
,
num_samples
=
num_samples
,
variance
=
FLAGS
.
variance
,
observation_variance
=
FLAGS
.
observation_variance
,
dtype
=
dtype
,
observation_type
=
FLAGS
.
observation_type
,
transition_type
=
FLAGS
.
transition_type
,
fixed_observation
=
FLAGS
.
fixed_observation
)
itr
=
dataset
.
make_one_shot_iterator
()
_
,
observations
=
itr
.
get_next
()
# Make the dataset for eval
eval_dataset
=
data
.
make_long_chain_dataset
(
state_size
=
state_size
,
num_obs
=
num_obs
,
steps_per_obs
=
steps_per_obs
,
batch_size
=
batch_size
,
num_samples
=
num_eval_samples
,
variance
=
FLAGS
.
variance
,
observation_variance
=
FLAGS
.
observation_variance
,
dtype
=
dtype
,
observation_type
=
FLAGS
.
observation_type
,
transition_type
=
FLAGS
.
transition_type
,
fixed_observation
=
FLAGS
.
fixed_observation
)
eval_itr
=
eval_dataset
.
make_one_shot_iterator
()
_
,
eval_observations
=
eval_itr
.
get_next
()
# Make the model.
model
=
models
.
LongChainModel
.
create
(
state_size
,
num_obs
,
steps_per_obs
,
observation_type
=
FLAGS
.
observation_type
,
transition_type
=
FLAGS
.
transition_type
,
variance
=
FLAGS
.
variance
,
observation_variance
=
FLAGS
.
observation_variance
,
dtype
=
tf
.
as_dtype
(
dtype
),
disable_r
=
FLAGS
.
disable_r
)
# Compute the bound and loss
if
bound
==
"iwae"
:
(
_
,
losses
,
ema_op
,
_
,
_
)
=
bounds
.
iwae
(
model
,
observations
,
num_timesteps
,
num_samples
=
num_samples
)
(
eval_log_p_hat
,
_
,
_
,
_
,
eval_log_weights
)
=
bounds
.
iwae
(
model
,
eval_observations
,
num_timesteps
,
num_samples
=
num_eval_samples
,
summarize
=
False
)
eval_log_p_hat
=
tf
.
reduce_mean
(
eval_log_p_hat
)
elif
bound
==
"fivo"
or
"fivo-aux"
:
(
_
,
losses
,
ema_op
,
_
,
_
)
=
bounds
.
fivo
(
model
,
observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
use_resampling_grads
=
use_resampling_grads
,
resampling_type
=
FLAGS
.
resampling_method
,
aux
=
(
"aux"
in
bound
),
num_samples
=
num_samples
)
(
eval_log_p_hat
,
_
,
_
,
_
,
eval_log_weights
)
=
bounds
.
fivo
(
model
,
eval_observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
use_resampling_grads
=
False
,
resampling_type
=
"multinomial"
,
aux
=
(
"aux"
in
bound
),
num_samples
=
num_eval_samples
,
summarize
=
False
)
eval_log_p_hat
=
tf
.
reduce_mean
(
eval_log_p_hat
)
summ
.
summarize_ess
(
eval_log_weights
,
only_last_timestep
=
True
)
tf
.
summary
.
scalar
(
"log_p_hat"
,
eval_log_p_hat
)
# Compute and apply grads.
global_step
=
tf
.
train
.
get_or_create_global_step
()
apply_grads
=
make_apply_grads_op
(
losses
,
global_step
,
learning_rate
,
lr_decay_steps
)
# Update the emas after applying the grads.
with
tf
.
control_dependencies
([
apply_grads
]):
train_op
=
tf
.
group
(
ema_op
)
# We can't calculate the likelihood for most of these models
# so we just return zeros.
eval_likelihood
=
tf
.
zeros
([],
dtype
=
dtype
)
return
global_step
,
train_op
,
eval_log_p_hat
,
eval_likelihood
def
create_graph
(
bound
,
state_size
,
num_timesteps
,
batch_size
,
num_samples
,
num_eval_samples
,
resampling_schedule
,
use_resampling_grads
,
learning_rate
,
lr_decay_steps
,
train_p
,
dtype
=
'float64'
):
if
FLAGS
.
use_bs
:
true_bs
=
None
else
:
true_bs
=
[
np
.
zeros
([
state_size
]).
astype
(
dtype
)
for
_
in
xrange
(
num_timesteps
)]
# Make the dataset.
true_bs
,
dataset
=
data
.
make_dataset
(
bs
=
true_bs
,
state_size
=
state_size
,
num_timesteps
=
num_timesteps
,
batch_size
=
batch_size
,
num_samples
=
num_samples
,
variance
=
FLAGS
.
variance
,
prior_type
=
FLAGS
.
p_type
,
bimodal_prior_weight
=
FLAGS
.
bimodal_prior_weight
,
bimodal_prior_mean
=
FLAGS
.
bimodal_prior_mean
,
transition_type
=
FLAGS
.
transition_type
,
fixed_observation
=
FLAGS
.
fixed_observation
,
dtype
=
dtype
)
itr
=
dataset
.
make_one_shot_iterator
()
_
,
observations
=
itr
.
get_next
()
# Make the dataset for eval
_
,
eval_dataset
=
data
.
make_dataset
(
bs
=
true_bs
,
state_size
=
state_size
,
num_timesteps
=
num_timesteps
,
batch_size
=
num_eval_samples
,
num_samples
=
num_eval_samples
,
variance
=
FLAGS
.
variance
,
prior_type
=
FLAGS
.
p_type
,
bimodal_prior_weight
=
FLAGS
.
bimodal_prior_weight
,
bimodal_prior_mean
=
FLAGS
.
bimodal_prior_mean
,
transition_type
=
FLAGS
.
transition_type
,
fixed_observation
=
FLAGS
.
fixed_observation
,
dtype
=
dtype
)
eval_itr
=
eval_dataset
.
make_one_shot_iterator
()
_
,
eval_observations
=
eval_itr
.
get_next
()
# Make the model.
if
bound
==
"fivo-aux-td"
:
model
=
models
.
TDModel
.
create
(
state_size
,
num_timesteps
,
variance
=
FLAGS
.
variance
,
train_p
=
train_p
,
p_type
=
FLAGS
.
p_type
,
q_type
=
FLAGS
.
q_type
,
mixing_coeff
=
FLAGS
.
bimodal_prior_weight
,
prior_mode_mean
=
FLAGS
.
bimodal_prior_mean
,
observation_variance
=
FLAGS
.
observation_variance
,
transition_type
=
FLAGS
.
transition_type
,
use_bs
=
FLAGS
.
use_bs
,
dtype
=
tf
.
as_dtype
(
dtype
),
random_seed
=
FLAGS
.
random_seed
)
else
:
model
=
models
.
Model
.
create
(
state_size
,
num_timesteps
,
variance
=
FLAGS
.
variance
,
train_p
=
train_p
,
p_type
=
FLAGS
.
p_type
,
q_type
=
FLAGS
.
q_type
,
mixing_coeff
=
FLAGS
.
bimodal_prior_weight
,
prior_mode_mean
=
FLAGS
.
bimodal_prior_mean
,
observation_variance
=
FLAGS
.
observation_variance
,
transition_type
=
FLAGS
.
transition_type
,
use_bs
=
FLAGS
.
use_bs
,
r_sigma_init
=
FLAGS
.
r_sigma_init
,
dtype
=
tf
.
as_dtype
(
dtype
),
random_seed
=
FLAGS
.
random_seed
)
# Compute the bound and loss
if
bound
==
"iwae"
:
(
_
,
losses
,
ema_op
,
_
,
_
)
=
bounds
.
iwae
(
model
,
observations
,
num_timesteps
,
num_samples
=
num_samples
)
(
eval_log_p_hat
,
_
,
_
,
eval_states
,
eval_log_weights
)
=
bounds
.
iwae
(
model
,
eval_observations
,
num_timesteps
,
num_samples
=
num_eval_samples
,
summarize
=
True
)
eval_log_p_hat
=
tf
.
reduce_mean
(
eval_log_p_hat
)
elif
"fivo"
in
bound
:
if
bound
==
"fivo-aux-td"
:
(
_
,
losses
,
ema_op
,
_
,
_
)
=
bounds
.
fivo_aux_td
(
model
,
observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
num_samples
=
num_samples
)
(
eval_log_p_hat
,
_
,
_
,
eval_states
,
eval_log_weights
)
=
bounds
.
fivo_aux_td
(
model
,
eval_observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
num_samples
=
num_eval_samples
,
summarize
=
True
)
else
:
(
_
,
losses
,
ema_op
,
_
,
_
)
=
bounds
.
fivo
(
model
,
observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
use_resampling_grads
=
use_resampling_grads
,
resampling_type
=
FLAGS
.
resampling_method
,
aux
=
(
"aux"
in
bound
),
num_samples
=
num_samples
)
(
eval_log_p_hat
,
_
,
_
,
eval_states
,
eval_log_weights
)
=
bounds
.
fivo
(
model
,
eval_observations
,
num_timesteps
,
resampling_schedule
=
resampling_schedule
,
use_resampling_grads
=
False
,
resampling_type
=
"multinomial"
,
aux
=
(
"aux"
in
bound
),
num_samples
=
num_eval_samples
,
summarize
=
True
)
eval_log_p_hat
=
tf
.
reduce_mean
(
eval_log_p_hat
)
summ
.
summarize_ess
(
eval_log_weights
,
only_last_timestep
=
True
)
# if FLAGS.p_type == "bimodal":
# # create the observations that showcase the model.
# mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.],
# dtype=tf.float64)
# mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1)
# k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean)
# explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k
# explain_obs = tf.tile(explain_obs, [num_eval_samples, 1])
# # run the model on the explainable observations
# if bound == "iwae":
# (_, _, _, explain_states, explain_log_weights) = bounds.iwae(
# model,
# explain_obs,
# num_timesteps,
# num_samples=num_eval_samples)
# elif bound == "fivo" or "fivo-aux":
# (_, _, _, explain_states, explain_log_weights) = bounds.fivo(
# model,
# explain_obs,
# num_timesteps,
# resampling_schedule=resampling_schedule,
# use_resampling_grads=False,
# resampling_type="multinomial",
# aux=("aux" in bound),
# num_samples=num_eval_samples)
# summ.summarize_particles(explain_states,
# explain_log_weights,
# explain_obs,
# model)
# Calculate the true likelihood.
if
hasattr
(
model
.
p
,
'likelihood'
)
and
callable
(
getattr
(
model
.
p
,
'likelihood'
)):
eval_likelihood
=
model
.
p
.
likelihood
(
eval_observations
)
/
FLAGS
.
num_timesteps
else
:
eval_likelihood
=
tf
.
zeros_like
(
eval_log_p_hat
)
tf
.
summary
.
scalar
(
"log_p_hat"
,
eval_log_p_hat
)
tf
.
summary
.
scalar
(
"likelihood"
,
eval_likelihood
)
tf
.
summary
.
scalar
(
"bound_gap"
,
eval_likelihood
-
eval_log_p_hat
)
summ
.
summarize_model
(
model
,
true_bs
,
eval_observations
,
eval_states
,
bound
,
summarize_r
=
not
bound
==
"fivo-aux-td"
)
# Compute and apply grads.
global_step
=
tf
.
train
.
get_or_create_global_step
()
apply_grads
=
make_apply_grads_op
(
losses
,
global_step
,
learning_rate
,
lr_decay_steps
)
# Update the emas after applying the grads.
with
tf
.
control_dependencies
([
apply_grads
]):
train_op
=
tf
.
group
(
ema_op
)
#train_op = tf.group(ema_op, add_check_numerics_ops())
return
global_step
,
train_op
,
eval_log_p_hat
,
eval_likelihood
def
parse_resampling_schedule
(
schedule
,
num_timesteps
):
schedule
=
schedule
.
strip
().
lower
()
if
schedule
==
"always"
:
return
[
True
]
*
(
num_timesteps
-
1
)
+
[
False
]
elif
schedule
==
"never"
:
return
[
False
]
*
num_timesteps
elif
"every"
in
schedule
:
n
=
int
(
schedule
.
split
(
"_"
)[
1
])
return
[(
i
+
1
)
%
n
==
0
for
i
in
xrange
(
num_timesteps
)]
else
:
sched
=
[
x
.
strip
()
==
"true"
for
x
in
schedule
.
split
(
","
)]
assert
len
(
sched
)
==
num_timesteps
,
"Wrong number of timesteps in resampling schedule."
return
sched
def
create_log_hook
(
step
,
eval_log_p_hat
,
eval_likelihood
):
def
summ_formatter
(
d
):
return
(
"Step {step}, log p_hat: {log_p_hat:.5f} likelihood: {likelihood:.5f}"
.
format
(
**
d
))
hook
=
tf
.
train
.
LoggingTensorHook
(
{
"step"
:
step
,
"log_p_hat"
:
eval_log_p_hat
,
"likelihood"
:
eval_likelihood
,
},
every_n_iter
=
FLAGS
.
summarize_every
,
formatter
=
summ_formatter
)
return
hook
def
create_infrequent_summary_hook
():
infrequent_summary_hook
=
tf
.
train
.
SummarySaverHook
(
save_steps
=
10000
,
output_dir
=
FLAGS
.
logdir
,
summary_op
=
tf
.
summary
.
merge_all
(
key
=
"infrequent_summaries"
)
)
return
infrequent_summary_hook
def
main
(
unused_argv
):
if
FLAGS
.
model
==
"long_chain"
:
resampling_schedule
=
parse_resampling_schedule
(
FLAGS
.
resampling
,
FLAGS
.
num_timesteps
+
1
)
else
:
resampling_schedule
=
parse_resampling_schedule
(
FLAGS
.
resampling
,
FLAGS
.
num_timesteps
)
if
FLAGS
.
random_seed
is
None
:
seed
=
np
.
random
.
randint
(
0
,
high
=
10000
)
else
:
seed
=
FLAGS
.
random_seed
tf
.
logging
.
info
(
"Using random seed %d"
,
seed
)
if
FLAGS
.
model
==
"long_chain"
:
assert
FLAGS
.
q_type
==
"normal"
,
"Q type %s not supported for long chain models"
%
FLAGS
.
q_type
assert
FLAGS
.
p_type
==
"unimodal"
,
"Bimodal priors are not supported for long chain models"
assert
not
FLAGS
.
use_bs
,
"Bs are not supported with long chain models"
assert
FLAGS
.
num_timesteps
==
FLAGS
.
num_observations
*
FLAGS
.
steps_per_observation
,
"Num timesteps does not match."
assert
FLAGS
.
bound
!=
"fivo-aux-td"
,
"TD Training is not compatible with long chain models."
if
FLAGS
.
model
==
"forward"
:
if
"nonlinear"
not
in
FLAGS
.
p_type
:
assert
FLAGS
.
transition_type
==
models
.
STANDARD_TRANSITION
,
"Non-standard transitions not supported by the forward model."
assert
FLAGS
.
observation_type
==
models
.
STANDARD_OBSERVATION
,
"Non-standard observations not supported by the forward model."
assert
FLAGS
.
observation_variance
is
None
,
"Forward model does not support observation variance."
assert
FLAGS
.
num_observations
==
1
,
"Forward model only supports 1 observation."
if
"relaxed"
in
FLAGS
.
resampling_method
:
FLAGS
.
use_resampling_grads
=
False
assert
FLAGS
.
bound
!=
"fivo-aux-td"
,
"TD Training is not compatible with relaxed resampling."
if
FLAGS
.
observation_variance
is
None
:
FLAGS
.
observation_variance
=
FLAGS
.
variance
if
FLAGS
.
p_type
==
"bimodal"
:
assert
FLAGS
.
bimodal_prior_mean
is
not
None
,
"Must specify prior mean if using bimodal p."
if
FLAGS
.
p_type
==
"nonlinear"
or
FLAGS
.
p_type
==
"nonlinear-cauchy"
:
assert
not
FLAGS
.
use_bs
,
"Using bs is not compatible with the nonlinear model."
g
=
tf
.
Graph
()
with
g
.
as_default
():
# Set the seeds.
tf
.
set_random_seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
FLAGS
.
model
==
"long_chain"
:
(
global_step
,
train_op
,
eval_log_p_hat
,
eval_likelihood
)
=
create_long_chain_graph
(
FLAGS
.
bound
,
FLAGS
.
state_size
,
FLAGS
.
num_observations
,
FLAGS
.
steps_per_observation
,
FLAGS
.
batch_size
,
FLAGS
.
num_samples
,
FLAGS
.
num_eval_samples
,
resampling_schedule
,
FLAGS
.
use_resampling_grads
,
FLAGS
.
learning_rate
,
FLAGS
.
decay_steps
)
else
:
(
global_step
,
train_op
,
eval_log_p_hat
,
eval_likelihood
)
=
create_graph
(
FLAGS
.
bound
,
FLAGS
.
state_size
,
FLAGS
.
num_timesteps
,
FLAGS
.
batch_size
,
FLAGS
.
num_samples
,
FLAGS
.
num_eval_samples
,
resampling_schedule
,
FLAGS
.
use_resampling_grads
,
FLAGS
.
learning_rate
,
FLAGS
.
decay_steps
,
FLAGS
.
train_p
)
log_hooks
=
[
create_log_hook
(
global_step
,
eval_log_p_hat
,
eval_likelihood
)]
if
len
(
tf
.
get_collection
(
"infrequent_summaries"
))
>
0
:
log_hooks
.
append
(
create_infrequent_summary_hook
())
tf
.
logging
.
info
(
"trainable variables:"
)
tf
.
logging
.
info
([
v
.
name
for
v
in
tf
.
trainable_variables
()])
tf
.
logging
.
info
(
"p vars:"
)
tf
.
logging
.
info
([
v
.
name
for
v
in
tf
.
get_collection
(
"P_VARS"
)])
tf
.
logging
.
info
(
"q vars:"
)
tf
.
logging
.
info
([
v
.
name
for
v
in
tf
.
get_collection
(
"Q_VARS"
)])
tf
.
logging
.
info
(
"r vars:"
)
tf
.
logging
.
info
([
v
.
name
for
v
in
tf
.
get_collection
(
"R_VARS"
)])
tf
.
logging
.
info
(
"r tilde vars:"
)
tf
.
logging
.
info
([
v
.
name
for
v
in
tf
.
get_collection
(
"R_TILDE_VARS"
)])
with
tf
.
train
.
MonitoredTrainingSession
(
master
=
""
,
is_chief
=
True
,
hooks
=
log_hooks
,
checkpoint_dir
=
FLAGS
.
logdir
,
save_checkpoint_secs
=
120
,
save_summaries_steps
=
FLAGS
.
summarize_every
,
log_step_count_steps
=
FLAGS
.
summarize_every
)
as
sess
:
cur_step
=
-
1
while
True
:
if
sess
.
should_stop
()
or
cur_step
>
FLAGS
.
max_steps
:
break
# run a step
_
,
cur_step
=
sess
.
run
([
train_op
,
global_step
])
if
__name__
==
"__main__"
:
tf
.
app
.
run
(
main
)
research/fivo/fivo/__init__.py
0 → 100644
View file @
27b4acd4
research/fivo/bounds.py
→
research/fivo/
fivo/
bounds.py
View file @
27b4acd4
# Copyright 201
7
The TensorFlow Authors All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -23,13 +23,15 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
functools
import
tensorflow
as
tf
import
nested_utils
as
nested
from
fivo
import
nested_utils
as
nested
from
fivo
import
smc
def
iwae
(
cel
l
,
input
s
,
def
iwae
(
mode
l
,
observation
s
,
seq_lengths
,
num_samples
=
1
,
parallel_iterations
=
30
,
...
...
@@ -45,13 +47,13 @@ def iwae(cell,
When num_samples = 1, this bound becomes the evidence lower bound (ELBO).
Args:
cel
l: A
callable that implements one timestep of the model. Se
e
models/vrnn.py for an example.
input
s: The inputs to the model. A potentially nested list or tuple of
mode
l: A
subclass of ELBOTrainableSequenceModel that implements on
e
timestep of the model. See
models/vrnn.py for an example.
observation
s: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively.
At each
timestep 'cell'
will be
call
ed with
a slice of the Tensors in inputs
.
dimensions, which represent time and the batch respectively.
The model
will be
provid
ed with
the observations before computing the bound
.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of samples to use.
...
...
@@ -63,98 +65,28 @@ def iwae(cell,
Returns:
log_p_hat: A Tensor of shape [batch_size] containing IWAE's estimate of the
log marginal probability of the observations.
kl: A Tensor of shape [batch_size] containing the kl divergence
from q(z|x) to p(z), averaged over samples.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep. Will not be valid for
timesteps past the end of a sequence.
log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
effective sample size at each timestep. Will not be valid for timesteps
past the end of a sequence.
"""
batch_size
=
tf
.
shape
(
seq_lengths
)[
0
]
max_seq_len
=
tf
.
reduce_max
(
seq_lengths
)
seq_mask
=
tf
.
transpose
(
tf
.
sequence_mask
(
seq_lengths
,
maxlen
=
max_seq_len
,
dtype
=
tf
.
float32
),
perm
=
[
1
,
0
])
if
num_samples
>
1
:
inputs
,
seq_mask
=
nested
.
tile_tensors
([
inputs
,
seq_mask
],
[
1
,
num_samples
])
inputs_ta
,
mask_ta
=
nested
.
tas_for_tensors
([
inputs
,
seq_mask
],
max_seq_len
)
t0
=
tf
.
constant
(
0
,
tf
.
int32
)
init_states
=
cell
.
zero_state
(
batch_size
*
num_samples
,
tf
.
float32
)
ta_names
=
[
'log_weights'
,
'log_ess'
]
tas
=
[
tf
.
TensorArray
(
tf
.
float32
,
max_seq_len
,
name
=
'%s_ta'
%
n
)
for
n
in
ta_names
]
log_weights_acc
=
tf
.
zeros
([
num_samples
,
batch_size
],
dtype
=
tf
.
float32
)
kl_acc
=
tf
.
zeros
([
num_samples
*
batch_size
],
dtype
=
tf
.
float32
)
accs
=
(
log_weights_acc
,
kl_acc
)
def
while_predicate
(
t
,
*
unused_args
):
return
t
<
max_seq_len
def
while_step
(
t
,
rnn_state
,
tas
,
accs
):
"""Implements one timestep of IWAE computation."""
log_weights_acc
,
kl_acc
=
accs
cur_inputs
,
cur_mask
=
nested
.
read_tas
([
inputs_ta
,
mask_ta
],
t
)
# Run the cell for one step.
log_q_z
,
log_p_z
,
log_p_x_given_z
,
kl
,
new_state
=
cell
(
cur_inputs
,
rnn_state
,
cur_mask
,
)
# Compute the incremental weight and use it to update the current
# accumulated weight.
kl_acc
+=
kl
*
cur_mask
log_alpha
=
(
log_p_x_given_z
+
log_p_z
-
log_q_z
)
*
cur_mask
log_alpha
=
tf
.
reshape
(
log_alpha
,
[
num_samples
,
batch_size
])
log_weights_acc
+=
log_alpha
# Calculate the effective sample size.
ess_num
=
2
*
tf
.
reduce_logsumexp
(
log_weights_acc
,
axis
=
0
)
ess_denom
=
tf
.
reduce_logsumexp
(
2
*
log_weights_acc
,
axis
=
0
)
log_ess
=
ess_num
-
ess_denom
# Update the Tensorarrays and accumulators.
ta_updates
=
[
log_weights_acc
,
log_ess
]
new_tas
=
[
ta
.
write
(
t
,
x
)
for
ta
,
x
in
zip
(
tas
,
ta_updates
)]
new_accs
=
(
log_weights_acc
,
kl_acc
)
return
t
+
1
,
new_state
,
new_tas
,
new_accs
_
,
_
,
tas
,
accs
=
tf
.
while_loop
(
while_predicate
,
while_step
,
loop_vars
=
(
t0
,
init_states
,
tas
,
accs
),
log_p_hat
,
log_weights
,
_
,
final_state
=
fivo
(
model
,
observations
,
seq_lengths
,
num_samples
=
num_samples
,
resampling_criterion
=
smc
.
never_resample_criterion
,
parallel_iterations
=
parallel_iterations
,
swap_memory
=
swap_memory
)
log_weights
,
log_ess
=
[
x
.
stack
()
for
x
in
tas
]
final_log_weights
,
kl
=
accs
log_p_hat
=
(
tf
.
reduce_logsumexp
(
final_log_weights
,
axis
=
0
)
-
tf
.
log
(
tf
.
to_float
(
num_samples
)))
kl
=
tf
.
reduce_mean
(
tf
.
reshape
(
kl
,
[
num_samples
,
batch_size
]),
axis
=
0
)
log_weights
=
tf
.
transpose
(
log_weights
,
perm
=
[
0
,
2
,
1
])
return
log_p_hat
,
kl
,
log_weights
,
log_ess
return
log_p_hat
,
log_weights
,
final_state
def
ess_criterion
(
num_samples
,
log_ess
,
unused_t
):
"""A criterion that resamples based on effective sample size."""
return
log_ess
<=
tf
.
log
(
num_samples
/
2.0
)
def
never_resample_criterion
(
unused_num_samples
,
log_ess
,
unused_t
):
"""A criterion that never resamples."""
return
tf
.
cast
(
tf
.
zeros_like
(
log_ess
),
tf
.
bool
)
def
always_resample_criterion
(
unused_num_samples
,
log_ess
,
unused_t
):
"""A criterion resamples at every timestep."""
return
tf
.
cast
(
tf
.
ones_like
(
log_ess
),
tf
.
bool
)
def
fivo
(
cell
,
inputs
,
def
fivo
(
model
,
observations
,
seq_lengths
,
num_samples
=
1
,
resampling_criterion
=
ess_criterion
,
resampling_criterion
=
smc
.
ess_criterion
,
resampling_type
=
'multinomial'
,
relaxed_resampling_temperature
=
0.5
,
parallel_iterations
=
30
,
swap_memory
=
True
,
random_seed
=
None
):
...
...
@@ -170,21 +102,26 @@ def fivo(cell,
When the resampling criterion is "never resample", this bound becomes IWAE.
Args:
cel
l: A
callable that implements one timestep of the model. Se
e
models/vrnn.py for an example.
input
s: The inputs to the model. A potentially nested list or tuple of
mode
l: A
subclass of ELBOTrainableSequenceModel that implements on
e
timestep of the model. See
models/vrnn.py for an example.
observation
s: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively.
At each
timestep 'cell'
will be
call
ed with
a slice of the Tensors in inputs
.
dimensions, which represent time and the batch respectively.
The model
will be
provid
ed with
the observations before computing the bound
.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of particles to use in each particle filter.
resampling_criterion: The resampling criterion to use for this particle
filter. Must accept the number of samples, the
effective sample size
,
filter. Must accept the number of samples, the
current log weights
,
and the current timestep and return a boolean Tensor of shape [batch_size]
indicating whether each particle filter should resample. See
ess_criterion and related functions defined in this file for examples.
ess_criterion and related functions for examples. When
resampling_criterion is never_resample_criterion, resampling_fn is ignored
and never called.
resampling_type: The type of resampling, one of "multinomial" or "relaxed".
relaxed_resampling_temperature: A positive temperature only used for relaxed
resampling.
parallel_iterations: The number of parallel iterations to use for the
internal while loop. Note that values greater than 1 can introduce
non-determinism even when random_seed is provided.
...
...
@@ -196,28 +133,17 @@ def fivo(cell,
Returns:
log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
log marginal probability of the observations.
kl: A Tensor of shape [batch_size] containing the sum over time of the kl
divergence from q_t(z_t|x) to p_t(z_t), averaged over particles. Note that
this includes kl terms from trajectories that are culled during resampling
steps.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep of the particle filter. Note
that on timesteps when a resampling operation is performed the log weights
are reset to 0. Will not be valid for timesteps past the end of a
sequence.
log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
effective sample size of each particle filter at each timestep. Will not
be valid for timesteps past the end of a sequence.
resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
particle filters resampled. Will be 1.0 on timesteps when resampling
occurred and 0.0 on timesteps when it did not.
"""
# batch_size
represent
s the number of particle filters running in parallel.
# batch_size
i
s the number of particle filters running in parallel.
batch_size
=
tf
.
shape
(
seq_lengths
)[
0
]
max_seq_len
=
tf
.
reduce_max
(
seq_lengths
)
seq_mask
=
tf
.
transpose
(
tf
.
sequence_mask
(
seq_lengths
,
maxlen
=
max_seq_len
,
dtype
=
tf
.
float32
),
perm
=
[
1
,
0
])
# Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as:
...
...
@@ -228,96 +154,164 @@ def fivo(cell,
# particle 2 of particle filter 1
# ...
# particle num_samples of particle filter batch_size
if
num_samples
>
1
:
inputs
,
seq_mask
=
nested
.
tile_tensors
([
inputs
,
seq_mask
],
[
1
,
num_samples
])
inputs_ta
,
mask_ta
=
nested
.
tas_for_tensors
([
inputs
,
seq_mask
],
max_seq_len
)
t0
=
tf
.
constant
(
0
,
tf
.
int32
)
init_states
=
cell
.
zero_state
(
batch_size
*
num_samples
,
tf
.
float32
)
ta_names
=
[
'log_weights'
,
'log_ess'
,
'resampled'
]
tas
=
[
tf
.
TensorArray
(
tf
.
float32
,
max_seq_len
,
name
=
'%s_ta'
%
n
)
for
n
in
ta_names
]
log_weights_acc
=
tf
.
zeros
([
num_samples
,
batch_size
],
dtype
=
tf
.
float32
)
log_p_hat_acc
=
tf
.
zeros
([
batch_size
],
dtype
=
tf
.
float32
)
kl_acc
=
tf
.
zeros
([
num_samples
*
batch_size
],
dtype
=
tf
.
float32
)
accs
=
(
log_weights_acc
,
log_p_hat_acc
,
kl_acc
)
observations
=
nested
.
tile_tensors
(
observations
,
[
1
,
num_samples
])
tiled_seq_lengths
=
tf
.
tile
(
seq_lengths
,
[
num_samples
])
model
.
set_observations
(
observations
,
tiled_seq_lengths
)
if
resampling_type
==
'multinomial'
:
resampling_fn
=
smc
.
multinomial_resampling
elif
resampling_type
==
'relaxed'
:
resampling_fn
=
functools
.
partial
(
smc
.
relaxed_resampling
,
temperature
=
relaxed_resampling_temperature
)
resampling_fn
=
functools
.
partial
(
resampling_fn
,
random_seed
=
random_seed
)
def
transition_fn
(
prev_state
,
t
):
if
prev_state
is
None
:
return
model
.
zero_state
(
batch_size
*
num_samples
,
tf
.
float32
)
return
model
.
propose_and_weight
(
prev_state
,
t
)
log_p_hat
,
log_weights
,
resampled
,
final_state
,
_
=
smc
.
smc
(
transition_fn
,
seq_lengths
,
num_particles
=
num_samples
,
resampling_criterion
=
resampling_criterion
,
resampling_fn
=
resampling_fn
,
parallel_iterations
=
parallel_iterations
,
swap_memory
=
swap_memory
)
def
while_predicate
(
t
,
*
unused_args
):
return
t
<
max_seq_len
return
log_p_hat
,
log_weights
,
resampled
,
final_state
def
fivo_aux_td
(
model
,
observations
,
seq_lengths
,
num_samples
=
1
,
resampling_criterion
=
smc
.
ess_criterion
,
resampling_type
=
'multinomial'
,
relaxed_resampling_temperature
=
0.5
,
parallel_iterations
=
30
,
swap_memory
=
True
,
random_seed
=
None
):
"""Experimental."""
# batch_size is the number of particle filters running in parallel.
batch_size
=
tf
.
shape
(
seq_lengths
)[
0
]
max_seq_len
=
tf
.
reduce_max
(
seq_lengths
)
def
while_step
(
t
,
rnn_state
,
tas
,
accs
):
"""Implements one timestep of FIVO computation."""
log_weights_acc
,
log_p_hat_acc
,
kl_acc
=
accs
cur_inputs
,
cur_mask
=
nested
.
read_tas
([
inputs_ta
,
mask_ta
],
t
)
# Run the cell for one step.
log_q_z
,
log_p_z
,
log_p_x_given_z
,
kl
,
new_state
=
cell
(
cur_inputs
,
rnn_state
,
cur_mask
,
)
# Compute the incremental weight and use it to update the current
# accumulated weight.
kl_acc
+=
kl
*
cur_mask
log_alpha
=
(
log_p_x_given_z
+
log_p_z
-
log_q_z
)
*
cur_mask
log_alpha
=
tf
.
reshape
(
log_alpha
,
[
num_samples
,
batch_size
])
log_weights_acc
+=
log_alpha
# Calculate the effective sample size.
ess_num
=
2
*
tf
.
reduce_logsumexp
(
log_weights_acc
,
axis
=
0
)
ess_denom
=
tf
.
reduce_logsumexp
(
2
*
log_weights_acc
,
axis
=
0
)
log_ess
=
ess_num
-
ess_denom
# Calculate the ancestor indices via resampling. Because we maintain the
# log unnormalized weights, we pass the weights in as logits, allowing
# the distribution object to apply a softmax and normalize them.
resampling_dist
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
tf
.
transpose
(
log_weights_acc
,
perm
=
[
1
,
0
]))
ancestor_inds
=
tf
.
stop_gradient
(
resampling_dist
.
sample
(
sample_shape
=
num_samples
,
seed
=
random_seed
))
# Because the batch is flattened and laid out as discussed
# above, we must modify ancestor_inds to index the proper samples.
# The particles in the ith filter are distributed every batch_size rows
# in the batch, and offset i rows from the top. So, to correct the indices
# we multiply by the batch_size and add the proper offset. Crucially,
# when ancestor_inds is flattened the layout of the batch is maintained.
offset
=
tf
.
expand_dims
(
tf
.
range
(
batch_size
),
0
)
ancestor_inds
=
tf
.
reshape
(
ancestor_inds
*
batch_size
+
offset
,
[
-
1
])
noresample_inds
=
tf
.
range
(
num_samples
*
batch_size
)
# Decide whether or not we should resample; don't resample if we are past
# the end of a sequence.
should_resample
=
resampling_criterion
(
num_samples
,
log_ess
,
t
)
should_resample
=
tf
.
logical_and
(
should_resample
,
cur_mask
[:
batch_size
]
>
0.
)
float_should_resample
=
tf
.
to_float
(
should_resample
)
ancestor_inds
=
tf
.
where
(
tf
.
tile
(
should_resample
,
[
num_samples
]),
ancestor_inds
,
noresample_inds
)
new_state
=
nested
.
gather_tensors
(
new_state
,
ancestor_inds
)
# Update the TensorArrays before we reset the weights so that we capture
# the incremental weights and not zeros.
ta_updates
=
[
log_weights_acc
,
log_ess
,
float_should_resample
]
new_tas
=
[
ta
.
write
(
t
,
x
)
for
ta
,
x
in
zip
(
tas
,
ta_updates
)]
# For the particle filters that resampled, update log_p_hat and
# reset weights to zero.
# Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as:
# particle 1 of particle filter 1
# particle 1 of particle filter 2
# ...
# particle 1 of particle filter batch_size
# particle 2 of particle filter 1
# ...
# particle num_samples of particle filter batch_size
observations
=
nested
.
tile_tensors
(
observations
,
[
1
,
num_samples
])
tiled_seq_lengths
=
tf
.
tile
(
seq_lengths
,
[
num_samples
])
model
.
set_observations
(
observations
,
tiled_seq_lengths
)
if
resampling_type
==
'multinomial'
:
resampling_fn
=
smc
.
multinomial_resampling
elif
resampling_type
==
'relaxed'
:
resampling_fn
=
functools
.
partial
(
smc
.
relaxed_resampling
,
temperature
=
relaxed_resampling_temperature
)
resampling_fn
=
functools
.
partial
(
resampling_fn
,
random_seed
=
random_seed
)
def
transition_fn
(
prev_state
,
t
):
if
prev_state
is
None
:
model_init_state
=
model
.
zero_state
(
batch_size
*
num_samples
,
tf
.
float32
)
return
(
tf
.
zeros
([
num_samples
*
batch_size
],
dtype
=
tf
.
float32
),
(
tf
.
zeros
([
num_samples
*
batch_size
,
model
.
latent_size
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
num_samples
*
batch_size
,
model
.
latent_size
],
dtype
=
tf
.
float32
)),
model_init_state
)
prev_log_r
,
prev_log_r_tilde
,
prev_model_state
=
prev_state
(
new_model_state
,
zt
,
log_q_zt
,
log_p_zt
,
log_p_x_given_z
,
log_r_tilde
,
p_ztplus1
)
=
model
(
prev_model_state
,
t
)
r_tilde_mu
,
r_tilde_sigma_sq
=
log_r_tilde
# Compute the weight without r.
log_weight
=
log_p_zt
+
log_p_x_given_z
-
log_q_zt
# Compute log_r and log_r_tilde.
p_mu
=
tf
.
stop_gradient
(
p_ztplus1
.
mean
())
p_sigma_sq
=
tf
.
stop_gradient
(
p_ztplus1
.
variance
())
log_r
=
(
tf
.
log
(
r_tilde_sigma_sq
)
-
tf
.
log
(
r_tilde_sigma_sq
+
p_sigma_sq
)
-
tf
.
square
(
r_tilde_mu
-
p_mu
)
/
(
r_tilde_sigma_sq
+
p_sigma_sq
))
# log_r is [num_samples*batch_size, latent_size]. We sum it along the last
# dimension to compute log r.
log_r
=
0.5
*
tf
.
reduce_sum
(
log_r
,
axis
=-
1
)
# Compute prev log r tilde
prev_r_tilde_mu
,
prev_r_tilde_sigma_sq
=
prev_log_r_tilde
prev_log_r_tilde
=
-
0.5
*
tf
.
reduce_sum
(
tf
.
square
(
tf
.
stop_gradient
(
zt
)
-
prev_r_tilde_mu
)
/
prev_r_tilde_sigma_sq
,
axis
=-
1
)
# If the sequence is on the last timestep, log_r and log_r_tilde are just zeros.
last_timestep
=
t
>=
(
tiled_seq_lengths
-
1
)
log_r
=
tf
.
where
(
last_timestep
,
tf
.
zeros_like
(
log_r
),
log_r
)
prev_log_r_tilde
=
tf
.
where
(
last_timestep
,
tf
.
zeros_like
(
prev_log_r_tilde
),
prev_log_r_tilde
)
log_weight
+=
tf
.
stop_gradient
(
log_r
-
prev_log_r
)
new_state
=
(
log_r
,
log_r_tilde
,
new_model_state
)
loop_fn_args
=
(
log_r
,
prev_log_r_tilde
,
log_p_x_given_z
,
log_r
-
prev_log_r
)
return
log_weight
,
new_state
,
loop_fn_args
def
loop_fn
(
loop_state
,
loop_args
,
unused_model_state
,
log_weights
,
resampled
,
mask
,
t
):
if
loop_state
is
None
:
return
(
tf
.
zeros
([
batch_size
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
batch_size
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
num_samples
,
batch_size
],
dtype
=
tf
.
float32
))
log_p_hat_acc
,
bellman_loss_acc
,
log_r_diff_acc
=
loop_state
log_r
,
prev_log_r_tilde
,
log_p_x_given_z
,
log_r_diff
=
loop_args
# Compute the log_p_hat update
log_p_hat_update
=
tf
.
reduce_logsumexp
(
log_weights_acc
,
axis
=
0
)
-
tf
.
log
(
tf
.
to_float
(
num_samples
))
log_p_hat_acc
+=
log_p_hat_update
*
float_should_resample
log_weights_acc
*=
(
1.
-
tf
.
tile
(
float_should_resample
[
tf
.
newaxis
,
:],
[
num_samples
,
1
]))
new_accs
=
(
log_weights_acc
,
log_p_hat_acc
,
kl_acc
)
return
t
+
1
,
new_state
,
new_tas
,
new_accs
_
,
_
,
tas
,
accs
=
tf
.
while_loop
(
while_predicate
,
while_step
,
loop_vars
=
(
t0
,
init_states
,
tas
,
accs
),
log_weights
,
axis
=
0
)
-
tf
.
log
(
tf
.
to_float
(
num_samples
))
# If it is the last timestep, we always add the update.
log_p_hat_acc
+=
tf
.
cond
(
t
>=
max_seq_len
-
1
,
lambda
:
log_p_hat_update
,
lambda
:
log_p_hat_update
*
resampled
)
# Compute the Bellman update.
log_r
=
tf
.
reshape
(
log_r
,
[
num_samples
,
batch_size
])
prev_log_r_tilde
=
tf
.
reshape
(
prev_log_r_tilde
,
[
num_samples
,
batch_size
])
log_p_x_given_z
=
tf
.
reshape
(
log_p_x_given_z
,
[
num_samples
,
batch_size
])
mask
=
tf
.
reshape
(
mask
,
[
num_samples
,
batch_size
])
# On the first timestep there is no bellman error because there is no
# prev_log_r_tilde.
mask
=
tf
.
cond
(
tf
.
equal
(
t
,
0
),
lambda
:
tf
.
zeros_like
(
mask
),
lambda
:
mask
)
# On the first timestep also fix up prev_log_r_tilde, which will be -inf.
prev_log_r_tilde
=
tf
.
where
(
tf
.
is_inf
(
prev_log_r_tilde
),
tf
.
zeros_like
(
prev_log_r_tilde
),
prev_log_r_tilde
)
# log_lambda is [num_samples, batch_size]
log_lambda
=
tf
.
reduce_mean
(
prev_log_r_tilde
-
log_p_x_given_z
-
log_r
,
axis
=
0
,
keepdims
=
True
)
bellman_error
=
mask
*
tf
.
square
(
prev_log_r_tilde
-
tf
.
stop_gradient
(
log_lambda
+
log_p_x_given_z
+
log_r
)
)
bellman_loss_acc
+=
tf
.
reduce_mean
(
bellman_error
,
axis
=
0
)
# Compute the log_r_diff update
log_r_diff_acc
+=
mask
*
tf
.
reshape
(
log_r_diff
,
[
num_samples
,
batch_size
])
return
(
log_p_hat_acc
,
bellman_loss_acc
,
log_r_diff_acc
)
log_weights
,
resampled
,
accs
=
smc
.
smc
(
transition_fn
,
seq_lengths
,
num_particles
=
num_samples
,
resampling_criterion
=
resampling_criterion
,
resampling_fn
=
resampling_fn
,
loop_fn
=
loop_fn
,
parallel_iterations
=
parallel_iterations
,
swap_memory
=
swap_memory
)
log_weights
,
log_ess
,
resampled
=
[
x
.
stack
()
for
x
in
tas
]
final_log_weights
,
log_p_hat
,
kl
=
accs
# Add in the final weight update to log_p_hat.
log_p_hat
+=
(
tf
.
reduce_logsumexp
(
final_log_weights
,
axis
=
0
)
-
tf
.
log
(
tf
.
to_float
(
num_samples
)))
kl
=
tf
.
reduce_mean
(
tf
.
reshape
(
kl
,
[
num_samples
,
batch_size
]),
axis
=
0
)
log_weights
=
tf
.
transpose
(
log_weights
,
perm
=
[
0
,
2
,
1
])
return
log_p_hat
,
kl
,
log_weights
,
log_ess
,
resampled
log_p_hat
,
bellman_loss
,
log_r_diff
=
accs
loss_per_seq
=
[
-
log_p_hat
,
bellman_loss
]
tf
.
summary
.
scalar
(
"bellman_loss"
,
tf
.
reduce_mean
(
bellman_loss
/
tf
.
to_float
(
seq_lengths
)))
tf
.
summary
.
scalar
(
"log_r_diff"
,
tf
.
reduce_mean
(
tf
.
reduce_mean
(
log_r_diff
,
axis
=
0
)
/
tf
.
to_float
(
seq_lengths
)))
return
loss_per_seq
,
log_p_hat
,
log_weights
,
resampled
research/fivo/fivo/bounds_test.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.bounds"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
fivo.test_utils
import
create_vrnn
from
fivo
import
bounds
class
BoundsTest
(
tf
.
test
.
TestCase
):
def
test_elbo
(
self
):
"""A golden-value test for the ELBO (the IWAE bound with num_samples=1)."""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
model
,
inputs
,
targets
,
lengths
=
create_vrnn
(
random_seed
=
1234
)
outs
=
bounds
.
iwae
(
model
,
(
inputs
,
targets
),
lengths
,
num_samples
=
1
,
parallel_iterations
=
1
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_p_hat
,
_
,
_
=
sess
.
run
(
outs
)
self
.
assertAllClose
([
-
21.615765
,
-
13.614225
],
log_p_hat
)
def
test_iwae
(
self
):
"""A golden-value test for the IWAE bound."""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
model
,
inputs
,
targets
,
lengths
=
create_vrnn
(
random_seed
=
1234
)
outs
=
bounds
.
iwae
(
model
,
(
inputs
,
targets
),
lengths
,
num_samples
=
4
,
parallel_iterations
=
1
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_p_hat
,
weights
,
_
=
sess
.
run
(
outs
)
self
.
assertAllClose
([
-
23.301426
,
-
13.64028
],
log_p_hat
)
weights_gt
=
np
.
array
(
[[[
-
3.66708851
,
-
2.07074022
,
-
4.91751671
,
-
5.03293562
],
[
-
2.99690723
,
-
3.17782736
,
-
4.50084877
,
-
3.48536515
]],
[[
-
6.2539978
,
-
4.37615728
,
-
7.43738699
,
-
7.85044909
],
[
-
8.27518654
,
-
6.71545124
,
-
8.96198845
,
-
7.05567837
]],
[[
-
9.19093227
,
-
8.01637268
,
-
11.64603615
,
-
10.51128292
],
[
-
12.34527206
,
-
11.54284477
,
-
11.8667469
,
-
9.69417381
]],
[[
-
12.20609856
,
-
10.47217369
,
-
13.66270638
,
-
13.46115875
],
[
-
17.17656708
,
-
16.25190353
,
-
15.28658581
,
-
12.33067703
]],
[[
-
16.14766312
,
-
15.57472229
,
-
17.47755432
,
-
17.98189926
],
[
-
17.17656708
,
-
16.25190353
,
-
15.28658581
,
-
12.33067703
]],
[[
-
20.07182884
,
-
18.43191147
,
-
20.1606636
,
-
21.45263863
],
[
-
17.17656708
,
-
16.25190353
,
-
15.28658581
,
-
12.33067703
]],
[[
-
24.10270691
,
-
22.20865822
,
-
24.14675522
,
-
25.27248383
],
[
-
17.17656708
,
-
16.25190353
,
-
15.28658581
,
-
12.33067703
]]])
self
.
assertAllClose
(
weights_gt
,
weights
)
def
test_fivo
(
self
):
"""A golden-value test for the FIVO bound."""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
model
,
inputs
,
targets
,
lengths
=
create_vrnn
(
random_seed
=
1234
)
outs
=
bounds
.
fivo
(
model
,
(
inputs
,
targets
),
lengths
,
num_samples
=
4
,
random_seed
=
1234
,
parallel_iterations
=
1
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_p_hat
,
weights
,
resampled
,
_
=
sess
.
run
(
outs
)
self
.
assertAllClose
([
-
22.98902512
,
-
14.21689224
],
log_p_hat
)
weights_gt
=
np
.
array
(
[[[
-
3.66708851
,
-
2.07074022
,
-
4.91751671
,
-
5.03293562
],
[
-
2.99690723
,
-
3.17782736
,
-
4.50084877
,
-
3.48536515
]],
[[
-
2.67100811
,
-
2.30541706
,
-
2.34178066
,
-
2.81751347
],
[
-
8.27518654
,
-
6.71545124
,
-
8.96198845
,
-
7.05567837
]],
[[
-
5.65190411
,
-
5.94563246
,
-
6.55041981
,
-
5.4783473
],
[
-
12.34527206
,
-
11.54284477
,
-
11.8667469
,
-
9.69417381
]],
[[
-
8.71947861
,
-
8.40143299
,
-
8.54593086
,
-
8.42822266
],
[
-
4.28782988
,
-
4.50591278
,
-
3.40847206
,
-
2.63650274
]],
[[
-
12.7003831
,
-
13.5039815
,
-
12.3569726
,
-
12.9489622
],
[
-
4.28782988
,
-
4.50591278
,
-
3.40847206
,
-
2.63650274
]],
[[
-
16.4520301
,
-
16.3611698
,
-
15.0314846
,
-
16.4197006
],
[
-
4.28782988
,
-
4.50591278
,
-
3.40847206
,
-
2.63650274
]],
[[
-
20.7010765
,
-
20.1379165
,
-
19.0020351
,
-
20.2395458
],
[
-
4.28782988
,
-
4.50591278
,
-
3.40847206
,
-
2.63650274
]]])
self
.
assertAllClose
(
weights_gt
,
weights
)
resampled_gt
=
np
.
array
(
[[
1.
,
0.
],
[
0.
,
0.
],
[
0.
,
1.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]])
self
.
assertAllClose
(
resampled_gt
,
resampled
)
def
test_fivo_relaxed
(
self
):
"""A golden-value test for the FIVO bound with relaxed sampling."""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
model
,
inputs
,
targets
,
lengths
=
create_vrnn
(
random_seed
=
1234
)
outs
=
bounds
.
fivo
(
model
,
(
inputs
,
targets
),
lengths
,
num_samples
=
4
,
random_seed
=
1234
,
parallel_iterations
=
1
,
resampling_type
=
"relaxed"
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_p_hat
,
weights
,
resampled
,
_
=
sess
.
run
(
outs
)
self
.
assertAllClose
([
-
22.942394
,
-
14.273882
],
log_p_hat
)
weights_gt
=
np
.
array
(
[[[
-
3.66708851
,
-
2.07074118
,
-
4.91751575
,
-
5.03293514
],
[
-
2.99690628
,
-
3.17782831
,
-
4.50084877
,
-
3.48536515
]],
[[
-
2.84939098
,
-
2.30087185
,
-
2.35649204
,
-
2.48417377
],
[
-
8.27518654
,
-
6.71545172
,
-
8.96199131
,
-
7.05567837
]],
[[
-
5.92327023
,
-
5.9433074
,
-
6.5826683
,
-
5.04259014
],
[
-
12.34527206
,
-
11.54284668
,
-
11.86675072
,
-
9.69417477
]],
[[
-
8.95323944
,
-
8.40061855
,
-
8.52760506
,
-
7.99130583
],
[
-
4.58102798
,
-
4.56017351
,
-
3.46283388
,
-
2.65550804
]],
[[
-
12.87836456
,
-
13.49628639
,
-
12.31680107
,
-
12.74228859
],
[
-
4.58102798
,
-
4.56017351
,
-
3.46283388
,
-
2.65550804
]],
[[
-
16.78347397
,
-
16.35150909
,
-
14.98797417
,
-
16.35162735
],
[
-
4.58102798
,
-
4.56017351
,
-
3.46283388
,
-
2.65550804
]],
[[
-
20.81165886
,
-
20.1307621
,
-
18.92229652
,
-
20.17458153
],
[
-
4.58102798
,
-
4.56017351
,
-
3.46283388
,
-
2.65550804
]]])
self
.
assertAllClose
(
weights_gt
,
weights
)
resampled_gt
=
np
.
array
(
[[
1.
,
0.
],
[
0.
,
0.
],
[
0.
,
1.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]])
self
.
assertAllClose
(
resampled_gt
,
resampled
)
def
test_fivo_aux_relaxed
(
self
):
"""A golden-value test for the FIVO-AUX bound with relaxed sampling."""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
model
,
inputs
,
targets
,
lengths
=
create_vrnn
(
random_seed
=
1234
,
use_tilt
=
True
)
outs
=
bounds
.
fivo
(
model
,
(
inputs
,
targets
),
lengths
,
num_samples
=
4
,
random_seed
=
1234
,
parallel_iterations
=
1
,
resampling_type
=
"relaxed"
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_p_hat
,
weights
,
resampled
,
_
=
sess
.
run
(
outs
)
self
.
assertAllClose
([
-
23.1395
,
-
14.271059
],
log_p_hat
)
weights_gt
=
np
.
array
(
[[[
-
5.19826221
,
-
3.55476403
,
-
5.98663855
,
-
6.08058834
],
[
-
6.31685925
,
-
5.70243931
,
-
7.07638931
,
-
6.18138981
]],
[[
-
3.97986865
,
-
3.58831525
,
-
3.85753584
,
-
3.5010016
],
[
-
11.38203049
,
-
8.66213989
,
-
11.23646641
,
-
10.02024746
]],
[[
-
6.62269831
,
-
6.36680222
,
-
6.78096485
,
-
5.80072498
],
[
-
3.55419445
,
-
8.11326408
,
-
3.48766923
,
-
3.08593249
]],
[[
-
10.56472301
,
-
10.16084099
,
-
9.96741676
,
-
8.5270071
],
[
-
6.04880285
,
-
7.80853653
,
-
4.72652149
,
-
3.49711013
]],
[[
-
13.36585426
,
-
16.08720398
,
-
13.33416367
,
-
13.1017189
],
[
-
0.
,
-
0.
,
-
0.
,
-
0.
]],
[[
-
17.54233551
,
-
17.35167503
,
-
16.79163361
,
-
16.51471138
],
[
0.
,
-
0.
,
-
0.
,
-
0.
]],
[[
-
19.74024963
,
-
18.69452858
,
-
17.76246452
,
-
18.76182365
],
[
0.
,
-
0.
,
-
0.
,
-
0.
]]])
self
.
assertAllClose
(
weights_gt
,
weights
)
resampled_gt
=
np
.
array
([[
1.
,
0.
],
[
0.
,
1.
],
[
0.
,
0.
],
[
0.
,
1.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]])
self
.
assertAllClose
(
resampled_gt
,
resampled
)
if
__name__
==
"__main__"
:
np
.
set_printoptions
(
threshold
=
np
.
nan
)
# Used to easily see the gold values.
# Use print(repr(numpy_array)) to print the values.
tf
.
test
.
main
()
research/fivo/fivo/data/__init__.py
0 → 100644
View file @
27b4acd4
research/fivo/data/calculate_pianoroll_mean.py
→
research/fivo/
fivo/
data/calculate_pianoroll_mean.py
View file @
27b4acd4
# Copyright 201
7
The TensorFlow Authors All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
research/fivo/data/create_timit_dataset.py
→
research/fivo/
fivo/
data/create_timit_dataset.py
View file @
27b4acd4
# Copyright 201
7
The TensorFlow Authors All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocesses TIMIT from raw wavfiles to create a set of TFRecords.
"""
...
...
research/fivo/data/datasets.py
→
research/fivo/
fivo/
data/datasets.py
View file @
27b4acd4
# Copyright 201
7
The TensorFlow Authors All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,6 +21,8 @@ from __future__ import division
from
__future__
import
print_function
import
pickle
import
numpy
as
np
from
scipy.sparse
import
coo_matrix
import
tensorflow
as
tf
...
...
@@ -147,6 +149,95 @@ def create_pianoroll_dataset(path,
return
inputs
,
targets
,
lengths
,
tf
.
constant
(
mean
,
dtype
=
tf
.
float32
)
def
create_human_pose_dataset
(
path
,
split
,
batch_size
,
num_parallel_calls
=
DEFAULT_PARALLELISM
,
shuffle
=
False
,
repeat
=
False
,):
"""Creates a human pose dataset.
Args:
path: The path of a pickle file containing the dataset to load.
split: The split to use, can be train, test, or valid.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros. This tensor is mean-centered, with the mean taken from the pickle
file key 'train_mean'.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
mean: A float Tensor of shape [data_dimension] containing the mean loaded
from the pickle file.
"""
# Load the data from disk.
with
tf
.
gfile
.
Open
(
path
,
"r"
)
as
f
:
raw_data
=
pickle
.
load
(
f
)
mean
=
raw_data
[
"train_mean"
]
pose_sequences
=
raw_data
[
split
]
num_examples
=
len
(
pose_sequences
)
num_features
=
pose_sequences
[
0
].
shape
[
1
]
def
pose_generator
():
"""A generator that yields pose data sequences."""
# Each timestep has 32 x values followed by 32 y values so is 64
# dimensional.
for
pose_sequence
in
pose_sequences
:
yield
pose_sequence
,
pose_sequence
.
shape
[
0
]
dataset
=
tf
.
data
.
Dataset
.
from_generator
(
pose_generator
,
output_types
=
(
tf
.
float64
,
tf
.
int64
),
output_shapes
=
([
None
,
num_features
],
[]))
if
repeat
:
dataset
=
dataset
.
repeat
()
if
shuffle
:
dataset
=
dataset
.
shuffle
(
num_examples
)
# Batch sequences togther, padding them to a common length in time.
dataset
=
dataset
.
padded_batch
(
batch_size
,
padded_shapes
=
([
None
,
num_features
],
[]))
# Post-process each batch, ensuring that it is mean-centered and time-major.
def
process_pose_data
(
data
,
lengths
):
"""Creates Tensors for next step prediction and mean-centers the input."""
data
=
tf
.
to_float
(
tf
.
transpose
(
data
,
perm
=
[
1
,
0
,
2
]))
lengths
=
tf
.
to_int32
(
lengths
)
targets
=
data
# Mean center the inputs.
inputs
=
data
-
tf
.
constant
(
mean
,
dtype
=
tf
.
float32
,
shape
=
[
1
,
1
,
mean
.
shape
[
0
]])
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs
=
tf
.
pad
(
inputs
,
[[
1
,
0
],
[
0
,
0
],
[
0
,
0
]],
mode
=
"CONSTANT"
)[:
-
1
]
# Mask out unused timesteps.
inputs
*=
tf
.
expand_dims
(
tf
.
transpose
(
tf
.
sequence_mask
(
lengths
,
dtype
=
inputs
.
dtype
)),
2
)
return
inputs
,
targets
,
lengths
dataset
=
dataset
.
map
(
process_pose_data
,
num_parallel_calls
=
num_parallel_calls
)
dataset
=
dataset
.
prefetch
(
num_examples
)
itr
=
dataset
.
make_one_shot_iterator
()
inputs
,
targets
,
lengths
=
itr
.
get_next
()
return
inputs
,
targets
,
lengths
,
tf
.
constant
(
mean
,
dtype
=
tf
.
float32
)
def
create_speech_dataset
(
path
,
batch_size
,
samples_per_timestep
=
200
,
...
...
@@ -221,3 +312,142 @@ def create_speech_dataset(path,
itr
=
dataset
.
make_one_shot_iterator
()
inputs
,
targets
,
lengths
=
itr
.
get_next
()
return
inputs
,
targets
,
lengths
SQUARED_OBSERVATION
=
"squared"
ABS_OBSERVATION
=
"abs"
STANDARD_OBSERVATION
=
"standard"
OBSERVATION_TYPES
=
[
SQUARED_OBSERVATION
,
ABS_OBSERVATION
,
STANDARD_OBSERVATION
]
ROUND_TRANSITION
=
"round"
STANDARD_TRANSITION
=
"standard"
TRANSITION_TYPES
=
[
ROUND_TRANSITION
,
STANDARD_TRANSITION
]
def
create_chain_graph_dataset
(
batch_size
,
num_timesteps
,
steps_per_observation
=
None
,
state_size
=
1
,
transition_variance
=
1.
,
observation_variance
=
1.
,
transition_type
=
STANDARD_TRANSITION
,
observation_type
=
STANDARD_OBSERVATION
,
fixed_observation
=
None
,
prefetch_buffer_size
=
2048
,
dtype
=
"float32"
):
"""Creates a toy chain graph dataset.
Creates a dataset where the data are sampled from a diffusion process. The
'latent' states of the process are sampled as a chain of Normals:
z0 ~ N(0, transition_variance)
z1 ~ N(transition_fn(z0), transition_variance)
...
where transition_fn could be round z0 or pass it through unchanged.
The observations are produced every steps_per_observation timesteps as a
function of the latent zs. For example if steps_per_observation is 3 then the
first observation will be produced as a function of z3:
x1 ~ N(observation_fn(z3), observation_variance)
where observation_fn could square z3, take the absolute value, or pass
it through unchanged.
Only the observations are returned.
Args:
batch_size: The batch size. The number of trajectories to run in parallel.
num_timesteps: The length of the chain of latent states (i.e. the
number of z's excluding z0.
steps_per_observation: The number of latent states between each observation,
must evenly divide num_timesteps.
state_size: The size of the latent state and observation, must be a
python int.
transition_variance: The variance of the transition density.
observation_variance: The variance of the observation density.
transition_type: Must be one of "round" or "standard". "round" means that
the transition density is centered at the rounded previous latent state.
"standard" centers the transition density at the previous latent state,
unchanged.
observation_type: Must be one of "squared", "abs" or "standard". "squared"
centers the observation density at the squared latent state. "abs"
centers the observaiton density at the absolute value of the current
latent state. "standard" centers the observation density at the current
latent state.
fixed_observation: If not None, fixes all observations to be a constant.
Must be a scalar.
prefetch_buffer_size: The size of the prefetch queues to use after reading
and processing the raw data.
dtype: A string convertible to a tensorflow datatype. The datatype used
to represent the states and observations.
Returns:
observations: A batch of observations represented as a dense Tensor of
shape [num_observations, batch_size, state_size]. num_observations is
num_timesteps/steps_per_observation.
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch. Will contain num_observations as each entry.
Raises:
ValueError: Raised if steps_per_observation does not evenly divide
num_timesteps.
"""
if
steps_per_observation
is
None
:
steps_per_observation
=
num_timesteps
if
num_timesteps
%
steps_per_observation
!=
0
:
raise
ValueError
(
"steps_per_observation must evenly divide num_timesteps."
)
num_observations
=
int
(
num_timesteps
/
steps_per_observation
)
def
data_generator
():
"""An infinite generator of latents and observations from the model."""
transition_std
=
np
.
sqrt
(
transition_variance
)
observation_std
=
np
.
sqrt
(
observation_variance
)
while
True
:
states
=
[]
observations
=
[]
# Sample z0 ~ Normal(0, sqrt(variance)).
states
.
append
(
np
.
random
.
normal
(
size
=
[
state_size
],
scale
=
observation_std
).
astype
(
dtype
))
# Start the range at 1 because we've already generated z0.
# The range ends at num_timesteps+1 because we want to include the
# num_timesteps-th step.
for
t
in
xrange
(
1
,
num_timesteps
+
1
):
if
transition_type
==
ROUND_TRANSITION
:
loc
=
np
.
round
(
states
[
-
1
])
elif
transition_type
==
STANDARD_TRANSITION
:
loc
=
states
[
-
1
]
z_t
=
np
.
random
.
normal
(
size
=
[
state_size
],
loc
=
loc
,
scale
=
transition_std
)
states
.
append
(
z_t
.
astype
(
dtype
))
if
t
%
steps_per_observation
==
0
:
if
fixed_observation
is
None
:
if
observation_type
==
SQUARED_OBSERVATION
:
loc
=
np
.
square
(
states
[
-
1
])
elif
observation_type
==
ABS_OBSERVATION
:
loc
=
np
.
abs
(
states
[
-
1
])
elif
observation_type
==
STANDARD_OBSERVATION
:
loc
=
states
[
-
1
]
x_t
=
np
.
random
.
normal
(
size
=
[
state_size
],
loc
=
loc
,
scale
=
observation_std
).
astype
(
dtype
)
else
:
x_t
=
np
.
ones
([
state_size
])
*
fixed_observation
observations
.
append
(
x_t
)
yield
states
,
observations
dataset
=
tf
.
data
.
Dataset
.
from_generator
(
data_generator
,
output_types
=
(
tf
.
as_dtype
(
dtype
),
tf
.
as_dtype
(
dtype
)),
output_shapes
=
([
num_timesteps
+
1
,
state_size
],
[
num_observations
,
state_size
])
)
dataset
=
dataset
.
repeat
().
batch
(
batch_size
)
dataset
=
dataset
.
prefetch
(
prefetch_buffer_size
)
itr
=
dataset
.
make_one_shot_iterator
()
_
,
observations
=
itr
.
get_next
()
# Transpose observations from [batch, time, state_size] to
# [time, batch, state_size].
observations
=
tf
.
transpose
(
observations
,
perm
=
[
1
,
0
,
2
])
lengths
=
tf
.
ones
([
batch_size
],
dtype
=
tf
.
int32
)
*
num_observations
return
observations
,
lengths
research/fivo/fivo/data/datasets_test.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.data.datasets."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
pickle
import
os
import
numpy
as
np
import
tensorflow
as
tf
from
fivo.data
import
datasets
FLAGS
=
tf
.
app
.
flags
.
FLAGS
class
DatasetsTest
(
tf
.
test
.
TestCase
):
def
test_sparse_pianoroll_to_dense_empty_at_end
(
self
):
sparse_pianoroll
=
[(
0
,
1
),
(
1
,
0
),
(),
(
1
,),
(),
()]
dense_pianoroll
,
num_timesteps
=
datasets
.
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
=
0
,
num_notes
=
2
)
self
.
assertEqual
(
num_timesteps
,
6
)
self
.
assertAllEqual
([[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
],
[
0
,
0
]],
dense_pianoroll
)
def
test_sparse_pianoroll_to_dense_with_chord
(
self
):
sparse_pianoroll
=
[(
0
,
1
),
(
1
,
0
),
(),
(
1
,)]
dense_pianoroll
,
num_timesteps
=
datasets
.
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
=
0
,
num_notes
=
2
)
self
.
assertEqual
(
num_timesteps
,
4
)
self
.
assertAllEqual
([[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
1
]],
dense_pianoroll
)
def
test_sparse_pianoroll_to_dense_simple
(
self
):
sparse_pianoroll
=
[(
0
,),
(),
(
1
,)]
dense_pianoroll
,
num_timesteps
=
datasets
.
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
=
0
,
num_notes
=
2
)
self
.
assertEqual
(
num_timesteps
,
3
)
self
.
assertAllEqual
([[
1
,
0
],
[
0
,
0
],
[
0
,
1
]],
dense_pianoroll
)
def
test_sparse_pianoroll_to_dense_subtracts_min_note
(
self
):
sparse_pianoroll
=
[(
4
,
5
),
(
5
,
4
),
(),
(
5
,),
(),
()]
dense_pianoroll
,
num_timesteps
=
datasets
.
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
=
4
,
num_notes
=
2
)
self
.
assertEqual
(
num_timesteps
,
6
)
self
.
assertAllEqual
([[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
],
[
0
,
0
]],
dense_pianoroll
)
def
test_sparse_pianoroll_to_dense_uses_num_notes
(
self
):
sparse_pianoroll
=
[(
4
,
5
),
(
5
,
4
),
(),
(
5
,),
(),
()]
dense_pianoroll
,
num_timesteps
=
datasets
.
sparse_pianoroll_to_dense
(
sparse_pianoroll
,
min_note
=
4
,
num_notes
=
3
)
self
.
assertEqual
(
num_timesteps
,
6
)
self
.
assertAllEqual
([[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
0
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
0
]],
dense_pianoroll
)
def
test_pianoroll_dataset
(
self
):
pianoroll_data
=
[[(
0
,),
(),
(
1
,)],
[(
0
,
1
),
(
1
,)],
[(
1
,),
(
0
,),
(),
(
0
,
1
),
(),
()]]
pianoroll_mean
=
np
.
zeros
([
3
])
pianoroll_mean
[
-
1
]
=
1
data
=
{
"train"
:
pianoroll_data
,
"train_mean"
:
pianoroll_mean
}
path
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
"test.pkl"
)
pickle
.
dump
(
data
,
open
(
path
,
"wb"
))
with
self
.
test_session
()
as
sess
:
inputs
,
targets
,
lens
,
mean
=
datasets
.
create_pianoroll_dataset
(
path
,
"train"
,
2
,
num_parallel_calls
=
1
,
shuffle
=
False
,
repeat
=
False
,
min_note
=
0
,
max_note
=
2
)
i1
,
t1
,
l1
=
sess
.
run
([
inputs
,
targets
,
lens
])
i2
,
t2
,
l2
=
sess
.
run
([
inputs
,
targets
,
lens
])
m
=
sess
.
run
(
mean
)
# Check the lengths.
self
.
assertAllEqual
([
3
,
2
],
l1
)
self
.
assertAllEqual
([
6
],
l2
)
# Check the mean.
self
.
assertAllEqual
(
pianoroll_mean
,
m
)
# Check the targets. The targets should not be mean-centered and should
# be padded with zeros to a common length within a batch.
self
.
assertAllEqual
([[
1
,
0
,
0
],
[
0
,
0
,
0
],
[
0
,
1
,
0
]],
t1
[:,
0
,
:])
self
.
assertAllEqual
([[
1
,
1
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
0
]],
t1
[:,
1
,
:])
self
.
assertAllEqual
([[
0
,
1
,
0
],
[
1
,
0
,
0
],
[
0
,
0
,
0
],
[
1
,
1
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
0
]],
t2
[:,
0
,
:])
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch. The mean should be subtracted from all timesteps except
# the first and the padding.
self
.
assertAllEqual
([[
0
,
0
,
0
],
[
1
,
0
,
-
1
],
[
0
,
0
,
-
1
]],
i1
[:,
0
,
:])
self
.
assertAllEqual
([[
0
,
0
,
0
],
[
1
,
1
,
-
1
],
[
0
,
0
,
0
]],
i1
[:,
1
,
:])
self
.
assertAllEqual
([[
0
,
0
,
0
],
[
0
,
1
,
-
1
],
[
1
,
0
,
-
1
],
[
0
,
0
,
-
1
],
[
1
,
1
,
-
1
],
[
0
,
0
,
-
1
]],
i2
[:,
0
,
:])
def
test_human_pose_dataset
(
self
):
pose_data
=
[
[[
0
,
0
],
[
2
,
2
]],
[[
2
,
2
]],
[[
0
,
0
],
[
0
,
0
],
[
2
,
2
],
[
2
,
2
],
[
0
,
0
]],
]
pose_data
=
[
np
.
array
(
x
,
dtype
=
np
.
float64
)
for
x
in
pose_data
]
pose_data_mean
=
np
.
array
([
1
,
1
],
dtype
=
np
.
float64
)
data
=
{
"train"
:
pose_data
,
"train_mean"
:
pose_data_mean
,
}
path
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
"test_human_pose_dataset.pkl"
)
with
open
(
path
,
"wb"
)
as
out
:
pickle
.
dump
(
data
,
out
)
with
self
.
test_session
()
as
sess
:
inputs
,
targets
,
lens
,
mean
=
datasets
.
create_human_pose_dataset
(
path
,
"train"
,
2
,
num_parallel_calls
=
1
,
shuffle
=
False
,
repeat
=
False
)
i1
,
t1
,
l1
=
sess
.
run
([
inputs
,
targets
,
lens
])
i2
,
t2
,
l2
=
sess
.
run
([
inputs
,
targets
,
lens
])
m
=
sess
.
run
(
mean
)
# Check the lengths.
self
.
assertAllEqual
([
2
,
1
],
l1
)
self
.
assertAllEqual
([
5
],
l2
)
# Check the mean.
self
.
assertAllEqual
(
pose_data_mean
,
m
)
# Check the targets. The targets should not be mean-centered and should
# be padded with zeros to a common length within a batch.
self
.
assertAllEqual
([[
0
,
0
],
[
2
,
2
]],
t1
[:,
0
,
:])
self
.
assertAllEqual
([[
2
,
2
],
[
0
,
0
]],
t1
[:,
1
,
:])
self
.
assertAllEqual
([[
0
,
0
],
[
0
,
0
],
[
2
,
2
],
[
2
,
2
],
[
0
,
0
]],
t2
[:,
0
,
:])
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch. The mean should be subtracted from all timesteps except
# the first and the padding.
self
.
assertAllEqual
([[
0
,
0
],
[
-
1
,
-
1
]],
i1
[:,
0
,
:])
self
.
assertAllEqual
([[
0
,
0
],
[
0
,
0
]],
i1
[:,
1
,
:])
self
.
assertAllEqual
([[
0
,
0
],
[
-
1
,
-
1
],
[
-
1
,
-
1
],
[
1
,
1
],
[
1
,
1
]],
i2
[:,
0
,
:])
def
test_speech_dataset
(
self
):
with
self
.
test_session
()
as
sess
:
path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))),
"test_data"
,
"tiny_speech_dataset.tfrecord"
)
inputs
,
targets
,
lens
=
datasets
.
create_speech_dataset
(
path
,
3
,
samples_per_timestep
=
2
,
num_parallel_calls
=
1
,
prefetch_buffer_size
=
3
,
shuffle
=
False
,
repeat
=
False
)
inputs1
,
targets1
,
lengths1
=
sess
.
run
([
inputs
,
targets
,
lens
])
inputs2
,
targets2
,
lengths2
=
sess
.
run
([
inputs
,
targets
,
lens
])
# Check the lengths.
self
.
assertAllEqual
([
1
,
2
,
3
],
lengths1
)
self
.
assertAllEqual
([
4
],
lengths2
)
# Check the targets. The targets should be padded with zeros to a common
# length within a batch.
self
.
assertAllEqual
([[[
0.
,
1.
],
[
0.
,
1.
],
[
0.
,
1.
]],
[[
0.
,
0.
],
[
2.
,
3.
],
[
2.
,
3.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
4.
,
5.
]]],
targets1
)
self
.
assertAllEqual
([[[
0.
,
1.
]],
[[
2.
,
3.
]],
[[
4.
,
5.
]],
[[
6.
,
7.
]]],
targets2
)
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch.
self
.
assertAllEqual
([[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]],
[[
0.
,
0.
],
[
0.
,
1.
],
[
0.
,
1.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
2.
,
3.
]]],
inputs1
)
self
.
assertAllEqual
([[[
0.
,
0.
]],
[[
0.
,
1.
]],
[[
2.
,
3.
]],
[[
4.
,
5.
]]],
inputs2
)
def
test_chain_graph_raises_error_on_wrong_steps_per_observation
(
self
):
with
self
.
assertRaises
(
ValueError
):
datasets
.
create_chain_graph_dataset
(
batch_size
=
4
,
num_timesteps
=
10
,
steps_per_observation
=
9
)
def
test_chain_graph_single_obs
(
self
):
with
self
.
test_session
()
as
sess
:
np
.
random
.
seed
(
1234
)
num_observations
=
1
num_timesteps
=
5
batch_size
=
2
state_size
=
1
observations
,
lengths
=
datasets
.
create_chain_graph_dataset
(
batch_size
=
batch_size
,
num_timesteps
=
num_timesteps
,
state_size
=
state_size
)
out_observations
,
out_lengths
=
sess
.
run
([
observations
,
lengths
])
self
.
assertAllEqual
([
num_observations
,
num_observations
],
out_lengths
)
self
.
assertAllClose
(
[[[
1.426677
],
[
-
1.789461
]]],
out_observations
)
def
test_chain_graph_multiple_obs
(
self
):
with
self
.
test_session
()
as
sess
:
np
.
random
.
seed
(
1234
)
num_observations
=
3
num_timesteps
=
6
batch_size
=
2
state_size
=
1
observations
,
lengths
=
datasets
.
create_chain_graph_dataset
(
batch_size
=
batch_size
,
num_timesteps
=
num_timesteps
,
steps_per_observation
=
num_timesteps
/
num_observations
,
state_size
=
state_size
)
out_observations
,
out_lengths
=
sess
.
run
([
observations
,
lengths
])
self
.
assertAllEqual
([
num_observations
,
num_observations
],
out_lengths
)
self
.
assertAllClose
(
[[[
0.40051451
],
[
1.07405114
]],
[[
1.73932898
],
[
3.16880035
]],
[[
-
1.98377144
],
[
2.82669163
]]],
out_observations
)
def
test_chain_graph_state_dims
(
self
):
with
self
.
test_session
()
as
sess
:
np
.
random
.
seed
(
1234
)
num_observations
=
1
num_timesteps
=
5
batch_size
=
2
state_size
=
3
observations
,
lengths
=
datasets
.
create_chain_graph_dataset
(
batch_size
=
batch_size
,
num_timesteps
=
num_timesteps
,
state_size
=
state_size
)
out_observations
,
out_lengths
=
sess
.
run
([
observations
,
lengths
])
self
.
assertAllEqual
([
num_observations
,
num_observations
],
out_lengths
)
self
.
assertAllClose
(
[[[
1.052287
,
-
4.560759
,
3.07988
],
[
2.008926
,
0.495567
,
3.488678
]]],
out_observations
)
def
test_chain_graph_fixed_obs
(
self
):
with
self
.
test_session
()
as
sess
:
np
.
random
.
seed
(
1234
)
num_observations
=
3
num_timesteps
=
6
batch_size
=
2
state_size
=
1
observations
,
lengths
=
datasets
.
create_chain_graph_dataset
(
batch_size
=
batch_size
,
num_timesteps
=
num_timesteps
,
steps_per_observation
=
num_timesteps
/
num_observations
,
state_size
=
state_size
,
fixed_observation
=
4.
)
out_observations
,
out_lengths
=
sess
.
run
([
observations
,
lengths
])
self
.
assertAllEqual
([
num_observations
,
num_observations
],
out_lengths
)
self
.
assertAllClose
(
np
.
ones
([
num_observations
,
batch_size
,
state_size
])
*
4.
,
out_observations
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/fivo/fivo/ghmm_runners.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Creates and runs Gaussian HMM-related graphs."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
tensorflow
as
tf
from
fivo
import
smc
from
fivo
import
bounds
from
fivo.data
import
datasets
from
fivo.models
import
ghmm
def
run_train
(
config
):
"""Runs training for a Gaussian HMM setup."""
def
create_logging_hook
(
step
,
bound_value
,
likelihood
,
bound_gap
):
"""Creates a logging hook that prints the bound value periodically."""
bound_label
=
config
.
bound
+
"/t"
def
summary_formatter
(
log_dict
):
string
=
(
"Step {step}, %s: {value:.3f}, "
"likelihood: {ll:.3f}, gap: {gap:.3e}"
)
%
bound_label
return
string
.
format
(
**
log_dict
)
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
{
"step"
:
step
,
"value"
:
bound_value
,
"ll"
:
likelihood
,
"gap"
:
bound_gap
},
every_n_iter
=
config
.
summarize_every
,
formatter
=
summary_formatter
)
return
logging_hook
def
create_losses
(
model
,
observations
,
lengths
):
"""Creates the loss to be optimized.
Args:
model: A Trainable GHMM model.
observations: A set of observations.
lengths: The lengths of each sequence in the observations.
Returns:
loss: A float Tensor that when differentiated yields the gradients
to apply to the model. Should be optimized via gradient descent.
bound: A float Tensor containing the value of the bound that is
being optimized.
true_ll: The true log-likelihood of the data under the model.
bound_gap: The gap between the bound and the true log-likelihood.
"""
# Compute lower bounds on the log likelihood.
if
config
.
bound
==
"elbo"
:
ll_per_seq
,
_
,
_
=
bounds
.
iwae
(
model
,
observations
,
lengths
,
num_samples
=
1
,
parallel_iterations
=
config
.
parallel_iterations
)
elif
config
.
bound
==
"iwae"
:
ll_per_seq
,
_
,
_
=
bounds
.
iwae
(
model
,
observations
,
lengths
,
num_samples
=
config
.
num_samples
,
parallel_iterations
=
config
.
parallel_iterations
)
elif
config
.
bound
==
"fivo"
:
if
config
.
resampling_type
==
"relaxed"
:
ll_per_seq
,
_
,
_
,
_
=
bounds
.
fivo
(
model
,
observations
,
lengths
,
num_samples
=
config
.
num_samples
,
resampling_criterion
=
smc
.
ess_criterion
,
resampling_type
=
config
.
resampling_type
,
relaxed_resampling_temperature
=
config
.
relaxed_resampling_temperature
,
random_seed
=
config
.
random_seed
,
parallel_iterations
=
config
.
parallel_iterations
)
else
:
ll_per_seq
,
_
,
_
,
_
=
bounds
.
fivo
(
model
,
observations
,
lengths
,
num_samples
=
config
.
num_samples
,
resampling_criterion
=
smc
.
ess_criterion
,
resampling_type
=
config
.
resampling_type
,
random_seed
=
config
.
random_seed
,
parallel_iterations
=
config
.
parallel_iterations
)
ll_per_t
=
tf
.
reduce_mean
(
ll_per_seq
/
tf
.
to_float
(
lengths
))
# Compute the data's true likelihood under the model and the bound gap.
true_ll_per_seq
=
model
.
likelihood
(
tf
.
squeeze
(
observations
))
true_ll_per_t
=
tf
.
reduce_mean
(
true_ll_per_seq
/
tf
.
to_float
(
lengths
))
bound_gap
=
true_ll_per_seq
-
ll_per_seq
bound_gap
=
tf
.
reduce_mean
(
bound_gap
/
tf
.
to_float
(
lengths
))
tf
.
summary
.
scalar
(
"train_ll_bound"
,
ll_per_t
)
tf
.
summary
.
scalar
(
"train_true_ll"
,
true_ll_per_t
)
tf
.
summary
.
scalar
(
"bound_gap"
,
bound_gap
)
return
-
ll_per_t
,
ll_per_t
,
true_ll_per_t
,
bound_gap
def
create_graph
():
"""Creates the training graph."""
global_step
=
tf
.
train
.
get_or_create_global_step
()
xs
,
lengths
=
datasets
.
create_chain_graph_dataset
(
config
.
batch_size
,
config
.
num_timesteps
,
steps_per_observation
=
1
,
state_size
=
1
,
transition_variance
=
config
.
variance
,
observation_variance
=
config
.
variance
)
model
=
ghmm
.
TrainableGaussianHMM
(
config
.
num_timesteps
,
config
.
proposal_type
,
transition_variances
=
config
.
variance
,
emission_variances
=
config
.
variance
,
random_seed
=
config
.
random_seed
)
loss
,
bound
,
true_ll
,
gap
=
create_losses
(
model
,
xs
,
lengths
)
opt
=
tf
.
train
.
AdamOptimizer
(
config
.
learning_rate
)
grads
=
opt
.
compute_gradients
(
loss
,
var_list
=
tf
.
trainable_variables
())
train_op
=
opt
.
apply_gradients
(
grads
,
global_step
=
global_step
)
return
bound
,
true_ll
,
gap
,
train_op
,
global_step
with
tf
.
Graph
().
as_default
():
if
config
.
random_seed
:
tf
.
set_random_seed
(
config
.
random_seed
)
np
.
random
.
seed
(
config
.
random_seed
)
bound
,
true_ll
,
gap
,
train_op
,
global_step
=
create_graph
()
log_hook
=
create_logging_hook
(
global_step
,
bound
,
true_ll
,
gap
)
with
tf
.
train
.
MonitoredTrainingSession
(
master
=
""
,
hooks
=
[
log_hook
],
checkpoint_dir
=
config
.
logdir
,
save_checkpoint_secs
=
120
,
save_summaries_steps
=
config
.
summarize_every
,
log_step_count_steps
=
config
.
summarize_every
*
20
)
as
sess
:
cur_step
=
-
1
while
cur_step
<=
config
.
max_steps
and
not
sess
.
should_stop
():
cur_step
=
sess
.
run
(
global_step
)
_
,
cur_step
=
sess
.
run
([
train_op
,
global_step
])
def
run_eval
(
config
):
"""Evaluates a Gaussian HMM using the given config."""
def
create_bound
(
model
,
xs
,
lengths
):
"""Creates the bound to be evaluated."""
if
config
.
bound
==
"elbo"
:
ll_per_seq
,
log_weights
,
_
=
bounds
.
iwae
(
model
,
xs
,
lengths
,
num_samples
=
1
,
parallel_iterations
=
config
.
parallel_iterations
)
elif
config
.
bound
==
"iwae"
:
ll_per_seq
,
log_weights
,
_
=
bounds
.
iwae
(
model
,
xs
,
lengths
,
num_samples
=
config
.
num_samples
,
parallel_iterations
=
config
.
parallel_iterations
)
elif
config
.
bound
==
"fivo"
:
ll_per_seq
,
log_weights
,
resampled
,
_
=
bounds
.
fivo
(
model
,
xs
,
lengths
,
num_samples
=
config
.
num_samples
,
resampling_criterion
=
smc
.
ess_criterion
,
resampling_type
=
config
.
resampling_type
,
random_seed
=
config
.
random_seed
,
parallel_iterations
=
config
.
parallel_iterations
)
# Compute bound scaled by number of timesteps.
bound_per_t
=
ll_per_seq
/
tf
.
to_float
(
lengths
)
if
config
.
bound
==
"fivo"
:
return
bound_per_t
,
log_weights
,
resampled
else
:
return
bound_per_t
,
log_weights
def
create_graph
():
"""Creates the dataset, model, and bound."""
xs
,
lengths
=
datasets
.
create_chain_graph_dataset
(
config
.
batch_size
,
config
.
num_timesteps
,
steps_per_observation
=
1
,
state_size
=
1
,
transition_variance
=
config
.
variance
,
observation_variance
=
config
.
variance
)
model
=
ghmm
.
TrainableGaussianHMM
(
config
.
num_timesteps
,
config
.
proposal_type
,
transition_variances
=
config
.
variance
,
emission_variances
=
config
.
variance
,
random_seed
=
config
.
random_seed
)
true_likelihood
=
tf
.
reduce_mean
(
model
.
likelihood
(
tf
.
squeeze
(
xs
))
/
tf
.
to_float
(
lengths
))
outs
=
[
true_likelihood
]
outs
.
extend
(
list
(
create_bound
(
model
,
xs
,
lengths
)))
return
outs
with
tf
.
Graph
().
as_default
():
if
config
.
random_seed
:
tf
.
set_random_seed
(
config
.
random_seed
)
np
.
random
.
seed
(
config
.
random_seed
)
graph_outs
=
create_graph
()
with
tf
.
train
.
SingularMonitoredSession
(
checkpoint_dir
=
config
.
logdir
)
as
sess
:
outs
=
sess
.
run
(
graph_outs
)
likelihood
=
outs
[
0
]
avg_bound
=
np
.
mean
(
outs
[
1
])
std
=
np
.
std
(
outs
[
1
])
log_weights
=
outs
[
2
]
log_weight_variances
=
np
.
var
(
log_weights
,
axis
=
2
)
avg_log_weight_variance
=
np
.
var
(
log_weight_variances
,
axis
=
1
)
avg_log_weight
=
np
.
mean
(
log_weights
,
axis
=
(
1
,
2
))
data
=
{
"mean"
:
avg_bound
,
"std"
:
std
,
"log_weights"
:
log_weights
,
"log_weight_means"
:
avg_log_weight
,
"log_weight_variances"
:
avg_log_weight_variance
}
if
len
(
outs
)
==
4
:
data
[
"resampled"
]
=
outs
[
3
]
data
[
"avg_resampled"
]
=
np
.
mean
(
outs
[
3
],
axis
=
1
)
# Log some useful statistics.
tf
.
logging
.
info
(
"Evaled bound %s with batch_size: %d, num_samples: %d."
%
(
config
.
bound
,
config
.
batch_size
,
config
.
num_samples
))
tf
.
logging
.
info
(
"mean: %f, std: %f"
%
(
avg_bound
,
std
))
tf
.
logging
.
info
(
"true likelihood: %s"
%
likelihood
)
tf
.
logging
.
info
(
"avg log weight: %s"
%
avg_log_weight
)
tf
.
logging
.
info
(
"log weight variance: %s"
%
avg_log_weight_variance
)
if
len
(
outs
)
==
4
:
tf
.
logging
.
info
(
"avg resamples per t: %s"
%
data
[
"avg_resampled"
])
if
not
tf
.
gfile
.
Exists
(
config
.
logdir
):
tf
.
gfile
.
MakeDirs
(
config
.
logdir
)
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
config
.
logdir
,
"out.npz"
),
"w"
)
as
fout
:
np
.
save
(
fout
,
data
)
research/fivo/fivo/ghmm_runners_test.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.ghmm_runners."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
tensorflow
as
tf
from
fivo
import
ghmm_runners
class
GHMMRunnersTest
(
tf
.
test
.
TestCase
):
def
default_config
(
self
):
class
Config
(
object
):
pass
config
=
Config
()
config
.
model
=
"ghmm"
config
.
bound
=
"fivo"
config
.
proposal_type
=
"prior"
config
.
batch_size
=
4
config
.
num_samples
=
4
config
.
num_timesteps
=
10
config
.
variance
=
0.1
config
.
resampling_type
=
"multinomial"
config
.
random_seed
=
1234
config
.
parallel_iterations
=
1
config
.
learning_rate
=
1e-4
config
.
summarize_every
=
1
config
.
max_steps
=
1
return
config
def
test_eval_ghmm_notraining_fivo_prior
(
self
):
self
.
eval_ghmm_notraining
(
"fivo"
,
"prior"
,
-
3.063864
)
def
test_eval_ghmm_notraining_fivo_true_filtering
(
self
):
self
.
eval_ghmm_notraining
(
"fivo"
,
"true-filtering"
,
-
1.1409812
)
def
test_eval_ghmm_notraining_fivo_true_smoothing
(
self
):
self
.
eval_ghmm_notraining
(
"fivo"
,
"true-smoothing"
,
-
0.85592091
)
def
test_eval_ghmm_notraining_iwae_prior
(
self
):
self
.
eval_ghmm_notraining
(
"iwae"
,
"prior"
,
-
5.9730167
)
def
test_eval_ghmm_notraining_iwae_true_filtering
(
self
):
self
.
eval_ghmm_notraining
(
"iwae"
,
"true-filtering"
,
-
1.1485999
)
def
test_eval_ghmm_notraining_iwae_true_smoothing
(
self
):
self
.
eval_ghmm_notraining
(
"iwae"
,
"true-smoothing"
,
-
0.85592091
)
def
eval_ghmm_notraining
(
self
,
bound
,
proposal_type
,
expected_bound_avg
):
config
=
self
.
default_config
()
config
.
proposal_type
=
proposal_type
config
.
bound
=
bound
config
.
logdir
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
"test-ghmm-%s-%s"
%
(
proposal_type
,
bound
))
ghmm_runners
.
run_eval
(
config
)
data
=
np
.
load
(
os
.
path
.
join
(
config
.
logdir
,
"out.npz"
)).
item
()
self
.
assertAlmostEqual
(
expected_bound_avg
,
data
[
"mean"
],
places
=
3
)
def
test_train_ghmm_for_one_step_and_eval_fivo_filtering
(
self
):
self
.
train_ghmm_for_one_step_and_eval
(
"fivo"
,
"filtering"
,
-
16.727108
)
def
test_train_ghmm_for_one_step_and_eval_fivo_smoothing
(
self
):
self
.
train_ghmm_for_one_step_and_eval
(
"fivo"
,
"smoothing"
,
-
19.381277
)
def
test_train_ghmm_for_one_step_and_eval_iwae_filtering
(
self
):
self
.
train_ghmm_for_one_step_and_eval
(
"iwae"
,
"filtering"
,
-
33.31966
)
def
test_train_ghmm_for_one_step_and_eval_iwae_smoothing
(
self
):
self
.
train_ghmm_for_one_step_and_eval
(
"iwae"
,
"smoothing"
,
-
46.388447
)
def
train_ghmm_for_one_step_and_eval
(
self
,
bound
,
proposal_type
,
expected_bound_avg
):
config
=
self
.
default_config
()
config
.
proposal_type
=
proposal_type
config
.
bound
=
bound
config
.
max_steps
=
1
config
.
logdir
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
"test-ghmm-training-%s-%s"
%
(
proposal_type
,
bound
))
ghmm_runners
.
run_train
(
config
)
ghmm_runners
.
run_eval
(
config
)
data
=
np
.
load
(
os
.
path
.
join
(
config
.
logdir
,
"out.npz"
)).
item
()
self
.
assertAlmostEqual
(
expected_bound_avg
,
data
[
"mean"
],
places
=
2
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/fivo/fivo/models/__init__.py
0 → 100644
View file @
27b4acd4
research/fivo/models/
vrnn
.py
→
research/fivo/
fivo/
models/
base
.py
View file @
27b4acd4
# Copyright 201
7
The TensorFlow Authors All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""
VRNN classes
."""
"""
Reusable model classes for FIVO
."""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -22,282 +22,67 @@ from __future__ import print_function
import
sonnet
as
snt
import
tensorflow
as
tf
from
fivo
import
nested_utils
as
nested
class
VRNNCell
(
snt
.
AbstractModule
):
"""Implementation of a Variational Recurrent Neural Network (VRNN).
tfd
=
tf
.
contrib
.
distributions
Introduced in "A Recurrent Latent Variable Model for Sequential data"
by Chung et al. https://arxiv.org/pdf/1506.02216.pdf.
The VRNN is a sequence model similar to an RNN that uses stochastic latent
variables to improve its representational power. It can be thought of as a
sequential analogue to the variational auto-encoder (VAE).
class
ELBOTrainableSequenceModel
(
object
):
"""An abstract class for ELBO-trainable sequence models to extend.
The VRNN has a deterministic RNN as its backbone, represented by the
sequence of RNN hidden states h_t. At each timestep, the RNN hidden state h_t
is conditioned on the previous sequence element, x_{t-1}, as well as the
latent state from the previous timestep, z_{t-1}.
In this implementation of the VRNN the latent state z_t is Gaussian. The
model's prior over z_t is distributed as Normal(mu_t, diag(sigma_t^2)) where
mu_t and sigma_t are the mean and standard deviation output from a fully
connected network that accepts the rnn hidden state h_t as input.
The approximate posterior (also known as q or the encoder in the VAE
framework) is similar to the prior except that it is conditioned on the
current target, x_t, as well as h_t via a fully connected network.
This implementation uses the 'res_q' parameterization of the approximate
posterior, meaning that instead of directly predicting the mean of z_t, the
approximate posterior predicts the 'residual' from the prior's mean. This is
explored more in section 3.3 of https://arxiv.org/pdf/1605.07571.pdf.
During training, the latent state z_t is sampled from the approximate
posterior and the reparameterization trick is used to provide low-variance
gradients.
The generative distribution p(x_t|z_t, h_t) is conditioned on the latent state
z_t as well as the current RNN hidden state h_t via a fully connected network.
To increase the modeling power of the VRNN, two additional networks are
used to extract features from the data and the latent state. Those networks
are called data_feat_extractor and latent_feat_extractor respectively.
There are a few differences between this exposition and the paper.
First, the indexing scheme for h_t is different than the paper's -- what the
paper calls h_t we call h_{t+1}. This is the same notation used by Fraccaro
et al. to describe the VRNN in the paper linked above. Also, the VRNN paper
uses VAE terminology to refer to the different internal networks, so it
refers to the approximate posterior as the encoder and the generative
distribution as the decoder. This implementation also renamed the functions
phi_x and phi_z in the paper to data_feat_extractor and latent_feat_extractor.
Because the ELBO, IWAE, and FIVO bounds all accept the same arguments,
any model that is ELBO-trainable is also IWAE- and FIVO-trainable.
"""
def
__init__
(
self
,
rnn_cell
,
data_feat_extractor
,
latent_feat_extractor
,
prior
,
approx_posterior
,
generative
,
random_seed
=
None
,
name
=
"vrnn"
):
"""Creates a VRNN cell.
def
zero_state
(
self
,
batch_size
,
dtype
):
"""Returns the initial state of the model as a Tensor or tuple of Tensors.
Args:
rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
deterministic backbone of the VRNN. The inputs to the RNN will be the
encoded latent state of the previous timestep with shape
[batch_size, encoded_latent_size] as well as the encoded input of the
current timestep, a Tensor of shape [batch_size, encoded_data_size].
data_feat_extractor: A callable that accepts a batch of data x_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument the inputs x_t, a Tensor of the shape
[batch_size, data_size] and return a Tensor of shape
[batch_size, encoded_data_size]. This callable will be called multiple
times in the VRNN cell so if scoping is not handled correctly then
multiple copies of the variables in this network could be made. It is
recommended to use a snt.nets.MLP module, which takes care of this for
you.
latent_feat_extractor: A callable that accepts a latent state z_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument a Tensor of shape [batch_size, latent_size] and
return a Tensor of shape [batch_size, encoded_latent_size].
This callable must also have the property 'output_size' defined,
returning encoded_latent_size.
prior: A callable that implements the prior p(z_t|h_t). Must accept as
argument the previous RNN hidden state and return a
tf.contrib.distributions.Normal distribution conditioned on the input.
approx_posterior: A callable that implements the approximate posterior
q(z_t|h_t,x_t). Must accept as arguments the encoded target of the
current timestep and the previous RNN hidden state. Must return
a tf.contrib.distributions.Normal distribution conditioned on the
inputs.
generative: A callable that implements the generative distribution
p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
and the RNN hidden state and return a subclass of
tf.contrib.distributions.Distribution that can be used to evaluate
the logprob of the targets.
random_seed: The seed for the random ops. Used mainly for testing.
name: The name of this VRNN.
batch_size: The batch size.
dtype: The datatype to use for the state.
"""
super
(
VRNNCell
,
self
).
__init__
(
name
=
name
)
self
.
rnn_cell
=
rnn_cell
self
.
data_feat_extractor
=
data_feat_extractor
self
.
latent_feat_extractor
=
latent_feat_extractor
self
.
prior
=
prior
self
.
approx_posterior
=
approx_posterior
self
.
generative
=
generative
self
.
random_seed
=
random_seed
self
.
encoded_z_size
=
latent_feat_extractor
.
output_size
self
.
state_size
=
(
self
.
rnn_cell
.
state_size
,
self
.
encoded_z_size
)
raise
NotImplementedError
(
"zero_state not yet implemented."
)
def
zero_state
(
self
,
batch_size
,
dtype
):
"""The initial state of the VRNN.
def
set_observations
(
self
,
observations
,
seq_lengths
):
"""Sets the observations for the model.
This method provides the model with all observed variables including both
inputs and targets. It will be called before running any computations with
the model that require the observations, e.g. training the model or
computing bounds, and should be used to run any necessary preprocessing
steps.
Contains the initial state of the RNN as well as a vector of zeros
corresponding to z_0.
Args:
batch_size: The batch size.
dtype: The data type of the VRNN.
Returns:
zero_state: The initial state of the VRNN.
observations: A potentially nested set of Tensors containing
all observations for the model, both inputs and targets. Typically
a set of Tensors with shape [max_seq_len, batch_size, data_size].
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
"""
return
(
self
.
rnn_cell
.
zero_state
(
batch_size
,
dtype
),
tf
.
zeros
([
batch_size
,
self
.
encoded_z_size
],
dtype
=
dtype
))
self
.
observations
=
observations
self
.
max_seq_len
=
tf
.
reduce_max
(
seq_lengths
)
self
.
observations_ta
=
nested
.
tas_for_tensors
(
observations
,
self
.
max_seq_len
,
clear_after_read
=
False
)
self
.
seq_lengths
=
seq_lengths
def
_build
(
self
,
observations
,
state
,
mask
):
"""
Computes one timestep of the VRNN
.
def
propose_and_weight
(
self
,
state
,
t
):
"""
Propogates model state one timestep and computes log weights
.
Args:
observations: The observations at the current timestep, a tuple
containing the model inputs and targets as Tensors of shape
[batch_size, data_size].
state: The current state of the VRNN
mask: Tensor of shape [batch_size], 1.0 if the current timestep is active
active, 0.0 if it is not active.
This method accepts the current state of the model and computes the state
for the next timestep as well as the incremental log weight of each
element in the batch.
Args:
state: The current state of the model.
t: A scalar integer Tensor representing the current timestep.
Returns:
log_q_z: The logprob of the latent state according to the approximate
posterior.
log_p_z: The logprob of the latent state according to the prior.
log_p_x_given_z: The conditional log-likelihood, i.e. logprob of the
observation according to the generative distribution.
kl: The analytic kl divergence from q(z) to p(z).
state: The new state of the VRNN.
next_state: The state of the model after one timestep.
log_weights: A [batch_size] Tensor containing the incremental log weights.
"""
inputs
,
targets
=
observations
rnn_state
,
prev_latent_encoded
=
state
# Encode the data.
inputs_encoded
=
self
.
data_feat_extractor
(
inputs
)
targets_encoded
=
self
.
data_feat_extractor
(
targets
)
# Run the RNN cell.
rnn_inputs
=
tf
.
concat
([
inputs_encoded
,
prev_latent_encoded
],
axis
=
1
)
rnn_out
,
new_rnn_state
=
self
.
rnn_cell
(
rnn_inputs
,
rnn_state
)
# Create the prior and approximate posterior distributions.
latent_dist_prior
=
self
.
prior
(
rnn_out
)
latent_dist_q
=
self
.
approx_posterior
(
rnn_out
,
targets_encoded
,
prior_mu
=
latent_dist_prior
.
loc
)
# Sample the new latent state z and encode it.
latent_state
=
latent_dist_q
.
sample
(
seed
=
self
.
random_seed
)
latent_encoded
=
self
.
latent_feat_extractor
(
latent_state
)
# Calculate probabilities of the latent state according to the prior p
# and approximate posterior q.
log_q_z
=
tf
.
reduce_sum
(
latent_dist_q
.
log_prob
(
latent_state
),
axis
=-
1
)
log_p_z
=
tf
.
reduce_sum
(
latent_dist_prior
.
log_prob
(
latent_state
),
axis
=-
1
)
analytic_kl
=
tf
.
reduce_sum
(
tf
.
contrib
.
distributions
.
kl_divergence
(
latent_dist_q
,
latent_dist_prior
),
axis
=-
1
)
# Create the generative dist. and calculate the logprob of the targets.
generative_dist
=
self
.
generative
(
latent_encoded
,
rnn_out
)
log_p_x_given_z
=
tf
.
reduce_sum
(
generative_dist
.
log_prob
(
targets
),
axis
=-
1
)
return
(
log_q_z
,
log_p_z
,
log_p_x_given_z
,
analytic_kl
,
(
new_rnn_state
,
latent_encoded
))
_DEFAULT_INITIALIZERS
=
{
"w"
:
tf
.
contrib
.
layers
.
xavier_initializer
(),
"b"
:
tf
.
zeros_initializer
()}
def
create_vrnn
(
data_size
,
latent_size
,
generative_class
,
rnn_hidden_size
=
None
,
fcnet_hidden_sizes
=
None
,
encoded_data_size
=
None
,
encoded_latent_size
=
None
,
sigma_min
=
0.0
,
raw_sigma_bias
=
0.25
,
generative_bias_init
=
0.0
,
initializers
=
None
,
random_seed
=
None
):
"""A factory method for creating VRNN cells.
raise
NotImplementedError
(
"propose_and_weight not yet implemented."
)
Args:
data_size: The dimension of the vectors that make up the data sequences.
latent_size: The size of the stochastic latent state of the VRNN.
generative_class: The class of the generative distribution. Can be either
ConditionalNormalDistribution or ConditionalBernoulliDistribution.
rnn_hidden_size: The hidden state dimension of the RNN that forms the
deterministic part of this VRNN. If None, then it defaults
to latent_size.
fcnet_hidden_sizes: A list of python integers, the size of the hidden
layers of the fully connected networks that parameterize the conditional
distributions of the VRNN. If None, then it defaults to one hidden
layer of size latent_size.
encoded_data_size: The size of the output of the data encoding network. If
None, defaults to latent_size.
encoded_latent_size: The size of the output of the latent state encoding
network. If None, defaults to latent_size.
sigma_min: The minimum value that the standard deviation of the
distribution over the latent state can take.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the neural networks that parameterize the prior and
approximate posterior. Useful for preventing standard deviations close
to zero.
generative_bias_init: A bias to added to the raw output of the fully
connected network that parameterizes the generative distribution. Useful
for initalizing the mean of the distribution to a sensible starting point
such as the mean of the training data. Only used with Bernoulli generative
distributions.
initializers: The variable intitializers to use for the fully connected
networks and RNN cell. Must be a dictionary mapping the keys 'w' and 'b'
to the initializers for the weights and biases. Defaults to xavier for
the weights and zeros for the biases when initializers is None.
random_seed: A random seed for the VRNN resampling operations.
Returns:
model: A VRNNCell object.
"""
if
rnn_hidden_size
is
None
:
rnn_hidden_size
=
latent_size
if
fcnet_hidden_sizes
is
None
:
fcnet_hidden_sizes
=
[
latent_size
]
if
encoded_data_size
is
None
:
encoded_data_size
=
latent_size
if
encoded_latent_size
is
None
:
encoded_latent_size
=
latent_size
if
initializers
is
None
:
initializers
=
_DEFAULT_INITIALIZERS
data_feat_extractor
=
snt
.
nets
.
MLP
(
output_sizes
=
fcnet_hidden_sizes
+
[
encoded_data_size
],
initializers
=
initializers
,
name
=
"data_feat_extractor"
)
latent_feat_extractor
=
snt
.
nets
.
MLP
(
output_sizes
=
fcnet_hidden_sizes
+
[
encoded_latent_size
],
initializers
=
initializers
,
name
=
"latent_feat_extractor"
)
prior
=
ConditionalNormalDistribution
(
size
=
latent_size
,
hidden_layer_sizes
=
fcnet_hidden_sizes
,
sigma_min
=
sigma_min
,
raw_sigma_bias
=
raw_sigma_bias
,
initializers
=
initializers
,
name
=
"prior"
)
approx_posterior
=
NormalApproximatePosterior
(
size
=
latent_size
,
hidden_layer_sizes
=
fcnet_hidden_sizes
,
sigma_min
=
sigma_min
,
raw_sigma_bias
=
raw_sigma_bias
,
initializers
=
initializers
,
name
=
"approximate_posterior"
)
if
generative_class
==
ConditionalBernoulliDistribution
:
generative
=
ConditionalBernoulliDistribution
(
size
=
data_size
,
hidden_layer_sizes
=
fcnet_hidden_sizes
,
initializers
=
initializers
,
bias_init
=
generative_bias_init
,
name
=
"generative"
)
else
:
generative
=
ConditionalNormalDistribution
(
size
=
data_size
,
hidden_layer_sizes
=
fcnet_hidden_sizes
,
initializers
=
initializers
,
name
=
"generative"
)
rnn_cell
=
tf
.
nn
.
rnn_cell
.
LSTMCell
(
rnn_hidden_size
,
initializer
=
initializers
[
"w"
])
return
VRNNCell
(
rnn_cell
,
data_feat_extractor
,
latent_feat_extractor
,
prior
,
approx_posterior
,
generative
,
random_seed
=
random_seed
)
DEFAULT_INITIALIZERS
=
{
"w"
:
tf
.
contrib
.
layers
.
xavier_initializer
(),
"b"
:
tf
.
zeros_initializer
()}
class
ConditionalNormalDistribution
(
object
):
...
...
@@ -328,8 +113,9 @@ class ConditionalNormalDistribution(object):
self
.
sigma_min
=
sigma_min
self
.
raw_sigma_bias
=
raw_sigma_bias
self
.
name
=
name
self
.
size
=
size
if
initializers
is
None
:
initializers
=
_
DEFAULT_INITIALIZERS
initializers
=
DEFAULT_INITIALIZERS
self
.
fcnet
=
snt
.
nets
.
MLP
(
output_sizes
=
hidden_layer_sizes
+
[
2
*
size
],
activation
=
hidden_activation_fn
,
...
...
@@ -378,8 +164,9 @@ class ConditionalBernoulliDistribution(object):
name: The name of this distribution, used for sonnet scoping.
"""
self
.
bias_init
=
bias_init
self
.
size
=
size
if
initializers
is
None
:
initializers
=
_
DEFAULT_INITIALIZERS
initializers
=
DEFAULT_INITIALIZERS
self
.
fcnet
=
snt
.
nets
.
MLP
(
output_sizes
=
hidden_layer_sizes
+
[
size
],
activation
=
hidden_activation_fn
,
...
...
@@ -401,7 +188,18 @@ class ConditionalBernoulliDistribution(object):
class
NormalApproximatePosterior
(
ConditionalNormalDistribution
):
"""A Normally-distributed approx. posterior with res_q parameterization."""
def
condition
(
self
,
tensor_list
,
prior_mu
):
def
__init__
(
self
,
size
,
hidden_layer_sizes
,
sigma_min
=
0.0
,
raw_sigma_bias
=
0.25
,
hidden_activation_fn
=
tf
.
nn
.
relu
,
initializers
=
None
,
smoothing
=
False
,
name
=
"conditional_normal_distribution"
):
super
(
NormalApproximatePosterior
,
self
).
__init__
(
size
,
hidden_layer_sizes
,
sigma_min
=
sigma_min
,
raw_sigma_bias
=
raw_sigma_bias
,
hidden_activation_fn
=
hidden_activation_fn
,
initializers
=
initializers
,
name
=
name
)
self
.
smoothing
=
smoothing
def
condition
(
self
,
tensor_list
,
prior_mu
,
smoothing_tensors
=
None
):
"""Generates the mean and variance of the normal distribution.
Args:
...
...
@@ -410,9 +208,135 @@ class NormalApproximatePosterior(ConditionalNormalDistribution):
prior_mu: The mean of the prior distribution associated with this
approximate posterior. Will be added to the mean produced by
this approximate posterior, in res_q fashion.
smoothing_tensors: A list of Tensors. If smoothing is True, these Tensors
will be concatenated with the tensors in tensor_list.
Returns:
mu: The mean of the approximate posterior.
sigma: The standard deviation of the approximate posterior.
"""
if
self
.
smoothing
:
tensor_list
.
extend
(
smoothing_tensors
)
mu
,
sigma
=
super
(
NormalApproximatePosterior
,
self
).
condition
(
tensor_list
)
return
mu
+
prior_mu
,
sigma
class
NonstationaryLinearDistribution
(
object
):
"""A set of loc-scale distributions that are linear functions of inputs.
This class defines a series of location-scale distributions such that
the means are learnable linear functions of the inputs and the log variances
are learnable constants. The functions and log variances are different across
timesteps, allowing the distributions to be nonstationary.
"""
def
__init__
(
self
,
num_timesteps
,
inputs_per_timestep
=
None
,
outputs_per_timestep
=
None
,
initializers
=
None
,
variance_min
=
0.0
,
output_distribution
=
tfd
.
Normal
,
dtype
=
tf
.
float32
):
"""Creates a NonstationaryLinearDistribution.
Args:
num_timesteps: The number of timesteps, i.e. the number of distributions.
inputs_per_timestep: A list of python ints, the dimension of inputs to the
linear function at each timestep. If not provided, the dimension at each
timestep is assumed to be 1.
outputs_per_timestep: A list of python ints, the dimension of the output
distribution at each timestep. If not provided, the dimension at each
timestep is assumed to be 1.
initializers: A dictionary containing intializers for the variables. The
initializer under the key 'w' is used for the weights in the linear
function and the initializer under the key 'b' is used for the biases.
Defaults to xavier initialization for the weights and zeros for the
biases.
variance_min: Python float, the minimum variance of each distribution.
output_distribution: A locatin-scale subclass of tfd.Distribution that
defines the output distribution, e.g. Normal.
dtype: The dtype of the weights and biases.
"""
if
not
initializers
:
initializers
=
DEFAULT_INITIALIZERS
if
not
inputs_per_timestep
:
inputs_per_timestep
=
[
1
]
*
num_timesteps
if
not
outputs_per_timestep
:
outputs_per_timestep
=
[
1
]
*
num_timesteps
self
.
num_timesteps
=
num_timesteps
self
.
variance_min
=
variance_min
self
.
initializers
=
initializers
self
.
dtype
=
dtype
self
.
output_distribution
=
output_distribution
def
_get_variables_ta
(
shapes
,
name
,
initializer
,
trainable
=
True
):
"""Creates a sequence of variables and stores them in a TensorArray."""
# Infer shape if all shapes are equal.
first_shape
=
shapes
[
0
]
infer_shape
=
all
(
shape
==
first_shape
for
shape
in
shapes
)
ta
=
tf
.
TensorArray
(
dtype
=
dtype
,
size
=
len
(
shapes
),
dynamic_size
=
False
,
clear_after_read
=
False
,
infer_shape
=
infer_shape
)
for
t
,
shape
in
enumerate
(
shapes
):
var
=
tf
.
get_variable
(
name
%
t
,
shape
=
shape
,
initializer
=
initializer
,
trainable
=
trainable
)
ta
=
ta
.
write
(
t
,
var
)
return
ta
bias_shapes
=
[[
num_outputs
]
for
num_outputs
in
outputs_per_timestep
]
self
.
log_variances
=
_get_variables_ta
(
bias_shapes
,
"proposal_log_variance_%d"
,
initializers
[
"b"
])
self
.
mean_biases
=
_get_variables_ta
(
bias_shapes
,
"proposal_b_%d"
,
initializers
[
"b"
])
weight_shapes
=
zip
(
inputs_per_timestep
,
outputs_per_timestep
)
self
.
mean_weights
=
_get_variables_ta
(
weight_shapes
,
"proposal_w_%d"
,
initializers
[
"w"
])
self
.
shapes
=
tf
.
TensorArray
(
dtype
=
tf
.
int32
,
size
=
num_timesteps
,
dynamic_size
=
False
,
clear_after_read
=
False
).
unstack
(
weight_shapes
)
def
__call__
(
self
,
t
,
inputs
):
"""Computes the distribution at timestep t.
Args:
t: Scalar integer Tensor, the current timestep. Must be in
[0, num_timesteps).
inputs: The inputs to the linear function parameterizing the mean of
the current distribution. A Tensor of shape [batch_size, num_inputs_t].
Returns:
A tfd.Distribution subclass representing the distribution at timestep t.
"""
b
=
self
.
mean_biases
.
read
(
t
)
w
=
self
.
mean_weights
.
read
(
t
)
shape
=
self
.
shapes
.
read
(
t
)
w
=
tf
.
reshape
(
w
,
shape
)
b
=
tf
.
reshape
(
b
,
[
shape
[
1
],
1
])
log_variance
=
self
.
log_variances
.
read
(
t
)
scale
=
tf
.
sqrt
(
tf
.
maximum
(
tf
.
exp
(
log_variance
),
self
.
variance_min
))
loc
=
tf
.
matmul
(
w
,
inputs
,
transpose_a
=
True
)
+
b
return
self
.
output_distribution
(
loc
=
loc
,
scale
=
scale
)
def
encode_all
(
inputs
,
encoder
):
"""Encodes a timeseries of inputs with a time independent encoder.
Args:
inputs: A [time, batch, feature_dimensions] tensor.
encoder: A network that takes a [batch, features_dimensions] input and
encodes the input.
Returns:
A [time, batch, encoded_feature_dimensions] output tensor.
"""
input_shape
=
tf
.
shape
(
inputs
)
num_timesteps
,
batch_size
=
input_shape
[
0
],
input_shape
[
1
]
reshaped_inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
inputs
.
shape
[
-
1
]])
inputs_encoded
=
encoder
(
reshaped_inputs
)
inputs_encoded
=
tf
.
reshape
(
inputs_encoded
,
[
num_timesteps
,
batch_size
,
encoder
.
output_size
])
return
inputs_encoded
def
ta_for_tensor
(
x
,
**
kwargs
):
"""Creates a TensorArray for the input tensor."""
return
tf
.
TensorArray
(
x
.
dtype
,
tf
.
shape
(
x
)[
0
],
dynamic_size
=
False
,
**
kwargs
).
unstack
(
x
)
research/fivo/fivo/models/ghmm.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A Gaussian hidden markov model.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
fivo.models
import
base
tfd
=
tf
.
contrib
.
distributions
class
GaussianHMM
(
object
):
"""A hidden markov model with 1-D Gaussian latent space and observations.
This is a hidden markov model where the state and observations are
one-dimensional Gaussians. The mean of each latent state is a linear
function of the previous latent state, and the mean of each observation
is a linear function of the current latent state.
The description that follows is 0-indexed instead of 1-indexed to make
it easier to reason about the parameters passed to the model.
The parameters of the model are:
T: The number timesteps, latent states, and observations.
vz_t, t=0 to T-1: The variance of the latent state at timestep t.
vx_t, t=0 to T-1: The variance of the observation at timestep t.
wz_t, t=1 to T-1: The weight that defines the latent transition at t.
wx_t, t=0 to T-1: The weight that defines the observation function at t.
There are T vz_t, vx_t, and wx_t but only T-1 wz_t because there are only
T-1 transitions in the model.
Given these parameters, sampling from the model is defined as
z_0 ~ N(0, vz_0)
x_0 | z_0 ~ N(wx_0 * z_0, vx_0)
z_1 | z_0 ~ N(wz_1 * z_0, vz_1)
x_1 | z_1 ~ N(wx_1 * z_1, vx_1)
...
z_{T-1} | z_{T-2} ~ N(wz_{T-1} * z_{T-2}, vz_{T-1})
x_{T-1} | z_{T-1} ~ N(wx_{T-1} * z_{T-1}, vx_{T-1}).
"""
def
__init__
(
self
,
num_timesteps
,
transition_variances
=
1.
,
emission_variances
=
1.
,
transition_weights
=
1.
,
emission_weights
=
1.
,
dtype
=
tf
.
float32
):
"""Creates a gaussian hidden markov model.
Args:
num_timesteps: A python int, the number of timesteps in the model.
transition_variances: The variance of p(z_t | z_t-1). Can be a scalar,
setting all variances to be the same, or a Tensor of shape
[num_timesteps].
emission_variances: The variance of p(x_t | z_t). Can be a scalar,
setting all variances to be the same, or a Tensor of shape
[num_timesteps].
transition_weights: The weight that defines the linear function that
produces the mean of z_t given z_{t-1}. Can be a scalar, setting
all weights to be the same, or a Tensor of shape [num_timesteps-1].
emission_weights: The weight that defines the linear function that
produces the mean of x_t given z_t. Can be a scalar, setting
all weights to be the same, or a Tensor of shape [num_timesteps].
dtype: The datatype of the state.
"""
self
.
num_timesteps
=
num_timesteps
self
.
dtype
=
dtype
def
_expand_param
(
param
,
size
):
param
=
tf
.
convert_to_tensor
(
param
,
dtype
=
self
.
dtype
)
if
not
param
.
get_shape
().
as_list
():
param
=
tf
.
tile
(
param
[
tf
.
newaxis
],
[
size
])
return
param
def
_ta_for_param
(
param
):
size
=
tf
.
shape
(
param
)[
0
]
ta
=
tf
.
TensorArray
(
dtype
=
param
.
dtype
,
size
=
size
,
dynamic_size
=
False
,
clear_after_read
=
False
).
unstack
(
param
)
return
ta
self
.
transition_variances
=
_ta_for_param
(
_expand_param
(
transition_variances
,
num_timesteps
))
self
.
transition_weights
=
_ta_for_param
(
_expand_param
(
transition_weights
,
num_timesteps
-
1
))
em_var
=
_expand_param
(
emission_variances
,
num_timesteps
)
self
.
emission_variances
=
_ta_for_param
(
em_var
)
em_w
=
_expand_param
(
emission_weights
,
num_timesteps
)
self
.
emission_weights
=
_ta_for_param
(
em_w
)
self
.
_compute_covariances
(
em_w
,
em_var
)
def
_compute_covariances
(
self
,
emission_weights
,
emission_variances
):
"""Compute all covariance matrices.
Computes the covaraince matrix for the latent variables, the observations,
and the covariance between the latents and observations.
Args:
emission_weights: A Tensor of shape [num_timesteps] containing
the emission distribution weights at each timestep.
emission_variances: A Tensor of shape [num_timesteps] containing
the emiision distribution variances at each timestep.
"""
# Compute the marginal variance of each latent.
z_variances
=
[
self
.
transition_variances
.
read
(
0
)]
for
i
in
range
(
1
,
self
.
num_timesteps
):
z_variances
.
append
(
z_variances
[
i
-
1
]
*
tf
.
square
(
self
.
transition_weights
.
read
(
i
-
1
))
+
self
.
transition_variances
.
read
(
i
))
# Compute the latent covariance matrix.
sigma_z
=
[]
for
i
in
range
(
self
.
num_timesteps
):
sigma_z_row
=
[]
for
j
in
range
(
self
.
num_timesteps
):
if
i
==
j
:
sigma_z_row
.
append
(
z_variances
[
i
])
continue
min_ind
=
min
(
i
,
j
)
max_ind
=
max
(
i
,
j
)
weight
=
tf
.
reduce_prod
(
self
.
transition_weights
.
gather
(
tf
.
range
(
min_ind
,
max_ind
)))
sigma_z_row
.
append
(
z_variances
[
min_ind
]
*
weight
)
sigma_z
.
append
(
tf
.
stack
(
sigma_z_row
))
self
.
sigma_z
=
tf
.
stack
(
sigma_z
)
# Compute the observation covariance matrix.
x_weights_outer
=
tf
.
einsum
(
"i,j->ij"
,
emission_weights
,
emission_weights
)
self
.
sigma_x
=
x_weights_outer
*
self
.
sigma_z
+
tf
.
diag
(
emission_variances
)
# Compute the latent - observation covariance matrix.
# The first axis will index latents, the second axis will index observtions.
self
.
sigma_zx
=
emission_weights
[
tf
.
newaxis
,
:]
*
self
.
sigma_z
self
.
obs_dist
=
tfd
.
MultivariateNormalFullCovariance
(
loc
=
tf
.
zeros
([
self
.
num_timesteps
],
dtype
=
tf
.
float32
),
covariance_matrix
=
self
.
sigma_x
)
def
transition
(
self
,
t
,
z_prev
):
"""Compute the transition distribution p(z_t | z_t-1).
Args:
t: The current timestep, a scalar integer Tensor. When t=0 z_prev is
mostly ignored and the distribution p(z_0) is returned. z_prev is
'mostly' ignored because it is still used to derive batch_size.
z_prev: A [batch_size] set of states.
Returns:
p(z_t | z_t-1) as a univariate normal distribution.
"""
batch_size
=
tf
.
shape
(
z_prev
)[
0
]
scale
=
tf
.
sqrt
(
self
.
transition_variances
.
read
(
t
))
scale
=
tf
.
tile
(
scale
[
tf
.
newaxis
],
[
batch_size
])
loc
=
tf
.
cond
(
tf
.
greater
(
t
,
0
),
lambda
:
self
.
transition_weights
.
read
(
t
-
1
)
*
z_prev
,
lambda
:
tf
.
zeros_like
(
scale
))
return
tfd
.
Normal
(
loc
=
loc
,
scale
=
scale
)
def
emission
(
self
,
t
,
z
):
"""Compute the emission distribution p(x_t | z_t).
Args:
t: The current timestep, a scalar integer Tensor.
z: A [batch_size] set of the current states.
Returns:
p(x_t | z_t) as a univariate normal distribution.
"""
batch_size
=
tf
.
shape
(
z
)[
0
]
scale
=
tf
.
sqrt
(
self
.
emission_variances
.
read
(
t
))
scale
=
tf
.
tile
(
scale
[
tf
.
newaxis
],
[
batch_size
])
loc
=
self
.
emission_weights
.
read
(
t
)
*
z
return
tfd
.
Normal
(
loc
=
loc
,
scale
=
scale
)
def
filtering
(
self
,
t
,
z_prev
,
x_cur
):
"""Computes the filtering distribution p(z_t | z_{t-1}, x_t).
Args:
t: A python int, the index for z_t. When t is 0, z_prev is ignored,
giving p(z_0 | x_0).
z_prev: z_{t-1}, the previous z to condition on. A Tensor of shape
[batch_size].
x_cur: x_t, the current x to condition on. A Tensor of shape [batch_size].
Returns:
p(z_t | z_{t-1}, x_t) as a univariate normal distribution.
"""
z_prev
=
tf
.
convert_to_tensor
(
z_prev
)
x_cur
=
tf
.
convert_to_tensor
(
x_cur
)
batch_size
=
tf
.
shape
(
z_prev
)[
0
]
z_var
=
self
.
transition_variances
.
read
(
t
)
x_var
=
self
.
emission_variances
.
read
(
t
)
x_weight
=
self
.
emission_weights
.
read
(
t
)
prev_state_weight
=
x_var
/
(
tf
.
square
(
x_weight
)
*
z_var
+
x_var
)
prev_state_weight
*=
tf
.
cond
(
tf
.
greater
(
t
,
0
),
lambda
:
self
.
transition_weights
.
read
(
t
-
1
),
lambda
:
tf
.
zeros_like
(
prev_state_weight
))
cur_obs_weight
=
(
x_weight
*
z_var
)
/
(
tf
.
square
(
x_weight
)
*
z_var
+
x_var
)
loc
=
prev_state_weight
*
z_prev
+
cur_obs_weight
*
x_cur
scale
=
tf
.
sqrt
((
z_var
*
x_var
)
/
(
tf
.
square
(
x_weight
)
*
z_var
+
x_var
))
scale
=
tf
.
tile
(
scale
[
tf
.
newaxis
],
[
batch_size
])
return
tfd
.
Normal
(
loc
=
loc
,
scale
=
scale
)
def
smoothing
(
self
,
t
,
z_prev
,
xs
):
"""Computes the smoothing distribution p(z_t | z_{t-1}, x_{t:num_timesteps).
Args:
t: A python int, the index for z_t. When t is 0, z_prev is ignored,
giving p(z_0 | x_{0:num_timesteps-1}).
z_prev: z_{t-1}, the previous z to condition on. A Tensor of shape
[batch_size].
xs: x_{t:num_timesteps}, the future xs to condition on. A Tensor of shape
[num_timesteps - t, batch_size].
Returns:
p(z_t | z_{t-1}, x_{t:num_timesteps}) as a univariate normal distribution.
"""
xs
=
tf
.
convert_to_tensor
(
xs
)
z_prev
=
tf
.
convert_to_tensor
(
z_prev
)
batch_size
=
tf
.
shape
(
xs
)[
1
]
mess_mean
,
mess_prec
=
tf
.
cond
(
tf
.
less
(
t
,
self
.
num_timesteps
-
1
),
lambda
:
tf
.
unstack
(
self
.
_compute_backwards_messages
(
xs
[
1
:]).
read
(
0
)),
lambda
:
[
tf
.
zeros
([
batch_size
]),
tf
.
zeros
([
batch_size
])])
return
self
.
_smoothing_from_message
(
t
,
z_prev
,
xs
[
0
],
mess_mean
,
mess_prec
)
def
_smoothing_from_message
(
self
,
t
,
z_prev
,
x_t
,
mess_mean
,
mess_prec
):
"""Computes the smoothing distribution given message incoming to z_t.
Computes p(z_t | z_{t-1}, x_{t:num_timesteps}) given the message incoming
to the node for z_t.
Args:
t: A python int, the index for z_t. When t is 0, z_prev is ignored.
z_prev: z_{t-1}, the previous z to condition on. A Tensor of shape
[batch_size].
x_t: The observation x at timestep t.
mess_mean: The mean of the message incoming to z_t, in information form.
mess_prec: The precision of the message incoming to z_t.
Returns:
p(z_t | z_{t-1}, x_{t:num_timesteps}) as a univariate normal distribution.
"""
batch_size
=
tf
.
shape
(
x_t
)[
0
]
z_var
=
self
.
transition_variances
.
read
(
t
)
x_var
=
self
.
emission_variances
.
read
(
t
)
w_x
=
self
.
emission_weights
.
read
(
t
)
def
transition_term
():
return
(
tf
.
square
(
self
.
transition_weights
.
read
(
t
))
/
self
.
transition_variances
.
read
(
t
+
1
))
prec
=
1.
/
z_var
+
tf
.
square
(
w_x
)
/
x_var
+
mess_prec
prec
+=
tf
.
cond
(
tf
.
less
(
t
,
self
.
num_timesteps
-
1
),
transition_term
,
lambda
:
0.
)
mean
=
x_t
*
(
w_x
/
x_var
)
+
mess_mean
mean
+=
tf
.
cond
(
tf
.
greater
(
t
,
0
),
lambda
:
z_prev
*
(
self
.
transition_weights
.
read
(
t
-
1
)
/
z_var
),
lambda
:
0.
)
mean
=
tf
.
reshape
(
mean
/
prec
,
[
batch_size
])
scale
=
tf
.
reshape
(
tf
.
sqrt
(
1.
/
prec
),
[
batch_size
])
return
tfd
.
Normal
(
loc
=
mean
,
scale
=
scale
)
def
_compute_backwards_messages
(
self
,
xs
):
"""Computes the backwards messages used in smoothing."""
batch_size
=
tf
.
shape
(
xs
)[
1
]
num_xs
=
tf
.
shape
(
xs
)[
0
]
until_t
=
self
.
num_timesteps
-
num_xs
xs
=
tf
.
TensorArray
(
dtype
=
xs
.
dtype
,
size
=
num_xs
,
dynamic_size
=
False
,
clear_after_read
=
True
).
unstack
(
xs
)
messages_ta
=
tf
.
TensorArray
(
dtype
=
xs
.
dtype
,
size
=
num_xs
,
dynamic_size
=
False
,
clear_after_read
=
False
)
def
compute_message
(
t
,
prev_mean
,
prev_prec
,
messages_ta
):
"""Computes one step of the backwards messages."""
z_var
=
self
.
transition_variances
.
read
(
t
)
w_z
=
self
.
transition_weights
.
read
(
t
-
1
)
x_var
=
self
.
emission_variances
.
read
(
t
)
w_x
=
self
.
emission_weights
.
read
(
t
)
cur_x
=
xs
.
read
(
t
-
until_t
)
# If it isn't the first message, add the terms from the transition.
def
transition_term
():
return
(
tf
.
square
(
self
.
transition_weights
.
read
(
t
))
/
self
.
transition_variances
.
read
(
t
+
1
))
unary_prec
=
1
/
z_var
+
tf
.
square
(
w_x
)
/
x_var
unary_prec
+=
tf
.
cond
(
tf
.
less
(
t
,
self
.
num_timesteps
-
1
),
transition_term
,
lambda
:
0.
)
unary_mean
=
(
w_x
/
x_var
)
*
cur_x
pairwise_prec
=
w_z
/
z_var
next_prec
=
-
tf
.
square
(
pairwise_prec
)
/
(
unary_prec
+
prev_prec
)
next_mean
=
(
pairwise_prec
*
(
unary_mean
+
prev_mean
)
/
(
unary_prec
+
prev_prec
))
next_prec
=
tf
.
reshape
(
next_prec
,
[
batch_size
])
next_mean
=
tf
.
reshape
(
next_mean
,
[
batch_size
])
messages_ta
=
messages_ta
.
write
(
t
-
until_t
,
tf
.
stack
([
next_mean
,
next_prec
]))
return
t
-
1
,
next_mean
,
next_prec
,
messages_ta
def
pred
(
t
,
*
unused_args
):
return
tf
.
greater_equal
(
t
,
until_t
)
init_prec
=
tf
.
zeros
([
batch_size
],
dtype
=
xs
.
dtype
)
init_mean
=
tf
.
zeros
([
batch_size
],
dtype
=
xs
.
dtype
)
t0
=
tf
.
constant
(
self
.
num_timesteps
-
1
,
dtype
=
tf
.
int32
)
outs
=
tf
.
while_loop
(
pred
,
compute_message
,
(
t0
,
init_mean
,
init_prec
,
messages_ta
))
messages
=
outs
[
-
1
]
return
messages
def
lookahead
(
self
,
t
,
z_prev
):
"""Compute the 'lookahead' distribution, p(x_{t:T} | z_{t-1}).
Args:
t: A scalar Tensor int, the current timestep. Must be at least 1.
z_prev: The latent state at time t-1. A Tensor of shape [batch_size].
Returns:
p(x_{t:T} | z_{t-1}) as a multivariate normal distribution.
"""
z_prev
=
tf
.
convert_to_tensor
(
z_prev
)
sigma_zx
=
self
.
sigma_zx
[
t
-
1
,
t
:]
z_var
=
self
.
sigma_z
[
t
-
1
,
t
-
1
]
mean
=
tf
.
einsum
(
"i,j->ij"
,
z_prev
,
sigma_zx
)
/
z_var
variance
=
(
self
.
sigma_x
[
t
:,
t
:]
-
tf
.
einsum
(
"i,j->ij"
,
sigma_zx
,
sigma_zx
)
/
z_var
)
return
tfd
.
MultivariateNormalFullCovariance
(
loc
=
mean
,
covariance_matrix
=
variance
)
def
likelihood
(
self
,
xs
):
"""Compute the true marginal likelihood of the data.
Args:
xs: The observations, a [num_timesteps, batch_size] float Tensor.
Returns:
likelihoods: A [batch_size] float Tensor representing the likelihood of
each sequence of observations in the batch.
"""
return
self
.
obs_dist
.
log_prob
(
tf
.
transpose
(
xs
))
class
TrainableGaussianHMM
(
GaussianHMM
,
base
.
ELBOTrainableSequenceModel
):
"""An interface between importance-sampling training methods and the GHMM."""
def
__init__
(
self
,
num_timesteps
,
proposal_type
,
transition_variances
=
1.
,
emission_variances
=
1.
,
transition_weights
=
1.
,
emission_weights
=
1.
,
random_seed
=
None
,
dtype
=
tf
.
float32
):
"""Constructs a trainable Gaussian HMM.
Args:
num_timesteps: A python int, the number of timesteps in the model.
proposal_type: The type of proposal to use in the importance sampling
setup. Could be "filtering", "smoothing", "prior", "true-filtering",
or "true-smoothing". If "true-filtering" or "true-smoothing" are
selected, then the true filtering or smoothing distributions are used to
propose new states. If "learned-filtering" is selected then a
distribution with learnable parameters is used. Specifically at each
timestep the proposal is Gaussian with mean that is a learnable linear
function of the previous state and current observation. The log variance
is a per-timestep learnable constant. "learned-smoothing" is similar,
but the mean is a learnable linear function of the previous state and
all future observations. Note that this proposal class includes the true
posterior. If "prior" is selected then states are proposed from the
model's prior.
transition_variances: The variance of p(z_t | z_t-1). Can be a scalar,
setting all variances to be the same, or a Tensor of shape
[num_timesteps].
emission_variances: The variance of p(x_t | z_t). Can be a scalar,
setting all variances to be the same, or a Tensor of shape
[num_timesteps].
transition_weights: The weight that defines the linear function that
produces the mean of z_t given z_{t-1}. Can be a scalar, setting
all weights to be the same, or a Tensor of shape [num_timesteps-1].
emission_weights: The weight that defines the linear function that
produces the mean of x_t given z_t. Can be a scalar, setting
all weights to be the same, or a Tensor of shape [num_timesteps].
random_seed: A seed for the proposal sampling, mainly useful for testing.
dtype: The datatype of the state.
"""
super
(
TrainableGaussianHMM
,
self
).
__init__
(
num_timesteps
,
transition_variances
,
emission_variances
,
transition_weights
,
emission_weights
,
dtype
=
dtype
)
self
.
random_seed
=
random_seed
assert
proposal_type
in
[
"filtering"
,
"smoothing"
,
"prior"
,
"true-filtering"
,
"true-smoothing"
]
if
proposal_type
==
"true-filtering"
:
self
.
proposal
=
self
.
_filtering_proposal
elif
proposal_type
==
"true-smoothing"
:
self
.
proposal
=
self
.
_smoothing_proposal
elif
proposal_type
==
"prior"
:
self
.
proposal
=
self
.
transition
elif
proposal_type
==
"filtering"
:
self
.
_learned_proposal_fn
=
base
.
NonstationaryLinearDistribution
(
num_timesteps
,
inputs_per_timestep
=
[
1
]
+
[
2
]
*
(
num_timesteps
-
1
))
self
.
proposal
=
self
.
_learned_filtering_proposal
elif
proposal_type
==
"smoothing"
:
inputs_per_timestep
=
[
num_timesteps
]
+
[
num_timesteps
-
t
for
t
in
range
(
num_timesteps
-
1
)]
self
.
_learned_proposal_fn
=
base
.
NonstationaryLinearDistribution
(
num_timesteps
,
inputs_per_timestep
=
inputs_per_timestep
)
self
.
proposal
=
self
.
_learned_smoothing_proposal
def
set_observations
(
self
,
xs
,
seq_lengths
):
"""Sets the observations and stores the backwards messages."""
# Squeeze out data dimension since everything is 1-d.
xs
=
tf
.
squeeze
(
xs
)
self
.
batch_size
=
tf
.
shape
(
xs
)[
1
]
super
(
TrainableGaussianHMM
,
self
).
set_observations
(
xs
,
seq_lengths
)
self
.
messages
=
self
.
_compute_backwards_messages
(
xs
[
1
:])
def
zero_state
(
self
,
batch_size
,
dtype
):
return
tf
.
zeros
([
batch_size
],
dtype
=
dtype
)
def
propose_and_weight
(
self
,
state
,
t
):
"""Computes the next state and log weights for the GHMM."""
state_shape
=
tf
.
shape
(
state
)
xt
=
self
.
observations
[
t
]
p_zt
=
self
.
transition
(
t
,
state
)
q_zt
=
self
.
proposal
(
t
,
state
)
zt
=
q_zt
.
sample
(
seed
=
self
.
random_seed
)
zt
=
tf
.
reshape
(
zt
,
state_shape
)
p_xt_given_zt
=
self
.
emission
(
t
,
zt
)
log_p_zt
=
p_zt
.
log_prob
(
zt
)
log_q_zt
=
q_zt
.
log_prob
(
zt
)
log_p_xt_given_zt
=
p_xt_given_zt
.
log_prob
(
xt
)
weight
=
log_p_zt
+
log_p_xt_given_zt
-
log_q_zt
return
weight
,
zt
def
_filtering_proposal
(
self
,
t
,
state
):
"""Uses the stored observations to compute the filtering distribution."""
cur_x
=
self
.
observations
[
t
]
return
self
.
filtering
(
t
,
state
,
cur_x
)
def
_smoothing_proposal
(
self
,
t
,
state
):
"""Uses the stored messages to compute the smoothing distribution."""
mess_mean
,
mess_prec
=
tf
.
cond
(
tf
.
less
(
t
,
self
.
num_timesteps
-
1
),
lambda
:
tf
.
unstack
(
self
.
messages
.
read
(
t
)),
lambda
:
[
tf
.
zeros
([
self
.
batch_size
]),
tf
.
zeros
([
self
.
batch_size
])])
return
self
.
_smoothing_from_message
(
t
,
state
,
self
.
observations
[
t
],
mess_mean
,
mess_prec
)
def
_learned_filtering_proposal
(
self
,
t
,
state
):
cur_x
=
self
.
observations
[
t
]
inputs
=
tf
.
cond
(
tf
.
greater
(
t
,
0
),
lambda
:
tf
.
stack
([
state
,
cur_x
],
axis
=
0
),
lambda
:
cur_x
[
tf
.
newaxis
,
:])
return
self
.
_learned_proposal_fn
(
t
,
inputs
)
def
_learned_smoothing_proposal
(
self
,
t
,
state
):
xs
=
self
.
observations_ta
.
gather
(
tf
.
range
(
t
,
self
.
num_timesteps
))
inputs
=
tf
.
cond
(
tf
.
greater
(
t
,
0
),
lambda
:
tf
.
concat
([
state
[
tf
.
newaxis
,
:],
xs
],
axis
=
0
),
lambda
:
xs
)
return
self
.
_learned_proposal_fn
(
t
,
inputs
)
research/fivo/fivo/models/ghmm_test.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.models.ghmm"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
fivo.models.ghmm
import
GaussianHMM
from
fivo.models.ghmm
import
TrainableGaussianHMM
class
GHMMTest
(
tf
.
test
.
TestCase
):
def
test_transition_no_weights
(
self
):
with
self
.
test_session
()
as
sess
:
ghmm
=
GaussianHMM
(
3
,
transition_variances
=
[
1.
,
2.
,
3.
])
prev_z
=
tf
.
constant
([
1.
,
2.
],
dtype
=
tf
.
float32
)
z0
=
ghmm
.
transition
(
0
,
prev_z
)
z1
=
ghmm
.
transition
(
1
,
prev_z
)
z2
=
ghmm
.
transition
(
2
,
prev_z
)
outs
=
sess
.
run
([
z0
.
mean
(),
z0
.
variance
(),
z1
.
mean
(),
z1
.
variance
(),
z2
.
mean
(),
z2
.
variance
()])
self
.
assertAllClose
(
outs
,
[[
0.
,
0.
],
[
1.
,
1.
],
[
1.
,
2.
],
[
2.
,
2.
],
[
1.
,
2.
],
[
3.
,
3.
]])
def
test_transition_with_weights
(
self
):
with
self
.
test_session
()
as
sess
:
ghmm
=
GaussianHMM
(
3
,
transition_variances
=
[
1.
,
2.
,
3.
],
transition_weights
=
[
2.
,
3.
])
prev_z
=
tf
.
constant
([
1.
,
2.
],
dtype
=
tf
.
float32
)
z0
=
ghmm
.
transition
(
0
,
prev_z
)
z1
=
ghmm
.
transition
(
1
,
prev_z
)
z2
=
ghmm
.
transition
(
2
,
prev_z
)
outs
=
sess
.
run
([
z0
.
mean
(),
z0
.
variance
(),
z1
.
mean
(),
z1
.
variance
(),
z2
.
mean
(),
z2
.
variance
()])
self
.
assertAllClose
(
outs
,
[[
0.
,
0.
],
[
1.
,
1.
],
[
2.
,
4.
],
[
2.
,
2.
],
[
3.
,
6.
],
[
3.
,
3.
]])
def
test_emission_no_weights
(
self
):
with
self
.
test_session
()
as
sess
:
ghmm
=
GaussianHMM
(
3
,
emission_variances
=
[
1.
,
2.
,
3.
])
z
=
tf
.
constant
([
1.
,
2.
],
dtype
=
tf
.
float32
)
x0
=
ghmm
.
emission
(
0
,
z
)
x1
=
ghmm
.
emission
(
1
,
z
)
x2
=
ghmm
.
emission
(
2
,
z
)
outs
=
sess
.
run
([
x0
.
mean
(),
x0
.
variance
(),
x1
.
mean
(),
x1
.
variance
(),
x2
.
mean
(),
x2
.
variance
()])
self
.
assertAllClose
(
outs
,
[[
1.
,
2.
],
[
1.
,
1.
],
[
1.
,
2.
],
[
2.
,
2.
],
[
1.
,
2.
],
[
3.
,
3.
]])
def
test_emission_with_weights
(
self
):
with
self
.
test_session
()
as
sess
:
ghmm
=
GaussianHMM
(
3
,
emission_variances
=
[
1.
,
2.
,
3.
],
emission_weights
=
[
1.
,
2.
,
3.
])
z
=
tf
.
constant
([
1.
,
2.
],
dtype
=
tf
.
float32
)
x0
=
ghmm
.
emission
(
0
,
z
)
x1
=
ghmm
.
emission
(
1
,
z
)
x2
=
ghmm
.
emission
(
2
,
z
)
outs
=
sess
.
run
([
x0
.
mean
(),
x0
.
variance
(),
x1
.
mean
(),
x1
.
variance
(),
x2
.
mean
(),
x2
.
variance
()])
self
.
assertAllClose
(
outs
,
[[
1.
,
2.
],
[
1.
,
1.
],
[
2.
,
4.
],
[
2.
,
2.
],
[
3.
,
6.
],
[
3.
,
3.
]])
def
test_filtering_no_weights
(
self
):
with
self
.
test_session
()
as
sess
:
ghmm
=
GaussianHMM
(
3
,
transition_variances
=
[
1.
,
2.
,
3.
],
emission_variances
=
[
4.
,
5.
,
6.
])
z_prev
=
tf
.
constant
([
1.
,
2.
],
dtype
=
tf
.
float32
)
x_cur
=
tf
.
constant
([
3.
,
4.
],
dtype
=
tf
.
float32
)
expected_outs
=
[[[
3.
/
5.
,
4.
/
5.
],
[
4.
/
5.
,
4.
/
5.
]],
[[
11.
/
7.
,
18.
/
7.
],
[
10.
/
7.
,
10.
/
7.
]],
[[
5.
/
3.
,
8.
/
3.
],
[
2.
,
2.
]]]
f_post_0
=
ghmm
.
filtering
(
0
,
z_prev
,
x_cur
)
f_post_1
=
ghmm
.
filtering
(
1
,
z_prev
,
x_cur
)
f_post_2
=
ghmm
.
filtering
(
2
,
z_prev
,
x_cur
)
outs
=
sess
.
run
([[
f_post_0
.
mean
(),
f_post_0
.
variance
()],
[
f_post_1
.
mean
(),
f_post_1
.
variance
()],
[
f_post_2
.
mean
(),
f_post_2
.
variance
()]])
self
.
assertAllClose
(
expected_outs
,
outs
)
def
test_filtering_with_weights
(
self
):
with
self
.
test_session
()
as
sess
:
ghmm
=
GaussianHMM
(
3
,
transition_variances
=
[
1.
,
2.
,
3.
],
emission_variances
=
[
4.
,
5.
,
6.
],
transition_weights
=
[
7.
,
8.
],
emission_weights
=
[
9.
,
10.
,
11
])
z_prev
=
tf
.
constant
([
1.
,
2.
],
dtype
=
tf
.
float32
)
x_cur
=
tf
.
constant
([
3.
,
4.
],
dtype
=
tf
.
float32
)
expected_outs
=
[[[
27.
/
85.
,
36.
/
85.
],
[
4.
/
85.
,
4.
/
85.
]],
[[
95.
/
205.
,
150.
/
205.
],
[
10.
/
205.
,
10.
/
205.
]],
[[
147.
/
369.
,
228.
/
369.
],
[
18.
/
369.
,
18.
/
369.
]]]
f_post_0
=
ghmm
.
filtering
(
0
,
z_prev
,
x_cur
)
f_post_1
=
ghmm
.
filtering
(
1
,
z_prev
,
x_cur
)
f_post_2
=
ghmm
.
filtering
(
2
,
z_prev
,
x_cur
)
outs
=
sess
.
run
([[
f_post_0
.
mean
(),
f_post_0
.
variance
()],
[
f_post_1
.
mean
(),
f_post_1
.
variance
()],
[
f_post_2
.
mean
(),
f_post_2
.
variance
()]])
self
.
assertAllClose
(
expected_outs
,
outs
)
def
test_smoothing
(
self
):
with
self
.
test_session
()
as
sess
:
ghmm
=
GaussianHMM
(
3
,
transition_variances
=
[
1.
,
2.
,
3.
],
emission_variances
=
[
4.
,
5.
,
6.
])
z_prev
=
tf
.
constant
([
1.
,
2.
],
dtype
=
tf
.
float32
)
xs
=
tf
.
constant
([[
1.
,
2.
],
[
3.
,
4.
],
[
5.
,
6.
]],
dtype
=
tf
.
float32
)
s_post1
=
ghmm
.
smoothing
(
0
,
z_prev
,
xs
)
outs
=
sess
.
run
([
s_post1
.
mean
(),
s_post1
.
variance
()])
expected_outs
=
[[
281.
/
421.
,
410.
/
421.
],
[
292.
/
421.
,
292.
/
421.
]]
self
.
assertAllClose
(
expected_outs
,
outs
)
expected_outs
=
[[
149.
/
73.
,
222.
/
73.
],
[
90.
/
73.
,
90.
/
73.
]]
s_post2
=
ghmm
.
smoothing
(
1
,
z_prev
,
xs
[
1
:])
outs
=
sess
.
run
([
s_post2
.
mean
(),
s_post2
.
variance
()])
self
.
assertAllClose
(
expected_outs
,
outs
)
s_post3
=
ghmm
.
smoothing
(
2
,
z_prev
,
xs
[
2
:])
outs
=
sess
.
run
([
s_post3
.
mean
(),
s_post3
.
variance
()])
expected_outs
=
[[
7.
/
3.
,
10.
/
3.
],
[
2.
,
2.
]]
self
.
assertAllClose
(
expected_outs
,
outs
)
def
test_smoothing_with_weights
(
self
):
with
self
.
test_session
()
as
sess
:
x_weight
=
np
.
array
([
4
,
5
,
6
,
7
],
dtype
=
np
.
float32
)
sigma_x
=
np
.
array
([
5
,
6
,
7
,
8
],
dtype
=
np
.
float32
)
z_weight
=
np
.
array
([
1
,
2
,
3
],
dtype
=
np
.
float32
)
sigma_z
=
np
.
array
([
1
,
2
,
3
,
4
],
dtype
=
np
.
float32
)
z_prev
=
np
.
array
([
1
,
2
],
dtype
=
np
.
float32
)
batch_size
=
2
xs
=
np
.
array
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
=
np
.
float32
)
z_cov
,
x_cov
,
z_x_cov
=
self
.
_compute_covariance_matrices
(
x_weight
,
z_weight
,
sigma_x
,
sigma_z
)
expected_outs
=
[]
# Compute mean and variance for z_0 when we don't condition
# on previous zs.
sigma_12
=
z_x_cov
[
0
,
:]
sigma_12_22
=
np
.
dot
(
sigma_12
,
np
.
linalg
.
inv
(
x_cov
))
mean
=
np
.
dot
(
sigma_12_22
,
xs
)
variance
=
np
.
squeeze
(
z_cov
[
0
,
0
]
-
np
.
dot
(
sigma_12_22
,
sigma_12
))
expected_outs
.
append
([
mean
,
np
.
tile
(
variance
,
[
batch_size
])])
# Compute mean and variance for remaining z_ts.
for
t
in
xrange
(
1
,
4
):
sigma_12
=
np
.
concatenate
([[
z_cov
[
t
,
t
-
1
]],
z_x_cov
[
t
,
t
:]])
sigma_22
=
np
.
vstack
((
np
.
hstack
((
z_cov
[
t
-
1
,
t
-
1
],
z_x_cov
[
t
-
1
,
t
:])),
np
.
hstack
((
np
.
transpose
([
z_x_cov
[
t
-
1
,
t
:]]),
x_cov
[
t
:,
t
:]))
))
sigma_12_22
=
np
.
dot
(
sigma_12
,
np
.
linalg
.
inv
(
sigma_22
))
mean
=
np
.
dot
(
sigma_12_22
,
np
.
vstack
((
z_prev
,
xs
[
t
:])))
variance
=
np
.
squeeze
(
z_cov
[
t
,
t
]
-
np
.
dot
(
sigma_12_22
,
sigma_12
))
expected_outs
.
append
([
mean
,
np
.
tile
(
variance
,
[
batch_size
])])
ghmm
=
GaussianHMM
(
4
,
transition_variances
=
sigma_z
,
emission_variances
=
sigma_x
,
transition_weights
=
z_weight
,
emission_weights
=
x_weight
)
out_dists
=
[
ghmm
.
smoothing
(
t
,
z_prev
,
xs
[
t
:])
for
t
in
range
(
0
,
4
)]
outs
=
[[
d
.
mean
(),
d
.
variance
()]
for
d
in
out_dists
]
run_outs
=
sess
.
run
(
outs
)
self
.
assertAllClose
(
expected_outs
,
run_outs
)
def
test_covariance_matrices
(
self
):
with
self
.
test_session
()
as
sess
:
x_weight
=
np
.
array
([
4
,
5
,
6
,
7
],
dtype
=
np
.
float32
)
sigma_x
=
np
.
array
([
5
,
6
,
7
,
8
],
dtype
=
np
.
float32
)
z_weight
=
np
.
array
([
1
,
2
,
3
],
dtype
=
np
.
float32
)
sigma_z
=
np
.
array
([
1
,
2
,
3
,
4
],
dtype
=
np
.
float32
)
z_cov
,
x_cov
,
z_x_cov
=
self
.
_compute_covariance_matrices
(
x_weight
,
z_weight
,
sigma_x
,
sigma_z
)
ghmm
=
GaussianHMM
(
4
,
transition_variances
=
sigma_z
,
emission_variances
=
sigma_x
,
transition_weights
=
z_weight
,
emission_weights
=
x_weight
)
self
.
assertAllClose
(
z_cov
,
sess
.
run
(
ghmm
.
sigma_z
))
self
.
assertAllClose
(
x_cov
,
sess
.
run
(
ghmm
.
sigma_x
))
self
.
assertAllClose
(
z_x_cov
,
sess
.
run
(
ghmm
.
sigma_zx
))
def
_compute_covariance_matrices
(
self
,
x_weight
,
z_weight
,
sigma_x
,
sigma_z
):
# Create z covariance matrix from the definitions.
z_cov
=
np
.
zeros
([
4
,
4
])
z_cov
[
0
,
0
]
=
sigma_z
[
0
]
for
i
in
range
(
1
,
4
):
z_cov
[
i
,
i
]
=
(
z_cov
[
i
-
1
,
i
-
1
]
*
np
.
square
(
z_weight
[
i
-
1
])
+
sigma_z
[
i
])
for
i
in
range
(
4
):
for
j
in
range
(
4
):
if
i
==
j
:
continue
min_ind
=
min
(
i
,
j
)
max_ind
=
max
(
i
,
j
)
weights
=
np
.
prod
(
z_weight
[
min_ind
:
max_ind
])
z_cov
[
i
,
j
]
=
z_cov
[
min_ind
,
min_ind
]
*
weights
# Compute the x covariance matrix and the z-x covariance matrix.
x_weights_outer
=
np
.
outer
(
x_weight
,
x_weight
)
x_cov
=
x_weights_outer
*
z_cov
+
np
.
diag
(
sigma_x
)
z_x_cov
=
x_weight
*
z_cov
return
z_cov
,
x_cov
,
z_x_cov
def
test_lookahead
(
self
):
x_weight
=
np
.
array
([
4
,
5
,
6
,
7
],
dtype
=
np
.
float32
)
sigma_x
=
np
.
array
([
5
,
6
,
7
,
8
],
dtype
=
np
.
float32
)
z_weight
=
np
.
array
([
1
,
2
,
3
],
dtype
=
np
.
float32
)
sigma_z
=
np
.
array
([
1
,
2
,
3
,
4
],
dtype
=
np
.
float32
)
z_prev
=
np
.
array
([
1
,
2
],
dtype
=
np
.
float32
)
with
self
.
test_session
()
as
sess
:
z_cov
,
x_cov
,
z_x_cov
=
self
.
_compute_covariance_matrices
(
x_weight
,
z_weight
,
sigma_x
,
sigma_z
)
expected_outs
=
[]
for
t
in
range
(
1
,
4
):
sigma_12
=
z_x_cov
[
t
-
1
,
t
:]
z_var
=
z_cov
[
t
-
1
,
t
-
1
]
mean
=
np
.
outer
(
z_prev
,
sigma_12
/
z_var
)
variance
=
x_cov
[
t
:,
t
:]
-
np
.
outer
(
sigma_12
,
sigma_12
)
/
z_var
expected_outs
.
append
([
mean
,
variance
])
ghmm
=
GaussianHMM
(
4
,
transition_variances
=
sigma_z
,
emission_variances
=
sigma_x
,
transition_weights
=
z_weight
,
emission_weights
=
x_weight
)
out_dists
=
[
ghmm
.
lookahead
(
t
,
z_prev
)
for
t
in
range
(
1
,
4
)]
outs
=
[[
d
.
mean
(),
d
.
covariance
()]
for
d
in
out_dists
]
run_outs
=
sess
.
run
(
outs
)
self
.
assertAllClose
(
expected_outs
,
run_outs
)
class
TrainableGHMMTest
(
tf
.
test
.
TestCase
):
def
test_filtering_proposal
(
self
):
"""Check that stashing the xs doesn't change the filtering distributions."""
with
self
.
test_session
()
as
sess
:
ghmm
=
TrainableGaussianHMM
(
3
,
"filtering"
,
transition_variances
=
[
1.
,
2.
,
3.
],
emission_variances
=
[
4.
,
5.
,
6.
],
transition_weights
=
[
7.
,
8.
],
emission_weights
=
[
9.
,
10.
,
11
])
observations
=
tf
.
constant
([[
3.
,
4.
],
[
3.
,
4.
],
[
3.
,
4.
]],
dtype
=
tf
.
float32
)
ghmm
.
set_observations
(
observations
,
[
3
,
3
])
z_prev
=
tf
.
constant
([
1.
,
2.
],
dtype
=
tf
.
float32
)
proposals
=
[
ghmm
.
_filtering_proposal
(
t
,
z_prev
)
for
t
in
range
(
3
)]
dist_params
=
[[
p
.
mean
(),
p
.
variance
()]
for
p
in
proposals
]
expected_outs
=
[[[
27.
/
85.
,
36.
/
85.
],
[
4.
/
85.
,
4.
/
85.
]],
[[
95.
/
205.
,
150.
/
205.
],
[
10.
/
205.
,
10.
/
205.
]],
[[
147.
/
369.
,
228.
/
369.
],
[
18.
/
369.
,
18.
/
369.
]]]
self
.
assertAllClose
(
expected_outs
,
sess
.
run
(
dist_params
))
def
test_smoothing_proposal
(
self
):
with
self
.
test_session
()
as
sess
:
ghmm
=
TrainableGaussianHMM
(
3
,
"smoothing"
,
transition_variances
=
[
1.
,
2.
,
3.
],
emission_variances
=
[
4.
,
5.
,
6.
])
xs
=
tf
.
constant
([[
1.
,
2.
],
[
3.
,
4.
],
[
5.
,
6.
]],
dtype
=
tf
.
float32
)
ghmm
.
set_observations
(
xs
,
[
3
,
3
])
z_prev
=
tf
.
constant
([
1.
,
2.
],
dtype
=
tf
.
float32
)
proposals
=
[
ghmm
.
_smoothing_proposal
(
t
,
z_prev
)
for
t
in
range
(
3
)]
dist_params
=
[[
p
.
mean
(),
p
.
variance
()]
for
p
in
proposals
]
expected_outs
=
[[[
281.
/
421.
,
410.
/
421.
],
[
292.
/
421.
,
292.
/
421.
]],
[[
149.
/
73.
,
222.
/
73.
],
[
90.
/
73.
,
90.
/
73.
]],
[[
7.
/
3.
,
10.
/
3.
],
[
2.
,
2.
]]]
self
.
assertAllClose
(
expected_outs
,
sess
.
run
(
dist_params
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/fivo/fivo/models/srnn.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SRNN classes."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
namedtuple
import
functools
import
sonnet
as
snt
import
tensorflow
as
tf
from
fivo.models
import
base
SRNNState
=
namedtuple
(
"SRNNState"
,
"rnn_state latent_encoded"
)
class
SRNN
(
object
):
"""Implementation of a Stochastic Recurrent Neural Network (SRNN).
Introduced in "Sequential Neural Models with Stochastic Layers"
by Fraccaro et al. https://arxiv.org/pdf/1605.07571.pdf.
The SRNN is a sequence model similar to an RNN that uses stochastic latent
variables to improve its representational power. It can be thought of as a
sequential analogue to the variational auto-encoder (VAE).
The SRNN has a deterministic RNN as its backbone, represented by the
sequence of RNN hidden states h_t. The latent state is conditioned on
the deterministic RNN states and previous latent state. Unlike the VRNN, the
the RNN state is not conditioned on the previous latent state. The latent
states have a Markov structure and it is assumed that
p(z_t | z_{1:t-1}) = p(z_t | z_{t-1}).
In this implementation of the SRNN the latent state z_t is Gaussian. The
model's prior over z_t (also called the transition distribution) is
distributed as Normal(mu_t, diag(sigma_t^2)) where mu_t and sigma_t are the
mean and standard deviation output from a fully connected network that accepts
the rnn hidden state h_t and previous latent state z_{t-1} as input.
The emission distribution p(x_t|z_t, h_t) is conditioned on the latent state
z_t as well as the current RNN hidden state h_t via a fully connected network.
To increase the modeling power of the SRNN, two additional networks are
used to extract features from the data and the latent state. Those networks
are called data_encoder and latent_encoder respectively.
For an example of how to call the SRNN's methods see sample_step.
There are a few differences between this exposition and the paper. The main
goal was to be consistent with the VRNN code. A few components are renamed.
The backward RNN for approximating the posterior, g_phi_a in the paper, is the
rev_rnn_cell. The forward RNN that conditions the latent distribution, d in
the paper, is the rnn_cell. The paper doesn't name the NN's that serve as
feature extractors, and we name them here as the data_encoder and
latent_encoder.
"""
def
__init__
(
self
,
rnn_cell
,
data_encoder
,
latent_encoder
,
transition
,
emission
,
random_seed
=
None
):
"""Create a SRNN.
Args:
rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
deterministic backbone of the SRNN. The inputs to the RNN will be the
the encoded input of the current timestep, a Tensor of shape
[batch_size, encoded_data_size].
data_encoder: A callable that accepts a batch of data x_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument the inputs x_t, a Tensor of the shape
[batch_size, data_size] and return a Tensor of shape
[batch_size, encoded_data_size]. This callable will be called multiple
times in the SRNN cell so if scoping is not handled correctly then
multiple copies of the variables in this network could be made. It is
recommended to use a snt.nets.MLP module, which takes care of this for
you.
latent_encoder: A callable that accepts a latent state z_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument a Tensor of shape [batch_size, latent_size] and
return a Tensor of shape [batch_size, encoded_latent_size].
This callable must also have the property 'output_size' defined,
returning encoded_latent_size.
transition: A callable that implements the transition distribution
p(z_t|h_t, z_t-1). Must accept as argument the previous RNN hidden state
and previous encoded latent state then return a tf.distributions.Normal
distribution conditioned on the input.
emission: A callable that implements the emission distribution
p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
and the RNN hidden state and return a subclass of
tf.distributions.Distribution that can be used to evaluate the logprob
of the targets.
random_seed: The seed for the random ops. Sets the seed for sample_step.
"""
self
.
random_seed
=
random_seed
self
.
rnn_cell
=
rnn_cell
self
.
data_encoder
=
data_encoder
self
.
latent_encoder
=
latent_encoder
self
.
encoded_z_size
=
latent_encoder
.
output_size
self
.
state_size
=
(
self
.
rnn_cell
.
state_size
)
self
.
_transition
=
transition
self
.
_emission
=
emission
def
zero_state
(
self
,
batch_size
,
dtype
):
"""The initial state of the SRNN.
Contains the initial state of the RNN and the inital encoded latent.
Args:
batch_size: The batch size.
dtype: The data type of the SRNN.
Returns:
zero_state: The initial state of the SRNN.
"""
return
SRNNState
(
rnn_state
=
self
.
rnn_cell
.
zero_state
(
batch_size
,
dtype
),
latent_encoded
=
tf
.
zeros
(
[
batch_size
,
self
.
latent_encoder
.
output_size
],
dtype
=
dtype
))
def
run_rnn
(
self
,
prev_rnn_state
,
inputs
):
"""Runs the deterministic RNN for one step.
Args:
prev_rnn_state: The state of the RNN from the previous timestep.
inputs: A Tensor of shape [batch_size, data_size], the current inputs to
the model. Most often this is x_{t-1}, the previous token in the
observation sequence.
Returns:
rnn_out: The output of the RNN.
rnn_state: The new state of the RNN.
"""
rnn_inputs
=
self
.
data_encoder
(
tf
.
to_float
(
inputs
))
rnn_out
,
rnn_state
=
self
.
rnn_cell
(
rnn_inputs
,
prev_rnn_state
)
return
rnn_out
,
rnn_state
def
transition
(
self
,
rnn_out
,
prev_latent_encoded
):
"""Computes the transition distribution p(z_t|h_t, z_{t-1}).
Note that p(z_t | h_t, z_{t-1}) = p(z_t| z_{t-1}, x_{1:t-1})
Args:
rnn_out: The output of the rnn for the current timestep.
prev_latent_encoded: Float Tensor of shape
[batch_size, encoded_latent_size], the previous latent state z_{t-1}
run through latent_encoder.
Returns:
p(z_t | h_t): A normal distribution with event shape
[batch_size, latent_size].
"""
return
self
.
_transition
(
rnn_out
,
prev_latent_encoded
)
def
emission
(
self
,
latent
,
rnn_out
):
"""Computes the emission distribution p(x_t | z_t, h_t).
Note that p(x_t | z_t, h_t) = p(x_t | z_t, x_{1:t-1})
Args:
latent: The stochastic latent state z_t.
rnn_out: The output of the rnn for the current timestep.
Returns:
p(x_t | z_t, h_t): A distribution with event shape
[batch_size, data_size].
latent_encoded: The latent state encoded with latent_encoder. Should be
passed to transition() on the next timestep.
"""
latent_encoded
=
self
.
latent_encoder
(
latent
)
return
self
.
_emission
(
latent_encoded
,
rnn_out
),
latent_encoded
def
sample_step
(
self
,
prev_state
,
inputs
,
unused_t
):
"""Samples one output from the model.
Args:
prev_state: The previous state of the model, a SRNNState containing the
previous rnn state and the previous encoded latent.
inputs: A Tensor of shape [batch_size, data_size], the current inputs to
the model. Most often this is x_{t-1}, the previous token in the
observation sequence.
unused_t: The current timestep. Not used currently.
Returns:
new_state: The next state of the model, a SRNNState.
xt: A float Tensor of shape [batch_size, data_size], an output sampled
from the emission distribution.
"""
rnn_out
,
rnn_state
=
self
.
run_rnn
(
prev_state
.
rnn_state
,
inputs
)
p_zt
=
self
.
transition
(
rnn_out
,
prev_state
.
latent_encoded
)
zt
=
p_zt
.
sample
(
seed
=
self
.
random_seed
)
p_xt_given_zt
,
latent_encoded
=
self
.
emission
(
zt
,
rnn_out
)
xt
=
p_xt_given_zt
.
sample
(
seed
=
self
.
random_seed
)
new_state
=
SRNNState
(
rnn_state
=
rnn_state
,
latent_encoded
=
latent_encoded
)
return
new_state
,
tf
.
to_float
(
xt
)
# pylint: disable=invalid-name
# pylint thinks this is a top-level constant.
TrainableSRNNState
=
namedtuple
(
"TrainableSRNNState"
,
SRNNState
.
_fields
+
(
"rnn_out"
,))
# pylint: enable=g-invalid-name
class
TrainableSRNN
(
SRNN
,
base
.
ELBOTrainableSequenceModel
):
"""A SRNN subclass with proposals and methods for training and evaluation.
This class adds proposals used for training with importance-sampling based
methods such as the ELBO. The model can be configured to propose from one
of three proposals: a learned filtering proposal, a learned smoothing
proposal, or the prior (i.e. the transition distribution).
As described in the SRNN paper, the learned filtering proposal is
parameterized by a fully connected neural network that accepts as input the
current target x_t and the current rnn output h_t. The learned smoothing
proposal is also given the hidden state of an RNN run in reverse over the
inputs, so as to incorporate information about future observations.
All learned proposals use the 'res_q' parameterization, meaning that instead
of directly producing the mean of z_t, the proposal network predicts the
'residual' from the prior's mean. This is explored more in section 3.3 of
https://arxiv.org/pdf/1605.07571.pdf.
During training, the latent state z_t is sampled from the proposal and the
reparameterization trick is used to provide low-variance gradients.
Note that the SRNN paper refers to the proposals as the approximate posterior,
but we match the VRNN convention of referring to it as the encoder.
"""
def
__init__
(
self
,
rnn_cell
,
data_encoder
,
latent_encoder
,
transition
,
emission
,
proposal_type
,
proposal
=
None
,
rev_rnn_cell
=
None
,
tilt
=
None
,
random_seed
=
None
):
"""Create a trainable RNN.
Args:
rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
deterministic backbone of the SRNN. The inputs to the RNN will be the
the encoded input of the current timestep, a Tensor of shape
[batch_size, encoded_data_size].
data_encoder: A callable that accepts a batch of data x_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument the inputs x_t, a Tensor of the shape
[batch_size, data_size] and return a Tensor of shape
[batch_size, encoded_data_size]. This callable will be called multiple
times in the SRNN cell so if scoping is not handled correctly then
multiple copies of the variables in this network could be made. It is
recommended to use a snt.nets.MLP module, which takes care of this for
you.
latent_encoder: A callable that accepts a latent state z_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument a Tensor of shape [batch_size, latent_size] and
return a Tensor of shape [batch_size, encoded_latent_size].
This callable must also have the property 'output_size' defined,
returning encoded_latent_size.
transition: A callable that implements the transition distribution
p(z_t|h_t, z_t-1). Must accept as argument the previous RNN hidden state
and previous encoded latent state then return a tf.distributions.Normal
distribution conditioned on the input.
emission: A callable that implements the emission distribution
p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
and the RNN hidden state and return a subclass of
tf.distributions.Distribution that can be used to evaluate the logprob
of the targets.
proposal_type: A string indicating the type of proposal to use. Can
be either "filtering", "smoothing", or "prior". When proposal_type is
"filtering" or "smoothing", proposal must be provided. When
proposal_type is "smoothing", rev_rnn_cell must also be provided.
proposal: A callable that implements the proposal q(z_t| h_t, x_{1:T}).
If proposal_type is "filtering" then proposal must accept as arguments
the current rnn output, the encoded target of the current timestep,
and the mean of the prior. If proposal_type is "smoothing" then
in addition to the current rnn output and the mean of the prior
proposal must accept as arguments the output of the reverse rnn.
proposal should return a tf.distributions.Normal distribution
conditioned on its inputs. If proposal_type is "prior" this argument is
ignored.
rev_rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will aggregate
forward rnn outputs in the reverse direction. The inputs to the RNN
will be the encoded reverse input of the current timestep, a Tensor of
shape [batch_size, encoded_data_size].
tilt: A callable that implements the log of a positive tilting function
(ideally approximating log p(x_{t+1}|z_t, h_t). Must accept as arguments
the encoded latent state and the RNN hidden state and return a subclass
of tf.distributions.Distribution that can be used to evaluate the
logprob of x_{t+1}. Optionally, None and then no tilt is used.
random_seed: The seed for the random ops. Sets the seed for sample_step
and __call__.
"""
super
(
TrainableSRNN
,
self
).
__init__
(
rnn_cell
,
data_encoder
,
latent_encoder
,
transition
,
emission
,
random_seed
=
random_seed
)
self
.
rev_rnn_cell
=
rev_rnn_cell
self
.
_tilt
=
tilt
assert
proposal_type
in
[
"filtering"
,
"smoothing"
,
"prior"
]
self
.
_proposal
=
proposal
self
.
proposal_type
=
proposal_type
if
proposal_type
!=
"prior"
:
assert
proposal
,
"If not proposing from the prior, must provide proposal."
if
proposal_type
==
"smoothing"
:
assert
rev_rnn_cell
,
"Must provide rev_rnn_cell for smoothing proposal."
def
zero_state
(
self
,
batch_size
,
dtype
):
super_state
=
super
(
TrainableSRNN
,
self
).
zero_state
(
batch_size
,
dtype
)
return
TrainableSRNNState
(
rnn_out
=
tf
.
zeros
([
batch_size
,
self
.
rnn_cell
.
output_size
],
dtype
=
dtype
),
**
super_state
.
_asdict
())
def
set_observations
(
self
,
observations
,
seq_lengths
):
"""Stores the model's observations.
Stores the observations (inputs and targets) in TensorArrays and precomputes
things for later like the reverse RNN output and encoded targets.
Args:
observations: The observations of the model, a tuple containing two
Tensors of shape [max_seq_len, batch_size, data_size]. The Tensors
should be the inputs and targets, respectively.
seq_lengths: An int Tensor of shape [batch_size] containing the length
of each sequence in observations.
"""
inputs
,
targets
=
observations
self
.
seq_lengths
=
seq_lengths
self
.
max_seq_len
=
tf
.
reduce_max
(
seq_lengths
)
self
.
targets_ta
=
base
.
ta_for_tensor
(
targets
,
clear_after_read
=
False
)
targets_encoded
=
base
.
encode_all
(
targets
,
self
.
data_encoder
)
self
.
targets_encoded_ta
=
base
.
ta_for_tensor
(
targets_encoded
,
clear_after_read
=
False
)
inputs_encoded
=
base
.
encode_all
(
inputs
,
self
.
data_encoder
)
rnn_out
,
_
=
tf
.
nn
.
dynamic_rnn
(
self
.
rnn_cell
,
inputs_encoded
,
time_major
=
True
,
dtype
=
tf
.
float32
,
scope
=
"forward_rnn"
)
self
.
rnn_ta
=
base
.
ta_for_tensor
(
rnn_out
,
clear_after_read
=
False
)
if
self
.
rev_rnn_cell
:
targets_and_rnn_out
=
tf
.
concat
([
rnn_out
,
targets_encoded
],
2
)
reversed_targets_and_rnn_out
=
tf
.
reverse_sequence
(
targets_and_rnn_out
,
seq_lengths
,
seq_axis
=
0
,
batch_axis
=
1
)
# Compute the reverse rnn over the targets.
reverse_rnn_out
,
_
=
tf
.
nn
.
dynamic_rnn
(
self
.
rev_rnn_cell
,
reversed_targets_and_rnn_out
,
time_major
=
True
,
dtype
=
tf
.
float32
,
scope
=
"reverse_rnn"
)
reverse_rnn_out
=
tf
.
reverse_sequence
(
reverse_rnn_out
,
seq_lengths
,
seq_axis
=
0
,
batch_axis
=
1
)
self
.
reverse_rnn_ta
=
base
.
ta_for_tensor
(
reverse_rnn_out
,
clear_after_read
=
False
)
def
_filtering_proposal
(
self
,
rnn_out
,
prev_latent_encoded
,
prior
,
t
):
"""Computes the filtering proposal distribution."""
return
self
.
_proposal
(
rnn_out
,
prev_latent_encoded
,
self
.
targets_encoded_ta
.
read
(
t
),
prior_mu
=
prior
.
mean
())
def
_smoothing_proposal
(
self
,
rnn_out
,
prev_latent_encoded
,
prior
,
t
):
"""Computes the smoothing proposal distribution."""
return
self
.
_proposal
(
rnn_out
,
prev_latent_encoded
,
smoothing_tensors
=
[
self
.
reverse_rnn_ta
.
read
(
t
)],
prior_mu
=
prior
.
mean
())
def
proposal
(
self
,
rnn_out
,
prev_latent_encoded
,
prior
,
t
):
"""Computes the proposal distribution specified by proposal_type.
Args:
rnn_out: The output of the rnn for the current timestep.
prev_latent_encoded: Float Tensor of shape
[batch_size, encoded_latent_size], the previous latent state z_{t-1}
run through latent_encoder.
prior: A tf.distributions.Normal distribution representing the prior
over z_t, p(z_t | z_{1:t-1}, x_{1:t-1}). Used for 'res_q'.
t: A scalar int Tensor, the current timestep.
"""
if
self
.
proposal_type
==
"filtering"
:
return
self
.
_filtering_proposal
(
rnn_out
,
prev_latent_encoded
,
prior
,
t
)
elif
self
.
proposal_type
==
"smoothing"
:
return
self
.
_smoothing_proposal
(
rnn_out
,
prev_latent_encoded
,
prior
,
t
)
elif
self
.
proposal_type
==
"prior"
:
return
self
.
transition
(
rnn_out
,
prev_latent_encoded
)
def
tilt
(
self
,
rnn_out
,
latent_encoded
,
targets
):
r_func
=
self
.
_tilt
(
rnn_out
,
latent_encoded
)
return
tf
.
reduce_sum
(
r_func
.
log_prob
(
targets
),
axis
=-
1
)
def
propose_and_weight
(
self
,
state
,
t
):
"""Runs the model and computes importance weights for one timestep.
Runs the model and computes importance weights, sampling from the proposal
instead of the transition/prior.
Args:
state: The previous state of the model, a TrainableSRNNState containing
the previous rnn state, the previous rnn outs, and the previous encoded
latent.
t: A scalar integer Tensor, the current timestep.
Returns:
weights: A float Tensor of shape [batch_size].
new_state: The new state of the model.
"""
targets
=
self
.
targets_ta
.
read
(
t
)
rnn_out
=
self
.
rnn_ta
.
read
(
t
)
p_zt
=
self
.
transition
(
rnn_out
,
state
.
latent_encoded
)
q_zt
=
self
.
proposal
(
rnn_out
,
state
.
latent_encoded
,
p_zt
,
t
)
zt
=
q_zt
.
sample
(
seed
=
self
.
random_seed
)
p_xt_given_zt
,
latent_encoded
=
self
.
emission
(
zt
,
rnn_out
)
log_p_xt_given_zt
=
tf
.
reduce_sum
(
p_xt_given_zt
.
log_prob
(
targets
),
axis
=-
1
)
log_p_zt
=
tf
.
reduce_sum
(
p_zt
.
log_prob
(
zt
),
axis
=-
1
)
log_q_zt
=
tf
.
reduce_sum
(
q_zt
.
log_prob
(
zt
),
axis
=-
1
)
weights
=
log_p_zt
+
log_p_xt_given_zt
-
log_q_zt
if
self
.
_tilt
:
prev_log_r
=
tf
.
cond
(
tf
.
greater
(
t
,
0
),
lambda
:
self
.
tilt
(
state
.
rnn_out
,
state
.
latent_encoded
,
targets
),
lambda
:
0.
)
# On the first step, prev_log_r = 0.
log_r
=
tf
.
cond
(
tf
.
less
(
t
+
1
,
self
.
max_seq_len
),
lambda
:
self
.
tilt
(
rnn_out
,
latent_encoded
,
self
.
targets_ta
.
read
(
t
+
1
)),
lambda
:
0.
)
# On the last step, log_r = 0.
log_r
*=
tf
.
to_float
(
t
<
self
.
seq_lengths
-
1
)
weights
+=
log_r
-
prev_log_r
# This reshape is required because the TensorArray reports different shapes
# than the initial state provides (where the first dimension is unknown).
# The difference breaks the while_loop. Reshape prevents the error.
rnn_out
=
tf
.
reshape
(
rnn_out
,
tf
.
shape
(
state
.
rnn_out
))
new_state
=
TrainableSRNNState
(
rnn_out
=
rnn_out
,
rnn_state
=
state
.
rnn_state
,
# unmodified
latent_encoded
=
latent_encoded
)
return
weights
,
new_state
_DEFAULT_INITIALIZERS
=
{
"w"
:
tf
.
contrib
.
layers
.
xavier_initializer
(),
"b"
:
tf
.
zeros_initializer
()}
def
create_srnn
(
data_size
,
latent_size
,
emission_class
,
rnn_hidden_size
=
None
,
fcnet_hidden_sizes
=
None
,
encoded_data_size
=
None
,
encoded_latent_size
=
None
,
sigma_min
=
0.0
,
raw_sigma_bias
=
0.25
,
emission_bias_init
=
0.0
,
use_tilt
=
False
,
proposal_type
=
"filtering"
,
initializers
=
None
,
random_seed
=
None
):
"""A factory method for creating SRNN cells.
Args:
data_size: The dimension of the vectors that make up the data sequences.
latent_size: The size of the stochastic latent state of the SRNN.
emission_class: The class of the emission distribution. Can be either
ConditionalNormalDistribution or ConditionalBernoulliDistribution.
rnn_hidden_size: The hidden state dimension of the RNN that forms the
deterministic part of this SRNN. If None, then it defaults
to latent_size.
fcnet_hidden_sizes: A list of python integers, the size of the hidden
layers of the fully connected networks that parameterize the conditional
distributions of the SRNN. If None, then it defaults to one hidden
layer of size latent_size.
encoded_data_size: The size of the output of the data encoding network. If
None, defaults to latent_size.
encoded_latent_size: The size of the output of the latent state encoding
network. If None, defaults to latent_size.
sigma_min: The minimum value that the standard deviation of the
distribution over the latent state can take.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the neural networks that parameterize the prior and
approximate posterior. Useful for preventing standard deviations close
to zero.
emission_bias_init: A bias to added to the raw output of the fully
connected network that parameterizes the emission distribution. Useful
for initalizing the mean of the distribution to a sensible starting point
such as the mean of the training data. Only used with Bernoulli generative
distributions.
use_tilt: If true, create a SRNN with a tilting function.
proposal_type: The type of proposal to use. Can be "filtering", "smoothing",
or "prior".
initializers: The variable intitializers to use for the fully connected
networks and RNN cell. Must be a dictionary mapping the keys 'w' and 'b'
to the initializers for the weights and biases. Defaults to xavier for
the weights and zeros for the biases when initializers is None.
random_seed: A random seed for the SRNN resampling operations.
Returns:
model: A TrainableSRNN object.
"""
if
rnn_hidden_size
is
None
:
rnn_hidden_size
=
latent_size
if
fcnet_hidden_sizes
is
None
:
fcnet_hidden_sizes
=
[
latent_size
]
if
encoded_data_size
is
None
:
encoded_data_size
=
latent_size
if
encoded_latent_size
is
None
:
encoded_latent_size
=
latent_size
if
initializers
is
None
:
initializers
=
_DEFAULT_INITIALIZERS
data_encoder
=
snt
.
nets
.
MLP
(
output_sizes
=
fcnet_hidden_sizes
+
[
encoded_data_size
],
initializers
=
initializers
,
name
=
"data_encoder"
)
latent_encoder
=
snt
.
nets
.
MLP
(
output_sizes
=
fcnet_hidden_sizes
+
[
encoded_latent_size
],
initializers
=
initializers
,
name
=
"latent_encoder"
)
transition
=
base
.
ConditionalNormalDistribution
(
size
=
latent_size
,
hidden_layer_sizes
=
fcnet_hidden_sizes
,
sigma_min
=
sigma_min
,
raw_sigma_bias
=
raw_sigma_bias
,
initializers
=
initializers
,
name
=
"prior"
)
# Construct the emission distribution.
if
emission_class
==
base
.
ConditionalBernoulliDistribution
:
# For Bernoulli distributed outputs, we initialize the bias so that the
# network generates on average the mean from the training set.
emission_dist
=
functools
.
partial
(
base
.
ConditionalBernoulliDistribution
,
bias_init
=
emission_bias_init
)
else
:
emission_dist
=
base
.
ConditionalNormalDistribution
emission
=
emission_dist
(
size
=
data_size
,
hidden_layer_sizes
=
fcnet_hidden_sizes
,
initializers
=
initializers
,
name
=
"generative"
)
# Construct the proposal distribution.
if
proposal_type
in
[
"filtering"
,
"smoothing"
]:
proposal
=
base
.
NormalApproximatePosterior
(
size
=
latent_size
,
hidden_layer_sizes
=
fcnet_hidden_sizes
,
sigma_min
=
sigma_min
,
raw_sigma_bias
=
raw_sigma_bias
,
initializers
=
initializers
,
smoothing
=
(
proposal_type
==
"smoothing"
),
name
=
"approximate_posterior"
)
else
:
proposal
=
None
if
use_tilt
:
tilt
=
emission_dist
(
size
=
data_size
,
hidden_layer_sizes
=
fcnet_hidden_sizes
,
initializers
=
initializers
,
name
=
"tilt"
)
else
:
tilt
=
None
rnn_cell
=
tf
.
nn
.
rnn_cell
.
LSTMCell
(
rnn_hidden_size
,
initializer
=
initializers
[
"w"
])
rev_rnn_cell
=
tf
.
nn
.
rnn_cell
.
LSTMCell
(
rnn_hidden_size
,
initializer
=
initializers
[
"w"
])
return
TrainableSRNN
(
rnn_cell
,
data_encoder
,
latent_encoder
,
transition
,
emission
,
proposal_type
,
proposal
=
proposal
,
rev_rnn_cell
=
rev_rnn_cell
,
tilt
=
tilt
,
random_seed
=
random_seed
)
research/fivo/fivo/models/srnn_test.py
0 → 100644
View file @
27b4acd4
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.models.srnn."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
fivo.models
import
base
from
fivo.test_utils
import
create_srnn
class
SrnnTest
(
tf
.
test
.
TestCase
):
def
test_srnn_normal_emission
(
self
):
self
.
run_srnn
(
base
.
ConditionalNormalDistribution
,
[
-
5.947752
,
-
1.182961
])
def
test_srnn_bernoulli_emission
(
self
):
self
.
run_srnn
(
base
.
ConditionalBernoulliDistribution
,
[
-
2.566631
,
-
2.479234
])
def
run_srnn
(
self
,
generative_class
,
gt_log_alpha
):
"""Tests the SRNN.
All test values are 'golden values' derived by running the code and copying
the output.
Args:
generative_class: The class of the generative distribution to use.
gt_log_alpha: The ground-truth value of log alpha.
"""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
batch_size
=
2
model
,
inputs
,
targets
,
_
=
create_srnn
(
generative_class
=
generative_class
,
batch_size
=
batch_size
,
data_lengths
=
(
1
,
1
),
random_seed
=
1234
)
zero_state
=
model
.
zero_state
(
batch_size
=
batch_size
,
dtype
=
tf
.
float32
)
model
.
set_observations
([
inputs
,
targets
],
tf
.
convert_to_tensor
([
1
,
1
]))
model_out
=
model
.
propose_and_weight
(
zero_state
,
0
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_alpha
,
state
=
sess
.
run
(
model_out
)
self
.
assertAllClose
(
state
.
latent_encoded
,
[[
0.591787
,
1.310583
],
[
-
1.523136
,
0.953918
]])
self
.
assertAllClose
(
state
.
rnn_out
,
[[
0.041675
,
-
0.056038
,
-
0.001823
,
0.005224
],
[
0.042925
,
-
0.044619
,
0.021401
,
0.016998
]])
self
.
assertAllClose
(
log_alpha
,
gt_log_alpha
)
def
test_srnn_with_tilt_normal_emission
(
self
):
self
.
run_srnn_with_tilt
(
base
.
ConditionalNormalDistribution
,
[
-
9.13577
,
-
4.56725
])
def
test_srnn_with_tilt_bernoulli_emission
(
self
):
self
.
run_srnn_with_tilt
(
base
.
ConditionalBernoulliDistribution
,
[
-
4.617461
,
-
5.079248
])
def
run_srnn_with_tilt
(
self
,
generative_class
,
gt_log_alpha
):
"""Tests the SRNN with a tilting function.
All test values are 'golden values' derived by running the code and copying
the output.
Args:
generative_class: The class of the generative distribution to use.
gt_log_alpha: The ground-truth value of log alpha.
"""
tf
.
set_random_seed
(
1234
)
with
self
.
test_session
()
as
sess
:
batch_size
=
2
model
,
inputs
,
targets
,
_
=
create_srnn
(
generative_class
=
generative_class
,
batch_size
=
batch_size
,
data_lengths
=
(
3
,
2
),
random_seed
=
1234
,
use_tilt
=
True
)
zero_state
=
model
.
zero_state
(
batch_size
=
batch_size
,
dtype
=
tf
.
float32
)
model
.
set_observations
([
inputs
,
targets
],
tf
.
convert_to_tensor
([
3
,
2
]))
model_out
=
model
.
propose_and_weight
(
zero_state
,
0
)
sess
.
run
(
tf
.
global_variables_initializer
())
log_alpha
,
state
=
sess
.
run
(
model_out
)
self
.
assertAllClose
(
state
.
latent_encoded
,
[[
0.591787
,
1.310583
],
[
-
1.523136
,
0.953918
]])
self
.
assertAllClose
(
state
.
rnn_out
,
[[
0.041675
,
-
0.056038
,
-
0.001823
,
0.005224
],
[
0.042925
,
-
0.044619
,
0.021401
,
0.016998
]])
self
.
assertAllClose
(
log_alpha
,
gt_log_alpha
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Prev
1
2
3
4
5
6
7
8
9
10
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment