Commit 27b4acd4 authored by Aman Gupta's avatar Aman Gupta
Browse files

Merge remote-tracking branch 'upstream/master'

parents 5133522f d4e1f97f
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import sonnet as snt
import tensorflow as tf
import numpy as np
import math
SQUARED_OBSERVATION = "squared"
ABS_OBSERVATION = "abs"
STANDARD_OBSERVATION = "standard"
OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION]
ROUND_TRANSITION = "round"
STANDARD_TRANSITION = "standard"
TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION]
class Q(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
init_mu0_to_zero=False,
graph_collection_name="Q_VARS"):
self.sigma_min = sigma_min
self.dtype = dtype
self.graph_collection_name = graph_collection_name
initializers = []
for t in xrange(num_timesteps):
if t == 0 and init_mu0_to_zero:
initializers.append(
{"w": tf.zeros_initializer, "b": tf.zeros_initializer})
else:
initializers.append(
{"w": tf.random_uniform_initializer(seed=random_seed),
"b": tf.zeros_initializer})
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.mus = [
snt.Linear(output_size=state_size,
initializers=initializers[t],
name="q_mu_%d" % t,
custom_getter=custom_getter
)
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(num_timesteps)
]
def q_zt(self, observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](tf.concat([observation, prev_state], axis=1))
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0])
if t != 0:
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[1,0])
class PreviousStateQ(Q):
def q_zt(self, unused_observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](prev_state)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[0,0])
class ObservationQ(Q):
def q_zt(self, observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](observation)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0])
class SimpleMeanQ(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
init_mu0_to_zero=False,
graph_collection_name="Q_VARS"):
self.sigma_min = sigma_min
self.dtype = dtype
self.graph_collection_name = graph_collection_name
initializers = []
for t in xrange(num_timesteps):
if t == 0 and init_mu0_to_zero:
initializers.append(tf.zeros_initializer)
else:
initializers.append(tf.random_uniform_initializer(seed=random_seed))
self.mus = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_mu_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=initializers[t])
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(num_timesteps)
]
def q_zt(self, unused_observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = tf.tile(self.mus[t][tf.newaxis, :], [batch_size, 1])
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/%d" % t, f[0])
class R(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
sigma_init=1.,
random_seed=None,
graph_collection_name="R_VARS"):
self.dtype = dtype
self.sigma_min = sigma_min
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
"b": tf.zeros_initializer}
self.graph_collection_name=graph_collection_name
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.mus= [
snt.Linear(output_size=state_size,
initializers=initializers,
name="r_mu_%d" % t,
custom_getter=custom_getter)
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer=tf.constant_initializer(sigma_init))
for t in xrange(num_timesteps)
]
def r_xn(self, z_t, t):
batch_size = tf.shape(z_t)[0]
r_mu = self.mus[t](z_t)
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1])
return tf.contrib.distributions.Normal(
loc=r_mu, scale=tf.sqrt(r_sigma))
def summarize_weights(self):
for t in range(len(self.mus) - 1):
tf.summary.scalar("r_mu/%d" % t, self.mus[t][0])
tf.summary.scalar("r_sigma/%d" % t, self.sigmas[t][0])
class P(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
trainable=True,
init_bs_to_zero=False,
graph_collection_name="P_VARS"):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.graph_collection_name = graph_collection_name
if init_bs_to_zero:
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)]
else:
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)]
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="p_b_%d" % (t + 1),
initializer=initializers[t],
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
trainable=trainable) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
def posterior(self, observation, prev_state, t):
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu = observation - self.Bs[t]
if t > 0:
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t)
mu /= float(self.num_timesteps - t + 1)
sigma = tf.ones_like(mu) * self.variance * (
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1))
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def lookahead(self, state, t):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu = state + self.Bs[t]
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t)
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(mu) * self.variance * (self.num_timesteps + 1)
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
z_mu_p = prev_state + self.bs[t - 1]
else: # p(z_0) is Normal(0,1)
z_mu_p = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance))
return p_zt
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance))
class ShortChainNonlinearP(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
transition_dist=tf.contrib.distributions.Normal,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.observation_variance = observation_variance
self.transition_type = transition_type
self.transition_dist = transition_dist
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(prev_state)
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance))
elif self.transition_type == STANDARD_TRANSITION:
loc = prev_state
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance))
else: # p(z_0) is Normal(0,1)
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance)
p_zt = self.transition_dist(
loc=loc,
scale=tf.sqrt(tf.ones_like(loc) * self.variance))
return p_zt
def generative(self, unused_obs, z_ni):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(z_ni)
elif self.transition_type == STANDARD_TRANSITION:
loc = z_ni
generative_sigma_sq = tf.ones_like(loc) * self.observation_variance
return self.transition_dist(
loc=loc, scale=tf.sqrt(generative_sigma_sq))
class BimodalPriorP(object):
def __init__(self,
state_size,
num_timesteps,
mixing_coeff=0.5,
prior_mode_mean=1,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
trainable=True,
init_bs_to_zero=False,
graph_collection_name="P_VARS"):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.mixing_coeff = mixing_coeff
self.prior_mode_mean = prior_mode_mean
if init_bs_to_zero:
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)]
else:
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)]
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="b_%d" % (t + 1),
initializer=initializers[t],
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
trainable=trainable) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
def posterior(self, observation, prev_state, t):
# NOTE: This is currently wrong, but would require a refactoring of
# summarize_q to fix as kl is not defined for a mixture
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu = observation - self.Bs[t]
if t > 0:
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t)
mu /= float(self.num_timesteps - t + 1)
sigma = tf.ones_like(mu) * self.variance * (
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1))
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def lookahead(self, state, t):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu = state + self.Bs[t]
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t)
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
sum_of_bs = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(sum_of_bs) * self.variance * (self.num_timesteps + 1)
mu_pos = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean) + sum_of_bs
mu_neg = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean) + sum_of_bs
zn_pos = tf.contrib.distributions.Normal(
loc=mu_pos,
scale=tf.sqrt(sigma))
zn_neg = tf.contrib.distributions.Normal(
loc=mu_neg,
scale=tf.sqrt(sigma))
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64)
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1])
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs)
zn_dist = tf.contrib.distributions.Mixture(
cat=mode_selection_dist,
components=[zn_pos, zn_neg],
validate_args=True)
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(zn_dist.log_prob(observation), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
z_mu_p = prev_state + self.bs[t - 1]
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance))
return p_zt
else: # p(z_0) is mixture of two Normals
mu_pos = tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean
mu_neg = tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean
z0_pos = tf.contrib.distributions.Normal(
loc=mu_pos,
scale=tf.sqrt(tf.ones_like(mu_pos) * self.variance))
z0_neg = tf.contrib.distributions.Normal(
loc=mu_neg,
scale=tf.sqrt(tf.ones_like(mu_neg) * self.variance))
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64)
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1])
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs)
z0_dist = tf.contrib.distributions.Mixture(
cat=mode_selection_dist,
components=[z0_pos, z0_neg],
validate_args=False)
return z0_dist
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance))
class Model(object):
def __init__(self,
p,
q,
r,
state_size,
num_timesteps,
dtype=tf.float32):
self.p = p
self.q = q
self.r = r
self.state_size = state_size
self.num_timesteps = num_timesteps
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def __call__(self, prev_state, observation, t):
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observation, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q
zt = q_zt.sample()
r_xn = self.r.r_xn(zt, t)
# Calculate the logprobs and sum over the state size.
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
log_r_xn = tf.reduce_sum(r_xn.log_prob(observation), axis=1)
# If we're at the last timestep, also calc the logprob of the observation.
if t == self.num_timesteps - 1:
generative_dist = self.p.generative(observation, zt)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn)
@staticmethod
def create(state_size,
num_timesteps,
sigma_min=1e-5,
r_sigma_init=1,
variance=1.0,
mixing_coeff=0.5,
prior_mode_mean=1.0,
dtype=tf.float32,
random_seed=None,
train_p=True,
p_type="unimodal",
q_type="normal",
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
use_bs=True):
if p_type == "unimodal":
p = P(state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif p_type == "bimodal":
p = BimodalPriorP(
state_size,
num_timesteps,
mixing_coeff=mixing_coeff,
prior_mode_mean=prior_mode_mean,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif "nonlinear" in p_type:
if "cauchy" in p_type:
trans_dist = tf.contrib.distributions.Cauchy
else:
trans_dist = tf.contrib.distributions.Normal
p = ShortChainNonlinearP(
state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
transition_type=transition_type,
transition_dist=trans_dist,
dtype=dtype,
random_seed=random_seed
)
if q_type == "normal":
q_class = Q
elif q_type == "simple_mean":
q_class = SimpleMeanQ
elif q_type == "prev_state":
q_class = PreviousStateQ
elif q_type == "observation":
q_class = ObservationQ
q = q_class(state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed,
init_mu0_to_zero=not use_bs)
r = R(state_size,
num_timesteps,
sigma_min=sigma_min,
sigma_init=r_sigma_init,
dtype=dtype,
random_seed=random_seed)
model = Model(p, q, r, state_size, num_timesteps, dtype=dtype)
return model
class BackwardsModel(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="b_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
self.q_mus = [
snt.Linear(output_size=state_size) for _ in xrange(num_timesteps)
]
self.q_sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.r_mus = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_mu_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.r_sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def posterior(self, unused_observation, prev_state, unused_t):
# TODO(dieterichl): Correct this.
return tf.contrib.distributions.Normal(
loc=tf.zeros_like(prev_state), scale=tf.zeros_like(prev_state))
def lookahead(self, state, unused_t):
# TODO(dieterichl): Correct this.
return tf.contrib.distributions.Normal(
loc=tf.zeros_like(state), scale=tf.zeros_like(state))
def q_zt(self, observation, next_state, t):
"""Computes the variational posterior q(z_{t}|z_{t+1}, z_n)."""
t_backwards = self.num_timesteps - t - 1
batch_size = tf.shape(next_state)[0]
q_mu = self.q_mus[t_backwards](tf.concat([observation, next_state], axis=1))
q_sigma = tf.maximum(
tf.nn.softplus(self.q_sigmas[t_backwards]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def p_zt(self, prev_state, t):
"""Computes the model p(z_{t+1}| z_{t})."""
t_backwards = self.num_timesteps - t - 1
z_mu_p = prev_state + self.bs[t_backwards]
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.ones_like(z_mu_p))
return p_zt
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.ones_like(generative_p_mu))
def r(self, z_t, t):
t_backwards = self.num_timesteps - t - 1
batch_size = tf.shape(z_t)[0]
r_mu = tf.tile(self.r_mus[t_backwards][tf.newaxis, :], [batch_size, 1])
r_sigma = tf.maximum(
tf.nn.softplus(self.r_sigmas[t_backwards]), self.sigma_min)
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1])
return tf.contrib.distributions.Normal(loc=r_mu, scale=tf.sqrt(r_sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(mu) * (self.num_timesteps + 1)
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1))
def __call__(self, next_state, observation, t):
# next state = z_{t+1}
# Compute the q distribution over z, q(z_{t}|z_n, z_{t+1}).
q_zt = self.q_zt(observation, next_state, t)
# sample from q
zt = q_zt.sample()
# Compute the p distribution over z, p(z_{t+1}|z_{t}).
p_zt = self.p_zt(zt, t)
# Compute log p(z_{t+1} | z_t)
if t == 0:
log_p_zt = p_zt.log_prob(observation)
else:
log_p_zt = p_zt.log_prob(next_state)
# Compute r prior over zt
r_zt = self.r(zt, t)
log_r_zt = r_zt.log_prob(zt)
# Compute proposal density at zt
log_q_zt = q_zt.log_prob(zt)
# If we're at the last timestep, also calc the logprob of the observation.
if t == self.num_timesteps - 1:
p_z0_dist = tf.contrib.distributions.Normal(
loc=tf.zeros_like(zt), scale=tf.ones_like(zt))
z0_log_prob = p_z0_dist.log_prob(zt)
else:
z0_log_prob = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, z0_log_prob, log_r_zt)
class LongChainP(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
observation_type=STANDARD_OBSERVATION,
transition_type=STANDARD_TRANSITION,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = steps_per_obs*num_obs + 1
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.observation_variance = observation_variance
self.observation_type = observation_type
self.transition_type = transition_type
def likelihood(self, observations):
"""Computes the model's true likelihood of the observations.
Args:
observations: A [batch_size, m, state_size] Tensor representing each of
the m observations.
Returns:
logprob: The true likelihood of the observations given the model.
"""
raise ValueError("Likelihood is not defined for long-chain models")
# batch_size = tf.shape(observations)[0]
# mu = tf.zeros([batch_size, self.state_size, self.num_obs], dtype=self.dtype)
# sigma = np.fromfunction(
# lambda i, j: 1 + self.steps_per_obs*np.minimum(i+1, j+1),
# [self.num_obs, self.num_obs])
# sigma += np.eye(self.num_obs)
# sigma = tf.convert_to_tensor(sigma * self.variance, dtype=self.dtype)
# sigma = tf.tile(sigma[tf.newaxis, tf.newaxis, ...],
# [batch_size, self.state_size, 1, 1])
# dist = tf.contrib.distributions.MultivariateNormalFullCovariance(
# loc=mu,
# covariance_matrix=sigma)
# Average over the batch and take the sum over the state size
#return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observations), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(prev_state)
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance))
elif self.transition_type == STANDARD_TRANSITION:
loc = prev_state
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance))
else: # p(z_0) is Normal(0,1)
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance)
p_zt = tf.contrib.distributions.Normal(
loc=loc,
scale=tf.sqrt(tf.ones_like(loc) * self.variance))
return p_zt
def generative(self, z_ni, t):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if self.observation_type == SQUARED_OBSERVATION:
generative_mu = tf.square(z_ni)
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d^2, %0.1f)" % (t, t, t, self.variance))
elif self.observation_type == ABS_OBSERVATION:
generative_mu = tf.abs(z_ni)
tf.logging.info("p(x_%d | z_%d) ~ N(|z_%d|, %0.1f)" % (t, t, t, self.variance))
elif self.observation_type == STANDARD_OBSERVATION:
generative_mu = z_ni
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t, t, self.variance))
generative_sigma_sq = tf.ones_like(generative_mu) * self.observation_variance
return tf.contrib.distributions.Normal(
loc=generative_mu, scale=tf.sqrt(generative_sigma_sq))
class LongChainQ(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.sigma_min = sigma_min
self.dtype = dtype
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = num_obs*steps_per_obs +1
initializers = {
"w": tf.random_uniform_initializer(seed=random_seed),
"b": tf.zeros_initializer
}
self.mus = [
snt.Linear(output_size=state_size, initializers=initializers)
for t in xrange(self.num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(self.num_timesteps)
]
def first_relevant_obs_index(self, t):
return int(max((t-1)/self.steps_per_obs, 0))
def q_zt(self, observations, prev_state, t):
"""Computes a distribution over z_t.
Args:
observations: a [batch_size, num_observations, state_size] Tensor.
prev_state: a [batch_size, state_size] Tensor.
t: The current timestep, an int Tensor.
"""
# filter out unneeded past obs
first_relevant_obs_index = int(math.floor(max(t-1, 0) / self.steps_per_obs))
num_relevant_observations = self.num_obs - first_relevant_obs_index
observations = observations[:,first_relevant_obs_index:,:]
batch_size = tf.shape(prev_state)[0]
# concatenate the prev state and observations along the second axis (that is
# not the batch or state size axis, and then flatten it to
# [batch_size, (num_relevant_observations + 1) * state_size] to feed it into
# the linear layer.
q_input = tf.concat([observations, prev_state[:,tf.newaxis, :]], axis=1)
q_input = tf.reshape(q_input,
[batch_size, (num_relevant_observations + 1) * self.state_size])
q_mu = self.mus[t](q_input)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
tf.logging.info(
"q(z_{t} | z_{tm1}, x_{obsf}:{obst}) ~ N(Linear([z_{tm1},x_{obsf}:{obst}]), sigma_{t})".format(
**{"t": t,
"tm1": t-1,
"obsf": (first_relevant_obs_index+1)*self.steps_per_obs,
"obst":self.steps_per_obs*self.num_obs}))
return q_zt
def summarize_weights(self):
pass
class LongChainR(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.dtype = dtype
self.sigma_min = sigma_min
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = num_obs*steps_per_obs + 1
self.sigmas = [
tf.get_variable(
shape=[self.num_future_obs(t)],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer=tf.constant_initializer(1.0))
for t in range(self.num_timesteps)
]
def first_future_obs_index(self, t):
return int(math.floor(t / self.steps_per_obs))
def num_future_obs(self, t):
return int(self.num_obs - self.first_future_obs_index(t))
def r_xn(self, z_t, t):
"""Computes a distribution over the future observations given current latent
state.
The indexing in these messages is 1 indexed and inclusive. This is
consistent with the latex documents.
Args:
z_t: [batch_size, state_size] Tensor
t: Current timestep
"""
tf.logging.info(
"r(x_{start}:{end} | z_{t}) ~ N(z_{t}, sigma_{t})".format(
**{"t": t,
"start": (self.first_future_obs_index(t)+1)*self.steps_per_obs,
"end": self.num_timesteps-1}))
batch_size = tf.shape(z_t)[0]
# the mean for all future observations is the same.
# this tiling results in a [batch_size, num_future_obs, state_size] Tensor
r_mu = tf.tile(z_t[:,tf.newaxis,:], [1, self.num_future_obs(t), 1])
# compute the variance
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
# the variance is the same across all state dimensions, so we only have to
# time sigma to be [batch_size, num_future_obs].
r_sigma = tf.tile(r_sigma[tf.newaxis,:, tf.newaxis], [batch_size, 1, self.state_size])
return tf.contrib.distributions.Normal(
loc=r_mu, scale=tf.sqrt(r_sigma))
def summarize_weights(self):
pass
class LongChainModel(object):
def __init__(self,
p,
q,
r,
state_size,
num_obs,
steps_per_obs,
dtype=tf.float32,
disable_r=False):
self.p = p
self.q = q
self.r = r
self.disable_r = disable_r
self.state_size = state_size
self.num_obs = num_obs
self.steps_per_obs = steps_per_obs
self.num_timesteps = steps_per_obs*num_obs + 1
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def next_obs_ind(self, t):
return int(math.floor(max(t-1,0)/self.steps_per_obs))
def __call__(self, prev_state, observations, t):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observations, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q and evaluate the logprobs, summing over the state size
zt = q_zt.sample()
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
if not self.disable_r and t < self.num_timesteps-1:
# score the remaining observations using r
r_xn = self.r.r_xn(zt, t)
log_r_xn = r_xn.log_prob(observations[:, self.next_obs_ind(t+1):, :])
# sum over state size and observation, leaving the batch index
log_r_xn = tf.reduce_sum(log_r_xn, axis=[1,2])
else:
log_r_xn = tf.zeros_like(log_p_zt)
if t != 0 and t % self.steps_per_obs == 0:
generative_dist = self.p.generative(zt, t)
log_p_x_given_z = generative_dist.log_prob(observations[:,self.next_obs_ind(t),:])
log_p_x_given_z = tf.reduce_sum(log_p_x_given_z, axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn)
@staticmethod
def create(state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
observation_type=STANDARD_OBSERVATION,
transition_type=STANDARD_TRANSITION,
dtype=tf.float32,
random_seed=None,
disable_r=False):
p = LongChainP(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
observation_type=observation_type,
transition_type=transition_type,
dtype=dtype,
random_seed=random_seed)
q = LongChainQ(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
r = LongChainR(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
model = LongChainModel(
p, q, r, state_size, num_obs, steps_per_obs,
dtype=dtype,
disable_r=disable_r)
return model
class RTilde(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
graph_collection_name="R_TILDE_VARS"):
self.dtype = dtype
self.sigma_min = sigma_min
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
"b": tf.zeros_initializer}
self.graph_collection_name=graph_collection_name
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.fns = [
snt.Linear(output_size=2*state_size,
initializers=initializers,
name="r_tilde_%d" % t,
custom_getter=custom_getter)
for t in xrange(num_timesteps)
]
def r_zt(self, z_t, observation, t):
#out = self.fns[t](tf.stop_gradient(tf.concat([z_t, observation], axis=1)))
out = self.fns[t](tf.concat([z_t, observation], axis=1))
mu, raw_sigma_sq = tf.split(out, 2, axis=1)
sigma_sq = tf.maximum(tf.nn.softplus(raw_sigma_sq), self.sigma_min)
return mu, sigma_sq
class TDModel(object):
def __init__(self,
p,
q,
r_tilde,
state_size,
num_timesteps,
dtype=tf.float32,
disable_r=False):
self.p = p
self.q = q
self.r_tilde = r_tilde
self.disable_r = disable_r
self.state_size = state_size
self.num_timesteps = num_timesteps
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def __call__(self, prev_state, observation, t):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observation, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q and evaluate the logprobs, summing over the state size
zt = q_zt.sample()
# If it isn't the last timestep, compute the distribution over the next z.
if t < self.num_timesteps - 1:
p_ztplus1 = self.p.p_zt(zt, t+1)
else:
p_ztplus1 = None
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
if not self.disable_r and t < self.num_timesteps-1:
# score the remaining observations using r
r_tilde_mu, r_tilde_sigma_sq = self.r_tilde.r_zt(zt, observation, t+1)
else:
r_tilde_mu = None
r_tilde_sigma_sq = None
if t == self.num_timesteps - 1:
generative_dist = self.p.generative(observation, zt)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z,
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1)
@staticmethod
def create(state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
train_p=True,
p_type="unimodal",
q_type="normal",
mixing_coeff=0.5,
prior_mode_mean=1.0,
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
use_bs=True):
if p_type == "unimodal":
p = P(state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif p_type == "bimodal":
p = BimodalPriorP(
state_size,
num_timesteps,
mixing_coeff=mixing_coeff,
prior_mode_mean=prior_mode_mean,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif "nonlinear" in p_type:
if "cauchy" in p_type:
trans_dist = tf.contrib.distributions.Cauchy
else:
trans_dist = tf.contrib.distributions.Normal
p = ShortChainNonlinearP(
state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
transition_type=transition_type,
transition_dist=trans_dist,
dtype=dtype,
random_seed=random_seed
)
if q_type == "normal":
q_class = Q
elif q_type == "simple_mean":
q_class = SimpleMeanQ
elif q_type == "prev_state":
q_class = PreviousStateQ
elif q_type == "observation":
q_class = ObservationQ
q = q_class(state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed,
init_mu0_to_zero=not use_bs)
r_tilde = RTilde(
state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
model = TDModel(p, q, r_tilde, state_size, num_timesteps, dtype=dtype)
return model
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
model="forward"
T=5
num_obs=1
var=0.1
n=4
lr=0.0001
bound="fivo-aux"
q_type="normal"
resampling_method="multinomial"
rgrad="true"
p_type="unimodal"
use_bs=false
LOGDIR=/tmp/fivo/model-$model-$bound-$resampling_method-resampling-rgrad-$rgrad-T-$T-var-$var-n-$n-lr-$lr-q-$q_type-p-$p_type
python train.py \
--logdir=$LOGDIR \
--model=$model \
--bound=$bound \
--q_type=$q_type \
--p_type=$p_type \
--variance=$var \
--use_resampling_grads=$rgrad \
--resampling=always \
--resampling_method=$resampling_method \
--batch_size=4 \
--num_samples=$n \
--num_timesteps=$T \
--num_eval_samples=256 \
--summarize_every=100 \
--learning_rate=$lr \
--decay_steps=1000000 \
--max_steps=1000000000 \
--random_seed=1234 \
--train_p=false \
--use_bs=$use_bs \
--alsologtostderr
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for plotting and summarizing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import scipy
import tensorflow as tf
import models
def summarize_ess(weights, only_last_timestep=False):
"""Plots the effective sample size.
Args:
weights: List of length num_timesteps Tensors of shape
[num_samples, batch_size]
"""
num_timesteps = len(weights)
batch_size = tf.cast(tf.shape(weights[0])[1], dtype=tf.float64)
for i in range(num_timesteps):
if only_last_timestep and i < num_timesteps-1: continue
w = tf.nn.softmax(weights[i], dim=0)
centered_weights = w - tf.reduce_mean(w, axis=0, keepdims=True)
variance = tf.reduce_sum(tf.square(centered_weights))/(batch_size-1)
ess = 1./tf.reduce_mean(tf.reduce_sum(tf.square(w), axis=0))
tf.summary.scalar("ess/%d" % i, ess)
tf.summary.scalar("ese/%d" % i, ess / batch_size)
tf.summary.scalar("weight_variance/%d" % i, variance)
def summarize_particles(states, weights, observation, model):
"""Plots particle locations and weights.
Args:
states: List of length num_timesteps Tensors of shape
[batch_size*num_particles, state_size].
weights: List of length num_timesteps Tensors of shape [num_samples,
batch_size]
observation: Tensor of shape [batch_size*num_samples, state_size]
"""
num_timesteps = len(weights)
num_samples, batch_size = weights[0].get_shape().as_list()
# get q0 information for plotting
q0_dist = model.q.q_zt(observation, tf.zeros_like(states[0]), 0)
q0_loc = q0_dist.loc[0:batch_size, 0]
q0_scale = q0_dist.scale[0:batch_size, 0]
# get posterior information for plotting
post = (model.p.mixing_coeff, model.p.prior_mode_mean, model.p.variance,
tf.reduce_sum(model.p.bs), model.p.num_timesteps)
# Reshape states and weights to be [time, num_samples, batch_size]
states = tf.stack(states)
weights = tf.stack(weights)
# normalize the weights over the sample dimension
weights = tf.nn.softmax(weights, dim=1)
states = tf.reshape(states, tf.shape(weights))
ess = 1./tf.reduce_sum(tf.square(weights), axis=1)
def _plot_states(states_batch, weights_batch, observation_batch, ess_batch, q0, post):
"""
states: [time, num_samples, batch_size]
weights [time, num_samples, batch_size]
observation: [batch_size, 1]
q0: ([batch_size], [batch_size])
post: ...
"""
num_timesteps, _, batch_size = states_batch.shape
plots = []
for i in range(batch_size):
states = states_batch[:,:,i]
weights = weights_batch[:,:,i]
observation = observation_batch[i]
ess = ess_batch[:,i]
q0_loc = q0[0][i]
q0_scale = q0[1][i]
fig = plt.figure(figsize=(7, (num_timesteps + 1) * 2))
# Each timestep gets two plots -- a bar plot and a histogram of state locs.
# The bar plot will be bar_rows rows tall.
# The histogram will be 1 row tall.
# There is also 1 extra plot at the top showing the posterior and q.
bar_rows = 8
num_rows = (num_timesteps + 1) * (bar_rows + 1)
gs = gridspec.GridSpec(num_rows, 1)
# Figure out how wide to make the plot
prior_lims = (post[1] * -2, post[1] * 2)
q_lims = (scipy.stats.norm.ppf(0.01, loc=q0_loc, scale=q0_scale),
scipy.stats.norm.ppf(0.99, loc=q0_loc, scale=q0_scale))
state_width = states.max() - states.min()
state_lims = (states.min() - state_width * 0.15,
states.max() + state_width * 0.15)
lims = (min(prior_lims[0], q_lims[0], state_lims[0]),
max(prior_lims[1], q_lims[1], state_lims[1]))
# plot the posterior
z0 = np.arange(lims[0], lims[1], 0.1)
alpha, pos_mu, sigma_sq, B, T = post
neg_mu = -pos_mu
scale = np.sqrt((T + 1) * sigma_sq)
p_zn = (
alpha * scipy.stats.norm.pdf(
observation, loc=pos_mu + B, scale=scale) + (1 - alpha) *
scipy.stats.norm.pdf(observation, loc=neg_mu + B, scale=scale))
p_z0 = (
alpha * scipy.stats.norm.pdf(z0, loc=pos_mu, scale=np.sqrt(sigma_sq))
+ (1 - alpha) * scipy.stats.norm.pdf(
z0, loc=neg_mu, scale=np.sqrt(sigma_sq)))
p_zn_given_z0 = scipy.stats.norm.pdf(
observation, loc=z0 + B, scale=np.sqrt(T * sigma_sq))
post_z0 = (p_z0 * p_zn_given_z0) / p_zn
# plot q
q_z0 = scipy.stats.norm.pdf(z0, loc=q0_loc, scale=q0_scale)
ax = plt.subplot(gs[0:bar_rows, :])
ax.plot(z0, q_z0, color="blue")
ax.plot(z0, post_z0, color="green")
ax.plot(z0, p_z0, color="red")
ax.legend(("q", "posterior", "prior"), loc="best", prop={"size": 10})
ax.set_xticks([])
ax.set_xlim(*lims)
# plot the states
for t in range(num_timesteps):
start = (t + 1) * (bar_rows + 1)
ax1 = plt.subplot(gs[start:start + bar_rows, :])
ax2 = plt.subplot(gs[start + bar_rows:start + bar_rows + 1, :])
# plot the states barplot
# ax1.hist(
# states[t, :],
# weights=weights[t, :],
# bins=50,
# edgecolor="none",
# alpha=0.2)
ax1.bar(states[t,:], weights[t,:], width=0.02, alpha=0.2, edgecolor = "none")
ax1.set_ylabel("t=%d" % t)
ax1.set_xticks([])
ax1.grid(True, which="both")
ax1.set_xlim(*lims)
# plot the observation
ax1.axvline(x=observation, color="red", linestyle="dashed")
# add the ESS
ax1.text(0.1, 0.9, "ESS: %0.2f" % ess[t],
ha='center', va='center', transform=ax1.transAxes)
# plot the state location histogram
ax2.hist2d(
states[t, :], np.zeros_like(states[t, :]), bins=[50, 1], cmap="Greys")
ax2.grid(False)
ax2.set_yticks([])
ax2.set_xlim(*lims)
if t != num_timesteps - 1:
ax2.set_xticks([])
fig.canvas.draw()
p = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
plots.append(p.reshape(fig.canvas.get_width_height()[::-1] + (3,)))
plt.close(fig)
return np.stack(plots)
plots = tf.py_func(_plot_states,
[states, weights, observation, ess, (q0_loc, q0_scale), post],
[tf.uint8])[0]
tf.summary.image("states", plots, 5, collections=["infrequent_summaries"])
def plot_weights(weights, resampled=None):
"""Plots the weights and effective sample size from an SMC rollout.
Args:
weights: [num_timesteps, num_samples, batch_size] importance weights
resampled: [num_timesteps] 0/1 indicating if resampling ocurred
"""
weights = tf.convert_to_tensor(weights)
def _make_plots(weights, resampled):
num_timesteps, num_samples, batch_size = weights.shape
plots = []
for i in range(batch_size):
fig, axes = plt.subplots(nrows=1, sharex=True, figsize=(8, 4))
axes.stackplot(np.arange(num_timesteps), np.transpose(weights[:, :, i]))
axes.set_title("Weights")
axes.set_xlabel("Steps")
axes.set_ylim([0, 1])
axes.set_xlim([0, num_timesteps - 1])
for j in np.where(resampled > 0)[0]:
axes.axvline(x=j, color="red", linestyle="dashed", ymin=0.0, ymax=1.0)
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plots.append(data)
plt.close(fig)
return np.stack(plots, axis=0)
if resampled is None:
num_timesteps, _, batch_size = weights.get_shape().as_list()
resampled = tf.zeros([num_timesteps], dtype=tf.float32)
plots = tf.py_func(_make_plots,
[tf.nn.softmax(weights, dim=1),
tf.to_float(resampled)], [tf.uint8])[0]
batch_size = weights.get_shape().as_list()[-1]
tf.summary.image(
"weights", plots, batch_size, collections=["infrequent_summaries"])
def summarize_weights(weights, num_timesteps, num_samples):
# weights is [num_timesteps, num_samples, batch_size]
weights = tf.convert_to_tensor(weights)
mean = tf.reduce_mean(weights, axis=1, keepdims=True)
squared_diff = tf.square(weights - mean)
variances = tf.reduce_sum(squared_diff, axis=1) / (num_samples - 1)
# average the variance over the batch
variances = tf.reduce_mean(variances, axis=1)
avg_magnitude = tf.reduce_mean(tf.abs(weights), axis=[1, 2])
for t in xrange(num_timesteps):
tf.summary.scalar("weights/variance_%d" % t, variances[t])
tf.summary.scalar("weights/magnitude_%d" % t, avg_magnitude[t])
tf.summary.histogram("weights/step_%d" % t, weights[t])
def summarize_learning_signal(rewards, tag):
num_resampling_events, _ = rewards.get_shape().as_list()
mean = tf.reduce_mean(rewards, axis=1)
avg_magnitude = tf.reduce_mean(tf.abs(rewards), axis=1)
reward_square = tf.reduce_mean(tf.square(rewards), axis=1)
for t in xrange(num_resampling_events):
tf.summary.scalar("%s/mean_%d" % (tag, t), mean[t])
tf.summary.scalar("%s/magnitude_%d" % (tag, t), avg_magnitude[t])
tf.summary.scalar("%s/squared_%d" % (tag, t), reward_square[t])
tf.summary.histogram("%s/step_%d" % (tag, t), rewards[t])
def summarize_qs(model, observation, states):
model.q.summarize_weights()
if hasattr(model.p, "posterior") and callable(getattr(model.p, "posterior")):
states = [tf.zeros_like(states[0])] + states[:-1]
for t, prev_state in enumerate(states):
p = model.p.posterior(observation, prev_state, t)
q = model.q.q_zt(observation, prev_state, t)
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(p, q))
tf.summary.scalar("kl_q/%d" % t, tf.reduce_mean(kl))
mean_diff = q.loc - p.loc
mean_abs_err = tf.abs(mean_diff)
mean_rel_err = tf.abs(mean_diff / p.loc)
tf.summary.scalar("q_mean_convergence/absolute_error_%d" % t,
tf.reduce_mean(mean_abs_err))
tf.summary.scalar("q_mean_convergence/relative_error_%d" % t,
tf.reduce_mean(mean_rel_err))
sigma_diff = tf.square(q.scale) - tf.square(p.scale)
sigma_abs_err = tf.abs(sigma_diff)
sigma_rel_err = tf.abs(sigma_diff / tf.square(p.scale))
tf.summary.scalar("q_variance_convergence/absolute_error_%d" % t,
tf.reduce_mean(sigma_abs_err))
tf.summary.scalar("q_variance_convergence/relative_error_%d" % t,
tf.reduce_mean(sigma_rel_err))
def summarize_rs(model, states):
model.r.summarize_weights()
for t, state in enumerate(states):
true_r = model.p.lookahead(state, t)
r = model.r.r_xn(state, t)
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(true_r, r))
tf.summary.scalar("kl_r/%d" % t, tf.reduce_mean(kl))
mean_diff = true_r.loc - r.loc
mean_abs_err = tf.abs(mean_diff)
mean_rel_err = tf.abs(mean_diff / true_r.loc)
tf.summary.scalar("r_mean_convergence/absolute_error_%d" % t,
tf.reduce_mean(mean_abs_err))
tf.summary.scalar("r_mean_convergence/relative_error_%d" % t,
tf.reduce_mean(mean_rel_err))
sigma_diff = tf.square(r.scale) - tf.square(true_r.scale)
sigma_abs_err = tf.abs(sigma_diff)
sigma_rel_err = tf.abs(sigma_diff / tf.square(true_r.scale))
tf.summary.scalar("r_variance_convergence/absolute_error_%d" % t,
tf.reduce_mean(sigma_abs_err))
tf.summary.scalar("r_variance_convergence/relative_error_%d" % t,
tf.reduce_mean(sigma_rel_err))
def summarize_model(model, true_bs, observation, states, bound, summarize_r=True):
if hasattr(model.p, "bs"):
model_b = tf.reduce_sum(model.p.bs, axis=0)
true_b = tf.reduce_sum(true_bs, axis=0)
abs_err = tf.abs(model_b - true_b)
rel_err = abs_err / true_b
tf.summary.scalar("sum_of_bs/data_generating_process", tf.reduce_mean(true_b))
tf.summary.scalar("sum_of_bs/model", tf.reduce_mean(model_b))
tf.summary.scalar("sum_of_bs/absolute_error", tf.reduce_mean(abs_err))
tf.summary.scalar("sum_of_bs/relative_error", tf.reduce_mean(rel_err))
#summarize_qs(model, observation, states)
#if bound == "fivo-aux" and summarize_r:
# summarize_rs(model, states)
def summarize_grads(grads, loss_name):
grad_ema = tf.train.ExponentialMovingAverage(decay=0.99)
vectorized_grads = tf.concat(
[tf.reshape(g, [-1]) for g, _ in grads if g is not None], axis=0)
new_second_moments = tf.square(vectorized_grads)
new_first_moments = vectorized_grads
maintain_grad_ema_op = grad_ema.apply([new_first_moments, new_second_moments])
first_moments = grad_ema.average(new_first_moments)
second_moments = grad_ema.average(new_second_moments)
variances = second_moments - tf.square(first_moments)
tf.summary.scalar("grad_variance/%s" % loss_name, tf.reduce_mean(variances))
tf.summary.histogram("grad_variance/%s" % loss_name, variances)
tf.summary.histogram("grad_mean/%s" % loss_name, first_moments)
return maintain_grad_ema_op
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main script for running fivo"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import numpy as np
import tensorflow as tf
import bounds
import data
import models
import summary_utils as summ
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.flags.DEFINE_integer("random_seed", None,
"A random seed for the data generating process. Same seed "
"-> same data generating process and initialization.")
tf.app.flags.DEFINE_enum("bound", "fivo", ["iwae", "fivo", "fivo-aux", "fivo-aux-td"],
"The bound to optimize.")
tf.app.flags.DEFINE_enum("model", "forward", ["forward", "long_chain"],
"The model to use.")
tf.app.flags.DEFINE_enum("q_type", "normal",
["normal", "simple_mean", "prev_state", "observation"],
"The parameterization to use for q")
tf.app.flags.DEFINE_enum("p_type", "unimodal", ["unimodal", "bimodal", "nonlinear"],
"The type of prior.")
tf.app.flags.DEFINE_boolean("train_p", True,
"If false, do not train the model p.")
tf.app.flags.DEFINE_integer("state_size", 1,
"The dimensionality of the state space.")
tf.app.flags.DEFINE_float("variance", 1.0,
"The variance of the data generating process.")
tf.app.flags.DEFINE_boolean("use_bs", True,
"If False, initialize all bs to 0.")
tf.app.flags.DEFINE_float("bimodal_prior_weight", 0.5,
"The weight assigned to the positive mode of the prior in "
"both the data generating process and p.")
tf.app.flags.DEFINE_float("bimodal_prior_mean", None,
"If supplied, sets the mean of the 2 modes of the prior to "
"be 1 and -1 times the supplied value. This is for both the "
"data generating process and p.")
tf.app.flags.DEFINE_float("fixed_observation", None,
"If supplied, fix the observation to a constant value in the"
" data generating process only.")
tf.app.flags.DEFINE_float("r_sigma_init", 1.,
"Value to initialize variance of r to.")
tf.app.flags.DEFINE_enum("observation_type",
models.STANDARD_OBSERVATION, models.OBSERVATION_TYPES,
"The type of observation for the long chain model.")
tf.app.flags.DEFINE_enum("transition_type",
models.STANDARD_TRANSITION, models.TRANSITION_TYPES,
"The type of transition for the long chain model.")
tf.app.flags.DEFINE_float("observation_variance", None,
"The variance of the observation. Defaults to 'variance'")
tf.app.flags.DEFINE_integer("num_timesteps", 5,
"Number of timesteps in the sequence.")
tf.app.flags.DEFINE_integer("num_observations", 1,
"The number of observations.")
tf.app.flags.DEFINE_integer("steps_per_observation", 5,
"The number of timesteps between each observation.")
tf.app.flags.DEFINE_integer("batch_size", 4,
"The number of examples per batch.")
tf.app.flags.DEFINE_integer("num_samples", 4,
"The number particles to use.")
tf.app.flags.DEFINE_integer("num_eval_samples", 512,
"The batch size and # of particles to use for eval.")
tf.app.flags.DEFINE_string("resampling", "always",
"How to resample. Accepts 'always','never', or a "
"comma-separated list of booleans like 'true,true,false'.")
tf.app.flags.DEFINE_enum("resampling_method", "multinomial", ["multinomial",
"stratified",
"systematic",
"relaxed-logblend",
"relaxed-stateblend",
"relaxed-linearblend",
"relaxed-stateblend-st",],
"Type of resampling method to use.")
tf.app.flags.DEFINE_boolean("use_resampling_grads", True,
"Whether or not to use resampling grads to optimize FIVO."
"Disabled automatically if resampling_method=relaxed.")
tf.app.flags.DEFINE_boolean("disable_r", False,
"If false, r is not used for fivo-aux and is set to zeros.")
tf.app.flags.DEFINE_float("learning_rate", 1e-4,
"The learning rate to use for ADAM or SGD.")
tf.app.flags.DEFINE_integer("decay_steps", 25000,
"The number of steps before the learning rate is halved.")
tf.app.flags.DEFINE_integer("max_steps", int(1e6),
"The number of steps to run training for.")
tf.app.flags.DEFINE_string("logdir", "/tmp/fivo-aux",
"Directory for summaries and checkpoints.")
tf.app.flags.DEFINE_integer("summarize_every", int(1e3),
"The number of steps between each evaluation.")
FLAGS = tf.app.flags.FLAGS
def combine_grad_lists(grad_lists):
# grads is num_losses by num_variables.
# each list could have different variables.
# for each variable, sum the grads across all losses.
grads_dict = defaultdict(list)
var_dict = {}
for grad_list in grad_lists:
for grad, var in grad_list:
if grad is not None:
grads_dict[var.name].append(grad)
var_dict[var.name] = var
final_grads = []
for var_name, var in var_dict.iteritems():
grads = grads_dict[var_name]
if len(grads) > 0:
tf.logging.info("Var %s has combined grads from %s." %
(var_name, [g.name for g in grads]))
grad = tf.reduce_sum(grads, axis=0)
else:
tf.logging.info("Var %s has no grads" % var_name)
grad = None
final_grads.append((grad, var))
return final_grads
def make_apply_grads_op(losses, global_step, learning_rate, lr_decay_steps):
for l in losses:
assert isinstance(l, bounds.Loss)
lr = tf.train.exponential_decay(
learning_rate, global_step, lr_decay_steps, 0.5, staircase=False)
tf.summary.scalar("learning_rate", lr)
opt = tf.train.AdamOptimizer(lr)
ema_ops = []
grads = []
for loss_name, loss, loss_var_collection in losses:
tf.logging.info("Computing grads of %s w.r.t. vars in collection %s" %
(loss_name, loss_var_collection))
g = opt.compute_gradients(loss,
var_list=tf.get_collection(loss_var_collection))
ema_ops.append(summ.summarize_grads(g, loss_name))
grads.append(g)
all_grads = combine_grad_lists(grads)
apply_grads_op = opt.apply_gradients(all_grads, global_step=global_step)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads_op]):
train_op = tf.group(*ema_ops)
return train_op
def add_check_numerics_ops():
check_op = []
for op in tf.get_default_graph().get_operations():
bad = ["logits/Log", "sample/Reshape", "log_prob/mul",
"log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape",
"entropy/Reshape", "entropy/LogSoftmax", "Categorical", "Mean"]
if all([x not in op.name for x in bad]):
for output in op.outputs:
if output.dtype in [tf.float16, tf.float32, tf.float64]:
if op._get_control_flow_context() is not None: # pylint: disable=protected-access
raise ValueError("`tf.add_check_numerics_ops() is not compatible "
"with TensorFlow control flow operations such as "
"`tf.cond()` or `tf.while_loop()`.")
message = op.name + ":" + str(output.value_index)
with tf.control_dependencies(check_op):
check_op = [tf.check_numerics(output, message=message)]
return tf.group(*check_op)
def create_long_chain_graph(bound, state_size, num_obs, steps_per_obs,
batch_size, num_samples, num_eval_samples,
resampling_schedule, use_resampling_grads,
learning_rate, lr_decay_steps, dtype="float64"):
num_timesteps = num_obs * steps_per_obs + 1
# Make the dataset.
dataset = data.make_long_chain_dataset(
state_size=state_size,
num_obs=num_obs,
steps_per_obs=steps_per_obs,
batch_size=batch_size,
num_samples=num_samples,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=dtype,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Make the dataset for eval
eval_dataset = data.make_long_chain_dataset(
state_size=state_size,
num_obs=num_obs,
steps_per_obs=steps_per_obs,
batch_size=batch_size,
num_samples=num_eval_samples,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=dtype,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation)
eval_itr = eval_dataset.make_one_shot_iterator()
_, eval_observations = eval_itr.get_next()
# Make the model.
model = models.LongChainModel.create(
state_size,
num_obs,
steps_per_obs,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=tf.as_dtype(dtype),
disable_r=FLAGS.disable_r)
# Compute the bound and loss
if bound == "iwae":
(_, losses, ema_op, _, _) = bounds.iwae(
model,
observations,
num_timesteps,
num_samples=num_samples)
(eval_log_p_hat, _, _, _, eval_log_weights) = bounds.iwae(
model,
eval_observations,
num_timesteps,
num_samples=num_eval_samples,
summarize=False)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
elif bound == "fivo" or "fivo-aux":
(_, losses, ema_op, _, _) = bounds.fivo(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=use_resampling_grads,
resampling_type=FLAGS.resampling_method,
aux=("aux" in bound),
num_samples=num_samples)
(eval_log_p_hat, _, _, _, eval_log_weights) = bounds.fivo(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=False,
resampling_type="multinomial",
aux=("aux" in bound),
num_samples=num_eval_samples,
summarize=False)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
summ.summarize_ess(eval_log_weights, only_last_timestep=True)
tf.summary.scalar("log_p_hat", eval_log_p_hat)
# Compute and apply grads.
global_step = tf.train.get_or_create_global_step()
apply_grads = make_apply_grads_op(losses,
global_step,
learning_rate,
lr_decay_steps)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads]):
train_op = tf.group(ema_op)
# We can't calculate the likelihood for most of these models
# so we just return zeros.
eval_likelihood = tf.zeros([], dtype=dtype)
return global_step, train_op, eval_log_p_hat, eval_likelihood
def create_graph(bound, state_size, num_timesteps, batch_size,
num_samples, num_eval_samples, resampling_schedule,
use_resampling_grads, learning_rate, lr_decay_steps,
train_p, dtype='float64'):
if FLAGS.use_bs:
true_bs = None
else:
true_bs = [np.zeros([state_size]).astype(dtype) for _ in xrange(num_timesteps)]
# Make the dataset.
true_bs, dataset = data.make_dataset(
bs=true_bs,
state_size=state_size,
num_timesteps=num_timesteps,
batch_size=batch_size,
num_samples=num_samples,
variance=FLAGS.variance,
prior_type=FLAGS.p_type,
bimodal_prior_weight=FLAGS.bimodal_prior_weight,
bimodal_prior_mean=FLAGS.bimodal_prior_mean,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation,
dtype=dtype)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Make the dataset for eval
_, eval_dataset = data.make_dataset(
bs=true_bs,
state_size=state_size,
num_timesteps=num_timesteps,
batch_size=num_eval_samples,
num_samples=num_eval_samples,
variance=FLAGS.variance,
prior_type=FLAGS.p_type,
bimodal_prior_weight=FLAGS.bimodal_prior_weight,
bimodal_prior_mean=FLAGS.bimodal_prior_mean,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation,
dtype=dtype)
eval_itr = eval_dataset.make_one_shot_iterator()
_, eval_observations = eval_itr.get_next()
# Make the model.
if bound == "fivo-aux-td":
model = models.TDModel.create(
state_size,
num_timesteps,
variance=FLAGS.variance,
train_p=train_p,
p_type=FLAGS.p_type,
q_type=FLAGS.q_type,
mixing_coeff=FLAGS.bimodal_prior_weight,
prior_mode_mean=FLAGS.bimodal_prior_mean,
observation_variance=FLAGS.observation_variance,
transition_type=FLAGS.transition_type,
use_bs=FLAGS.use_bs,
dtype=tf.as_dtype(dtype),
random_seed=FLAGS.random_seed)
else:
model = models.Model.create(
state_size,
num_timesteps,
variance=FLAGS.variance,
train_p=train_p,
p_type=FLAGS.p_type,
q_type=FLAGS.q_type,
mixing_coeff=FLAGS.bimodal_prior_weight,
prior_mode_mean=FLAGS.bimodal_prior_mean,
observation_variance=FLAGS.observation_variance,
transition_type=FLAGS.transition_type,
use_bs=FLAGS.use_bs,
r_sigma_init=FLAGS.r_sigma_init,
dtype=tf.as_dtype(dtype),
random_seed=FLAGS.random_seed)
# Compute the bound and loss
if bound == "iwae":
(_, losses, ema_op, _, _) = bounds.iwae(
model,
observations,
num_timesteps,
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.iwae(
model,
eval_observations,
num_timesteps,
num_samples=num_eval_samples,
summarize=True)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
elif "fivo" in bound:
if bound == "fivo-aux-td":
(_, losses, ema_op, _, _) = bounds.fivo_aux_td(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo_aux_td(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
num_samples=num_eval_samples,
summarize=True)
else:
(_, losses, ema_op, _, _) = bounds.fivo(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=use_resampling_grads,
resampling_type=FLAGS.resampling_method,
aux=("aux" in bound),
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=False,
resampling_type="multinomial",
aux=("aux" in bound),
num_samples=num_eval_samples,
summarize=True)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
summ.summarize_ess(eval_log_weights, only_last_timestep=True)
# if FLAGS.p_type == "bimodal":
# # create the observations that showcase the model.
# mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.],
# dtype=tf.float64)
# mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1)
# k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean)
# explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k
# explain_obs = tf.tile(explain_obs, [num_eval_samples, 1])
# # run the model on the explainable observations
# if bound == "iwae":
# (_, _, _, explain_states, explain_log_weights) = bounds.iwae(
# model,
# explain_obs,
# num_timesteps,
# num_samples=num_eval_samples)
# elif bound == "fivo" or "fivo-aux":
# (_, _, _, explain_states, explain_log_weights) = bounds.fivo(
# model,
# explain_obs,
# num_timesteps,
# resampling_schedule=resampling_schedule,
# use_resampling_grads=False,
# resampling_type="multinomial",
# aux=("aux" in bound),
# num_samples=num_eval_samples)
# summ.summarize_particles(explain_states,
# explain_log_weights,
# explain_obs,
# model)
# Calculate the true likelihood.
if hasattr(model.p, 'likelihood') and callable(getattr(model.p, 'likelihood')):
eval_likelihood = model.p.likelihood(eval_observations)/ FLAGS.num_timesteps
else:
eval_likelihood = tf.zeros_like(eval_log_p_hat)
tf.summary.scalar("log_p_hat", eval_log_p_hat)
tf.summary.scalar("likelihood", eval_likelihood)
tf.summary.scalar("bound_gap", eval_likelihood - eval_log_p_hat)
summ.summarize_model(model, true_bs, eval_observations, eval_states, bound,
summarize_r=not bound == "fivo-aux-td")
# Compute and apply grads.
global_step = tf.train.get_or_create_global_step()
apply_grads = make_apply_grads_op(losses,
global_step,
learning_rate,
lr_decay_steps)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads]):
train_op = tf.group(ema_op)
#train_op = tf.group(ema_op, add_check_numerics_ops())
return global_step, train_op, eval_log_p_hat, eval_likelihood
def parse_resampling_schedule(schedule, num_timesteps):
schedule = schedule.strip().lower()
if schedule == "always":
return [True] * (num_timesteps - 1) + [False]
elif schedule == "never":
return [False] * num_timesteps
elif "every" in schedule:
n = int(schedule.split("_")[1])
return [(i+1) % n == 0 for i in xrange(num_timesteps)]
else:
sched = [x.strip() == "true" for x in schedule.split(",")]
assert len(
sched
) == num_timesteps, "Wrong number of timesteps in resampling schedule."
return sched
def create_log_hook(step, eval_log_p_hat, eval_likelihood):
def summ_formatter(d):
return ("Step {step}, log p_hat: {log_p_hat:.5f} likelihood: {likelihood:.5f}".format(**d))
hook = tf.train.LoggingTensorHook(
{
"step": step,
"log_p_hat": eval_log_p_hat,
"likelihood": eval_likelihood,
},
every_n_iter=FLAGS.summarize_every,
formatter=summ_formatter)
return hook
def create_infrequent_summary_hook():
infrequent_summary_hook = tf.train.SummarySaverHook(
save_steps=10000,
output_dir=FLAGS.logdir,
summary_op=tf.summary.merge_all(key="infrequent_summaries")
)
return infrequent_summary_hook
def main(unused_argv):
if FLAGS.model == "long_chain":
resampling_schedule = parse_resampling_schedule(FLAGS.resampling,
FLAGS.num_timesteps + 1)
else:
resampling_schedule = parse_resampling_schedule(FLAGS.resampling,
FLAGS.num_timesteps)
if FLAGS.random_seed is None:
seed = np.random.randint(0, high=10000)
else:
seed = FLAGS.random_seed
tf.logging.info("Using random seed %d", seed)
if FLAGS.model == "long_chain":
assert FLAGS.q_type == "normal", "Q type %s not supported for long chain models" % FLAGS.q_type
assert FLAGS.p_type == "unimodal", "Bimodal priors are not supported for long chain models"
assert not FLAGS.use_bs, "Bs are not supported with long chain models"
assert FLAGS.num_timesteps == FLAGS.num_observations * FLAGS.steps_per_observation, "Num timesteps does not match."
assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with long chain models."
if FLAGS.model == "forward":
if "nonlinear" not in FLAGS.p_type:
assert FLAGS.transition_type == models.STANDARD_TRANSITION, "Non-standard transitions not supported by the forward model."
assert FLAGS.observation_type == models.STANDARD_OBSERVATION, "Non-standard observations not supported by the forward model."
assert FLAGS.observation_variance is None, "Forward model does not support observation variance."
assert FLAGS.num_observations == 1, "Forward model only supports 1 observation."
if "relaxed" in FLAGS.resampling_method:
FLAGS.use_resampling_grads = False
assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with relaxed resampling."
if FLAGS.observation_variance is None:
FLAGS.observation_variance = FLAGS.variance
if FLAGS.p_type == "bimodal":
assert FLAGS.bimodal_prior_mean is not None, "Must specify prior mean if using bimodal p."
if FLAGS.p_type == "nonlinear" or FLAGS.p_type == "nonlinear-cauchy":
assert not FLAGS.use_bs, "Using bs is not compatible with the nonlinear model."
g = tf.Graph()
with g.as_default():
# Set the seeds.
tf.set_random_seed(seed)
np.random.seed(seed)
if FLAGS.model == "long_chain":
(global_step, train_op, eval_log_p_hat,
eval_likelihood) = create_long_chain_graph(
FLAGS.bound,
FLAGS.state_size,
FLAGS.num_observations,
FLAGS.steps_per_observation,
FLAGS.batch_size,
FLAGS.num_samples,
FLAGS.num_eval_samples,
resampling_schedule,
FLAGS.use_resampling_grads,
FLAGS.learning_rate,
FLAGS.decay_steps)
else:
(global_step, train_op,
eval_log_p_hat, eval_likelihood) = create_graph(
FLAGS.bound,
FLAGS.state_size,
FLAGS.num_timesteps,
FLAGS.batch_size,
FLAGS.num_samples,
FLAGS.num_eval_samples,
resampling_schedule,
FLAGS.use_resampling_grads,
FLAGS.learning_rate,
FLAGS.decay_steps,
FLAGS.train_p)
log_hooks = [create_log_hook(global_step, eval_log_p_hat, eval_likelihood)]
if len(tf.get_collection("infrequent_summaries")) > 0:
log_hooks.append(create_infrequent_summary_hook())
tf.logging.info("trainable variables:")
tf.logging.info([v.name for v in tf.trainable_variables()])
tf.logging.info("p vars:")
tf.logging.info([v.name for v in tf.get_collection("P_VARS")])
tf.logging.info("q vars:")
tf.logging.info([v.name for v in tf.get_collection("Q_VARS")])
tf.logging.info("r vars:")
tf.logging.info([v.name for v in tf.get_collection("R_VARS")])
tf.logging.info("r tilde vars:")
tf.logging.info([v.name for v in tf.get_collection("R_TILDE_VARS")])
with tf.train.MonitoredTrainingSession(
master="",
is_chief=True,
hooks=log_hooks,
checkpoint_dir=FLAGS.logdir,
save_checkpoint_secs=120,
save_summaries_steps=FLAGS.summarize_every,
log_step_count_steps=FLAGS.summarize_every) as sess:
cur_step = -1
while True:
if sess.should_stop() or cur_step > FLAGS.max_steps:
break
# run a step
_, cur_step = sess.run([train_op, global_step])
if __name__ == "__main__":
tf.app.run(main)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -23,13 +23,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
import nested_utils as nested
from fivo import nested_utils as nested
from fivo import smc
def iwae(cell,
inputs,
def iwae(model,
observations,
seq_lengths,
num_samples=1,
parallel_iterations=30,
......@@ -45,13 +47,13 @@ def iwae(cell,
When num_samples = 1, this bound becomes the evidence lower bound (ELBO).
Args:
cell: A callable that implements one timestep of the model. See
models/vrnn.py for an example.
inputs: The inputs to the model. A potentially nested list or tuple of
model: A subclass of ELBOTrainableSequenceModel that implements one
timestep of the model. See models/vrnn.py for an example.
observations: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. At each
timestep 'cell' will be called with a slice of the Tensors in inputs.
dimensions, which represent time and the batch respectively. The model
will be provided with the observations before computing the bound.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of samples to use.
......@@ -63,98 +65,28 @@ def iwae(cell,
Returns:
log_p_hat: A Tensor of shape [batch_size] containing IWAE's estimate of the
log marginal probability of the observations.
kl: A Tensor of shape [batch_size] containing the kl divergence
from q(z|x) to p(z), averaged over samples.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep. Will not be valid for
timesteps past the end of a sequence.
log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
effective sample size at each timestep. Will not be valid for timesteps
past the end of a sequence.
"""
batch_size = tf.shape(seq_lengths)[0]
max_seq_len = tf.reduce_max(seq_lengths)
seq_mask = tf.transpose(
tf.sequence_mask(seq_lengths, maxlen=max_seq_len, dtype=tf.float32),
perm=[1, 0])
if num_samples > 1:
inputs, seq_mask = nested.tile_tensors([inputs, seq_mask], [1, num_samples])
inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask], max_seq_len)
t0 = tf.constant(0, tf.int32)
init_states = cell.zero_state(batch_size * num_samples, tf.float32)
ta_names = ['log_weights', 'log_ess']
tas = [tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
for n in ta_names]
log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
accs = (log_weights_acc, kl_acc)
def while_predicate(t, *unused_args):
return t < max_seq_len
def while_step(t, rnn_state, tas, accs):
"""Implements one timestep of IWAE computation."""
log_weights_acc, kl_acc = accs
cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
# Run the cell for one step.
log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
cur_inputs,
rnn_state,
cur_mask,
)
# Compute the incremental weight and use it to update the current
# accumulated weight.
kl_acc += kl * cur_mask
log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
log_weights_acc += log_alpha
# Calculate the effective sample size.
ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
log_ess = ess_num - ess_denom
# Update the Tensorarrays and accumulators.
ta_updates = [log_weights_acc, log_ess]
new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
new_accs = (log_weights_acc, kl_acc)
return t + 1, new_state, new_tas, new_accs
_, _, tas, accs = tf.while_loop(
while_predicate,
while_step,
loop_vars=(t0, init_states, tas, accs),
log_p_hat, log_weights, _, final_state = fivo(
model,
observations,
seq_lengths,
num_samples=num_samples,
resampling_criterion=smc.never_resample_criterion,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
log_weights, log_ess = [x.stack() for x in tas]
final_log_weights, kl = accs
log_p_hat = (tf.reduce_logsumexp(final_log_weights, axis=0) -
tf.log(tf.to_float(num_samples)))
kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0)
log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
return log_p_hat, kl, log_weights, log_ess
return log_p_hat, log_weights, final_state
def ess_criterion(num_samples, log_ess, unused_t):
"""A criterion that resamples based on effective sample size."""
return log_ess <= tf.log(num_samples / 2.0)
def never_resample_criterion(unused_num_samples, log_ess, unused_t):
"""A criterion that never resamples."""
return tf.cast(tf.zeros_like(log_ess), tf.bool)
def always_resample_criterion(unused_num_samples, log_ess, unused_t):
"""A criterion resamples at every timestep."""
return tf.cast(tf.ones_like(log_ess), tf.bool)
def fivo(cell,
inputs,
def fivo(model,
observations,
seq_lengths,
num_samples=1,
resampling_criterion=ess_criterion,
resampling_criterion=smc.ess_criterion,
resampling_type='multinomial',
relaxed_resampling_temperature=0.5,
parallel_iterations=30,
swap_memory=True,
random_seed=None):
......@@ -170,21 +102,26 @@ def fivo(cell,
When the resampling criterion is "never resample", this bound becomes IWAE.
Args:
cell: A callable that implements one timestep of the model. See
models/vrnn.py for an example.
inputs: The inputs to the model. A potentially nested list or tuple of
model: A subclass of ELBOTrainableSequenceModel that implements one
timestep of the model. See models/vrnn.py for an example.
observations: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. At each
timestep 'cell' will be called with a slice of the Tensors in inputs.
dimensions, which represent time and the batch respectively. The model
will be provided with the observations before computing the bound.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of particles to use in each particle filter.
resampling_criterion: The resampling criterion to use for this particle
filter. Must accept the number of samples, the effective sample size,
filter. Must accept the number of samples, the current log weights,
and the current timestep and return a boolean Tensor of shape [batch_size]
indicating whether each particle filter should resample. See
ess_criterion and related functions defined in this file for examples.
ess_criterion and related functions for examples. When
resampling_criterion is never_resample_criterion, resampling_fn is ignored
and never called.
resampling_type: The type of resampling, one of "multinomial" or "relaxed".
relaxed_resampling_temperature: A positive temperature only used for relaxed
resampling.
parallel_iterations: The number of parallel iterations to use for the
internal while loop. Note that values greater than 1 can introduce
non-determinism even when random_seed is provided.
......@@ -196,28 +133,17 @@ def fivo(cell,
Returns:
log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
log marginal probability of the observations.
kl: A Tensor of shape [batch_size] containing the sum over time of the kl
divergence from q_t(z_t|x) to p_t(z_t), averaged over particles. Note that
this includes kl terms from trajectories that are culled during resampling
steps.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep of the particle filter. Note
that on timesteps when a resampling operation is performed the log weights
are reset to 0. Will not be valid for timesteps past the end of a
sequence.
log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
effective sample size of each particle filter at each timestep. Will not
be valid for timesteps past the end of a sequence.
resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
particle filters resampled. Will be 1.0 on timesteps when resampling
occurred and 0.0 on timesteps when it did not.
"""
# batch_size represents the number of particle filters running in parallel.
# batch_size is the number of particle filters running in parallel.
batch_size = tf.shape(seq_lengths)[0]
max_seq_len = tf.reduce_max(seq_lengths)
seq_mask = tf.transpose(
tf.sequence_mask(seq_lengths, maxlen=max_seq_len, dtype=tf.float32),
perm=[1, 0])
# Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as:
......@@ -228,96 +154,164 @@ def fivo(cell,
# particle 2 of particle filter 1
# ...
# particle num_samples of particle filter batch_size
if num_samples > 1:
inputs, seq_mask = nested.tile_tensors([inputs, seq_mask], [1, num_samples])
inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask], max_seq_len)
t0 = tf.constant(0, tf.int32)
init_states = cell.zero_state(batch_size * num_samples, tf.float32)
ta_names = ['log_weights', 'log_ess', 'resampled']
tas = [tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
for n in ta_names]
log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
log_p_hat_acc = tf.zeros([batch_size], dtype=tf.float32)
kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
accs = (log_weights_acc, log_p_hat_acc, kl_acc)
observations = nested.tile_tensors(observations, [1, num_samples])
tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
model.set_observations(observations, tiled_seq_lengths)
if resampling_type == 'multinomial':
resampling_fn = smc.multinomial_resampling
elif resampling_type == 'relaxed':
resampling_fn = functools.partial(
smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)
def transition_fn(prev_state, t):
if prev_state is None:
return model.zero_state(batch_size * num_samples, tf.float32)
return model.propose_and_weight(prev_state, t)
log_p_hat, log_weights, resampled, final_state, _ = smc.smc(
transition_fn,
seq_lengths,
num_particles=num_samples,
resampling_criterion=resampling_criterion,
resampling_fn=resampling_fn,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
def while_predicate(t, *unused_args):
return t < max_seq_len
return log_p_hat, log_weights, resampled, final_state
def fivo_aux_td(
model,
observations,
seq_lengths,
num_samples=1,
resampling_criterion=smc.ess_criterion,
resampling_type='multinomial',
relaxed_resampling_temperature=0.5,
parallel_iterations=30,
swap_memory=True,
random_seed=None):
"""Experimental."""
# batch_size is the number of particle filters running in parallel.
batch_size = tf.shape(seq_lengths)[0]
max_seq_len = tf.reduce_max(seq_lengths)
def while_step(t, rnn_state, tas, accs):
"""Implements one timestep of FIVO computation."""
log_weights_acc, log_p_hat_acc, kl_acc = accs
cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
# Run the cell for one step.
log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
cur_inputs,
rnn_state,
cur_mask,
)
# Compute the incremental weight and use it to update the current
# accumulated weight.
kl_acc += kl * cur_mask
log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
log_weights_acc += log_alpha
# Calculate the effective sample size.
ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
log_ess = ess_num - ess_denom
# Calculate the ancestor indices via resampling. Because we maintain the
# log unnormalized weights, we pass the weights in as logits, allowing
# the distribution object to apply a softmax and normalize them.
resampling_dist = tf.contrib.distributions.Categorical(
logits=tf.transpose(log_weights_acc, perm=[1, 0]))
ancestor_inds = tf.stop_gradient(
resampling_dist.sample(sample_shape=num_samples, seed=random_seed))
# Because the batch is flattened and laid out as discussed
# above, we must modify ancestor_inds to index the proper samples.
# The particles in the ith filter are distributed every batch_size rows
# in the batch, and offset i rows from the top. So, to correct the indices
# we multiply by the batch_size and add the proper offset. Crucially,
# when ancestor_inds is flattened the layout of the batch is maintained.
offset = tf.expand_dims(tf.range(batch_size), 0)
ancestor_inds = tf.reshape(ancestor_inds * batch_size + offset, [-1])
noresample_inds = tf.range(num_samples * batch_size)
# Decide whether or not we should resample; don't resample if we are past
# the end of a sequence.
should_resample = resampling_criterion(num_samples, log_ess, t)
should_resample = tf.logical_and(should_resample,
cur_mask[:batch_size] > 0.)
float_should_resample = tf.to_float(should_resample)
ancestor_inds = tf.where(
tf.tile(should_resample, [num_samples]),
ancestor_inds,
noresample_inds)
new_state = nested.gather_tensors(new_state, ancestor_inds)
# Update the TensorArrays before we reset the weights so that we capture
# the incremental weights and not zeros.
ta_updates = [log_weights_acc, log_ess, float_should_resample]
new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
# For the particle filters that resampled, update log_p_hat and
# reset weights to zero.
# Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as:
# particle 1 of particle filter 1
# particle 1 of particle filter 2
# ...
# particle 1 of particle filter batch_size
# particle 2 of particle filter 1
# ...
# particle num_samples of particle filter batch_size
observations = nested.tile_tensors(observations, [1, num_samples])
tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
model.set_observations(observations, tiled_seq_lengths)
if resampling_type == 'multinomial':
resampling_fn = smc.multinomial_resampling
elif resampling_type == 'relaxed':
resampling_fn = functools.partial(
smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)
def transition_fn(prev_state, t):
if prev_state is None:
model_init_state = model.zero_state(batch_size * num_samples, tf.float32)
return (tf.zeros([num_samples*batch_size], dtype=tf.float32),
(tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32),
tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32)),
model_init_state)
prev_log_r, prev_log_r_tilde, prev_model_state = prev_state
(new_model_state, zt, log_q_zt, log_p_zt,
log_p_x_given_z, log_r_tilde, p_ztplus1) = model(prev_model_state, t)
r_tilde_mu, r_tilde_sigma_sq = log_r_tilde
# Compute the weight without r.
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
# Compute log_r and log_r_tilde.
p_mu = tf.stop_gradient(p_ztplus1.mean())
p_sigma_sq = tf.stop_gradient(p_ztplus1.variance())
log_r = (tf.log(r_tilde_sigma_sq) -
tf.log(r_tilde_sigma_sq + p_sigma_sq) -
tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
# log_r is [num_samples*batch_size, latent_size]. We sum it along the last
# dimension to compute log r.
log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
# Compute prev log r tilde
prev_r_tilde_mu, prev_r_tilde_sigma_sq = prev_log_r_tilde
prev_log_r_tilde = -0.5*tf.reduce_sum(
tf.square(tf.stop_gradient(zt) - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
# If the sequence is on the last timestep, log_r and log_r_tilde are just zeros.
last_timestep = t >= (tiled_seq_lengths - 1)
log_r = tf.where(last_timestep,
tf.zeros_like(log_r),
log_r)
prev_log_r_tilde = tf.where(last_timestep,
tf.zeros_like(prev_log_r_tilde),
prev_log_r_tilde)
log_weight += tf.stop_gradient(log_r - prev_log_r)
new_state = (log_r, log_r_tilde, new_model_state)
loop_fn_args = (log_r, prev_log_r_tilde, log_p_x_given_z, log_r - prev_log_r)
return log_weight, new_state, loop_fn_args
def loop_fn(loop_state, loop_args, unused_model_state, log_weights, resampled, mask, t):
if loop_state is None:
return (tf.zeros([batch_size], dtype=tf.float32),
tf.zeros([batch_size], dtype=tf.float32),
tf.zeros([num_samples, batch_size], dtype=tf.float32))
log_p_hat_acc, bellman_loss_acc, log_r_diff_acc = loop_state
log_r, prev_log_r_tilde, log_p_x_given_z, log_r_diff = loop_args
# Compute the log_p_hat update
log_p_hat_update = tf.reduce_logsumexp(
log_weights_acc, axis=0) - tf.log(tf.to_float(num_samples))
log_p_hat_acc += log_p_hat_update * float_should_resample
log_weights_acc *= (1. - tf.tile(float_should_resample[tf.newaxis, :],
[num_samples, 1]))
new_accs = (log_weights_acc, log_p_hat_acc, kl_acc)
return t + 1, new_state, new_tas, new_accs
_, _, tas, accs = tf.while_loop(
while_predicate,
while_step,
loop_vars=(t0, init_states, tas, accs),
log_weights, axis=0) - tf.log(tf.to_float(num_samples))
# If it is the last timestep, we always add the update.
log_p_hat_acc += tf.cond(t >= max_seq_len-1,
lambda: log_p_hat_update,
lambda: log_p_hat_update * resampled)
# Compute the Bellman update.
log_r = tf.reshape(log_r, [num_samples, batch_size])
prev_log_r_tilde = tf.reshape(prev_log_r_tilde, [num_samples, batch_size])
log_p_x_given_z = tf.reshape(log_p_x_given_z, [num_samples, batch_size])
mask = tf.reshape(mask, [num_samples, batch_size])
# On the first timestep there is no bellman error because there is no
# prev_log_r_tilde.
mask = tf.cond(tf.equal(t, 0),
lambda: tf.zeros_like(mask),
lambda: mask)
# On the first timestep also fix up prev_log_r_tilde, which will be -inf.
prev_log_r_tilde = tf.where(
tf.is_inf(prev_log_r_tilde),
tf.zeros_like(prev_log_r_tilde),
prev_log_r_tilde)
# log_lambda is [num_samples, batch_size]
log_lambda = tf.reduce_mean(prev_log_r_tilde - log_p_x_given_z - log_r,
axis=0, keepdims=True)
bellman_error = mask * tf.square(
prev_log_r_tilde -
tf.stop_gradient(log_lambda + log_p_x_given_z + log_r)
)
bellman_loss_acc += tf.reduce_mean(bellman_error, axis=0)
# Compute the log_r_diff update
log_r_diff_acc += mask * tf.reshape(log_r_diff, [num_samples, batch_size])
return (log_p_hat_acc, bellman_loss_acc, log_r_diff_acc)
log_weights, resampled, accs = smc.smc(
transition_fn,
seq_lengths,
num_particles=num_samples,
resampling_criterion=resampling_criterion,
resampling_fn=resampling_fn,
loop_fn=loop_fn,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
log_weights, log_ess, resampled = [x.stack() for x in tas]
final_log_weights, log_p_hat, kl = accs
# Add in the final weight update to log_p_hat.
log_p_hat += (tf.reduce_logsumexp(final_log_weights, axis=0) -
tf.log(tf.to_float(num_samples)))
kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0)
log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
return log_p_hat, kl, log_weights, log_ess, resampled
log_p_hat, bellman_loss, log_r_diff = accs
loss_per_seq = [- log_p_hat, bellman_loss]
tf.summary.scalar("bellman_loss",
tf.reduce_mean(bellman_loss / tf.to_float(seq_lengths)))
tf.summary.scalar("log_r_diff",
tf.reduce_mean(tf.reduce_mean(log_r_diff, axis=0) / tf.to_float(seq_lengths)))
return loss_per_seq, log_p_hat, log_weights, resampled
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.bounds"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from fivo.test_utils import create_vrnn
from fivo import bounds
class BoundsTest(tf.test.TestCase):
def test_elbo(self):
"""A golden-value test for the ELBO (the IWAE bound with num_samples=1)."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.iwae(model, (inputs, targets), lengths, num_samples=1,
parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, _, _ = sess.run(outs)
self.assertAllClose([-21.615765, -13.614225], log_p_hat)
def test_iwae(self):
"""A golden-value test for the IWAE bound."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.iwae(model, (inputs, targets), lengths, num_samples=4,
parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, weights, _ = sess.run(outs)
self.assertAllClose([-23.301426, -13.64028], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074022, -4.91751671, -5.03293562],
[-2.99690723, -3.17782736, -4.50084877, -3.48536515]],
[[-6.2539978, -4.37615728, -7.43738699, -7.85044909],
[-8.27518654, -6.71545124, -8.96198845, -7.05567837]],
[[-9.19093227, -8.01637268, -11.64603615, -10.51128292],
[-12.34527206, -11.54284477, -11.8667469, -9.69417381]],
[[-12.20609856, -10.47217369, -13.66270638, -13.46115875],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-16.14766312, -15.57472229, -17.47755432, -17.98189926],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-20.07182884, -18.43191147, -20.1606636, -21.45263863],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-24.10270691, -22.20865822, -24.14675522, -25.27248383],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]]])
self.assertAllClose(weights_gt, weights)
def test_fivo(self):
"""A golden-value test for the FIVO bound."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-22.98902512, -14.21689224], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074022, -4.91751671, -5.03293562],
[-2.99690723, -3.17782736, -4.50084877, -3.48536515]],
[[-2.67100811, -2.30541706, -2.34178066, -2.81751347],
[-8.27518654, -6.71545124, -8.96198845, -7.05567837]],
[[-5.65190411, -5.94563246, -6.55041981, -5.4783473],
[-12.34527206, -11.54284477, -11.8667469, -9.69417381]],
[[-8.71947861, -8.40143299, -8.54593086, -8.42822266],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-12.7003831, -13.5039815, -12.3569726, -12.9489622],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-16.4520301, -16.3611698, -15.0314846, -16.4197006],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-20.7010765, -20.1379165, -19.0020351, -20.2395458],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array(
[[1., 0.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
def test_fivo_relaxed(self):
"""A golden-value test for the FIVO bound with relaxed sampling."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1,
resampling_type="relaxed")
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-22.942394, -14.273882], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074118, -4.91751575, -5.03293514],
[-2.99690628, -3.17782831, -4.50084877, -3.48536515]],
[[-2.84939098, -2.30087185, -2.35649204, -2.48417377],
[-8.27518654, -6.71545172, -8.96199131, -7.05567837]],
[[-5.92327023, -5.9433074, -6.5826683, -5.04259014],
[-12.34527206, -11.54284668, -11.86675072, -9.69417477]],
[[-8.95323944, -8.40061855, -8.52760506, -7.99130583],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-12.87836456, -13.49628639, -12.31680107, -12.74228859],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-16.78347397, -16.35150909, -14.98797417, -16.35162735],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-20.81165886, -20.1307621, -18.92229652, -20.17458153],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array(
[[1., 0.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
def test_fivo_aux_relaxed(self):
"""A golden-value test for the FIVO-AUX bound with relaxed sampling."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234,
use_tilt=True)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1,
resampling_type="relaxed")
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-23.1395, -14.271059], log_p_hat)
weights_gt = np.array(
[[[-5.19826221, -3.55476403, -5.98663855, -6.08058834],
[-6.31685925, -5.70243931, -7.07638931, -6.18138981]],
[[-3.97986865, -3.58831525, -3.85753584, -3.5010016],
[-11.38203049, -8.66213989, -11.23646641, -10.02024746]],
[[-6.62269831, -6.36680222, -6.78096485, -5.80072498],
[-3.55419445, -8.11326408, -3.48766923, -3.08593249]],
[[-10.56472301, -10.16084099, -9.96741676, -8.5270071],
[-6.04880285, -7.80853653, -4.72652149, -3.49711013]],
[[-13.36585426, -16.08720398, -13.33416367, -13.1017189],
[-0., -0., -0., -0.]],
[[-17.54233551, -17.35167503, -16.79163361, -16.51471138],
[0., -0., -0., -0.]],
[[-19.74024963, -18.69452858, -17.76246452, -18.76182365],
[0., -0., -0., -0.]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array([[1., 0.],
[0., 1.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
if __name__ == "__main__":
np.set_printoptions(threshold=np.nan) # Used to easily see the gold values.
# Use print(repr(numpy_array)) to print the values.
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocesses TIMIT from raw wavfiles to create a set of TFRecords.
"""
......
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,6 +21,8 @@ from __future__ import division
from __future__ import print_function
import pickle
import numpy as np
from scipy.sparse import coo_matrix
import tensorflow as tf
......@@ -147,6 +149,95 @@ def create_pianoroll_dataset(path,
return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32)
def create_human_pose_dataset(
path,
split,
batch_size,
num_parallel_calls=DEFAULT_PARALLELISM,
shuffle=False,
repeat=False,):
"""Creates a human pose dataset.
Args:
path: The path of a pickle file containing the dataset to load.
split: The split to use, can be train, test, or valid.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros. This tensor is mean-centered, with the mean taken from the pickle
file key 'train_mean'.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
mean: A float Tensor of shape [data_dimension] containing the mean loaded
from the pickle file.
"""
# Load the data from disk.
with tf.gfile.Open(path, "r") as f:
raw_data = pickle.load(f)
mean = raw_data["train_mean"]
pose_sequences = raw_data[split]
num_examples = len(pose_sequences)
num_features = pose_sequences[0].shape[1]
def pose_generator():
"""A generator that yields pose data sequences."""
# Each timestep has 32 x values followed by 32 y values so is 64
# dimensional.
for pose_sequence in pose_sequences:
yield pose_sequence, pose_sequence.shape[0]
dataset = tf.data.Dataset.from_generator(
pose_generator,
output_types=(tf.float64, tf.int64),
output_shapes=([None, num_features], []))
if repeat:
dataset = dataset.repeat()
if shuffle:
dataset = dataset.shuffle(num_examples)
# Batch sequences togther, padding them to a common length in time.
dataset = dataset.padded_batch(
batch_size, padded_shapes=([None, num_features], []))
# Post-process each batch, ensuring that it is mean-centered and time-major.
def process_pose_data(data, lengths):
"""Creates Tensors for next step prediction and mean-centers the input."""
data = tf.to_float(tf.transpose(data, perm=[1, 0, 2]))
lengths = tf.to_int32(lengths)
targets = data
# Mean center the inputs.
inputs = data - tf.constant(
mean, dtype=tf.float32, shape=[1, 1, mean.shape[0]])
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs = tf.pad(inputs, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
# Mask out unused timesteps.
inputs *= tf.expand_dims(
tf.transpose(tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
return inputs, targets, lengths
dataset = dataset.map(
process_pose_data,
num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(num_examples)
itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32)
def create_speech_dataset(path,
batch_size,
samples_per_timestep=200,
......@@ -221,3 +312,142 @@ def create_speech_dataset(path,
itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths
SQUARED_OBSERVATION = "squared"
ABS_OBSERVATION = "abs"
STANDARD_OBSERVATION = "standard"
OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION]
ROUND_TRANSITION = "round"
STANDARD_TRANSITION = "standard"
TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION]
def create_chain_graph_dataset(
batch_size,
num_timesteps,
steps_per_observation=None,
state_size=1,
transition_variance=1.,
observation_variance=1.,
transition_type=STANDARD_TRANSITION,
observation_type=STANDARD_OBSERVATION,
fixed_observation=None,
prefetch_buffer_size=2048,
dtype="float32"):
"""Creates a toy chain graph dataset.
Creates a dataset where the data are sampled from a diffusion process. The
'latent' states of the process are sampled as a chain of Normals:
z0 ~ N(0, transition_variance)
z1 ~ N(transition_fn(z0), transition_variance)
...
where transition_fn could be round z0 or pass it through unchanged.
The observations are produced every steps_per_observation timesteps as a
function of the latent zs. For example if steps_per_observation is 3 then the
first observation will be produced as a function of z3:
x1 ~ N(observation_fn(z3), observation_variance)
where observation_fn could square z3, take the absolute value, or pass
it through unchanged.
Only the observations are returned.
Args:
batch_size: The batch size. The number of trajectories to run in parallel.
num_timesteps: The length of the chain of latent states (i.e. the
number of z's excluding z0.
steps_per_observation: The number of latent states between each observation,
must evenly divide num_timesteps.
state_size: The size of the latent state and observation, must be a
python int.
transition_variance: The variance of the transition density.
observation_variance: The variance of the observation density.
transition_type: Must be one of "round" or "standard". "round" means that
the transition density is centered at the rounded previous latent state.
"standard" centers the transition density at the previous latent state,
unchanged.
observation_type: Must be one of "squared", "abs" or "standard". "squared"
centers the observation density at the squared latent state. "abs"
centers the observaiton density at the absolute value of the current
latent state. "standard" centers the observation density at the current
latent state.
fixed_observation: If not None, fixes all observations to be a constant.
Must be a scalar.
prefetch_buffer_size: The size of the prefetch queues to use after reading
and processing the raw data.
dtype: A string convertible to a tensorflow datatype. The datatype used
to represent the states and observations.
Returns:
observations: A batch of observations represented as a dense Tensor of
shape [num_observations, batch_size, state_size]. num_observations is
num_timesteps/steps_per_observation.
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch. Will contain num_observations as each entry.
Raises:
ValueError: Raised if steps_per_observation does not evenly divide
num_timesteps.
"""
if steps_per_observation is None:
steps_per_observation = num_timesteps
if num_timesteps % steps_per_observation != 0:
raise ValueError("steps_per_observation must evenly divide num_timesteps.")
num_observations = int(num_timesteps / steps_per_observation)
def data_generator():
"""An infinite generator of latents and observations from the model."""
transition_std = np.sqrt(transition_variance)
observation_std = np.sqrt(observation_variance)
while True:
states = []
observations = []
# Sample z0 ~ Normal(0, sqrt(variance)).
states.append(
np.random.normal(size=[state_size],
scale=observation_std).astype(dtype))
# Start the range at 1 because we've already generated z0.
# The range ends at num_timesteps+1 because we want to include the
# num_timesteps-th step.
for t in xrange(1, num_timesteps+1):
if transition_type == ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == STANDARD_TRANSITION:
loc = states[-1]
z_t = np.random.normal(size=[state_size], loc=loc, scale=transition_std)
states.append(z_t.astype(dtype))
if t % steps_per_observation == 0:
if fixed_observation is None:
if observation_type == SQUARED_OBSERVATION:
loc = np.square(states[-1])
elif observation_type == ABS_OBSERVATION:
loc = np.abs(states[-1])
elif observation_type == STANDARD_OBSERVATION:
loc = states[-1]
x_t = np.random.normal(size=[state_size],
loc=loc,
scale=observation_std).astype(dtype)
else:
x_t = np.ones([state_size]) * fixed_observation
observations.append(x_t)
yield states, observations
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps+1, state_size],
[num_observations, state_size])
)
dataset = dataset.repeat().batch(batch_size)
dataset = dataset.prefetch(prefetch_buffer_size)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Transpose observations from [batch, time, state_size] to
# [time, batch, state_size].
observations = tf.transpose(observations, perm=[1, 0, 2])
lengths = tf.ones([batch_size], dtype=tf.int32) * num_observations
return observations, lengths
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.data.datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import os
import numpy as np
import tensorflow as tf
from fivo.data import datasets
FLAGS = tf.app.flags.FLAGS
class DatasetsTest(tf.test.TestCase):
def test_sparse_pianoroll_to_dense_empty_at_end(self):
sparse_pianoroll = [(0, 1), (1, 0), (), (1,), (), ()]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=0, num_notes=2)
self.assertEqual(num_timesteps, 6)
self.assertAllEqual([[1, 1],
[1, 1],
[0, 0],
[0, 1],
[0, 0],
[0, 0]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_with_chord(self):
sparse_pianoroll = [(0, 1), (1, 0), (), (1,)]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=0, num_notes=2)
self.assertEqual(num_timesteps, 4)
self.assertAllEqual([[1, 1],
[1, 1],
[0, 0],
[0, 1]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_simple(self):
sparse_pianoroll = [(0,), (), (1,)]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=0, num_notes=2)
self.assertEqual(num_timesteps, 3)
self.assertAllEqual([[1, 0],
[0, 0],
[0, 1]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_subtracts_min_note(self):
sparse_pianoroll = [(4, 5), (5, 4), (), (5,), (), ()]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=4, num_notes=2)
self.assertEqual(num_timesteps, 6)
self.assertAllEqual([[1, 1],
[1, 1],
[0, 0],
[0, 1],
[0, 0],
[0, 0]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_uses_num_notes(self):
sparse_pianoroll = [(4, 5), (5, 4), (), (5,), (), ()]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=4, num_notes=3)
self.assertEqual(num_timesteps, 6)
self.assertAllEqual([[1, 1, 0],
[1, 1, 0],
[0, 0, 0],
[0, 1, 0],
[0, 0, 0],
[0, 0, 0]], dense_pianoroll)
def test_pianoroll_dataset(self):
pianoroll_data = [[(0,), (), (1,)],
[(0, 1), (1,)],
[(1,), (0,), (), (0, 1), (), ()]]
pianoroll_mean = np.zeros([3])
pianoroll_mean[-1] = 1
data = {"train": pianoroll_data, "train_mean": pianoroll_mean}
path = os.path.join(tf.test.get_temp_dir(), "test.pkl")
pickle.dump(data, open(path, "wb"))
with self.test_session() as sess:
inputs, targets, lens, mean = datasets.create_pianoroll_dataset(
path, "train", 2, num_parallel_calls=1,
shuffle=False, repeat=False,
min_note=0, max_note=2)
i1, t1, l1 = sess.run([inputs, targets, lens])
i2, t2, l2 = sess.run([inputs, targets, lens])
m = sess.run(mean)
# Check the lengths.
self.assertAllEqual([3, 2], l1)
self.assertAllEqual([6], l2)
# Check the mean.
self.assertAllEqual(pianoroll_mean, m)
# Check the targets. The targets should not be mean-centered and should
# be padded with zeros to a common length within a batch.
self.assertAllEqual([[1, 0, 0],
[0, 0, 0],
[0, 1, 0]], t1[:, 0, :])
self.assertAllEqual([[1, 1, 0],
[0, 1, 0],
[0, 0, 0]], t1[:, 1, :])
self.assertAllEqual([[0, 1, 0],
[1, 0, 0],
[0, 0, 0],
[1, 1, 0],
[0, 0, 0],
[0, 0, 0]], t2[:, 0, :])
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch. The mean should be subtracted from all timesteps except
# the first and the padding.
self.assertAllEqual([[0, 0, 0],
[1, 0, -1],
[0, 0, -1]], i1[:, 0, :])
self.assertAllEqual([[0, 0, 0],
[1, 1, -1],
[0, 0, 0]], i1[:, 1, :])
self.assertAllEqual([[0, 0, 0],
[0, 1, -1],
[1, 0, -1],
[0, 0, -1],
[1, 1, -1],
[0, 0, -1]], i2[:, 0, :])
def test_human_pose_dataset(self):
pose_data = [
[[0, 0], [2, 2]],
[[2, 2]],
[[0, 0], [0, 0], [2, 2], [2, 2], [0, 0]],
]
pose_data = [np.array(x, dtype=np.float64) for x in pose_data]
pose_data_mean = np.array([1, 1], dtype=np.float64)
data = {
"train": pose_data,
"train_mean": pose_data_mean,
}
path = os.path.join(tf.test.get_temp_dir(), "test_human_pose_dataset.pkl")
with open(path, "wb") as out:
pickle.dump(data, out)
with self.test_session() as sess:
inputs, targets, lens, mean = datasets.create_human_pose_dataset(
path, "train", 2, num_parallel_calls=1, shuffle=False, repeat=False)
i1, t1, l1 = sess.run([inputs, targets, lens])
i2, t2, l2 = sess.run([inputs, targets, lens])
m = sess.run(mean)
# Check the lengths.
self.assertAllEqual([2, 1], l1)
self.assertAllEqual([5], l2)
# Check the mean.
self.assertAllEqual(pose_data_mean, m)
# Check the targets. The targets should not be mean-centered and should
# be padded with zeros to a common length within a batch.
self.assertAllEqual([[0, 0], [2, 2]], t1[:, 0, :])
self.assertAllEqual([[2, 2], [0, 0]], t1[:, 1, :])
self.assertAllEqual([[0, 0], [0, 0], [2, 2], [2, 2], [0, 0]], t2[:, 0, :])
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch. The mean should be subtracted from all timesteps except
# the first and the padding.
self.assertAllEqual([[0, 0], [-1, -1]], i1[:, 0, :])
self.assertAllEqual([[0, 0], [0, 0]], i1[:, 1, :])
self.assertAllEqual([[0, 0], [-1, -1], [-1, -1], [1, 1], [1, 1]],
i2[:, 0, :])
def test_speech_dataset(self):
with self.test_session() as sess:
path = os.path.join(
os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
"test_data",
"tiny_speech_dataset.tfrecord")
inputs, targets, lens = datasets.create_speech_dataset(
path, 3, samples_per_timestep=2, num_parallel_calls=1,
prefetch_buffer_size=3, shuffle=False, repeat=False)
inputs1, targets1, lengths1 = sess.run([inputs, targets, lens])
inputs2, targets2, lengths2 = sess.run([inputs, targets, lens])
# Check the lengths.
self.assertAllEqual([1, 2, 3], lengths1)
self.assertAllEqual([4], lengths2)
# Check the targets. The targets should be padded with zeros to a common
# length within a batch.
self.assertAllEqual([[[0., 1.], [0., 1.], [0., 1.]],
[[0., 0.], [2., 3.], [2., 3.]],
[[0., 0.], [0., 0.], [4., 5.]]],
targets1)
self.assertAllEqual([[[0., 1.]],
[[2., 3.]],
[[4., 5.]],
[[6., 7.]]],
targets2)
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch.
self.assertAllEqual([[[0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 1.], [0., 1.]],
[[0., 0.], [0., 0.], [2., 3.]]],
inputs1)
self.assertAllEqual([[[0., 0.]],
[[0., 1.]],
[[2., 3.]],
[[4., 5.]]],
inputs2)
def test_chain_graph_raises_error_on_wrong_steps_per_observation(self):
with self.assertRaises(ValueError):
datasets.create_chain_graph_dataset(
batch_size=4,
num_timesteps=10,
steps_per_observation=9)
def test_chain_graph_single_obs(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 1
num_timesteps = 5
batch_size = 2
state_size = 1
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
state_size=state_size)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
[[[1.426677], [-1.789461]]],
out_observations)
def test_chain_graph_multiple_obs(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 3
num_timesteps = 6
batch_size = 2
state_size = 1
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
steps_per_observation=num_timesteps/num_observations,
state_size=state_size)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
[[[0.40051451], [1.07405114]],
[[1.73932898], [3.16880035]],
[[-1.98377144], [2.82669163]]],
out_observations)
def test_chain_graph_state_dims(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 1
num_timesteps = 5
batch_size = 2
state_size = 3
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
state_size=state_size)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
[[[1.052287, -4.560759, 3.07988],
[2.008926, 0.495567, 3.488678]]],
out_observations)
def test_chain_graph_fixed_obs(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 3
num_timesteps = 6
batch_size = 2
state_size = 1
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
steps_per_observation=num_timesteps/num_observations,
state_size=state_size,
fixed_observation=4.)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
np.ones([num_observations, batch_size, state_size]) * 4.,
out_observations)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Creates and runs Gaussian HMM-related graphs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from fivo import smc
from fivo import bounds
from fivo.data import datasets
from fivo.models import ghmm
def run_train(config):
"""Runs training for a Gaussian HMM setup."""
def create_logging_hook(step, bound_value, likelihood, bound_gap):
"""Creates a logging hook that prints the bound value periodically."""
bound_label = config.bound + "/t"
def summary_formatter(log_dict):
string = ("Step {step}, %s: {value:.3f}, "
"likelihood: {ll:.3f}, gap: {gap:.3e}") % bound_label
return string.format(**log_dict)
logging_hook = tf.train.LoggingTensorHook(
{"step": step, "value": bound_value,
"ll": likelihood, "gap": bound_gap},
every_n_iter=config.summarize_every,
formatter=summary_formatter)
return logging_hook
def create_losses(model, observations, lengths):
"""Creates the loss to be optimized.
Args:
model: A Trainable GHMM model.
observations: A set of observations.
lengths: The lengths of each sequence in the observations.
Returns:
loss: A float Tensor that when differentiated yields the gradients
to apply to the model. Should be optimized via gradient descent.
bound: A float Tensor containing the value of the bound that is
being optimized.
true_ll: The true log-likelihood of the data under the model.
bound_gap: The gap between the bound and the true log-likelihood.
"""
# Compute lower bounds on the log likelihood.
if config.bound == "elbo":
ll_per_seq, _, _ = bounds.iwae(
model, observations, lengths, num_samples=1,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "iwae":
ll_per_seq, _, _ = bounds.iwae(
model, observations, lengths, num_samples=config.num_samples,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "fivo":
if config.resampling_type == "relaxed":
ll_per_seq, _, _, _ = bounds.fivo(
model,
observations,
lengths,
num_samples=config.num_samples,
resampling_criterion=smc.ess_criterion,
resampling_type=config.resampling_type,
relaxed_resampling_temperature=config.
relaxed_resampling_temperature,
random_seed=config.random_seed,
parallel_iterations=config.parallel_iterations)
else:
ll_per_seq, _, _, _ = bounds.fivo(
model, observations, lengths,
num_samples=config.num_samples,
resampling_criterion=smc.ess_criterion,
resampling_type=config.resampling_type,
random_seed=config.random_seed,
parallel_iterations=config.parallel_iterations
)
ll_per_t = tf.reduce_mean(ll_per_seq / tf.to_float(lengths))
# Compute the data's true likelihood under the model and the bound gap.
true_ll_per_seq = model.likelihood(tf.squeeze(observations))
true_ll_per_t = tf.reduce_mean(true_ll_per_seq / tf.to_float(lengths))
bound_gap = true_ll_per_seq - ll_per_seq
bound_gap = tf.reduce_mean(bound_gap/ tf.to_float(lengths))
tf.summary.scalar("train_ll_bound", ll_per_t)
tf.summary.scalar("train_true_ll", true_ll_per_t)
tf.summary.scalar("bound_gap", bound_gap)
return -ll_per_t, ll_per_t, true_ll_per_t, bound_gap
def create_graph():
"""Creates the training graph."""
global_step = tf.train.get_or_create_global_step()
xs, lengths = datasets.create_chain_graph_dataset(
config.batch_size,
config.num_timesteps,
steps_per_observation=1,
state_size=1,
transition_variance=config.variance,
observation_variance=config.variance)
model = ghmm.TrainableGaussianHMM(
config.num_timesteps,
config.proposal_type,
transition_variances=config.variance,
emission_variances=config.variance,
random_seed=config.random_seed)
loss, bound, true_ll, gap = create_losses(model, xs, lengths)
opt = tf.train.AdamOptimizer(config.learning_rate)
grads = opt.compute_gradients(loss, var_list=tf.trainable_variables())
train_op = opt.apply_gradients(grads, global_step=global_step)
return bound, true_ll, gap, train_op, global_step
with tf.Graph().as_default():
if config.random_seed:
tf.set_random_seed(config.random_seed)
np.random.seed(config.random_seed)
bound, true_ll, gap, train_op, global_step = create_graph()
log_hook = create_logging_hook(global_step, bound, true_ll, gap)
with tf.train.MonitoredTrainingSession(
master="",
hooks=[log_hook],
checkpoint_dir=config.logdir,
save_checkpoint_secs=120,
save_summaries_steps=config.summarize_every,
log_step_count_steps=config.summarize_every*20) as sess:
cur_step = -1
while cur_step <= config.max_steps and not sess.should_stop():
cur_step = sess.run(global_step)
_, cur_step = sess.run([train_op, global_step])
def run_eval(config):
"""Evaluates a Gaussian HMM using the given config."""
def create_bound(model, xs, lengths):
"""Creates the bound to be evaluated."""
if config.bound == "elbo":
ll_per_seq, log_weights, _ = bounds.iwae(
model, xs, lengths, num_samples=1,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "iwae":
ll_per_seq, log_weights, _ = bounds.iwae(
model, xs, lengths, num_samples=config.num_samples,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "fivo":
ll_per_seq, log_weights, resampled, _ = bounds.fivo(
model, xs, lengths,
num_samples=config.num_samples,
resampling_criterion=smc.ess_criterion,
resampling_type=config.resampling_type,
random_seed=config.random_seed,
parallel_iterations=config.parallel_iterations
)
# Compute bound scaled by number of timesteps.
bound_per_t = ll_per_seq / tf.to_float(lengths)
if config.bound == "fivo":
return bound_per_t, log_weights, resampled
else:
return bound_per_t, log_weights
def create_graph():
"""Creates the dataset, model, and bound."""
xs, lengths = datasets.create_chain_graph_dataset(
config.batch_size,
config.num_timesteps,
steps_per_observation=1,
state_size=1,
transition_variance=config.variance,
observation_variance=config.variance)
model = ghmm.TrainableGaussianHMM(
config.num_timesteps,
config.proposal_type,
transition_variances=config.variance,
emission_variances=config.variance,
random_seed=config.random_seed)
true_likelihood = tf.reduce_mean(
model.likelihood(tf.squeeze(xs)) / tf.to_float(lengths))
outs = [true_likelihood]
outs.extend(list(create_bound(model, xs, lengths)))
return outs
with tf.Graph().as_default():
if config.random_seed:
tf.set_random_seed(config.random_seed)
np.random.seed(config.random_seed)
graph_outs = create_graph()
with tf.train.SingularMonitoredSession(
checkpoint_dir=config.logdir) as sess:
outs = sess.run(graph_outs)
likelihood = outs[0]
avg_bound = np.mean(outs[1])
std = np.std(outs[1])
log_weights = outs[2]
log_weight_variances = np.var(log_weights, axis=2)
avg_log_weight_variance = np.var(log_weight_variances, axis=1)
avg_log_weight = np.mean(log_weights, axis=(1, 2))
data = {"mean": avg_bound, "std": std, "log_weights": log_weights,
"log_weight_means": avg_log_weight,
"log_weight_variances": avg_log_weight_variance}
if len(outs) == 4:
data["resampled"] = outs[3]
data["avg_resampled"] = np.mean(outs[3], axis=1)
# Log some useful statistics.
tf.logging.info("Evaled bound %s with batch_size: %d, num_samples: %d."
% (config.bound, config.batch_size, config.num_samples))
tf.logging.info("mean: %f, std: %f" % (avg_bound, std))
tf.logging.info("true likelihood: %s" % likelihood)
tf.logging.info("avg log weight: %s" % avg_log_weight)
tf.logging.info("log weight variance: %s" % avg_log_weight_variance)
if len(outs) == 4:
tf.logging.info("avg resamples per t: %s" % data["avg_resampled"])
if not tf.gfile.Exists(config.logdir):
tf.gfile.MakeDirs(config.logdir)
with tf.gfile.Open(os.path.join(config.logdir, "out.npz"), "w") as fout:
np.save(fout, data)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.ghmm_runners."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from fivo import ghmm_runners
class GHMMRunnersTest(tf.test.TestCase):
def default_config(self):
class Config(object):
pass
config = Config()
config.model = "ghmm"
config.bound = "fivo"
config.proposal_type = "prior"
config.batch_size = 4
config.num_samples = 4
config.num_timesteps = 10
config.variance = 0.1
config.resampling_type = "multinomial"
config.random_seed = 1234
config.parallel_iterations = 1
config.learning_rate = 1e-4
config.summarize_every = 1
config.max_steps = 1
return config
def test_eval_ghmm_notraining_fivo_prior(self):
self.eval_ghmm_notraining("fivo", "prior", -3.063864)
def test_eval_ghmm_notraining_fivo_true_filtering(self):
self.eval_ghmm_notraining("fivo", "true-filtering", -1.1409812)
def test_eval_ghmm_notraining_fivo_true_smoothing(self):
self.eval_ghmm_notraining("fivo", "true-smoothing", -0.85592091)
def test_eval_ghmm_notraining_iwae_prior(self):
self.eval_ghmm_notraining("iwae", "prior", -5.9730167)
def test_eval_ghmm_notraining_iwae_true_filtering(self):
self.eval_ghmm_notraining("iwae", "true-filtering", -1.1485999)
def test_eval_ghmm_notraining_iwae_true_smoothing(self):
self.eval_ghmm_notraining("iwae", "true-smoothing", -0.85592091)
def eval_ghmm_notraining(self, bound, proposal_type, expected_bound_avg):
config = self.default_config()
config.proposal_type = proposal_type
config.bound = bound
config.logdir = os.path.join(
tf.test.get_temp_dir(), "test-ghmm-%s-%s" % (proposal_type, bound))
ghmm_runners.run_eval(config)
data = np.load(os.path.join(config.logdir, "out.npz")).item()
self.assertAlmostEqual(expected_bound_avg, data["mean"], places=3)
def test_train_ghmm_for_one_step_and_eval_fivo_filtering(self):
self.train_ghmm_for_one_step_and_eval("fivo", "filtering", -16.727108)
def test_train_ghmm_for_one_step_and_eval_fivo_smoothing(self):
self.train_ghmm_for_one_step_and_eval("fivo", "smoothing", -19.381277)
def test_train_ghmm_for_one_step_and_eval_iwae_filtering(self):
self.train_ghmm_for_one_step_and_eval("iwae", "filtering", -33.31966)
def test_train_ghmm_for_one_step_and_eval_iwae_smoothing(self):
self.train_ghmm_for_one_step_and_eval("iwae", "smoothing", -46.388447)
def train_ghmm_for_one_step_and_eval(self, bound, proposal_type, expected_bound_avg):
config = self.default_config()
config.proposal_type = proposal_type
config.bound = bound
config.max_steps = 1
config.logdir = os.path.join(
tf.test.get_temp_dir(), "test-ghmm-training-%s-%s" % (proposal_type, bound))
ghmm_runners.run_train(config)
ghmm_runners.run_eval(config)
data = np.load(os.path.join(config.logdir, "out.npz")).item()
self.assertAlmostEqual(expected_bound_avg, data["mean"], places=2)
if __name__ == "__main__":
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""VRNN classes."""
"""Reusable model classes for FIVO."""
from __future__ import absolute_import
from __future__ import division
......@@ -22,282 +22,67 @@ from __future__ import print_function
import sonnet as snt
import tensorflow as tf
from fivo import nested_utils as nested
class VRNNCell(snt.AbstractModule):
"""Implementation of a Variational Recurrent Neural Network (VRNN).
tfd = tf.contrib.distributions
Introduced in "A Recurrent Latent Variable Model for Sequential data"
by Chung et al. https://arxiv.org/pdf/1506.02216.pdf.
The VRNN is a sequence model similar to an RNN that uses stochastic latent
variables to improve its representational power. It can be thought of as a
sequential analogue to the variational auto-encoder (VAE).
class ELBOTrainableSequenceModel(object):
"""An abstract class for ELBO-trainable sequence models to extend.
The VRNN has a deterministic RNN as its backbone, represented by the
sequence of RNN hidden states h_t. At each timestep, the RNN hidden state h_t
is conditioned on the previous sequence element, x_{t-1}, as well as the
latent state from the previous timestep, z_{t-1}.
In this implementation of the VRNN the latent state z_t is Gaussian. The
model's prior over z_t is distributed as Normal(mu_t, diag(sigma_t^2)) where
mu_t and sigma_t are the mean and standard deviation output from a fully
connected network that accepts the rnn hidden state h_t as input.
The approximate posterior (also known as q or the encoder in the VAE
framework) is similar to the prior except that it is conditioned on the
current target, x_t, as well as h_t via a fully connected network.
This implementation uses the 'res_q' parameterization of the approximate
posterior, meaning that instead of directly predicting the mean of z_t, the
approximate posterior predicts the 'residual' from the prior's mean. This is
explored more in section 3.3 of https://arxiv.org/pdf/1605.07571.pdf.
During training, the latent state z_t is sampled from the approximate
posterior and the reparameterization trick is used to provide low-variance
gradients.
The generative distribution p(x_t|z_t, h_t) is conditioned on the latent state
z_t as well as the current RNN hidden state h_t via a fully connected network.
To increase the modeling power of the VRNN, two additional networks are
used to extract features from the data and the latent state. Those networks
are called data_feat_extractor and latent_feat_extractor respectively.
There are a few differences between this exposition and the paper.
First, the indexing scheme for h_t is different than the paper's -- what the
paper calls h_t we call h_{t+1}. This is the same notation used by Fraccaro
et al. to describe the VRNN in the paper linked above. Also, the VRNN paper
uses VAE terminology to refer to the different internal networks, so it
refers to the approximate posterior as the encoder and the generative
distribution as the decoder. This implementation also renamed the functions
phi_x and phi_z in the paper to data_feat_extractor and latent_feat_extractor.
Because the ELBO, IWAE, and FIVO bounds all accept the same arguments,
any model that is ELBO-trainable is also IWAE- and FIVO-trainable.
"""
def __init__(self,
rnn_cell,
data_feat_extractor,
latent_feat_extractor,
prior,
approx_posterior,
generative,
random_seed=None,
name="vrnn"):
"""Creates a VRNN cell.
def zero_state(self, batch_size, dtype):
"""Returns the initial state of the model as a Tensor or tuple of Tensors.
Args:
rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
deterministic backbone of the VRNN. The inputs to the RNN will be the
encoded latent state of the previous timestep with shape
[batch_size, encoded_latent_size] as well as the encoded input of the
current timestep, a Tensor of shape [batch_size, encoded_data_size].
data_feat_extractor: A callable that accepts a batch of data x_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument the inputs x_t, a Tensor of the shape
[batch_size, data_size] and return a Tensor of shape
[batch_size, encoded_data_size]. This callable will be called multiple
times in the VRNN cell so if scoping is not handled correctly then
multiple copies of the variables in this network could be made. It is
recommended to use a snt.nets.MLP module, which takes care of this for
you.
latent_feat_extractor: A callable that accepts a latent state z_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument a Tensor of shape [batch_size, latent_size] and
return a Tensor of shape [batch_size, encoded_latent_size].
This callable must also have the property 'output_size' defined,
returning encoded_latent_size.
prior: A callable that implements the prior p(z_t|h_t). Must accept as
argument the previous RNN hidden state and return a
tf.contrib.distributions.Normal distribution conditioned on the input.
approx_posterior: A callable that implements the approximate posterior
q(z_t|h_t,x_t). Must accept as arguments the encoded target of the
current timestep and the previous RNN hidden state. Must return
a tf.contrib.distributions.Normal distribution conditioned on the
inputs.
generative: A callable that implements the generative distribution
p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
and the RNN hidden state and return a subclass of
tf.contrib.distributions.Distribution that can be used to evaluate
the logprob of the targets.
random_seed: The seed for the random ops. Used mainly for testing.
name: The name of this VRNN.
batch_size: The batch size.
dtype: The datatype to use for the state.
"""
super(VRNNCell, self).__init__(name=name)
self.rnn_cell = rnn_cell
self.data_feat_extractor = data_feat_extractor
self.latent_feat_extractor = latent_feat_extractor
self.prior = prior
self.approx_posterior = approx_posterior
self.generative = generative
self.random_seed = random_seed
self.encoded_z_size = latent_feat_extractor.output_size
self.state_size = (self.rnn_cell.state_size, self.encoded_z_size)
raise NotImplementedError("zero_state not yet implemented.")
def zero_state(self, batch_size, dtype):
"""The initial state of the VRNN.
def set_observations(self, observations, seq_lengths):
"""Sets the observations for the model.
This method provides the model with all observed variables including both
inputs and targets. It will be called before running any computations with
the model that require the observations, e.g. training the model or
computing bounds, and should be used to run any necessary preprocessing
steps.
Contains the initial state of the RNN as well as a vector of zeros
corresponding to z_0.
Args:
batch_size: The batch size.
dtype: The data type of the VRNN.
Returns:
zero_state: The initial state of the VRNN.
observations: A potentially nested set of Tensors containing
all observations for the model, both inputs and targets. Typically
a set of Tensors with shape [max_seq_len, batch_size, data_size].
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
"""
return (self.rnn_cell.zero_state(batch_size, dtype),
tf.zeros([batch_size, self.encoded_z_size], dtype=dtype))
self.observations = observations
self.max_seq_len = tf.reduce_max(seq_lengths)
self.observations_ta = nested.tas_for_tensors(
observations, self.max_seq_len, clear_after_read=False)
self.seq_lengths = seq_lengths
def _build(self, observations, state, mask):
"""Computes one timestep of the VRNN.
def propose_and_weight(self, state, t):
"""Propogates model state one timestep and computes log weights.
Args:
observations: The observations at the current timestep, a tuple
containing the model inputs and targets as Tensors of shape
[batch_size, data_size].
state: The current state of the VRNN
mask: Tensor of shape [batch_size], 1.0 if the current timestep is active
active, 0.0 if it is not active.
This method accepts the current state of the model and computes the state
for the next timestep as well as the incremental log weight of each
element in the batch.
Args:
state: The current state of the model.
t: A scalar integer Tensor representing the current timestep.
Returns:
log_q_z: The logprob of the latent state according to the approximate
posterior.
log_p_z: The logprob of the latent state according to the prior.
log_p_x_given_z: The conditional log-likelihood, i.e. logprob of the
observation according to the generative distribution.
kl: The analytic kl divergence from q(z) to p(z).
state: The new state of the VRNN.
next_state: The state of the model after one timestep.
log_weights: A [batch_size] Tensor containing the incremental log weights.
"""
inputs, targets = observations
rnn_state, prev_latent_encoded = state
# Encode the data.
inputs_encoded = self.data_feat_extractor(inputs)
targets_encoded = self.data_feat_extractor(targets)
# Run the RNN cell.
rnn_inputs = tf.concat([inputs_encoded, prev_latent_encoded], axis=1)
rnn_out, new_rnn_state = self.rnn_cell(rnn_inputs, rnn_state)
# Create the prior and approximate posterior distributions.
latent_dist_prior = self.prior(rnn_out)
latent_dist_q = self.approx_posterior(rnn_out, targets_encoded,
prior_mu=latent_dist_prior.loc)
# Sample the new latent state z and encode it.
latent_state = latent_dist_q.sample(seed=self.random_seed)
latent_encoded = self.latent_feat_extractor(latent_state)
# Calculate probabilities of the latent state according to the prior p
# and approximate posterior q.
log_q_z = tf.reduce_sum(latent_dist_q.log_prob(latent_state), axis=-1)
log_p_z = tf.reduce_sum(latent_dist_prior.log_prob(latent_state), axis=-1)
analytic_kl = tf.reduce_sum(
tf.contrib.distributions.kl_divergence(
latent_dist_q, latent_dist_prior),
axis=-1)
# Create the generative dist. and calculate the logprob of the targets.
generative_dist = self.generative(latent_encoded, rnn_out)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(targets), axis=-1)
return (log_q_z, log_p_z, log_p_x_given_z, analytic_kl,
(new_rnn_state, latent_encoded))
_DEFAULT_INITIALIZERS = {"w": tf.contrib.layers.xavier_initializer(),
"b": tf.zeros_initializer()}
def create_vrnn(
data_size,
latent_size,
generative_class,
rnn_hidden_size=None,
fcnet_hidden_sizes=None,
encoded_data_size=None,
encoded_latent_size=None,
sigma_min=0.0,
raw_sigma_bias=0.25,
generative_bias_init=0.0,
initializers=None,
random_seed=None):
"""A factory method for creating VRNN cells.
raise NotImplementedError("propose_and_weight not yet implemented.")
Args:
data_size: The dimension of the vectors that make up the data sequences.
latent_size: The size of the stochastic latent state of the VRNN.
generative_class: The class of the generative distribution. Can be either
ConditionalNormalDistribution or ConditionalBernoulliDistribution.
rnn_hidden_size: The hidden state dimension of the RNN that forms the
deterministic part of this VRNN. If None, then it defaults
to latent_size.
fcnet_hidden_sizes: A list of python integers, the size of the hidden
layers of the fully connected networks that parameterize the conditional
distributions of the VRNN. If None, then it defaults to one hidden
layer of size latent_size.
encoded_data_size: The size of the output of the data encoding network. If
None, defaults to latent_size.
encoded_latent_size: The size of the output of the latent state encoding
network. If None, defaults to latent_size.
sigma_min: The minimum value that the standard deviation of the
distribution over the latent state can take.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the neural networks that parameterize the prior and
approximate posterior. Useful for preventing standard deviations close
to zero.
generative_bias_init: A bias to added to the raw output of the fully
connected network that parameterizes the generative distribution. Useful
for initalizing the mean of the distribution to a sensible starting point
such as the mean of the training data. Only used with Bernoulli generative
distributions.
initializers: The variable intitializers to use for the fully connected
networks and RNN cell. Must be a dictionary mapping the keys 'w' and 'b'
to the initializers for the weights and biases. Defaults to xavier for
the weights and zeros for the biases when initializers is None.
random_seed: A random seed for the VRNN resampling operations.
Returns:
model: A VRNNCell object.
"""
if rnn_hidden_size is None:
rnn_hidden_size = latent_size
if fcnet_hidden_sizes is None:
fcnet_hidden_sizes = [latent_size]
if encoded_data_size is None:
encoded_data_size = latent_size
if encoded_latent_size is None:
encoded_latent_size = latent_size
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
data_feat_extractor = snt.nets.MLP(
output_sizes=fcnet_hidden_sizes + [encoded_data_size],
initializers=initializers,
name="data_feat_extractor")
latent_feat_extractor = snt.nets.MLP(
output_sizes=fcnet_hidden_sizes + [encoded_latent_size],
initializers=initializers,
name="latent_feat_extractor")
prior = ConditionalNormalDistribution(
size=latent_size,
hidden_layer_sizes=fcnet_hidden_sizes,
sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
initializers=initializers,
name="prior")
approx_posterior = NormalApproximatePosterior(
size=latent_size,
hidden_layer_sizes=fcnet_hidden_sizes,
sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
initializers=initializers,
name="approximate_posterior")
if generative_class == ConditionalBernoulliDistribution:
generative = ConditionalBernoulliDistribution(
size=data_size,
hidden_layer_sizes=fcnet_hidden_sizes,
initializers=initializers,
bias_init=generative_bias_init,
name="generative")
else:
generative = ConditionalNormalDistribution(
size=data_size,
hidden_layer_sizes=fcnet_hidden_sizes,
initializers=initializers,
name="generative")
rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden_size,
initializer=initializers["w"])
return VRNNCell(rnn_cell, data_feat_extractor, latent_feat_extractor,
prior, approx_posterior, generative, random_seed=random_seed)
DEFAULT_INITIALIZERS = {"w": tf.contrib.layers.xavier_initializer(),
"b": tf.zeros_initializer()}
class ConditionalNormalDistribution(object):
......@@ -328,8 +113,9 @@ class ConditionalNormalDistribution(object):
self.sigma_min = sigma_min
self.raw_sigma_bias = raw_sigma_bias
self.name = name
self.size = size
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
initializers = DEFAULT_INITIALIZERS
self.fcnet = snt.nets.MLP(
output_sizes=hidden_layer_sizes + [2*size],
activation=hidden_activation_fn,
......@@ -378,8 +164,9 @@ class ConditionalBernoulliDistribution(object):
name: The name of this distribution, used for sonnet scoping.
"""
self.bias_init = bias_init
self.size = size
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
initializers = DEFAULT_INITIALIZERS
self.fcnet = snt.nets.MLP(
output_sizes=hidden_layer_sizes + [size],
activation=hidden_activation_fn,
......@@ -401,7 +188,18 @@ class ConditionalBernoulliDistribution(object):
class NormalApproximatePosterior(ConditionalNormalDistribution):
"""A Normally-distributed approx. posterior with res_q parameterization."""
def condition(self, tensor_list, prior_mu):
def __init__(self, size, hidden_layer_sizes, sigma_min=0.0,
raw_sigma_bias=0.25, hidden_activation_fn=tf.nn.relu,
initializers=None, smoothing=False,
name="conditional_normal_distribution"):
super(NormalApproximatePosterior, self).__init__(
size, hidden_layer_sizes, sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
hidden_activation_fn=hidden_activation_fn, initializers=initializers,
name=name)
self.smoothing = smoothing
def condition(self, tensor_list, prior_mu, smoothing_tensors=None):
"""Generates the mean and variance of the normal distribution.
Args:
......@@ -410,9 +208,135 @@ class NormalApproximatePosterior(ConditionalNormalDistribution):
prior_mu: The mean of the prior distribution associated with this
approximate posterior. Will be added to the mean produced by
this approximate posterior, in res_q fashion.
smoothing_tensors: A list of Tensors. If smoothing is True, these Tensors
will be concatenated with the tensors in tensor_list.
Returns:
mu: The mean of the approximate posterior.
sigma: The standard deviation of the approximate posterior.
"""
if self.smoothing:
tensor_list.extend(smoothing_tensors)
mu, sigma = super(NormalApproximatePosterior, self).condition(tensor_list)
return mu + prior_mu, sigma
class NonstationaryLinearDistribution(object):
"""A set of loc-scale distributions that are linear functions of inputs.
This class defines a series of location-scale distributions such that
the means are learnable linear functions of the inputs and the log variances
are learnable constants. The functions and log variances are different across
timesteps, allowing the distributions to be nonstationary.
"""
def __init__(self,
num_timesteps,
inputs_per_timestep=None,
outputs_per_timestep=None,
initializers=None,
variance_min=0.0,
output_distribution=tfd.Normal,
dtype=tf.float32):
"""Creates a NonstationaryLinearDistribution.
Args:
num_timesteps: The number of timesteps, i.e. the number of distributions.
inputs_per_timestep: A list of python ints, the dimension of inputs to the
linear function at each timestep. If not provided, the dimension at each
timestep is assumed to be 1.
outputs_per_timestep: A list of python ints, the dimension of the output
distribution at each timestep. If not provided, the dimension at each
timestep is assumed to be 1.
initializers: A dictionary containing intializers for the variables. The
initializer under the key 'w' is used for the weights in the linear
function and the initializer under the key 'b' is used for the biases.
Defaults to xavier initialization for the weights and zeros for the
biases.
variance_min: Python float, the minimum variance of each distribution.
output_distribution: A locatin-scale subclass of tfd.Distribution that
defines the output distribution, e.g. Normal.
dtype: The dtype of the weights and biases.
"""
if not initializers:
initializers = DEFAULT_INITIALIZERS
if not inputs_per_timestep:
inputs_per_timestep = [1] * num_timesteps
if not outputs_per_timestep:
outputs_per_timestep = [1] * num_timesteps
self.num_timesteps = num_timesteps
self.variance_min = variance_min
self.initializers = initializers
self.dtype = dtype
self.output_distribution = output_distribution
def _get_variables_ta(shapes, name, initializer, trainable=True):
"""Creates a sequence of variables and stores them in a TensorArray."""
# Infer shape if all shapes are equal.
first_shape = shapes[0]
infer_shape = all(shape == first_shape for shape in shapes)
ta = tf.TensorArray(
dtype=dtype, size=len(shapes), dynamic_size=False,
clear_after_read=False, infer_shape=infer_shape)
for t, shape in enumerate(shapes):
var = tf.get_variable(
name % t, shape=shape, initializer=initializer, trainable=trainable)
ta = ta.write(t, var)
return ta
bias_shapes = [[num_outputs] for num_outputs in outputs_per_timestep]
self.log_variances = _get_variables_ta(
bias_shapes, "proposal_log_variance_%d", initializers["b"])
self.mean_biases = _get_variables_ta(
bias_shapes, "proposal_b_%d", initializers["b"])
weight_shapes = zip(inputs_per_timestep, outputs_per_timestep)
self.mean_weights = _get_variables_ta(
weight_shapes, "proposal_w_%d", initializers["w"])
self.shapes = tf.TensorArray(
dtype=tf.int32, size=num_timesteps,
dynamic_size=False, clear_after_read=False).unstack(weight_shapes)
def __call__(self, t, inputs):
"""Computes the distribution at timestep t.
Args:
t: Scalar integer Tensor, the current timestep. Must be in
[0, num_timesteps).
inputs: The inputs to the linear function parameterizing the mean of
the current distribution. A Tensor of shape [batch_size, num_inputs_t].
Returns:
A tfd.Distribution subclass representing the distribution at timestep t.
"""
b = self.mean_biases.read(t)
w = self.mean_weights.read(t)
shape = self.shapes.read(t)
w = tf.reshape(w, shape)
b = tf.reshape(b, [shape[1], 1])
log_variance = self.log_variances.read(t)
scale = tf.sqrt(tf.maximum(tf.exp(log_variance), self.variance_min))
loc = tf.matmul(w, inputs, transpose_a=True) + b
return self.output_distribution(loc=loc, scale=scale)
def encode_all(inputs, encoder):
"""Encodes a timeseries of inputs with a time independent encoder.
Args:
inputs: A [time, batch, feature_dimensions] tensor.
encoder: A network that takes a [batch, features_dimensions] input and
encodes the input.
Returns:
A [time, batch, encoded_feature_dimensions] output tensor.
"""
input_shape = tf.shape(inputs)
num_timesteps, batch_size = input_shape[0], input_shape[1]
reshaped_inputs = tf.reshape(inputs, [-1, inputs.shape[-1]])
inputs_encoded = encoder(reshaped_inputs)
inputs_encoded = tf.reshape(inputs_encoded,
[num_timesteps, batch_size, encoder.output_size])
return inputs_encoded
def ta_for_tensor(x, **kwargs):
"""Creates a TensorArray for the input tensor."""
return tf.TensorArray(
x.dtype, tf.shape(x)[0], dynamic_size=False, **kwargs).unstack(x)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A Gaussian hidden markov model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from fivo.models import base
tfd = tf.contrib.distributions
class GaussianHMM(object):
"""A hidden markov model with 1-D Gaussian latent space and observations.
This is a hidden markov model where the state and observations are
one-dimensional Gaussians. The mean of each latent state is a linear
function of the previous latent state, and the mean of each observation
is a linear function of the current latent state.
The description that follows is 0-indexed instead of 1-indexed to make
it easier to reason about the parameters passed to the model.
The parameters of the model are:
T: The number timesteps, latent states, and observations.
vz_t, t=0 to T-1: The variance of the latent state at timestep t.
vx_t, t=0 to T-1: The variance of the observation at timestep t.
wz_t, t=1 to T-1: The weight that defines the latent transition at t.
wx_t, t=0 to T-1: The weight that defines the observation function at t.
There are T vz_t, vx_t, and wx_t but only T-1 wz_t because there are only
T-1 transitions in the model.
Given these parameters, sampling from the model is defined as
z_0 ~ N(0, vz_0)
x_0 | z_0 ~ N(wx_0 * z_0, vx_0)
z_1 | z_0 ~ N(wz_1 * z_0, vz_1)
x_1 | z_1 ~ N(wx_1 * z_1, vx_1)
...
z_{T-1} | z_{T-2} ~ N(wz_{T-1} * z_{T-2}, vz_{T-1})
x_{T-1} | z_{T-1} ~ N(wx_{T-1} * z_{T-1}, vx_{T-1}).
"""
def __init__(self,
num_timesteps,
transition_variances=1.,
emission_variances=1.,
transition_weights=1.,
emission_weights=1.,
dtype=tf.float32):
"""Creates a gaussian hidden markov model.
Args:
num_timesteps: A python int, the number of timesteps in the model.
transition_variances: The variance of p(z_t | z_t-1). Can be a scalar,
setting all variances to be the same, or a Tensor of shape
[num_timesteps].
emission_variances: The variance of p(x_t | z_t). Can be a scalar,
setting all variances to be the same, or a Tensor of shape
[num_timesteps].
transition_weights: The weight that defines the linear function that
produces the mean of z_t given z_{t-1}. Can be a scalar, setting
all weights to be the same, or a Tensor of shape [num_timesteps-1].
emission_weights: The weight that defines the linear function that
produces the mean of x_t given z_t. Can be a scalar, setting
all weights to be the same, or a Tensor of shape [num_timesteps].
dtype: The datatype of the state.
"""
self.num_timesteps = num_timesteps
self.dtype = dtype
def _expand_param(param, size):
param = tf.convert_to_tensor(param, dtype=self.dtype)
if not param.get_shape().as_list():
param = tf.tile(param[tf.newaxis], [size])
return param
def _ta_for_param(param):
size = tf.shape(param)[0]
ta = tf.TensorArray(dtype=param.dtype,
size=size,
dynamic_size=False,
clear_after_read=False).unstack(param)
return ta
self.transition_variances = _ta_for_param(
_expand_param(transition_variances, num_timesteps))
self.transition_weights = _ta_for_param(
_expand_param(transition_weights, num_timesteps-1))
em_var = _expand_param(emission_variances, num_timesteps)
self.emission_variances = _ta_for_param(em_var)
em_w = _expand_param(emission_weights, num_timesteps)
self.emission_weights = _ta_for_param(em_w)
self._compute_covariances(em_w, em_var)
def _compute_covariances(self, emission_weights, emission_variances):
"""Compute all covariance matrices.
Computes the covaraince matrix for the latent variables, the observations,
and the covariance between the latents and observations.
Args:
emission_weights: A Tensor of shape [num_timesteps] containing
the emission distribution weights at each timestep.
emission_variances: A Tensor of shape [num_timesteps] containing
the emiision distribution variances at each timestep.
"""
# Compute the marginal variance of each latent.
z_variances = [self.transition_variances.read(0)]
for i in range(1, self.num_timesteps):
z_variances.append(
z_variances[i-1] * tf.square(self.transition_weights.read(i-1)) +
self.transition_variances.read(i))
# Compute the latent covariance matrix.
sigma_z = []
for i in range(self.num_timesteps):
sigma_z_row = []
for j in range(self.num_timesteps):
if i == j:
sigma_z_row.append(z_variances[i])
continue
min_ind = min(i, j)
max_ind = max(i, j)
weight = tf.reduce_prod(
self.transition_weights.gather(tf.range(min_ind, max_ind)))
sigma_z_row.append(z_variances[min_ind] * weight)
sigma_z.append(tf.stack(sigma_z_row))
self.sigma_z = tf.stack(sigma_z)
# Compute the observation covariance matrix.
x_weights_outer = tf.einsum("i,j->ij", emission_weights, emission_weights)
self.sigma_x = x_weights_outer * self.sigma_z + tf.diag(emission_variances)
# Compute the latent - observation covariance matrix.
# The first axis will index latents, the second axis will index observtions.
self.sigma_zx = emission_weights[tf.newaxis, :] * self.sigma_z
self.obs_dist = tfd.MultivariateNormalFullCovariance(
loc=tf.zeros([self.num_timesteps], dtype=tf.float32),
covariance_matrix=self.sigma_x)
def transition(self, t, z_prev):
"""Compute the transition distribution p(z_t | z_t-1).
Args:
t: The current timestep, a scalar integer Tensor. When t=0 z_prev is
mostly ignored and the distribution p(z_0) is returned. z_prev is
'mostly' ignored because it is still used to derive batch_size.
z_prev: A [batch_size] set of states.
Returns:
p(z_t | z_t-1) as a univariate normal distribution.
"""
batch_size = tf.shape(z_prev)[0]
scale = tf.sqrt(self.transition_variances.read(t))
scale = tf.tile(scale[tf.newaxis], [batch_size])
loc = tf.cond(tf.greater(t, 0),
lambda: self.transition_weights.read(t-1)*z_prev,
lambda: tf.zeros_like(scale))
return tfd.Normal(loc=loc, scale=scale)
def emission(self, t, z):
"""Compute the emission distribution p(x_t | z_t).
Args:
t: The current timestep, a scalar integer Tensor.
z: A [batch_size] set of the current states.
Returns:
p(x_t | z_t) as a univariate normal distribution.
"""
batch_size = tf.shape(z)[0]
scale = tf.sqrt(self.emission_variances.read(t))
scale = tf.tile(scale[tf.newaxis], [batch_size])
loc = self.emission_weights.read(t)*z
return tfd.Normal(loc=loc, scale=scale)
def filtering(self, t, z_prev, x_cur):
"""Computes the filtering distribution p(z_t | z_{t-1}, x_t).
Args:
t: A python int, the index for z_t. When t is 0, z_prev is ignored,
giving p(z_0 | x_0).
z_prev: z_{t-1}, the previous z to condition on. A Tensor of shape
[batch_size].
x_cur: x_t, the current x to condition on. A Tensor of shape [batch_size].
Returns:
p(z_t | z_{t-1}, x_t) as a univariate normal distribution.
"""
z_prev = tf.convert_to_tensor(z_prev)
x_cur = tf.convert_to_tensor(x_cur)
batch_size = tf.shape(z_prev)[0]
z_var = self.transition_variances.read(t)
x_var = self.emission_variances.read(t)
x_weight = self.emission_weights.read(t)
prev_state_weight = x_var/(tf.square(x_weight)*z_var + x_var)
prev_state_weight *= tf.cond(tf.greater(t, 0),
lambda: self.transition_weights.read(t-1),
lambda: tf.zeros_like(prev_state_weight))
cur_obs_weight = (x_weight*z_var)/(tf.square(x_weight)*z_var + x_var)
loc = prev_state_weight*z_prev + cur_obs_weight*x_cur
scale = tf.sqrt((z_var*x_var)/(tf.square(x_weight)*z_var + x_var))
scale = tf.tile(scale[tf.newaxis], [batch_size])
return tfd.Normal(loc=loc, scale=scale)
def smoothing(self, t, z_prev, xs):
"""Computes the smoothing distribution p(z_t | z_{t-1}, x_{t:num_timesteps).
Args:
t: A python int, the index for z_t. When t is 0, z_prev is ignored,
giving p(z_0 | x_{0:num_timesteps-1}).
z_prev: z_{t-1}, the previous z to condition on. A Tensor of shape
[batch_size].
xs: x_{t:num_timesteps}, the future xs to condition on. A Tensor of shape
[num_timesteps - t, batch_size].
Returns:
p(z_t | z_{t-1}, x_{t:num_timesteps}) as a univariate normal distribution.
"""
xs = tf.convert_to_tensor(xs)
z_prev = tf.convert_to_tensor(z_prev)
batch_size = tf.shape(xs)[1]
mess_mean, mess_prec = tf.cond(
tf.less(t, self.num_timesteps-1),
lambda: tf.unstack(self._compute_backwards_messages(xs[1:]).read(0)),
lambda: [tf.zeros([batch_size]), tf.zeros([batch_size])])
return self._smoothing_from_message(t, z_prev, xs[0], mess_mean, mess_prec)
def _smoothing_from_message(self, t, z_prev, x_t, mess_mean, mess_prec):
"""Computes the smoothing distribution given message incoming to z_t.
Computes p(z_t | z_{t-1}, x_{t:num_timesteps}) given the message incoming
to the node for z_t.
Args:
t: A python int, the index for z_t. When t is 0, z_prev is ignored.
z_prev: z_{t-1}, the previous z to condition on. A Tensor of shape
[batch_size].
x_t: The observation x at timestep t.
mess_mean: The mean of the message incoming to z_t, in information form.
mess_prec: The precision of the message incoming to z_t.
Returns:
p(z_t | z_{t-1}, x_{t:num_timesteps}) as a univariate normal distribution.
"""
batch_size = tf.shape(x_t)[0]
z_var = self.transition_variances.read(t)
x_var = self.emission_variances.read(t)
w_x = self.emission_weights.read(t)
def transition_term():
return (tf.square(self.transition_weights.read(t))/
self.transition_variances.read(t+1))
prec = 1./z_var + tf.square(w_x)/x_var + mess_prec
prec += tf.cond(tf.less(t, self.num_timesteps-1),
transition_term, lambda: 0.)
mean = x_t*(w_x/x_var) + mess_mean
mean += tf.cond(tf.greater(t, 0),
lambda: z_prev*(self.transition_weights.read(t-1)/z_var),
lambda: 0.)
mean = tf.reshape(mean / prec, [batch_size])
scale = tf.reshape(tf.sqrt(1./prec), [batch_size])
return tfd.Normal(loc=mean, scale=scale)
def _compute_backwards_messages(self, xs):
"""Computes the backwards messages used in smoothing."""
batch_size = tf.shape(xs)[1]
num_xs = tf.shape(xs)[0]
until_t = self.num_timesteps - num_xs
xs = tf.TensorArray(dtype=xs.dtype,
size=num_xs,
dynamic_size=False,
clear_after_read=True).unstack(xs)
messages_ta = tf.TensorArray(dtype=xs.dtype,
size=num_xs,
dynamic_size=False,
clear_after_read=False)
def compute_message(t, prev_mean, prev_prec, messages_ta):
"""Computes one step of the backwards messages."""
z_var = self.transition_variances.read(t)
w_z = self.transition_weights.read(t-1)
x_var = self.emission_variances.read(t)
w_x = self.emission_weights.read(t)
cur_x = xs.read(t - until_t)
# If it isn't the first message, add the terms from the transition.
def transition_term():
return (tf.square(self.transition_weights.read(t))/
self.transition_variances.read(t+1))
unary_prec = 1/z_var + tf.square(w_x)/x_var
unary_prec += tf.cond(tf.less(t, self.num_timesteps-1),
transition_term, lambda: 0.)
unary_mean = (w_x / x_var) * cur_x
pairwise_prec = w_z / z_var
next_prec = -tf.square(pairwise_prec)/(unary_prec + prev_prec)
next_mean = (pairwise_prec * (unary_mean + prev_mean) /
(unary_prec + prev_prec))
next_prec = tf.reshape(next_prec, [batch_size])
next_mean = tf.reshape(next_mean, [batch_size])
messages_ta = messages_ta.write(t - until_t,
tf.stack([next_mean, next_prec]))
return t-1, next_mean, next_prec, messages_ta
def pred(t, *unused_args):
return tf.greater_equal(t, until_t)
init_prec = tf.zeros([batch_size], dtype=xs.dtype)
init_mean = tf.zeros([batch_size], dtype=xs.dtype)
t0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32)
outs = tf.while_loop(pred, compute_message,
(t0, init_mean, init_prec, messages_ta))
messages = outs[-1]
return messages
def lookahead(self, t, z_prev):
"""Compute the 'lookahead' distribution, p(x_{t:T} | z_{t-1}).
Args:
t: A scalar Tensor int, the current timestep. Must be at least 1.
z_prev: The latent state at time t-1. A Tensor of shape [batch_size].
Returns:
p(x_{t:T} | z_{t-1}) as a multivariate normal distribution.
"""
z_prev = tf.convert_to_tensor(z_prev)
sigma_zx = self.sigma_zx[t-1, t:]
z_var = self.sigma_z[t-1, t-1]
mean = tf.einsum("i,j->ij", z_prev, sigma_zx) / z_var
variance = (self.sigma_x[t:, t:] -
tf.einsum("i,j->ij", sigma_zx, sigma_zx) / z_var)
return tfd.MultivariateNormalFullCovariance(
loc=mean, covariance_matrix=variance)
def likelihood(self, xs):
"""Compute the true marginal likelihood of the data.
Args:
xs: The observations, a [num_timesteps, batch_size] float Tensor.
Returns:
likelihoods: A [batch_size] float Tensor representing the likelihood of
each sequence of observations in the batch.
"""
return self.obs_dist.log_prob(tf.transpose(xs))
class TrainableGaussianHMM(GaussianHMM, base.ELBOTrainableSequenceModel):
"""An interface between importance-sampling training methods and the GHMM."""
def __init__(self,
num_timesteps,
proposal_type,
transition_variances=1.,
emission_variances=1.,
transition_weights=1.,
emission_weights=1.,
random_seed=None,
dtype=tf.float32):
"""Constructs a trainable Gaussian HMM.
Args:
num_timesteps: A python int, the number of timesteps in the model.
proposal_type: The type of proposal to use in the importance sampling
setup. Could be "filtering", "smoothing", "prior", "true-filtering",
or "true-smoothing". If "true-filtering" or "true-smoothing" are
selected, then the true filtering or smoothing distributions are used to
propose new states. If "learned-filtering" is selected then a
distribution with learnable parameters is used. Specifically at each
timestep the proposal is Gaussian with mean that is a learnable linear
function of the previous state and current observation. The log variance
is a per-timestep learnable constant. "learned-smoothing" is similar,
but the mean is a learnable linear function of the previous state and
all future observations. Note that this proposal class includes the true
posterior. If "prior" is selected then states are proposed from the
model's prior.
transition_variances: The variance of p(z_t | z_t-1). Can be a scalar,
setting all variances to be the same, or a Tensor of shape
[num_timesteps].
emission_variances: The variance of p(x_t | z_t). Can be a scalar,
setting all variances to be the same, or a Tensor of shape
[num_timesteps].
transition_weights: The weight that defines the linear function that
produces the mean of z_t given z_{t-1}. Can be a scalar, setting
all weights to be the same, or a Tensor of shape [num_timesteps-1].
emission_weights: The weight that defines the linear function that
produces the mean of x_t given z_t. Can be a scalar, setting
all weights to be the same, or a Tensor of shape [num_timesteps].
random_seed: A seed for the proposal sampling, mainly useful for testing.
dtype: The datatype of the state.
"""
super(TrainableGaussianHMM, self).__init__(
num_timesteps, transition_variances, emission_variances,
transition_weights, emission_weights, dtype=dtype)
self.random_seed = random_seed
assert proposal_type in ["filtering", "smoothing", "prior",
"true-filtering", "true-smoothing"]
if proposal_type == "true-filtering":
self.proposal = self._filtering_proposal
elif proposal_type == "true-smoothing":
self.proposal = self._smoothing_proposal
elif proposal_type == "prior":
self.proposal = self.transition
elif proposal_type == "filtering":
self._learned_proposal_fn = base.NonstationaryLinearDistribution(
num_timesteps, inputs_per_timestep=[1] + [2] * (num_timesteps-1))
self.proposal = self._learned_filtering_proposal
elif proposal_type == "smoothing":
inputs_per_timestep = [num_timesteps] + [num_timesteps - t
for t in range(num_timesteps-1)]
self._learned_proposal_fn = base.NonstationaryLinearDistribution(
num_timesteps, inputs_per_timestep=inputs_per_timestep)
self.proposal = self._learned_smoothing_proposal
def set_observations(self, xs, seq_lengths):
"""Sets the observations and stores the backwards messages."""
# Squeeze out data dimension since everything is 1-d.
xs = tf.squeeze(xs)
self.batch_size = tf.shape(xs)[1]
super(TrainableGaussianHMM, self).set_observations(xs, seq_lengths)
self.messages = self._compute_backwards_messages(xs[1:])
def zero_state(self, batch_size, dtype):
return tf.zeros([batch_size], dtype=dtype)
def propose_and_weight(self, state, t):
"""Computes the next state and log weights for the GHMM."""
state_shape = tf.shape(state)
xt = self.observations[t]
p_zt = self.transition(t, state)
q_zt = self.proposal(t, state)
zt = q_zt.sample(seed=self.random_seed)
zt = tf.reshape(zt, state_shape)
p_xt_given_zt = self.emission(t, zt)
log_p_zt = p_zt.log_prob(zt)
log_q_zt = q_zt.log_prob(zt)
log_p_xt_given_zt = p_xt_given_zt.log_prob(xt)
weight = log_p_zt + log_p_xt_given_zt - log_q_zt
return weight, zt
def _filtering_proposal(self, t, state):
"""Uses the stored observations to compute the filtering distribution."""
cur_x = self.observations[t]
return self.filtering(t, state, cur_x)
def _smoothing_proposal(self, t, state):
"""Uses the stored messages to compute the smoothing distribution."""
mess_mean, mess_prec = tf.cond(
tf.less(t, self.num_timesteps-1),
lambda: tf.unstack(self.messages.read(t)),
lambda: [tf.zeros([self.batch_size]), tf.zeros([self.batch_size])])
return self._smoothing_from_message(t, state, self.observations[t],
mess_mean, mess_prec)
def _learned_filtering_proposal(self, t, state):
cur_x = self.observations[t]
inputs = tf.cond(tf.greater(t, 0),
lambda: tf.stack([state, cur_x], axis=0),
lambda: cur_x[tf.newaxis, :])
return self._learned_proposal_fn(t, inputs)
def _learned_smoothing_proposal(self, t, state):
xs = self.observations_ta.gather(tf.range(t, self.num_timesteps))
inputs = tf.cond(tf.greater(t, 0),
lambda: tf.concat([state[tf.newaxis, :], xs], axis=0),
lambda: xs)
return self._learned_proposal_fn(t, inputs)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.models.ghmm"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from fivo.models.ghmm import GaussianHMM
from fivo.models.ghmm import TrainableGaussianHMM
class GHMMTest(tf.test.TestCase):
def test_transition_no_weights(self):
with self.test_session() as sess:
ghmm = GaussianHMM(3,
transition_variances=[1., 2., 3.])
prev_z = tf.constant([1., 2.], dtype=tf.float32)
z0 = ghmm.transition(0, prev_z)
z1 = ghmm.transition(1, prev_z)
z2 = ghmm.transition(2, prev_z)
outs = sess.run([z0.mean(), z0.variance(),
z1.mean(), z1.variance(),
z2.mean(), z2.variance()])
self.assertAllClose(outs, [[0., 0.], [1., 1.],
[1., 2.], [2., 2.],
[1., 2.], [3., 3.]])
def test_transition_with_weights(self):
with self.test_session() as sess:
ghmm = GaussianHMM(3,
transition_variances=[1., 2., 3.],
transition_weights=[2., 3.])
prev_z = tf.constant([1., 2.], dtype=tf.float32)
z0 = ghmm.transition(0, prev_z)
z1 = ghmm.transition(1, prev_z)
z2 = ghmm.transition(2, prev_z)
outs = sess.run([z0.mean(), z0.variance(),
z1.mean(), z1.variance(),
z2.mean(), z2.variance()])
self.assertAllClose(outs, [[0., 0.], [1., 1.],
[2., 4.], [2., 2.],
[3., 6.], [3., 3.]])
def test_emission_no_weights(self):
with self.test_session() as sess:
ghmm = GaussianHMM(3, emission_variances=[1., 2., 3.])
z = tf.constant([1., 2.], dtype=tf.float32)
x0 = ghmm.emission(0, z)
x1 = ghmm.emission(1, z)
x2 = ghmm.emission(2, z)
outs = sess.run([x0.mean(), x0.variance(),
x1.mean(), x1.variance(),
x2.mean(), x2.variance()])
self.assertAllClose(outs, [[1., 2.], [1., 1.],
[1., 2.], [2., 2.],
[1., 2.], [3., 3.]])
def test_emission_with_weights(self):
with self.test_session() as sess:
ghmm = GaussianHMM(3,
emission_variances=[1., 2., 3.],
emission_weights=[1., 2., 3.])
z = tf.constant([1., 2.], dtype=tf.float32)
x0 = ghmm.emission(0, z)
x1 = ghmm.emission(1, z)
x2 = ghmm.emission(2, z)
outs = sess.run([x0.mean(), x0.variance(),
x1.mean(), x1.variance(),
x2.mean(), x2.variance()])
self.assertAllClose(outs, [[1., 2.], [1., 1.],
[2., 4.], [2., 2.],
[3., 6.], [3., 3.]])
def test_filtering_no_weights(self):
with self.test_session() as sess:
ghmm = GaussianHMM(3,
transition_variances=[1., 2., 3.],
emission_variances=[4., 5., 6.])
z_prev = tf.constant([1., 2.], dtype=tf.float32)
x_cur = tf.constant([3., 4.], dtype=tf.float32)
expected_outs = [[[3./5., 4./5.], [4./5., 4./5.]],
[[11./7., 18./7.], [10./7., 10./7.]],
[[5./3., 8./3.], [2., 2.]]]
f_post_0 = ghmm.filtering(0, z_prev, x_cur)
f_post_1 = ghmm.filtering(1, z_prev, x_cur)
f_post_2 = ghmm.filtering(2, z_prev, x_cur)
outs = sess.run([[f_post_0.mean(), f_post_0.variance()],
[f_post_1.mean(), f_post_1.variance()],
[f_post_2.mean(), f_post_2.variance()]])
self.assertAllClose(expected_outs, outs)
def test_filtering_with_weights(self):
with self.test_session() as sess:
ghmm = GaussianHMM(3,
transition_variances=[1., 2., 3.],
emission_variances=[4., 5., 6.],
transition_weights=[7., 8.],
emission_weights=[9., 10., 11])
z_prev = tf.constant([1., 2.], dtype=tf.float32)
x_cur = tf.constant([3., 4.], dtype=tf.float32)
expected_outs = [[[27./85., 36./85.], [4./85., 4./85.]],
[[95./205., 150./205.], [10./205., 10./205.]],
[[147./369., 228./369.], [18./369., 18./369.]]]
f_post_0 = ghmm.filtering(0, z_prev, x_cur)
f_post_1 = ghmm.filtering(1, z_prev, x_cur)
f_post_2 = ghmm.filtering(2, z_prev, x_cur)
outs = sess.run([[f_post_0.mean(), f_post_0.variance()],
[f_post_1.mean(), f_post_1.variance()],
[f_post_2.mean(), f_post_2.variance()]])
self.assertAllClose(expected_outs, outs)
def test_smoothing(self):
with self.test_session() as sess:
ghmm = GaussianHMM(3,
transition_variances=[1., 2., 3.],
emission_variances=[4., 5., 6.])
z_prev = tf.constant([1., 2.], dtype=tf.float32)
xs = tf.constant([[1., 2.],
[3., 4.],
[5., 6.]], dtype=tf.float32)
s_post1 = ghmm.smoothing(0, z_prev, xs)
outs = sess.run([s_post1.mean(), s_post1.variance()])
expected_outs = [[281./421., 410./421.], [292./421., 292./421.]]
self.assertAllClose(expected_outs, outs)
expected_outs = [[149./73., 222./73.], [90./73., 90./73.]]
s_post2 = ghmm.smoothing(1, z_prev, xs[1:])
outs = sess.run([s_post2.mean(), s_post2.variance()])
self.assertAllClose(expected_outs, outs)
s_post3 = ghmm.smoothing(2, z_prev, xs[2:])
outs = sess.run([s_post3.mean(), s_post3.variance()])
expected_outs = [[7./3., 10./3.], [2., 2.]]
self.assertAllClose(expected_outs, outs)
def test_smoothing_with_weights(self):
with self.test_session() as sess:
x_weight = np.array([4, 5, 6, 7], dtype=np.float32)
sigma_x = np.array([5, 6, 7, 8], dtype=np.float32)
z_weight = np.array([1, 2, 3], dtype=np.float32)
sigma_z = np.array([1, 2, 3, 4], dtype=np.float32)
z_prev = np.array([1, 2], dtype=np.float32)
batch_size = 2
xs = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32)
z_cov, x_cov, z_x_cov = self._compute_covariance_matrices(
x_weight, z_weight, sigma_x, sigma_z)
expected_outs = []
# Compute mean and variance for z_0 when we don't condition
# on previous zs.
sigma_12 = z_x_cov[0, :]
sigma_12_22 = np.dot(sigma_12, np.linalg.inv(x_cov))
mean = np.dot(sigma_12_22, xs)
variance = np.squeeze(z_cov[0, 0] - np.dot(sigma_12_22, sigma_12))
expected_outs.append([mean, np.tile(variance, [batch_size])])
# Compute mean and variance for remaining z_ts.
for t in xrange(1, 4):
sigma_12 = np.concatenate([[z_cov[t, t - 1]], z_x_cov[t, t:]])
sigma_22 = np.vstack((
np.hstack((z_cov[t-1, t-1], z_x_cov[t-1, t:])),
np.hstack((np.transpose([z_x_cov[t-1, t:]]), x_cov[t:, t:]))
))
sigma_12_22 = np.dot(sigma_12, np.linalg.inv(sigma_22))
mean = np.dot(sigma_12_22, np.vstack((z_prev, xs[t:])))
variance = np.squeeze(z_cov[t, t] - np.dot(sigma_12_22, sigma_12))
expected_outs.append([mean, np.tile(variance, [batch_size])])
ghmm = GaussianHMM(4,
transition_variances=sigma_z,
emission_variances=sigma_x,
transition_weights=z_weight,
emission_weights=x_weight)
out_dists = [ghmm.smoothing(t, z_prev, xs[t:]) for t in range(0, 4)]
outs = [[d.mean(), d.variance()] for d in out_dists]
run_outs = sess.run(outs)
self.assertAllClose(expected_outs, run_outs)
def test_covariance_matrices(self):
with self.test_session() as sess:
x_weight = np.array([4, 5, 6, 7], dtype=np.float32)
sigma_x = np.array([5, 6, 7, 8], dtype=np.float32)
z_weight = np.array([1, 2, 3], dtype=np.float32)
sigma_z = np.array([1, 2, 3, 4], dtype=np.float32)
z_cov, x_cov, z_x_cov = self._compute_covariance_matrices(
x_weight, z_weight, sigma_x, sigma_z)
ghmm = GaussianHMM(4,
transition_variances=sigma_z,
emission_variances=sigma_x,
transition_weights=z_weight,
emission_weights=x_weight)
self.assertAllClose(z_cov, sess.run(ghmm.sigma_z))
self.assertAllClose(x_cov, sess.run(ghmm.sigma_x))
self.assertAllClose(z_x_cov, sess.run(ghmm.sigma_zx))
def _compute_covariance_matrices(self, x_weight, z_weight, sigma_x, sigma_z):
# Create z covariance matrix from the definitions.
z_cov = np.zeros([4, 4])
z_cov[0, 0] = sigma_z[0]
for i in range(1, 4):
z_cov[i, i] = (z_cov[i - 1, i - 1] * np.square(z_weight[i - 1]) +
sigma_z[i])
for i in range(4):
for j in range(4):
if i == j: continue
min_ind = min(i, j)
max_ind = max(i, j)
weights = np.prod(z_weight[min_ind:max_ind])
z_cov[i, j] = z_cov[min_ind, min_ind] * weights
# Compute the x covariance matrix and the z-x covariance matrix.
x_weights_outer = np.outer(x_weight, x_weight)
x_cov = x_weights_outer * z_cov + np.diag(sigma_x)
z_x_cov = x_weight * z_cov
return z_cov, x_cov, z_x_cov
def test_lookahead(self):
x_weight = np.array([4, 5, 6, 7], dtype=np.float32)
sigma_x = np.array([5, 6, 7, 8], dtype=np.float32)
z_weight = np.array([1, 2, 3], dtype=np.float32)
sigma_z = np.array([1, 2, 3, 4], dtype=np.float32)
z_prev = np.array([1, 2], dtype=np.float32)
with self.test_session() as sess:
z_cov, x_cov, z_x_cov = self._compute_covariance_matrices(
x_weight, z_weight, sigma_x, sigma_z)
expected_outs = []
for t in range(1, 4):
sigma_12 = z_x_cov[t-1, t:]
z_var = z_cov[t-1, t-1]
mean = np.outer(z_prev, sigma_12/z_var)
variance = x_cov[t:, t:] - np.outer(sigma_12, sigma_12)/ z_var
expected_outs.append([mean, variance])
ghmm = GaussianHMM(4,
transition_variances=sigma_z,
emission_variances=sigma_x,
transition_weights=z_weight,
emission_weights=x_weight)
out_dists = [ghmm.lookahead(t, z_prev) for t in range(1, 4)]
outs = [[d.mean(), d.covariance()] for d in out_dists]
run_outs = sess.run(outs)
self.assertAllClose(expected_outs, run_outs)
class TrainableGHMMTest(tf.test.TestCase):
def test_filtering_proposal(self):
"""Check that stashing the xs doesn't change the filtering distributions."""
with self.test_session() as sess:
ghmm = TrainableGaussianHMM(
3, "filtering",
transition_variances=[1., 2., 3.],
emission_variances=[4., 5., 6.],
transition_weights=[7., 8.],
emission_weights=[9., 10., 11])
observations = tf.constant([[3., 4.],
[3., 4.],
[3., 4.]], dtype=tf.float32)
ghmm.set_observations(observations, [3, 3])
z_prev = tf.constant([1., 2.], dtype=tf.float32)
proposals = [ghmm._filtering_proposal(t, z_prev) for t in range(3)]
dist_params = [[p.mean(), p.variance()] for p in proposals]
expected_outs = [[[27./85., 36./85.], [4./85., 4./85.]],
[[95./205., 150./205.], [10./205., 10./205.]],
[[147./369., 228./369.], [18./369., 18./369.]]]
self.assertAllClose(expected_outs, sess.run(dist_params))
def test_smoothing_proposal(self):
with self.test_session() as sess:
ghmm = TrainableGaussianHMM(
3, "smoothing",
transition_variances=[1., 2., 3.],
emission_variances=[4., 5., 6.])
xs = tf.constant([[1., 2.],
[3., 4.],
[5., 6.]], dtype=tf.float32)
ghmm.set_observations(xs, [3, 3])
z_prev = tf.constant([1., 2.], dtype=tf.float32)
proposals = [ghmm._smoothing_proposal(t, z_prev) for t in range(3)]
dist_params = [[p.mean(), p.variance()] for p in proposals]
expected_outs = [[[281./421., 410./421.], [292./421., 292./421.]],
[[149./73., 222./73.], [90./73., 90./73.]],
[[7./3., 10./3.], [2., 2.]]]
self.assertAllClose(expected_outs, sess.run(dist_params))
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SRNN classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import functools
import sonnet as snt
import tensorflow as tf
from fivo.models import base
SRNNState = namedtuple("SRNNState", "rnn_state latent_encoded")
class SRNN(object):
"""Implementation of a Stochastic Recurrent Neural Network (SRNN).
Introduced in "Sequential Neural Models with Stochastic Layers"
by Fraccaro et al. https://arxiv.org/pdf/1605.07571.pdf.
The SRNN is a sequence model similar to an RNN that uses stochastic latent
variables to improve its representational power. It can be thought of as a
sequential analogue to the variational auto-encoder (VAE).
The SRNN has a deterministic RNN as its backbone, represented by the
sequence of RNN hidden states h_t. The latent state is conditioned on
the deterministic RNN states and previous latent state. Unlike the VRNN, the
the RNN state is not conditioned on the previous latent state. The latent
states have a Markov structure and it is assumed that
p(z_t | z_{1:t-1}) = p(z_t | z_{t-1}).
In this implementation of the SRNN the latent state z_t is Gaussian. The
model's prior over z_t (also called the transition distribution) is
distributed as Normal(mu_t, diag(sigma_t^2)) where mu_t and sigma_t are the
mean and standard deviation output from a fully connected network that accepts
the rnn hidden state h_t and previous latent state z_{t-1} as input.
The emission distribution p(x_t|z_t, h_t) is conditioned on the latent state
z_t as well as the current RNN hidden state h_t via a fully connected network.
To increase the modeling power of the SRNN, two additional networks are
used to extract features from the data and the latent state. Those networks
are called data_encoder and latent_encoder respectively.
For an example of how to call the SRNN's methods see sample_step.
There are a few differences between this exposition and the paper. The main
goal was to be consistent with the VRNN code. A few components are renamed.
The backward RNN for approximating the posterior, g_phi_a in the paper, is the
rev_rnn_cell. The forward RNN that conditions the latent distribution, d in
the paper, is the rnn_cell. The paper doesn't name the NN's that serve as
feature extractors, and we name them here as the data_encoder and
latent_encoder.
"""
def __init__(self,
rnn_cell,
data_encoder,
latent_encoder,
transition,
emission,
random_seed=None):
"""Create a SRNN.
Args:
rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
deterministic backbone of the SRNN. The inputs to the RNN will be the
the encoded input of the current timestep, a Tensor of shape
[batch_size, encoded_data_size].
data_encoder: A callable that accepts a batch of data x_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument the inputs x_t, a Tensor of the shape
[batch_size, data_size] and return a Tensor of shape
[batch_size, encoded_data_size]. This callable will be called multiple
times in the SRNN cell so if scoping is not handled correctly then
multiple copies of the variables in this network could be made. It is
recommended to use a snt.nets.MLP module, which takes care of this for
you.
latent_encoder: A callable that accepts a latent state z_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument a Tensor of shape [batch_size, latent_size] and
return a Tensor of shape [batch_size, encoded_latent_size].
This callable must also have the property 'output_size' defined,
returning encoded_latent_size.
transition: A callable that implements the transition distribution
p(z_t|h_t, z_t-1). Must accept as argument the previous RNN hidden state
and previous encoded latent state then return a tf.distributions.Normal
distribution conditioned on the input.
emission: A callable that implements the emission distribution
p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
and the RNN hidden state and return a subclass of
tf.distributions.Distribution that can be used to evaluate the logprob
of the targets.
random_seed: The seed for the random ops. Sets the seed for sample_step.
"""
self.random_seed = random_seed
self.rnn_cell = rnn_cell
self.data_encoder = data_encoder
self.latent_encoder = latent_encoder
self.encoded_z_size = latent_encoder.output_size
self.state_size = (self.rnn_cell.state_size)
self._transition = transition
self._emission = emission
def zero_state(self, batch_size, dtype):
"""The initial state of the SRNN.
Contains the initial state of the RNN and the inital encoded latent.
Args:
batch_size: The batch size.
dtype: The data type of the SRNN.
Returns:
zero_state: The initial state of the SRNN.
"""
return SRNNState(
rnn_state=self.rnn_cell.zero_state(batch_size, dtype),
latent_encoded=tf.zeros(
[batch_size, self.latent_encoder.output_size], dtype=dtype))
def run_rnn(self, prev_rnn_state, inputs):
"""Runs the deterministic RNN for one step.
Args:
prev_rnn_state: The state of the RNN from the previous timestep.
inputs: A Tensor of shape [batch_size, data_size], the current inputs to
the model. Most often this is x_{t-1}, the previous token in the
observation sequence.
Returns:
rnn_out: The output of the RNN.
rnn_state: The new state of the RNN.
"""
rnn_inputs = self.data_encoder(tf.to_float(inputs))
rnn_out, rnn_state = self.rnn_cell(rnn_inputs, prev_rnn_state)
return rnn_out, rnn_state
def transition(self, rnn_out, prev_latent_encoded):
"""Computes the transition distribution p(z_t|h_t, z_{t-1}).
Note that p(z_t | h_t, z_{t-1}) = p(z_t| z_{t-1}, x_{1:t-1})
Args:
rnn_out: The output of the rnn for the current timestep.
prev_latent_encoded: Float Tensor of shape
[batch_size, encoded_latent_size], the previous latent state z_{t-1}
run through latent_encoder.
Returns:
p(z_t | h_t): A normal distribution with event shape
[batch_size, latent_size].
"""
return self._transition(rnn_out, prev_latent_encoded)
def emission(self, latent, rnn_out):
"""Computes the emission distribution p(x_t | z_t, h_t).
Note that p(x_t | z_t, h_t) = p(x_t | z_t, x_{1:t-1})
Args:
latent: The stochastic latent state z_t.
rnn_out: The output of the rnn for the current timestep.
Returns:
p(x_t | z_t, h_t): A distribution with event shape
[batch_size, data_size].
latent_encoded: The latent state encoded with latent_encoder. Should be
passed to transition() on the next timestep.
"""
latent_encoded = self.latent_encoder(latent)
return self._emission(latent_encoded, rnn_out), latent_encoded
def sample_step(self, prev_state, inputs, unused_t):
"""Samples one output from the model.
Args:
prev_state: The previous state of the model, a SRNNState containing the
previous rnn state and the previous encoded latent.
inputs: A Tensor of shape [batch_size, data_size], the current inputs to
the model. Most often this is x_{t-1}, the previous token in the
observation sequence.
unused_t: The current timestep. Not used currently.
Returns:
new_state: The next state of the model, a SRNNState.
xt: A float Tensor of shape [batch_size, data_size], an output sampled
from the emission distribution.
"""
rnn_out, rnn_state = self.run_rnn(prev_state.rnn_state,
inputs)
p_zt = self.transition(rnn_out, prev_state.latent_encoded)
zt = p_zt.sample(seed=self.random_seed)
p_xt_given_zt, latent_encoded = self.emission(zt, rnn_out)
xt = p_xt_given_zt.sample(seed=self.random_seed)
new_state = SRNNState(rnn_state=rnn_state, latent_encoded=latent_encoded)
return new_state, tf.to_float(xt)
# pylint: disable=invalid-name
# pylint thinks this is a top-level constant.
TrainableSRNNState = namedtuple("TrainableSRNNState",
SRNNState._fields + ("rnn_out",))
# pylint: enable=g-invalid-name
class TrainableSRNN(SRNN, base.ELBOTrainableSequenceModel):
"""A SRNN subclass with proposals and methods for training and evaluation.
This class adds proposals used for training with importance-sampling based
methods such as the ELBO. The model can be configured to propose from one
of three proposals: a learned filtering proposal, a learned smoothing
proposal, or the prior (i.e. the transition distribution).
As described in the SRNN paper, the learned filtering proposal is
parameterized by a fully connected neural network that accepts as input the
current target x_t and the current rnn output h_t. The learned smoothing
proposal is also given the hidden state of an RNN run in reverse over the
inputs, so as to incorporate information about future observations.
All learned proposals use the 'res_q' parameterization, meaning that instead
of directly producing the mean of z_t, the proposal network predicts the
'residual' from the prior's mean. This is explored more in section 3.3 of
https://arxiv.org/pdf/1605.07571.pdf.
During training, the latent state z_t is sampled from the proposal and the
reparameterization trick is used to provide low-variance gradients.
Note that the SRNN paper refers to the proposals as the approximate posterior,
but we match the VRNN convention of referring to it as the encoder.
"""
def __init__(self,
rnn_cell,
data_encoder,
latent_encoder,
transition,
emission,
proposal_type,
proposal=None,
rev_rnn_cell=None,
tilt=None,
random_seed=None):
"""Create a trainable RNN.
Args:
rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
deterministic backbone of the SRNN. The inputs to the RNN will be the
the encoded input of the current timestep, a Tensor of shape
[batch_size, encoded_data_size].
data_encoder: A callable that accepts a batch of data x_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument the inputs x_t, a Tensor of the shape
[batch_size, data_size] and return a Tensor of shape
[batch_size, encoded_data_size]. This callable will be called multiple
times in the SRNN cell so if scoping is not handled correctly then
multiple copies of the variables in this network could be made. It is
recommended to use a snt.nets.MLP module, which takes care of this for
you.
latent_encoder: A callable that accepts a latent state z_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument a Tensor of shape [batch_size, latent_size] and
return a Tensor of shape [batch_size, encoded_latent_size].
This callable must also have the property 'output_size' defined,
returning encoded_latent_size.
transition: A callable that implements the transition distribution
p(z_t|h_t, z_t-1). Must accept as argument the previous RNN hidden state
and previous encoded latent state then return a tf.distributions.Normal
distribution conditioned on the input.
emission: A callable that implements the emission distribution
p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
and the RNN hidden state and return a subclass of
tf.distributions.Distribution that can be used to evaluate the logprob
of the targets.
proposal_type: A string indicating the type of proposal to use. Can
be either "filtering", "smoothing", or "prior". When proposal_type is
"filtering" or "smoothing", proposal must be provided. When
proposal_type is "smoothing", rev_rnn_cell must also be provided.
proposal: A callable that implements the proposal q(z_t| h_t, x_{1:T}).
If proposal_type is "filtering" then proposal must accept as arguments
the current rnn output, the encoded target of the current timestep,
and the mean of the prior. If proposal_type is "smoothing" then
in addition to the current rnn output and the mean of the prior
proposal must accept as arguments the output of the reverse rnn.
proposal should return a tf.distributions.Normal distribution
conditioned on its inputs. If proposal_type is "prior" this argument is
ignored.
rev_rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will aggregate
forward rnn outputs in the reverse direction. The inputs to the RNN
will be the encoded reverse input of the current timestep, a Tensor of
shape [batch_size, encoded_data_size].
tilt: A callable that implements the log of a positive tilting function
(ideally approximating log p(x_{t+1}|z_t, h_t). Must accept as arguments
the encoded latent state and the RNN hidden state and return a subclass
of tf.distributions.Distribution that can be used to evaluate the
logprob of x_{t+1}. Optionally, None and then no tilt is used.
random_seed: The seed for the random ops. Sets the seed for sample_step
and __call__.
"""
super(TrainableSRNN, self).__init__(
rnn_cell, data_encoder, latent_encoder,
transition, emission, random_seed=random_seed)
self.rev_rnn_cell = rev_rnn_cell
self._tilt = tilt
assert proposal_type in ["filtering", "smoothing", "prior"]
self._proposal = proposal
self.proposal_type = proposal_type
if proposal_type != "prior":
assert proposal, "If not proposing from the prior, must provide proposal."
if proposal_type == "smoothing":
assert rev_rnn_cell, "Must provide rev_rnn_cell for smoothing proposal."
def zero_state(self, batch_size, dtype):
super_state = super(TrainableSRNN, self).zero_state(batch_size, dtype)
return TrainableSRNNState(
rnn_out=tf.zeros([batch_size, self.rnn_cell.output_size], dtype=dtype),
**super_state._asdict())
def set_observations(self, observations, seq_lengths):
"""Stores the model's observations.
Stores the observations (inputs and targets) in TensorArrays and precomputes
things for later like the reverse RNN output and encoded targets.
Args:
observations: The observations of the model, a tuple containing two
Tensors of shape [max_seq_len, batch_size, data_size]. The Tensors
should be the inputs and targets, respectively.
seq_lengths: An int Tensor of shape [batch_size] containing the length
of each sequence in observations.
"""
inputs, targets = observations
self.seq_lengths = seq_lengths
self.max_seq_len = tf.reduce_max(seq_lengths)
self.targets_ta = base.ta_for_tensor(targets, clear_after_read=False)
targets_encoded = base.encode_all(targets, self.data_encoder)
self.targets_encoded_ta = base.ta_for_tensor(targets_encoded,
clear_after_read=False)
inputs_encoded = base.encode_all(inputs, self.data_encoder)
rnn_out, _ = tf.nn.dynamic_rnn(self.rnn_cell,
inputs_encoded,
time_major=True,
dtype=tf.float32,
scope="forward_rnn")
self.rnn_ta = base.ta_for_tensor(rnn_out,
clear_after_read=False)
if self.rev_rnn_cell:
targets_and_rnn_out = tf.concat([rnn_out, targets_encoded], 2)
reversed_targets_and_rnn_out = tf.reverse_sequence(
targets_and_rnn_out, seq_lengths, seq_axis=0, batch_axis=1)
# Compute the reverse rnn over the targets.
reverse_rnn_out, _ = tf.nn.dynamic_rnn(self.rev_rnn_cell,
reversed_targets_and_rnn_out,
time_major=True,
dtype=tf.float32,
scope="reverse_rnn")
reverse_rnn_out = tf.reverse_sequence(reverse_rnn_out, seq_lengths,
seq_axis=0, batch_axis=1)
self.reverse_rnn_ta = base.ta_for_tensor(reverse_rnn_out,
clear_after_read=False)
def _filtering_proposal(self, rnn_out, prev_latent_encoded, prior, t):
"""Computes the filtering proposal distribution."""
return self._proposal(rnn_out,
prev_latent_encoded,
self.targets_encoded_ta.read(t),
prior_mu=prior.mean())
def _smoothing_proposal(self, rnn_out, prev_latent_encoded, prior, t):
"""Computes the smoothing proposal distribution."""
return self._proposal(rnn_out,
prev_latent_encoded,
smoothing_tensors=[self.reverse_rnn_ta.read(t)],
prior_mu=prior.mean())
def proposal(self, rnn_out, prev_latent_encoded, prior, t):
"""Computes the proposal distribution specified by proposal_type.
Args:
rnn_out: The output of the rnn for the current timestep.
prev_latent_encoded: Float Tensor of shape
[batch_size, encoded_latent_size], the previous latent state z_{t-1}
run through latent_encoder.
prior: A tf.distributions.Normal distribution representing the prior
over z_t, p(z_t | z_{1:t-1}, x_{1:t-1}). Used for 'res_q'.
t: A scalar int Tensor, the current timestep.
"""
if self.proposal_type == "filtering":
return self._filtering_proposal(rnn_out, prev_latent_encoded, prior, t)
elif self.proposal_type == "smoothing":
return self._smoothing_proposal(rnn_out, prev_latent_encoded, prior, t)
elif self.proposal_type == "prior":
return self.transition(rnn_out, prev_latent_encoded)
def tilt(self, rnn_out, latent_encoded, targets):
r_func = self._tilt(rnn_out, latent_encoded)
return tf.reduce_sum(r_func.log_prob(targets), axis=-1)
def propose_and_weight(self, state, t):
"""Runs the model and computes importance weights for one timestep.
Runs the model and computes importance weights, sampling from the proposal
instead of the transition/prior.
Args:
state: The previous state of the model, a TrainableSRNNState containing
the previous rnn state, the previous rnn outs, and the previous encoded
latent.
t: A scalar integer Tensor, the current timestep.
Returns:
weights: A float Tensor of shape [batch_size].
new_state: The new state of the model.
"""
targets = self.targets_ta.read(t)
rnn_out = self.rnn_ta.read(t)
p_zt = self.transition(rnn_out, state.latent_encoded)
q_zt = self.proposal(rnn_out, state.latent_encoded, p_zt, t)
zt = q_zt.sample(seed=self.random_seed)
p_xt_given_zt, latent_encoded = self.emission(zt, rnn_out)
log_p_xt_given_zt = tf.reduce_sum(p_xt_given_zt.log_prob(targets), axis=-1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=-1)
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=-1)
weights = log_p_zt + log_p_xt_given_zt - log_q_zt
if self._tilt:
prev_log_r = tf.cond(
tf.greater(t, 0),
lambda: self.tilt(state.rnn_out, state.latent_encoded, targets),
lambda: 0.) # On the first step, prev_log_r = 0.
log_r = tf.cond(
tf.less(t + 1, self.max_seq_len),
lambda: self.tilt(rnn_out, latent_encoded, self.targets_ta.read(t+1)),
lambda: 0.)
# On the last step, log_r = 0.
log_r *= tf.to_float(t < self.seq_lengths - 1)
weights += log_r - prev_log_r
# This reshape is required because the TensorArray reports different shapes
# than the initial state provides (where the first dimension is unknown).
# The difference breaks the while_loop. Reshape prevents the error.
rnn_out = tf.reshape(rnn_out, tf.shape(state.rnn_out))
new_state = TrainableSRNNState(rnn_out=rnn_out,
rnn_state=state.rnn_state, # unmodified
latent_encoded=latent_encoded)
return weights, new_state
_DEFAULT_INITIALIZERS = {"w": tf.contrib.layers.xavier_initializer(),
"b": tf.zeros_initializer()}
def create_srnn(
data_size,
latent_size,
emission_class,
rnn_hidden_size=None,
fcnet_hidden_sizes=None,
encoded_data_size=None,
encoded_latent_size=None,
sigma_min=0.0,
raw_sigma_bias=0.25,
emission_bias_init=0.0,
use_tilt=False,
proposal_type="filtering",
initializers=None,
random_seed=None):
"""A factory method for creating SRNN cells.
Args:
data_size: The dimension of the vectors that make up the data sequences.
latent_size: The size of the stochastic latent state of the SRNN.
emission_class: The class of the emission distribution. Can be either
ConditionalNormalDistribution or ConditionalBernoulliDistribution.
rnn_hidden_size: The hidden state dimension of the RNN that forms the
deterministic part of this SRNN. If None, then it defaults
to latent_size.
fcnet_hidden_sizes: A list of python integers, the size of the hidden
layers of the fully connected networks that parameterize the conditional
distributions of the SRNN. If None, then it defaults to one hidden
layer of size latent_size.
encoded_data_size: The size of the output of the data encoding network. If
None, defaults to latent_size.
encoded_latent_size: The size of the output of the latent state encoding
network. If None, defaults to latent_size.
sigma_min: The minimum value that the standard deviation of the
distribution over the latent state can take.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the neural networks that parameterize the prior and
approximate posterior. Useful for preventing standard deviations close
to zero.
emission_bias_init: A bias to added to the raw output of the fully
connected network that parameterizes the emission distribution. Useful
for initalizing the mean of the distribution to a sensible starting point
such as the mean of the training data. Only used with Bernoulli generative
distributions.
use_tilt: If true, create a SRNN with a tilting function.
proposal_type: The type of proposal to use. Can be "filtering", "smoothing",
or "prior".
initializers: The variable intitializers to use for the fully connected
networks and RNN cell. Must be a dictionary mapping the keys 'w' and 'b'
to the initializers for the weights and biases. Defaults to xavier for
the weights and zeros for the biases when initializers is None.
random_seed: A random seed for the SRNN resampling operations.
Returns:
model: A TrainableSRNN object.
"""
if rnn_hidden_size is None:
rnn_hidden_size = latent_size
if fcnet_hidden_sizes is None:
fcnet_hidden_sizes = [latent_size]
if encoded_data_size is None:
encoded_data_size = latent_size
if encoded_latent_size is None:
encoded_latent_size = latent_size
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
data_encoder = snt.nets.MLP(
output_sizes=fcnet_hidden_sizes + [encoded_data_size],
initializers=initializers,
name="data_encoder")
latent_encoder = snt.nets.MLP(
output_sizes=fcnet_hidden_sizes + [encoded_latent_size],
initializers=initializers,
name="latent_encoder")
transition = base.ConditionalNormalDistribution(
size=latent_size,
hidden_layer_sizes=fcnet_hidden_sizes,
sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
initializers=initializers,
name="prior")
# Construct the emission distribution.
if emission_class == base.ConditionalBernoulliDistribution:
# For Bernoulli distributed outputs, we initialize the bias so that the
# network generates on average the mean from the training set.
emission_dist = functools.partial(base.ConditionalBernoulliDistribution,
bias_init=emission_bias_init)
else:
emission_dist = base.ConditionalNormalDistribution
emission = emission_dist(
size=data_size,
hidden_layer_sizes=fcnet_hidden_sizes,
initializers=initializers,
name="generative")
# Construct the proposal distribution.
if proposal_type in ["filtering", "smoothing"]:
proposal = base.NormalApproximatePosterior(
size=latent_size,
hidden_layer_sizes=fcnet_hidden_sizes,
sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
initializers=initializers,
smoothing=(proposal_type == "smoothing"),
name="approximate_posterior")
else:
proposal = None
if use_tilt:
tilt = emission_dist(
size=data_size,
hidden_layer_sizes=fcnet_hidden_sizes,
initializers=initializers,
name="tilt")
else:
tilt = None
rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden_size,
initializer=initializers["w"])
rev_rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden_size,
initializer=initializers["w"])
return TrainableSRNN(
rnn_cell, data_encoder, latent_encoder, transition,
emission, proposal_type, proposal=proposal, rev_rnn_cell=rev_rnn_cell,
tilt=tilt, random_seed=random_seed)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.models.srnn."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from fivo.models import base
from fivo.test_utils import create_srnn
class SrnnTest(tf.test.TestCase):
def test_srnn_normal_emission(self):
self.run_srnn(base.ConditionalNormalDistribution, [-5.947752, -1.182961])
def test_srnn_bernoulli_emission(self):
self.run_srnn(base.ConditionalBernoulliDistribution, [-2.566631, -2.479234])
def run_srnn(self, generative_class, gt_log_alpha):
"""Tests the SRNN.
All test values are 'golden values' derived by running the code and copying
the output.
Args:
generative_class: The class of the generative distribution to use.
gt_log_alpha: The ground-truth value of log alpha.
"""
tf.set_random_seed(1234)
with self.test_session() as sess:
batch_size = 2
model, inputs, targets, _ = create_srnn(generative_class=generative_class,
batch_size=batch_size,
data_lengths=(1, 1),
random_seed=1234)
zero_state = model.zero_state(batch_size=batch_size, dtype=tf.float32)
model.set_observations([inputs, targets], tf.convert_to_tensor([1, 1]))
model_out = model.propose_and_weight(zero_state, 0)
sess.run(tf.global_variables_initializer())
log_alpha, state = sess.run(model_out)
self.assertAllClose(
state.latent_encoded,
[[0.591787, 1.310583], [-1.523136, 0.953918]])
self.assertAllClose(state.rnn_out,
[[0.041675, -0.056038, -0.001823, 0.005224],
[0.042925, -0.044619, 0.021401, 0.016998]])
self.assertAllClose(log_alpha, gt_log_alpha)
def test_srnn_with_tilt_normal_emission(self):
self.run_srnn_with_tilt(base.ConditionalNormalDistribution, [-9.13577, -4.56725])
def test_srnn_with_tilt_bernoulli_emission(self):
self.run_srnn_with_tilt(base.ConditionalBernoulliDistribution, [-4.617461, -5.079248])
def run_srnn_with_tilt(self, generative_class, gt_log_alpha):
"""Tests the SRNN with a tilting function.
All test values are 'golden values' derived by running the code and copying
the output.
Args:
generative_class: The class of the generative distribution to use.
gt_log_alpha: The ground-truth value of log alpha.
"""
tf.set_random_seed(1234)
with self.test_session() as sess:
batch_size = 2
model, inputs, targets, _ = create_srnn(generative_class=generative_class,
batch_size=batch_size,
data_lengths=(3, 2),
random_seed=1234,
use_tilt=True)
zero_state = model.zero_state(batch_size=batch_size, dtype=tf.float32)
model.set_observations([inputs, targets], tf.convert_to_tensor([3, 2]))
model_out = model.propose_and_weight(zero_state, 0)
sess.run(tf.global_variables_initializer())
log_alpha, state = sess.run(model_out)
self.assertAllClose(
state.latent_encoded,
[[0.591787, 1.310583], [-1.523136, 0.953918]])
self.assertAllClose(state.rnn_out,
[[0.041675, -0.056038, -0.001823, 0.005224],
[0.042925, -0.044619, 0.021401, 0.016998]])
self.assertAllClose(log_alpha, gt_log_alpha)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment