Commit 46dea625 authored by Dieterich Lawson's avatar Dieterich Lawson
Browse files

Updating fivo codebase

parent 5856878d
# Copyright 2017 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,6 +21,8 @@ from __future__ import division ...@@ -21,6 +21,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import pickle import pickle
import numpy as np
from scipy.sparse import coo_matrix from scipy.sparse import coo_matrix
import tensorflow as tf import tensorflow as tf
...@@ -147,6 +149,95 @@ def create_pianoroll_dataset(path, ...@@ -147,6 +149,95 @@ def create_pianoroll_dataset(path,
return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32) return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32)
def create_human_pose_dataset(
path,
split,
batch_size,
num_parallel_calls=DEFAULT_PARALLELISM,
shuffle=False,
repeat=False,):
"""Creates a human pose dataset.
Args:
path: The path of a pickle file containing the dataset to load.
split: The split to use, can be train, test, or valid.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros. This tensor is mean-centered, with the mean taken from the pickle
file key 'train_mean'.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
mean: A float Tensor of shape [data_dimension] containing the mean loaded
from the pickle file.
"""
# Load the data from disk.
with tf.gfile.Open(path, "r") as f:
raw_data = pickle.load(f)
mean = raw_data["train_mean"]
pose_sequences = raw_data[split]
num_examples = len(pose_sequences)
num_features = pose_sequences[0].shape[1]
def pose_generator():
"""A generator that yields pose data sequences."""
# Each timestep has 32 x values followed by 32 y values so is 64
# dimensional.
for pose_sequence in pose_sequences:
yield pose_sequence, pose_sequence.shape[0]
dataset = tf.data.Dataset.from_generator(
pose_generator,
output_types=(tf.float64, tf.int64),
output_shapes=([None, num_features], []))
if repeat:
dataset = dataset.repeat()
if shuffle:
dataset = dataset.shuffle(num_examples)
# Batch sequences togther, padding them to a common length in time.
dataset = dataset.padded_batch(
batch_size, padded_shapes=([None, num_features], []))
# Post-process each batch, ensuring that it is mean-centered and time-major.
def process_pose_data(data, lengths):
"""Creates Tensors for next step prediction and mean-centers the input."""
data = tf.to_float(tf.transpose(data, perm=[1, 0, 2]))
lengths = tf.to_int32(lengths)
targets = data
# Mean center the inputs.
inputs = data - tf.constant(
mean, dtype=tf.float32, shape=[1, 1, mean.shape[0]])
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs = tf.pad(inputs, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
# Mask out unused timesteps.
inputs *= tf.expand_dims(
tf.transpose(tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
return inputs, targets, lengths
dataset = dataset.map(
process_pose_data,
num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(num_examples)
itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32)
def create_speech_dataset(path, def create_speech_dataset(path,
batch_size, batch_size,
samples_per_timestep=200, samples_per_timestep=200,
...@@ -221,3 +312,142 @@ def create_speech_dataset(path, ...@@ -221,3 +312,142 @@ def create_speech_dataset(path,
itr = dataset.make_one_shot_iterator() itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next() inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths return inputs, targets, lengths
SQUARED_OBSERVATION = "squared"
ABS_OBSERVATION = "abs"
STANDARD_OBSERVATION = "standard"
OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION]
ROUND_TRANSITION = "round"
STANDARD_TRANSITION = "standard"
TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION]
def create_chain_graph_dataset(
batch_size,
num_timesteps,
steps_per_observation=None,
state_size=1,
transition_variance=1.,
observation_variance=1.,
transition_type=STANDARD_TRANSITION,
observation_type=STANDARD_OBSERVATION,
fixed_observation=None,
prefetch_buffer_size=2048,
dtype="float32"):
"""Creates a toy chain graph dataset.
Creates a dataset where the data are sampled from a diffusion process. The
'latent' states of the process are sampled as a chain of Normals:
z0 ~ N(0, transition_variance)
z1 ~ N(transition_fn(z0), transition_variance)
...
where transition_fn could be round z0 or pass it through unchanged.
The observations are produced every steps_per_observation timesteps as a
function of the latent zs. For example if steps_per_observation is 3 then the
first observation will be produced as a function of z3:
x1 ~ N(observation_fn(z3), observation_variance)
where observation_fn could square z3, take the absolute value, or pass
it through unchanged.
Only the observations are returned.
Args:
batch_size: The batch size. The number of trajectories to run in parallel.
num_timesteps: The length of the chain of latent states (i.e. the
number of z's excluding z0.
steps_per_observation: The number of latent states between each observation,
must evenly divide num_timesteps.
state_size: The size of the latent state and observation, must be a
python int.
transition_variance: The variance of the transition density.
observation_variance: The variance of the observation density.
transition_type: Must be one of "round" or "standard". "round" means that
the transition density is centered at the rounded previous latent state.
"standard" centers the transition density at the previous latent state,
unchanged.
observation_type: Must be one of "squared", "abs" or "standard". "squared"
centers the observation density at the squared latent state. "abs"
centers the observaiton density at the absolute value of the current
latent state. "standard" centers the observation density at the current
latent state.
fixed_observation: If not None, fixes all observations to be a constant.
Must be a scalar.
prefetch_buffer_size: The size of the prefetch queues to use after reading
and processing the raw data.
dtype: A string convertible to a tensorflow datatype. The datatype used
to represent the states and observations.
Returns:
observations: A batch of observations represented as a dense Tensor of
shape [num_observations, batch_size, state_size]. num_observations is
num_timesteps/steps_per_observation.
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch. Will contain num_observations as each entry.
Raises:
ValueError: Raised if steps_per_observation does not evenly divide
num_timesteps.
"""
if steps_per_observation is None:
steps_per_observation = num_timesteps
if num_timesteps % steps_per_observation != 0:
raise ValueError("steps_per_observation must evenly divide num_timesteps.")
num_observations = int(num_timesteps / steps_per_observation)
def data_generator():
"""An infinite generator of latents and observations from the model."""
transition_std = np.sqrt(transition_variance)
observation_std = np.sqrt(observation_variance)
while True:
states = []
observations = []
# Sample z0 ~ Normal(0, sqrt(variance)).
states.append(
np.random.normal(size=[state_size],
scale=observation_std).astype(dtype))
# Start the range at 1 because we've already generated z0.
# The range ends at num_timesteps+1 because we want to include the
# num_timesteps-th step.
for t in xrange(1, num_timesteps+1):
if transition_type == ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == STANDARD_TRANSITION:
loc = states[-1]
z_t = np.random.normal(size=[state_size], loc=loc, scale=transition_std)
states.append(z_t.astype(dtype))
if t % steps_per_observation == 0:
if fixed_observation is None:
if observation_type == SQUARED_OBSERVATION:
loc = np.square(states[-1])
elif observation_type == ABS_OBSERVATION:
loc = np.abs(states[-1])
elif observation_type == STANDARD_OBSERVATION:
loc = states[-1]
x_t = np.random.normal(size=[state_size],
loc=loc,
scale=observation_std).astype(dtype)
else:
x_t = np.ones([state_size]) * fixed_observation
observations.append(x_t)
yield states, observations
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps+1, state_size],
[num_observations, state_size])
)
dataset = dataset.repeat().batch(batch_size)
dataset = dataset.prefetch(prefetch_buffer_size)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Transpose observations from [batch, time, state_size] to
# [time, batch, state_size].
observations = tf.transpose(observations, perm=[1, 0, 2])
lengths = tf.ones([batch_size], dtype=tf.int32) * num_observations
return observations, lengths
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.data.datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import os
import numpy as np
import tensorflow as tf
from fivo.data import datasets
FLAGS = tf.app.flags.FLAGS
class DatasetsTest(tf.test.TestCase):
def test_sparse_pianoroll_to_dense_empty_at_end(self):
sparse_pianoroll = [(0, 1), (1, 0), (), (1,), (), ()]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=0, num_notes=2)
self.assertEqual(num_timesteps, 6)
self.assertAllEqual([[1, 1],
[1, 1],
[0, 0],
[0, 1],
[0, 0],
[0, 0]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_with_chord(self):
sparse_pianoroll = [(0, 1), (1, 0), (), (1,)]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=0, num_notes=2)
self.assertEqual(num_timesteps, 4)
self.assertAllEqual([[1, 1],
[1, 1],
[0, 0],
[0, 1]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_simple(self):
sparse_pianoroll = [(0,), (), (1,)]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=0, num_notes=2)
self.assertEqual(num_timesteps, 3)
self.assertAllEqual([[1, 0],
[0, 0],
[0, 1]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_subtracts_min_note(self):
sparse_pianoroll = [(4, 5), (5, 4), (), (5,), (), ()]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=4, num_notes=2)
self.assertEqual(num_timesteps, 6)
self.assertAllEqual([[1, 1],
[1, 1],
[0, 0],
[0, 1],
[0, 0],
[0, 0]], dense_pianoroll)
def test_sparse_pianoroll_to_dense_uses_num_notes(self):
sparse_pianoroll = [(4, 5), (5, 4), (), (5,), (), ()]
dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
sparse_pianoroll, min_note=4, num_notes=3)
self.assertEqual(num_timesteps, 6)
self.assertAllEqual([[1, 1, 0],
[1, 1, 0],
[0, 0, 0],
[0, 1, 0],
[0, 0, 0],
[0, 0, 0]], dense_pianoroll)
def test_pianoroll_dataset(self):
pianoroll_data = [[(0,), (), (1,)],
[(0, 1), (1,)],
[(1,), (0,), (), (0, 1), (), ()]]
pianoroll_mean = np.zeros([3])
pianoroll_mean[-1] = 1
data = {"train": pianoroll_data, "train_mean": pianoroll_mean}
path = os.path.join(tf.test.get_temp_dir(), "test.pkl")
pickle.dump(data, open(path, "wb"))
with self.test_session() as sess:
inputs, targets, lens, mean = datasets.create_pianoroll_dataset(
path, "train", 2, num_parallel_calls=1,
shuffle=False, repeat=False,
min_note=0, max_note=2)
i1, t1, l1 = sess.run([inputs, targets, lens])
i2, t2, l2 = sess.run([inputs, targets, lens])
m = sess.run(mean)
# Check the lengths.
self.assertAllEqual([3, 2], l1)
self.assertAllEqual([6], l2)
# Check the mean.
self.assertAllEqual(pianoroll_mean, m)
# Check the targets. The targets should not be mean-centered and should
# be padded with zeros to a common length within a batch.
self.assertAllEqual([[1, 0, 0],
[0, 0, 0],
[0, 1, 0]], t1[:, 0, :])
self.assertAllEqual([[1, 1, 0],
[0, 1, 0],
[0, 0, 0]], t1[:, 1, :])
self.assertAllEqual([[0, 1, 0],
[1, 0, 0],
[0, 0, 0],
[1, 1, 0],
[0, 0, 0],
[0, 0, 0]], t2[:, 0, :])
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch. The mean should be subtracted from all timesteps except
# the first and the padding.
self.assertAllEqual([[0, 0, 0],
[1, 0, -1],
[0, 0, -1]], i1[:, 0, :])
self.assertAllEqual([[0, 0, 0],
[1, 1, -1],
[0, 0, 0]], i1[:, 1, :])
self.assertAllEqual([[0, 0, 0],
[0, 1, -1],
[1, 0, -1],
[0, 0, -1],
[1, 1, -1],
[0, 0, -1]], i2[:, 0, :])
def test_human_pose_dataset(self):
pose_data = [
[[0, 0], [2, 2]],
[[2, 2]],
[[0, 0], [0, 0], [2, 2], [2, 2], [0, 0]],
]
pose_data = [np.array(x, dtype=np.float64) for x in pose_data]
pose_data_mean = np.array([1, 1], dtype=np.float64)
data = {
"train": pose_data,
"train_mean": pose_data_mean,
}
path = os.path.join(tf.test.get_temp_dir(), "test_human_pose_dataset.pkl")
with open(path, "wb") as out:
pickle.dump(data, out)
with self.test_session() as sess:
inputs, targets, lens, mean = datasets.create_human_pose_dataset(
path, "train", 2, num_parallel_calls=1, shuffle=False, repeat=False)
i1, t1, l1 = sess.run([inputs, targets, lens])
i2, t2, l2 = sess.run([inputs, targets, lens])
m = sess.run(mean)
# Check the lengths.
self.assertAllEqual([2, 1], l1)
self.assertAllEqual([5], l2)
# Check the mean.
self.assertAllEqual(pose_data_mean, m)
# Check the targets. The targets should not be mean-centered and should
# be padded with zeros to a common length within a batch.
self.assertAllEqual([[0, 0], [2, 2]], t1[:, 0, :])
self.assertAllEqual([[2, 2], [0, 0]], t1[:, 1, :])
self.assertAllEqual([[0, 0], [0, 0], [2, 2], [2, 2], [0, 0]], t2[:, 0, :])
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch. The mean should be subtracted from all timesteps except
# the first and the padding.
self.assertAllEqual([[0, 0], [-1, -1]], i1[:, 0, :])
self.assertAllEqual([[0, 0], [0, 0]], i1[:, 1, :])
self.assertAllEqual([[0, 0], [-1, -1], [-1, -1], [1, 1], [1, 1]],
i2[:, 0, :])
def test_speech_dataset(self):
with self.test_session() as sess:
path = os.path.join(
os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
"test_data",
"tiny_speech_dataset.tfrecord")
inputs, targets, lens = datasets.create_speech_dataset(
path, 3, samples_per_timestep=2, num_parallel_calls=1,
prefetch_buffer_size=3, shuffle=False, repeat=False)
inputs1, targets1, lengths1 = sess.run([inputs, targets, lens])
inputs2, targets2, lengths2 = sess.run([inputs, targets, lens])
# Check the lengths.
self.assertAllEqual([1, 2, 3], lengths1)
self.assertAllEqual([4], lengths2)
# Check the targets. The targets should be padded with zeros to a common
# length within a batch.
self.assertAllEqual([[[0., 1.], [0., 1.], [0., 1.]],
[[0., 0.], [2., 3.], [2., 3.]],
[[0., 0.], [0., 0.], [4., 5.]]],
targets1)
self.assertAllEqual([[[0., 1.]],
[[2., 3.]],
[[4., 5.]],
[[6., 7.]]],
targets2)
# Check the inputs. Each sequence should start with zeros on the first
# timestep. Each sequence should be padded with zeros to a common length
# within a batch.
self.assertAllEqual([[[0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 1.], [0., 1.]],
[[0., 0.], [0., 0.], [2., 3.]]],
inputs1)
self.assertAllEqual([[[0., 0.]],
[[0., 1.]],
[[2., 3.]],
[[4., 5.]]],
inputs2)
def test_chain_graph_raises_error_on_wrong_steps_per_observation(self):
with self.assertRaises(ValueError):
datasets.create_chain_graph_dataset(
batch_size=4,
num_timesteps=10,
steps_per_observation=9)
def test_chain_graph_single_obs(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 1
num_timesteps = 5
batch_size = 2
state_size = 1
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
state_size=state_size)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
[[[1.426677], [-1.789461]]],
out_observations)
def test_chain_graph_multiple_obs(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 3
num_timesteps = 6
batch_size = 2
state_size = 1
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
steps_per_observation=num_timesteps/num_observations,
state_size=state_size)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
[[[0.40051451], [1.07405114]],
[[1.73932898], [3.16880035]],
[[-1.98377144], [2.82669163]]],
out_observations)
def test_chain_graph_state_dims(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 1
num_timesteps = 5
batch_size = 2
state_size = 3
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
state_size=state_size)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
[[[1.052287, -4.560759, 3.07988],
[2.008926, 0.495567, 3.488678]]],
out_observations)
def test_chain_graph_fixed_obs(self):
with self.test_session() as sess:
np.random.seed(1234)
num_observations = 3
num_timesteps = 6
batch_size = 2
state_size = 1
observations, lengths = datasets.create_chain_graph_dataset(
batch_size=batch_size,
num_timesteps=num_timesteps,
steps_per_observation=num_timesteps/num_observations,
state_size=state_size,
fixed_observation=4.)
out_observations, out_lengths = sess.run([observations, lengths])
self.assertAllEqual([num_observations, num_observations], out_lengths)
self.assertAllClose(
np.ones([num_observations, batch_size, state_size]) * 4.,
out_observations)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Creates and runs Gaussian HMM-related graphs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from fivo import smc
from fivo import bounds
from fivo.data import datasets
from fivo.models import ghmm
def run_train(config):
"""Runs training for a Gaussian HMM setup."""
def create_logging_hook(step, bound_value, likelihood, bound_gap):
"""Creates a logging hook that prints the bound value periodically."""
bound_label = config.bound + "/t"
def summary_formatter(log_dict):
string = ("Step {step}, %s: {value:.3f}, "
"likelihood: {ll:.3f}, gap: {gap:.3e}") % bound_label
return string.format(**log_dict)
logging_hook = tf.train.LoggingTensorHook(
{"step": step, "value": bound_value,
"ll": likelihood, "gap": bound_gap},
every_n_iter=config.summarize_every,
formatter=summary_formatter)
return logging_hook
def create_losses(model, observations, lengths):
"""Creates the loss to be optimized.
Args:
model: A Trainable GHMM model.
observations: A set of observations.
lengths: The lengths of each sequence in the observations.
Returns:
loss: A float Tensor that when differentiated yields the gradients
to apply to the model. Should be optimized via gradient descent.
bound: A float Tensor containing the value of the bound that is
being optimized.
true_ll: The true log-likelihood of the data under the model.
bound_gap: The gap between the bound and the true log-likelihood.
"""
# Compute lower bounds on the log likelihood.
if config.bound == "elbo":
ll_per_seq, _, _ = bounds.iwae(
model, observations, lengths, num_samples=1,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "iwae":
ll_per_seq, _, _ = bounds.iwae(
model, observations, lengths, num_samples=config.num_samples,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "fivo":
if config.resampling_type == "relaxed":
ll_per_seq, _, _, _ = bounds.fivo(
model,
observations,
lengths,
num_samples=config.num_samples,
resampling_criterion=smc.ess_criterion,
resampling_type=config.resampling_type,
relaxed_resampling_temperature=config.
relaxed_resampling_temperature,
random_seed=config.random_seed,
parallel_iterations=config.parallel_iterations)
else:
ll_per_seq, _, _, _ = bounds.fivo(
model, observations, lengths,
num_samples=config.num_samples,
resampling_criterion=smc.ess_criterion,
resampling_type=config.resampling_type,
random_seed=config.random_seed,
parallel_iterations=config.parallel_iterations
)
ll_per_t = tf.reduce_mean(ll_per_seq / tf.to_float(lengths))
# Compute the data's true likelihood under the model and the bound gap.
true_ll_per_seq = model.likelihood(tf.squeeze(observations))
true_ll_per_t = tf.reduce_mean(true_ll_per_seq / tf.to_float(lengths))
bound_gap = true_ll_per_seq - ll_per_seq
bound_gap = tf.reduce_mean(bound_gap/ tf.to_float(lengths))
tf.summary.scalar("train_ll_bound", ll_per_t)
tf.summary.scalar("train_true_ll", true_ll_per_t)
tf.summary.scalar("bound_gap", bound_gap)
return -ll_per_t, ll_per_t, true_ll_per_t, bound_gap
def create_graph():
"""Creates the training graph."""
global_step = tf.train.get_or_create_global_step()
xs, lengths = datasets.create_chain_graph_dataset(
config.batch_size,
config.num_timesteps,
steps_per_observation=1,
state_size=1,
transition_variance=config.variance,
observation_variance=config.variance)
model = ghmm.TrainableGaussianHMM(
config.num_timesteps,
config.proposal_type,
transition_variances=config.variance,
emission_variances=config.variance,
random_seed=config.random_seed)
loss, bound, true_ll, gap = create_losses(model, xs, lengths)
opt = tf.train.AdamOptimizer(config.learning_rate)
grads = opt.compute_gradients(loss, var_list=tf.trainable_variables())
train_op = opt.apply_gradients(grads, global_step=global_step)
return bound, true_ll, gap, train_op, global_step
with tf.Graph().as_default():
if config.random_seed:
tf.set_random_seed(config.random_seed)
np.random.seed(config.random_seed)
bound, true_ll, gap, train_op, global_step = create_graph()
log_hook = create_logging_hook(global_step, bound, true_ll, gap)
with tf.train.MonitoredTrainingSession(
master="",
hooks=[log_hook],
checkpoint_dir=config.logdir,
save_checkpoint_secs=120,
save_summaries_steps=config.summarize_every,
log_step_count_steps=config.summarize_every*20) as sess:
cur_step = -1
while cur_step <= config.max_steps and not sess.should_stop():
cur_step = sess.run(global_step)
_, cur_step = sess.run([train_op, global_step])
def run_eval(config):
"""Evaluates a Gaussian HMM using the given config."""
def create_bound(model, xs, lengths):
"""Creates the bound to be evaluated."""
if config.bound == "elbo":
ll_per_seq, log_weights, _ = bounds.iwae(
model, xs, lengths, num_samples=1,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "iwae":
ll_per_seq, log_weights, _ = bounds.iwae(
model, xs, lengths, num_samples=config.num_samples,
parallel_iterations=config.parallel_iterations
)
elif config.bound == "fivo":
ll_per_seq, log_weights, resampled, _ = bounds.fivo(
model, xs, lengths,
num_samples=config.num_samples,
resampling_criterion=smc.ess_criterion,
resampling_type=config.resampling_type,
random_seed=config.random_seed,
parallel_iterations=config.parallel_iterations
)
# Compute bound scaled by number of timesteps.
bound_per_t = ll_per_seq / tf.to_float(lengths)
if config.bound == "fivo":
return bound_per_t, log_weights, resampled
else:
return bound_per_t, log_weights
def create_graph():
"""Creates the dataset, model, and bound."""
xs, lengths = datasets.create_chain_graph_dataset(
config.batch_size,
config.num_timesteps,
steps_per_observation=1,
state_size=1,
transition_variance=config.variance,
observation_variance=config.variance)
model = ghmm.TrainableGaussianHMM(
config.num_timesteps,
config.proposal_type,
transition_variances=config.variance,
emission_variances=config.variance,
random_seed=config.random_seed)
true_likelihood = tf.reduce_mean(
model.likelihood(tf.squeeze(xs)) / tf.to_float(lengths))
outs = [true_likelihood]
outs.extend(list(create_bound(model, xs, lengths)))
return outs
with tf.Graph().as_default():
if config.random_seed:
tf.set_random_seed(config.random_seed)
np.random.seed(config.random_seed)
graph_outs = create_graph()
with tf.train.SingularMonitoredSession(
checkpoint_dir=config.logdir) as sess:
outs = sess.run(graph_outs)
likelihood = outs[0]
avg_bound = np.mean(outs[1])
std = np.std(outs[1])
log_weights = outs[2]
log_weight_variances = np.var(log_weights, axis=2)
avg_log_weight_variance = np.var(log_weight_variances, axis=1)
avg_log_weight = np.mean(log_weights, axis=(1, 2))
data = {"mean": avg_bound, "std": std, "log_weights": log_weights,
"log_weight_means": avg_log_weight,
"log_weight_variances": avg_log_weight_variance}
if len(outs) == 4:
data["resampled"] = outs[3]
data["avg_resampled"] = np.mean(outs[3], axis=1)
# Log some useful statistics.
tf.logging.info("Evaled bound %s with batch_size: %d, num_samples: %d."
% (config.bound, config.batch_size, config.num_samples))
tf.logging.info("mean: %f, std: %f" % (avg_bound, std))
tf.logging.info("true likelihood: %s" % likelihood)
tf.logging.info("avg log weight: %s" % avg_log_weight)
tf.logging.info("log weight variance: %s" % avg_log_weight_variance)
if len(outs) == 4:
tf.logging.info("avg resamples per t: %s" % data["avg_resampled"])
if not tf.gfile.Exists(config.logdir):
tf.gfile.MakeDirs(config.logdir)
with tf.gfile.Open(os.path.join(config.logdir, "out.npz"), "w") as fout:
np.save(fout, data)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment