Commit 513998a3 authored by Martin Wicke's avatar Martin Wicke Committed by GitHub
Browse files

Merge pull request #888 from laurent-dinh/master

Real NVP code
parents 91c7b91f e871d295
# Real NVP in TensorFlow
*A Tensorflow implementation of the training procedure of*
[*Density estimation using Real NVP*](https://arxiv.org/abs/1605.08803)*, by
Laurent Dinh, Jascha Sohl-Dickstein and Samy Bengio, for Imagenet
(32x32 and 64x64), CelebA and LSUN Including the scripts to
put the datasets in `.tfrecords` format.*
We are happy to open source the code for *Real NVP*, a novel approach to
density estimation using deep neural networks that enables tractable density
estimation and efficient one-pass inference and sampling. This model
successfully decomposes images into hierarchical features ranging from
high-level concepts to low-resolution details. Visualizations are available
[here](http://goo.gl/yco14s).
## Installation
* python 2.7:
* python 3 support is not available yet
* pip (python package manager)
* `apt-get install python-pip` on Ubuntu
* `brew` installs pip along with python on OSX
* Install the dependencies for [LSUN](https://github.com/fyu/lsun.git)
* Install [OpenCV](http://opencv.org/)
* `pip install numpy lmdb`
* Install the python dependencies
* `pip install scipy scikit-image Pillow`
* Install the
[latest Tensorflow Pip package](https://www.tensorflow.org/get_started/os_setup.html#using-pip)
for Python 2.7
## Getting Started
Once you have successfully installed the dependencies, you can start by
downloading the repository:
```shell
git clone --recursive https://github.com/tensorflow/models.git
```
Afterward, you can use the utilities in this folder prepare the datasets.
## Preparing datasets
### CelebA
For [*CelebA*](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html), download
`img_align_celeba.zip` from the Dropbox link on this
[page](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) under the
link *Align&Cropped Images* in the *Img* directory and `list_eval_partition.txt`
under the link *Train/Val/Test Partitions* in the *Eval* directory. Then do:
```shell
mkdir celeba
cd celeba
unzip img_align_celeba.zip
```
We'll format the training subset:
```shell
python2.7 ../models/real_nvp/celeba_formatting.py \
--partition_fn list_eval_partition.txt \
--file_out celeba_train \
--fn_root img_align_celeba \
--set 0
```
Then the validation subset:
```shell
python2.7 ../models/real_nvp/celeba_formatting.py \
--partition_fn list_eval_partition.txt \
--file_out celeba_valid \
--fn_root img_align_celeba \
--set 1
```
And finally the test subset:
```shell
python2.7 ../models/real_nvp/celeba_formatting.py \
--partition_fn list_eval_partition.txt \
--file_out celeba_test \
--fn_root img_align_celeba \
--set 2
```
Afterward:
```shell
cd ..
```
### Small Imagenet
Downloading the [*small Imagenet*](http://image-net.org/small/download.php)
dataset is more straightforward and can be done
entirely in Shell:
```shell
mkdir small_imnet
cd small_imnet
for FILENAME in train_32x32.tar valid_32x32.tar train_64x64.tar valid_64x64.tar
do
curl -O http://image-net.org/small/$FILENAME
tar -xvf $FILENAME
done
```
Then, you can format the datasets as follow:
```shell
for DIRNAME in train_32x32 valid_32x32 train_64x64 valid_64x64
do
python2.7 ../models/real_nvp/imnet_formatting.py \
--file_out $DIRNAME \
--fn_root $DIRNAME
done
cd ..
```
### LSUN
To prepare the [*LSUN*](http://lsun.cs.princeton.edu/2016/) dataset, we will
need to use the code associated:
```shell
git clone https://github.com/fyu/lsun.git
cd lsun
```
Then we'll download the db files:
```shell
for CATEGORY in bedroom church_outdoor tower
do
python2.7 download.py -c $CATEGORY
unzip "$CATEGORY"_train_lmdb.zip
unzip "$CATEGORY"_val_lmdb.zip
python2.7 data.py export "$CATEGORY"_train_lmdb \
--out_dir "$CATEGORY"_train --flat
python2.7 data.py export "$CATEGORY"_val_lmdb \
--out_dir "$CATEGORY"_val --flat
done
```
Finally, we then format the dataset into `.tfrecords`:
```shell
for CATEGORY in bedroom church_outdoor tower
do
python2.7 ../models/real_nvp/lsun_formatting.py \
--file_out "$CATEGORY"_train \
--fn_root "$CATEGORY"_train
python2.7 ../models/real_nvp/lsun_formatting.py \
--file_out "$CATEGORY"_val \
--fn_root "$CATEGORY"_val
done
cd ..
```
## Training
We'll give an example on how to train a model on the small Imagenet
dataset (32x32):
```shell
cd models/real_nvp/
python2.7 real_nvp_multiscale_dataset.py \
--image_size 32 \
--hpconfig=n_scale=4,base_dim=32,clip_gradient=100,residual_blocks=4 \
--dataset imnet \
--traindir /tmp/real_nvp_imnet32/train \
--logdir /tmp/real_nvp_imnet32/train \
--data_path ../../small_imnet/train_32x32_?????.tfrecords
```
In parallel, you can run the script to generate visualization from the model:
```shell
python2.7 real_nvp_multiscale_dataset.py \
--image_size 32 \
--hpconfig=n_scale=4,base_dim=32,clip_gradient=100,residual_blocks=4 \
--dataset imnet \
--traindir /tmp/real_nvp_imnet32/train \
--logdir /tmp/real_nvp_imnet32/sample \
--data_path ../../small_imnet/valid_32x32_?????.tfrecords \
--mode sample
```
Additionally, you can also run in the script to evaluate the model on the
validation set:
```shell
python2.7 real_nvp_multiscale_dataset.py \
--image_size 32 \
--hpconfig=n_scale=4,base_dim=32,clip_gradient=100,residual_blocks=4 \
--dataset imnet \
--traindir /tmp/real_nvp_imnet32/train \
--logdir /tmp/real_nvp_imnet32/eval \
--data_path ../../small_imnet/valid_32x32_?????.tfrecords \
--eval_set_size 50000
--mode eval
```
The visualizations and validation set evaluation can be seen through
[Tensorboard](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/README.md).
Another example would be how to run the model on LSUN (bedroom category):
```shell
# train the model
python2.7 real_nvp_multiscale_dataset.py \
--image_size 64 \
--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
--dataset lsun \
--traindir /tmp/real_nvp_church_outdoor/train \
--logdir /tmp/real_nvp_church_outdoor/train \
--data_path ../../lsun/church_outdoor_train_?????.tfrecords
```
```shell
# sample from the model
python2.7 real_nvp_multiscale_dataset.py \
--image_size 64 \
--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
--dataset lsun \
--traindir /tmp/real_nvp_church_outdoor/train \
--logdir /tmp/real_nvp_church_outdoor/sample \
--data_path ../../lsun/church_outdoor_val_?????.tfrecords \
--mode sample
```
```shell
# evaluate the model
python2.7 real_nvp_multiscale_dataset.py \
--image_size 64 \
--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
--dataset lsun \
--traindir /tmp/real_nvp_church_outdoor/train \
--logdir /tmp/real_nvp_church_outdoor/eval \
--data_path ../../lsun/church_outdoor_val_?????.tfrecords \
--eval_set_size 300
--mode eval
```
Finally, we'll give the commands to run the model on the CelebA dataset:
```shell
# train the model
python2.7 real_nvp_multiscale_dataset.py \
--image_size 64 \
--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
--dataset lsun \
--traindir /tmp/real_nvp_celeba/train \
--logdir /tmp/real_nvp_celeba/train \
--data_path ../../celeba/celeba_train.tfrecords
```
```shell
# sample from the model
python2.7 real_nvp_multiscale_dataset.py \
--image_size 64 \
--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
--dataset celeba \
--traindir /tmp/real_nvp_celeba/train \
--logdir /tmp/real_nvp_celeba/sample \
--data_path ../../celeba/celeba_valid.tfrecords \
--mode sample
```
```shell
# evaluate the model on validation set
python2.7 real_nvp_multiscale_dataset.py \
--image_size 64 \
--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
--dataset celeba \
--traindir /tmp/real_nvp_celeba/train \
--logdir /tmp/real_nvp_celeba/eval_valid \
--data_path ../../celeba/celeba_valid.tfrecords \
--eval_set_size 19867
--mode eval
# evaluate the model on test set
python2.7 real_nvp_multiscale_dataset.py \
--image_size 64 \
--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
--dataset celeba \
--traindir /tmp/real_nvp_celeba/train \
--logdir /tmp/real_nvp_celeba/eval_test \
--data_path ../../celeba/celeba_test.tfrecords \
--eval_set_size 19962
--mode eval
```
## Credits
This code was written by Laurent Dinh
([@laurent-dinh](https://github.com/laurent-dinh)) with
the help of
Jascha Sohl-Dickstein ([@Sohl-Dickstein](https://github.com/Sohl-Dickstein)
and [jaschasd@google.com](mailto:jaschasd@google.com)),
Samy Bengio, Jon Shlens, Sherry Moore and
David Andersen.
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""CelebA dataset formating.
Download img_align_celeba.zip from
http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html under the
link "Align&Cropped Images" in the "Img" directory and list_eval_partition.txt
under the link "Train/Val/Test Partitions" in the "Eval" directory. Then do:
unzip img_align_celeba.zip
Use the script as follow:
python celeba_formatting.py \
--partition_fn [PARTITION_FILE_PATH] \
--file_out [OUTPUT_FILE_PATH_PREFIX] \
--fn_root [CELEBA_FOLDER] \
--set [SUBSET_INDEX]
"""
import os
import os.path
import scipy.io
import scipy.io.wavfile
import scipy.ndimage
import tensorflow as tf
tf.flags.DEFINE_string("file_out", "",
"Filename of the output .tfrecords file.")
tf.flags.DEFINE_string("fn_root", "", "Name of root file path.")
tf.flags.DEFINE_string("partition_fn", "", "Partition file path.")
tf.flags.DEFINE_string("set", "", "Name of subset.")
FLAGS = tf.flags.FLAGS
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def main():
"""Main converter function."""
# Celeb A
with open(FLAGS.partition_fn, "r") as infile:
img_fn_list = infile.readlines()
img_fn_list = [elem.strip().split() for elem in img_fn_list]
img_fn_list = [elem[0] for elem in img_fn_list if elem[1] == FLAGS.set]
fn_root = FLAGS.fn_root
num_examples = len(img_fn_list)
file_out = "%s.tfrecords" % FLAGS.file_out
writer = tf.python_io.TFRecordWriter(file_out)
for example_idx, img_fn in enumerate(img_fn_list):
if example_idx % 1000 == 0:
print example_idx, "/", num_examples
image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn))
rows = image_raw.shape[0]
cols = image_raw.shape[1]
depth = image_raw.shape[2]
image_raw = image_raw.tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
"height": _int64_feature(rows),
"width": _int64_feature(cols),
"depth": _int64_feature(depth),
"image_raw": _bytes_feature(image_raw)
}
)
)
writer.write(example.SerializeToString())
writer.close()
if __name__ == "__main__":
main()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""LSUN dataset formatting.
Download and format the Imagenet dataset as follow:
mkdir [IMAGENET_PATH]
cd [IMAGENET_PATH]
for FILENAME in train_32x32.tar valid_32x32.tar train_64x64.tar valid_64x64.tar
do
curl -O http://image-net.org/small/$FILENAME
tar -xvf $FILENAME
done
Then use the script as follow:
for DIRNAME in train_32x32 valid_32x32 train_64x64 valid_64x64
do
python imnet_formatting.py \
--file_out $DIRNAME \
--fn_root $DIRNAME
done
"""
import os
import os.path
import scipy.io
import scipy.io.wavfile
import scipy.ndimage
import tensorflow as tf
tf.flags.DEFINE_string("file_out", "",
"Filename of the output .tfrecords file.")
tf.flags.DEFINE_string("fn_root", "", "Name of root file path.")
FLAGS = tf.flags.FLAGS
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def main():
"""Main converter function."""
# LSUN
fn_root = FLAGS.fn_root
img_fn_list = os.listdir(fn_root)
img_fn_list = [img_fn for img_fn in img_fn_list
if img_fn.endswith('.png')]
num_examples = len(img_fn_list)
n_examples_per_file = 10000
for example_idx, img_fn in enumerate(img_fn_list):
if example_idx % n_examples_per_file == 0:
file_out = "%s_%05d.tfrecords"
file_out = file_out % (FLAGS.file_out,
example_idx // n_examples_per_file)
print "Writing on:", file_out
writer = tf.python_io.TFRecordWriter(file_out)
if example_idx % 1000 == 0:
print example_idx, "/", num_examples
image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn))
rows = image_raw.shape[0]
cols = image_raw.shape[1]
depth = image_raw.shape[2]
image_raw = image_raw.astype("uint8")
image_raw = image_raw.tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
"height": _int64_feature(rows),
"width": _int64_feature(cols),
"depth": _int64_feature(depth),
"image_raw": _bytes_feature(image_raw)
}
)
)
writer.write(example.SerializeToString())
if example_idx % n_examples_per_file == (n_examples_per_file - 1):
writer.close()
writer.close()
if __name__ == "__main__":
main()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""LSUN dataset formatting.
Download and format the LSUN dataset as follow:
git clone https://github.com/fyu/lsun.git
cd lsun
python2.7 download.py -c [CATEGORY]
Then unzip the downloaded .zip files before executing:
python2.7 data.py export [IMAGE_DB_PATH] --out_dir [LSUN_FOLDER] --flat
Then use the script as follow:
python lsun_formatting.py \
--file_out [OUTPUT_FILE_PATH_PREFIX] \
--fn_root [LSUN_FOLDER]
"""
import os
import os.path
import numpy
import skimage.transform
from PIL import Image
import tensorflow as tf
tf.flags.DEFINE_string("file_out", "",
"Filename of the output .tfrecords file.")
tf.flags.DEFINE_string("fn_root", "", "Name of root file path.")
FLAGS = tf.flags.FLAGS
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def main():
"""Main converter function."""
fn_root = FLAGS.fn_root
img_fn_list = os.listdir(fn_root)
img_fn_list = [img_fn for img_fn in img_fn_list
if img_fn.endswith('.webp')]
num_examples = len(img_fn_list)
n_examples_per_file = 10000
for example_idx, img_fn in enumerate(img_fn_list):
if example_idx % n_examples_per_file == 0:
file_out = "%s_%05d.tfrecords"
file_out = file_out % (FLAGS.file_out,
example_idx // n_examples_per_file)
print "Writing on:", file_out
writer = tf.python_io.TFRecordWriter(file_out)
if example_idx % 1000 == 0:
print example_idx, "/", num_examples
image_raw = numpy.array(Image.open(os.path.join(fn_root, img_fn)))
rows = image_raw.shape[0]
cols = image_raw.shape[1]
depth = image_raw.shape[2]
downscale = min(rows / 96., cols / 96.)
image_raw = skimage.transform.pyramid_reduce(image_raw, downscale)
image_raw *= 255.
image_raw = image_raw.astype("uint8")
rows = image_raw.shape[0]
cols = image_raw.shape[1]
depth = image_raw.shape[2]
image_raw = image_raw.tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
"height": _int64_feature(rows),
"width": _int64_feature(cols),
"depth": _int64_feature(depth),
"image_raw": _bytes_feature(image_raw)
}
)
)
writer.write(example.SerializeToString())
if example_idx % n_examples_per_file == (n_examples_per_file - 1):
writer.close()
writer.close()
if __name__ == "__main__":
main()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Script for training, evaluation and sampling for Real NVP.
$ python real_nvp_multiscale_dataset.py \
--alsologtostderr \
--image_size 64 \
--hpconfig=n_scale=5,base_dim=8 \
--dataset imnet \
--data_path [DATA_PATH]
"""
import time
from datetime import datetime
import os
import numpy
import tensorflow as tf
from tensorflow import gfile
from real_nvp_utils import (
batch_norm, batch_norm_log_diff, conv_layer,
squeeze_2x2, squeeze_2x2_ordered, standard_normal_ll,
standard_normal_sample, unsqueeze_2x2, variable_on_cpu)
tf.flags.DEFINE_string("master", "local",
"BNS name of the TensorFlow master, or local.")
tf.flags.DEFINE_string("logdir", "/tmp/real_nvp_multiscale",
"Directory to which writes logs.")
tf.flags.DEFINE_string("traindir", "/tmp/real_nvp_multiscale",
"Directory to which writes logs.")
tf.flags.DEFINE_integer("train_steps", 1000000000000000000,
"Number of steps to train for.")
tf.flags.DEFINE_string("data_path", "", "Path to the data.")
tf.flags.DEFINE_string("mode", "train",
"Mode of execution. Must be 'train', "
"'sample' or 'eval'.")
tf.flags.DEFINE_string("dataset", "imnet",
"Dataset used. Must be 'imnet', "
"'celeba' or 'lsun'.")
tf.flags.DEFINE_integer("recursion_type", 2,
"Type of the recursion.")
tf.flags.DEFINE_integer("image_size", 64,
"Size of the input image.")
tf.flags.DEFINE_integer("eval_set_size", 0,
"Size of evaluation dataset.")
tf.flags.DEFINE_string(
"hpconfig", "",
"A comma separated list of hyperparameters for the model. Format is "
"hp1=value1,hp2=value2,etc. If this FLAG is set, the model will be trained "
"with the specified hyperparameters, filling in missing hyperparameters "
"from the default_values in |hyper_params|.")
FLAGS = tf.flags.FLAGS
class HParams(object):
"""Dictionary of hyperparameters."""
def __init__(self, **kwargs):
self.dict_ = kwargs
self.__dict__.update(self.dict_)
def update_config(self, in_string):
"""Update the dictionary with a comma separated list."""
pairs = in_string.split(",")
pairs = [pair.split("=") for pair in pairs]
for key, val in pairs:
self.dict_[key] = type(self.dict_[key])(val)
self.__dict__.update(self.dict_)
return self
def __getitem__(self, key):
return self.dict_[key]
def __setitem__(self, key, val):
self.dict_[key] = val
self.__dict__.update(self.dict_)
def get_default_hparams():
"""Get the default hyperparameters."""
return HParams(
batch_size=64,
residual_blocks=2,
n_couplings=2,
n_scale=4,
learning_rate=0.001,
momentum=1e-1,
decay=1e-3,
l2_coeff=0.00005,
clip_gradient=100.,
optimizer="adam",
dropout_mask=0,
base_dim=32,
bottleneck=0,
use_batch_norm=1,
alternate=1,
use_aff=1,
skip=1,
data_constraint=.9,
n_opt=0)
# RESNET UTILS
def residual_block(input_, dim, name, use_batch_norm=True,
train=True, weight_norm=True, bottleneck=False):
"""Residual convolutional block."""
with tf.variable_scope(name):
res = input_
if use_batch_norm:
res = batch_norm(
input_=res, dim=dim, name="bn_in", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res = tf.nn.relu(res)
if bottleneck:
res = conv_layer(
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
name="h_0", stddev=numpy.sqrt(2. / (dim)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None, bias=(not use_batch_norm),
weight_norm=weight_norm, scale=False)
if use_batch_norm:
res = batch_norm(
input_=res, dim=dim,
name="bn_0", scale=False, train=train,
epsilon=1e-4, axes=[0, 1, 2])
res = tf.nn.relu(res)
res = conv_layer(
input_=res, filter_size=[3, 3], dim_in=dim,
dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)),
strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None,
bias=(not use_batch_norm),
weight_norm=weight_norm, scale=False)
if use_batch_norm:
res = batch_norm(
input_=res, dim=dim, name="bn_1", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res = tf.nn.relu(res)
res = conv_layer(
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
name="out", stddev=numpy.sqrt(2. / (1. * dim)),
strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None,
bias=True, weight_norm=weight_norm, scale=True)
else:
res = conv_layer(
input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim,
name="h_0", stddev=numpy.sqrt(2. / (dim)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None, bias=(not use_batch_norm),
weight_norm=weight_norm, scale=False)
if use_batch_norm:
res = batch_norm(
input_=res, dim=dim, name="bn_0", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res = tf.nn.relu(res)
res = conv_layer(
input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim,
name="out", stddev=numpy.sqrt(2. / (1. * dim)),
strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None,
bias=True, weight_norm=weight_norm, scale=True)
res += input_
return res
def resnet(input_, dim_in, dim, dim_out, name, use_batch_norm=True,
train=True, weight_norm=True, residual_blocks=5,
bottleneck=False, skip=True):
"""Residual convolutional network."""
with tf.variable_scope(name):
res = input_
if residual_blocks != 0:
res = conv_layer(
input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim,
name="h_in", stddev=numpy.sqrt(2. / (dim_in)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None, bias=True,
weight_norm=weight_norm, scale=False)
if skip:
out = conv_layer(
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
name="skip_in", stddev=numpy.sqrt(2. / (dim)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None, bias=True,
weight_norm=weight_norm, scale=True)
# residual blocks
for idx_block in xrange(residual_blocks):
res = residual_block(res, dim, "block_%d" % idx_block,
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
bottleneck=bottleneck)
if skip:
out += conv_layer(
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
name="skip_%d" % idx_block, stddev=numpy.sqrt(2. / (dim)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None, bias=True,
weight_norm=weight_norm, scale=True)
# outputs
if skip:
res = out
if use_batch_norm:
res = batch_norm(
input_=res, dim=dim, name="bn_pre_out", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res = tf.nn.relu(res)
res = conv_layer(
input_=res, filter_size=[1, 1], dim_in=dim,
dim_out=dim_out,
name="out", stddev=numpy.sqrt(2. / (1. * dim)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None, bias=True,
weight_norm=weight_norm, scale=True)
else:
if bottleneck:
res = conv_layer(
input_=res, filter_size=[1, 1], dim_in=dim_in, dim_out=dim,
name="h_0", stddev=numpy.sqrt(2. / (dim_in)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None, bias=(not use_batch_norm),
weight_norm=weight_norm, scale=False)
if use_batch_norm:
res = batch_norm(
input_=res, dim=dim, name="bn_0", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res = tf.nn.relu(res)
res = conv_layer(
input_=res, filter_size=[3, 3], dim_in=dim,
dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None,
bias=(not use_batch_norm),
weight_norm=weight_norm, scale=False)
if use_batch_norm:
res = batch_norm(
input_=res, dim=dim, name="bn_1", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res = tf.nn.relu(res)
res = conv_layer(
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim_out,
name="out", stddev=numpy.sqrt(2. / (1. * dim)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None, bias=True,
weight_norm=weight_norm, scale=True)
else:
res = conv_layer(
input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim,
name="h_0", stddev=numpy.sqrt(2. / (dim_in)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None, bias=(not use_batch_norm),
weight_norm=weight_norm, scale=False)
if use_batch_norm:
res = batch_norm(
input_=res, dim=dim, name="bn_0", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res = tf.nn.relu(res)
res = conv_layer(
input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim_out,
name="out", stddev=numpy.sqrt(2. / (1. * dim)),
strides=[1, 1, 1, 1], padding="SAME",
nonlinearity=None, bias=True,
weight_norm=weight_norm, scale=True)
return res
# COUPLING LAYERS
# masked convolution implementations
def masked_conv_aff_coupling(input_, mask_in, dim, name,
use_batch_norm=True, train=True, weight_norm=True,
reverse=False, residual_blocks=5,
bottleneck=False, use_width=1., use_height=1.,
mask_channel=0., skip=True):
"""Affine coupling with masked convolution."""
with tf.variable_scope(name) as scope:
if reverse or (not train):
scope.reuse_variables()
shape = input_.get_shape().as_list()
batch_size = shape[0]
height = shape[1]
width = shape[2]
channels = shape[3]
# build mask
mask = use_width * numpy.arange(width)
mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask
mask = mask.astype("float32")
mask = tf.mod(mask_in + mask, 2)
mask = tf.reshape(mask, [-1, height, width, 1])
if mask.get_shape().as_list()[0] == 1:
mask = tf.tile(mask, [batch_size, 1, 1, 1])
res = input_ * tf.mod(mask_channel + mask, 2)
# initial input
if use_batch_norm:
res = batch_norm(
input_=res, dim=channels, name="bn_in", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res *= 2.
res = tf.concat_v2([res, -res], 3)
res = tf.concat_v2([res, mask], 3)
dim_in = 2. * channels + 1
res = tf.nn.relu(res)
res = resnet(input_=res, dim_in=dim_in, dim=dim,
dim_out=2 * channels,
name="resnet", use_batch_norm=use_batch_norm,
train=train, weight_norm=weight_norm,
residual_blocks=residual_blocks,
bottleneck=bottleneck, skip=skip)
mask = tf.mod(mask_channel + mask, 2)
res = tf.split(res, 2, 3)
shift, log_rescaling = res[-2], res[-1]
scale = variable_on_cpu(
"rescaling_scale", [],
tf.constant_initializer(0.))
shift = tf.reshape(
shift, [batch_size, height, width, channels])
log_rescaling = tf.reshape(
log_rescaling, [batch_size, height, width, channels])
log_rescaling = scale * tf.tanh(log_rescaling)
if not use_batch_norm:
scale_shift = variable_on_cpu(
"scale_shift", [],
tf.constant_initializer(0.))
log_rescaling += scale_shift
shift *= (1. - mask)
log_rescaling *= (1. - mask)
if reverse:
res = input_
if use_batch_norm:
mean, var = batch_norm_log_diff(
input_=res * (1. - mask), dim=channels, name="bn_out",
train=False, epsilon=1e-4, axes=[0, 1, 2])
log_var = tf.log(var)
res *= tf.exp(.5 * log_var * (1. - mask))
res += mean * (1. - mask)
res *= tf.exp(-log_rescaling)
res -= shift
log_diff = -log_rescaling
if use_batch_norm:
log_diff += .5 * log_var * (1. - mask)
else:
res = input_
res += shift
res *= tf.exp(log_rescaling)
log_diff = log_rescaling
if use_batch_norm:
mean, var = batch_norm_log_diff(
input_=res * (1. - mask), dim=channels, name="bn_out",
train=train, epsilon=1e-4, axes=[0, 1, 2])
log_var = tf.log(var)
res -= mean * (1. - mask)
res *= tf.exp(-.5 * log_var * (1. - mask))
log_diff -= .5 * log_var * (1. - mask)
return res, log_diff
def masked_conv_add_coupling(input_, mask_in, dim, name,
use_batch_norm=True, train=True, weight_norm=True,
reverse=False, residual_blocks=5,
bottleneck=False, use_width=1., use_height=1.,
mask_channel=0., skip=True):
"""Additive coupling with masked convolution."""
with tf.variable_scope(name) as scope:
if reverse or (not train):
scope.reuse_variables()
shape = input_.get_shape().as_list()
batch_size = shape[0]
height = shape[1]
width = shape[2]
channels = shape[3]
# build mask
mask = use_width * numpy.arange(width)
mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask
mask = mask.astype("float32")
mask = tf.mod(mask_in + mask, 2)
mask = tf.reshape(mask, [-1, height, width, 1])
if mask.get_shape().as_list()[0] == 1:
mask = tf.tile(mask, [batch_size, 1, 1, 1])
res = input_ * tf.mod(mask_channel + mask, 2)
# initial input
if use_batch_norm:
res = batch_norm(
input_=res, dim=channels, name="bn_in", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res *= 2.
res = tf.concat_v2([res, -res], 3)
res = tf.concat_v2([res, mask], 3)
dim_in = 2. * channels + 1
res = tf.nn.relu(res)
shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels,
name="resnet", use_batch_norm=use_batch_norm,
train=train, weight_norm=weight_norm,
residual_blocks=residual_blocks,
bottleneck=bottleneck, skip=skip)
mask = tf.mod(mask_channel + mask, 2)
shift *= (1. - mask)
# use_batch_norm = False
if reverse:
res = input_
if use_batch_norm:
mean, var = batch_norm_log_diff(
input_=res * (1. - mask),
dim=channels, name="bn_out", train=False, epsilon=1e-4)
log_var = tf.log(var)
res *= tf.exp(.5 * log_var * (1. - mask))
res += mean * (1. - mask)
res -= shift
log_diff = tf.zeros_like(res)
if use_batch_norm:
log_diff += .5 * log_var * (1. - mask)
else:
res = input_
res += shift
log_diff = tf.zeros_like(res)
if use_batch_norm:
mean, var = batch_norm_log_diff(
input_=res * (1. - mask), dim=channels,
name="bn_out", train=train, epsilon=1e-4, axes=[0, 1, 2])
log_var = tf.log(var)
res -= mean * (1. - mask)
res *= tf.exp(-.5 * log_var * (1. - mask))
log_diff -= .5 * log_var * (1. - mask)
return res, log_diff
def masked_conv_coupling(input_, mask_in, dim, name,
use_batch_norm=True, train=True, weight_norm=True,
reverse=False, residual_blocks=5,
bottleneck=False, use_aff=True,
use_width=1., use_height=1.,
mask_channel=0., skip=True):
"""Coupling with masked convolution."""
if use_aff:
return masked_conv_aff_coupling(
input_=input_, mask_in=mask_in, dim=dim, name=name,
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
reverse=reverse, residual_blocks=residual_blocks,
bottleneck=bottleneck, use_width=use_width, use_height=use_height,
mask_channel=mask_channel, skip=skip)
else:
return masked_conv_add_coupling(
input_=input_, mask_in=mask_in, dim=dim, name=name,
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
reverse=reverse, residual_blocks=residual_blocks,
bottleneck=bottleneck, use_width=use_width, use_height=use_height,
mask_channel=mask_channel, skip=skip)
# channel-axis splitting implementations
def conv_ch_aff_coupling(input_, dim, name,
use_batch_norm=True, train=True, weight_norm=True,
reverse=False, residual_blocks=5,
bottleneck=False, change_bottom=True, skip=True):
"""Affine coupling with channel-wise splitting."""
with tf.variable_scope(name) as scope:
if reverse or (not train):
scope.reuse_variables()
if change_bottom:
input_, canvas = tf.split(input_, 2, 3)
else:
canvas, input_ = tf.split(input_, 2, 3)
shape = input_.get_shape().as_list()
batch_size = shape[0]
height = shape[1]
width = shape[2]
channels = shape[3]
res = input_
# initial input
if use_batch_norm:
res = batch_norm(
input_=res, dim=channels, name="bn_in", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res = tf.concat_v2([res, -res], 3)
dim_in = 2. * channels
res = tf.nn.relu(res)
res = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=2 * channels,
name="resnet", use_batch_norm=use_batch_norm,
train=train, weight_norm=weight_norm,
residual_blocks=residual_blocks,
bottleneck=bottleneck, skip=skip)
shift, log_rescaling = tf.split(res, 2, 3)
scale = variable_on_cpu(
"scale", [],
tf.constant_initializer(1.))
shift = tf.reshape(
shift, [batch_size, height, width, channels])
log_rescaling = tf.reshape(
log_rescaling, [batch_size, height, width, channels])
log_rescaling = scale * tf.tanh(log_rescaling)
if not use_batch_norm:
scale_shift = variable_on_cpu(
"scale_shift", [],
tf.constant_initializer(0.))
log_rescaling += scale_shift
if reverse:
res = canvas
if use_batch_norm:
mean, var = batch_norm_log_diff(
input_=res, dim=channels, name="bn_out", train=False,
epsilon=1e-4, axes=[0, 1, 2])
log_var = tf.log(var)
res *= tf.exp(.5 * log_var)
res += mean
res *= tf.exp(-log_rescaling)
res -= shift
log_diff = -log_rescaling
if use_batch_norm:
log_diff += .5 * log_var
else:
res = canvas
res += shift
res *= tf.exp(log_rescaling)
log_diff = log_rescaling
if use_batch_norm:
mean, var = batch_norm_log_diff(
input_=res, dim=channels, name="bn_out", train=train,
epsilon=1e-4, axes=[0, 1, 2])
log_var = tf.log(var)
res -= mean
res *= tf.exp(-.5 * log_var)
log_diff -= .5 * log_var
if change_bottom:
res = tf.concat_v2([input_, res], 3)
log_diff = tf.concat_v2([tf.zeros_like(log_diff), log_diff], 3)
else:
res = tf.concat_v2([res, input_], 3)
log_diff = tf.concat_v2([log_diff, tf.zeros_like(log_diff)], 3)
return res, log_diff
def conv_ch_add_coupling(input_, dim, name,
use_batch_norm=True, train=True, weight_norm=True,
reverse=False, residual_blocks=5,
bottleneck=False, change_bottom=True, skip=True):
"""Additive coupling with channel-wise splitting."""
with tf.variable_scope(name) as scope:
if reverse or (not train):
scope.reuse_variables()
if change_bottom:
input_, canvas = tf.split(input_, 2, 3)
else:
canvas, input_ = tf.split(input_, 2, 3)
shape = input_.get_shape().as_list()
channels = shape[3]
res = input_
# initial input
if use_batch_norm:
res = batch_norm(
input_=res, dim=channels, name="bn_in", scale=False,
train=train, epsilon=1e-4, axes=[0, 1, 2])
res = tf.concat_v2([res, -res], 3)
dim_in = 2. * channels
res = tf.nn.relu(res)
shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels,
name="resnet", use_batch_norm=use_batch_norm,
train=train, weight_norm=weight_norm,
residual_blocks=residual_blocks,
bottleneck=bottleneck, skip=skip)
if reverse:
res = canvas
if use_batch_norm:
mean, var = batch_norm_log_diff(
input_=res, dim=channels, name="bn_out", train=False,
epsilon=1e-4, axes=[0, 1, 2])
log_var = tf.log(var)
res *= tf.exp(.5 * log_var)
res += mean
res -= shift
log_diff = tf.zeros_like(res)
if use_batch_norm:
log_diff += .5 * log_var
else:
res = canvas
res += shift
log_diff = tf.zeros_like(res)
if use_batch_norm:
mean, var = batch_norm_log_diff(
input_=res, dim=channels, name="bn_out", train=train,
epsilon=1e-4, axes=[0, 1, 2])
log_var = tf.log(var)
res -= mean
res *= tf.exp(-.5 * log_var)
log_diff -= .5 * log_var
if change_bottom:
res = tf.concat_v2([input_, res], 3)
log_diff = tf.concat_v2([tf.zeros_like(log_diff), log_diff], 3)
else:
res = tf.concat_v2([res, input_], 3)
log_diff = tf.concat_v2([log_diff, tf.zeros_like(log_diff)], 3)
return res, log_diff
def conv_ch_coupling(input_, dim, name,
use_batch_norm=True, train=True, weight_norm=True,
reverse=False, residual_blocks=5,
bottleneck=False, use_aff=True, change_bottom=True,
skip=True):
"""Coupling with channel-wise splitting."""
if use_aff:
return conv_ch_aff_coupling(
input_=input_, dim=dim, name=name,
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
reverse=reverse, residual_blocks=residual_blocks,
bottleneck=bottleneck, change_bottom=change_bottom, skip=skip)
else:
return conv_ch_add_coupling(
input_=input_, dim=dim, name=name,
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
reverse=reverse, residual_blocks=residual_blocks,
bottleneck=bottleneck, change_bottom=change_bottom, skip=skip)
# RECURSIVE USE OF COUPLING LAYERS
def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale,
use_batch_norm=True, weight_norm=True,
train=True):
"""Recursion on coupling layers."""
shape = input_.get_shape().as_list()
channels = shape[3]
residual_blocks = hps.residual_blocks
base_dim = hps.base_dim
mask = 1.
use_aff = hps.use_aff
res = input_
skip = hps.skip
log_diff = tf.zeros_like(input_)
dim = base_dim
if FLAGS.recursion_type < 4:
dim *= 2 ** scale_idx
with tf.variable_scope("scale_%d" % scale_idx):
# initial coupling layers
res, inc_log_diff = masked_conv_coupling(
input_=res,
mask_in=mask, dim=dim,
name="coupling_0",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=False, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=use_aff,
use_width=1., use_height=1., skip=skip)
log_diff += inc_log_diff
res, inc_log_diff = masked_conv_coupling(
input_=res,
mask_in=1. - mask, dim=dim,
name="coupling_1",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=False, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=use_aff,
use_width=1., use_height=1., skip=skip)
log_diff += inc_log_diff
res, inc_log_diff = masked_conv_coupling(
input_=res,
mask_in=mask, dim=dim,
name="coupling_2",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=False, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=True,
use_width=1., use_height=1., skip=skip)
log_diff += inc_log_diff
if scale_idx < (n_scale - 1):
with tf.variable_scope("scale_%d" % scale_idx):
res = squeeze_2x2(res)
log_diff = squeeze_2x2(log_diff)
res, inc_log_diff = conv_ch_coupling(
input_=res,
change_bottom=True, dim=2 * dim,
name="coupling_4",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=False, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
log_diff += inc_log_diff
res, inc_log_diff = conv_ch_coupling(
input_=res,
change_bottom=False, dim=2 * dim,
name="coupling_5",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=False, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
log_diff += inc_log_diff
res, inc_log_diff = conv_ch_coupling(
input_=res,
change_bottom=True, dim=2 * dim,
name="coupling_6",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=False, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=True, skip=skip)
log_diff += inc_log_diff
res = unsqueeze_2x2(res)
log_diff = unsqueeze_2x2(log_diff)
if FLAGS.recursion_type > 1:
res = squeeze_2x2_ordered(res)
log_diff = squeeze_2x2_ordered(log_diff)
if FLAGS.recursion_type > 2:
res_1 = res[:, :, :, :channels]
res_2 = res[:, :, :, channels:]
log_diff_1 = log_diff[:, :, :, :channels]
log_diff_2 = log_diff[:, :, :, channels:]
else:
res_1, res_2 = tf.split(res, 2, 3)
log_diff_1, log_diff_2 = tf.split(log_diff, 2, 3)
res_1, inc_log_diff = rec_masked_conv_coupling(
input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale,
use_batch_norm=use_batch_norm, weight_norm=weight_norm,
train=train)
res = tf.concat_v2([res_1, res_2], 3)
log_diff_1 += inc_log_diff
log_diff = tf.concat_v2([log_diff_1, log_diff_2], 3)
res = squeeze_2x2_ordered(res, reverse=True)
log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
else:
res = squeeze_2x2_ordered(res)
log_diff = squeeze_2x2_ordered(log_diff)
res, inc_log_diff = rec_masked_conv_coupling(
input_=res, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale,
use_batch_norm=use_batch_norm, weight_norm=weight_norm,
train=train)
log_diff += inc_log_diff
res = squeeze_2x2_ordered(res, reverse=True)
log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
else:
with tf.variable_scope("scale_%d" % scale_idx):
res, inc_log_diff = masked_conv_coupling(
input_=res,
mask_in=1. - mask, dim=dim,
name="coupling_3",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=False, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=True,
use_width=1., use_height=1., skip=skip)
log_diff += inc_log_diff
return res, log_diff
def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale,
use_batch_norm=True, weight_norm=True,
train=True):
"""Recursion on inverting coupling layers."""
shape = input_.get_shape().as_list()
channels = shape[3]
residual_blocks = hps.residual_blocks
base_dim = hps.base_dim
mask = 1.
use_aff = hps.use_aff
res = input_
log_diff = tf.zeros_like(input_)
skip = hps.skip
dim = base_dim
if FLAGS.recursion_type < 4:
dim *= 2 ** scale_idx
if scale_idx < (n_scale - 1):
if FLAGS.recursion_type > 1:
res = squeeze_2x2_ordered(res)
log_diff = squeeze_2x2_ordered(log_diff)
if FLAGS.recursion_type > 2:
res_1 = res[:, :, :, :channels]
res_2 = res[:, :, :, channels:]
log_diff_1 = log_diff[:, :, :, :channels]
log_diff_2 = log_diff[:, :, :, channels:]
else:
res_1, res_2 = tf.split(res, 2, 3)
log_diff_1, log_diff_2 = tf.split(log_diff, 2, 3)
res_1, log_diff_1 = rec_masked_deconv_coupling(
input_=res_1, hps=hps,
scale_idx=scale_idx + 1, n_scale=n_scale,
use_batch_norm=use_batch_norm, weight_norm=weight_norm,
train=train)
res = tf.concat_v2([res_1, res_2], 3)
log_diff = tf.concat_v2([log_diff_1, log_diff_2], 3)
res = squeeze_2x2_ordered(res, reverse=True)
log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
else:
res = squeeze_2x2_ordered(res)
log_diff = squeeze_2x2_ordered(log_diff)
res, log_diff = rec_masked_deconv_coupling(
input_=res, hps=hps,
scale_idx=scale_idx + 1, n_scale=n_scale,
use_batch_norm=use_batch_norm, weight_norm=weight_norm,
train=train)
res = squeeze_2x2_ordered(res, reverse=True)
log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
with tf.variable_scope("scale_%d" % scale_idx):
res = squeeze_2x2(res)
log_diff = squeeze_2x2(log_diff)
res, inc_log_diff = conv_ch_coupling(
input_=res,
change_bottom=True, dim=2 * dim,
name="coupling_6",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=True, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=True, skip=skip)
log_diff += inc_log_diff
res, inc_log_diff = conv_ch_coupling(
input_=res,
change_bottom=False, dim=2 * dim,
name="coupling_5",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=True, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
log_diff += inc_log_diff
res, inc_log_diff = conv_ch_coupling(
input_=res,
change_bottom=True, dim=2 * dim,
name="coupling_4",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=True, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
log_diff += inc_log_diff
res = unsqueeze_2x2(res)
log_diff = unsqueeze_2x2(log_diff)
else:
with tf.variable_scope("scale_%d" % scale_idx):
res, inc_log_diff = masked_conv_coupling(
input_=res,
mask_in=1. - mask, dim=dim,
name="coupling_3",
use_batch_norm=use_batch_norm, train=train,
weight_norm=weight_norm,
reverse=True, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=True,
use_width=1., use_height=1., skip=skip)
log_diff += inc_log_diff
with tf.variable_scope("scale_%d" % scale_idx):
res, inc_log_diff = masked_conv_coupling(
input_=res,
mask_in=mask, dim=dim,
name="coupling_2",
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
reverse=True, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=True,
use_width=1., use_height=1., skip=skip)
log_diff += inc_log_diff
res, inc_log_diff = masked_conv_coupling(
input_=res,
mask_in=1. - mask, dim=dim,
name="coupling_1",
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
reverse=True, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=use_aff,
use_width=1., use_height=1., skip=skip)
log_diff += inc_log_diff
res, inc_log_diff = masked_conv_coupling(
input_=res,
mask_in=mask, dim=dim,
name="coupling_0",
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
reverse=True, residual_blocks=residual_blocks,
bottleneck=hps.bottleneck, use_aff=use_aff,
use_width=1., use_height=1., skip=skip)
log_diff += inc_log_diff
return res, log_diff
# ENCODER AND DECODER IMPLEMENTATIONS
# start the recursions
def encoder(input_, hps, n_scale, use_batch_norm=True,
weight_norm=True, train=True):
"""Encoding/gaussianization function."""
res = input_
log_diff = tf.zeros_like(input_)
res, inc_log_diff = rec_masked_conv_coupling(
input_=res, hps=hps, scale_idx=0, n_scale=n_scale,
use_batch_norm=use_batch_norm, weight_norm=weight_norm,
train=train)
log_diff += inc_log_diff
return res, log_diff
def decoder(input_, hps, n_scale, use_batch_norm=True,
weight_norm=True, train=True):
"""Decoding/generator function."""
res, log_diff = rec_masked_deconv_coupling(
input_=input_, hps=hps, scale_idx=0, n_scale=n_scale,
use_batch_norm=use_batch_norm, weight_norm=weight_norm,
train=train)
return res, log_diff
class RealNVP(object):
"""Real NVP model."""
def __init__(self, hps, sampling=False):
# DATA TENSOR INSTANTIATION
device = "/cpu:0"
if FLAGS.dataset == "imnet":
with tf.device(
tf.train.replica_device_setter(0, worker_device=device)):
filename_queue = tf.train.string_input_producer(
gfile.Glob(FLAGS.data_path), num_epochs=None)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
"image_raw": tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features["image_raw"], tf.uint8)
image.set_shape([FLAGS.image_size * FLAGS.image_size * 3])
image = tf.cast(image, tf.float32)
if FLAGS.mode == "train":
images = tf.train.shuffle_batch(
[image], batch_size=hps.batch_size, num_threads=1,
capacity=1000 + 3 * hps.batch_size,
# Ensures a minimum amount of shuffling of examples.
min_after_dequeue=1000)
else:
images = tf.train.batch(
[image], batch_size=hps.batch_size, num_threads=1,
capacity=1000 + 3 * hps.batch_size)
self.x_orig = x_orig = images
image_size = FLAGS.image_size
x_in = tf.reshape(
x_orig,
[hps.batch_size, FLAGS.image_size, FLAGS.image_size, 3])
x_in = tf.clip_by_value(x_in, 0, 255)
x_in = (tf.cast(x_in, tf.float32)
+ tf.random_uniform(tf.shape(x_in))) / 256.
elif FLAGS.dataset == "celeba":
with tf.device(
tf.train.replica_device_setter(0, worker_device=device)):
filename_queue = tf.train.string_input_producer(
gfile.Glob(FLAGS.data_path), num_epochs=None)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
"image_raw": tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features["image_raw"], tf.uint8)
image.set_shape([218 * 178 * 3]) # 218, 178
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [218, 178, 3])
image = image[40:188, 15:163, :]
if FLAGS.mode == "train":
image = tf.image.random_flip_left_right(image)
images = tf.train.shuffle_batch(
[image], batch_size=hps.batch_size, num_threads=1,
capacity=1000 + 3 * hps.batch_size,
min_after_dequeue=1000)
else:
images = tf.train.batch(
[image], batch_size=hps.batch_size, num_threads=1,
capacity=1000 + 3 * hps.batch_size)
self.x_orig = x_orig = images
image_size = 64
x_in = tf.reshape(x_orig, [hps.batch_size, 148, 148, 3])
x_in = tf.image.resize_images(
x_in, [64, 64], method=0, align_corners=False)
x_in = (tf.cast(x_in, tf.float32)
+ tf.random_uniform(tf.shape(x_in))) / 256.
elif FLAGS.dataset == "lsun":
with tf.device(
tf.train.replica_device_setter(0, worker_device=device)):
filename_queue = tf.train.string_input_producer(
gfile.Glob(FLAGS.data_path), num_epochs=None)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
"image_raw": tf.FixedLenFeature([], tf.string),
"height": tf.FixedLenFeature([], tf.int64),
"width": tf.FixedLenFeature([], tf.int64),
"depth": tf.FixedLenFeature([], tf.int64)
})
image = tf.decode_raw(features["image_raw"], tf.uint8)
height = tf.reshape((features["height"], tf.int64)[0], [1])
height = tf.cast(height, tf.int32)
width = tf.reshape((features["width"], tf.int64)[0], [1])
width = tf.cast(width, tf.int32)
depth = tf.reshape((features["depth"], tf.int64)[0], [1])
depth = tf.cast(depth, tf.int32)
image = tf.reshape(image, tf.concat_v2([height, width, depth], 0))
image = tf.random_crop(image, [64, 64, 3])
if FLAGS.mode == "train":
image = tf.image.random_flip_left_right(image)
images = tf.train.shuffle_batch(
[image], batch_size=hps.batch_size, num_threads=1,
capacity=1000 + 3 * hps.batch_size,
# Ensures a minimum amount of shuffling of examples.
min_after_dequeue=1000)
else:
images = tf.train.batch(
[image], batch_size=hps.batch_size, num_threads=1,
capacity=1000 + 3 * hps.batch_size)
self.x_orig = x_orig = images
image_size = 64
x_in = tf.reshape(x_orig, [hps.batch_size, 64, 64, 3])
x_in = (tf.cast(x_in, tf.float32)
+ tf.random_uniform(tf.shape(x_in))) / 256.
else:
raise ValueError("Unknown dataset.")
x_in = tf.reshape(x_in, [hps.batch_size, image_size, image_size, 3])
side_shown = int(numpy.sqrt(hps.batch_size))
shown_x = tf.transpose(
tf.reshape(
x_in[:(side_shown * side_shown), :, :, :],
[side_shown, image_size * side_shown, image_size, 3]),
[0, 2, 1, 3])
shown_x = tf.transpose(
tf.reshape(
shown_x,
[1, image_size * side_shown, image_size * side_shown, 3]),
[0, 2, 1, 3]) * 255.
tf.summary.image(
"inputs",
tf.cast(shown_x, tf.uint8),
max_outputs=1)
# restrict the data
FLAGS.image_size = image_size
data_constraint = hps.data_constraint
pre_logit_scale = numpy.log(data_constraint)
pre_logit_scale -= numpy.log(1. - data_constraint)
pre_logit_scale = tf.cast(pre_logit_scale, tf.float32)
logit_x_in = 2. * x_in # [0, 2]
logit_x_in -= 1. # [-1, 1]
logit_x_in *= data_constraint # [-.9, .9]
logit_x_in += 1. # [.1, 1.9]
logit_x_in /= 2. # [.05, .95]
# logit the data
logit_x_in = tf.log(logit_x_in) - tf.log(1. - logit_x_in)
transform_cost = tf.reduce_sum(
tf.nn.softplus(logit_x_in) + tf.nn.softplus(-logit_x_in)
- tf.nn.softplus(-pre_logit_scale),
[1, 2, 3])
# INFERENCE AND COSTS
z_out, log_diff = encoder(
input_=logit_x_in, hps=hps, n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=True)
if FLAGS.mode != "train":
z_out, log_diff = encoder(
input_=logit_x_in, hps=hps, n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=False)
final_shape = [image_size, image_size, 3]
prior_ll = standard_normal_ll(z_out)
prior_ll = tf.reduce_sum(prior_ll, [1, 2, 3])
log_diff = tf.reduce_sum(log_diff, [1, 2, 3])
log_diff += transform_cost
cost = -(prior_ll + log_diff)
self.x_in = x_in
self.z_out = z_out
self.cost = cost = tf.reduce_mean(cost)
l2_reg = sum(
[tf.reduce_sum(tf.square(v)) for v in tf.trainable_variables()
if ("magnitude" in v.name) or ("rescaling_scale" in v.name)])
bit_per_dim = ((cost + numpy.log(256.) * image_size * image_size * 3.)
/ (image_size * image_size * 3. * numpy.log(2.)))
self.bit_per_dim = bit_per_dim
# OPTIMIZATION
momentum = 1. - hps.momentum
decay = 1. - hps.decay
if hps.optimizer == "adam":
optimizer = tf.train.AdamOptimizer(
learning_rate=hps.learning_rate,
beta1=momentum, beta2=decay, epsilon=1e-08,
use_locking=False, name="Adam")
elif hps.optimizer == "rmsprop":
optimizer = tf.train.RMSPropOptimizer(
learning_rate=hps.learning_rate, decay=decay,
momentum=momentum, epsilon=1e-04,
use_locking=False, name="RMSProp")
else:
optimizer = tf.train.MomentumOptimizer(hps.learning_rate,
momentum=momentum)
step = tf.get_variable(
"global_step", [], tf.int64,
tf.zeros_initializer(),
trainable=False)
self.step = step
grads_and_vars = optimizer.compute_gradients(
cost + hps.l2_coeff * l2_reg,
tf.trainable_variables())
grads, vars_ = zip(*grads_and_vars)
capped_grads, gradient_norm = tf.clip_by_global_norm(
grads, clip_norm=hps.clip_gradient)
gradient_norm = tf.check_numerics(gradient_norm,
"Gradient norm is NaN or Inf.")
l2_z = tf.reduce_sum(tf.square(z_out), [1, 2, 3])
if not sampling:
tf.summary.scalar("negative_log_likelihood", tf.reshape(cost, []))
tf.summary.scalar("gradient_norm", tf.reshape(gradient_norm, []))
tf.summary.scalar("bit_per_dim", tf.reshape(bit_per_dim, []))
tf.summary.scalar("log_diff", tf.reshape(tf.reduce_mean(log_diff), []))
tf.summary.scalar("prior_ll", tf.reshape(tf.reduce_mean(prior_ll), []))
tf.summary.scalar(
"log_diff_var",
tf.reshape(tf.reduce_mean(tf.square(log_diff))
- tf.square(tf.reduce_mean(log_diff)), []))
tf.summary.scalar(
"prior_ll_var",
tf.reshape(tf.reduce_mean(tf.square(prior_ll))
- tf.square(tf.reduce_mean(prior_ll)), []))
tf.summary.scalar("l2_z_mean", tf.reshape(tf.reduce_mean(l2_z), []))
tf.summary.scalar(
"l2_z_var",
tf.reshape(tf.reduce_mean(tf.square(l2_z))
- tf.square(tf.reduce_mean(l2_z)), []))
capped_grads_and_vars = zip(capped_grads, vars_)
self.train_step = optimizer.apply_gradients(
capped_grads_and_vars, global_step=step)
# SAMPLING AND VISUALIZATION
if sampling:
# SAMPLES
sample = standard_normal_sample([100] + final_shape)
sample, _ = decoder(
input_=sample, hps=hps, n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=True)
sample = tf.nn.sigmoid(sample)
sample = tf.clip_by_value(sample, 0, 1) * 255.
sample = tf.reshape(sample, [100, image_size, image_size, 3])
sample = tf.transpose(
tf.reshape(sample, [10, image_size * 10, image_size, 3]),
[0, 2, 1, 3])
sample = tf.transpose(
tf.reshape(sample, [1, image_size * 10, image_size * 10, 3]),
[0, 2, 1, 3])
tf.summary.image(
"samples",
tf.cast(sample, tf.uint8),
max_outputs=1)
# CONCATENATION
concatenation, _ = encoder(
input_=logit_x_in, hps=hps,
n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=False)
concatenation = tf.reshape(
concatenation,
[(side_shown * side_shown), image_size, image_size, 3])
concatenation = tf.transpose(
tf.reshape(
concatenation,
[side_shown, image_size * side_shown, image_size, 3]),
[0, 2, 1, 3])
concatenation = tf.transpose(
tf.reshape(
concatenation,
[1, image_size * side_shown, image_size * side_shown, 3]),
[0, 2, 1, 3])
concatenation, _ = decoder(
input_=concatenation, hps=hps, n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=False)
concatenation = tf.nn.sigmoid(concatenation) * 255.
tf.summary.image(
"concatenation",
tf.cast(concatenation, tf.uint8),
max_outputs=1)
# MANIFOLD
# Data basis
z_u, _ = encoder(
input_=logit_x_in[:8, :, :, :], hps=hps,
n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=False)
u_1 = tf.reshape(z_u[0, :, :, :], [-1])
u_2 = tf.reshape(z_u[1, :, :, :], [-1])
u_3 = tf.reshape(z_u[2, :, :, :], [-1])
u_4 = tf.reshape(z_u[3, :, :, :], [-1])
u_5 = tf.reshape(z_u[4, :, :, :], [-1])
u_6 = tf.reshape(z_u[5, :, :, :], [-1])
u_7 = tf.reshape(z_u[6, :, :, :], [-1])
u_8 = tf.reshape(z_u[7, :, :, :], [-1])
# 3D dome
manifold_side = 8
angle_1 = numpy.arange(manifold_side) * 1. / manifold_side
angle_2 = numpy.arange(manifold_side) * 1. / manifold_side
angle_1 *= 2. * numpy.pi
angle_2 *= 2. * numpy.pi
angle_1 = angle_1.astype("float32")
angle_2 = angle_2.astype("float32")
angle_1 = tf.reshape(angle_1, [1, -1, 1])
angle_1 += tf.zeros([manifold_side, manifold_side, 1])
angle_2 = tf.reshape(angle_2, [-1, 1, 1])
angle_2 += tf.zeros([manifold_side, manifold_side, 1])
n_angle_3 = 40
angle_3 = numpy.arange(n_angle_3) * 1. / n_angle_3
angle_3 *= 2 * numpy.pi
angle_3 = angle_3.astype("float32")
angle_3 = tf.reshape(angle_3, [-1, 1, 1, 1])
angle_3 += tf.zeros([n_angle_3, manifold_side, manifold_side, 1])
manifold = tf.cos(angle_1) * (
tf.cos(angle_2) * (
tf.cos(angle_3) * u_1 + tf.sin(angle_3) * u_2)
+ tf.sin(angle_2) * (
tf.cos(angle_3) * u_3 + tf.sin(angle_3) * u_4))
manifold += tf.sin(angle_1) * (
tf.cos(angle_2) * (
tf.cos(angle_3) * u_5 + tf.sin(angle_3) * u_6)
+ tf.sin(angle_2) * (
tf.cos(angle_3) * u_7 + tf.sin(angle_3) * u_8))
manifold = tf.reshape(
manifold,
[n_angle_3 * manifold_side * manifold_side] + final_shape)
manifold, _ = decoder(
input_=manifold, hps=hps, n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=False)
manifold = tf.nn.sigmoid(manifold)
manifold = tf.clip_by_value(manifold, 0, 1) * 255.
manifold = tf.reshape(
manifold,
[n_angle_3,
manifold_side * manifold_side,
image_size,
image_size,
3])
manifold = tf.transpose(
tf.reshape(
manifold,
[n_angle_3, manifold_side,
image_size * manifold_side, image_size, 3]), [0, 1, 3, 2, 4])
manifold = tf.transpose(
tf.reshape(
manifold,
[n_angle_3, image_size * manifold_side,
image_size * manifold_side, 3]),
[0, 2, 1, 3])
manifold = tf.transpose(manifold, [1, 2, 0, 3])
manifold = tf.reshape(
manifold,
[1, image_size * manifold_side,
image_size * manifold_side, 3 * n_angle_3])
tf.summary.image(
"manifold",
tf.cast(manifold[:, :, :, :3], tf.uint8),
max_outputs=1)
# COMPRESSION
z_complete, _ = encoder(
input_=logit_x_in[:hps.n_scale, :, :, :], hps=hps,
n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=False)
z_compressed_list = [z_complete]
z_noisy_list = [z_complete]
z_lost = z_complete
for scale_idx in xrange(hps.n_scale - 1):
z_lost = squeeze_2x2_ordered(z_lost)
z_lost, _ = tf.split(z_lost, 2, 3)
z_compressed = z_lost
z_noisy = z_lost
for _ in xrange(scale_idx + 1):
z_compressed = tf.concat_v2(
[z_compressed, tf.zeros_like(z_compressed)], 3)
z_compressed = squeeze_2x2_ordered(
z_compressed, reverse=True)
z_noisy = tf.concat_v2(
[z_noisy, tf.random_normal(
z_noisy.get_shape().as_list())], 3)
z_noisy = squeeze_2x2_ordered(z_noisy, reverse=True)
z_compressed_list.append(z_compressed)
z_noisy_list.append(z_noisy)
self.z_reduced = z_lost
z_compressed = tf.concat_v2(z_compressed_list, 0)
z_noisy = tf.concat_v2(z_noisy_list, 0)
noisy_images, _ = decoder(
input_=z_noisy, hps=hps, n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=False)
compressed_images, _ = decoder(
input_=z_compressed, hps=hps, n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=False)
noisy_images = tf.nn.sigmoid(noisy_images)
compressed_images = tf.nn.sigmoid(compressed_images)
noisy_images = tf.clip_by_value(noisy_images, 0, 1) * 255.
noisy_images = tf.reshape(
noisy_images,
[(hps.n_scale * hps.n_scale), image_size, image_size, 3])
noisy_images = tf.transpose(
tf.reshape(
noisy_images,
[hps.n_scale, image_size * hps.n_scale, image_size, 3]),
[0, 2, 1, 3])
noisy_images = tf.transpose(
tf.reshape(
noisy_images,
[1, image_size * hps.n_scale, image_size * hps.n_scale, 3]),
[0, 2, 1, 3])
tf.summary.image(
"noise",
tf.cast(noisy_images, tf.uint8),
max_outputs=1)
compressed_images = tf.clip_by_value(compressed_images, 0, 1) * 255.
compressed_images = tf.reshape(
compressed_images,
[(hps.n_scale * hps.n_scale), image_size, image_size, 3])
compressed_images = tf.transpose(
tf.reshape(
compressed_images,
[hps.n_scale, image_size * hps.n_scale, image_size, 3]),
[0, 2, 1, 3])
compressed_images = tf.transpose(
tf.reshape(
compressed_images,
[1, image_size * hps.n_scale, image_size * hps.n_scale, 3]),
[0, 2, 1, 3])
tf.summary.image(
"compression",
tf.cast(compressed_images, tf.uint8),
max_outputs=1)
# SAMPLES x2
final_shape[0] *= 2
final_shape[1] *= 2
big_sample = standard_normal_sample([25] + final_shape)
big_sample, _ = decoder(
input_=big_sample, hps=hps, n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=True)
big_sample = tf.nn.sigmoid(big_sample)
big_sample = tf.clip_by_value(big_sample, 0, 1) * 255.
big_sample = tf.reshape(
big_sample,
[25, image_size * 2, image_size * 2, 3])
big_sample = tf.transpose(
tf.reshape(
big_sample,
[5, image_size * 10, image_size * 2, 3]), [0, 2, 1, 3])
big_sample = tf.transpose(
tf.reshape(
big_sample,
[1, image_size * 10, image_size * 10, 3]),
[0, 2, 1, 3])
tf.summary.image(
"big_sample",
tf.cast(big_sample, tf.uint8),
max_outputs=1)
# SAMPLES x10
final_shape[0] *= 5
final_shape[1] *= 5
extra_large = standard_normal_sample([1] + final_shape)
extra_large, _ = decoder(
input_=extra_large, hps=hps, n_scale=hps.n_scale,
use_batch_norm=hps.use_batch_norm, weight_norm=True,
train=True)
extra_large = tf.nn.sigmoid(extra_large)
extra_large = tf.clip_by_value(extra_large, 0, 1) * 255.
tf.summary.image(
"extra_large",
tf.cast(extra_large, tf.uint8),
max_outputs=1)
def eval_epoch(self, hps):
"""Evaluate bits/dim."""
n_eval_dict = {
"imnet": 50000,
"lsun": 300,
"celeba": 19962,
"svhn": 26032,
}
if FLAGS.eval_set_size == 0:
num_examples_eval = n_eval_dict[FLAGS.dataset]
else:
num_examples_eval = FLAGS.eval_set_size
n_epoch = num_examples_eval / hps.batch_size
eval_costs = []
bar_len = 70
for epoch_idx in xrange(n_epoch):
n_equal = epoch_idx * bar_len * 1. / n_epoch
n_equal = numpy.ceil(n_equal)
n_equal = int(n_equal)
n_dash = bar_len - n_equal
progress_bar = "[" + "=" * n_equal + "-" * n_dash + "]\r"
print progress_bar,
cost = self.bit_per_dim.eval()
eval_costs.append(cost)
print ""
return float(numpy.mean(eval_costs))
def train_model(hps, logdir):
"""Training."""
with tf.Graph().as_default():
with tf.device(tf.train.replica_device_setter(0)):
with tf.variable_scope("model"):
model = RealNVP(hps)
saver = tf.train.Saver(tf.global_variables())
# Build the summary operation from the last tower summaries.
summary_op = tf.summary.merge_all()
# Build an initialization operation to run below.
init = tf.global_variables_initializer()
# Start running operations on the Graph. allow_soft_placement must be set to
# True to build towers on GPU, as some of the ops do not have GPU
# implementations.
sess = tf.Session(config=tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=True))
sess.run(init)
ckpt_state = tf.train.get_checkpoint_state(logdir)
if ckpt_state and ckpt_state.model_checkpoint_path:
print "Loading file %s" % ckpt_state.model_checkpoint_path
saver.restore(sess, ckpt_state.model_checkpoint_path)
# Start the queue runners.
tf.train.start_queue_runners(sess=sess)
summary_writer = tf.summary.FileWriter(
logdir,
graph=sess.graph)
local_step = 0
while True:
fetches = [model.step, model.bit_per_dim, model.train_step]
# The chief worker evaluates the summaries every 10 steps.
should_eval_summaries = local_step % 100 == 0
if should_eval_summaries:
fetches += [summary_op]
start_time = time.time()
outputs = sess.run(fetches)
global_step_val = outputs[0]
loss = outputs[1]
duration = time.time() - start_time
assert not numpy.isnan(
loss), 'Model diverged with loss = NaN'
if local_step % 10 == 0:
examples_per_sec = hps.batch_size / float(duration)
format_str = ('%s: step %d, loss = %.2f '
'(%.1f examples/sec; %.3f '
'sec/batch)')
print format_str % (datetime.now(), global_step_val, loss,
examples_per_sec, duration)
if should_eval_summaries:
summary_str = outputs[-1]
summary_writer.add_summary(summary_str, global_step_val)
# Save the model checkpoint periodically.
if local_step % 1000 == 0 or (local_step + 1) == FLAGS.train_steps:
checkpoint_path = os.path.join(logdir, 'model.ckpt')
saver.save(
sess,
checkpoint_path,
global_step=global_step_val)
if outputs[0] >= FLAGS.train_steps:
break
local_step += 1
def evaluate(hps, logdir, traindir, subset="valid", return_val=False):
"""Evaluation."""
hps.batch_size = 100
with tf.Graph().as_default():
with tf.device("/cpu:0"):
with tf.variable_scope("model") as var_scope:
eval_model = RealNVP(hps)
summary_writer = tf.summary.FileWriter(logdir)
var_scope.reuse_variables()
saver = tf.train.Saver()
sess = tf.Session(config=tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=True))
tf.train.start_queue_runners(sess)
previous_global_step = 0 # don"t run eval for step = 0
with sess.as_default():
while True:
ckpt_state = tf.train.get_checkpoint_state(traindir)
if not (ckpt_state and ckpt_state.model_checkpoint_path):
print "No model to eval yet at %s" % traindir
time.sleep(30)
continue
print "Loading file %s" % ckpt_state.model_checkpoint_path
saver.restore(sess, ckpt_state.model_checkpoint_path)
current_step = tf.train.global_step(sess, eval_model.step)
if current_step == previous_global_step:
print "Waiting for the checkpoint to be updated."
time.sleep(30)
continue
previous_global_step = current_step
print "Evaluating..."
bit_per_dim = eval_model.eval_epoch(hps)
print ("Epoch: %d, %s -> %.3f bits/dim"
% (current_step, subset, bit_per_dim))
print "Writing summary..."
summary = tf.Summary()
summary.value.extend(
[tf.Summary.Value(
tag="bit_per_dim",
simple_value=bit_per_dim)])
summary_writer.add_summary(summary, current_step)
if return_val:
return current_step, bit_per_dim
def sample_from_model(hps, logdir, traindir):
"""Sampling."""
hps.batch_size = 100
with tf.Graph().as_default():
with tf.device("/cpu:0"):
with tf.variable_scope("model") as var_scope:
eval_model = RealNVP(hps, sampling=True)
summary_writer = tf.summary.FileWriter(logdir)
var_scope.reuse_variables()
summary_op = tf.summary.merge_all()
saver = tf.train.Saver()
sess = tf.Session(config=tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=True))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
previous_global_step = 0 # don"t run eval for step = 0
initialized = False
with sess.as_default():
while True:
ckpt_state = tf.train.get_checkpoint_state(traindir)
if not (ckpt_state and ckpt_state.model_checkpoint_path):
if not initialized:
print "No model to eval yet at %s" % traindir
time.sleep(30)
continue
else:
print ("Loading file %s"
% ckpt_state.model_checkpoint_path)
saver.restore(sess, ckpt_state.model_checkpoint_path)
current_step = tf.train.global_step(sess, eval_model.step)
if current_step == previous_global_step:
print "Waiting for the checkpoint to be updated."
time.sleep(30)
continue
previous_global_step = current_step
fetches = [summary_op]
outputs = sess.run(fetches)
summary_writer.add_summary(outputs[0], current_step)
coord.request_stop()
coord.join(threads)
def main(unused_argv):
hps = get_default_hparams().update_config(FLAGS.hpconfig)
if FLAGS.mode == "train":
train_model(hps=hps, logdir=FLAGS.logdir)
elif FLAGS.mode == "sample":
sample_from_model(hps=hps, logdir=FLAGS.logdir,
traindir=FLAGS.traindir)
else:
hps.batch_size = 100
evaluate(hps=hps, logdir=FLAGS.logdir,
traindir=FLAGS.traindir, subset=FLAGS.mode)
if __name__ == "__main__":
tf.app.run()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Utility functions for Real NVP.
"""
# pylint: disable=dangerous-default-value
import numpy
import tensorflow as tf
from tensorflow.python.framework import ops
DEFAULT_BN_LAG = .0
def stable_var(input_, mean=None, axes=[0]):
"""Numerically more stable variance computation."""
if mean is None:
mean = tf.reduce_mean(input_, axes)
res = tf.square(input_ - mean)
max_sqr = tf.reduce_max(res, axes)
res /= max_sqr
res = tf.reduce_mean(res, axes)
res *= max_sqr
return res
def variable_on_cpu(name, shape, initializer, trainable=True):
"""Helper to create a Variable stored on CPU memory.
Args:
name: name of the variable
shape: list of ints
initializer: initializer for Variable
trainable: boolean defining if the variable is for training
Returns:
Variable Tensor
"""
var = tf.get_variable(
name, shape, initializer=initializer, trainable=trainable)
return var
# layers
def conv_layer(input_,
filter_size,
dim_in,
dim_out,
name,
stddev=1e-2,
strides=[1, 1, 1, 1],
padding="SAME",
nonlinearity=None,
bias=False,
weight_norm=False,
scale=False):
"""Convolutional layer."""
with tf.variable_scope(name) as scope:
weights = variable_on_cpu(
"weights",
filter_size + [dim_in, dim_out],
tf.random_uniform_initializer(
minval=-stddev, maxval=stddev))
# weight normalization
if weight_norm:
weights /= tf.sqrt(tf.reduce_sum(tf.square(weights), [0, 1, 2]))
if scale:
magnitude = variable_on_cpu(
"magnitude", [dim_out],
tf.constant_initializer(
stddev * numpy.sqrt(dim_in * numpy.prod(filter_size) / 12.)))
weights *= magnitude
res = input_
# handling filter size bigger than image size
if hasattr(input_, "shape"):
if input_.get_shape().as_list()[1] < filter_size[0]:
pad_1 = tf.zeros([
input_.get_shape().as_list()[0],
filter_size[0] - input_.get_shape().as_list()[1],
input_.get_shape().as_list()[2],
input_.get_shape().as_list()[3]
])
pad_2 = tf.zeros([
input_.get_shape().as_list[0],
filter_size[0],
filter_size[1] - input_.get_shape().as_list()[2],
input_.get_shape().as_list()[3]
])
res = tf.concat(1, [pad_1, res])
res = tf.concat(2, [pad_2, res])
res = tf.nn.conv2d(
input=res,
filter=weights,
strides=strides,
padding=padding,
name=scope.name)
if hasattr(input_, "shape"):
if input_.get_shape().as_list()[1] < filter_size[0]:
res = tf.slice(res, [
0, filter_size[0] - input_.get_shape().as_list()[1],
filter_size[1] - input_.get_shape().as_list()[2], 0
], [-1, -1, -1, -1])
if bias:
biases = variable_on_cpu("biases", [dim_out], tf.constant_initializer(0.))
res = tf.nn.bias_add(res, biases)
if nonlinearity is not None:
res = nonlinearity(res)
return res
def max_pool_2x2(input_):
"""Max pooling."""
return tf.nn.max_pool(
input_, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
def depool_2x2(input_, stride=2):
"""Depooling."""
shape = input_.get_shape().as_list()
batch_size = shape[0]
height = shape[1]
width = shape[2]
channels = shape[3]
res = tf.reshape(input_, [batch_size, height, 1, width, 1, channels])
res = tf.concat(
2, [res, tf.zeros([batch_size, height, stride - 1, width, 1, channels])])
res = tf.concat(4, [
res, tf.zeros([batch_size, height, stride, width, stride - 1, channels])
])
res = tf.reshape(res, [batch_size, stride * height, stride * width, channels])
return res
# random flip on a batch of images
def batch_random_flip(input_):
"""Simultaneous horizontal random flip."""
if isinstance(input_, (float, int)):
return input_
shape = input_.get_shape().as_list()
batch_size = shape[0]
height = shape[1]
width = shape[2]
channels = shape[3]
res = tf.split(0, batch_size, input_)
res = [elem[0, :, :, :] for elem in res]
res = [tf.image.random_flip_left_right(elem) for elem in res]
res = [tf.reshape(elem, [1, height, width, channels]) for elem in res]
res = tf.concat(0, res)
return res
# build a one hot representation corresponding to the integer tensor
# the one-hot dimension is appended to the integer tensor shape
def as_one_hot(input_, n_indices):
"""Convert indices to one-hot."""
shape = input_.get_shape().as_list()
n_elem = numpy.prod(shape)
indices = tf.range(n_elem)
indices = tf.cast(indices, tf.int64)
indices_input = tf.concat(0, [indices, tf.reshape(input_, [-1])])
indices_input = tf.reshape(indices_input, [2, -1])
indices_input = tf.transpose(indices_input)
res = tf.sparse_to_dense(
indices_input, [n_elem, n_indices], 1., 0., name="flat_one_hot")
res = tf.reshape(res, [elem for elem in shape] + [n_indices])
return res
def squeeze_2x2(input_):
"""Squeezing operation: reshape to convert space to channels."""
return squeeze_nxn(input_, n_factor=2)
def squeeze_nxn(input_, n_factor=2):
"""Squeezing operation: reshape to convert space to channels."""
if isinstance(input_, (float, int)):
return input_
shape = input_.get_shape().as_list()
batch_size = shape[0]
height = shape[1]
width = shape[2]
channels = shape[3]
if height % n_factor != 0:
raise ValueError("Height not divisible by %d." % n_factor)
if width % n_factor != 0:
raise ValueError("Width not divisible by %d." % n_factor)
res = tf.reshape(
input_,
[batch_size,
height // n_factor,
n_factor, width // n_factor,
n_factor, channels])
res = tf.transpose(res, [0, 1, 3, 5, 2, 4])
res = tf.reshape(
res,
[batch_size,
height // n_factor,
width // n_factor,
channels * n_factor * n_factor])
return res
def unsqueeze_2x2(input_):
"""Unsqueezing operation: reshape to convert channels into space."""
if isinstance(input_, (float, int)):
return input_
shape = input_.get_shape().as_list()
batch_size = shape[0]
height = shape[1]
width = shape[2]
channels = shape[3]
if channels % 4 != 0:
raise ValueError("Number of channels not divisible by 4.")
res = tf.reshape(input_, [batch_size, height, width, channels // 4, 2, 2])
res = tf.transpose(res, [0, 1, 4, 2, 5, 3])
res = tf.reshape(res, [batch_size, 2 * height, 2 * width, channels // 4])
return res
# batch norm
def batch_norm(input_,
dim,
name,
scale=True,
train=True,
epsilon=1e-8,
decay=.1,
axes=[0],
bn_lag=DEFAULT_BN_LAG):
"""Batch normalization."""
# create variables
with tf.variable_scope(name):
var = variable_on_cpu(
"var", [dim], tf.constant_initializer(1.), trainable=False)
mean = variable_on_cpu(
"mean", [dim], tf.constant_initializer(0.), trainable=False)
step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
if scale:
gamma = variable_on_cpu("gamma", [dim], tf.constant_initializer(1.))
beta = variable_on_cpu("beta", [dim], tf.constant_initializer(0.))
# choose the appropriate moments
if train:
used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
cur_mean, cur_var = used_mean, used_var
if bn_lag > 0.:
used_mean -= (1. - bn_lag) * (used_mean - tf.stop_gradient(mean))
used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
used_mean /= (1. - bn_lag**(step + 1))
used_var /= (1. - bn_lag**(step + 1))
else:
used_mean, used_var = mean, var
cur_mean, cur_var = used_mean, used_var
# normalize
res = (input_ - used_mean) / tf.sqrt(used_var + epsilon)
# de-normalize
if scale:
res *= gamma
res += beta
# update variables
if train:
with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
with ops.colocate_with(mean):
new_mean = tf.assign_sub(
mean,
tf.check_numerics(decay * (mean - cur_mean), "NaN in moving mean."))
with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
with ops.colocate_with(var):
new_var = tf.assign_sub(
var,
tf.check_numerics(decay * (var - cur_var),
"NaN in moving variance."))
with tf.name_scope(name, "IncrementTime", [step]):
with ops.colocate_with(step):
new_step = tf.assign_add(step, 1.)
res += 0. * new_mean * new_var * new_step
return res
# batch normalization taking into account the volume transformation
def batch_norm_log_diff(input_,
dim,
name,
train=True,
epsilon=1e-8,
decay=.1,
axes=[0],
reuse=None,
bn_lag=DEFAULT_BN_LAG):
"""Batch normalization with corresponding log determinant Jacobian."""
if reuse is None:
reuse = not train
# create variables
with tf.variable_scope(name) as scope:
if reuse:
scope.reuse_variables()
var = variable_on_cpu(
"var", [dim], tf.constant_initializer(1.), trainable=False)
mean = variable_on_cpu(
"mean", [dim], tf.constant_initializer(0.), trainable=False)
step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
# choose the appropriate moments
if train:
used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
cur_mean, cur_var = used_mean, used_var
if bn_lag > 0.:
used_var = stable_var(input_=input_, mean=used_mean, axes=axes)
cur_var = used_var
used_mean -= (1 - bn_lag) * (used_mean - tf.stop_gradient(mean))
used_mean /= (1. - bn_lag**(step + 1))
used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
used_var /= (1. - bn_lag**(step + 1))
else:
used_mean, used_var = mean, var
cur_mean, cur_var = used_mean, used_var
# update variables
if train:
with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
with ops.colocate_with(mean):
new_mean = tf.assign_sub(
mean,
tf.check_numerics(
decay * (mean - cur_mean), "NaN in moving mean."))
with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
with ops.colocate_with(var):
new_var = tf.assign_sub(
var,
tf.check_numerics(decay * (var - cur_var),
"NaN in moving variance."))
with tf.name_scope(name, "IncrementTime", [step]):
with ops.colocate_with(step):
new_step = tf.assign_add(step, 1.)
used_var += 0. * new_mean * new_var * new_step
used_var += epsilon
return used_mean, used_var
def convnet(input_,
dim_in,
dim_hid,
filter_sizes,
dim_out,
name,
use_batch_norm=True,
train=True,
nonlinearity=tf.nn.relu):
"""Chaining of convolutional layers."""
dims_in = [dim_in] + dim_hid[:-1]
dims_out = dim_hid
res = input_
bias = (not use_batch_norm)
with tf.variable_scope(name):
for layer_idx in xrange(len(dim_hid)):
res = conv_layer(
input_=res,
filter_size=filter_sizes[layer_idx],
dim_in=dims_in[layer_idx],
dim_out=dims_out[layer_idx],
name="h_%d" % layer_idx,
stddev=1e-2,
nonlinearity=None,
bias=bias)
if use_batch_norm:
res = batch_norm(
input_=res,
dim=dims_out[layer_idx],
name="bn_%d" % layer_idx,
scale=(nonlinearity == tf.nn.relu),
train=train,
epsilon=1e-8,
axes=[0, 1, 2])
if nonlinearity is not None:
res = nonlinearity(res)
res = conv_layer(
input_=res,
filter_size=filter_sizes[-1],
dim_in=dims_out[-1],
dim_out=dim_out,
name="out",
stddev=1e-2,
nonlinearity=None)
return res
# distributions
# log-likelihood estimation
def standard_normal_ll(input_):
"""Log-likelihood of standard Gaussian distribution."""
res = -.5 * (tf.square(input_) + numpy.log(2. * numpy.pi))
return res
def standard_normal_sample(shape):
"""Samples from standard Gaussian distribution."""
return tf.random_normal(shape)
SQUEEZE_MATRIX = numpy.array([[[[1., 0., 0., 0.]], [[0., 0., 1., 0.]]],
[[[0., 0., 0., 1.]], [[0., 1., 0., 0.]]]])
def squeeze_2x2_ordered(input_, reverse=False):
"""Squeezing operation with a controlled ordering."""
shape = input_.get_shape().as_list()
batch_size = shape[0]
height = shape[1]
width = shape[2]
channels = shape[3]
if reverse:
if channels % 4 != 0:
raise ValueError("Number of channels not divisible by 4.")
channels /= 4
else:
if height % 2 != 0:
raise ValueError("Height not divisible by 2.")
if width % 2 != 0:
raise ValueError("Width not divisible by 2.")
weights = numpy.zeros((2, 2, channels, 4 * channels))
for idx_ch in xrange(channels):
slice_2 = slice(idx_ch, (idx_ch + 1))
slice_3 = slice((idx_ch * 4), ((idx_ch + 1) * 4))
weights[:, :, slice_2, slice_3] = SQUEEZE_MATRIX
shuffle_channels = [idx_ch * 4 for idx_ch in xrange(channels)]
shuffle_channels += [idx_ch * 4 + 1 for idx_ch in xrange(channels)]
shuffle_channels += [idx_ch * 4 + 2 for idx_ch in xrange(channels)]
shuffle_channels += [idx_ch * 4 + 3 for idx_ch in xrange(channels)]
shuffle_channels = numpy.array(shuffle_channels)
weights = weights[:, :, :, shuffle_channels].astype("float32")
if reverse:
res = tf.nn.conv2d_transpose(
value=input_,
filter=weights,
output_shape=[batch_size, height * 2, width * 2, channels],
strides=[1, 2, 2, 1],
padding="SAME",
name="unsqueeze_2x2")
else:
res = tf.nn.conv2d(
input=input_,
filter=weights,
strides=[1, 2, 2, 1],
padding="SAME",
name="squeeze_2x2")
return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment