Commit 703bcbab authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #2029 from gjtucker/master

Add rebar to models.
parents e88d0cf4 a978aa9c
......@@ -19,6 +19,7 @@ object_detection/* @jch1 @tombstone @derekjchow @jesu9 @dreamdragon
pcl_rl/* @ofirnachum
ptn/* @xcyan @arkanath @hellojas @honglaklee
real_nvp/* @laurent-dinh
rebar/* @gjtucker
resnet/* @panyx0718
skip_thoughts/* @cshallue
slim/* @sguada @nathansilberman
......
......@@ -30,6 +30,7 @@ running TensorFlow 0.12 or earlier, please
- [next_frame_prediction](next_frame_prediction): probabilistic future frame synthesis via cross convolutional networks.
- [object_detection](object_detection): localizing and identifying multiple objects in a single image.
- [real_nvp](real_nvp): density estimation using real-valued non-volume preserving (real NVP) transformations.
- [rebar](rebar): low-variance, unbiased gradient estimates for discrete latent variable models.
- [resnet](resnet): deep and wide residual networks.
- [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector encoder.
- [slim](slim): image classification models in TF-Slim.
......
# REINFORCing Concrete with REBAR
*Implemention of REBAR (and other closely related methods) as described
in "REBAR: Low-variance, unbiased gradient estimates for discrete latent variable models" by
George Tucker, Andriy Mnih, Chris J. Maddison, Dieterich Lawson, Jascha Sohl-Dickstein [(https://arxiv.org/abs/1703.07370)](https://arxiv.org/abs/1703.07370).*
Learning in models with discrete latent variables is challenging due to high variance gradient estimators. Generally, approaches have relied on control variates to reduce the variance of the REINFORCE estimator. Recent work ([Jang et al. 2016](https://arxiv.org/abs/1611.01144); [Maddison et al. 2016](https://arxiv.org/abs/1611.00712)) has taken a different approach, introducing a continuous relaxation of discrete variables to produce low-variance, but biased, gradient estimates. In this work, we combine the two approaches through a novel control variate that produces low-variance, unbiased gradient estimates. Then, we introduce a novel continuous relaxation and show that the tightness of the relaxation can be adapted online, removing it as a hyperparameter. We show state-of-the-art variance reduction on several benchmark generative modeling tasks, generally leading to faster convergence to a better final log likelihood.
REBAR applied to multilayer sigmoid belief networks is implemented in rebar.py and rebar_train.py provides a training/evaluation setup. As a comparison, we also implemented the following methods:
* [NVIL](https://arxiv.org/abs/1402.0030)
* [MuProp](https://arxiv.org/abs/1511.05176)
* [Gumbel-Softmax](https://arxiv.org/abs/1611.01144)
The code is not optimized and some computation is repeated for ease of
implementation. We hope that this code will be a useful starting point for future research in this area.
## Quick Start:
Requirements:
* TensorFlow (see tensorflow.org for how to install)
* MNIST dataset
* Omniglot dataset
First download datasets by selecting URLs to download the data from. Then
fill in the download_data.py script like so:
```
MNIST_URL = 'http://yann.lecun.com/exdb/mnist'
MNIST_BINARIZED_URL = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist'
OMNIGLOT_URL = 'https://github.com/yburda/iwae/raw/master/datasets/OMNIGLOT/chardata.mat'
```
Then run the script to download the data:
```
python download_data.py
```
Then run the training script:
```
python rebar_train.py --hparams="model=SBNDynamicRebar,learning_rate=0.0003,n_layer=2,task=sbn"
```
and you should see something like:
```
Step 2084: [-231.026474 0.3711713 1. 1.06934261 1.07023323
1.02173257 1.02171052 1. 1. 1. 1. ]
-3.6465678215
Step 4168: [-156.86795044 0.3097114 1. 1.03964758 1.03936625
1.02627242 1.02629256 1. 1. 1. 1. ]
-4.42727231979
Step 6252: [-143.4650116 0.26153237 1. 1.03633797 1.03600132
1.02639604 1.02639794 1. 1. 1. 1. ]
-4.85577583313
Step 8336: [-137.65275574 0.22313026 1. 1.03467286 1.03428006
1.02336085 1.02335203 0.99999988 1. 0.99999988
1. ]
-4.95563364029
```
The first number in the list is the log likelihood lower bound and the number
after the list is the log of the variance of the gradient estimator. The rest of
the numbers are for debugging.
We can also compare the variance between methods:
```
python rebar_train.py \
--hparams="model=SBNTrackGradVariances,learning_rate=0.0003,n_layer=2,task=omni"
```
and you should see something like:
```
Step 959: [ -2.60478699e+02 3.84281784e-01 6.31126612e-02 3.27319391e-02
6.13379292e-03 1.98278503e-04 1.96425783e-04 8.83973844e-04
8.70995224e-04 -inf]
('DynamicREBAR', -3.725339889526367)
('MuProp', -0.033569782972335815)
('NVIL', 2.7640280723571777)
('REBAR', -3.539274215698242)
('SimpleMuProp', -0.040744658559560776)
Step 1918: [ -2.06948471e+02 3.35904926e-01 5.20901568e-03 7.81541676e-05
2.06885766e-03 1.08521657e-04 1.07351625e-04 2.30646547e-04
2.26554010e-04 -8.22885323e+00]
('DynamicREBAR', -3.864381790161133)
('MuProp', -0.7183765172958374)
('NVIL', 2.266523599624634)
('REBAR', -3.662022113800049)
('SimpleMuProp', -0.7071359157562256)
```
where the tuples show the log of the variance of the gradient estimators.
The training script has a number of hyperparameter configuration flags:
* task (sbn): one of {sbn, sp, omni} which correspond to MNIST generative
modeling, structured prediction on MNIST, and Omniglot generative modeling,
respectively
* model (SBNGumbel) : one of {SBN, SBNNVIL, SBNMuProp, SBNSimpleMuProp,
SBNRebar, SBNDynamicRebar, SBNGumbel SBNTrackGradVariances}. DynamicRebar automatically
adjusts the temperature, whereas Rebar and Gumbel-Softmax require tuning the
temperature. The ones named after
methods uses that method to estimate the gradients (SBN refers to
REINFORCE). SBNTrackGradVariances runs multiple methods and follows a single
optimization trajectory
* n_hidden (200): number of hidden nodes per layer
* n_layer (1): number of layers in the model
* nonlinear (false): if true use 2 x tanh layers between each stochastic layer,
otherwise use a linear layer
* learning_rate (0.001): learning rate
* temperature (0.5): temperature hyperparameter (for DynamicRebar, this is the initial
value of the temperature)
* n_samples (1): number of samples used to compute the gradient estimator (for the
experiments in the paper, set to 1)
* batch_size (24): batch size
* muprop_relaxation (true): if true use the new relaxation described in the paper,
otherwise use the Concrete/Gumbel softmax relaxation
* dynamic_b (false): if true dynamically binarize the training set. This
increases the effective training dataset size and reduces overfitting, though
it is not a standard dataset
Maintained by George Tucker (gjt@google.com, github user: gjtucker).
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration variables."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
DATA_DIR = 'data'
MNIST_BINARIZED = 'mnist_salakhutdinov_07-19-2017.pkl'
MNIST_FLOAT = 'mnist_train_xs_07-19-2017.npy'
OMNIGLOT = 'omniglot_07-19-2017.mat'
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library of datasets for REBAR."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import os
import scipy.io
import numpy as np
import cPickle as pickle
import tensorflow as tf
import config
gfile = tf.gfile
def load_data(hparams):
# Load data
if hparams.task in ['sbn', 'sp']:
reader = read_MNIST
elif hparams.task == 'omni':
reader = read_omniglot
x_train, x_valid, x_test = reader(binarize=not hparams.dynamic_b)
return x_train, x_valid, x_test
def read_MNIST(binarize=False):
"""Reads in MNIST images.
Args:
binarize: whether to use the fixed binarization
Returns:
x_train: 50k training images
x_valid: 10k validation images
x_test: 10k test images
"""
with gfile.FastGFile(os.path.join(config.DATA_DIR, config.MNIST_BINARIZED), 'r') as f:
(x_train, _), (x_valid, _), (x_test, _) = pickle.load(f)
if not binarize:
with gfile.FastGFile(os.path.join(config.DATA_DIR, config.MNIST_FLOAT), 'r') as f:
x_train = np.load(f).reshape(-1, 784)
return x_train, x_valid, x_test
def read_omniglot(binarize=False):
"""Reads in Omniglot images.
Args:
binarize: whether to use the fixed binarization
Returns:
x_train: training images
x_valid: validation images
x_test: test images
"""
n_validation=1345
def reshape_data(data):
return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='fortran')
omni_raw = scipy.io.loadmat(os.path.join(config.DATA_DIR, config.OMNIGLOT))
train_data = reshape_data(omni_raw['data'].T.astype('float32'))
test_data = reshape_data(omni_raw['testdata'].T.astype('float32'))
# Binarize the data with a fixed seed
if binarize:
np.random.seed(5)
train_data = (np.random.rand(*train_data.shape) < train_data).astype(float)
test_data = (np.random.rand(*test_data.shape) < test_data).astype(float)
shuffle_seed = 123
permutation = np.random.RandomState(seed=shuffle_seed).permutation(train_data.shape[0])
train_data = train_data[permutation]
x_train = train_data[:-n_validation]
x_valid = train_data[-n_validation:]
x_test = test_data
return x_train, x_valid, x_test
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download MNIST, Omniglot datasets for Rebar."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import urllib
import gzip
import os
import config
import struct
import numpy as np
import cPickle as pickle
import datasets
MNIST_URL = 'see README'
MNIST_BINARIZED_URL = 'see README'
OMNIGLOT_URL = 'see README'
MNIST_FLOAT_TRAIN = 'train-images-idx3-ubyte'
def load_mnist_float(local_filename):
with open(local_filename, 'rb') as f:
f.seek(4)
nimages, rows, cols = struct.unpack('>iii', f.read(12))
dim = rows*cols
images = np.fromfile(f, dtype=np.dtype(np.ubyte))
images = (images/255.0).astype('float32').reshape((nimages, dim))
return images
if __name__ == '__main__':
if not os.path.exists(config.DATA_DIR):
os.makedirs(config.DATA_DIR)
# Get MNIST and convert to npy file
local_filename = os.path.join(config.DATA_DIR, MNIST_FLOAT_TRAIN)
if not os.path.exists(local_filename):
urllib.urlretrieve("%s/%s.gz" % (MNIST_URL, MNIST_FLOAT_TRAIN), local_filename+'.gz')
with gzip.open(local_filename+'.gz', 'rb') as f:
file_content = f.read()
with open(local_filename, 'wb') as f:
f.write(file_content)
os.remove(local_filename+'.gz')
mnist_float_train = load_mnist_float(local_filename)[:-10000]
# save in a nice format
np.save(os.path.join(config.DATA_DIR, config.MNIST_FLOAT), mnist_float_train)
# Get binarized MNIST
splits = ['train', 'valid', 'test']
mnist_binarized = []
for split in splits:
filename = 'binarized_mnist_%s.amat' % split
url = '%s/binarized_mnist_%s.amat' % (MNIST_BINARIZED_URL, split)
local_filename = os.path.join(config.DATA_DIR, filename)
if not os.path.exists(local_filename):
urllib.urlretrieve(url, local_filename)
with open(local_filename, 'rb') as f:
mnist_binarized.append((np.array([map(int, line.split()) for line in f.readlines()]).astype('float32'), None))
# save in a nice format
with open(os.path.join(config.DATA_DIR, config.MNIST_BINARIZED), 'w') as out:
pickle.dump(mnist_binarized, out)
# Get Omniglot
local_filename = os.path.join(config.DATA_DIR, config.OMNIGLOT)
if not os.path.exists(local_filename):
urllib.urlretrieve(OMNIGLOT_URL,
local_filename)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logger for REBAR"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class Logger:
def __init__(self):
pass
def log(self, key, value):
pass
def flush(self):
pass
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
import numpy as np
from scipy.misc import logsumexp
import tensorflow.contrib.slim as slim
from tensorflow.python.ops import init_ops
import utils as U
FLAGS = tf.flags.FLAGS
Q_COLLECTION = "q_collection"
P_COLLECTION = "p_collection"
class SBN(object): # REINFORCE
def __init__(self,
hparams,
activation_func=tf.nn.sigmoid,
mean_xs = None,
eval_mode=False):
self.eval_mode = eval_mode
self.hparams = hparams
self.mean_xs = mean_xs
self.train_bias= -np.log(1./np.clip(mean_xs, 0.001, 0.999)-1.).astype(np.float32)
self.activation_func = activation_func
self.n_samples = tf.placeholder('int32')
self.x = tf.placeholder('float', [None, self.hparams.n_input])
self._x = tf.tile(self.x, [self.n_samples, 1])
self.batch_size = tf.shape(self._x)[0]
self.uniform_samples = dict()
self.uniform_samples_v = dict()
self.prior = tf.Variable(tf.zeros([self.hparams.n_hidden],
dtype=tf.float32),
name='p_prior',
collections=[tf.GraphKeys.GLOBAL_VARIABLES, P_COLLECTION])
self.run_recognition_network = False
self.run_generator_network = False
# Initialize temperature
self.pre_temperature_variable = tf.Variable(
np.log(self.hparams.temperature),
trainable=False,
dtype=tf.float32)
self.temperature_variable = tf.exp(self.pre_temperature_variable)
self.global_step = tf.Variable(0, trainable=False)
self.baseline_loss = []
self.ema = tf.train.ExponentialMovingAverage(decay=0.999)
self.maintain_ema_ops = []
self.optimizer_class = tf.train.AdamOptimizer(
learning_rate=1*self.hparams.learning_rate,
beta2=self.hparams.beta2)
self._generate_randomness()
self._create_network()
def initialize(self, sess):
self.sess = sess
def _create_eta(self, shape=[], collection='CV'):
return 2 * tf.sigmoid(tf.Variable(tf.zeros(shape), trainable=False,
collections=[collection, tf.GraphKeys.GLOBAL_VARIABLES, Q_COLLECTION]))
def _create_baseline(self, n_output=1, n_hidden=100,
is_zero_init=False,
collection='BASELINE'):
# center input
h = self._x
if self.mean_xs is not None:
h -= self.mean_xs
if is_zero_init:
initializer = init_ops.zeros_initializer()
else:
initializer = slim.variance_scaling_initializer()
with slim.arg_scope([slim.fully_connected],
variables_collections=[collection, Q_COLLECTION],
trainable=False,
weights_initializer=initializer):
h = slim.fully_connected(h, n_hidden, activation_fn=tf.nn.tanh)
baseline = slim.fully_connected(h, n_output, activation_fn=None)
if n_output == 1:
baseline = tf.reshape(baseline, [-1]) # very important to reshape
return baseline
def _create_transformation(self, input, n_output, reuse, scope_prefix):
"""Create the deterministic transformation between stochastic layers.
If self.hparam.nonlinear:
2 x tanh layers
Else:
1 x linear layer
"""
if self.hparams.nonlinear:
h = slim.fully_connected(input,
self.hparams.n_hidden,
reuse=reuse,
activation_fn=tf.nn.tanh,
scope='%s_nonlinear_1' % scope_prefix)
h = slim.fully_connected(h,
self.hparams.n_hidden,
reuse=reuse,
activation_fn=tf.nn.tanh,
scope='%s_nonlinear_2' % scope_prefix)
h = slim.fully_connected(h,
n_output,
reuse=reuse,
activation_fn=None,
scope='%s' % scope_prefix)
else:
h = slim.fully_connected(input,
n_output,
reuse=reuse,
activation_fn=None,
scope='%s' % scope_prefix)
return h
def _recognition_network(self, sampler=None, log_likelihood_func=None):
"""x values -> samples from Q and return log Q(h|x)."""
samples = {}
reuse = None if not self.run_recognition_network else True
# Set defaults
if sampler is None:
sampler = self._random_sample
if log_likelihood_func is None:
log_likelihood_func = lambda sample, log_params: (
U.binary_log_likelihood(sample['activation'], log_params))
logQ = []
if self.hparams.task in ['sbn', 'omni']:
# Initialize the edge case
samples[-1] = {'activation': self._x}
if self.mean_xs is not None:
samples[-1]['activation'] -= self.mean_xs # center the input
samples[-1]['activation'] = (samples[-1]['activation'] + 1)/2.0
with slim.arg_scope([slim.fully_connected],
weights_initializer=slim.variance_scaling_initializer(),
variables_collections=[Q_COLLECTION]):
for i in xrange(self.hparams.n_layer):
# Set up the input to the layer
input = 2.0*samples[i-1]['activation'] - 1.0
# Create the conditional distribution (output is the logits)
h = self._create_transformation(input,
n_output=self.hparams.n_hidden,
reuse=reuse,
scope_prefix='q_%d' % i)
samples[i] = sampler(h, self.uniform_samples[i], i)
logQ.append(log_likelihood_func(samples[i], h))
self.run_recognition_network = True
return logQ, samples
elif self.hparams.task == 'sp':
# Initialize the edge case
samples[-1] = {'activation': tf.split(self._x,
num_or_size_splits=2,
axis=1)[0]} # top half of digit
if self.mean_xs is not None:
samples[-1]['activation'] -= np.split(self.mean_xs, 2, 0)[0] # center the input
samples[-1]['activation'] = (samples[-1]['activation'] + 1)/2.0
with slim.arg_scope([slim.fully_connected],
weights_initializer=slim.variance_scaling_initializer(),
variables_collections=[Q_COLLECTION]):
for i in xrange(self.hparams.n_layer):
# Set up the input to the layer
input = 2.0*samples[i-1]['activation'] - 1.0
# Create the conditional distribution (output is the logits)
h = self._create_transformation(input,
n_output=self.hparams.n_hidden,
reuse=reuse,
scope_prefix='q_%d' % i)
samples[i] = sampler(h, self.uniform_samples[i], i)
logQ.append(log_likelihood_func(samples[i], h))
self.run_recognition_network = True
return logQ, samples
def _generator_network(self, samples, logQ, log_likelihood_func=None):
'''Returns learning signal and function.
This is the implementation for SBNs for the ELBO.
Args:
samples: dictionary of sampled latent variables
logQ: list of log q(h_i) terms
log_likelihood_func: function used to compute log probs for the latent
variables
Returns:
learning_signal: the "reward" function
function_term: part of the function that depends on the parameters
and needs to have the gradient taken through
'''
reuse=None if not self.run_generator_network else True
if self.hparams.task in ['sbn', 'omni']:
if log_likelihood_func is None:
log_likelihood_func = lambda sample, log_params: (
U.binary_log_likelihood(sample['activation'], log_params))
logPPrior = log_likelihood_func(
samples[self.hparams.n_layer-1],
tf.expand_dims(self.prior, 0))
with slim.arg_scope([slim.fully_connected],
weights_initializer=slim.variance_scaling_initializer(),
variables_collections=[P_COLLECTION]):
for i in reversed(xrange(self.hparams.n_layer)):
if i == 0:
n_output = self.hparams.n_input
else:
n_output = self.hparams.n_hidden
input = 2.0*samples[i]['activation']-1.0
h = self._create_transformation(input,
n_output,
reuse=reuse,
scope_prefix='p_%d' % i)
if i == 0:
# Assume output is binary
logP = U.binary_log_likelihood(self._x, h + self.train_bias)
else:
logPPrior += log_likelihood_func(samples[i-1], h)
self.run_generator_network = True
return logP + logPPrior - tf.add_n(logQ), logP + logPPrior
elif self.hparams.task == 'sp':
with slim.arg_scope([slim.fully_connected],
weights_initializer=slim.variance_scaling_initializer(),
variables_collections=[P_COLLECTION]):
n_output = int(self.hparams.n_input/2)
i = self.hparams.n_layer - 1 # use the last layer
input = 2.0*samples[i]['activation']-1.0
h = self._create_transformation(input,
n_output,
reuse=reuse,
scope_prefix='p_%d' % i)
# Predict on the lower half of the image
logP = U.binary_log_likelihood(tf.split(self._x,
num_or_size_splits=2,
axis=1)[1],
h + np.split(self.train_bias, 2, 0)[1])
self.run_generator_network = True
return logP, logP
def _create_loss(self):
# Hard loss
logQHard, samples = self._recognition_network()
reinforce_learning_signal, reinforce_model_grad = self._generator_network(samples, logQHard)
logQHard = tf.add_n(logQHard)
# REINFORCE
learning_signal = tf.stop_gradient(center(reinforce_learning_signal))
self.optimizerLoss = -(learning_signal*logQHard +
reinforce_model_grad)
self.lHat = map(tf.reduce_mean, [
reinforce_learning_signal,
U.rms(learning_signal),
])
return reinforce_learning_signal
def _reshape(self, t):
return tf.transpose(tf.reshape(t,
[self.n_samples, -1]))
def compute_tensor_variance(self, t):
"""Compute the mean per component variance.
Use a moving average to estimate the required moments.
"""
t_sq = tf.reduce_mean(tf.square(t))
self.maintain_ema_ops.append(self.ema.apply([t, t_sq]))
# mean per component variance
variance_estimator = (self.ema.average(t_sq) -
tf.reduce_mean(
tf.square(self.ema.average(t))))
return variance_estimator
def _create_train_op(self, grads_and_vars, extra_grads_and_vars=[]):
'''
Args:
grads_and_vars: gradients to apply and compute running average variance
extra_grads_and_vars: gradients to apply (not used to compute average variance)
'''
# Variance summaries
first_moment = U.vectorize(grads_and_vars, skip_none=True)
second_moment = tf.square(first_moment)
self.maintain_ema_ops.append(self.ema.apply([first_moment, second_moment]))
# Add baseline losses
if len(self.baseline_loss) > 0:
mean_baseline_loss = tf.reduce_mean(tf.add_n(self.baseline_loss))
extra_grads_and_vars += self.optimizer_class.compute_gradients(
mean_baseline_loss,
var_list=tf.get_collection('BASELINE'))
# Ensure that all required tensors are computed before updates are executed
extra_optimizer = tf.train.AdamOptimizer(
learning_rate=10*self.hparams.learning_rate,
beta2=self.hparams.beta2)
with tf.control_dependencies(
[tf.group(*[g for g, _ in (grads_and_vars + extra_grads_and_vars) if g is not None])]):
# Filter out the P_COLLECTION variables if we're in eval mode
if self.eval_mode:
grads_and_vars = [(g, v) for g, v in grads_and_vars
if v not in tf.get_collection(P_COLLECTION)]
train_op = self.optimizer_class.apply_gradients(grads_and_vars,
global_step=self.global_step)
if len(extra_grads_and_vars) > 0:
extra_train_op = extra_optimizer.apply_gradients(extra_grads_and_vars)
else:
extra_train_op = tf.no_op()
self.optimizer = tf.group(train_op, extra_train_op, *self.maintain_ema_ops)
# per parameter variance
variance_estimator = (self.ema.average(second_moment) -
tf.square(self.ema.average(first_moment)))
self.grad_variance = tf.reduce_mean(variance_estimator)
def _create_network(self):
logF = self._create_loss()
self.optimizerLoss = tf.reduce_mean(self.optimizerLoss)
# Setup optimizer
grads_and_vars = self.optimizer_class.compute_gradients(self.optimizerLoss)
self._create_train_op(grads_and_vars)
# Create IWAE lower bound for evaluation
self.logF = self._reshape(logF)
self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
tf.log(tf.to_float(self.n_samples)))
def partial_fit(self, X, n_samples=1):
if hasattr(self, 'grad_variances'):
grad_variance_field_to_return = self.grad_variances
else:
grad_variance_field_to_return = self.grad_variance
_, res, grad_variance, step, temperature = self.sess.run(
(self.optimizer, self.lHat, grad_variance_field_to_return, self.global_step, self.temperature_variable),
feed_dict={self.x: X, self.n_samples: n_samples})
return res, grad_variance, step, temperature
def partial_grad(self, X, n_samples=1):
control_variate_grads, step = self.sess.run(
(self.control_variate_grads, self.global_step),
feed_dict={self.x: X, self.n_samples: n_samples})
return control_variate_grads, step
def partial_eval(self, X, n_samples=5):
if n_samples < 1000:
res, iwae = self.sess.run(
(self.lHat, self.iwae),
feed_dict={self.x: X, self.n_samples: n_samples})
res = [iwae] + res
else: # special case to handle OOM
assert n_samples % 100 == 0, "When using large # of samples, it must be divisble by 100"
res = []
for i in xrange(int(n_samples/100)):
logF, = self.sess.run(
(self.logF,),
feed_dict={self.x: X, self.n_samples: 100})
res.append(logsumexp(logF, axis=1))
res = [np.mean(logsumexp(res, axis=0) - np.log(n_samples))]
return res
# Random samplers
def _mean_sample(self, log_alpha, _, layer):
"""Returns mean of random variables parameterized by log_alpha."""
mu = tf.nn.sigmoid(log_alpha)
return {
'preactivation': mu,
'activation': mu,
'log_param': log_alpha,
}
def _generate_randomness(self):
for i in xrange(self.hparams.n_layer):
self.uniform_samples[i] = tf.stop_gradient(tf.random_uniform(
[self.batch_size, self.hparams.n_hidden]))
def _u_to_v(self, log_alpha, u, eps = 1e-8):
"""Convert u to tied randomness in v."""
u_prime = tf.nn.sigmoid(-log_alpha) # g(u') = 0
v_1 = (u - u_prime) / tf.clip_by_value(1 - u_prime, eps, 1)
v_1 = tf.clip_by_value(v_1, 0, 1)
v_1 = tf.stop_gradient(v_1)
v_1 = v_1*(1 - u_prime) + u_prime
v_0 = u / tf.clip_by_value(u_prime, eps, 1)
v_0 = tf.clip_by_value(v_0, 0, 1)
v_0 = tf.stop_gradient(v_0)
v_0 = v_0 * u_prime
v = tf.where(u > u_prime, v_1, v_0)
v = tf.check_numerics(v, 'v sampling is not numerically stable.')
v = v + tf.stop_gradient(-v + u) # v and u are the same up to numerical errors
return v
def _random_sample(self, log_alpha, u, layer):
"""Returns sampled random variables parameterized by log_alpha."""
# Generate tied randomness for later
if layer not in self.uniform_samples_v:
self.uniform_samples_v[layer] = self._u_to_v(log_alpha, u)
# Sample random variable underlying softmax/argmax
x = log_alpha + U.safe_log_prob(u) - U.safe_log_prob(1 - u)
samples = tf.stop_gradient(tf.to_float(x > 0))
return {
'preactivation': x,
'activation': samples,
'log_param': log_alpha,
}
def _random_sample_soft(self, log_alpha, u, layer, temperature=None):
"""Returns sampled random variables parameterized by log_alpha."""
if temperature is None:
temperature = self.hparams.temperature
# Sample random variable underlying softmax/argmax
x = log_alpha + U.safe_log_prob(u) - U.safe_log_prob(1 - u)
x /= tf.expand_dims(temperature, -1)
if self.hparams.muprop_relaxation:
y = tf.nn.sigmoid(x + log_alpha * tf.expand_dims(temperature/(temperature + 1), -1))
else:
y = tf.nn.sigmoid(x)
return {
'preactivation': x,
'activation': y,
'log_param': log_alpha
}
def _random_sample_soft_v(self, log_alpha, _, layer, temperature=None):
"""Returns sampled random variables parameterized by log_alpha."""
v = self.uniform_samples_v[layer]
return self._random_sample_soft(log_alpha, v, layer, temperature)
def get_gumbel_gradient(self):
logQ, softSamples = self._recognition_network(sampler=self._random_sample_soft)
logQ = tf.add_n(logQ)
logPPrior, logP = self._generator_network(softSamples)
softELBO = logPPrior + logP - logQ
gumbel_gradient = (self.optimizer_class.
compute_gradients(softELBO))
debug = {
'softELBO': softELBO,
}
return gumbel_gradient, debug
# samplers used for quadratic version
def _random_sample_switch(self, log_alpha, u, layer, switch_layer, temperature=None):
"""Run partial discrete, then continuous path.
Args:
switch_layer: this layer and beyond will be continuous
"""
if layer < switch_layer:
return self._random_sample(log_alpha, u, layer)
else:
return self._random_sample_soft(log_alpha, u, layer, temperature)
def _random_sample_switch_v(self, log_alpha, u, layer, switch_layer, temperature=None):
"""Run partial discrete, then continuous path.
Args:
switch_layer: this layer and beyond will be continuous
"""
if layer < switch_layer:
return self._random_sample(log_alpha, u, layer)
else:
return self._random_sample_soft_v(log_alpha, u, layer, temperature)
# #####
# Gradient computation
# #####
def get_nvil_gradient(self):
"""Compute the NVIL gradient."""
# Hard loss
logQHard, samples = self._recognition_network()
ELBO, reinforce_model_grad = self._generator_network(samples, logQHard)
logQHard = tf.add_n(logQHard)
# Add baselines (no variance normalization)
learning_signal = tf.stop_gradient(ELBO) - self._create_baseline()
# Set up losses
self.baseline_loss.append(tf.square(learning_signal))
optimizerLoss = -(tf.stop_gradient(learning_signal)*logQHard +
reinforce_model_grad)
optimizerLoss = tf.reduce_mean(optimizerLoss)
nvil_gradient = self.optimizer_class.compute_gradients(optimizerLoss)
debug = {
'ELBO': ELBO,
'RMS of centered learning signal': U.rms(learning_signal),
}
return nvil_gradient, debug
def get_simple_muprop_gradient(self):
""" Computes the simple muprop gradient.
This muprop control variate does not include the linear term.
"""
# Hard loss
logQHard, hardSamples = self._recognition_network()
hardELBO, reinforce_model_grad = self._generator_network(hardSamples, logQHard)
# Soft loss
logQ, muSamples = self._recognition_network(sampler=self._mean_sample)
muELBO, _ = self._generator_network(muSamples, logQ)
scaling_baseline = self._create_eta(collection='BASELINE')
learning_signal = (hardELBO
- scaling_baseline * muELBO
- self._create_baseline())
self.baseline_loss.append(tf.square(learning_signal))
optimizerLoss = -(tf.stop_gradient(learning_signal) * tf.add_n(logQHard)
+ reinforce_model_grad)
optimizerLoss = tf.reduce_mean(optimizerLoss)
simple_muprop_gradient = (self.optimizer_class.
compute_gradients(optimizerLoss))
debug = {
'ELBO': hardELBO,
'muELBO': muELBO,
'RMS': U.rms(learning_signal),
}
return simple_muprop_gradient, debug
def get_muprop_gradient(self):
"""
random sample function that actually returns mean
new forward pass that returns logQ as a list
can get x_i from samples
"""
# Hard loss
logQHard, hardSamples = self._recognition_network()
hardELBO, reinforce_model_grad = self._generator_network(hardSamples, logQHard)
# Soft loss
logQ, muSamples = self._recognition_network(sampler=self._mean_sample)
muELBO, _ = self._generator_network(muSamples, logQ)
# Compute gradients
muELBOGrads = tf.gradients(tf.reduce_sum(muELBO),
[ muSamples[i]['activation'] for
i in xrange(self.hparams.n_layer) ])
# Compute MuProp gradient estimates
learning_signal = hardELBO
optimizerLoss = 0.0
learning_signals = []
for i in xrange(self.hparams.n_layer):
dfDiff = tf.reduce_sum(
muELBOGrads[i] * (hardSamples[i]['activation'] -
muSamples[i]['activation']),
axis=1)
dfMu = tf.reduce_sum(
tf.stop_gradient(muELBOGrads[i]) *
tf.nn.sigmoid(hardSamples[i]['log_param']),
axis=1)
scaling_baseline_0 = self._create_eta(collection='BASELINE')
scaling_baseline_1 = self._create_eta(collection='BASELINE')
learning_signals.append(learning_signal - scaling_baseline_0 * muELBO - scaling_baseline_1 * dfDiff - self._create_baseline())
self.baseline_loss.append(tf.square(learning_signals[i]))
optimizerLoss += (
logQHard[i] * tf.stop_gradient(learning_signals[i]) +
tf.stop_gradient(scaling_baseline_1) * dfMu)
optimizerLoss += reinforce_model_grad
optimizerLoss *= -1
optimizerLoss = tf.reduce_mean(optimizerLoss)
muprop_gradient = self.optimizer_class.compute_gradients(optimizerLoss)
debug = {
'ELBO': hardELBO,
'muELBO': muELBO,
}
debug.update(dict([
('RMS learning signal layer %d' % i, U.rms(learning_signal))
for (i, learning_signal) in enumerate(learning_signals)]))
return muprop_gradient, debug
# REBAR gradient helper functions
def _create_gumbel_control_variate(self, logQHard, temperature=None):
'''Calculate gumbel control variate.
'''
if temperature is None:
temperature = self.hparams.temperature
logQ, softSamples = self._recognition_network(sampler=functools.partial(
self._random_sample_soft, temperature=temperature))
softELBO, _ = self._generator_network(softSamples, logQ)
logQ = tf.add_n(logQ)
# Generate the softELBO_v (should be the same value but different grads)
logQ_v, softSamples_v = self._recognition_network(sampler=functools.partial(
self._random_sample_soft_v, temperature=temperature))
softELBO_v, _ = self._generator_network(softSamples_v, logQ_v)
logQ_v = tf.add_n(logQ_v)
# Compute losses
learning_signal = tf.stop_gradient(softELBO_v)
# Control variate
h = (tf.stop_gradient(learning_signal) * tf.add_n(logQHard)
- softELBO + softELBO_v)
extra = (softELBO_v, -softELBO + softELBO_v)
return h, extra
def _create_gumbel_control_variate_quadratic(self, logQHard, temperature=None):
'''Calculate gumbel control variate.
'''
if temperature is None:
temperature = self.hparams.temperature
h = 0
extra = []
for layer in xrange(self.hparams.n_layer):
logQ, softSamples = self._recognition_network(sampler=functools.partial(
self._random_sample_switch, switch_layer=layer, temperature=temperature))
softELBO, _ = self._generator_network(softSamples, logQ)
# Generate the softELBO_v (should be the same value but different grads)
logQ_v, softSamples_v = self._recognition_network(sampler=functools.partial(
self._random_sample_switch_v, switch_layer=layer, temperature=temperature))
softELBO_v, _ = self._generator_network(softSamples_v, logQ_v)
# Compute losses
learning_signal = tf.stop_gradient(softELBO_v)
# Control variate
h += (tf.stop_gradient(learning_signal) * logQHard[layer]
- softELBO + softELBO_v)
extra.append((softELBO_v, -softELBO + softELBO_v))
return h, extra
def _create_hard_elbo(self):
logQHard, hardSamples = self._recognition_network()
hardELBO, reinforce_model_grad = self._generator_network(hardSamples, logQHard)
reinforce_learning_signal = tf.stop_gradient(hardELBO)
# Center learning signal
baseline = self._create_baseline(collection='CV')
reinforce_learning_signal = tf.stop_gradient(reinforce_learning_signal) - baseline
nvil_gradient = (tf.stop_gradient(hardELBO) - baseline) * tf.add_n(logQHard) + reinforce_model_grad
return hardELBO, nvil_gradient, logQHard
def multiply_by_eta(self, h_grads, eta):
# Modifies eta
res = []
eta_statistics = []
for (g, v) in h_grads:
if g is None:
res.append((g, v))
else:
if 'network' not in eta:
eta['network'] = self._create_eta()
res.append((g*eta['network'], v))
eta_statistics.append(eta['network'])
return res, eta_statistics
def multiply_by_eta_per_layer(self, h_grads, eta):
# Modifies eta
res = []
eta_statistics = []
for (g, v) in h_grads:
if g is None:
res.append((g, v))
else:
if v not in eta:
eta[v] = self._create_eta()
res.append((g*eta[v], v))
eta_statistics.append(eta[v])
return res, eta_statistics
def multiply_by_eta_per_unit(self, h_grads, eta):
# Modifies eta
res = []
eta_statistics = []
for (g, v) in h_grads:
if g is None:
res.append((g, v))
else:
if v not in eta:
g_shape = g.shape_as_list()
assert len(g_shape) <= 2, 'Gradient has too many dimensions'
if len(g_shape) == 1:
eta[v] = self._create_eta(g_shape)
else:
eta[v] = self._create_eta([1, g_shape[1]])
h_grads.append((g*eta[v], v))
eta_statistics.extend(tf.nn.moments(tf.squeeze(eta[v]), axes=[0]))
return res, eta_statistics
def get_dynamic_rebar_gradient(self):
"""Get the dynamic rebar gradient (t, eta optimized)."""
tiled_pre_temperature = tf.tile([self.pre_temperature_variable],
[self.batch_size])
temperature = tf.exp(tiled_pre_temperature)
hardELBO, nvil_gradient, logQHard = self._create_hard_elbo()
if self.hparams.quadratic:
gumbel_cv, extra = self._create_gumbel_control_variate_quadratic(logQHard, temperature=temperature)
else:
gumbel_cv, extra = self._create_gumbel_control_variate(logQHard, temperature=temperature)
f_grads = self.optimizer_class.compute_gradients(tf.reduce_mean(-nvil_gradient))
eta = {}
h_grads, eta_statistics = self.multiply_by_eta_per_layer(
self.optimizer_class.compute_gradients(tf.reduce_mean(gumbel_cv)),
eta)
model_grads = U.add_grads_and_vars(f_grads, h_grads)
total_grads = model_grads
# Construct the variance objective
g = U.vectorize(model_grads, set_none_to_zero=True)
self.maintain_ema_ops.append(self.ema.apply([g]))
gbar = 0 #tf.stop_gradient(self.ema.average(g))
variance_objective = tf.reduce_mean(tf.square(g - gbar))
reinf_g_t = 0
if self.hparams.quadratic:
for layer in xrange(self.hparams.n_layer):
gumbel_learning_signal, _ = extra[layer]
df_dt = tf.gradients(gumbel_learning_signal, tiled_pre_temperature)[0]
reinf_g_t_i, _ = self.multiply_by_eta_per_layer(
self.optimizer_class.compute_gradients(tf.reduce_mean(tf.stop_gradient(df_dt) * logQHard[layer])),
eta)
reinf_g_t += U.vectorize(reinf_g_t_i, set_none_to_zero=True)
reparam = tf.add_n([reparam_i for _, reparam_i in extra])
else:
gumbel_learning_signal, reparam = extra
df_dt = tf.gradients(gumbel_learning_signal, tiled_pre_temperature)[0]
reinf_g_t, _ = self.multiply_by_eta_per_layer(
self.optimizer_class.compute_gradients(tf.reduce_mean(tf.stop_gradient(df_dt) * tf.add_n(logQHard))),
eta)
reinf_g_t = U.vectorize(reinf_g_t, set_none_to_zero=True)
reparam_g, _ = self.multiply_by_eta_per_layer(
self.optimizer_class.compute_gradients(tf.reduce_mean(reparam)),
eta)
reparam_g = U.vectorize(reparam_g, set_none_to_zero=True)
reparam_g_t = tf.gradients(tf.reduce_mean(2*tf.stop_gradient(g - gbar)*reparam_g), self.pre_temperature_variable)[0]
variance_objective_grad = tf.reduce_mean(2*(g - gbar)*reinf_g_t) + reparam_g_t
debug = { 'ELBO': hardELBO,
'etas': eta_statistics,
'variance_objective': variance_objective,
}
return total_grads, debug, variance_objective, variance_objective_grad
def get_rebar_gradient(self):
"""Get the rebar gradient."""
hardELBO, nvil_gradient, logQHard = self._create_hard_elbo()
if self.hparams.quadratic:
gumbel_cv, _ = self._create_gumbel_control_variate_quadratic(logQHard)
else:
gumbel_cv, _ = self._create_gumbel_control_variate(logQHard)
f_grads = self.optimizer_class.compute_gradients(tf.reduce_mean(-nvil_gradient))
eta = {}
h_grads, eta_statistics = self.multiply_by_eta_per_layer(
self.optimizer_class.compute_gradients(tf.reduce_mean(gumbel_cv)),
eta)
model_grads = U.add_grads_and_vars(f_grads, h_grads)
total_grads = model_grads
# Construct the variance objective
variance_objective = tf.reduce_mean(tf.square(U.vectorize(model_grads, set_none_to_zero=True)))
debug = { 'ELBO': hardELBO,
'etas': eta_statistics,
'variance_objective': variance_objective,
}
return total_grads, debug, variance_objective
###
# Create varaints
###
class SBNSimpleMuProp(SBN):
def _create_loss(self):
simple_muprop_gradient, debug = self.get_simple_muprop_gradient()
self.lHat = map(tf.reduce_mean, [
debug['ELBO'],
debug['muELBO'],
])
return debug['ELBO'], simple_muprop_gradient
def _create_network(self):
logF, loss_grads = self._create_loss()
self._create_train_op(loss_grads)
# Create IWAE lower bound for evaluation
self.logF = self._reshape(logF)
self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
tf.log(tf.to_float(self.n_samples)))
class SBNMuProp(SBN):
def _create_loss(self):
muprop_gradient, debug = self.get_muprop_gradient()
self.lHat = map(tf.reduce_mean, [
debug['ELBO'],
debug['muELBO'],
])
return debug['ELBO'], muprop_gradient
def _create_network(self):
logF, loss_grads = self._create_loss()
self._create_train_op(loss_grads)
# Create IWAE lower bound for evaluation
self.logF = self._reshape(logF)
self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
tf.log(tf.to_float(self.n_samples)))
class SBNNVIL(SBN):
def _create_loss(self):
nvil_gradient, debug = self.get_nvil_gradient()
self.lHat = map(tf.reduce_mean, [
debug['ELBO'],
])
return debug['ELBO'], nvil_gradient
def _create_network(self):
logF, loss_grads = self._create_loss()
self._create_train_op(loss_grads)
# Create IWAE lower bound for evaluation
self.logF = self._reshape(logF)
self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
tf.log(tf.to_float(self.n_samples)))
class SBNRebar(SBN):
def _create_loss(self):
rebar_gradient, debug, variance_objective = self.get_rebar_gradient()
self.lHat = map(tf.reduce_mean, [
debug['ELBO'],
])
self.lHat.extend(map(tf.reduce_mean, debug['etas']))
return debug['ELBO'], rebar_gradient, variance_objective
def _create_network(self):
logF, loss_grads, variance_objective = self._create_loss()
# Create additional updates for control variates and temperature
eta_grads = (self.optimizer_class.compute_gradients(variance_objective,
var_list=tf.get_collection('CV')))
self._create_train_op(loss_grads, eta_grads)
# Create IWAE lower bound for evaluation
self.logF = self._reshape(logF)
self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
tf.log(tf.to_float(self.n_samples)))
class SBNDynamicRebar(SBN):
def _create_loss(self):
rebar_gradient, debug, variance_objective, variance_objective_grad = self.get_dynamic_rebar_gradient()
self.lHat = map(tf.reduce_mean, [
debug['ELBO'],
self.temperature_variable,
])
self.lHat.extend(debug['etas'])
return debug['ELBO'], rebar_gradient, variance_objective, variance_objective_grad
def _create_network(self):
logF, loss_grads, variance_objective, variance_objective_grad = self._create_loss()
# Create additional updates for control variates and temperature
eta_grads = (self.optimizer_class.compute_gradients(variance_objective,
var_list=tf.get_collection('CV'))
+ [(variance_objective_grad, self.pre_temperature_variable)])
self._create_train_op(loss_grads, eta_grads)
# Create IWAE lower bound for evaluation
self.logF = self._reshape(logF)
self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
tf.log(tf.to_float(self.n_samples)))
class SBNTrackGradVariances(SBN):
"""Follow NVIL, compute gradient variances for NVIL, MuProp and REBAR."""
def compute_gradient_moments(self, grads_and_vars):
first_moment = U.vectorize(grads_and_vars, set_none_to_zero=True)
second_moment = tf.square(first_moment)
self.maintain_ema_ops.append(self.ema.apply([first_moment, second_moment]))
return self.ema.average(first_moment), self.ema.average(second_moment)
def _create_loss(self):
self.losses = [
('NVIL', self.get_nvil_gradient),
('SimpleMuProp', self.get_simple_muprop_gradient),
('MuProp', self.get_muprop_gradient),
]
moments = []
for k, v in self.losses:
print(k)
gradient, debug = v()
if k == 'SimpleMuProp':
ELBO = debug['ELBO']
gradient_to_follow = gradient
moments.append(self.compute_gradient_moments(
gradient))
self.losses.append(('DynamicREBAR', self.get_dynamic_rebar_gradient))
dynamic_rebar_gradient, _, variance_objective, variance_objective_grad = self.get_dynamic_rebar_gradient()
moments.append(self.compute_gradient_moments(dynamic_rebar_gradient))
self.losses.append(('REBAR', self.get_rebar_gradient))
rebar_gradient, _, variance_objective2 = self.get_rebar_gradient()
moments.append(self.compute_gradient_moments(rebar_gradient))
mu = tf.reduce_mean(tf.stack([f for f, _ in moments]), axis=0)
self.grad_variances = []
deviations = []
for f, s in moments:
self.grad_variances.append(tf.reduce_mean(s - tf.square(mu)))
deviations.append(tf.reduce_mean(tf.square(f - mu)))
self.lHat = map(tf.reduce_mean, [
ELBO,
self.temperature_variable,
variance_objective_grad,
variance_objective_grad*variance_objective_grad,
])
self.lHat.extend(deviations)
self.lHat.append(tf.log(tf.reduce_mean(mu*mu)))
# self.lHat.extend(map(tf.log, grad_variances))
return ELBO, gradient_to_follow, variance_objective + variance_objective2, variance_objective_grad
def _create_network(self):
logF, loss_grads, variance_objective, variance_objective_grad = self._create_loss()
eta_grads = (self.optimizer_class.compute_gradients(variance_objective,
var_list=tf.get_collection('CV'))
+ [(variance_objective_grad, self.pre_temperature_variable)])
self._create_train_op(loss_grads, eta_grads)
# Create IWAE lower bound for evaluation
self.logF = self._reshape(logF)
self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
tf.log(tf.to_float(self.n_samples)))
class SBNGumbel(SBN):
def _random_sample_soft(self, log_alpha, u, layer, temperature=None):
"""Returns sampled random variables parameterized by log_alpha."""
if temperature is None:
temperature = self.hparams.temperature
# Sample random variable underlying softmax/argmax
x = log_alpha + U.safe_log_prob(u) - U.safe_log_prob(1 - u)
x /= temperature
if self.hparams.muprop_relaxation:
x += temperature/(temperature + 1)*log_alpha
y = tf.nn.sigmoid(x)
return {
'preactivation': x,
'activation': y,
'log_param': log_alpha
}
def _create_loss(self):
# Hard loss
logQHard, hardSamples = self._recognition_network()
hardELBO, _ = self._generator_network(hardSamples, logQHard)
logQ, softSamples = self._recognition_network(sampler=self._random_sample_soft)
softELBO, _ = self._generator_network(softSamples, logQ)
self.optimizerLoss = -softELBO
self.lHat = map(tf.reduce_mean, [
hardELBO,
softELBO,
])
return hardELBO
default_hparams = tf.contrib.training.HParams(model='SBNGumbel',
n_hidden=200,
n_input=784,
n_layer=1,
nonlinear=False,
learning_rate=0.001,
temperature=0.5,
n_samples=1,
batch_size=24,
trial=1,
muprop_relaxation=True,
dynamic_b=False, # dynamic binarization
quadratic=True,
beta2=0.99999,
task='sbn',
)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import random
import sys
import os
import numpy as np
import tensorflow as tf
import rebar
import datasets
import logger as L
gfile = tf.gfile
tf.app.flags.DEFINE_string("working_dir", "/tmp/rebar",
"""Directory where to save data, write logs, etc.""")
tf.app.flags.DEFINE_string('hparams', '',
'''Comma separated list of name=value pairs.''')
tf.app.flags.DEFINE_integer('eval_freq', 20,
'''How often to run the evaluation step.''')
FLAGS = tf.flags.FLAGS
def manual_scalar_summary(name, value):
value = tf.Summary.Value(tag=name, simple_value=value)
summary_str = tf.Summary(value=[value])
return summary_str
def eval(sbn, eval_xs, n_samples=100, batch_size=5):
n = eval_xs.shape[0]
i = 0
res = []
while i < n:
batch_xs = eval_xs[i:min(i+batch_size, n)]
res.append(sbn.partial_eval(batch_xs, n_samples))
i += batch_size
res = np.mean(res, axis=0)
return res
def train(sbn, train_xs, valid_xs, test_xs, training_steps, debug=False):
hparams = sorted(sbn.hparams.values().items())
hparams = (map(str, x) for x in hparams)
hparams = ('_'.join(x) for x in hparams)
hparams_str = '.'.join(hparams)
logger = L.Logger()
# Create the experiment name from the hparams
experiment_name = ([str(sbn.hparams.n_hidden) for i in xrange(sbn.hparams.n_layer)] +
[str(sbn.hparams.n_input)])
if sbn.hparams.nonlinear:
experiment_name = '~'.join(experiment_name)
else:
experiment_name = '-'.join(experiment_name)
experiment_name = 'SBN_%s' % experiment_name
rowkey = {'experiment': experiment_name,
'model': hparams_str}
# Create summary writer
summ_dir = os.path.join(FLAGS.working_dir, hparams_str)
summary_writer = tf.summary.FileWriter(
summ_dir, flush_secs=15, max_queue=100)
sv = tf.train.Supervisor(logdir=os.path.join(
FLAGS.working_dir, hparams_str),
save_summaries_secs=0,
save_model_secs=1200,
summary_op=None,
recovery_wait_secs=30,
global_step=sbn.global_step)
with sv.managed_session() as sess:
# Dump hparams to file
with gfile.Open(os.path.join(FLAGS.working_dir,
hparams_str,
'hparams.json'),
'w') as out:
json.dump(sbn.hparams.values(), out)
sbn.initialize(sess)
batch_size = sbn.hparams.batch_size
scores = []
n = train_xs.shape[0]
index = range(n)
while not sv.should_stop():
lHats = []
grad_variances = []
temperatures = []
random.shuffle(index)
i = 0
while i < n:
batch_index = index[i:min(i+batch_size, n)]
batch_xs = train_xs[batch_index, :]
if sbn.hparams.dynamic_b:
# Dynamically binarize the batch data
batch_xs = (np.random.rand(*batch_xs.shape) < batch_xs).astype(float)
lHat, grad_variance, step, temperature = sbn.partial_fit(batch_xs,
sbn.hparams.n_samples)
if debug:
print(i, lHat)
if i > 100:
return
lHats.append(lHat)
grad_variances.append(grad_variance)
temperatures.append(temperature)
i += batch_size
grad_variances = np.log(np.mean(grad_variances, axis=0)).tolist()
summary_strings = []
if isinstance(grad_variances, list):
grad_variances = dict(zip([k for (k, v) in sbn.losses], map(float, grad_variances)))
rowkey['step'] = step
logger.log(rowkey, {'step': step,
'train': np.mean(lHats, axis=0)[0],
'grad_variances': grad_variances,
'temperature': np.mean(temperatures), })
grad_variances = '\n'.join(map(str, sorted(grad_variances.iteritems())))
else:
rowkey['step'] = step
logger.log(rowkey, {'step': step,
'train': np.mean(lHats, axis=0)[0],
'grad_variance': grad_variances,
'temperature': np.mean(temperatures), })
summary_strings.append(manual_scalar_summary("log grad variance", grad_variances))
print('Step %d: %s\n%s' % (step, str(np.mean(lHats, axis=0)), str(grad_variances)))
# Every few epochs compute test and validation scores
epoch = int(step / (train_xs.shape[0] / sbn.hparams.batch_size))
if epoch % FLAGS.eval_freq == 0:
valid_res = eval(sbn, valid_xs)
test_res= eval(sbn, test_xs)
print('\nValid %d: %s' % (step, str(valid_res)))
print('Test %d: %s\n' % (step, str(test_res)))
logger.log(rowkey, {'step': step,
'valid': valid_res[0],
'test': test_res[0]})
logger.flush() # Flush infrequently
# Create summaries
summary_strings.extend([
manual_scalar_summary("Train ELBO", np.mean(lHats, axis=0)[0]),
manual_scalar_summary("Temperature", np.mean(temperatures)),
])
for summ_str in summary_strings:
summary_writer.add_summary(summ_str, global_step=step)
summary_writer.flush()
sys.stdout.flush()
scores.append(np.mean(lHats, axis=0))
if step > training_steps:
break
return scores
def main():
# Parse hyperparams
hparams = rebar.default_hparams
hparams.parse(FLAGS.hparams)
print(hparams.values())
train_xs, valid_xs, test_xs = datasets.load_data(hparams)
mean_xs = np.mean(train_xs, axis=0) # Compute mean centering on training
training_steps = 2000000
model = getattr(rebar, hparams.model)
sbn = model(hparams, mean_xs=mean_xs)
scores = train(sbn, train_xs, valid_xs, test_xs,
training_steps=training_steps, debug=False)
if __name__ == '__main__':
main()
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Basic data management and plotting utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import cPickle as pickle
import getpass
import numpy as np
import gc
import tensorflow as tf
#
# Python utlities
#
def exp_moving_average(x, alpha=0.9):
res = []
mu = 0
alpha_factor = 1
for x_i in x:
mu += (1 - alpha)*(x_i - mu)
alpha_factor *= alpha
res.append(mu/(1 - alpha_factor))
return np.array(res)
def sanitize(s):
return s.replace('.', '_')
#
# Tensorflow utilities
#
def softplus(x):
'''
Let m = max(0, x), then,
sofplus(x) = log(1 + e(x)) = log(e(0) + e(x)) = log(e(m)(e(-m) + e(x-m)))
= m + log(e(-m) + e(x - m))
The term inside of the log is guaranteed to be between 1 and 2.
'''
m = tf.maximum(tf.zeros_like(x), x)
return m + tf.log(tf.exp(-m) + tf.exp(x - m))
def safe_log_prob(x, eps=1e-8):
return tf.log(tf.clip_by_value(x, eps, 1.0))
def rms(x):
return tf.sqrt(tf.reduce_mean(tf.square(x)))
def center(x):
mu = (tf.reduce_sum(x) - x)/tf.to_float(tf.shape(x)[0] - 1)
return x - mu
def vectorize(grads_and_vars, set_none_to_zero=False, skip_none=False):
if set_none_to_zero:
return tf.concat([tf.reshape(g, [-1]) if g is not None else
tf.reshape(tf.zeros_like(v), [-1]) for g, v in grads_and_vars], 0)
elif skip_none:
return tf.concat([tf.reshape(g, [-1]) for g, v in grads_and_vars if g is not None], 0)
else:
return tf.concat([tf.reshape(g, [-1]) for g, v in grads_and_vars], 0)
def add_grads_and_vars(a, b):
'''Add grads_and_vars from two calls to tf.compute_gradients.'''
res = []
for (g_a, v_a), (g_b, v_b) in zip(a, b):
assert v_a == v_b
if g_a is None:
res.append((g_b, v_b))
elif g_b is None:
res.append((g_a, v_a))
else:
res.append((g_a + g_b, v_a))
return res
def binary_log_likelihood(y, log_y_hat):
"""Computes binary log likelihood.
Args:
y: observed data
log_y_hat: parameters of the binary variables
Returns:
log_likelihood
"""
return tf.reduce_sum(y*(-softplus(-log_y_hat)) +
(1 - y)*(-log_y_hat-softplus(-log_y_hat)),
1)
def cov(a, b):
"""Compute the sample covariance between two vectors."""
mu_a = tf.reduce_mean(a)
mu_b = tf.reduce_mean(b)
n = tf.to_float(tf.shape(a)[0])
return tf.reduce_sum((a - mu_a)*(b - mu_b))/(n - 1.0)
def corr(a, b):
return cov(a, b)*tf.rsqrt(cov(a, a))*tf.rsqrt(cov(b, b))
def logSumExp(t, axis=0, keep_dims = False):
'''Computes the log(sum(exp(t))) numerically stabily.
Args:
t: input tensor
axis: which axis to sum over
keep_dims: whether to keep the dim or not
Returns:
tensor with result
'''
m = tf.reduce_max(t, [axis])
res = m + tf.log(tf.reduce_sum(tf.exp(t - tf.expand_dims(m, axis)), [axis]))
if keep_dims:
return tf.expand_dims(res, axis)
else:
return res
if __name__ == '__main__':
app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment