Unverified Commit a8ba923c authored by Jaeyoun Kim's avatar Jaeyoun Kim Committed by GitHub
Browse files

Deprecate old models (#8934)

Deprecate old models
parent 5eb294f8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import tensorflow as tf
import summary_utils as summ
Loss = namedtuple("Loss", "name loss vars")
Loss.__new__.__defaults__ = (tf.GraphKeys.TRAINABLE_VARIABLES,)
def iwae(model, observation, num_timesteps, num_samples=1,
summarize=False):
"""Compute the IWAE evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
num_samples: The number of samples to use to compute the IWAE bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: A no-op included for compatibility with FIVO.
states: The sequence of states sampled.
"""
# Initialization
num_instances = tf.shape(observation)[0]
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
log_weights = []
log_weight_acc = tf.zeros([num_samples, batch_size], dtype=observation.dtype)
for t in xrange(num_timesteps):
# run the model for one timestep
(zt, log_q_zt, log_p_zt, log_p_x_given_z, _) = model(
states[-1], observation, t)
# update accumulators
states.append(zt)
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
if summarize:
weight_dist = tf.contrib.distributions.Categorical(
logits=tf.transpose(log_weight_acc, perm=[1, 0]),
allow_nan_stats=False)
weight_entropy = weight_dist.entropy()
weight_entropy = tf.reduce_mean(weight_entropy)
tf.summary.scalar("weight_entropy/%d" % t, weight_entropy)
log_weights.append(log_weight_acc)
# Compute the lower bound on the log evidence.
log_p_hat = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, observation.dtype))) / num_timesteps
loss = -tf.reduce_mean(log_p_hat)
losses = [Loss("log_p_hat", loss)]
# we clip off the initial state before returning.
# there are no emas for iwae, so we return a noop for that
return log_p_hat, losses, tf.no_op(), states[1:], log_weights
def multinomial_resampling(log_weights, states, n, b):
"""Resample states with multinomial resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
resampling_parameters = tf.transpose(log_weights, perm=[1,0])
resampling_dist = tf.contrib.distributions.Categorical(logits=resampling_parameters)
ancestors = tf.stop_gradient(
resampling_dist.sample(sample_shape=n))
log_probs = resampling_dist.log_prob(ancestors)
offset = tf.expand_dims(tf.range(b), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def stratified_resampling(log_weights, states, n, b):
"""Resample states with straitified resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
log_weights = tf.transpose(log_weights, perm=[1,0])
probs = tf.nn.softmax(
tf.tile(tf.expand_dims(log_weights, axis=1),
[1, n, 1])
)
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
bins = tf.range(n, dtype=probs.dtype) / n
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
resampling_dist = tf.contrib.distributions.Categorical(
probs = resampling_parameters,
allow_nan_stats=False)
ancestors = tf.stop_gradient(
resampling_dist.sample())
log_probs = resampling_dist.log_prob(ancestors)
ancestors = tf.transpose(ancestors, perm=[1,0])
log_probs = tf.transpose(log_probs, perm=[1,0])
offset = tf.expand_dims(tf.range(b), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def systematic_resampling(log_weights, states, n, b):
"""Resample states with systematic resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
log_weights = tf.transpose(log_weights, perm=[1,0])
probs = tf.nn.softmax(
tf.tile(tf.expand_dims(log_weights, axis=1),
[1, n, 1])
)
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
bins = tf.range(n, dtype=probs.dtype) / n
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
resampling_dist = tf.contrib.distributions.Categorical(
probs=resampling_parameters,
allow_nan_stats=True)
U = tf.random_uniform((b, 1, 1), dtype=probs.dtype)
ancestors = tf.stop_gradient(tf.reduce_sum(tf.to_float(U > strat_cdfs[:,:,1:]), axis=-1))
log_probs = resampling_dist.log_prob(ancestors)
ancestors = tf.transpose(ancestors, perm=[1,0])
log_probs = tf.transpose(log_probs, perm=[1,0])
offset = tf.expand_dims(tf.range(b, dtype=probs.dtype), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def log_blend(inputs, weights):
"""Blends state in the log space.
Args:
inputs: A set of scalar states, one for each particle in each particle filter.
Should be [num_samples, batch_size].
weights: A set of weights used to blend the state. Each set of weights
should be of dimension [num_samples] (one weight for each previous particle).
There should be one set of weights for each new particle in each particle filter.
Thus the shape should be [num_samples, batch_size, num_samples] where
the first axis indexes new particle and the last axis indexes old particles.
Returns:
blended: The blended states, a tensor of shape [num_samples, batch_size].
"""
raw_max = tf.reduce_max(inputs, axis=0, keepdims=True)
my_max = tf.stop_gradient(
tf.where(tf.is_finite(raw_max), raw_max, tf.zeros_like(raw_max))
)
# Don't ask.
blended = tf.log(tf.einsum("ijk,kj->ij", weights, tf.exp(inputs - raw_max))) + my_max
return blended
def relaxed_resampling(log_weights, states, num_samples, batch_size,
log_r_x=None, blend_type="log", temperature=0.5,
straight_through=False):
"""Resample states with relaxed resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b x n) Tensor of relaxed one hot representations of the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
assert blend_type in ["log", "linear"], "Blend type must be 'log' or 'linear'."
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
state_dim = states[0].get_shape().as_list()[-1]
# weights are num_samples by batch_size, so we transpose to get a
# set of batch_size distributions over [0,num_samples).
resampling_parameters = tf.transpose(log_weights, perm=[1, 0])
resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical(
temperature,
logits=resampling_parameters)
# sample num_samples samples from the distribution, resulting in a
# [num_samples, batch_size, num_samples] Tensor that represents a set of
# [num_samples, batch_size] blending weights. The dimensions represent
# [sample index, batch index, blending weight index]
ancestors = resampling_dist.sample(sample_shape=num_samples)
if straight_through:
# Forward pass discrete choices, backwards pass soft choices
hard_ancestor_indices = tf.argmax(ancestors, axis=-1)
hard_ancestors = tf.one_hot(hard_ancestor_indices, num_samples,
dtype=ancestors.dtype)
ancestors = tf.stop_gradient(hard_ancestors - ancestors) + ancestors
log_probs = resampling_dist.log_prob(ancestors)
if log_r_x is not None and blend_type == "log":
log_r_x = tf.reshape(log_r_x, [num_samples, batch_size])
log_r_x = log_blend(log_r_x, ancestors)
log_r_x = tf.reshape(log_r_x, [num_samples*batch_size])
elif log_r_x is not None and blend_type == "linear":
# If blend type is linear just add log_r to the states that will be blended
# linearly.
states.append(log_r_x)
# transpose the 'indices' to be [batch_index, blending weight index, sample index]
ancestor_inds = tf.transpose(ancestors, perm=[1, 2, 0])
resampled_states = []
for state in states:
# state is currently [num_samples * batch_size, state_dim] so we reshape
# to [num_samples, batch_size, state_dim] and then transpose to
# [batch_size, state_size, num_samples]
state = tf.transpose(tf.reshape(state, [num_samples, batch_size, -1]), perm=[1, 2, 0])
# state is now (batch_size, state_size, num_samples)
# and ancestor is (batch index, blending weight index, sample index)
# multiplying these gives a matrix of size [batch_size, state_size, num_samples]
next_state = tf.matmul(state, ancestor_inds)
# transpose the state to be [num_samples, batch_size, state_size]
# and then reshape it to match the state format.
next_state = tf.reshape(tf.transpose(next_state, perm=[2,0,1]), [num_samples*batch_size, state_dim])
resampled_states.append(next_state)
new_dist = tf.contrib.distributions.Categorical(
logits=resampling_parameters)
if log_r_x is not None and blend_type == "linear":
# If blend type is linear pop off log_r that we added to the states.
log_r_x = tf.squeeze(resampled_states[-1])
resampled_states = resampled_states[:-1]
return resampled_states, log_probs, log_r_x, resampling_parameters, ancestors, new_dist
def fivo(model,
observation,
num_timesteps,
resampling_schedule,
num_samples=1,
use_resampling_grads=True,
resampling_type="multinomial",
resampling_temperature=0.5,
aux=True,
summarize=False):
"""Compute the FIVO evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
resampling_schedule: A list of booleans of length num_timesteps, contains
True if a resampling should occur on a specific timestep.
num_samples: The number of samples to use to compute the IWAE bound.
use_resampling_grads: Whether or not to include the resampling gradients
in loss.
resampling type: The type of resampling, one of "multinomial", "stratified",
"relaxed-logblend", "relaxed-linearblend", "relaxed-stateblend", or
"systematic".
resampling_temperature: A positive temperature only used for relaxed
resampling.
aux: If true, compute the FIVO-AUX bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: An op to update the baseline ema used for the resampling
gradients.
states: The sequence of states sampled.
"""
# Initialization
num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
prev_state = states[0]
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
prev_log_r_zt = tf.zeros([num_instances], dtype=observation.dtype)
log_weights = []
log_weights_all = []
log_p_hats = []
resampling_log_probs = []
for t in xrange(num_timesteps):
# run the model for one timestep
(zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_zt) = model(
prev_state, observation, t)
# update accumulators
states.append(zt)
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
if aux:
if t == num_timesteps - 1:
log_weight -= prev_log_r_zt
else:
log_weight += log_r_zt - prev_log_r_zt
prev_log_r_zt = log_r_zt
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
log_weights_all.append(log_weight_acc)
if resampling_schedule[t]:
# These objects will be resampled
to_resample = [states[-1]]
if aux and "relaxed" not in resampling_type:
to_resample.append(prev_log_r_zt)
# do the resampling
if resampling_type == "multinomial":
(resampled,
resampling_log_prob,
_, _, _) = multinomial_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif resampling_type == "stratified":
(resampled,
resampling_log_prob,
_, _, _) = stratified_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif resampling_type == "systematic":
(resampled,
resampling_log_prob,
_, _, _) = systematic_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif "relaxed" in resampling_type:
if aux:
if resampling_type == "relaxed-logblend":
(resampled,
resampling_log_prob,
prev_log_r_zt,
_, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
log_r_x=prev_log_r_zt,
blend_type="log")
elif resampling_type == "relaxed-linearblend":
(resampled,
resampling_log_prob,
prev_log_r_zt,
_, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
log_r_x=prev_log_r_zt,
blend_type="linear")
elif resampling_type == "relaxed-stateblend":
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt = model.r.r_xn(resampled[0], t)
prev_log_r_zt = tf.reduce_sum(
prev_r_zt.log_prob(observation), axis=[1])
elif resampling_type == "relaxed-stateblend-st":
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
straight_through=True)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt = model.r.r_xn(resampled[0], t)
prev_log_r_zt = tf.reduce_sum(
prev_r_zt.log_prob(observation), axis=[1])
else:
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature)
#if summarize:
# resampling_entropy = resampling_dist.entropy()
# resampling_entropy = tf.reduce_mean(resampling_entropy)
# tf.summary.scalar("weight_entropy/%d" % t, resampling_entropy)
resampling_log_probs.append(tf.reduce_sum(resampling_log_prob, axis=0))
prev_state = resampled[0]
if aux and "relaxed" not in resampling_type:
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt should always be [num_instances]
prev_log_r_zt = tf.squeeze(resampled[1])
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats.append(
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
tf.cast(num_samples, dtype=observation.dtype)))
# reset the weights
log_weights.append(log_weight_acc)
log_weight_acc = tf.zeros_like(log_weight_acc)
else:
prev_state = states[-1]
# Compute the final weight update. If we just resampled this will be zero.
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, dtype=observation.dtype)))
# If we ever resampled, then sum up the previous log p hat terms
if len(log_p_hats) > 0:
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
else: # otherwise, log_p_hat only comes from the final update
log_p_hat = final_update
if use_resampling_grads and any(resampling_schedule):
# compute the rewards
# cumsum([a, b, c]) => [a, a+b, a+b+c]
# learning signal at timestep t is
# [sum from i=t+1 to T of log_p_hat_i for t=1:T]
# so we will compute (sum from i=1 to T of log_p_hat_i)
# and at timestep t will subtract off (sum from i=1 to t of log_p_hat_i)
# rewards is a [num_resampling_events, batch_size] Tensor
rewards = tf.stop_gradient(
tf.expand_dims(log_p_hat, 0) - tf.cumsum(log_p_hats, axis=0))
batch_avg_rewards = tf.reduce_mean(rewards, axis=1)
# compute ema baseline.
# centered_rewards is [num_resampling_events, batch_size]
baseline_ema = tf.train.ExponentialMovingAverage(decay=0.94)
maintain_baseline_op = baseline_ema.apply([batch_avg_rewards])
baseline = tf.expand_dims(baseline_ema.average(batch_avg_rewards), 1)
centered_rewards = rewards - baseline
if summarize:
summ.summarize_learning_signal(rewards, "rewards")
summ.summarize_learning_signal(centered_rewards, "centered_rewards")
# compute the loss tensor.
resampling_grads = tf.reduce_sum(
tf.stop_gradient(centered_rewards) * resampling_log_probs, axis=0)
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps),
Loss("resampling_grads", -tf.reduce_mean(resampling_grads)/num_timesteps)]
else:
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps)]
maintain_baseline_op = tf.no_op()
log_p_hat /= num_timesteps
# we clip off the initial state before returning.
return log_p_hat, losses, maintain_baseline_op, states[1:], log_weights_all
def fivo_aux_td(
model,
observation,
num_timesteps,
resampling_schedule,
num_samples=1,
summarize=False):
"""Compute the FIVO_AUX evidence lower bound."""
# Initialization
num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
prev_state = states[0]
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
prev_log_r = tf.zeros([num_instances], dtype=observation.dtype)
# must be pre-resampling
log_rs = []
# must be post-resampling
r_tilde_params = [model.r_tilde.r_zt(states[0], observation, 0)]
log_r_tildes = []
log_p_xs = []
# contains the weight at each timestep before resampling only on resampling timesteps
log_weights = []
# contains weight at each timestep before resampling
log_weights_all = []
log_p_hats = []
for t in xrange(num_timesteps):
# run the model for one timestep
# zt is state, [num_instances, state_dim]
# log_q_zt, log_p_x_given_z is [num_instances]
# r_tilde_mu, r_tilde_sigma is [num_instances, state_dim]
# p_ztplus1 is a normal distribution on [num_instances, state_dim]
(zt, log_q_zt, log_p_zt, log_p_x_given_z,
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1) = model(prev_state, observation, t)
# Compute the log weight without log r.
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
# Compute log r.
if t == num_timesteps - 1:
log_r = tf.zeros_like(prev_log_r)
else:
p_mu = p_ztplus1.mean()
p_sigma_sq = p_ztplus1.variance()
log_r = (tf.log(r_tilde_sigma_sq) -
tf.log(r_tilde_sigma_sq + p_sigma_sq) -
tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
#log_weight += tf.stop_gradient(log_r - prev_log_r)
log_weight += log_r - prev_log_r
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
# Update accumulators
states.append(zt)
log_weights_all.append(log_weight_acc)
log_p_xs.append(log_p_x_given_z)
log_rs.append(log_r)
# Compute log_r_tilde as [num_instances] Tensor.
prev_r_tilde_mu, prev_r_tilde_sigma_sq = r_tilde_params[-1]
prev_log_r_tilde = -0.5*tf.reduce_sum(
tf.square(zt - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
#tf.square(tf.stop_gradient(zt) - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
#tf.square(zt - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
log_r_tildes.append(prev_log_r_tilde)
# optionally resample
if resampling_schedule[t]:
# These objects will be resampled
if t < num_timesteps - 1:
to_resample = [zt, log_r, r_tilde_mu, r_tilde_sigma_sq]
else:
to_resample = [zt, log_r]
(resampled,
_, _, _, _) = multinomial_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
prev_state = resampled[0]
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt and log_r_tilde should always be [num_instances]
prev_log_r = tf.squeeze(resampled[1])
if t < num_timesteps -1:
r_tilde_params.append((resampled[2], resampled[3]))
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats.append(
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
tf.cast(num_samples, dtype=observation.dtype)))
# reset the weights
log_weights.append(log_weight_acc)
log_weight_acc = tf.zeros_like(log_weight_acc)
else:
prev_state = zt
prev_log_r = log_r
if t < num_timesteps - 1:
r_tilde_params.append((r_tilde_mu, r_tilde_sigma_sq))
# Compute the final weight update. If we just resampled this will be zero.
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, dtype=observation.dtype)))
# If we ever resampled, then sum up the previous log p hat terms
if len(log_p_hats) > 0:
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
else: # otherwise, log_p_hat only comes from the final update
log_p_hat = final_update
# Compute the bellman loss.
# Will remove the first timestep as it is not used.
# log p(x_t|z_t) is in row t-1.
log_p_x = tf.reshape(tf.stack(log_p_xs),
[num_timesteps, num_samples, batch_size])
# log r_t is contained in row t-1.
# last column is zeros (because at timestep T (num_timesteps) r is 1.
log_r = tf.reshape(tf.stack(log_rs),
[num_timesteps, num_samples, batch_size])
# [num_timesteps, num_instances]. log r_tilde_t is in row t-1.
log_r_tilde = tf.reshape(tf.stack(log_r_tildes),
[num_timesteps, num_samples, batch_size])
log_lambda = tf.reduce_mean(log_r_tilde - log_p_x - log_r, axis=1,
keepdims=True)
bellman_sos = tf.reduce_mean(tf.square(
log_r_tilde - tf.stop_gradient(log_lambda + log_p_x + log_r)), axis=[0, 1])
bellman_loss = tf.reduce_mean(bellman_sos)/num_timesteps
tf.summary.scalar("bellman_loss", bellman_loss)
if len(tf.get_collection("LOG_P_HAT_VARS")) == 0:
log_p_hat_collection = list(set(tf.trainable_variables()) -
set(tf.get_collection("R_TILDE_VARS")))
for v in log_p_hat_collection:
tf.add_to_collection("LOG_P_HAT_VARS", v)
log_p_hat /= num_timesteps
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat), "LOG_P_HAT_VARS"),
Loss("bellman_loss", bellman_loss, "R_TILDE_VARS")]
return log_p_hat, losses, tf.no_op(), states[1:], log_weights_all
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import models
def make_long_chain_dataset(
state_size=1,
num_obs=5,
steps_per_obs=3,
variance=1.,
observation_variance=1.,
batch_size=4,
num_samples=1,
observation_type=models.STANDARD_OBSERVATION,
transition_type=models.STANDARD_TRANSITION,
fixed_observation=None,
dtype="float32"):
"""Creates a long chain data generating process.
Creates a tf.data.Dataset that provides batches of data from a long
chain.
Args:
state_size: The dimension of the state space of the process.
num_obs: The number of observations in the chain.
steps_per_obs: The number of steps between each observation.
variance: The variance of the normal distributions used at each timestep.
batch_size: The number of trajectories to include in each batch.
num_samples: The number of replicas of each trajectory to include in each
batch.
dtype: The datatype of the states and observations.
Returns:
dataset: A tf.data.Dataset that can be iterated over.
"""
num_timesteps = num_obs * steps_per_obs
def data_generator():
"""An infinite generator of latents and observations from the model."""
while True:
states = []
observations = []
# z0 ~ Normal(0, sqrt(variance)).
states.append(
np.random.normal(size=[state_size],
scale=np.sqrt(variance)).astype(dtype))
# start at 1 because we've already generated z0
# go to num_timesteps+1 because we want to include the num_timesteps-th step
for t in xrange(1, num_timesteps+1):
if transition_type == models.ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == models.STANDARD_TRANSITION:
loc = states[-1]
new_state = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance))
states.append(new_state.astype(dtype))
if t % steps_per_obs == 0:
if fixed_observation is None:
if observation_type == models.SQUARED_OBSERVATION:
loc = np.square(states[-1])
elif observation_type == models.ABS_OBSERVATION:
loc = np.abs(states[-1])
elif observation_type == models.STANDARD_OBSERVATION:
loc = states[-1]
new_obs = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(observation_variance)).astype(dtype)
else:
new_obs = np.ones([state_size])* fixed_observation
observations.append(new_obs)
yield states, observations
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps+1, state_size], [num_obs, state_size]))
dataset = dataset.repeat().batch(batch_size)
def tile_batch(state, observation):
state = tf.tile(state, [num_samples, 1, 1])
observation = tf.tile(observation, [num_samples, 1, 1])
return state, observation
dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024)
return dataset
def make_dataset(bs=None,
state_size=1,
num_timesteps=10,
variance=1.,
prior_type="unimodal",
bimodal_prior_weight=0.5,
bimodal_prior_mean=1,
transition_type=models.STANDARD_TRANSITION,
fixed_observation=None,
batch_size=4,
num_samples=1,
dtype='float32'):
"""Creates a data generating process.
Creates a tf.data.Dataset that provides batches of data.
Args:
bs: The parameters of the data generating process. If None, new bs are
randomly generated.
state_size: The dimension of the state space of the process.
num_timesteps: The length of the state sequences in the process.
variance: The variance of the normal distributions used at each timestep.
batch_size: The number of trajectories to include in each batch.
num_samples: The number of replicas of each trajectory to include in each
batch.
Returns:
bs: The true bs used to generate the data
dataset: A tf.data.Dataset that can be iterated over.
"""
if bs is None:
bs = [np.random.uniform(size=[state_size]).astype(dtype) for _ in xrange(num_timesteps)]
tf.logging.info("data generating processs bs: %s",
np.array(bs).reshape(num_timesteps))
def data_generator():
"""An infinite generator of latents and observations from the model."""
while True:
states = []
if prior_type == "unimodal" or prior_type == "nonlinear":
# Prior is Normal(0, sqrt(variance)).
states.append(np.random.normal(size=[state_size], scale=np.sqrt(variance)).astype(dtype))
elif prior_type == "bimodal":
if np.random.uniform() > bimodal_prior_weight:
loc = bimodal_prior_mean
else:
loc = - bimodal_prior_mean
states.append(np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance)
).astype(dtype))
for t in xrange(num_timesteps):
if transition_type == models.ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == models.STANDARD_TRANSITION:
loc = states[-1]
loc += bs[t]
new_state = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance)).astype(dtype)
states.append(new_state)
if fixed_observation is None:
observation = states[-1]
else:
observation = np.ones_like(states[-1]) * fixed_observation
yield np.array(states[:-1]), observation
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps, state_size], [state_size]))
dataset = dataset.repeat().batch(batch_size)
def tile_batch(state, observation):
state = tf.tile(state, [num_samples, 1, 1])
observation = tf.tile(observation, [num_samples, 1])
return state, observation
dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024)
return np.array(bs), dataset
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import sonnet as snt
import tensorflow as tf
import numpy as np
import math
SQUARED_OBSERVATION = "squared"
ABS_OBSERVATION = "abs"
STANDARD_OBSERVATION = "standard"
OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION]
ROUND_TRANSITION = "round"
STANDARD_TRANSITION = "standard"
TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION]
class Q(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
init_mu0_to_zero=False,
graph_collection_name="Q_VARS"):
self.sigma_min = sigma_min
self.dtype = dtype
self.graph_collection_name = graph_collection_name
initializers = []
for t in xrange(num_timesteps):
if t == 0 and init_mu0_to_zero:
initializers.append(
{"w": tf.zeros_initializer, "b": tf.zeros_initializer})
else:
initializers.append(
{"w": tf.random_uniform_initializer(seed=random_seed),
"b": tf.zeros_initializer})
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.mus = [
snt.Linear(output_size=state_size,
initializers=initializers[t],
name="q_mu_%d" % t,
custom_getter=custom_getter
)
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(num_timesteps)
]
def q_zt(self, observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](tf.concat([observation, prev_state], axis=1))
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0])
if t != 0:
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[1,0])
class PreviousStateQ(Q):
def q_zt(self, unused_observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](prev_state)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[0,0])
class ObservationQ(Q):
def q_zt(self, observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](observation)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0])
class SimpleMeanQ(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
init_mu0_to_zero=False,
graph_collection_name="Q_VARS"):
self.sigma_min = sigma_min
self.dtype = dtype
self.graph_collection_name = graph_collection_name
initializers = []
for t in xrange(num_timesteps):
if t == 0 and init_mu0_to_zero:
initializers.append(tf.zeros_initializer)
else:
initializers.append(tf.random_uniform_initializer(seed=random_seed))
self.mus = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_mu_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=initializers[t])
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(num_timesteps)
]
def q_zt(self, unused_observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = tf.tile(self.mus[t][tf.newaxis, :], [batch_size, 1])
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/%d" % t, f[0])
class R(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
sigma_init=1.,
random_seed=None,
graph_collection_name="R_VARS"):
self.dtype = dtype
self.sigma_min = sigma_min
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
"b": tf.zeros_initializer}
self.graph_collection_name=graph_collection_name
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.mus= [
snt.Linear(output_size=state_size,
initializers=initializers,
name="r_mu_%d" % t,
custom_getter=custom_getter)
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer=tf.constant_initializer(sigma_init))
for t in xrange(num_timesteps)
]
def r_xn(self, z_t, t):
batch_size = tf.shape(z_t)[0]
r_mu = self.mus[t](z_t)
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1])
return tf.contrib.distributions.Normal(
loc=r_mu, scale=tf.sqrt(r_sigma))
def summarize_weights(self):
for t in range(len(self.mus) - 1):
tf.summary.scalar("r_mu/%d" % t, self.mus[t][0])
tf.summary.scalar("r_sigma/%d" % t, self.sigmas[t][0])
class P(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
trainable=True,
init_bs_to_zero=False,
graph_collection_name="P_VARS"):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.graph_collection_name = graph_collection_name
if init_bs_to_zero:
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)]
else:
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)]
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="p_b_%d" % (t + 1),
initializer=initializers[t],
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
trainable=trainable) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
def posterior(self, observation, prev_state, t):
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu = observation - self.Bs[t]
if t > 0:
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t)
mu /= float(self.num_timesteps - t + 1)
sigma = tf.ones_like(mu) * self.variance * (
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1))
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def lookahead(self, state, t):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu = state + self.Bs[t]
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t)
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(mu) * self.variance * (self.num_timesteps + 1)
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
z_mu_p = prev_state + self.bs[t - 1]
else: # p(z_0) is Normal(0,1)
z_mu_p = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance))
return p_zt
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance))
class ShortChainNonlinearP(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
transition_dist=tf.contrib.distributions.Normal,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.observation_variance = observation_variance
self.transition_type = transition_type
self.transition_dist = transition_dist
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(prev_state)
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance))
elif self.transition_type == STANDARD_TRANSITION:
loc = prev_state
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance))
else: # p(z_0) is Normal(0,1)
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance)
p_zt = self.transition_dist(
loc=loc,
scale=tf.sqrt(tf.ones_like(loc) * self.variance))
return p_zt
def generative(self, unused_obs, z_ni):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(z_ni)
elif self.transition_type == STANDARD_TRANSITION:
loc = z_ni
generative_sigma_sq = tf.ones_like(loc) * self.observation_variance
return self.transition_dist(
loc=loc, scale=tf.sqrt(generative_sigma_sq))
class BimodalPriorP(object):
def __init__(self,
state_size,
num_timesteps,
mixing_coeff=0.5,
prior_mode_mean=1,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
trainable=True,
init_bs_to_zero=False,
graph_collection_name="P_VARS"):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.mixing_coeff = mixing_coeff
self.prior_mode_mean = prior_mode_mean
if init_bs_to_zero:
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)]
else:
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)]
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="b_%d" % (t + 1),
initializer=initializers[t],
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
trainable=trainable) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
def posterior(self, observation, prev_state, t):
# NOTE: This is currently wrong, but would require a refactoring of
# summarize_q to fix as kl is not defined for a mixture
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu = observation - self.Bs[t]
if t > 0:
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t)
mu /= float(self.num_timesteps - t + 1)
sigma = tf.ones_like(mu) * self.variance * (
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1))
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def lookahead(self, state, t):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu = state + self.Bs[t]
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t)
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
sum_of_bs = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(sum_of_bs) * self.variance * (self.num_timesteps + 1)
mu_pos = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean) + sum_of_bs
mu_neg = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean) + sum_of_bs
zn_pos = tf.contrib.distributions.Normal(
loc=mu_pos,
scale=tf.sqrt(sigma))
zn_neg = tf.contrib.distributions.Normal(
loc=mu_neg,
scale=tf.sqrt(sigma))
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64)
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1])
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs)
zn_dist = tf.contrib.distributions.Mixture(
cat=mode_selection_dist,
components=[zn_pos, zn_neg],
validate_args=True)
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(zn_dist.log_prob(observation), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
z_mu_p = prev_state + self.bs[t - 1]
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance))
return p_zt
else: # p(z_0) is mixture of two Normals
mu_pos = tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean
mu_neg = tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean
z0_pos = tf.contrib.distributions.Normal(
loc=mu_pos,
scale=tf.sqrt(tf.ones_like(mu_pos) * self.variance))
z0_neg = tf.contrib.distributions.Normal(
loc=mu_neg,
scale=tf.sqrt(tf.ones_like(mu_neg) * self.variance))
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64)
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1])
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs)
z0_dist = tf.contrib.distributions.Mixture(
cat=mode_selection_dist,
components=[z0_pos, z0_neg],
validate_args=False)
return z0_dist
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance))
class Model(object):
def __init__(self,
p,
q,
r,
state_size,
num_timesteps,
dtype=tf.float32):
self.p = p
self.q = q
self.r = r
self.state_size = state_size
self.num_timesteps = num_timesteps
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def __call__(self, prev_state, observation, t):
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observation, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q
zt = q_zt.sample()
r_xn = self.r.r_xn(zt, t)
# Calculate the logprobs and sum over the state size.
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
log_r_xn = tf.reduce_sum(r_xn.log_prob(observation), axis=1)
# If we're at the last timestep, also calc the logprob of the observation.
if t == self.num_timesteps - 1:
generative_dist = self.p.generative(observation, zt)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn)
@staticmethod
def create(state_size,
num_timesteps,
sigma_min=1e-5,
r_sigma_init=1,
variance=1.0,
mixing_coeff=0.5,
prior_mode_mean=1.0,
dtype=tf.float32,
random_seed=None,
train_p=True,
p_type="unimodal",
q_type="normal",
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
use_bs=True):
if p_type == "unimodal":
p = P(state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif p_type == "bimodal":
p = BimodalPriorP(
state_size,
num_timesteps,
mixing_coeff=mixing_coeff,
prior_mode_mean=prior_mode_mean,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif "nonlinear" in p_type:
if "cauchy" in p_type:
trans_dist = tf.contrib.distributions.Cauchy
else:
trans_dist = tf.contrib.distributions.Normal
p = ShortChainNonlinearP(
state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
transition_type=transition_type,
transition_dist=trans_dist,
dtype=dtype,
random_seed=random_seed
)
if q_type == "normal":
q_class = Q
elif q_type == "simple_mean":
q_class = SimpleMeanQ
elif q_type == "prev_state":
q_class = PreviousStateQ
elif q_type == "observation":
q_class = ObservationQ
q = q_class(state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed,
init_mu0_to_zero=not use_bs)
r = R(state_size,
num_timesteps,
sigma_min=sigma_min,
sigma_init=r_sigma_init,
dtype=dtype,
random_seed=random_seed)
model = Model(p, q, r, state_size, num_timesteps, dtype=dtype)
return model
class BackwardsModel(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="b_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
self.q_mus = [
snt.Linear(output_size=state_size) for _ in xrange(num_timesteps)
]
self.q_sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.r_mus = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_mu_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.r_sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def posterior(self, unused_observation, prev_state, unused_t):
# TODO(dieterichl): Correct this.
return tf.contrib.distributions.Normal(
loc=tf.zeros_like(prev_state), scale=tf.zeros_like(prev_state))
def lookahead(self, state, unused_t):
# TODO(dieterichl): Correct this.
return tf.contrib.distributions.Normal(
loc=tf.zeros_like(state), scale=tf.zeros_like(state))
def q_zt(self, observation, next_state, t):
"""Computes the variational posterior q(z_{t}|z_{t+1}, z_n)."""
t_backwards = self.num_timesteps - t - 1
batch_size = tf.shape(next_state)[0]
q_mu = self.q_mus[t_backwards](tf.concat([observation, next_state], axis=1))
q_sigma = tf.maximum(
tf.nn.softplus(self.q_sigmas[t_backwards]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def p_zt(self, prev_state, t):
"""Computes the model p(z_{t+1}| z_{t})."""
t_backwards = self.num_timesteps - t - 1
z_mu_p = prev_state + self.bs[t_backwards]
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.ones_like(z_mu_p))
return p_zt
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.ones_like(generative_p_mu))
def r(self, z_t, t):
t_backwards = self.num_timesteps - t - 1
batch_size = tf.shape(z_t)[0]
r_mu = tf.tile(self.r_mus[t_backwards][tf.newaxis, :], [batch_size, 1])
r_sigma = tf.maximum(
tf.nn.softplus(self.r_sigmas[t_backwards]), self.sigma_min)
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1])
return tf.contrib.distributions.Normal(loc=r_mu, scale=tf.sqrt(r_sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(mu) * (self.num_timesteps + 1)
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1))
def __call__(self, next_state, observation, t):
# next state = z_{t+1}
# Compute the q distribution over z, q(z_{t}|z_n, z_{t+1}).
q_zt = self.q_zt(observation, next_state, t)
# sample from q
zt = q_zt.sample()
# Compute the p distribution over z, p(z_{t+1}|z_{t}).
p_zt = self.p_zt(zt, t)
# Compute log p(z_{t+1} | z_t)
if t == 0:
log_p_zt = p_zt.log_prob(observation)
else:
log_p_zt = p_zt.log_prob(next_state)
# Compute r prior over zt
r_zt = self.r(zt, t)
log_r_zt = r_zt.log_prob(zt)
# Compute proposal density at zt
log_q_zt = q_zt.log_prob(zt)
# If we're at the last timestep, also calc the logprob of the observation.
if t == self.num_timesteps - 1:
p_z0_dist = tf.contrib.distributions.Normal(
loc=tf.zeros_like(zt), scale=tf.ones_like(zt))
z0_log_prob = p_z0_dist.log_prob(zt)
else:
z0_log_prob = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, z0_log_prob, log_r_zt)
class LongChainP(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
observation_type=STANDARD_OBSERVATION,
transition_type=STANDARD_TRANSITION,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = steps_per_obs*num_obs + 1
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.observation_variance = observation_variance
self.observation_type = observation_type
self.transition_type = transition_type
def likelihood(self, observations):
"""Computes the model's true likelihood of the observations.
Args:
observations: A [batch_size, m, state_size] Tensor representing each of
the m observations.
Returns:
logprob: The true likelihood of the observations given the model.
"""
raise ValueError("Likelihood is not defined for long-chain models")
# batch_size = tf.shape(observations)[0]
# mu = tf.zeros([batch_size, self.state_size, self.num_obs], dtype=self.dtype)
# sigma = np.fromfunction(
# lambda i, j: 1 + self.steps_per_obs*np.minimum(i+1, j+1),
# [self.num_obs, self.num_obs])
# sigma += np.eye(self.num_obs)
# sigma = tf.convert_to_tensor(sigma * self.variance, dtype=self.dtype)
# sigma = tf.tile(sigma[tf.newaxis, tf.newaxis, ...],
# [batch_size, self.state_size, 1, 1])
# dist = tf.contrib.distributions.MultivariateNormalFullCovariance(
# loc=mu,
# covariance_matrix=sigma)
# Average over the batch and take the sum over the state size
#return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observations), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(prev_state)
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance))
elif self.transition_type == STANDARD_TRANSITION:
loc = prev_state
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance))
else: # p(z_0) is Normal(0,1)
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance)
p_zt = tf.contrib.distributions.Normal(
loc=loc,
scale=tf.sqrt(tf.ones_like(loc) * self.variance))
return p_zt
def generative(self, z_ni, t):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if self.observation_type == SQUARED_OBSERVATION:
generative_mu = tf.square(z_ni)
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d^2, %0.1f)" % (t, t, t, self.variance))
elif self.observation_type == ABS_OBSERVATION:
generative_mu = tf.abs(z_ni)
tf.logging.info("p(x_%d | z_%d) ~ N(|z_%d|, %0.1f)" % (t, t, t, self.variance))
elif self.observation_type == STANDARD_OBSERVATION:
generative_mu = z_ni
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t, t, self.variance))
generative_sigma_sq = tf.ones_like(generative_mu) * self.observation_variance
return tf.contrib.distributions.Normal(
loc=generative_mu, scale=tf.sqrt(generative_sigma_sq))
class LongChainQ(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.sigma_min = sigma_min
self.dtype = dtype
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = num_obs*steps_per_obs +1
initializers = {
"w": tf.random_uniform_initializer(seed=random_seed),
"b": tf.zeros_initializer
}
self.mus = [
snt.Linear(output_size=state_size, initializers=initializers)
for t in xrange(self.num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(self.num_timesteps)
]
def first_relevant_obs_index(self, t):
return int(max((t-1)/self.steps_per_obs, 0))
def q_zt(self, observations, prev_state, t):
"""Computes a distribution over z_t.
Args:
observations: a [batch_size, num_observations, state_size] Tensor.
prev_state: a [batch_size, state_size] Tensor.
t: The current timestep, an int Tensor.
"""
# filter out unneeded past obs
first_relevant_obs_index = int(math.floor(max(t-1, 0) / self.steps_per_obs))
num_relevant_observations = self.num_obs - first_relevant_obs_index
observations = observations[:,first_relevant_obs_index:,:]
batch_size = tf.shape(prev_state)[0]
# concatenate the prev state and observations along the second axis (that is
# not the batch or state size axis, and then flatten it to
# [batch_size, (num_relevant_observations + 1) * state_size] to feed it into
# the linear layer.
q_input = tf.concat([observations, prev_state[:,tf.newaxis, :]], axis=1)
q_input = tf.reshape(q_input,
[batch_size, (num_relevant_observations + 1) * self.state_size])
q_mu = self.mus[t](q_input)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
tf.logging.info(
"q(z_{t} | z_{tm1}, x_{obsf}:{obst}) ~ N(Linear([z_{tm1},x_{obsf}:{obst}]), sigma_{t})".format(
**{"t": t,
"tm1": t-1,
"obsf": (first_relevant_obs_index+1)*self.steps_per_obs,
"obst":self.steps_per_obs*self.num_obs}))
return q_zt
def summarize_weights(self):
pass
class LongChainR(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.dtype = dtype
self.sigma_min = sigma_min
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = num_obs*steps_per_obs + 1
self.sigmas = [
tf.get_variable(
shape=[self.num_future_obs(t)],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer=tf.constant_initializer(1.0))
for t in range(self.num_timesteps)
]
def first_future_obs_index(self, t):
return int(math.floor(t / self.steps_per_obs))
def num_future_obs(self, t):
return int(self.num_obs - self.first_future_obs_index(t))
def r_xn(self, z_t, t):
"""Computes a distribution over the future observations given current latent
state.
The indexing in these messages is 1 indexed and inclusive. This is
consistent with the latex documents.
Args:
z_t: [batch_size, state_size] Tensor
t: Current timestep
"""
tf.logging.info(
"r(x_{start}:{end} | z_{t}) ~ N(z_{t}, sigma_{t})".format(
**{"t": t,
"start": (self.first_future_obs_index(t)+1)*self.steps_per_obs,
"end": self.num_timesteps-1}))
batch_size = tf.shape(z_t)[0]
# the mean for all future observations is the same.
# this tiling results in a [batch_size, num_future_obs, state_size] Tensor
r_mu = tf.tile(z_t[:,tf.newaxis,:], [1, self.num_future_obs(t), 1])
# compute the variance
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
# the variance is the same across all state dimensions, so we only have to
# time sigma to be [batch_size, num_future_obs].
r_sigma = tf.tile(r_sigma[tf.newaxis,:, tf.newaxis], [batch_size, 1, self.state_size])
return tf.contrib.distributions.Normal(
loc=r_mu, scale=tf.sqrt(r_sigma))
def summarize_weights(self):
pass
class LongChainModel(object):
def __init__(self,
p,
q,
r,
state_size,
num_obs,
steps_per_obs,
dtype=tf.float32,
disable_r=False):
self.p = p
self.q = q
self.r = r
self.disable_r = disable_r
self.state_size = state_size
self.num_obs = num_obs
self.steps_per_obs = steps_per_obs
self.num_timesteps = steps_per_obs*num_obs + 1
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def next_obs_ind(self, t):
return int(math.floor(max(t-1,0)/self.steps_per_obs))
def __call__(self, prev_state, observations, t):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observations, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q and evaluate the logprobs, summing over the state size
zt = q_zt.sample()
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
if not self.disable_r and t < self.num_timesteps-1:
# score the remaining observations using r
r_xn = self.r.r_xn(zt, t)
log_r_xn = r_xn.log_prob(observations[:, self.next_obs_ind(t+1):, :])
# sum over state size and observation, leaving the batch index
log_r_xn = tf.reduce_sum(log_r_xn, axis=[1,2])
else:
log_r_xn = tf.zeros_like(log_p_zt)
if t != 0 and t % self.steps_per_obs == 0:
generative_dist = self.p.generative(zt, t)
log_p_x_given_z = generative_dist.log_prob(observations[:,self.next_obs_ind(t),:])
log_p_x_given_z = tf.reduce_sum(log_p_x_given_z, axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn)
@staticmethod
def create(state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
observation_type=STANDARD_OBSERVATION,
transition_type=STANDARD_TRANSITION,
dtype=tf.float32,
random_seed=None,
disable_r=False):
p = LongChainP(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
observation_type=observation_type,
transition_type=transition_type,
dtype=dtype,
random_seed=random_seed)
q = LongChainQ(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
r = LongChainR(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
model = LongChainModel(
p, q, r, state_size, num_obs, steps_per_obs,
dtype=dtype,
disable_r=disable_r)
return model
class RTilde(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
graph_collection_name="R_TILDE_VARS"):
self.dtype = dtype
self.sigma_min = sigma_min
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
"b": tf.zeros_initializer}
self.graph_collection_name=graph_collection_name
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.fns = [
snt.Linear(output_size=2*state_size,
initializers=initializers,
name="r_tilde_%d" % t,
custom_getter=custom_getter)
for t in xrange(num_timesteps)
]
def r_zt(self, z_t, observation, t):
#out = self.fns[t](tf.stop_gradient(tf.concat([z_t, observation], axis=1)))
out = self.fns[t](tf.concat([z_t, observation], axis=1))
mu, raw_sigma_sq = tf.split(out, 2, axis=1)
sigma_sq = tf.maximum(tf.nn.softplus(raw_sigma_sq), self.sigma_min)
return mu, sigma_sq
class TDModel(object):
def __init__(self,
p,
q,
r_tilde,
state_size,
num_timesteps,
dtype=tf.float32,
disable_r=False):
self.p = p
self.q = q
self.r_tilde = r_tilde
self.disable_r = disable_r
self.state_size = state_size
self.num_timesteps = num_timesteps
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def __call__(self, prev_state, observation, t):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observation, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q and evaluate the logprobs, summing over the state size
zt = q_zt.sample()
# If it isn't the last timestep, compute the distribution over the next z.
if t < self.num_timesteps - 1:
p_ztplus1 = self.p.p_zt(zt, t+1)
else:
p_ztplus1 = None
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
if not self.disable_r and t < self.num_timesteps-1:
# score the remaining observations using r
r_tilde_mu, r_tilde_sigma_sq = self.r_tilde.r_zt(zt, observation, t+1)
else:
r_tilde_mu = None
r_tilde_sigma_sq = None
if t == self.num_timesteps - 1:
generative_dist = self.p.generative(observation, zt)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z,
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1)
@staticmethod
def create(state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
train_p=True,
p_type="unimodal",
q_type="normal",
mixing_coeff=0.5,
prior_mode_mean=1.0,
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
use_bs=True):
if p_type == "unimodal":
p = P(state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif p_type == "bimodal":
p = BimodalPriorP(
state_size,
num_timesteps,
mixing_coeff=mixing_coeff,
prior_mode_mean=prior_mode_mean,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif "nonlinear" in p_type:
if "cauchy" in p_type:
trans_dist = tf.contrib.distributions.Cauchy
else:
trans_dist = tf.contrib.distributions.Normal
p = ShortChainNonlinearP(
state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
transition_type=transition_type,
transition_dist=trans_dist,
dtype=dtype,
random_seed=random_seed
)
if q_type == "normal":
q_class = Q
elif q_type == "simple_mean":
q_class = SimpleMeanQ
elif q_type == "prev_state":
q_class = PreviousStateQ
elif q_type == "observation":
q_class = ObservationQ
q = q_class(state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed,
init_mu0_to_zero=not use_bs)
r_tilde = RTilde(
state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
model = TDModel(p, q, r_tilde, state_size, num_timesteps, dtype=dtype)
return model
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
model="forward"
T=5
num_obs=1
var=0.1
n=4
lr=0.0001
bound="fivo-aux"
q_type="normal"
resampling_method="multinomial"
rgrad="true"
p_type="unimodal"
use_bs=false
LOGDIR=/tmp/fivo/model-$model-$bound-$resampling_method-resampling-rgrad-$rgrad-T-$T-var-$var-n-$n-lr-$lr-q-$q_type-p-$p_type
python train.py \
--logdir=$LOGDIR \
--model=$model \
--bound=$bound \
--q_type=$q_type \
--p_type=$p_type \
--variance=$var \
--use_resampling_grads=$rgrad \
--resampling=always \
--resampling_method=$resampling_method \
--batch_size=4 \
--num_samples=$n \
--num_timesteps=$T \
--num_eval_samples=256 \
--summarize_every=100 \
--learning_rate=$lr \
--decay_steps=1000000 \
--max_steps=1000000000 \
--random_seed=1234 \
--train_p=false \
--use_bs=$use_bs \
--alsologtostderr
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for plotting and summarizing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import scipy
import tensorflow as tf
import models
def summarize_ess(weights, only_last_timestep=False):
"""Plots the effective sample size.
Args:
weights: List of length num_timesteps Tensors of shape
[num_samples, batch_size]
"""
num_timesteps = len(weights)
batch_size = tf.cast(tf.shape(weights[0])[1], dtype=tf.float64)
for i in range(num_timesteps):
if only_last_timestep and i < num_timesteps-1: continue
w = tf.nn.softmax(weights[i], dim=0)
centered_weights = w - tf.reduce_mean(w, axis=0, keepdims=True)
variance = tf.reduce_sum(tf.square(centered_weights))/(batch_size-1)
ess = 1./tf.reduce_mean(tf.reduce_sum(tf.square(w), axis=0))
tf.summary.scalar("ess/%d" % i, ess)
tf.summary.scalar("ese/%d" % i, ess / batch_size)
tf.summary.scalar("weight_variance/%d" % i, variance)
def summarize_particles(states, weights, observation, model):
"""Plots particle locations and weights.
Args:
states: List of length num_timesteps Tensors of shape
[batch_size*num_particles, state_size].
weights: List of length num_timesteps Tensors of shape [num_samples,
batch_size]
observation: Tensor of shape [batch_size*num_samples, state_size]
"""
num_timesteps = len(weights)
num_samples, batch_size = weights[0].get_shape().as_list()
# get q0 information for plotting
q0_dist = model.q.q_zt(observation, tf.zeros_like(states[0]), 0)
q0_loc = q0_dist.loc[0:batch_size, 0]
q0_scale = q0_dist.scale[0:batch_size, 0]
# get posterior information for plotting
post = (model.p.mixing_coeff, model.p.prior_mode_mean, model.p.variance,
tf.reduce_sum(model.p.bs), model.p.num_timesteps)
# Reshape states and weights to be [time, num_samples, batch_size]
states = tf.stack(states)
weights = tf.stack(weights)
# normalize the weights over the sample dimension
weights = tf.nn.softmax(weights, dim=1)
states = tf.reshape(states, tf.shape(weights))
ess = 1./tf.reduce_sum(tf.square(weights), axis=1)
def _plot_states(states_batch, weights_batch, observation_batch, ess_batch, q0, post):
"""
states: [time, num_samples, batch_size]
weights [time, num_samples, batch_size]
observation: [batch_size, 1]
q0: ([batch_size], [batch_size])
post: ...
"""
num_timesteps, _, batch_size = states_batch.shape
plots = []
for i in range(batch_size):
states = states_batch[:,:,i]
weights = weights_batch[:,:,i]
observation = observation_batch[i]
ess = ess_batch[:,i]
q0_loc = q0[0][i]
q0_scale = q0[1][i]
fig = plt.figure(figsize=(7, (num_timesteps + 1) * 2))
# Each timestep gets two plots -- a bar plot and a histogram of state locs.
# The bar plot will be bar_rows rows tall.
# The histogram will be 1 row tall.
# There is also 1 extra plot at the top showing the posterior and q.
bar_rows = 8
num_rows = (num_timesteps + 1) * (bar_rows + 1)
gs = gridspec.GridSpec(num_rows, 1)
# Figure out how wide to make the plot
prior_lims = (post[1] * -2, post[1] * 2)
q_lims = (scipy.stats.norm.ppf(0.01, loc=q0_loc, scale=q0_scale),
scipy.stats.norm.ppf(0.99, loc=q0_loc, scale=q0_scale))
state_width = states.max() - states.min()
state_lims = (states.min() - state_width * 0.15,
states.max() + state_width * 0.15)
lims = (min(prior_lims[0], q_lims[0], state_lims[0]),
max(prior_lims[1], q_lims[1], state_lims[1]))
# plot the posterior
z0 = np.arange(lims[0], lims[1], 0.1)
alpha, pos_mu, sigma_sq, B, T = post
neg_mu = -pos_mu
scale = np.sqrt((T + 1) * sigma_sq)
p_zn = (
alpha * scipy.stats.norm.pdf(
observation, loc=pos_mu + B, scale=scale) + (1 - alpha) *
scipy.stats.norm.pdf(observation, loc=neg_mu + B, scale=scale))
p_z0 = (
alpha * scipy.stats.norm.pdf(z0, loc=pos_mu, scale=np.sqrt(sigma_sq))
+ (1 - alpha) * scipy.stats.norm.pdf(
z0, loc=neg_mu, scale=np.sqrt(sigma_sq)))
p_zn_given_z0 = scipy.stats.norm.pdf(
observation, loc=z0 + B, scale=np.sqrt(T * sigma_sq))
post_z0 = (p_z0 * p_zn_given_z0) / p_zn
# plot q
q_z0 = scipy.stats.norm.pdf(z0, loc=q0_loc, scale=q0_scale)
ax = plt.subplot(gs[0:bar_rows, :])
ax.plot(z0, q_z0, color="blue")
ax.plot(z0, post_z0, color="green")
ax.plot(z0, p_z0, color="red")
ax.legend(("q", "posterior", "prior"), loc="best", prop={"size": 10})
ax.set_xticks([])
ax.set_xlim(*lims)
# plot the states
for t in range(num_timesteps):
start = (t + 1) * (bar_rows + 1)
ax1 = plt.subplot(gs[start:start + bar_rows, :])
ax2 = plt.subplot(gs[start + bar_rows:start + bar_rows + 1, :])
# plot the states barplot
# ax1.hist(
# states[t, :],
# weights=weights[t, :],
# bins=50,
# edgecolor="none",
# alpha=0.2)
ax1.bar(states[t,:], weights[t,:], width=0.02, alpha=0.2, edgecolor = "none")
ax1.set_ylabel("t=%d" % t)
ax1.set_xticks([])
ax1.grid(True, which="both")
ax1.set_xlim(*lims)
# plot the observation
ax1.axvline(x=observation, color="red", linestyle="dashed")
# add the ESS
ax1.text(0.1, 0.9, "ESS: %0.2f" % ess[t],
ha='center', va='center', transform=ax1.transAxes)
# plot the state location histogram
ax2.hist2d(
states[t, :], np.zeros_like(states[t, :]), bins=[50, 1], cmap="Greys")
ax2.grid(False)
ax2.set_yticks([])
ax2.set_xlim(*lims)
if t != num_timesteps - 1:
ax2.set_xticks([])
fig.canvas.draw()
p = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
plots.append(p.reshape(fig.canvas.get_width_height()[::-1] + (3,)))
plt.close(fig)
return np.stack(plots)
plots = tf.py_func(_plot_states,
[states, weights, observation, ess, (q0_loc, q0_scale), post],
[tf.uint8])[0]
tf.summary.image("states", plots, 5, collections=["infrequent_summaries"])
def plot_weights(weights, resampled=None):
"""Plots the weights and effective sample size from an SMC rollout.
Args:
weights: [num_timesteps, num_samples, batch_size] importance weights
resampled: [num_timesteps] 0/1 indicating if resampling ocurred
"""
weights = tf.convert_to_tensor(weights)
def _make_plots(weights, resampled):
num_timesteps, num_samples, batch_size = weights.shape
plots = []
for i in range(batch_size):
fig, axes = plt.subplots(nrows=1, sharex=True, figsize=(8, 4))
axes.stackplot(np.arange(num_timesteps), np.transpose(weights[:, :, i]))
axes.set_title("Weights")
axes.set_xlabel("Steps")
axes.set_ylim([0, 1])
axes.set_xlim([0, num_timesteps - 1])
for j in np.where(resampled > 0)[0]:
axes.axvline(x=j, color="red", linestyle="dashed", ymin=0.0, ymax=1.0)
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plots.append(data)
plt.close(fig)
return np.stack(plots, axis=0)
if resampled is None:
num_timesteps, _, batch_size = weights.get_shape().as_list()
resampled = tf.zeros([num_timesteps], dtype=tf.float32)
plots = tf.py_func(_make_plots,
[tf.nn.softmax(weights, dim=1),
tf.to_float(resampled)], [tf.uint8])[0]
batch_size = weights.get_shape().as_list()[-1]
tf.summary.image(
"weights", plots, batch_size, collections=["infrequent_summaries"])
def summarize_weights(weights, num_timesteps, num_samples):
# weights is [num_timesteps, num_samples, batch_size]
weights = tf.convert_to_tensor(weights)
mean = tf.reduce_mean(weights, axis=1, keepdims=True)
squared_diff = tf.square(weights - mean)
variances = tf.reduce_sum(squared_diff, axis=1) / (num_samples - 1)
# average the variance over the batch
variances = tf.reduce_mean(variances, axis=1)
avg_magnitude = tf.reduce_mean(tf.abs(weights), axis=[1, 2])
for t in xrange(num_timesteps):
tf.summary.scalar("weights/variance_%d" % t, variances[t])
tf.summary.scalar("weights/magnitude_%d" % t, avg_magnitude[t])
tf.summary.histogram("weights/step_%d" % t, weights[t])
def summarize_learning_signal(rewards, tag):
num_resampling_events, _ = rewards.get_shape().as_list()
mean = tf.reduce_mean(rewards, axis=1)
avg_magnitude = tf.reduce_mean(tf.abs(rewards), axis=1)
reward_square = tf.reduce_mean(tf.square(rewards), axis=1)
for t in xrange(num_resampling_events):
tf.summary.scalar("%s/mean_%d" % (tag, t), mean[t])
tf.summary.scalar("%s/magnitude_%d" % (tag, t), avg_magnitude[t])
tf.summary.scalar("%s/squared_%d" % (tag, t), reward_square[t])
tf.summary.histogram("%s/step_%d" % (tag, t), rewards[t])
def summarize_qs(model, observation, states):
model.q.summarize_weights()
if hasattr(model.p, "posterior") and callable(getattr(model.p, "posterior")):
states = [tf.zeros_like(states[0])] + states[:-1]
for t, prev_state in enumerate(states):
p = model.p.posterior(observation, prev_state, t)
q = model.q.q_zt(observation, prev_state, t)
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(p, q))
tf.summary.scalar("kl_q/%d" % t, tf.reduce_mean(kl))
mean_diff = q.loc - p.loc
mean_abs_err = tf.abs(mean_diff)
mean_rel_err = tf.abs(mean_diff / p.loc)
tf.summary.scalar("q_mean_convergence/absolute_error_%d" % t,
tf.reduce_mean(mean_abs_err))
tf.summary.scalar("q_mean_convergence/relative_error_%d" % t,
tf.reduce_mean(mean_rel_err))
sigma_diff = tf.square(q.scale) - tf.square(p.scale)
sigma_abs_err = tf.abs(sigma_diff)
sigma_rel_err = tf.abs(sigma_diff / tf.square(p.scale))
tf.summary.scalar("q_variance_convergence/absolute_error_%d" % t,
tf.reduce_mean(sigma_abs_err))
tf.summary.scalar("q_variance_convergence/relative_error_%d" % t,
tf.reduce_mean(sigma_rel_err))
def summarize_rs(model, states):
model.r.summarize_weights()
for t, state in enumerate(states):
true_r = model.p.lookahead(state, t)
r = model.r.r_xn(state, t)
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(true_r, r))
tf.summary.scalar("kl_r/%d" % t, tf.reduce_mean(kl))
mean_diff = true_r.loc - r.loc
mean_abs_err = tf.abs(mean_diff)
mean_rel_err = tf.abs(mean_diff / true_r.loc)
tf.summary.scalar("r_mean_convergence/absolute_error_%d" % t,
tf.reduce_mean(mean_abs_err))
tf.summary.scalar("r_mean_convergence/relative_error_%d" % t,
tf.reduce_mean(mean_rel_err))
sigma_diff = tf.square(r.scale) - tf.square(true_r.scale)
sigma_abs_err = tf.abs(sigma_diff)
sigma_rel_err = tf.abs(sigma_diff / tf.square(true_r.scale))
tf.summary.scalar("r_variance_convergence/absolute_error_%d" % t,
tf.reduce_mean(sigma_abs_err))
tf.summary.scalar("r_variance_convergence/relative_error_%d" % t,
tf.reduce_mean(sigma_rel_err))
def summarize_model(model, true_bs, observation, states, bound, summarize_r=True):
if hasattr(model.p, "bs"):
model_b = tf.reduce_sum(model.p.bs, axis=0)
true_b = tf.reduce_sum(true_bs, axis=0)
abs_err = tf.abs(model_b - true_b)
rel_err = abs_err / true_b
tf.summary.scalar("sum_of_bs/data_generating_process", tf.reduce_mean(true_b))
tf.summary.scalar("sum_of_bs/model", tf.reduce_mean(model_b))
tf.summary.scalar("sum_of_bs/absolute_error", tf.reduce_mean(abs_err))
tf.summary.scalar("sum_of_bs/relative_error", tf.reduce_mean(rel_err))
#summarize_qs(model, observation, states)
#if bound == "fivo-aux" and summarize_r:
# summarize_rs(model, states)
def summarize_grads(grads, loss_name):
grad_ema = tf.train.ExponentialMovingAverage(decay=0.99)
vectorized_grads = tf.concat(
[tf.reshape(g, [-1]) for g, _ in grads if g is not None], axis=0)
new_second_moments = tf.square(vectorized_grads)
new_first_moments = vectorized_grads
maintain_grad_ema_op = grad_ema.apply([new_first_moments, new_second_moments])
first_moments = grad_ema.average(new_first_moments)
second_moments = grad_ema.average(new_second_moments)
variances = second_moments - tf.square(first_moments)
tf.summary.scalar("grad_variance/%s" % loss_name, tf.reduce_mean(variances))
tf.summary.histogram("grad_variance/%s" % loss_name, variances)
tf.summary.histogram("grad_mean/%s" % loss_name, first_moments)
return maintain_grad_ema_op
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main script for running fivo"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import numpy as np
import tensorflow as tf
import bounds
import data
import models
import summary_utils as summ
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.flags.DEFINE_integer("random_seed", None,
"A random seed for the data generating process. Same seed "
"-> same data generating process and initialization.")
tf.app.flags.DEFINE_enum("bound", "fivo", ["iwae", "fivo", "fivo-aux", "fivo-aux-td"],
"The bound to optimize.")
tf.app.flags.DEFINE_enum("model", "forward", ["forward", "long_chain"],
"The model to use.")
tf.app.flags.DEFINE_enum("q_type", "normal",
["normal", "simple_mean", "prev_state", "observation"],
"The parameterization to use for q")
tf.app.flags.DEFINE_enum("p_type", "unimodal", ["unimodal", "bimodal", "nonlinear"],
"The type of prior.")
tf.app.flags.DEFINE_boolean("train_p", True,
"If false, do not train the model p.")
tf.app.flags.DEFINE_integer("state_size", 1,
"The dimensionality of the state space.")
tf.app.flags.DEFINE_float("variance", 1.0,
"The variance of the data generating process.")
tf.app.flags.DEFINE_boolean("use_bs", True,
"If False, initialize all bs to 0.")
tf.app.flags.DEFINE_float("bimodal_prior_weight", 0.5,
"The weight assigned to the positive mode of the prior in "
"both the data generating process and p.")
tf.app.flags.DEFINE_float("bimodal_prior_mean", None,
"If supplied, sets the mean of the 2 modes of the prior to "
"be 1 and -1 times the supplied value. This is for both the "
"data generating process and p.")
tf.app.flags.DEFINE_float("fixed_observation", None,
"If supplied, fix the observation to a constant value in the"
" data generating process only.")
tf.app.flags.DEFINE_float("r_sigma_init", 1.,
"Value to initialize variance of r to.")
tf.app.flags.DEFINE_enum("observation_type",
models.STANDARD_OBSERVATION, models.OBSERVATION_TYPES,
"The type of observation for the long chain model.")
tf.app.flags.DEFINE_enum("transition_type",
models.STANDARD_TRANSITION, models.TRANSITION_TYPES,
"The type of transition for the long chain model.")
tf.app.flags.DEFINE_float("observation_variance", None,
"The variance of the observation. Defaults to 'variance'")
tf.app.flags.DEFINE_integer("num_timesteps", 5,
"Number of timesteps in the sequence.")
tf.app.flags.DEFINE_integer("num_observations", 1,
"The number of observations.")
tf.app.flags.DEFINE_integer("steps_per_observation", 5,
"The number of timesteps between each observation.")
tf.app.flags.DEFINE_integer("batch_size", 4,
"The number of examples per batch.")
tf.app.flags.DEFINE_integer("num_samples", 4,
"The number particles to use.")
tf.app.flags.DEFINE_integer("num_eval_samples", 512,
"The batch size and # of particles to use for eval.")
tf.app.flags.DEFINE_string("resampling", "always",
"How to resample. Accepts 'always','never', or a "
"comma-separated list of booleans like 'true,true,false'.")
tf.app.flags.DEFINE_enum("resampling_method", "multinomial", ["multinomial",
"stratified",
"systematic",
"relaxed-logblend",
"relaxed-stateblend",
"relaxed-linearblend",
"relaxed-stateblend-st",],
"Type of resampling method to use.")
tf.app.flags.DEFINE_boolean("use_resampling_grads", True,
"Whether or not to use resampling grads to optimize FIVO."
"Disabled automatically if resampling_method=relaxed.")
tf.app.flags.DEFINE_boolean("disable_r", False,
"If false, r is not used for fivo-aux and is set to zeros.")
tf.app.flags.DEFINE_float("learning_rate", 1e-4,
"The learning rate to use for ADAM or SGD.")
tf.app.flags.DEFINE_integer("decay_steps", 25000,
"The number of steps before the learning rate is halved.")
tf.app.flags.DEFINE_integer("max_steps", int(1e6),
"The number of steps to run training for.")
tf.app.flags.DEFINE_string("logdir", "/tmp/fivo-aux",
"Directory for summaries and checkpoints.")
tf.app.flags.DEFINE_integer("summarize_every", int(1e3),
"The number of steps between each evaluation.")
FLAGS = tf.app.flags.FLAGS
def combine_grad_lists(grad_lists):
# grads is num_losses by num_variables.
# each list could have different variables.
# for each variable, sum the grads across all losses.
grads_dict = defaultdict(list)
var_dict = {}
for grad_list in grad_lists:
for grad, var in grad_list:
if grad is not None:
grads_dict[var.name].append(grad)
var_dict[var.name] = var
final_grads = []
for var_name, var in var_dict.iteritems():
grads = grads_dict[var_name]
if len(grads) > 0:
tf.logging.info("Var %s has combined grads from %s." %
(var_name, [g.name for g in grads]))
grad = tf.reduce_sum(grads, axis=0)
else:
tf.logging.info("Var %s has no grads" % var_name)
grad = None
final_grads.append((grad, var))
return final_grads
def make_apply_grads_op(losses, global_step, learning_rate, lr_decay_steps):
for l in losses:
assert isinstance(l, bounds.Loss)
lr = tf.train.exponential_decay(
learning_rate, global_step, lr_decay_steps, 0.5, staircase=False)
tf.summary.scalar("learning_rate", lr)
opt = tf.train.AdamOptimizer(lr)
ema_ops = []
grads = []
for loss_name, loss, loss_var_collection in losses:
tf.logging.info("Computing grads of %s w.r.t. vars in collection %s" %
(loss_name, loss_var_collection))
g = opt.compute_gradients(loss,
var_list=tf.get_collection(loss_var_collection))
ema_ops.append(summ.summarize_grads(g, loss_name))
grads.append(g)
all_grads = combine_grad_lists(grads)
apply_grads_op = opt.apply_gradients(all_grads, global_step=global_step)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads_op]):
train_op = tf.group(*ema_ops)
return train_op
def add_check_numerics_ops():
check_op = []
for op in tf.get_default_graph().get_operations():
bad = ["logits/Log", "sample/Reshape", "log_prob/mul",
"log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape",
"entropy/Reshape", "entropy/LogSoftmax", "Categorical", "Mean"]
if all([x not in op.name for x in bad]):
for output in op.outputs:
if output.dtype in [tf.float16, tf.float32, tf.float64]:
if op._get_control_flow_context() is not None: # pylint: disable=protected-access
raise ValueError("`tf.add_check_numerics_ops() is not compatible "
"with TensorFlow control flow operations such as "
"`tf.cond()` or `tf.while_loop()`.")
message = op.name + ":" + str(output.value_index)
with tf.control_dependencies(check_op):
check_op = [tf.check_numerics(output, message=message)]
return tf.group(*check_op)
def create_long_chain_graph(bound, state_size, num_obs, steps_per_obs,
batch_size, num_samples, num_eval_samples,
resampling_schedule, use_resampling_grads,
learning_rate, lr_decay_steps, dtype="float64"):
num_timesteps = num_obs * steps_per_obs + 1
# Make the dataset.
dataset = data.make_long_chain_dataset(
state_size=state_size,
num_obs=num_obs,
steps_per_obs=steps_per_obs,
batch_size=batch_size,
num_samples=num_samples,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=dtype,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Make the dataset for eval
eval_dataset = data.make_long_chain_dataset(
state_size=state_size,
num_obs=num_obs,
steps_per_obs=steps_per_obs,
batch_size=batch_size,
num_samples=num_eval_samples,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=dtype,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation)
eval_itr = eval_dataset.make_one_shot_iterator()
_, eval_observations = eval_itr.get_next()
# Make the model.
model = models.LongChainModel.create(
state_size,
num_obs,
steps_per_obs,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=tf.as_dtype(dtype),
disable_r=FLAGS.disable_r)
# Compute the bound and loss
if bound == "iwae":
(_, losses, ema_op, _, _) = bounds.iwae(
model,
observations,
num_timesteps,
num_samples=num_samples)
(eval_log_p_hat, _, _, _, eval_log_weights) = bounds.iwae(
model,
eval_observations,
num_timesteps,
num_samples=num_eval_samples,
summarize=False)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
elif bound == "fivo" or "fivo-aux":
(_, losses, ema_op, _, _) = bounds.fivo(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=use_resampling_grads,
resampling_type=FLAGS.resampling_method,
aux=("aux" in bound),
num_samples=num_samples)
(eval_log_p_hat, _, _, _, eval_log_weights) = bounds.fivo(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=False,
resampling_type="multinomial",
aux=("aux" in bound),
num_samples=num_eval_samples,
summarize=False)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
summ.summarize_ess(eval_log_weights, only_last_timestep=True)
tf.summary.scalar("log_p_hat", eval_log_p_hat)
# Compute and apply grads.
global_step = tf.train.get_or_create_global_step()
apply_grads = make_apply_grads_op(losses,
global_step,
learning_rate,
lr_decay_steps)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads]):
train_op = tf.group(ema_op)
# We can't calculate the likelihood for most of these models
# so we just return zeros.
eval_likelihood = tf.zeros([], dtype=dtype)
return global_step, train_op, eval_log_p_hat, eval_likelihood
def create_graph(bound, state_size, num_timesteps, batch_size,
num_samples, num_eval_samples, resampling_schedule,
use_resampling_grads, learning_rate, lr_decay_steps,
train_p, dtype='float64'):
if FLAGS.use_bs:
true_bs = None
else:
true_bs = [np.zeros([state_size]).astype(dtype) for _ in xrange(num_timesteps)]
# Make the dataset.
true_bs, dataset = data.make_dataset(
bs=true_bs,
state_size=state_size,
num_timesteps=num_timesteps,
batch_size=batch_size,
num_samples=num_samples,
variance=FLAGS.variance,
prior_type=FLAGS.p_type,
bimodal_prior_weight=FLAGS.bimodal_prior_weight,
bimodal_prior_mean=FLAGS.bimodal_prior_mean,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation,
dtype=dtype)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Make the dataset for eval
_, eval_dataset = data.make_dataset(
bs=true_bs,
state_size=state_size,
num_timesteps=num_timesteps,
batch_size=num_eval_samples,
num_samples=num_eval_samples,
variance=FLAGS.variance,
prior_type=FLAGS.p_type,
bimodal_prior_weight=FLAGS.bimodal_prior_weight,
bimodal_prior_mean=FLAGS.bimodal_prior_mean,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation,
dtype=dtype)
eval_itr = eval_dataset.make_one_shot_iterator()
_, eval_observations = eval_itr.get_next()
# Make the model.
if bound == "fivo-aux-td":
model = models.TDModel.create(
state_size,
num_timesteps,
variance=FLAGS.variance,
train_p=train_p,
p_type=FLAGS.p_type,
q_type=FLAGS.q_type,
mixing_coeff=FLAGS.bimodal_prior_weight,
prior_mode_mean=FLAGS.bimodal_prior_mean,
observation_variance=FLAGS.observation_variance,
transition_type=FLAGS.transition_type,
use_bs=FLAGS.use_bs,
dtype=tf.as_dtype(dtype),
random_seed=FLAGS.random_seed)
else:
model = models.Model.create(
state_size,
num_timesteps,
variance=FLAGS.variance,
train_p=train_p,
p_type=FLAGS.p_type,
q_type=FLAGS.q_type,
mixing_coeff=FLAGS.bimodal_prior_weight,
prior_mode_mean=FLAGS.bimodal_prior_mean,
observation_variance=FLAGS.observation_variance,
transition_type=FLAGS.transition_type,
use_bs=FLAGS.use_bs,
r_sigma_init=FLAGS.r_sigma_init,
dtype=tf.as_dtype(dtype),
random_seed=FLAGS.random_seed)
# Compute the bound and loss
if bound == "iwae":
(_, losses, ema_op, _, _) = bounds.iwae(
model,
observations,
num_timesteps,
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.iwae(
model,
eval_observations,
num_timesteps,
num_samples=num_eval_samples,
summarize=True)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
elif "fivo" in bound:
if bound == "fivo-aux-td":
(_, losses, ema_op, _, _) = bounds.fivo_aux_td(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo_aux_td(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
num_samples=num_eval_samples,
summarize=True)
else:
(_, losses, ema_op, _, _) = bounds.fivo(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=use_resampling_grads,
resampling_type=FLAGS.resampling_method,
aux=("aux" in bound),
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=False,
resampling_type="multinomial",
aux=("aux" in bound),
num_samples=num_eval_samples,
summarize=True)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
summ.summarize_ess(eval_log_weights, only_last_timestep=True)
# if FLAGS.p_type == "bimodal":
# # create the observations that showcase the model.
# mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.],
# dtype=tf.float64)
# mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1)
# k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean)
# explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k
# explain_obs = tf.tile(explain_obs, [num_eval_samples, 1])
# # run the model on the explainable observations
# if bound == "iwae":
# (_, _, _, explain_states, explain_log_weights) = bounds.iwae(
# model,
# explain_obs,
# num_timesteps,
# num_samples=num_eval_samples)
# elif bound == "fivo" or "fivo-aux":
# (_, _, _, explain_states, explain_log_weights) = bounds.fivo(
# model,
# explain_obs,
# num_timesteps,
# resampling_schedule=resampling_schedule,
# use_resampling_grads=False,
# resampling_type="multinomial",
# aux=("aux" in bound),
# num_samples=num_eval_samples)
# summ.summarize_particles(explain_states,
# explain_log_weights,
# explain_obs,
# model)
# Calculate the true likelihood.
if hasattr(model.p, 'likelihood') and callable(getattr(model.p, 'likelihood')):
eval_likelihood = model.p.likelihood(eval_observations)/ FLAGS.num_timesteps
else:
eval_likelihood = tf.zeros_like(eval_log_p_hat)
tf.summary.scalar("log_p_hat", eval_log_p_hat)
tf.summary.scalar("likelihood", eval_likelihood)
tf.summary.scalar("bound_gap", eval_likelihood - eval_log_p_hat)
summ.summarize_model(model, true_bs, eval_observations, eval_states, bound,
summarize_r=not bound == "fivo-aux-td")
# Compute and apply grads.
global_step = tf.train.get_or_create_global_step()
apply_grads = make_apply_grads_op(losses,
global_step,
learning_rate,
lr_decay_steps)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads]):
train_op = tf.group(ema_op)
#train_op = tf.group(ema_op, add_check_numerics_ops())
return global_step, train_op, eval_log_p_hat, eval_likelihood
def parse_resampling_schedule(schedule, num_timesteps):
schedule = schedule.strip().lower()
if schedule == "always":
return [True] * (num_timesteps - 1) + [False]
elif schedule == "never":
return [False] * num_timesteps
elif "every" in schedule:
n = int(schedule.split("_")[1])
return [(i+1) % n == 0 for i in xrange(num_timesteps)]
else:
sched = [x.strip() == "true" for x in schedule.split(",")]
assert len(
sched
) == num_timesteps, "Wrong number of timesteps in resampling schedule."
return sched
def create_log_hook(step, eval_log_p_hat, eval_likelihood):
def summ_formatter(d):
return ("Step {step}, log p_hat: {log_p_hat:.5f} likelihood: {likelihood:.5f}".format(**d))
hook = tf.train.LoggingTensorHook(
{
"step": step,
"log_p_hat": eval_log_p_hat,
"likelihood": eval_likelihood,
},
every_n_iter=FLAGS.summarize_every,
formatter=summ_formatter)
return hook
def create_infrequent_summary_hook():
infrequent_summary_hook = tf.train.SummarySaverHook(
save_steps=10000,
output_dir=FLAGS.logdir,
summary_op=tf.summary.merge_all(key="infrequent_summaries")
)
return infrequent_summary_hook
def main(unused_argv):
if FLAGS.model == "long_chain":
resampling_schedule = parse_resampling_schedule(FLAGS.resampling,
FLAGS.num_timesteps + 1)
else:
resampling_schedule = parse_resampling_schedule(FLAGS.resampling,
FLAGS.num_timesteps)
if FLAGS.random_seed is None:
seed = np.random.randint(0, high=10000)
else:
seed = FLAGS.random_seed
tf.logging.info("Using random seed %d", seed)
if FLAGS.model == "long_chain":
assert FLAGS.q_type == "normal", "Q type %s not supported for long chain models" % FLAGS.q_type
assert FLAGS.p_type == "unimodal", "Bimodal priors are not supported for long chain models"
assert not FLAGS.use_bs, "Bs are not supported with long chain models"
assert FLAGS.num_timesteps == FLAGS.num_observations * FLAGS.steps_per_observation, "Num timesteps does not match."
assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with long chain models."
if FLAGS.model == "forward":
if "nonlinear" not in FLAGS.p_type:
assert FLAGS.transition_type == models.STANDARD_TRANSITION, "Non-standard transitions not supported by the forward model."
assert FLAGS.observation_type == models.STANDARD_OBSERVATION, "Non-standard observations not supported by the forward model."
assert FLAGS.observation_variance is None, "Forward model does not support observation variance."
assert FLAGS.num_observations == 1, "Forward model only supports 1 observation."
if "relaxed" in FLAGS.resampling_method:
FLAGS.use_resampling_grads = False
assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with relaxed resampling."
if FLAGS.observation_variance is None:
FLAGS.observation_variance = FLAGS.variance
if FLAGS.p_type == "bimodal":
assert FLAGS.bimodal_prior_mean is not None, "Must specify prior mean if using bimodal p."
if FLAGS.p_type == "nonlinear" or FLAGS.p_type == "nonlinear-cauchy":
assert not FLAGS.use_bs, "Using bs is not compatible with the nonlinear model."
g = tf.Graph()
with g.as_default():
# Set the seeds.
tf.set_random_seed(seed)
np.random.seed(seed)
if FLAGS.model == "long_chain":
(global_step, train_op, eval_log_p_hat,
eval_likelihood) = create_long_chain_graph(
FLAGS.bound,
FLAGS.state_size,
FLAGS.num_observations,
FLAGS.steps_per_observation,
FLAGS.batch_size,
FLAGS.num_samples,
FLAGS.num_eval_samples,
resampling_schedule,
FLAGS.use_resampling_grads,
FLAGS.learning_rate,
FLAGS.decay_steps)
else:
(global_step, train_op,
eval_log_p_hat, eval_likelihood) = create_graph(
FLAGS.bound,
FLAGS.state_size,
FLAGS.num_timesteps,
FLAGS.batch_size,
FLAGS.num_samples,
FLAGS.num_eval_samples,
resampling_schedule,
FLAGS.use_resampling_grads,
FLAGS.learning_rate,
FLAGS.decay_steps,
FLAGS.train_p)
log_hooks = [create_log_hook(global_step, eval_log_p_hat, eval_likelihood)]
if len(tf.get_collection("infrequent_summaries")) > 0:
log_hooks.append(create_infrequent_summary_hook())
tf.logging.info("trainable variables:")
tf.logging.info([v.name for v in tf.trainable_variables()])
tf.logging.info("p vars:")
tf.logging.info([v.name for v in tf.get_collection("P_VARS")])
tf.logging.info("q vars:")
tf.logging.info([v.name for v in tf.get_collection("Q_VARS")])
tf.logging.info("r vars:")
tf.logging.info([v.name for v in tf.get_collection("R_VARS")])
tf.logging.info("r tilde vars:")
tf.logging.info([v.name for v in tf.get_collection("R_TILDE_VARS")])
with tf.train.MonitoredTrainingSession(
master="",
is_chief=True,
hooks=log_hooks,
checkpoint_dir=FLAGS.logdir,
save_checkpoint_secs=120,
save_summaries_steps=FLAGS.summarize_every,
log_step_count_steps=FLAGS.summarize_every) as sess:
cur_step = -1
while True:
if sess.should_stop() or cur_step > FLAGS.max_steps:
break
# run a step
_, cur_step = sess.run([train_op, global_step])
if __name__ == "__main__":
tf.app.run(main)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of objectives for training stochastic latent variable models.
Contains implementations of the Importance Weighted Autoencoder objective (IWAE)
and the Filtering Variational objective (FIVO).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
from fivo import nested_utils as nested
from fivo import smc
def iwae(model,
observations,
seq_lengths,
num_samples=1,
parallel_iterations=30,
swap_memory=True):
"""Computes the IWAE lower bound on the log marginal probability.
This method accepts a stochastic latent variable model and some observations
and computes a stochastic lower bound on the log marginal probability of the
observations. The IWAE estimator is defined by averaging multiple importance
weights. For more details see "Importance Weighted Autoencoders" by Burda
et al. https://arxiv.org/abs/1509.00519.
When num_samples = 1, this bound becomes the evidence lower bound (ELBO).
Args:
model: A subclass of ELBOTrainableSequenceModel that implements one
timestep of the model. See models/vrnn.py for an example.
observations: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. The model
will be provided with the observations before computing the bound.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of samples to use.
parallel_iterations: The number of parallel iterations to use for the
internal while loop.
swap_memory: Whether GPU-CPU memory swapping should be enabled for the
internal while loop.
Returns:
log_p_hat: A Tensor of shape [batch_size] containing IWAE's estimate of the
log marginal probability of the observations.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep. Will not be valid for
timesteps past the end of a sequence.
"""
log_p_hat, log_weights, _, final_state = fivo(
model,
observations,
seq_lengths,
num_samples=num_samples,
resampling_criterion=smc.never_resample_criterion,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
return log_p_hat, log_weights, final_state
def fivo(model,
observations,
seq_lengths,
num_samples=1,
resampling_criterion=smc.ess_criterion,
resampling_type='multinomial',
relaxed_resampling_temperature=0.5,
parallel_iterations=30,
swap_memory=True,
random_seed=None):
"""Computes the FIVO lower bound on the log marginal probability.
This method accepts a stochastic latent variable model and some observations
and computes a stochastic lower bound on the log marginal probability of the
observations. The lower bound is defined by a particle filter's unbiased
estimate of the marginal probability of the observations. For more details see
"Filtering Variational Objectives" by Maddison et al.
https://arxiv.org/abs/1705.09279.
When the resampling criterion is "never resample", this bound becomes IWAE.
Args:
model: A subclass of ELBOTrainableSequenceModel that implements one
timestep of the model. See models/vrnn.py for an example.
observations: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. The model
will be provided with the observations before computing the bound.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of particles to use in each particle filter.
resampling_criterion: The resampling criterion to use for this particle
filter. Must accept the number of samples, the current log weights,
and the current timestep and return a boolean Tensor of shape [batch_size]
indicating whether each particle filter should resample. See
ess_criterion and related functions for examples. When
resampling_criterion is never_resample_criterion, resampling_fn is ignored
and never called.
resampling_type: The type of resampling, one of "multinomial" or "relaxed".
relaxed_resampling_temperature: A positive temperature only used for relaxed
resampling.
parallel_iterations: The number of parallel iterations to use for the
internal while loop. Note that values greater than 1 can introduce
non-determinism even when random_seed is provided.
swap_memory: Whether GPU-CPU memory swapping should be enabled for the
internal while loop.
random_seed: The random seed to pass to the resampling operations in
the particle filter. Mainly useful for testing.
Returns:
log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
log marginal probability of the observations.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep of the particle filter. Note
that on timesteps when a resampling operation is performed the log weights
are reset to 0. Will not be valid for timesteps past the end of a
sequence.
resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
particle filters resampled. Will be 1.0 on timesteps when resampling
occurred and 0.0 on timesteps when it did not.
"""
# batch_size is the number of particle filters running in parallel.
batch_size = tf.shape(seq_lengths)[0]
# Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as:
# particle 1 of particle filter 1
# particle 1 of particle filter 2
# ...
# particle 1 of particle filter batch_size
# particle 2 of particle filter 1
# ...
# particle num_samples of particle filter batch_size
observations = nested.tile_tensors(observations, [1, num_samples])
tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
model.set_observations(observations, tiled_seq_lengths)
if resampling_type == 'multinomial':
resampling_fn = smc.multinomial_resampling
elif resampling_type == 'relaxed':
resampling_fn = functools.partial(
smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)
def transition_fn(prev_state, t):
if prev_state is None:
return model.zero_state(batch_size * num_samples, tf.float32)
return model.propose_and_weight(prev_state, t)
log_p_hat, log_weights, resampled, final_state, _ = smc.smc(
transition_fn,
seq_lengths,
num_particles=num_samples,
resampling_criterion=resampling_criterion,
resampling_fn=resampling_fn,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
return log_p_hat, log_weights, resampled, final_state
def fivo_aux_td(
model,
observations,
seq_lengths,
num_samples=1,
resampling_criterion=smc.ess_criterion,
resampling_type='multinomial',
relaxed_resampling_temperature=0.5,
parallel_iterations=30,
swap_memory=True,
random_seed=None):
"""Experimental."""
# batch_size is the number of particle filters running in parallel.
batch_size = tf.shape(seq_lengths)[0]
max_seq_len = tf.reduce_max(seq_lengths)
# Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as:
# particle 1 of particle filter 1
# particle 1 of particle filter 2
# ...
# particle 1 of particle filter batch_size
# particle 2 of particle filter 1
# ...
# particle num_samples of particle filter batch_size
observations = nested.tile_tensors(observations, [1, num_samples])
tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
model.set_observations(observations, tiled_seq_lengths)
if resampling_type == 'multinomial':
resampling_fn = smc.multinomial_resampling
elif resampling_type == 'relaxed':
resampling_fn = functools.partial(
smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)
def transition_fn(prev_state, t):
if prev_state is None:
model_init_state = model.zero_state(batch_size * num_samples, tf.float32)
return (tf.zeros([num_samples*batch_size], dtype=tf.float32),
(tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32),
tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32)),
model_init_state)
prev_log_r, prev_log_r_tilde, prev_model_state = prev_state
(new_model_state, zt, log_q_zt, log_p_zt,
log_p_x_given_z, log_r_tilde, p_ztplus1) = model(prev_model_state, t)
r_tilde_mu, r_tilde_sigma_sq = log_r_tilde
# Compute the weight without r.
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
# Compute log_r and log_r_tilde.
p_mu = tf.stop_gradient(p_ztplus1.mean())
p_sigma_sq = tf.stop_gradient(p_ztplus1.variance())
log_r = (tf.log(r_tilde_sigma_sq) -
tf.log(r_tilde_sigma_sq + p_sigma_sq) -
tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
# log_r is [num_samples*batch_size, latent_size]. We sum it along the last
# dimension to compute log r.
log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
# Compute prev log r tilde
prev_r_tilde_mu, prev_r_tilde_sigma_sq = prev_log_r_tilde
prev_log_r_tilde = -0.5*tf.reduce_sum(
tf.square(tf.stop_gradient(zt) - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
# If the sequence is on the last timestep, log_r and log_r_tilde are just zeros.
last_timestep = t >= (tiled_seq_lengths - 1)
log_r = tf.where(last_timestep,
tf.zeros_like(log_r),
log_r)
prev_log_r_tilde = tf.where(last_timestep,
tf.zeros_like(prev_log_r_tilde),
prev_log_r_tilde)
log_weight += tf.stop_gradient(log_r - prev_log_r)
new_state = (log_r, log_r_tilde, new_model_state)
loop_fn_args = (log_r, prev_log_r_tilde, log_p_x_given_z, log_r - prev_log_r)
return log_weight, new_state, loop_fn_args
def loop_fn(loop_state, loop_args, unused_model_state, log_weights, resampled, mask, t):
if loop_state is None:
return (tf.zeros([batch_size], dtype=tf.float32),
tf.zeros([batch_size], dtype=tf.float32),
tf.zeros([num_samples, batch_size], dtype=tf.float32))
log_p_hat_acc, bellman_loss_acc, log_r_diff_acc = loop_state
log_r, prev_log_r_tilde, log_p_x_given_z, log_r_diff = loop_args
# Compute the log_p_hat update
log_p_hat_update = tf.reduce_logsumexp(
log_weights, axis=0) - tf.log(tf.to_float(num_samples))
# If it is the last timestep, we always add the update.
log_p_hat_acc += tf.cond(t >= max_seq_len-1,
lambda: log_p_hat_update,
lambda: log_p_hat_update * resampled)
# Compute the Bellman update.
log_r = tf.reshape(log_r, [num_samples, batch_size])
prev_log_r_tilde = tf.reshape(prev_log_r_tilde, [num_samples, batch_size])
log_p_x_given_z = tf.reshape(log_p_x_given_z, [num_samples, batch_size])
mask = tf.reshape(mask, [num_samples, batch_size])
# On the first timestep there is no bellman error because there is no
# prev_log_r_tilde.
mask = tf.cond(tf.equal(t, 0),
lambda: tf.zeros_like(mask),
lambda: mask)
# On the first timestep also fix up prev_log_r_tilde, which will be -inf.
prev_log_r_tilde = tf.where(
tf.is_inf(prev_log_r_tilde),
tf.zeros_like(prev_log_r_tilde),
prev_log_r_tilde)
# log_lambda is [num_samples, batch_size]
log_lambda = tf.reduce_mean(prev_log_r_tilde - log_p_x_given_z - log_r,
axis=0, keepdims=True)
bellman_error = mask * tf.square(
prev_log_r_tilde -
tf.stop_gradient(log_lambda + log_p_x_given_z + log_r)
)
bellman_loss_acc += tf.reduce_mean(bellman_error, axis=0)
# Compute the log_r_diff update
log_r_diff_acc += mask * tf.reshape(log_r_diff, [num_samples, batch_size])
return (log_p_hat_acc, bellman_loss_acc, log_r_diff_acc)
log_weights, resampled, accs = smc.smc(
transition_fn,
seq_lengths,
num_particles=num_samples,
resampling_criterion=resampling_criterion,
resampling_fn=resampling_fn,
loop_fn=loop_fn,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
log_p_hat, bellman_loss, log_r_diff = accs
loss_per_seq = [- log_p_hat, bellman_loss]
tf.summary.scalar("bellman_loss",
tf.reduce_mean(bellman_loss / tf.to_float(seq_lengths)))
tf.summary.scalar("log_r_diff",
tf.reduce_mean(tf.reduce_mean(log_r_diff, axis=0) / tf.to_float(seq_lengths)))
return loss_per_seq, log_p_hat, log_weights, resampled
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.bounds"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from fivo.test_utils import create_vrnn
from fivo import bounds
class BoundsTest(tf.test.TestCase):
def test_elbo(self):
"""A golden-value test for the ELBO (the IWAE bound with num_samples=1)."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.iwae(model, (inputs, targets), lengths, num_samples=1,
parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, _, _ = sess.run(outs)
self.assertAllClose([-21.615765, -13.614225], log_p_hat)
def test_iwae(self):
"""A golden-value test for the IWAE bound."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.iwae(model, (inputs, targets), lengths, num_samples=4,
parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, weights, _ = sess.run(outs)
self.assertAllClose([-23.301426, -13.64028], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074022, -4.91751671, -5.03293562],
[-2.99690723, -3.17782736, -4.50084877, -3.48536515]],
[[-6.2539978, -4.37615728, -7.43738699, -7.85044909],
[-8.27518654, -6.71545124, -8.96198845, -7.05567837]],
[[-9.19093227, -8.01637268, -11.64603615, -10.51128292],
[-12.34527206, -11.54284477, -11.8667469, -9.69417381]],
[[-12.20609856, -10.47217369, -13.66270638, -13.46115875],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-16.14766312, -15.57472229, -17.47755432, -17.98189926],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-20.07182884, -18.43191147, -20.1606636, -21.45263863],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-24.10270691, -22.20865822, -24.14675522, -25.27248383],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]]])
self.assertAllClose(weights_gt, weights)
def test_fivo(self):
"""A golden-value test for the FIVO bound."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-22.98902512, -14.21689224], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074022, -4.91751671, -5.03293562],
[-2.99690723, -3.17782736, -4.50084877, -3.48536515]],
[[-2.67100811, -2.30541706, -2.34178066, -2.81751347],
[-8.27518654, -6.71545124, -8.96198845, -7.05567837]],
[[-5.65190411, -5.94563246, -6.55041981, -5.4783473],
[-12.34527206, -11.54284477, -11.8667469, -9.69417381]],
[[-8.71947861, -8.40143299, -8.54593086, -8.42822266],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-12.7003831, -13.5039815, -12.3569726, -12.9489622],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-16.4520301, -16.3611698, -15.0314846, -16.4197006],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-20.7010765, -20.1379165, -19.0020351, -20.2395458],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array(
[[1., 0.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
def test_fivo_relaxed(self):
"""A golden-value test for the FIVO bound with relaxed sampling."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1,
resampling_type="relaxed")
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-22.942394, -14.273882], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074118, -4.91751575, -5.03293514],
[-2.99690628, -3.17782831, -4.50084877, -3.48536515]],
[[-2.84939098, -2.30087185, -2.35649204, -2.48417377],
[-8.27518654, -6.71545172, -8.96199131, -7.05567837]],
[[-5.92327023, -5.9433074, -6.5826683, -5.04259014],
[-12.34527206, -11.54284668, -11.86675072, -9.69417477]],
[[-8.95323944, -8.40061855, -8.52760506, -7.99130583],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-12.87836456, -13.49628639, -12.31680107, -12.74228859],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-16.78347397, -16.35150909, -14.98797417, -16.35162735],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-20.81165886, -20.1307621, -18.92229652, -20.17458153],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array(
[[1., 0.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
def test_fivo_aux_relaxed(self):
"""A golden-value test for the FIVO-AUX bound with relaxed sampling."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234,
use_tilt=True)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1,
resampling_type="relaxed")
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-23.1395, -14.271059], log_p_hat)
weights_gt = np.array(
[[[-5.19826221, -3.55476403, -5.98663855, -6.08058834],
[-6.31685925, -5.70243931, -7.07638931, -6.18138981]],
[[-3.97986865, -3.58831525, -3.85753584, -3.5010016],
[-11.38203049, -8.66213989, -11.23646641, -10.02024746]],
[[-6.62269831, -6.36680222, -6.78096485, -5.80072498],
[-3.55419445, -8.11326408, -3.48766923, -3.08593249]],
[[-10.56472301, -10.16084099, -9.96741676, -8.5270071],
[-6.04880285, -7.80853653, -4.72652149, -3.49711013]],
[[-13.36585426, -16.08720398, -13.33416367, -13.1017189],
[-0., -0., -0., -0.]],
[[-17.54233551, -17.35167503, -16.79163361, -16.51471138],
[0., -0., -0., -0.]],
[[-19.74024963, -18.69452858, -17.76246452, -18.76182365],
[0., -0., -0., -0.]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array([[1., 0.],
[0., 1.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
if __name__ == "__main__":
np.set_printoptions(threshold=np.nan) # Used to easily see the gold values.
# Use print(repr(numpy_array)) to print the values.
tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to calculate the mean of a pianoroll dataset.
Given a pianoroll pickle file, this script loads the dataset and
calculates the mean of the training set. Then it updates the pickle file
so that the key "train_mean" points to the mean vector.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import numpy as np
import tensorflow as tf
from datasets import sparse_pianoroll_to_dense
tf.app.flags.DEFINE_string('in_file', None,
'Filename of the pickled pianoroll dataset to load.')
tf.app.flags.DEFINE_string('out_file', None,
'Name of the output pickle file. Defaults to in_file, '
'updating the input pickle file.')
tf.app.flags.mark_flag_as_required('in_file')
FLAGS = tf.app.flags.FLAGS
MIN_NOTE = 21
MAX_NOTE = 108
NUM_NOTES = MAX_NOTE - MIN_NOTE + 1
def main(unused_argv):
if FLAGS.out_file is None:
FLAGS.out_file = FLAGS.in_file
with tf.gfile.Open(FLAGS.in_file, 'r') as f:
pianorolls = pickle.load(f)
dense_pianorolls = [sparse_pianoroll_to_dense(p, MIN_NOTE, NUM_NOTES)[0]
for p in pianorolls['train']]
# Concatenate all elements along the time axis.
concatenated = np.concatenate(dense_pianorolls, axis=0)
mean = np.mean(concatenated, axis=0)
pianorolls['train_mean'] = mean
# Write out the whole pickle file, including the train mean.
pickle.dump(pianorolls, open(FLAGS.out_file, 'wb'))
if __name__ == '__main__':
tf.app.run()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocesses TIMIT from raw wavfiles to create a set of TFRecords.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
import random
import re
import numpy as np
import tensorflow as tf
tf.app.flags.DEFINE_string("raw_timit_dir", None,
"Directory containing TIMIT files.")
tf.app.flags.DEFINE_string("out_dir", None,
"Output directory for TFRecord files.")
tf.app.flags.DEFINE_float("valid_frac", 0.05,
"Fraction of train set to use as valid set. "
"Must be between 0.0 and 1.0.")
tf.app.flags.mark_flag_as_required("raw_timit_dir")
tf.app.flags.mark_flag_as_required("out_dir")
FLAGS = tf.app.flags.FLAGS
NUM_TRAIN_FILES = 4620
NUM_TEST_FILES = 1680
SAMPLES_PER_TIMESTEP = 200
# Regexes for reading SPHERE header files.
SAMPLE_COUNT_REGEX = re.compile(r"sample_count -i (\d+)")
SAMPLE_MIN_REGEX = re.compile(r"sample_min -i (-?\d+)")
SAMPLE_MAX_REGEX = re.compile(r"sample_max -i (-?\d+)")
def get_filenames(split):
"""Get all wav filenames from the TIMIT archive."""
path = os.path.join(FLAGS.raw_timit_dir, "TIMIT", split, "*", "*", "*.WAV")
# Sort the output by name so the order is deterministic.
files = sorted(glob.glob(path))
return files
def load_timit_wav(filename):
"""Loads a TIMIT wavfile into a numpy array.
TIMIT wavfiles include a SPHERE header, detailed in the TIMIT docs. The first
line is the header type and the second is the length of the header in bytes.
After the header, the remaining bytes are actual WAV data.
The header includes information about the WAV data such as the number of
samples and minimum and maximum amplitude. This function asserts that the
loaded wav data matches the header.
Args:
filename: The name of the TIMIT wavfile to load.
Returns:
wav: A numpy array containing the loaded wav data.
"""
wav_file = open(filename, "rb")
header_type = wav_file.readline()
header_length_str = wav_file.readline()
# The header length includes the length of the first two lines.
header_remaining_bytes = (int(header_length_str) - len(header_type) -
len(header_length_str))
header = wav_file.read(header_remaining_bytes)
# Read the relevant header fields.
sample_count = int(SAMPLE_COUNT_REGEX.search(header).group(1))
sample_min = int(SAMPLE_MIN_REGEX.search(header).group(1))
sample_max = int(SAMPLE_MAX_REGEX.search(header).group(1))
wav = np.fromstring(wav_file.read(), dtype="int16").astype("float32")
# Check that the loaded data conforms to the header description.
assert len(wav) == sample_count
assert wav.min() == sample_min
assert wav.max() == sample_max
return wav
def preprocess(wavs, block_size, mean, std):
"""Normalize the wav data and reshape it into chunks."""
processed_wavs = []
for wav in wavs:
wav = (wav - mean) / std
wav_length = wav.shape[0]
if wav_length % block_size != 0:
pad_width = block_size - (wav_length % block_size)
wav = np.pad(wav, (0, pad_width), "constant")
assert wav.shape[0] % block_size == 0
wav = wav.reshape((-1, block_size))
processed_wavs.append(wav)
return processed_wavs
def create_tfrecord_from_wavs(wavs, output_file):
"""Writes processed wav files to disk as sharded TFRecord files."""
with tf.python_io.TFRecordWriter(output_file) as builder:
for wav in wavs:
builder.write(wav.astype(np.float32).tobytes())
def main(unused_argv):
train_filenames = get_filenames("TRAIN")
test_filenames = get_filenames("TEST")
num_train_files = len(train_filenames)
num_test_files = len(test_filenames)
num_valid_files = int(num_train_files * FLAGS.valid_frac)
num_train_files -= num_valid_files
print("%d train / %d valid / %d test" % (
num_train_files, num_valid_files, num_test_files))
random.seed(1234)
random.shuffle(train_filenames)
valid_filenames = train_filenames[:num_valid_files]
train_filenames = train_filenames[num_valid_files:]
# Make sure there is no overlap in the train, test, and valid sets.
train_s = set(train_filenames)
test_s = set(test_filenames)
valid_s = set(valid_filenames)
# Disable explicit length testing to make the assertions more readable.
# pylint: disable=g-explicit-length-test
assert len(train_s & test_s) == 0
assert len(train_s & valid_s) == 0
assert len(valid_s & test_s) == 0
# pylint: enable=g-explicit-length-test
train_wavs = [load_timit_wav(f) for f in train_filenames]
valid_wavs = [load_timit_wav(f) for f in valid_filenames]
test_wavs = [load_timit_wav(f) for f in test_filenames]
assert len(train_wavs) + len(valid_wavs) == NUM_TRAIN_FILES
assert len(test_wavs) == NUM_TEST_FILES
# Calculate the mean and standard deviation of the train set.
train_stacked = np.hstack(train_wavs)
train_mean = np.mean(train_stacked)
train_std = np.std(train_stacked)
print("train mean: %f train std: %f" % (train_mean, train_std))
# Process all data, normalizing with the train set statistics.
processed_train_wavs = preprocess(train_wavs, SAMPLES_PER_TIMESTEP,
train_mean, train_std)
processed_valid_wavs = preprocess(valid_wavs, SAMPLES_PER_TIMESTEP,
train_mean, train_std)
processed_test_wavs = preprocess(test_wavs, SAMPLES_PER_TIMESTEP, train_mean,
train_std)
# Write the datasets to disk.
create_tfrecord_from_wavs(
processed_train_wavs,
os.path.join(FLAGS.out_dir, "train"))
create_tfrecord_from_wavs(
processed_valid_wavs,
os.path.join(FLAGS.out_dir, "valid"))
create_tfrecord_from_wavs(
processed_test_wavs,
os.path.join(FLAGS.out_dir, "test"))
if __name__ == "__main__":
tf.app.run()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Code for creating sequence datasets.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import numpy as np
from scipy.sparse import coo_matrix
import tensorflow as tf
# The default number of threads used to process data in parallel.
DEFAULT_PARALLELISM = 12
def sparse_pianoroll_to_dense(pianoroll, min_note, num_notes):
"""Converts a sparse pianoroll to a dense numpy array.
Given a sparse pianoroll, converts it to a dense numpy array of shape
[num_timesteps, num_notes] where entry i,j is 1.0 if note j is active on
timestep i and 0.0 otherwise.
Args:
pianoroll: A sparse pianoroll object, a list of tuples where the i'th tuple
contains the indices of the notes active at timestep i.
min_note: The minimum note in the pianoroll, subtracted from all notes so
that the minimum note becomes 0.
num_notes: The number of possible different note indices, determines the
second dimension of the resulting dense array.
Returns:
dense_pianoroll: A [num_timesteps, num_notes] numpy array of floats.
num_timesteps: A python int, the number of timesteps in the pianoroll.
"""
num_timesteps = len(pianoroll)
inds = []
for time, chord in enumerate(pianoroll):
# Re-index the notes to start from min_note.
inds.extend((time, note-min_note) for note in chord)
shape = [num_timesteps, num_notes]
values = [1.] * len(inds)
sparse_pianoroll = coo_matrix(
(values, ([x[0] for x in inds], [x[1] for x in inds])),
shape=shape)
return sparse_pianoroll.toarray(), num_timesteps
def create_pianoroll_dataset(path,
split,
batch_size,
num_parallel_calls=DEFAULT_PARALLELISM,
shuffle=False,
repeat=False,
min_note=21,
max_note=108):
"""Creates a pianoroll dataset.
Args:
path: The path of a pickle file containing the dataset to load.
split: The split to use, can be train, test, or valid.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
min_note: The minimum note number of the dataset. For all pianoroll datasets
the minimum note is number 21, and changing this affects the dimension of
the data. This is useful mostly for testing.
max_note: The maximum note number of the dataset. For all pianoroll datasets
the maximum note is number 108, and changing this affects the dimension of
the data. This is useful mostly for testing.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros. This tensor is mean-centered, with the mean taken from the pickle
file key 'train_mean'.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
mean: A float Tensor of shape [data_dimension] containing the mean loaded
from the pickle file.
"""
# Load the data from disk.
num_notes = max_note - min_note + 1
with tf.gfile.Open(path, "r") as f:
raw_data = pickle.load(f)
pianorolls = raw_data[split]
mean = raw_data["train_mean"]
num_examples = len(pianorolls)
def pianoroll_generator():
for sparse_pianoroll in pianorolls:
yield sparse_pianoroll_to_dense(sparse_pianoroll, min_note, num_notes)
dataset = tf.data.Dataset.from_generator(
pianoroll_generator,
output_types=(tf.float64, tf.int64),
output_shapes=([None, num_notes], []))
if repeat: dataset = dataset.repeat()
if shuffle: dataset = dataset.shuffle(num_examples)
# Batch sequences togther, padding them to a common length in time.
dataset = dataset.padded_batch(batch_size,
padded_shapes=([None, num_notes], []))
def process_pianoroll_batch(data, lengths):
"""Create mean-centered and time-major next-step prediction Tensors."""
data = tf.to_float(tf.transpose(data, perm=[1, 0, 2]))
lengths = tf.to_int32(lengths)
targets = data
# Mean center the inputs.
inputs = data - tf.constant(mean, dtype=tf.float32,
shape=[1, 1, mean.shape[0]])
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs = tf.pad(inputs, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
# Mask out unused timesteps.
inputs *= tf.expand_dims(tf.transpose(
tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
return inputs, targets, lengths
dataset = dataset.map(process_pianoroll_batch,
num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(num_examples)
itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32)
def create_human_pose_dataset(
path,
split,
batch_size,
num_parallel_calls=DEFAULT_PARALLELISM,
shuffle=False,
repeat=False,):
"""Creates a human pose dataset.
Args:
path: The path of a pickle file containing the dataset to load.
split: The split to use, can be train, test, or valid.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros. This tensor is mean-centered, with the mean taken from the pickle
file key 'train_mean'.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
mean: A float Tensor of shape [data_dimension] containing the mean loaded
from the pickle file.
"""
# Load the data from disk.
with tf.gfile.Open(path, "r") as f:
raw_data = pickle.load(f)
mean = raw_data["train_mean"]
pose_sequences = raw_data[split]
num_examples = len(pose_sequences)
num_features = pose_sequences[0].shape[1]
def pose_generator():
"""A generator that yields pose data sequences."""
# Each timestep has 32 x values followed by 32 y values so is 64
# dimensional.
for pose_sequence in pose_sequences:
yield pose_sequence, pose_sequence.shape[0]
dataset = tf.data.Dataset.from_generator(
pose_generator,
output_types=(tf.float64, tf.int64),
output_shapes=([None, num_features], []))
if repeat:
dataset = dataset.repeat()
if shuffle:
dataset = dataset.shuffle(num_examples)
# Batch sequences togther, padding them to a common length in time.
dataset = dataset.padded_batch(
batch_size, padded_shapes=([None, num_features], []))
# Post-process each batch, ensuring that it is mean-centered and time-major.
def process_pose_data(data, lengths):
"""Creates Tensors for next step prediction and mean-centers the input."""
data = tf.to_float(tf.transpose(data, perm=[1, 0, 2]))
lengths = tf.to_int32(lengths)
targets = data
# Mean center the inputs.
inputs = data - tf.constant(
mean, dtype=tf.float32, shape=[1, 1, mean.shape[0]])
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs = tf.pad(inputs, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
# Mask out unused timesteps.
inputs *= tf.expand_dims(
tf.transpose(tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
return inputs, targets, lengths
dataset = dataset.map(
process_pose_data,
num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(num_examples)
itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32)
def create_speech_dataset(path,
batch_size,
samples_per_timestep=200,
num_parallel_calls=DEFAULT_PARALLELISM,
prefetch_buffer_size=2048,
shuffle=False,
repeat=False):
"""Creates a speech dataset.
Args:
path: The path of a possibly sharded TFRecord file containing the data.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
samples_per_timestep: The number of audio samples per timestep. Used to
reshape the data into sequences of shape [time, samples_per_timestep].
Should not change except for testing -- in all speech datasets 200 is the
number of samples per timestep.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
prefetch_buffer_size: The size of the prefetch queues to use after reading
and processing the raw data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, samples_per_timestep]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, samples_per_timestep].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
"""
filenames = [path]
def read_speech_example(value):
"""Parses a single tf.Example from the TFRecord file."""
decoded = tf.decode_raw(value, out_type=tf.float32)
example = tf.reshape(decoded, [-1, samples_per_timestep])
length = tf.shape(example)[0]
return example, length
# Create the dataset from the TFRecord files
dataset = tf.data.TFRecordDataset(filenames).map(
read_speech_example, num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(prefetch_buffer_size)
if repeat: dataset = dataset.repeat()
if shuffle: dataset = dataset.shuffle(prefetch_buffer_size)
dataset = dataset.padded_batch(
batch_size, padded_shapes=([None, samples_per_timestep], []))
def process_speech_batch(data, lengths):
"""Creates Tensors for next step prediction."""
data = tf.transpose(data, perm=[1, 0, 2])
lengths = tf.to_int32(lengths)
targets = data
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs = tf.pad(data, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
# Mask out unused timesteps.
inputs *= tf.expand_dims(
tf.transpose(tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
return inputs, targets, lengths
dataset = dataset.map(process_speech_batch,
num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(prefetch_buffer_size)
itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths
SQUARED_OBSERVATION = "squared"
ABS_OBSERVATION = "abs"
STANDARD_OBSERVATION = "standard"
OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION]
ROUND_TRANSITION = "round"
STANDARD_TRANSITION = "standard"
TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION]
def create_chain_graph_dataset(
batch_size,
num_timesteps,
steps_per_observation=None,
state_size=1,
transition_variance=1.,
observation_variance=1.,
transition_type=STANDARD_TRANSITION,
observation_type=STANDARD_OBSERVATION,
fixed_observation=None,
prefetch_buffer_size=2048,
dtype="float32"):
"""Creates a toy chain graph dataset.
Creates a dataset where the data are sampled from a diffusion process. The
'latent' states of the process are sampled as a chain of Normals:
z0 ~ N(0, transition_variance)
z1 ~ N(transition_fn(z0), transition_variance)
...
where transition_fn could be round z0 or pass it through unchanged.
The observations are produced every steps_per_observation timesteps as a
function of the latent zs. For example if steps_per_observation is 3 then the
first observation will be produced as a function of z3:
x1 ~ N(observation_fn(z3), observation_variance)
where observation_fn could square z3, take the absolute value, or pass
it through unchanged.
Only the observations are returned.
Args:
batch_size: The batch size. The number of trajectories to run in parallel.
num_timesteps: The length of the chain of latent states (i.e. the
number of z's excluding z0.
steps_per_observation: The number of latent states between each observation,
must evenly divide num_timesteps.
state_size: The size of the latent state and observation, must be a
python int.
transition_variance: The variance of the transition density.
observation_variance: The variance of the observation density.
transition_type: Must be one of "round" or "standard". "round" means that
the transition density is centered at the rounded previous latent state.
"standard" centers the transition density at the previous latent state,
unchanged.
observation_type: Must be one of "squared", "abs" or "standard". "squared"
centers the observation density at the squared latent state. "abs"
centers the observaiton density at the absolute value of the current
latent state. "standard" centers the observation density at the current
latent state.
fixed_observation: If not None, fixes all observations to be a constant.
Must be a scalar.
prefetch_buffer_size: The size of the prefetch queues to use after reading
and processing the raw data.
dtype: A string convertible to a tensorflow datatype. The datatype used
to represent the states and observations.
Returns:
observations: A batch of observations represented as a dense Tensor of
shape [num_observations, batch_size, state_size]. num_observations is
num_timesteps/steps_per_observation.
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch. Will contain num_observations as each entry.
Raises:
ValueError: Raised if steps_per_observation does not evenly divide
num_timesteps.
"""
if steps_per_observation is None:
steps_per_observation = num_timesteps
if num_timesteps % steps_per_observation != 0:
raise ValueError("steps_per_observation must evenly divide num_timesteps.")
num_observations = int(num_timesteps / steps_per_observation)
def data_generator():
"""An infinite generator of latents and observations from the model."""
transition_std = np.sqrt(transition_variance)
observation_std = np.sqrt(observation_variance)
while True:
states = []
observations = []
# Sample z0 ~ Normal(0, sqrt(variance)).
states.append(
np.random.normal(size=[state_size],
scale=observation_std).astype(dtype))
# Start the range at 1 because we've already generated z0.
# The range ends at num_timesteps+1 because we want to include the
# num_timesteps-th step.
for t in xrange(1, num_timesteps+1):
if transition_type == ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == STANDARD_TRANSITION:
loc = states[-1]
z_t = np.random.normal(size=[state_size], loc=loc, scale=transition_std)
states.append(z_t.astype(dtype))
if t % steps_per_observation == 0:
if fixed_observation is None:
if observation_type == SQUARED_OBSERVATION:
loc = np.square(states[-1])
elif observation_type == ABS_OBSERVATION:
loc = np.abs(states[-1])
elif observation_type == STANDARD_OBSERVATION:
loc = states[-1]
x_t = np.random.normal(size=[state_size],
loc=loc,
scale=observation_std).astype(dtype)
else:
x_t = np.ones([state_size]) * fixed_observation
observations.append(x_t)
yield states, observations
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps+1, state_size],
[num_observations, state_size])
)
dataset = dataset.repeat().batch(batch_size)
dataset = dataset.prefetch(prefetch_buffer_size)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Transpose observations from [batch, time, state_size] to
# [time, batch, state_size].
observations = tf.transpose(observations, perm=[1, 0, 2])
lengths = tf.ones([batch_size], dtype=tf.int32) * num_observations
return observations, lengths
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.data.datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import os
import numpy as np
import tensorflow as tf
from fivo.data import datasets
FLAGS = tf.app.flags.FLAGS
class DatasetsTest(tf.test.TestCase):
def test_sparse_pianoroll_to_dense_empty_at_end(self):
sparse_pianoroll = [(0, 1), (1, 0), (), (1,), (), ()]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=0, num_notes=2)
self.assertEqual(num_timesteps, 6)
self.assertAllEqual([[1, 1],
[1, 1],
[0, 0],
[0, 1],
[0, 0],
[0, 0]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_with_chord(self):
sparse_pianoroll = [(0, 1), (1, 0), (), (1,)]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=0, num_notes=2)
self.assertEqual(num_timesteps, 4)
self.assertAllEqual([[1, 1],
[1, 1],
[0, 0],
[0, 1]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_simple(self):
sparse_pianoroll = [(0,), (), (1,)]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=0, num_notes=2)
self.assertEqual(num_timesteps, 3)
self.assertAllEqual([[1, 0],
[0, 0],
[0, 1]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_subtracts_min_note(self):
sparse_pianoroll = [(4, 5), (5, 4), (), (5,), (), ()]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=4, num_notes=2)
self.assertEqual(num_timesteps, 6)
self.assertAllEqual([[1, 1],
[1, 1],
[0, 0],
[0, 1],
[0, 0],
[0, 0]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_uses_num_notes(self):
sparse_pianoroll = [(4, 5), (5, 4), (), (5,), (), ()]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=4, num_notes=3)
self.assertEqual(num_timesteps, 6)
self.assertAllEqual([[1, 1, 0],
[1, 1, 0],
[0, 0, 0],
[0, 1, 0],
[0, 0, 0],
[0, 0, 0]], dense_pianoroll)
def test_pianoroll_dataset(self):
pianoroll_data = [[(0,), (), (1,)],
[(0, 1), (1,)],
[(1,), (0,), (), (0, 1), (), ()]]
pianoroll_mean = np.zeros([3])
pianoroll_mean[-1] = 1
data = {"train": pianoroll_data, "train_mean": pianoroll_mean}
path = os.path.join(tf.test.get_temp_dir(), "test.pkl")
pickle.dump(data, open(path, "wb"))
with self.test_session() as sess:
inputs, targets, lens, mean = datasets.create_pianoroll_dataset(
path, "train", 2, num_parallel_calls=1,
shuffle=False, repeat=False,
min_note=0, max_note=2)
i1, t1, l1 = sess.run([inputs, targets, lens])
i2, t2, l2 = sess.run([inputs, targets, lens])
m = sess.run(mean)
# Check the lengths.
self.assertAllEqual([3, 2], l1)
self.assertAllEqual([6], l2)
# Check the mean.
self.assertAllEqual(pianoroll_mean, m)
# Check the targets. The targets should not be mean-centered and should
# be padded with zeros to a common length within a batch.
self.assertAllEqual([[1, 0, 0],
[0, 0, 0],
[0, 1, 0]], t1[:, 0, :])
self.assertAllEqual([[1, 1, 0],
[0, 1, 0],
[0, 0, 0]], t1[:, 1, :])
self.assertAllEqual([[0, 1, 0],
[1, 0, 0],
[0, 0, 0],
[1, 1, 0],
[0, 0, 0],
[0, 0, 0]], t2[:, 0, :])
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch. The mean should be subtracted from all timesteps except
# the first and the padding.
self.assertAllEqual([[0, 0, 0],
[1, 0, -1],
[0, 0, -1]], i1[:, 0, :])
self.assertAllEqual([[0, 0, 0],
[1, 1, -1],
[0, 0, 0]], i1[:, 1, :])
self.assertAllEqual([[0, 0, 0],
[0, 1, -1],
[1, 0, -1],
[0, 0, -1],
[1, 1, -1],
[0, 0, -1]], i2[:, 0, :])
def test_human_pose_dataset(self):
pose_data = [
[[0, 0], [2, 2]],
[[2, 2]],
[[0, 0], [0, 0], [2, 2], [2, 2], [0, 0]],
]
pose_data = [np.array(x, dtype=np.float64) for x in pose_data]
pose_data_mean = np.array([1, 1], dtype=np.float64)
data = {
"train": pose_data,
"train_mean": pose_data_mean,
}
path = os.path.join(tf.test.get_temp_dir(), "test_human_pose_dataset.pkl")
with open(path, "wb") as out:
pickle.dump(data, out)
with self.test_session() as sess:
inputs, targets, lens, mean = datasets.create_human_pose_dataset(
path, "train", 2, num_parallel_calls=1, shuffle=False, repeat=False)
i1, t1, l1 = sess.run([inputs, targets, lens])
i2, t2, l2 = sess.run([inputs, targets, lens])
m = sess.run(mean)
# Check the lengths.
self.assertAllEqual([2, 1], l1)
self.assertAllEqual([5], l2)
# Check the mean.
self.assertAllEqual(pose_data_mean, m)
# Check the targets. The targets should not be mean-centered and should
# be padded with zeros to a common length within a batch.
self.assertAllEqual([[0, 0], [2, 2]], t1[:, 0, :])
self.assertAllEqual([[2, 2], [0, 0]], t1[:, 1, :])
self.assertAllEqual([[0, 0], [0, 0], [2, 2], [2, 2], [0, 0]], t2[:, 0, :])
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch. The mean should be subtracted from all timesteps except
# the first and the padding.
self.assertAllEqual([[0, 0], [-1, -1]], i1[:, 0, :])
self.assertAllEqual([[0, 0], [0, 0]], i1[:, 1, :])
self.assertAllEqual([[0, 0], [-1, -1], [-1, -1], [1, 1], [1, 1]],
i2[:, 0, :])
def test_speech_dataset(self):
with self.test_session() as sess:
path = os.path.join(
os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
"test_data",
"tiny_speech_dataset.tfrecord")
inputs, targets, lens = datasets.create_speech_dataset(
path, 3, samples_per_timestep=2, num_parallel_calls=1,
prefetch_buffer_size=3, shuffle=False, repeat=False)
inputs1, targets1, lengths1 = sess.run([inputs, targets, lens])
inputs2, targets2, lengths2 = sess.run([inputs, targets, lens])
# Check the lengths.
self.assertAllEqual([1, 2, 3], lengths1)
self.assertAllEqual([4], lengths2)
# Check the targets. The targets should be padded with zeros to a common
# length within a batch.
self.assertAllEqual([[[0., 1.], [0., 1.], [0., 1.]],
[[0., 0.], [2., 3.], [2., 3.]],
[[0., 0.], [0., 0.], [4., 5.]]],
targets1)
self.assertAllEqual([[[0., 1.]],
[[2., 3.]],
[[4., 5.]],
[[6., 7.]]],
targets2)
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch.
self.assertAllEqual([[[0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 1.], [0., 1.]],
[[0., 0.], [0., 0.], [2., 3.]]],
inputs1)
self.assertAllEqual([[[0., 0.]],
[[0., 1.]],
[[2., 3.]],
[[4., 5.]]],
inputs2)
def test_chain_graph_raises_error_on_wrong_steps_per_observation(self):
with self.assertRaises(ValueError):
datasets.create_chain_graph_dataset(
batch_size=4,
num_timesteps=10,
steps_per_observation=9)
def test_chain_graph_single_obs(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 1
num_timesteps = 5
batch_size = 2
state_size = 1
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
state_size=state_size)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
[[[1.426677], [-1.789461]]],
out_observations)
def test_chain_graph_multiple_obs(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 3
num_timesteps = 6
batch_size = 2
state_size = 1
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
steps_per_observation=num_timesteps/num_observations,
state_size=state_size)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
[[[0.40051451], [1.07405114]],
[[1.73932898], [3.16880035]],
[[-1.98377144], [2.82669163]]],
out_observations)
def test_chain_graph_state_dims(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 1
num_timesteps = 5
batch_size = 2
state_size = 3
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
state_size=state_size)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
[[[1.052287, -4.560759, 3.07988],
[2.008926, 0.495567, 3.488678]]],
out_observations)
def test_chain_graph_fixed_obs(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 3
num_timesteps = 6
batch_size = 2
state_size = 1
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
steps_per_observation=num_timesteps/num_observations,
state_size=state_size,
fixed_observation=4.)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
np.ones([num_observations, batch_size, state_size]) * 4.,
out_observations)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Creates and runs Gaussian HMM-related graphs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from fivo import smc
from fivo import bounds
from fivo.data import datasets
from fivo.models import ghmm
def run_train(config):
"""Runs training for a Gaussian HMM setup."""
def create_logging_hook(step, bound_value, likelihood, bound_gap):
"""Creates a logging hook that prints the bound value periodically."""
bound_label = config.bound + "/t"
def summary_formatter(log_dict):
string = ("Step {step}, %s: {value:.3f}, "
"likelihood: {ll:.3f}, gap: {gap:.3e}") % bound_label
return string.format(**log_dict)
logging_hook = tf.train.LoggingTensorHook(
{"step": step, "value": bound_value,
"ll": likelihood, "gap": bound_gap},
every_n_iter=config.summarize_every,
formatter=summary_formatter)
return logging_hook
def create_losses(model, observations, lengths):
"""Creates the loss to be optimized.
Args:
model: A Trainable GHMM model.
observations: A set of observations.
lengths: The lengths of each sequence in the observations.
Returns:
loss: A float Tensor that when differentiated yields the gradients
to apply to the model. Should be optimized via gradient descent.
bound: A float Tensor containing the value of the bound that is
being optimized.
true_ll: The true log-likelihood of the data under the model.
bound_gap: The gap between the bound and the true log-likelihood.
"""
# Compute lower bounds on the log likelihood.
if config.bound == "elbo":
ll_per_seq, _, _ = bounds.iwae(
model, observations, lengths, num_samples=1,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "iwae":
ll_per_seq, _, _ = bounds.iwae(
model, observations, lengths, num_samples=config.num_samples,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "fivo":
if config.resampling_type == "relaxed":
ll_per_seq, _, _, _ = bounds.fivo(
model,
observations,
lengths,
num_samples=config.num_samples,
resampling_criterion=smc.ess_criterion,
resampling_type=config.resampling_type,
relaxed_resampling_temperature=config.
relaxed_resampling_temperature,
random_seed=config.random_seed,
parallel_iterations=config.parallel_iterations)
else:
ll_per_seq, _, _, _ = bounds.fivo(
model, observations, lengths,
num_samples=config.num_samples,
resampling_criterion=smc.ess_criterion,
resampling_type=config.resampling_type,
random_seed=config.random_seed,
parallel_iterations=config.parallel_iterations
)
ll_per_t = tf.reduce_mean(ll_per_seq / tf.to_float(lengths))
# Compute the data's true likelihood under the model and the bound gap.
true_ll_per_seq = model.likelihood(tf.squeeze(observations))
true_ll_per_t = tf.reduce_mean(true_ll_per_seq / tf.to_float(lengths))
bound_gap = true_ll_per_seq - ll_per_seq
bound_gap = tf.reduce_mean(bound_gap/ tf.to_float(lengths))
tf.summary.scalar("train_ll_bound", ll_per_t)
tf.summary.scalar("train_true_ll", true_ll_per_t)
tf.summary.scalar("bound_gap", bound_gap)
return -ll_per_t, ll_per_t, true_ll_per_t, bound_gap
def create_graph():
"""Creates the training graph."""
global_step = tf.train.get_or_create_global_step()
xs, lengths = datasets.create_chain_graph_dataset(
config.batch_size,
config.num_timesteps,
steps_per_observation=1,
state_size=1,
transition_variance=config.variance,
observation_variance=config.variance)
model = ghmm.TrainableGaussianHMM(
config.num_timesteps,
config.proposal_type,
transition_variances=config.variance,
emission_variances=config.variance,
random_seed=config.random_seed)
loss, bound, true_ll, gap = create_losses(model, xs, lengths)
opt = tf.train.AdamOptimizer(config.learning_rate)
grads = opt.compute_gradients(loss, var_list=tf.trainable_variables())
train_op = opt.apply_gradients(grads, global_step=global_step)
return bound, true_ll, gap, train_op, global_step
with tf.Graph().as_default():
if config.random_seed:
tf.set_random_seed(config.random_seed)
np.random.seed(config.random_seed)
bound, true_ll, gap, train_op, global_step = create_graph()
log_hook = create_logging_hook(global_step, bound, true_ll, gap)
with tf.train.MonitoredTrainingSession(
master="",
hooks=[log_hook],
checkpoint_dir=config.logdir,
save_checkpoint_secs=120,
save_summaries_steps=config.summarize_every,
log_step_count_steps=config.summarize_every*20) as sess:
cur_step = -1
while cur_step <= config.max_steps and not sess.should_stop():
cur_step = sess.run(global_step)
_, cur_step = sess.run([train_op, global_step])
def run_eval(config):
"""Evaluates a Gaussian HMM using the given config."""
def create_bound(model, xs, lengths):
"""Creates the bound to be evaluated."""
if config.bound == "elbo":
ll_per_seq, log_weights, _ = bounds.iwae(
model, xs, lengths, num_samples=1,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "iwae":
ll_per_seq, log_weights, _ = bounds.iwae(
model, xs, lengths, num_samples=config.num_samples,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "fivo":
ll_per_seq, log_weights, resampled, _ = bounds.fivo(
model, xs, lengths,
num_samples=config.num_samples,
resampling_criterion=smc.ess_criterion,
resampling_type=config.resampling_type,
random_seed=config.random_seed,
parallel_iterations=config.parallel_iterations
)
# Compute bound scaled by number of timesteps.
bound_per_t = ll_per_seq / tf.to_float(lengths)
if config.bound == "fivo":
return bound_per_t, log_weights, resampled
else:
return bound_per_t, log_weights
def create_graph():
"""Creates the dataset, model, and bound."""
xs, lengths = datasets.create_chain_graph_dataset(
config.batch_size,
config.num_timesteps,
steps_per_observation=1,
state_size=1,
transition_variance=config.variance,
observation_variance=config.variance)
model = ghmm.TrainableGaussianHMM(
config.num_timesteps,
config.proposal_type,
transition_variances=config.variance,
emission_variances=config.variance,
random_seed=config.random_seed)
true_likelihood = tf.reduce_mean(
model.likelihood(tf.squeeze(xs)) / tf.to_float(lengths))
outs = [true_likelihood]
outs.extend(list(create_bound(model, xs, lengths)))
return outs
with tf.Graph().as_default():
if config.random_seed:
tf.set_random_seed(config.random_seed)
np.random.seed(config.random_seed)
graph_outs = create_graph()
with tf.train.SingularMonitoredSession(
checkpoint_dir=config.logdir) as sess:
outs = sess.run(graph_outs)
likelihood = outs[0]
avg_bound = np.mean(outs[1])
std = np.std(outs[1])
log_weights = outs[2]
log_weight_variances = np.var(log_weights, axis=2)
avg_log_weight_variance = np.var(log_weight_variances, axis=1)
avg_log_weight = np.mean(log_weights, axis=(1, 2))
data = {"mean": avg_bound, "std": std, "log_weights": log_weights,
"log_weight_means": avg_log_weight,
"log_weight_variances": avg_log_weight_variance}
if len(outs) == 4:
data["resampled"] = outs[3]
data["avg_resampled"] = np.mean(outs[3], axis=1)
# Log some useful statistics.
tf.logging.info("Evaled bound %s with batch_size: %d, num_samples: %d."
% (config.bound, config.batch_size, config.num_samples))
tf.logging.info("mean: %f, std: %f" % (avg_bound, std))
tf.logging.info("true likelihood: %s" % likelihood)
tf.logging.info("avg log weight: %s" % avg_log_weight)
tf.logging.info("log weight variance: %s" % avg_log_weight_variance)
if len(outs) == 4:
tf.logging.info("avg resamples per t: %s" % data["avg_resampled"])
if not tf.gfile.Exists(config.logdir):
tf.gfile.MakeDirs(config.logdir)
with tf.gfile.Open(os.path.join(config.logdir, "out.npz"), "w") as fout:
np.save(fout, data)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.ghmm_runners."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from fivo import ghmm_runners
class GHMMRunnersTest(tf.test.TestCase):
def default_config(self):
class Config(object):
pass
config = Config()
config.model = "ghmm"
config.bound = "fivo"
config.proposal_type = "prior"
config.batch_size = 4
config.num_samples = 4
config.num_timesteps = 10
config.variance = 0.1
config.resampling_type = "multinomial"
config.random_seed = 1234
config.parallel_iterations = 1
config.learning_rate = 1e-4
config.summarize_every = 1
config.max_steps = 1
return config
def test_eval_ghmm_notraining_fivo_prior(self):
self.eval_ghmm_notraining("fivo", "prior", -3.063864)
def test_eval_ghmm_notraining_fivo_true_filtering(self):
self.eval_ghmm_notraining("fivo", "true-filtering", -1.1409812)
def test_eval_ghmm_notraining_fivo_true_smoothing(self):
self.eval_ghmm_notraining("fivo", "true-smoothing", -0.85592091)
def test_eval_ghmm_notraining_iwae_prior(self):
self.eval_ghmm_notraining("iwae", "prior", -5.9730167)
def test_eval_ghmm_notraining_iwae_true_filtering(self):
self.eval_ghmm_notraining("iwae", "true-filtering", -1.1485999)
def test_eval_ghmm_notraining_iwae_true_smoothing(self):
self.eval_ghmm_notraining("iwae", "true-smoothing", -0.85592091)
def eval_ghmm_notraining(self, bound, proposal_type, expected_bound_avg):
config = self.default_config()
config.proposal_type = proposal_type
config.bound = bound
config.logdir = os.path.join(
tf.test.get_temp_dir(), "test-ghmm-%s-%s" % (proposal_type, bound))
ghmm_runners.run_eval(config)
data = np.load(os.path.join(config.logdir, "out.npz")).item()
self.assertAlmostEqual(expected_bound_avg, data["mean"], places=3)
def test_train_ghmm_for_one_step_and_eval_fivo_filtering(self):
self.train_ghmm_for_one_step_and_eval("fivo", "filtering", -16.727108)
def test_train_ghmm_for_one_step_and_eval_fivo_smoothing(self):
self.train_ghmm_for_one_step_and_eval("fivo", "smoothing", -19.381277)
def test_train_ghmm_for_one_step_and_eval_iwae_filtering(self):
self.train_ghmm_for_one_step_and_eval("iwae", "filtering", -33.31966)
def test_train_ghmm_for_one_step_and_eval_iwae_smoothing(self):
self.train_ghmm_for_one_step_and_eval("iwae", "smoothing", -46.388447)
def train_ghmm_for_one_step_and_eval(self, bound, proposal_type, expected_bound_avg):
config = self.default_config()
config.proposal_type = proposal_type
config.bound = bound
config.max_steps = 1
config.logdir = os.path.join(
tf.test.get_temp_dir(), "test-ghmm-training-%s-%s" % (proposal_type, bound))
ghmm_runners.run_train(config)
ghmm_runners.run_eval(config)
data = np.load(os.path.join(config.logdir, "out.npz")).item()
self.assertAlmostEqual(expected_bound_avg, data["mean"], places=2)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reusable model classes for FIVO."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sonnet as snt
import tensorflow as tf
from fivo import nested_utils as nested
tfd = tf.contrib.distributions
class ELBOTrainableSequenceModel(object):
"""An abstract class for ELBO-trainable sequence models to extend.
Because the ELBO, IWAE, and FIVO bounds all accept the same arguments,
any model that is ELBO-trainable is also IWAE- and FIVO-trainable.
"""
def zero_state(self, batch_size, dtype):
"""Returns the initial state of the model as a Tensor or tuple of Tensors.
Args:
batch_size: The batch size.
dtype: The datatype to use for the state.
"""
raise NotImplementedError("zero_state not yet implemented.")
def set_observations(self, observations, seq_lengths):
"""Sets the observations for the model.
This method provides the model with all observed variables including both
inputs and targets. It will be called before running any computations with
the model that require the observations, e.g. training the model or
computing bounds, and should be used to run any necessary preprocessing
steps.
Args:
observations: A potentially nested set of Tensors containing
all observations for the model, both inputs and targets. Typically
a set of Tensors with shape [max_seq_len, batch_size, data_size].
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
"""
self.observations = observations
self.max_seq_len = tf.reduce_max(seq_lengths)
self.observations_ta = nested.tas_for_tensors(
observations, self.max_seq_len, clear_after_read=False)
self.seq_lengths = seq_lengths
def propose_and_weight(self, state, t):
"""Propogates model state one timestep and computes log weights.
This method accepts the current state of the model and computes the state
for the next timestep as well as the incremental log weight of each
element in the batch.
Args:
state: The current state of the model.
t: A scalar integer Tensor representing the current timestep.
Returns:
next_state: The state of the model after one timestep.
log_weights: A [batch_size] Tensor containing the incremental log weights.
"""
raise NotImplementedError("propose_and_weight not yet implemented.")
DEFAULT_INITIALIZERS = {"w": tf.contrib.layers.xavier_initializer(),
"b": tf.zeros_initializer()}
class ConditionalNormalDistribution(object):
"""A Normal distribution conditioned on Tensor inputs via a fc network."""
def __init__(self, size, hidden_layer_sizes, sigma_min=0.0,
raw_sigma_bias=0.25, hidden_activation_fn=tf.nn.relu,
initializers=None, name="conditional_normal_distribution"):
"""Creates a conditional Normal distribution.
Args:
size: The dimension of the random variable.
hidden_layer_sizes: The sizes of the hidden layers of the fully connected
network used to condition the distribution on the inputs.
sigma_min: The minimum standard deviation allowed, a scalar.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the fully connected network. Set to 0.25 by default to
prevent standard deviations close to 0.
hidden_activation_fn: The activation function to use on the hidden layers
of the fully connected network.
initializers: The variable intitializers to use for the fully connected
network. The network is implemented using snt.nets.MLP so it must
be a dictionary mapping the keys 'w' and 'b' to the initializers for
the weights and biases. Defaults to xavier for the weights and zeros
for the biases when initializers is None.
name: The name of this distribution, used for sonnet scoping.
"""
self.sigma_min = sigma_min
self.raw_sigma_bias = raw_sigma_bias
self.name = name
self.size = size
if initializers is None:
initializers = DEFAULT_INITIALIZERS
self.fcnet = snt.nets.MLP(
output_sizes=hidden_layer_sizes + [2*size],
activation=hidden_activation_fn,
initializers=initializers,
activate_final=False,
use_bias=True,
name=name + "_fcnet")
def condition(self, tensor_list, **unused_kwargs):
"""Computes the parameters of a normal distribution based on the inputs."""
inputs = tf.concat(tensor_list, axis=1)
outs = self.fcnet(inputs)
mu, sigma = tf.split(outs, 2, axis=1)
sigma = tf.maximum(tf.nn.softplus(sigma + self.raw_sigma_bias),
self.sigma_min)
return mu, sigma
def __call__(self, *args, **kwargs):
"""Creates a normal distribution conditioned on the inputs."""
mu, sigma = self.condition(args, **kwargs)
return tf.contrib.distributions.Normal(loc=mu, scale=sigma)
class ConditionalBernoulliDistribution(object):
"""A Bernoulli distribution conditioned on Tensor inputs via a fc net."""
def __init__(self, size, hidden_layer_sizes, hidden_activation_fn=tf.nn.relu,
initializers=None, bias_init=0.0,
name="conditional_bernoulli_distribution"):
"""Creates a conditional Bernoulli distribution.
Args:
size: The dimension of the random variable.
hidden_layer_sizes: The sizes of the hidden layers of the fully connected
network used to condition the distribution on the inputs.
hidden_activation_fn: The activation function to use on the hidden layers
of the fully connected network.
initializers: The variable intiializers to use for the fully connected
network. The network is implemented using snt.nets.MLP so it must
be a dictionary mapping the keys 'w' and 'b' to the initializers for
the weights and biases. Defaults to xavier for the weights and zeros
for the biases when initializers is None.
bias_init: A scalar or vector Tensor that is added to the output of the
fully-connected network that parameterizes the mean of this
distribution.
name: The name of this distribution, used for sonnet scoping.
"""
self.bias_init = bias_init
self.size = size
if initializers is None:
initializers = DEFAULT_INITIALIZERS
self.fcnet = snt.nets.MLP(
output_sizes=hidden_layer_sizes + [size],
activation=hidden_activation_fn,
initializers=initializers,
activate_final=False,
use_bias=True,
name=name + "_fcnet")
def condition(self, tensor_list):
"""Computes the p parameter of the Bernoulli distribution."""
inputs = tf.concat(tensor_list, axis=1)
return self.fcnet(inputs) + self.bias_init
def __call__(self, *args):
p = self.condition(args)
return tf.contrib.distributions.Bernoulli(logits=p)
class NormalApproximatePosterior(ConditionalNormalDistribution):
"""A Normally-distributed approx. posterior with res_q parameterization."""
def __init__(self, size, hidden_layer_sizes, sigma_min=0.0,
raw_sigma_bias=0.25, hidden_activation_fn=tf.nn.relu,
initializers=None, smoothing=False,
name="conditional_normal_distribution"):
super(NormalApproximatePosterior, self).__init__(
size, hidden_layer_sizes, sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
hidden_activation_fn=hidden_activation_fn, initializers=initializers,
name=name)
self.smoothing = smoothing
def condition(self, tensor_list, prior_mu, smoothing_tensors=None):
"""Generates the mean and variance of the normal distribution.
Args:
tensor_list: The list of Tensors to condition on. Will be concatenated and
fed through a fully connected network.
prior_mu: The mean of the prior distribution associated with this
approximate posterior. Will be added to the mean produced by
this approximate posterior, in res_q fashion.
smoothing_tensors: A list of Tensors. If smoothing is True, these Tensors
will be concatenated with the tensors in tensor_list.
Returns:
mu: The mean of the approximate posterior.
sigma: The standard deviation of the approximate posterior.
"""
if self.smoothing:
tensor_list.extend(smoothing_tensors)
mu, sigma = super(NormalApproximatePosterior, self).condition(tensor_list)
return mu + prior_mu, sigma
class NonstationaryLinearDistribution(object):
"""A set of loc-scale distributions that are linear functions of inputs.
This class defines a series of location-scale distributions such that
the means are learnable linear functions of the inputs and the log variances
are learnable constants. The functions and log variances are different across
timesteps, allowing the distributions to be nonstationary.
"""
def __init__(self,
num_timesteps,
inputs_per_timestep=None,
outputs_per_timestep=None,
initializers=None,
variance_min=0.0,
output_distribution=tfd.Normal,
dtype=tf.float32):
"""Creates a NonstationaryLinearDistribution.
Args:
num_timesteps: The number of timesteps, i.e. the number of distributions.
inputs_per_timestep: A list of python ints, the dimension of inputs to the
linear function at each timestep. If not provided, the dimension at each
timestep is assumed to be 1.
outputs_per_timestep: A list of python ints, the dimension of the output
distribution at each timestep. If not provided, the dimension at each
timestep is assumed to be 1.
initializers: A dictionary containing intializers for the variables. The
initializer under the key 'w' is used for the weights in the linear
function and the initializer under the key 'b' is used for the biases.
Defaults to xavier initialization for the weights and zeros for the
biases.
variance_min: Python float, the minimum variance of each distribution.
output_distribution: A locatin-scale subclass of tfd.Distribution that
defines the output distribution, e.g. Normal.
dtype: The dtype of the weights and biases.
"""
if not initializers:
initializers = DEFAULT_INITIALIZERS
if not inputs_per_timestep:
inputs_per_timestep = [1] * num_timesteps
if not outputs_per_timestep:
outputs_per_timestep = [1] * num_timesteps
self.num_timesteps = num_timesteps
self.variance_min = variance_min
self.initializers = initializers
self.dtype = dtype
self.output_distribution = output_distribution
def _get_variables_ta(shapes, name, initializer, trainable=True):
"""Creates a sequence of variables and stores them in a TensorArray."""
# Infer shape if all shapes are equal.
first_shape = shapes[0]
infer_shape = all(shape == first_shape for shape in shapes)
ta = tf.TensorArray(
dtype=dtype, size=len(shapes), dynamic_size=False,
clear_after_read=False, infer_shape=infer_shape)
for t, shape in enumerate(shapes):
var = tf.get_variable(
name % t, shape=shape, initializer=initializer, trainable=trainable)
ta = ta.write(t, var)
return ta
bias_shapes = [[num_outputs] for num_outputs in outputs_per_timestep]
self.log_variances = _get_variables_ta(
bias_shapes, "proposal_log_variance_%d", initializers["b"])
self.mean_biases = _get_variables_ta(
bias_shapes, "proposal_b_%d", initializers["b"])
weight_shapes = zip(inputs_per_timestep, outputs_per_timestep)
self.mean_weights = _get_variables_ta(
weight_shapes, "proposal_w_%d", initializers["w"])
self.shapes = tf.TensorArray(
dtype=tf.int32, size=num_timesteps,
dynamic_size=False, clear_after_read=False).unstack(weight_shapes)
def __call__(self, t, inputs):
"""Computes the distribution at timestep t.
Args:
t: Scalar integer Tensor, the current timestep. Must be in
[0, num_timesteps).
inputs: The inputs to the linear function parameterizing the mean of
the current distribution. A Tensor of shape [batch_size, num_inputs_t].
Returns:
A tfd.Distribution subclass representing the distribution at timestep t.
"""
b = self.mean_biases.read(t)
w = self.mean_weights.read(t)
shape = self.shapes.read(t)
w = tf.reshape(w, shape)
b = tf.reshape(b, [shape[1], 1])
log_variance = self.log_variances.read(t)
scale = tf.sqrt(tf.maximum(tf.exp(log_variance), self.variance_min))
loc = tf.matmul(w, inputs, transpose_a=True) + b
return self.output_distribution(loc=loc, scale=scale)
def encode_all(inputs, encoder):
"""Encodes a timeseries of inputs with a time independent encoder.
Args:
inputs: A [time, batch, feature_dimensions] tensor.
encoder: A network that takes a [batch, features_dimensions] input and
encodes the input.
Returns:
A [time, batch, encoded_feature_dimensions] output tensor.
"""
input_shape = tf.shape(inputs)
num_timesteps, batch_size = input_shape[0], input_shape[1]
reshaped_inputs = tf.reshape(inputs, [-1, inputs.shape[-1]])
inputs_encoded = encoder(reshaped_inputs)
inputs_encoded = tf.reshape(inputs_encoded,
[num_timesteps, batch_size, encoder.output_size])
return inputs_encoded
def ta_for_tensor(x, **kwargs):
"""Creates a TensorArray for the input tensor."""
return tf.TensorArray(
x.dtype, tf.shape(x)[0], dynamic_size=False, **kwargs).unstack(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment