"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "d0854bf1e663f77a537c5fcccdd97577fa02b686"
Commit 061142a0 authored by Xin Pan's avatar Xin Pan Committed by GitHub
Browse files

Merge pull request #800 from panyx0718/master

Add cross conv model for next frame prediction.
parents 33bc8b14 ba986cfc
<font size=4><b>Visual Dynamics: Probabilistic Future Frame Synthesis via Cross Convolutional Networks.</b></font>
<b>Introduction</b>
https://arxiv.org/pdf/1607.02586v1.pdf
This is an implementation based on my understanding, with small
variations. It doesn't necessarily represents the paper published
by the original authors.
Authors: Xin Pan (Github: panyx0718), Anelia Angelova
<b>Results:</b>
<left>
![Sample1](g3doc/cross_conv.png)
</left>
<left>
![Sample2](g3doc/cross_conv2.png)
</left>
<left>
![Loss](g3doc/cross_conv3.png)
</left>
<b>Prerequisite:</b>
1. Install TensorFlow (r0.12), Bazel.
2. Download the Sprites dataset or generate moving object dataset.
Sprites data is located here:
http://www.scottreed.info/files/nips2015-analogy-data.tar.gz
Convert .mat files into images and use sprites_gen.py to convert them
to tf.SequenceExample.
<b>How to run:</b>
```shell
ls -R
.:
data next_frame_prediction WORKSPACE
./data:
tfrecords tfrecords_test
./next_frame_prediction:
cross_conv g3doc README.md
./next_frame_prediction/cross_conv:
BUILD eval.py objects_gen.py model.py reader.py sprites_gen.py train.py
./next_frame_prediction/g3doc:
cross_conv2.png cross_conv3.png cross_conv.png
# Build everything.
bazel build -c opt next_frame_prediction/...
# The following example runs the generated 2d objects.
# For Sprites dataset, image_size should be 60, norm_scale should be 255.0.
# Batch size is normally 16~64, depending on your memory size.
#
# Run training.
bazel-bin/next_frame_prediction/cross_conv/train \
--batch_size=1 \
--data_filepattern=data/tfrecords \
--image_size=64 \
--log_root=/tmp/predict
step: 1, loss: 24.428671
step: 2, loss: 19.211605
step: 3, loss: 5.543143
step: 4, loss: 3.035339
step: 5, loss: 1.771392
step: 6, loss: 2.099824
step: 7, loss: 1.747665
step: 8, loss: 1.572436
step: 9, loss: 1.586816
step: 10, loss: 1.434191
#
# Run eval.
bazel-bin/next_frame_prediction/cross_conv/eval \
--batch_size=1 \
--data_filepattern=data/tfrecords_test \
--image_size=64 \
--log_root=/tmp/predict
```
licenses(["notice"]) # Apache 2.0
package_group(
name = "internal",
packages = [
"//next_frame_prediction/...",
],
)
package(default_visibility = [":internal"])
py_library(
name = "model",
srcs = ["model.py"],
)
py_library(
name = "reader",
srcs = ["reader.py"],
)
py_binary(
name = "train",
srcs = ["train.py"],
deps = [
":model",
":reader",
],
)
py_binary(
name = "eval",
srcs = ["eval.py"],
deps = [
":model",
":reader",
],
)
py_binary(
name = "example_gen",
srcs = ["example_gen.py"],
)
py_binary(
name = "sprites_gen",
srcs = ["sprites_gen.py"],
)
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Eval Cross Convolutional Model."""
import io
import os
import sys
import time
import numpy as np
import tensorflow as tf
import model as cross_conv_model
import reader
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('log_root', '/tmp/moving_obj', 'The root dir of output.')
tf.flags.DEFINE_string('data_filepattern',
'est',
'training data file pattern.')
tf.flags.DEFINE_integer('batch_size', 1, 'Batch size.')
tf.flags.DEFINE_integer('image_size', 64, 'Image height and width.')
tf.flags.DEFINE_float('norm_scale', 1.0, 'Normalize the original image')
tf.flags.DEFINE_float('scale', 10.0,
'Scale the image after norm_scale and move the diff '
'to the positive realm.')
tf.flags.DEFINE_integer('sequence_length', 2, 'tf.SequenceExample length.')
tf.flags.DEFINE_integer('eval_batch_count', 100,
'Average the result this number of examples.')
tf.flags.DEFINE_bool('l2_loss', True, 'If true, include l2_loss.')
tf.flags.DEFINE_bool('reconstr_loss', False, 'If true, include reconstr_loss.')
tf.flags.DEFINE_bool('kl_loss', True, 'If true, include KL loss.')
slim = tf.contrib.slim
def _Eval():
params = dict()
params['batch_size'] = FLAGS.batch_size
params['seq_len'] = FLAGS.sequence_length
params['image_size'] = FLAGS.image_size
params['is_training'] = False
params['norm_scale'] = FLAGS.norm_scale
params['scale'] = FLAGS.scale
params['l2_loss'] = FLAGS.l2_loss
params['reconstr_loss'] = FLAGS.reconstr_loss
params['kl_loss'] = FLAGS.kl_loss
eval_dir = os.path.join(FLAGS.log_root, 'eval')
images = reader.ReadInput(
FLAGS.data_filepattern, shuffle=False, params=params)
images *= params['scale']
# Increase the value makes training much faster.
image_diff_list = reader.SequenceToImageAndDiff(images)
model = cross_conv_model.CrossConvModel(image_diff_list, params)
model.Build()
summary_writer = tf.summary.FileWriter(eval_dir)
saver = tf.train.Saver()
sess = tf.Session('', config=tf.ConfigProto(allow_soft_placement=True))
tf.train.start_queue_runners(sess)
while True:
time.sleep(60)
try:
ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
except tf.errors.OutOfRangeError as e:
sys.stderr.write('Cannot restore checkpoint: %s\n' % e)
continue
if not (ckpt_state and ckpt_state.model_checkpoint_path):
sys.stderr.write('No model to eval yet at %s\n' % FLAGS.log_root)
continue
sys.stderr.write('Loading checkpoint %s\n' %
ckpt_state.model_checkpoint_path)
saver.restore(sess, ckpt_state.model_checkpoint_path)
# Use the empirical distribution of z from training set.
if not tf.gfile.Exists(os.path.join(FLAGS.log_root, 'z_mean.npy')):
sys.stderr.write('No z at %s\n' % FLAGS.log_root)
continue
with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy')) as f:
sample_z_mean = np.load(io.BytesIO(f.read()))
with tf.gfile.Open(
os.path.join(FLAGS.log_root, 'z_stddev_log.npy')) as f:
sample_z_stddev_log = np.load(io.BytesIO(f.read()))
total_loss = 0.0
for _ in xrange(FLAGS.eval_batch_count):
loss_val, total_steps, summaries = sess.run(
[model.loss, model.global_step, model.summary_op],
feed_dict={model.z_mean: sample_z_mean,
model.z_stddev_log: sample_z_stddev_log})
total_loss += loss_val
summary_writer.add_summary(summaries, total_steps)
sys.stderr.write('steps: %d, loss: %f\n' %
(total_steps, total_loss / FLAGS.eval_batch_count))
def main(_):
_Eval()
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generate examples of two objects moving in different directions."""
import random
import sys
import numpy as np
import tensorflow as tf
tf.flags.DEFINE_string('out_file', '',
'Output file for the tfrecords.')
def _add_object(obj_type, image, image2, xpos, ypos):
"""Add a moving obj to two consecutive images."""
obj_size = random.randint(8, 10)
channel = random.randint(0, 2)
move = random.randint(6, 10)
obj = np.zeros([obj_size, obj_size, 3])
if obj_type == 'rectangle':
xpos2 = xpos + move
ypos2 = ypos
for i in xrange(obj_size):
obj[i, 0:i+1, channel] = [1.0 for _ in xrange(i+1)]
elif obj_type == 'square':
xpos2 = xpos
ypos2 = ypos + move
obj[:, :, channel] = 1.0
for x in xrange(obj_size):
for y in xrange(obj_size):
if obj[x, y, channel] == 1.0:
image[xpos+x, ypos+y, channel] = 1.0
image2[xpos2+x, ypos2+y, channel] = 1.0
def _images_to_example(image, image2):
"""Convert two consecutive images to SequenceExample."""
example = tf.SequenceExample()
feature_list = example.feature_lists.feature_list['moving_objs']
feature = feature_list.feature.add()
feature.float_list.value.extend(np.reshape(image, [-1]).tolist())
feature = feature_list.feature.add()
feature.float_list.value.extend(np.reshape(image2, [-1]).tolist())
return example
def generate_input():
"""Generate tfrecords."""
writer = tf.python_io.TFRecordWriter(tf.flags.FLAGS.out_file)
writer2 = tf.python_io.TFRecordWriter(tf.flags.FLAGS.out_file + '_test')
examples = []
for xpos in xrange(0, 40, 3):
for ypos in xrange(0, 40, 3):
for xpos2 in xrange(0, 40, 3):
for ypos2 in xrange(0, 40, 3):
image = np.zeros([64, 64, 3])
image2 = np.zeros([64, 64, 3])
_add_object('rectangle', image, image2, xpos, ypos)
_add_object('square', image, image2, xpos2, ypos2)
examples.append(_images_to_example(image, image2))
sys.stderr.write('Finish generating examples.\n')
random.shuffle(examples)
for count, ex in enumerate(examples):
if count % 10 == 0:
writer2.write(ex.SerializeToString())
else:
writer.write(ex.SerializeToString())
def main(_):
generate_input()
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cross Convolutional Model.
https://arxiv.org/pdf/1607.02586v1.pdf
"""
import math
import sys
import tensorflow as tf
slim = tf.contrib.slim
class CrossConvModel(object):
def __init__(self, image_diff_list, params):
"""Constructor.
Args:
image_diff_list: A list of (image, diff) tuples, with shape
[batch_size, image_size, image_size, 3] and image_sizes as
[32, 64, 128, 256].
params: Dict of parameters.
"""
self.images = [i for (i, _) in image_diff_list]
# Move the diff to the positive realm.
self.diffs = [(d + params['scale']) / 2 for (i, d) in image_diff_list]
self.params = params
def Build(self):
with tf.device('/gpu:0'):
with slim.arg_scope([slim.conv2d],
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params={'is_training':
self.params['is_training']}):
self._BuildMotionKernel()
encoded_images = self._BuildImageEncoder()
cross_conved_images = self._CrossConv(encoded_images)
self._BuildImageDecoder(cross_conved_images)
self._BuildLoss()
image = self.images[1]
diff = self.diffs[1]
self.global_step = tf.Variable(0, name='global_step', trainable=False)
if self.params['is_training']:
self._BuildTrainOp()
diff = diff * 2.0 - self.params['scale']
diff_output = self.diff_output * 2.0 - self.params['scale']
concat_image = tf.concat(
1, [image, image + diff_output, image + diff, diff_output])
tf.summary.image('origin_predict_expect_predictdiff', concat_image)
self.summary_op = tf.summary.merge_all()
return self.loss
def _BuildTrainOp(self):
lrn_rate = tf.maximum(
0.01, # min_lr_rate.
tf.train.exponential_decay(
self.params['learning_rate'], self.global_step, 10000, 0.5))
tf.summary.scalar('learning rate', lrn_rate)
optimizer = tf.train.GradientDescentOptimizer(lrn_rate)
self.train_op = slim.learning.create_train_op(
self.loss, optimizer, global_step=self.global_step)
def _BuildLoss(self):
# 1. reconstr_loss seems doesn't do better than l2 loss.
# 2. Only works when using reduce_mean. reduce_sum doesn't work.
# 3. It seems kl loss doesn't play an important role.
self.loss = 0
with tf.variable_scope('loss'):
if self.params['l2_loss']:
l2_loss = tf.reduce_mean(tf.square(self.diff_output - self.diffs[1]))
tf.summary.scalar('l2_loss', l2_loss)
self.loss += l2_loss
if self.params['reconstr_loss']:
reconstr_loss = (-tf.reduce_mean(
self.diffs[1] * (1e-10 + self.diff_output) +
(1-self.diffs[1]) * tf.log(1e-10 + 1 - self.diff_output)))
reconstr_loss = tf.check_numerics(reconstr_loss, 'reconstr_loss')
tf.summary.scalar('reconstr_loss', reconstr_loss)
self.loss += reconstr_loss
if self.params['kl_loss']:
kl_loss = (0.5 * tf.reduce_mean(
tf.square(self.z_mean) + tf.square(self.z_stddev) -
2 * self.z_stddev_log - 1))
tf.summary.scalar('kl_loss', kl_loss)
self.loss += kl_loss
tf.summary.scalar('loss', self.loss)
def _BuildMotionKernel(self):
image = self.images[-2]
diff = self.diffs[-2]
shape = image.get_shape().as_list()
assert shape[1] == shape[2] and shape[1] == 128
batch_size = shape[0]
net = tf.concat(3, [image, diff])
with tf.variable_scope('motion_encoder'):
with slim.arg_scope([slim.conv2d], padding='VALID'):
net = slim.conv2d(net, 96, [5, 5], stride=1)
net = slim.max_pool2d(net, [2, 2])
net = slim.conv2d(net, 96, [5, 5], stride=1)
net = slim.max_pool2d(net, [2, 2])
net = slim.conv2d(net, 128, [5, 5], stride=1)
net = slim.conv2d(net, 128, [5, 5], stride=1)
net = slim.max_pool2d(net, [2, 2])
net = slim.conv2d(net, 256, [4, 4], stride=1)
net = slim.conv2d(net, 256, [3, 3], stride=1)
z = tf.reshape(net, shape=[batch_size, -1])
self.z_mean, self.z_stddev_log = tf.split(
split_dim=1, num_split=2, value=z)
self.z_stddev = tf.exp(self.z_stddev_log)
epsilon = tf.random_normal(
self.z_mean.get_shape().as_list(), 0, 1, dtype=tf.float32)
kernel = self.z_mean + tf.multiply(self.z_stddev, epsilon)
width = int(math.sqrt(kernel.get_shape().as_list()[1] // 128))
kernel = tf.reshape(kernel, [batch_size, width, width, 128])
with tf.variable_scope('kernel_decoder'):
with slim.arg_scope([slim.conv2d], padding='SAME'):
kernel = slim.conv2d(kernel, 128, [5, 5], stride=1)
self.kernel = slim.conv2d(kernel, 128, [5, 5], stride=1)
sys.stderr.write('kernel shape: %s\n' % kernel.get_shape())
def _BuildImageEncoder(self):
feature_maps = []
for (i, image) in enumerate(self.images):
with tf.variable_scope('image_encoder_%d' % i):
with slim.arg_scope([slim.conv2d, slim.max_pool2d], padding='SAME'):
net = slim.conv2d(image, 64, [5, 5], stride=1)
net = slim.conv2d(net, 64, [5, 5], stride=1)
net = slim.max_pool2d(net, [5, 5])
net = slim.conv2d(net, 64, [5, 5], stride=1)
net = slim.conv2d(net, 32, [5, 5], stride=1)
net = slim.max_pool2d(net, [2, 2])
sys.stderr.write('image_conv shape: %s\n' % net.get_shape())
feature_maps.append(net)
return feature_maps
def _CrossConvHelper(self, encoded_image, kernel):
"""Cross Convolution.
The encoded image and kernel are of the same shape. Namely
[batch_size, image_size, image_size, channels]. They are split
into [image_size, image_size] image squares [kernel_size, kernel_size]
kernel squares. kernel squares are used to convolute image squares.
"""
images = tf.expand_dims(encoded_image, 0)
kernels = tf.expand_dims(kernel, 3)
return tf.nn.depthwise_conv2d(images, kernels, [1, 1, 1, 1], 'SAME')
def _CrossConv(self, encoded_images):
"""Apply the motion kernel on the encoded_images."""
cross_conved_images = []
kernels = tf.split(split_dim=3, num_split=4, value=self.kernel)
for (i, encoded_image) in enumerate(encoded_images):
with tf.variable_scope('cross_conv_%d' % i):
kernel = kernels[i]
encoded_image = tf.unstack(encoded_image, axis=0)
kernel = tf.unstack(kernel, axis=0)
assert len(encoded_image) == len(kernel)
assert len(encoded_image) == self.params['batch_size']
conved_image = []
for j in xrange(len(encoded_image)):
conved_image.append(self._CrossConvHelper(
encoded_image[j], kernel[j]))
cross_conved_images.append(tf.concat(0, conved_image))
sys.stderr.write('cross_conved shape: %s\n' %
cross_conved_images[-1].get_shape())
return cross_conved_images
def _Deconv(self, net, out_filters, kernel_size, stride):
shape = net.get_shape().as_list()
in_filters = shape[3]
kernel_shape = [kernel_size, kernel_size, out_filters, in_filters]
weights = tf.get_variable(
name='weights',
shape=kernel_shape,
dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.01))
out_height = shape[1] * stride
out_width = shape[2] * stride
batch_size = shape[0]
output_shape = [batch_size, out_height, out_width, out_filters]
net = tf.nn.conv2d_transpose(net, weights, output_shape,
[1, stride, stride, 1], padding='SAME')
slim.batch_norm(net)
return net
def _BuildImageDecoder(self, cross_conved_images):
"""Decode the cross_conved feature maps into the predicted images."""
nets = []
for i, cross_conved_image in enumerate(cross_conved_images):
with tf.variable_scope('image_decoder_%d' % i):
stride = 64 / cross_conved_image.get_shape().as_list()[1]
# TODO(xpan): Alternative solution for upsampling?
nets.append(self._Deconv(
cross_conved_image, 64, kernel_size=3, stride=stride))
net = tf.concat(3, nets)
net = slim.conv2d(net, 128, [9, 9], padding='SAME', stride=1)
net = slim.conv2d(net, 128, [1, 1], padding='SAME', stride=1)
net = slim.conv2d(net, 3, [1, 1], padding='SAME', stride=1)
self.diff_output = net
sys.stderr.write('diff_output shape: %s\n' % self.diff_output.get_shape())
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Read image sequence."""
import tensorflow as tf
def SequenceToImageAndDiff(images):
"""Convert image sequence batch into image and diff batch.
Each image pair is converted to the first image and their diff.
Batch size will increase if sequence length is larger than 2.
Args:
images: Image sequence with shape
[batch_size, seq_len, image_size, image_size, channel]
Returns:
the list of (image, diff) tuples with shape
[batch_size2, image_size, image_size, channel]. image_sizes are
[32, 64, 128, 256].
"""
image_diff_list = []
image_seq = tf.unstack(images, axis=1)
for size in [32, 64, 128, 256]:
resized_images = [
tf.image.resize_images(i, [size, size]) for i in image_seq]
diffs = []
for i in xrange(0, len(resized_images)-1):
diffs.append(resized_images[i+1] - resized_images[i])
image_diff_list.append(
(tf.concat(0, resized_images[:-1]), tf.concat(0, diffs)))
return image_diff_list
def ReadInput(data_filepattern, shuffle, params):
"""Read the tf.SequenceExample tfrecord files.
Args:
data_filepattern: tf.SequenceExample tfrecord filepattern.
shuffle: Whether to shuffle the examples.
params: parameter dict.
Returns:
image sequence batch [batch_size, seq_len, image_size, image_size, channel].
"""
image_size = params['image_size']
filenames = tf.gfile.Glob(data_filepattern)
filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle)
reader = tf.TFRecordReader()
_, example = reader.read(filename_queue)
feature_sepc = {
'moving_objs': tf.FixedLenSequenceFeature(
shape=[image_size * image_size * 3], dtype=tf.float32)}
_, features = tf.parse_single_sequence_example(
example, sequence_features=feature_sepc)
moving_objs = tf.reshape(
features['moving_objs'], [params['seq_len'], image_size, image_size, 3])
if shuffle:
examples = tf.train.shuffle_batch(
[moving_objs],
batch_size=params['batch_size'],
num_threads=64,
capacity=params['batch_size'] * 100,
min_after_dequeue=params['batch_size'] * 4)
else:
examples = tf.train.batch([moving_objs],
batch_size=params['batch_size'],
num_threads=16,
capacity=params['batch_size'])
examples /= params['norm_scale']
return examples
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generate the sprites tfrecords from raw_images."""
import os
import random
import re
import sys
import numpy as np
import scipy.misc
import tensorflow as tf
tf.flags.DEFINE_string('data_filepattern', '', 'The raw images.')
tf.flags.DEFINE_string('out_file', '',
'File name for the tfrecord output.')
def _read_images():
"""Read images from image files into data structure."""
sprites = dict()
files = tf.gfile.Glob(tf.flags.FLAGS.data_filepattern)
for f in files:
image = scipy.misc.imread(f)
m = re.search('image_([0-9]+)_([0-9]+)_([0-9]+).jpg', os.path.basename(f))
if m.group(1) not in sprites:
sprites[m.group(1)] = dict()
character = sprites[m.group(1)]
if m.group(2) not in character:
character[m.group(2)] = dict()
pose = character[m.group(2)]
pose[int(m.group(3))] = image
return sprites
def _images_to_example(image, image2):
"""Convert 2 consecutive image to a SequenceExample."""
example = tf.SequenceExample()
feature_list = example.feature_lists.feature_list['moving_objs']
feature = feature_list.feature.add()
feature.float_list.value.extend(np.reshape(image, [-1]).tolist())
feature = feature_list.feature.add()
feature.float_list.value.extend(np.reshape(image2, [-1]).tolist())
return example
def generate_input():
"""Generate tfrecords."""
sprites = _read_images()
sys.stderr.write('Finish reading images.\n')
train_writer = tf.python_io.TFRecordWriter(
tf.flags.FLAGS.out_file.replace('sprites', 'sprites_train'))
test_writer = tf.python_io.TFRecordWriter(
tf.flags.FLAGS.out_file.replace('sprites', 'sprites_test'))
train_examples = []
test_examples = []
for i in sprites:
if int(i) < 24:
examples = test_examples
else:
examples = train_examples
character = sprites[i]
for j in character.keys():
pose = character[j]
for k in xrange(1, len(pose), 1):
image = pose[k]
image2 = pose[k+1]
examples.append(_images_to_example(image, image2))
sys.stderr.write('Finish generating examples: %d, %d.\n' %
(len(train_examples), len(test_examples)))
random.shuffle(train_examples)
_ = [train_writer.write(ex.SerializeToString()) for ex in train_examples]
_ = [test_writer.write(ex.SerializeToString()) for ex in test_examples]
def main(_):
generate_input()
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train the cross convolutional model."""
import os
import sys
import numpy as np
import tensorflow as tf
import model as cross_conv_model
import reader
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('master', '', 'Session address.')
tf.flags.DEFINE_string('log_root', '/tmp/moving_obj', 'The root dir of output.')
tf.flags.DEFINE_string('data_filepattern', '',
'training data file pattern.')
tf.flags.DEFINE_integer('image_size', 64, 'Image height and width.')
tf.flags.DEFINE_integer('batch_size', 1, 'Batch size.')
tf.flags.DEFINE_float('norm_scale', 1.0, 'Normalize the original image')
tf.flags.DEFINE_float('scale', 10.0,
'Scale the image after norm_scale and move the diff '
'to the positive realm.')
tf.flags.DEFINE_integer('sequence_length', 2, 'tf.SequenceExample length.')
tf.flags.DEFINE_float('learning_rate', 0.8, 'Learning rate.')
tf.flags.DEFINE_bool('l2_loss', True, 'If true, include l2_loss.')
tf.flags.DEFINE_bool('reconstr_loss', False, 'If true, include reconstr_loss.')
tf.flags.DEFINE_bool('kl_loss', True, 'If true, include KL loss.')
slim = tf.contrib.slim
def _Train():
params = dict()
params['batch_size'] = FLAGS.batch_size
params['seq_len'] = FLAGS.sequence_length
params['image_size'] = FLAGS.image_size
params['is_training'] = True
params['norm_scale'] = FLAGS.norm_scale
params['scale'] = FLAGS.scale
params['learning_rate'] = FLAGS.learning_rate
params['l2_loss'] = FLAGS.l2_loss
params['reconstr_loss'] = FLAGS.reconstr_loss
params['kl_loss'] = FLAGS.kl_loss
train_dir = os.path.join(FLAGS.log_root, 'train')
images = reader.ReadInput(FLAGS.data_filepattern, shuffle=True, params=params)
images *= params['scale']
# Increase the value makes training much faster.
image_diff_list = reader.SequenceToImageAndDiff(images)
model = cross_conv_model.CrossConvModel(image_diff_list, params)
model.Build()
tf.contrib.tfprof.model_analyzer.print_model_analysis(tf.get_default_graph())
summary_writer = tf.summary.FileWriter(train_dir)
sv = tf.train.Supervisor(logdir=FLAGS.log_root,
summary_op=None,
is_chief=True,
save_model_secs=60,
global_step=model.global_step)
sess = sv.prepare_or_wait_for_session(
FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True))
total_loss = 0.0
step = 0
sample_z_mean = np.zeros(model.z_mean.get_shape().as_list())
sample_z_stddev_log = np.zeros(model.z_stddev_log.get_shape().as_list())
sample_step = 0
while True:
_, loss_val, total_steps, summaries, z_mean, z_stddev_log = sess.run(
[model.train_op, model.loss, model.global_step,
model.summary_op,
model.z_mean, model.z_stddev_log])
sample_z_mean += z_mean
sample_z_stddev_log += z_stddev_log
total_loss += loss_val
step += 1
sample_step += 1
if step % 100 == 0:
summary_writer.add_summary(summaries, total_steps)
sys.stderr.write('step: %d, loss: %f\n' %
(total_steps, total_loss / step))
total_loss = 0.0
step = 0
# Sampled z is used for eval.
# It seems 10k is better than 1k. Maybe try 100k next?
if sample_step % 10000 == 0:
with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy'), 'w') as f:
np.save(f, sample_z_mean / sample_step)
with tf.gfile.Open(
os.path.join(FLAGS.log_root, 'z_stddev_log.npy'), 'w') as f:
np.save(f, sample_z_stddev_log / sample_step)
sample_z_mean = np.zeros(model.z_mean.get_shape().as_list())
sample_z_stddev_log = np.zeros(
model.z_stddev_log.get_shape().as_list())
sample_step = 0
def main(_):
_Train()
if __name__ == '__main__':
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment