Commit 46dea625 authored by Dieterich Lawson's avatar Dieterich Lawson
Browse files

Updating fivo codebase

parent 5856878d
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for testing FIVO.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from fivo.models import base
from fivo.models import srnn
from fivo.models import vrnn
def create_vrnn(generative_class=base.ConditionalNormalDistribution,
batch_size=2, data_size=3, rnn_hidden_size=4,
latent_size=5, fcnet_hidden_size=7, encoded_data_size=9,
encoded_latent_size=11, num_timesteps=7, data_lengths=(7, 4),
use_tilt=False, random_seed=None):
"""Creates a VRNN and some dummy data to feed it for testing purposes.
Args:
generative_class: The class of the generative distribution.
batch_size: The number of elements per batch.
data_size: The dimension of the vectors that make up the data sequences.
rnn_hidden_size: The hidden state dimension of the RNN that forms the
deterministic part of this VRNN.
latent_size: The size of the stochastic latent state of the VRNN.
fcnet_hidden_size: The size of the hidden layer of the fully connected
networks that parameterize the conditional probability distributions
of the VRNN.
encoded_data_size: The size of the output of the data encoding network.
encoded_latent_size: The size of the output of the latent state encoding
network.
num_timesteps: The maximum number of timesteps in the data.
data_lengths: A tuple of size batch_size that contains the desired lengths
of each sequence in the dummy data.
use_tilt: Use a tilting function.
random_seed: A random seed to feed the VRNN, mainly useful for testing
purposes.
Returns:
model: A VRNN object.
inputs: A Tensor of shape [num_timesteps, batch_size, data_size], the inputs
to the model, also known as the observations.
targets: A Tensor of shape [num_timesteps, batch_size, data_size], the
desired outputs of the model.
lengths: A Tensor of shape [batch_size], the lengths of the sequences in the
batch.
"""
fcnet_hidden_sizes = [fcnet_hidden_size]
initializers = {"w": tf.contrib.layers.xavier_initializer(seed=random_seed),
"b": tf.zeros_initializer()}
model = vrnn.create_vrnn(
data_size,
latent_size,
generative_class,
rnn_hidden_size=rnn_hidden_size,
fcnet_hidden_sizes=fcnet_hidden_sizes,
encoded_data_size=encoded_data_size,
encoded_latent_size=encoded_latent_size,
use_tilt=use_tilt,
initializers=initializers,
random_seed=random_seed)
inputs = tf.random_uniform([num_timesteps, batch_size, data_size],
seed=random_seed, dtype=tf.float32)
targets = tf.random_uniform([num_timesteps, batch_size, data_size],
seed=random_seed, dtype=tf.float32)
lengths = tf.constant(data_lengths, dtype=tf.int32)
return model, inputs, targets, lengths
def create_srnn(generative_class=base.ConditionalNormalDistribution,
batch_size=2, data_size=3, rnn_hidden_size=4,
latent_size=5, fcnet_hidden_size=7, encoded_data_size=3,
encoded_latent_size=2, num_timesteps=7, data_lengths=(7, 4),
use_tilt=False, random_seed=None):
"""Creates a SRNN and some dummy data to feed it for testing purposes.
Args:
generative_class: The class of the generative distribution.
batch_size: The number of elements per batch.
data_size: The dimension of the vectors that make up the data sequences.
rnn_hidden_size: The hidden state dimension of the RNN that forms the
deterministic part of this SRNN.
latent_size: The size of the stochastic latent state of the SRNN.
fcnet_hidden_size: The size of the hidden layer of the fully connected
networks that parameterize the conditional probability distributions
of the SRNN.
encoded_data_size: The size of the output of the data encoding network.
encoded_latent_size: The size of the output of the latent state encoding
network.
num_timesteps: The maximum number of timesteps in the data.
data_lengths: A tuple of size batch_size that contains the desired lengths
of each sequence in the dummy data.
use_tilt: Use a tilting function.
random_seed: A random seed to feed the SRNN, mainly useful for testing
purposes.
Returns:
model: A SRNN object.
inputs: A Tensor of shape [num_timesteps, batch_size, data_size], the inputs
to the model, also known as the observations.
targets: A Tensor of shape [num_timesteps, batch_size, data_size], the
desired outputs of the model.
lengths: A Tensor of shape [batch_size], the lengths of the sequences in the
batch.
"""
fcnet_hidden_sizes = [fcnet_hidden_size]
initializers = {"w": tf.contrib.layers.xavier_initializer(seed=random_seed),
"b": tf.zeros_initializer()}
model = srnn.create_srnn(
data_size,
latent_size,
generative_class,
rnn_hidden_size=rnn_hidden_size,
fcnet_hidden_sizes=fcnet_hidden_sizes,
encoded_data_size=encoded_data_size,
encoded_latent_size=encoded_latent_size,
use_tilt=use_tilt,
initializers=initializers,
random_seed=random_seed)
inputs = tf.random_uniform([num_timesteps, batch_size, data_size],
seed=random_seed, dtype=tf.float32)
targets = tf.random_uniform([num_timesteps, batch_size, data_size],
seed=random_seed, dtype=tf.float32)
lengths = tf.constant(data_lengths, dtype=tf.int32)
return model, inputs, targets, lengths
# Copyright 2017 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,17 +22,21 @@ from __future__ import print_function ...@@ -22,17 +22,21 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
import runners from fivo import ghmm_runners
from fivo import runners
# Shared flags. # Shared flags.
tf.app.flags.DEFINE_string("mode", "train", tf.app.flags.DEFINE_enum("mode", "train",
"The mode of the binary. Must be 'train' or 'test'.") ["train", "eval", "sample"],
tf.app.flags.DEFINE_string("model", "vrnn", "The mode of the binary.")
"Model choice. Currently only 'vrnn' is supported.") tf.app.flags.DEFINE_enum("model", "vrnn",
["vrnn", "ghmm", "srnn"],
"Model choice.")
tf.app.flags.DEFINE_integer("latent_size", 64, tf.app.flags.DEFINE_integer("latent_size", 64,
"The size of the latent state of the model.") "The size of the latent state of the model.")
tf.app.flags.DEFINE_string("dataset_type", "pianoroll", tf.app.flags.DEFINE_enum("dataset_type", "pianoroll",
"The type of dataset, either 'pianoroll' or 'speech'.") ["pianoroll", "speech", "pose"],
"The type of dataset.")
tf.app.flags.DEFINE_string("dataset_path", "", tf.app.flags.DEFINE_string("dataset_path", "",
"Path to load the dataset from.") "Path to load the dataset from.")
tf.app.flags.DEFINE_integer("data_dimension", None, tf.app.flags.DEFINE_integer("data_dimension", None,
...@@ -43,16 +47,20 @@ tf.app.flags.DEFINE_integer("data_dimension", None, ...@@ -43,16 +47,20 @@ tf.app.flags.DEFINE_integer("data_dimension", None,
tf.app.flags.DEFINE_integer("batch_size", 4, tf.app.flags.DEFINE_integer("batch_size", 4,
"Batch size.") "Batch size.")
tf.app.flags.DEFINE_integer("num_samples", 4, tf.app.flags.DEFINE_integer("num_samples", 4,
"The number of samples (or particles) for multisample " "The number of samples (or particles) for multisample "
"algorithms.") "algorithms.")
tf.app.flags.DEFINE_string("logdir", "/tmp/smc_vi", tf.app.flags.DEFINE_string("logdir", "/tmp/smc_vi",
"The directory to keep checkpoints and summaries in.") "The directory to keep checkpoints and summaries in.")
tf.app.flags.DEFINE_integer("random_seed", None, tf.app.flags.DEFINE_integer("random_seed", None,
"A random seed for seeding the TensorFlow graph.") "A random seed for seeding the TensorFlow graph.")
tf.app.flags.DEFINE_integer("parallel_iterations", 30,
"The number of parallel iterations to use for the while "
"loop that computes the bounds.")
# Training flags. # Training flags.
tf.app.flags.DEFINE_string("bound", "fivo", tf.app.flags.DEFINE_enum("bound", "fivo",
"The bound to optimize. Can be 'elbo', 'iwae', or 'fivo'.") ["elbo", "iwae", "fivo", "fivo-aux"],
"The bound to optimize.")
tf.app.flags.DEFINE_boolean("normalize_by_seq_len", True, tf.app.flags.DEFINE_boolean("normalize_by_seq_len", True,
"If true, normalize the loss by the number of timesteps " "If true, normalize the loss by the number of timesteps "
"per sequence.") "per sequence.")
...@@ -62,6 +70,17 @@ tf.app.flags.DEFINE_integer("max_steps", int(1e9), ...@@ -62,6 +70,17 @@ tf.app.flags.DEFINE_integer("max_steps", int(1e9),
"The number of gradient update steps to train for.") "The number of gradient update steps to train for.")
tf.app.flags.DEFINE_integer("summarize_every", 50, tf.app.flags.DEFINE_integer("summarize_every", 50,
"The number of steps between summaries.") "The number of steps between summaries.")
tf.app.flags.DEFINE_enum("resampling_type", "multinomial",
["multinomial", "relaxed"],
"The resampling strategy to use for training.")
tf.app.flags.DEFINE_float("relaxed_resampling_temperature", 0.5,
"The relaxation temperature for relaxed resampling.")
tf.app.flags.DEFINE_enum("proposal_type", "filtering",
["prior", "filtering", "smoothing",
"true-filtering", "true-smoothing"],
"The type of proposal to use. true-filtering and true-smoothing "
"are only available for the GHMM. The specific implementation "
"of each proposal type is left to model-writers.")
# Distributed training flags. # Distributed training flags.
tf.app.flags.DEFINE_string("master", "", tf.app.flags.DEFINE_string("master", "",
...@@ -74,9 +93,25 @@ tf.app.flags.DEFINE_boolean("stagger_workers", True, ...@@ -74,9 +93,25 @@ tf.app.flags.DEFINE_boolean("stagger_workers", True,
"If true, bring one worker online every 1000 steps.") "If true, bring one worker online every 1000 steps.")
# Evaluation flags. # Evaluation flags.
tf.app.flags.DEFINE_string("split", "train", tf.app.flags.DEFINE_enum("split", "train",
"Split to evaluate the model on. Can be 'train', 'valid', or 'test'.") ["train", "test", "valid"],
"Split to evaluate the model on.")
# Sampling flags.
tf.app.flags.DEFINE_integer("sample_length", 50,
"The number of timesteps to sample for.")
tf.app.flags.DEFINE_integer("prefix_length", 25,
"The number of timesteps to condition the model on "
"before sampling.")
tf.app.flags.DEFINE_string("sample_out_dir", None,
"The directory to write the samples to. "
"Defaults to logdir.")
# GHMM flags.
tf.app.flags.DEFINE_float("variance", 0.1,
"The variance of the ghmm.")
tf.app.flags.DEFINE_integer("num_timesteps", 5,
"The number of timesteps to run the gmp for.")
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
PIANOROLL_DEFAULT_DATA_DIMENSION = 88 PIANOROLL_DEFAULT_DATA_DIMENSION = 88
...@@ -85,15 +120,23 @@ SPEECH_DEFAULT_DATA_DIMENSION = 200 ...@@ -85,15 +120,23 @@ SPEECH_DEFAULT_DATA_DIMENSION = 200
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.data_dimension is None: if FLAGS.model in ["vrnn", "srnn"]:
if FLAGS.dataset_type == "pianoroll": if FLAGS.data_dimension is None:
FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION if FLAGS.dataset_type == "pianoroll":
elif FLAGS.dataset_type == "speech": FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION elif FLAGS.dataset_type == "speech":
if FLAGS.mode == "train": FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
runners.run_train(FLAGS) if FLAGS.mode == "train":
elif FLAGS.mode == "eval": runners.run_train(FLAGS)
runners.run_eval(FLAGS) elif FLAGS.mode == "eval":
runners.run_eval(FLAGS)
elif FLAGS.mode == "sample":
runners.run_sample(FLAGS)
elif FLAGS.model == "ghmm":
if FLAGS.mode == "train":
ghmm_runners.run_train(FLAGS)
elif FLAGS.mode == "eval":
ghmm_runners.run_eval(FLAGS)
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() tf.app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment