Commit 748eceae authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Merge branch 'master' into cifar10_experiment

parents 40e906d2 ed65b632
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the various loss functions in use by the PTN model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def add_rotator_image_loss(inputs, outputs, step_size, weight_scale):
"""Computes the image loss of deep rotator model.
Args:
inputs: Input dictionary to the model containing keys
such as `images_k'.
outputs: Output dictionary returned by the model containing keys
such as `images_k'.
step_size: A scalar representing the number of recurrent
steps (number of repeated out-of-plane rotations)
in the deep rotator network (int).
weight_scale: A reweighting factor applied over the image loss (float).
Returns:
A `Tensor' scalar that returns averaged L2 loss
(divided by batch_size and step_size) between the
ground-truth images (RGB) and predicted images (tf.float32).
"""
batch_size = tf.shape(inputs['images_0'])[0]
image_loss = 0
for k in range(1, step_size + 1):
image_loss += tf.nn.l2_loss(
inputs['images_%d' % k] - outputs['images_%d' % k])
image_loss /= tf.to_float(step_size * batch_size)
slim.summaries.add_scalar_summary(
image_loss, 'image_loss', prefix='losses')
image_loss *= weight_scale
return image_loss
def add_rotator_mask_loss(inputs, outputs, step_size, weight_scale):
"""Computes the mask loss of deep rotator model.
Args:
inputs: Input dictionary to the model containing keys
such as `masks_k'.
outputs: Output dictionary returned by the model containing
keys such as `masks_k'.
step_size: A scalar representing the number of recurrent
steps (number of repeated out-of-plane rotations)
in the deep rotator network (int).
weight_scale: A reweighting factor applied over the mask loss (float).
Returns:
A `Tensor' that returns averaged L2 loss
(divided by batch_size and step_size) between the ground-truth masks
(object silhouettes) and predicted masks (tf.float32).
"""
batch_size = tf.shape(inputs['images_0'])[0]
mask_loss = 0
for k in range(1, step_size + 1):
mask_loss += tf.nn.l2_loss(
inputs['masks_%d' % k] - outputs['masks_%d' % k])
mask_loss /= tf.to_float(step_size * batch_size)
slim.summaries.add_scalar_summary(
mask_loss, 'mask_loss', prefix='losses')
mask_loss *= weight_scale
return mask_loss
def add_volume_proj_loss(inputs, outputs, num_views, weight_scale):
"""Computes the projection loss of voxel generation model.
Args:
inputs: Input dictionary to the model containing keys such as
`images_1'.
outputs: Output dictionary returned by the model containing keys
such as `masks_k' and ``projs_k'.
num_views: A integer scalar represents the total number of
viewpoints for each of the object (int).
weight_scale: A reweighting factor applied over the projection loss (float).
Returns:
A `Tensor' that returns the averaged L2 loss
(divided by batch_size and num_views) between the ground-truth
masks (object silhouettes) and predicted masks (tf.float32).
"""
batch_size = tf.shape(inputs['images_1'])[0]
proj_loss = 0
for k in range(num_views):
proj_loss += tf.nn.l2_loss(
outputs['masks_%d' % (k + 1)] - outputs['projs_%d' % (k + 1)])
proj_loss /= tf.to_float(num_views * batch_size)
slim.summaries.add_scalar_summary(
proj_loss, 'proj_loss', prefix='losses')
proj_loss *= weight_scale
return proj_loss
def add_volume_loss(inputs, outputs, num_views, weight_scale):
"""Computes the volume loss of voxel generation model.
Args:
inputs: Input dictionary to the model containing keys such as
`images_1' and `voxels'.
outputs: Output dictionary returned by the model containing keys
such as `voxels_k'.
num_views: A scalar representing the total number of
viewpoints for each object (int).
weight_scale: A reweighting factor applied over the volume
loss (tf.float32).
Returns:
A `Tensor' that returns the averaged L2 loss
(divided by batch_size and num_views) between the ground-truth
volumes and predicted volumes (tf.float32).
"""
batch_size = tf.shape(inputs['images_1'])[0]
vol_loss = 0
for k in range(num_views):
vol_loss += tf.nn.l2_loss(
inputs['voxels'] - outputs['voxels_%d' % (k + 1)])
vol_loss /= tf.to_float(num_views * batch_size)
slim.summaries.add_scalar_summary(
vol_loss, 'vol_loss', prefix='losses')
vol_loss *= weight_scale
return vol_loss
def regularization_loss(scopes, params):
"""Computes the weight decay as regularization during training.
Args:
scopes: A list of different components of the model such as
``encoder'', ``decoder'' and ``projector''.
params: Parameters of the model.
Returns:
Regularization loss (tf.float32).
"""
reg_loss = tf.zeros(dtype=tf.float32, shape=[])
if params.weight_decay > 0:
is_trainable = lambda x: x in tf.trainable_variables()
is_weights = lambda x: 'weights' in x.name
for scope in scopes:
scope_vars = filter(is_trainable,
tf.contrib.framework.get_model_variables(scope))
scope_vars = filter(is_weights, scope_vars)
if scope_vars:
reg_loss += tf.add_n([tf.nn.l2_loss(var) for var in scope_vars])
slim.summaries.add_scalar_summary(
reg_loss, 'reg_loss', prefix='losses')
reg_loss *= params.weight_decay
return reg_loss
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides metrics used by PTN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def add_image_pred_metrics(
inputs, outputs, num_views, upscale_factor):
"""Computes the image prediction metrics.
Args:
inputs: Input dictionary of the deep rotator model (model_rotator.py).
outputs: Output dictionary of the deep rotator model (model_rotator.py).
num_views: An integer scalar representing the total number
of different viewpoints for each object in the dataset.
upscale_factor: A float scalar representing the number of pixels
per image (num_channels x image_height x image_width).
Returns:
names_to_values: A dictionary representing the current value
of the metric.
names_to_updates: A dictionary representing the operation
that accumulates the error from a batch of data.
"""
names_to_values = dict()
names_to_updates = dict()
for k in xrange(num_views):
tmp_value, tmp_update = tf.contrib.metrics.streaming_mean_squared_error(
outputs['images_%d' % (k + 1)], inputs['images_%d' % (k + 1)])
name = 'image_pred/rnn_%d' % (k + 1)
names_to_values.update({name: tmp_value * upscale_factor})
names_to_updates.update({name: tmp_update})
return names_to_values, names_to_updates
def add_mask_pred_metrics(
inputs, outputs, num_views, upscale_factor):
"""Computes the mask prediction metrics.
Args:
inputs: Input dictionary of the deep rotator model (model_rotator.py).
outputs: Output dictionary of the deep rotator model (model_rotator.py).
num_views: An integer scalar representing the total number
of different viewpoints for each object in the dataset.
upscale_factor: A float scalar representing the number of pixels
per image (num_channels x image_height x image_width).
Returns:
names_to_values: A dictionary representing the current value
of the metric.
names_to_updates: A dictionary representing the operation
that accumulates the error from a batch of data.
"""
names_to_values = dict()
names_to_updates = dict()
for k in xrange(num_views):
tmp_value, tmp_update = tf.contrib.metrics.streaming_mean_squared_error(
outputs['masks_%d' % (k + 1)], inputs['masks_%d' % (k + 1)])
name = 'mask_pred/rnn_%d' % (k + 1)
names_to_values.update({name: tmp_value * upscale_factor})
names_to_updates.update({name: tmp_update})
return names_to_values, names_to_updates
def add_volume_iou_metrics(inputs, outputs):
"""Computes the per-instance volume IOU.
Args:
inputs: Input dictionary of the voxel generation model.
outputs: Output dictionary returned by the voxel generation model.
Returns:
names_to_values: metrics->values (dict).
names_to_updates: metrics->ops (dict).
"""
names_to_values = dict()
names_to_updates = dict()
labels = tf.greater_equal(inputs['voxels'], 0.5)
predictions = tf.greater_equal(outputs['voxels_1'], 0.5)
labels = 2 - tf.to_int32(labels)
predictions = 3 - tf.to_int32(predictions) * 2
tmp_values, tmp_updates = tf.metrics.mean_iou(
labels=labels,
predictions=predictions,
num_classes=3)
names_to_values['volume_iou'] = tmp_values * 3.0
names_to_updates['volume_iou'] = tmp_updates
return names_to_values, names_to_updates
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementations for Im2Vox PTN (NIPS16) model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import losses
import metrics
import model_voxel_generation
import utils
from nets import im2vox_factory
slim = tf.contrib.slim
class model_PTN(model_voxel_generation.Im2Vox): # pylint:disable=invalid-name
"""Inherits the generic Im2Vox model class and implements the functions."""
def __init__(self, params):
super(model_PTN, self).__init__(params)
# For testing, this selects all views in input
def preprocess_with_all_views(self, raw_inputs):
(quantity, num_views) = raw_inputs['images'].get_shape().as_list()[:2]
inputs = dict()
inputs['voxels'] = []
inputs['images_1'] = []
for k in xrange(num_views):
inputs['matrix_%d' % (k + 1)] = []
inputs['matrix_1'] = []
for n in xrange(quantity):
for k in xrange(num_views):
inputs['images_1'].append(raw_inputs['images'][n, k, :, :, :])
inputs['voxels'].append(raw_inputs['voxels'][n, :, :, :, :])
tf_matrix = self.get_transform_matrix(k)
inputs['matrix_%d' % (k + 1)].append(tf_matrix)
inputs['images_1'] = tf.stack(inputs['images_1'])
inputs['voxels'] = tf.stack(inputs['voxels'])
for k in xrange(num_views):
inputs['matrix_%d' % (k + 1)] = tf.stack(inputs['matrix_%d' % (k + 1)])
return inputs
def get_model_fn(self, is_training=True, reuse=False, run_projection=True):
return im2vox_factory.get(self._params, is_training, reuse, run_projection)
def get_regularization_loss(self, scopes):
return losses.regularization_loss(scopes, self._params)
def get_loss(self, inputs, outputs):
"""Computes the loss used for PTN paper (projection + volume loss)."""
g_loss = tf.zeros(dtype=tf.float32, shape=[])
if self._params.proj_weight:
g_loss += losses.add_volume_proj_loss(
inputs, outputs, self._params.step_size, self._params.proj_weight)
if self._params.volume_weight:
g_loss += losses.add_volume_loss(inputs, outputs, 1,
self._params.volume_weight)
slim.summaries.add_scalar_summary(g_loss, 'im2vox_loss', prefix='losses')
return g_loss
def get_metrics(self, inputs, outputs):
"""Aggregate the metrics for voxel generation model.
Args:
inputs: Input dictionary of the voxel generation model.
outputs: Output dictionary returned by the voxel generation model.
Returns:
names_to_values: metrics->values (dict).
names_to_updates: metrics->ops (dict).
"""
names_to_values = dict()
names_to_updates = dict()
tmp_values, tmp_updates = metrics.add_volume_iou_metrics(inputs, outputs)
names_to_values.update(tmp_values)
names_to_updates.update(tmp_updates)
for name, value in names_to_values.iteritems():
slim.summaries.add_scalar_summary(
value, name, prefix='eval', print_summary=True)
return names_to_values, names_to_updates
def write_disk_grid(self,
global_step,
log_dir,
input_images,
gt_projs,
pred_projs,
input_voxels=None,
output_voxels=None):
"""Function called by TF to save the prediction periodically."""
summary_freq = self._params.save_every
def write_grid(input_images, gt_projs, pred_projs, global_step,
input_voxels, output_voxels):
"""Native python function to call for writing images to files."""
grid = _build_image_grid(
input_images,
gt_projs,
pred_projs,
input_voxels=input_voxels,
output_voxels=output_voxels)
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
return grid
save_op = tf.py_func(write_grid, [
input_images, gt_projs, pred_projs, global_step, input_voxels,
output_voxels
], [tf.uint8], 'write_grid')[0]
slim.summaries.add_image_summary(
tf.expand_dims(save_op, axis=0), name='grid_vis')
return save_op
def get_transform_matrix(self, view_out):
"""Get the 4x4 Perspective Transfromation matrix used for PTN."""
num_views = self._params.num_views
focal_length = self._params.focal_length
focal_range = self._params.focal_range
phi = 30
theta_interval = 360.0 / num_views
theta = theta_interval * view_out
# pylint: disable=invalid-name
camera_matrix = np.zeros((4, 4), dtype=np.float32)
intrinsic_matrix = np.eye(4, dtype=np.float32)
extrinsic_matrix = np.eye(4, dtype=np.float32)
sin_phi = np.sin(float(phi) / 180.0 * np.pi)
cos_phi = np.cos(float(phi) / 180.0 * np.pi)
sin_theta = np.sin(float(-theta) / 180.0 * np.pi)
cos_theta = np.cos(float(-theta) / 180.0 * np.pi)
rotation_azimuth = np.zeros((3, 3), dtype=np.float32)
rotation_azimuth[0, 0] = cos_theta
rotation_azimuth[2, 2] = cos_theta
rotation_azimuth[0, 2] = -sin_theta
rotation_azimuth[2, 0] = sin_theta
rotation_azimuth[1, 1] = 1.0
## rotation axis -- x
rotation_elevation = np.zeros((3, 3), dtype=np.float32)
rotation_elevation[0, 0] = cos_phi
rotation_elevation[0, 1] = sin_phi
rotation_elevation[1, 0] = -sin_phi
rotation_elevation[1, 1] = cos_phi
rotation_elevation[2, 2] = 1.0
rotation_matrix = np.matmul(rotation_azimuth, rotation_elevation)
displacement = np.zeros((3, 1), dtype=np.float32)
displacement[0, 0] = float(focal_length) + float(focal_range) / 2.0
displacement = np.matmul(rotation_matrix, displacement)
extrinsic_matrix[0:3, 0:3] = rotation_matrix
extrinsic_matrix[0:3, 3:4] = -displacement
intrinsic_matrix[2, 2] = 1.0 / float(focal_length)
intrinsic_matrix[1, 1] = 1.0 / float(focal_length)
camera_matrix = np.matmul(extrinsic_matrix, intrinsic_matrix)
return camera_matrix
def _build_image_grid(input_images,
gt_projs,
pred_projs,
input_voxels,
output_voxels,
vis_size=128):
"""Builds a grid image by concatenating the input images."""
quantity = input_images.shape[0]
for row in xrange(int(quantity / 3)):
for col in xrange(3):
index = row * 3 + col
input_img_ = utils.resize_image(input_images[index, :, :, :], vis_size,
vis_size)
gt_proj_ = utils.resize_image(gt_projs[index, :, :, :], vis_size,
vis_size)
pred_proj_ = utils.resize_image(pred_projs[index, :, :, :], vis_size,
vis_size)
gt_voxel_vis = utils.resize_image(
utils.display_voxel(input_voxels[index, :, :, :, 0]), vis_size,
vis_size)
pred_voxel_vis = utils.resize_image(
utils.display_voxel(output_voxels[index, :, :, :, 0]), vis_size,
vis_size)
if col == 0:
tmp_ = np.concatenate(
[input_img_, gt_proj_, pred_proj_, gt_voxel_vis, pred_voxel_vis], 1)
else:
tmp_ = np.concatenate([
tmp_, input_img_, gt_proj_, pred_proj_, gt_voxel_vis, pred_voxel_vis
], 1)
if row == 0:
out_grid = tmp_
else:
out_grid = np.concatenate([out_grid, tmp_], 0)
return out_grid
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for pretraining (rotator) as described in PTN paper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import input_generator
import losses
import metrics
import utils
from nets import deeprotator_factory
slim = tf.contrib.slim
def _get_data_from_provider(inputs, batch_size, split_name):
"""Returns dictionary of batch input data processed by tf.train.batch."""
images, masks = tf.train.batch(
[inputs['image'], inputs['mask']],
batch_size=batch_size,
num_threads=8,
capacity=8 * batch_size,
name='batching_queues/%s' % (split_name))
outputs = dict()
outputs['images'] = images
outputs['masks'] = masks
outputs['num_samples'] = inputs['num_samples']
return outputs
def get_inputs(dataset_dir, dataset_name, split_name, batch_size, image_size,
is_training):
"""Loads the given dataset and split."""
del image_size # Unused
with tf.variable_scope('data_loading_%s/%s' % (dataset_name, split_name)):
common_queue_min = 50
common_queue_capacity = 256
num_readers = 4
inputs = input_generator.get(
dataset_dir,
dataset_name,
split_name,
shuffle=is_training,
num_readers=num_readers,
common_queue_min=common_queue_min,
common_queue_capacity=common_queue_capacity)
return _get_data_from_provider(inputs, batch_size, split_name)
def preprocess(raw_inputs, step_size):
"""Selects the subset of viewpoints to train on."""
shp = raw_inputs['images'].get_shape().as_list()
quantity = shp[0]
num_views = shp[1]
image_size = shp[2]
del image_size # Unused
batch_rot = np.zeros((quantity, 3), dtype=np.float32)
inputs = dict()
for n in xrange(step_size + 1):
inputs['images_%d' % n] = []
inputs['masks_%d' % n] = []
for n in xrange(quantity):
view_in = np.random.randint(0, num_views)
rng_rot = np.random.randint(0, 2)
if step_size == 1:
rng_rot = np.random.randint(0, 3)
delta = 0
if rng_rot == 0:
delta = -1
batch_rot[n, 2] = 1
elif rng_rot == 1:
delta = 1
batch_rot[n, 0] = 1
else:
delta = 0
batch_rot[n, 1] = 1
inputs['images_0'].append(raw_inputs['images'][n, view_in, :, :, :])
inputs['masks_0'].append(raw_inputs['masks'][n, view_in, :, :, :])
view_out = view_in
for k in xrange(1, step_size + 1):
view_out += delta
if view_out >= num_views:
view_out = 0
if view_out < 0:
view_out = num_views - 1
inputs['images_%d' % k].append(raw_inputs['images'][n, view_out, :, :, :])
inputs['masks_%d' % k].append(raw_inputs['masks'][n, view_out, :, :, :])
for n in xrange(step_size + 1):
inputs['images_%d' % n] = tf.stack(inputs['images_%d' % n])
inputs['masks_%d' % n] = tf.stack(inputs['masks_%d' % n])
inputs['actions'] = tf.constant(batch_rot, dtype=tf.float32)
return inputs
def get_init_fn(scopes, params):
"""Initialization assignment operator function used while training."""
if not params.init_model:
return None
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
params.init_model, var_list)
def init_assign_function(sess):
sess.run(init_assign_op, init_feed_dict)
return init_assign_function
def get_model_fn(params, is_training, reuse=False):
return deeprotator_factory.get(params, is_training, reuse)
def get_regularization_loss(scopes, params):
return losses.regularization_loss(scopes, params)
def get_loss(inputs, outputs, params):
"""Computes the rotator loss."""
g_loss = tf.zeros(dtype=tf.float32, shape=[])
if hasattr(params, 'image_weight'):
g_loss += losses.add_rotator_image_loss(inputs, outputs, params.step_size,
params.image_weight)
if hasattr(params, 'mask_weight'):
g_loss += losses.add_rotator_mask_loss(inputs, outputs, params.step_size,
params.mask_weight)
slim.summaries.add_scalar_summary(
g_loss, 'rotator_loss', prefix='losses')
return g_loss
def get_train_op_for_scope(loss, optimizer, scopes, params):
"""Train operation function for the given scope used file training."""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
update_ops = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
return slim.learning.create_train_op(
loss,
optimizer,
update_ops=update_ops,
variables_to_train=var_list,
clip_gradient_norm=params.clip_gradient_norm)
def get_metrics(inputs, outputs, params):
names_to_values, names_to_updates = metrics.rotator_metrics(
inputs, outputs, params)
return names_to_values, names_to_updates
def write_disk_grid(global_step, summary_freq, log_dir, input_images,
output_images, pred_images, pred_masks):
"""Function called by TF to save the prediction periodically."""
def write_grid(grid, global_step):
"""Native python function to call for writing images to files."""
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
return 0
grid = _build_image_grid(input_images, output_images, pred_images, pred_masks)
slim.summaries.add_image_summary(
tf.expand_dims(grid, axis=0), name='grid_vis')
save_op = tf.py_func(write_grid, [grid, global_step], [tf.int64],
'write_grid')[0]
return save_op
def _build_image_grid(input_images, output_images, pred_images, pred_masks):
"""Builds a grid image by concatenating the input images."""
quantity = input_images.get_shape().as_list()[0]
for row in xrange(int(quantity / 4)):
for col in xrange(4):
index = row * 4 + col
input_img_ = input_images[index, :, :, :]
output_img_ = output_images[index, :, :, :]
pred_img_ = pred_images[index, :, :, :]
pred_mask_ = tf.tile(pred_masks[index, :, :, :], [1, 1, 3])
if col == 0:
tmp_ = tf.concat([input_img_, output_img_, pred_img_, pred_mask_],
1) ## to the right
else:
tmp_ = tf.concat([tmp_, input_img_, output_img_, pred_img_, pred_mask_],
1)
if row == 0:
out_grid = tmp_
else:
out_grid = tf.concat([out_grid, tmp_], 0)
return out_grid
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for voxel generation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import os
import numpy as np
import tensorflow as tf
import input_generator
import utils
slim = tf.contrib.slim
class Im2Vox(object):
"""Defines the voxel generation model."""
__metaclass__ = abc.ABCMeta
def __init__(self, params):
self._params = params
@abc.abstractmethod
def get_metrics(self, inputs, outputs):
"""Gets dictionaries from metrics to value `Tensors` & update `Tensors`."""
pass
@abc.abstractmethod
def get_loss(self, inputs, outputs):
pass
@abc.abstractmethod
def get_regularization_loss(self, scopes):
pass
def set_params(self, params):
self._params = params
def get_inputs(self,
dataset_dir,
dataset_name,
split_name,
batch_size,
image_size,
vox_size,
is_training=True):
"""Loads data for a specified dataset and split."""
del image_size, vox_size
with tf.variable_scope('data_loading_%s/%s' % (dataset_name, split_name)):
common_queue_min = 64
common_queue_capacity = 256
num_readers = 4
inputs = input_generator.get(
dataset_dir,
dataset_name,
split_name,
shuffle=is_training,
num_readers=num_readers,
common_queue_min=common_queue_min,
common_queue_capacity=common_queue_capacity)
images, voxels = tf.train.batch(
[inputs['image'], inputs['voxel']],
batch_size=batch_size,
num_threads=8,
capacity=8 * batch_size,
name='batching_queues/%s/%s' % (dataset_name, split_name))
outputs = dict()
outputs['images'] = images
outputs['voxels'] = voxels
outputs['num_samples'] = inputs['num_samples']
return outputs
def preprocess(self, raw_inputs, step_size):
"""Selects the subset of viewpoints to train on."""
(quantity, num_views) = raw_inputs['images'].get_shape().as_list()[:2]
inputs = dict()
inputs['voxels'] = raw_inputs['voxels']
for k in xrange(step_size):
inputs['images_%d' % (k + 1)] = []
inputs['matrix_%d' % (k + 1)] = []
for n in xrange(quantity):
selected_views = np.random.choice(num_views, step_size, replace=False)
for k in xrange(step_size):
view_selected = selected_views[k]
inputs['images_%d' %
(k + 1)].append(raw_inputs['images'][n, view_selected, :, :, :])
tf_matrix = self.get_transform_matrix(view_selected)
inputs['matrix_%d' % (k + 1)].append(tf_matrix)
for k in xrange(step_size):
inputs['images_%d' % (k + 1)] = tf.stack(inputs['images_%d' % (k + 1)])
inputs['matrix_%d' % (k + 1)] = tf.stack(inputs['matrix_%d' % (k + 1)])
return inputs
def get_init_fn(self, scopes):
"""Initialization assignment operator function used while training."""
if not self._params.init_model:
return None
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
self._params.init_model, var_list)
def init_assign_function(sess):
sess.run(init_assign_op, init_feed_dict)
return init_assign_function
def get_train_op_for_scope(self, loss, optimizer, scopes):
"""Train operation function for the given scope used file training."""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
update_ops = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
return slim.learning.create_train_op(
loss,
optimizer,
update_ops=update_ops,
variables_to_train=var_list,
clip_gradient_norm=self._params.clip_gradient_norm)
def write_disk_grid(self,
global_step,
log_dir,
input_images,
gt_projs,
pred_projs,
pred_voxels=None):
"""Function called by TF to save the prediction periodically."""
summary_freq = self._params.save_every
def write_grid(input_images, gt_projs, pred_projs, pred_voxels,
global_step):
"""Native python function to call for writing images to files."""
grid = _build_image_grid(input_images, gt_projs, pred_projs, pred_voxels)
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
with open(
os.path.join(log_dir, 'pred_voxels_%s' % str(global_step)),
'w') as fout:
np.save(fout, pred_voxels)
with open(
os.path.join(log_dir, 'input_images_%s' % str(global_step)),
'w') as fout:
np.save(fout, input_images)
return grid
py_func_args = [
input_images, gt_projs, pred_projs, pred_voxels, global_step
]
save_grid_op = tf.py_func(write_grid, py_func_args, [tf.uint8],
'wrtie_grid')[0]
slim.summaries.add_image_summary(
tf.expand_dims(save_grid_op, axis=0), name='grid_vis')
return save_grid_op
def _build_image_grid(input_images, gt_projs, pred_projs, pred_voxels):
"""Build the visualization grid with py_func."""
quantity, img_height, img_width = input_images.shape[:3]
for row in xrange(int(quantity / 3)):
for col in xrange(3):
index = row * 3 + col
input_img_ = input_images[index, :, :, :]
gt_proj_ = gt_projs[index, :, :, :]
pred_proj_ = pred_projs[index, :, :, :]
pred_voxel_ = utils.display_voxel(pred_voxels[index, :, :, :, 0])
pred_voxel_ = utils.resize_image(pred_voxel_, img_height, img_width)
if col == 0:
tmp_ = np.concatenate([input_img_, gt_proj_, pred_proj_, pred_voxel_],
1)
else:
tmp_ = np.concatenate(
[tmp_, input_img_, gt_proj_, pred_proj_, pred_voxel_], 1)
if row == 0:
out_grid = tmp_
else:
out_grid = np.concatenate([out_grid, tmp_], 0)
out_grid = out_grid.astype(np.uint8)
return out_grid
package(default_visibility = ["//visibility:public"])
py_library(
name = "deeprotator_factory",
srcs = ["deeprotator_factory.py"],
deps = [
":ptn_encoder",
":ptn_im_decoder",
":ptn_rotator",
],
)
py_library(
name = "im2vox_factory",
srcs = ["im2vox_factory.py"],
deps = [
":perspective_projector",
":ptn_encoder",
":ptn_vox_decoder",
],
)
py_library(
name = "perspective_projector",
srcs = ["perspective_projector.py"],
deps = [
":perspective_transform",
],
)
py_library(
name = "perspective_transform",
srcs = ["perspective_transform.py"],
deps = [
],
)
py_library(
name = "ptn_encoder",
srcs = ["ptn_encoder.py"],
deps = [
],
)
py_library(
name = "ptn_im_decoder",
srcs = ["ptn_im_decoder.py"],
deps = [
],
)
py_library(
name = "ptn_rotator",
srcs = ["ptn_rotator.py"],
deps = [
],
)
py_library(
name = "ptn_vox_decoder",
srcs = ["ptn_vox_decoder.py"],
deps = [
],
)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory module for different encoder/decoder network models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import ptn_encoder
from nets import ptn_im_decoder
from nets import ptn_rotator
_NAME_TO_NETS = {
'ptn_encoder': ptn_encoder,
'ptn_rotator': ptn_rotator,
'ptn_im_decoder': ptn_im_decoder,
}
def _get_network(name):
"""Gets a single network component."""
if name not in _NAME_TO_NETS:
raise ValueError('Network name [%s] not recognized.' % name)
return _NAME_TO_NETS[name].model
def get(params, is_training=False, reuse=False):
"""Factory function to retrieve a network model.
Args:
params: Different parameters used througout ptn, typically FLAGS (dict)
is_training: Set to True if while training (boolean)
reuse: Set as True if either using a pre-trained model or
in the training loop while the graph has already been built (boolean)
Returns:
Model function for network (inputs to outputs)
"""
def model(inputs):
"""Model function corresponding to a specific network architecture."""
outputs = {}
# First, build the encoder.
encoder_fn = _get_network(params.encoder_name)
with tf.variable_scope('encoder', reuse=reuse):
# Produces id/pose units
features = encoder_fn(inputs['images_0'], params, is_training)
outputs['ids'] = features['ids']
outputs['poses_0'] = features['poses']
# Second, build the rotator and decoder.
rotator_fn = _get_network(params.rotator_name)
with tf.variable_scope('rotator', reuse=reuse):
outputs['poses_1'] = rotator_fn(outputs['poses_0'], inputs['actions'],
params, is_training)
decoder_fn = _get_network(params.decoder_name)
with tf.variable_scope('decoder', reuse=reuse):
dec_output = decoder_fn(outputs['ids'], outputs['poses_1'], params,
is_training)
outputs['images_1'] = dec_output['images']
outputs['masks_1'] = dec_output['masks']
# Third, build the recurrent connection
for k in range(1, params.step_size):
with tf.variable_scope('rotator', reuse=True):
outputs['poses_%d' % (k + 1)] = rotator_fn(
outputs['poses_%d' % k], inputs['actions'], params, is_training)
with tf.variable_scope('decoder', reuse=True):
dec_output = decoder_fn(outputs['ids'], outputs['poses_%d' % (k + 1)],
params, is_training)
outputs['images_%d' % (k + 1)] = dec_output['images']
outputs['masks_%d' % (k + 1)] = dec_output['masks']
return outputs
return model
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory module for getting the complete image to voxel generation network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import perspective_projector
from nets import ptn_encoder
from nets import ptn_vox_decoder
_NAME_TO_NETS = {
'ptn_encoder': ptn_encoder,
'ptn_vox_decoder': ptn_vox_decoder,
'perspective_projector': perspective_projector,
}
def _get_network(name):
"""Gets a single encoder/decoder network model."""
if name not in _NAME_TO_NETS:
raise ValueError('Network name [%s] not recognized.' % name)
return _NAME_TO_NETS[name].model
def get(params, is_training=False, reuse=False, run_projection=True):
"""Factory function to get the training/pretraining im->vox model (NIPS16).
Args:
params: Different parameters used througout ptn, typically FLAGS (dict).
is_training: Set to True if while training (boolean).
reuse: Set as True if sharing variables with a model that has already
been built (boolean).
run_projection: Set as False if not interested in mask and projection
images. Useful in evaluation routine (boolean).
Returns:
Model function for network (inputs to outputs).
"""
def model(inputs):
"""Model function corresponding to a specific network architecture."""
outputs = {}
# First, build the encoder
encoder_fn = _get_network(params.encoder_name)
with tf.variable_scope('encoder', reuse=reuse):
# Produces id/pose units
enc_outputs = encoder_fn(inputs['images_1'], params, is_training)
outputs['ids_1'] = enc_outputs['ids']
# Second, build the decoder and projector
decoder_fn = _get_network(params.decoder_name)
with tf.variable_scope('decoder', reuse=reuse):
outputs['voxels_1'] = decoder_fn(outputs['ids_1'], params, is_training)
if run_projection:
projector_fn = _get_network(params.projector_name)
with tf.variable_scope('projector', reuse=reuse):
outputs['projs_1'] = projector_fn(
outputs['voxels_1'], inputs['matrix_1'], params, is_training)
# Infer the ground-truth mask
with tf.variable_scope('oracle', reuse=reuse):
outputs['masks_1'] = projector_fn(inputs['voxels'], inputs['matrix_1'],
params, False)
# Third, build the entire graph (bundled strategy described in PTN paper)
for k in range(1, params.step_size):
with tf.variable_scope('projector', reuse=True):
outputs['projs_%d' % (k + 1)] = projector_fn(
outputs['voxels_1'], inputs['matrix_%d' %
(k + 1)], params, is_training)
with tf.variable_scope('oracle', reuse=True):
outputs['masks_%d' % (k + 1)] = projector_fn(
inputs['voxels'], inputs['matrix_%d' % (k + 1)], params, False)
return outputs
return model
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""3D->2D projector model as used in PTN (NIPS16)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import perspective_transform
def model(voxels, transform_matrix, params, is_training):
"""Model transforming the 3D voxels into 2D projections.
Args:
voxels: A tensor of size [batch, depth, height, width, channel]
representing the input of projection layer (tf.float32).
transform_matrix: A tensor of size [batch, 16] representing
the flattened 4-by-4 matrix for transformation (tf.float32).
params: Model parameters (dict).
is_training: Set to True if while training (boolean).
Returns:
A transformed tensor (tf.float32)
"""
del is_training # Doesn't make a difference for projector
# Rearrangement (batch, z, y, x, channel) --> (batch, y, z, x, channel).
# By the standard, projection happens along z-axis but the voxels
# are stored in a different way. So we need to switch the y and z
# axis for transformation operation.
voxels = tf.transpose(voxels, [0, 2, 1, 3, 4])
z_near = params.focal_length
z_far = params.focal_length + params.focal_range
transformed_voxels = perspective_transform.transformer(
voxels, transform_matrix, [params.vox_size] * 3, z_near, z_far)
views = tf.reduce_max(transformed_voxels, [1])
views = tf.reverse(views, [1])
return views
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Perspective Transformer Layer Implementation.
Transform the volume based on 4 x 4 perspective projection matrix.
Reference:
(1) "Perspective Transformer Nets: Perspective Transformer Nets:
Learning Single-View 3D Object Reconstruction without 3D Supervision."
Xinchen Yan, Jimei Yang, Ersin Yumer, Yijie Guo, Honglak Lee. In NIPS 2016
https://papers.nips.cc/paper/6206-perspective-transformer-nets-learning-single-view-3d-object-reconstruction-without-3d-supervision.pdf
(2) Official implementation in Torch: https://github.com/xcyan/ptnbhwd
(3) 2D Transformer implementation in TF:
github.com/tensorflow/models/tree/master/transformer
"""
import tensorflow as tf
def transformer(voxels,
theta,
out_size,
z_near,
z_far,
name='PerspectiveTransformer'):
"""Perspective Transformer Layer.
Args:
voxels: A tensor of size [num_batch, depth, height, width, num_channels].
It is the output of a deconv/upsampling conv network (tf.float32).
theta: A tensor of size [num_batch, 16].
It is the inverse camera transformation matrix (tf.float32).
out_size: A tuple representing the size of output of
transformer layer (float).
z_near: A number representing the near clipping plane (float).
z_far: A number representing the far clipping plane (float).
Returns:
A transformed tensor (tf.float32).
"""
def _repeat(x, n_repeats):
with tf.variable_scope('_repeat'):
rep = tf.transpose(
tf.expand_dims(tf.ones(shape=tf.stack([
n_repeats,
])), 1), [1, 0])
rep = tf.to_int32(rep)
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
return tf.reshape(x, [-1])
def _interpolate(im, x, y, z, out_size):
"""Bilinear interploation layer.
Args:
im: A 5D tensor of size [num_batch, depth, height, width, num_channels].
It is the input volume for the transformation layer (tf.float32).
x: A tensor of size [num_batch, out_depth, out_height, out_width]
representing the inverse coordinate mapping for x (tf.float32).
y: A tensor of size [num_batch, out_depth, out_height, out_width]
representing the inverse coordinate mapping for y (tf.float32).
z: A tensor of size [num_batch, out_depth, out_height, out_width]
representing the inverse coordinate mapping for z (tf.float32).
out_size: A tuple representing the output size of transformation layer
(float).
Returns:
A transformed tensor (tf.float32).
"""
with tf.variable_scope('_interpolate'):
num_batch = im.get_shape().as_list()[0]
depth = im.get_shape().as_list()[1]
height = im.get_shape().as_list()[2]
width = im.get_shape().as_list()[3]
channels = im.get_shape().as_list()[4]
x = tf.to_float(x)
y = tf.to_float(y)
z = tf.to_float(z)
depth_f = tf.to_float(depth)
height_f = tf.to_float(height)
width_f = tf.to_float(width)
# Number of disparity interpolated.
out_depth = out_size[0]
out_height = out_size[1]
out_width = out_size[2]
zero = tf.zeros([], dtype='int32')
# 0 <= z < depth, 0 <= y < height & 0 <= x < width.
max_z = tf.to_int32(tf.shape(im)[1] - 1)
max_y = tf.to_int32(tf.shape(im)[2] - 1)
max_x = tf.to_int32(tf.shape(im)[3] - 1)
# Converts scale indices from [-1, 1] to [0, width/height/depth].
x = (x + 1.0) * (width_f) / 2.0
y = (y + 1.0) * (height_f) / 2.0
z = (z + 1.0) * (depth_f) / 2.0
x0 = tf.to_int32(tf.floor(x))
x1 = x0 + 1
y0 = tf.to_int32(tf.floor(y))
y1 = y0 + 1
z0 = tf.to_int32(tf.floor(z))
z1 = z0 + 1
x0_clip = tf.clip_by_value(x0, zero, max_x)
x1_clip = tf.clip_by_value(x1, zero, max_x)
y0_clip = tf.clip_by_value(y0, zero, max_y)
y1_clip = tf.clip_by_value(y1, zero, max_y)
z0_clip = tf.clip_by_value(z0, zero, max_z)
z1_clip = tf.clip_by_value(z1, zero, max_z)
dim3 = width
dim2 = width * height
dim1 = width * height * depth
base = _repeat(
tf.range(num_batch) * dim1, out_depth * out_height * out_width)
base_z0_y0 = base + z0_clip * dim2 + y0_clip * dim3
base_z0_y1 = base + z0_clip * dim2 + y1_clip * dim3
base_z1_y0 = base + z1_clip * dim2 + y0_clip * dim3
base_z1_y1 = base + z1_clip * dim2 + y1_clip * dim3
idx_z0_y0_x0 = base_z0_y0 + x0_clip
idx_z0_y0_x1 = base_z0_y0 + x1_clip
idx_z0_y1_x0 = base_z0_y1 + x0_clip
idx_z0_y1_x1 = base_z0_y1 + x1_clip
idx_z1_y0_x0 = base_z1_y0 + x0_clip
idx_z1_y0_x1 = base_z1_y0 + x1_clip
idx_z1_y1_x0 = base_z1_y1 + x0_clip
idx_z1_y1_x1 = base_z1_y1 + x1_clip
# Use indices to lookup pixels in the flat image and restore
# channels dim
im_flat = tf.reshape(im, tf.stack([-1, channels]))
im_flat = tf.to_float(im_flat)
i_z0_y0_x0 = tf.gather(im_flat, idx_z0_y0_x0)
i_z0_y0_x1 = tf.gather(im_flat, idx_z0_y0_x1)
i_z0_y1_x0 = tf.gather(im_flat, idx_z0_y1_x0)
i_z0_y1_x1 = tf.gather(im_flat, idx_z0_y1_x1)
i_z1_y0_x0 = tf.gather(im_flat, idx_z1_y0_x0)
i_z1_y0_x1 = tf.gather(im_flat, idx_z1_y0_x1)
i_z1_y1_x0 = tf.gather(im_flat, idx_z1_y1_x0)
i_z1_y1_x1 = tf.gather(im_flat, idx_z1_y1_x1)
# Finally calculate interpolated values.
x0_f = tf.to_float(x0)
x1_f = tf.to_float(x1)
y0_f = tf.to_float(y0)
y1_f = tf.to_float(y1)
z0_f = tf.to_float(z0)
z1_f = tf.to_float(z1)
# Check the out-of-boundary case.
x0_valid = tf.to_float(
tf.less_equal(x0, max_x) & tf.greater_equal(x0, 0))
x1_valid = tf.to_float(
tf.less_equal(x1, max_x) & tf.greater_equal(x1, 0))
y0_valid = tf.to_float(
tf.less_equal(y0, max_y) & tf.greater_equal(y0, 0))
y1_valid = tf.to_float(
tf.less_equal(y1, max_y) & tf.greater_equal(y1, 0))
z0_valid = tf.to_float(
tf.less_equal(z0, max_z) & tf.greater_equal(z0, 0))
z1_valid = tf.to_float(
tf.less_equal(z1, max_z) & tf.greater_equal(z1, 0))
w_z0_y0_x0 = tf.expand_dims(((x1_f - x) * (y1_f - y) *
(z1_f - z) * x1_valid * y1_valid * z1_valid),
1)
w_z0_y0_x1 = tf.expand_dims(((x - x0_f) * (y1_f - y) *
(z1_f - z) * x0_valid * y1_valid * z1_valid),
1)
w_z0_y1_x0 = tf.expand_dims(((x1_f - x) * (y - y0_f) *
(z1_f - z) * x1_valid * y0_valid * z1_valid),
1)
w_z0_y1_x1 = tf.expand_dims(((x - x0_f) * (y - y0_f) *
(z1_f - z) * x0_valid * y0_valid * z1_valid),
1)
w_z1_y0_x0 = tf.expand_dims(((x1_f - x) * (y1_f - y) *
(z - z0_f) * x1_valid * y1_valid * z0_valid),
1)
w_z1_y0_x1 = tf.expand_dims(((x - x0_f) * (y1_f - y) *
(z - z0_f) * x0_valid * y1_valid * z0_valid),
1)
w_z1_y1_x0 = tf.expand_dims(((x1_f - x) * (y - y0_f) *
(z - z0_f) * x1_valid * y0_valid * z0_valid),
1)
w_z1_y1_x1 = tf.expand_dims(((x - x0_f) * (y - y0_f) *
(z - z0_f) * x0_valid * y0_valid * z0_valid),
1)
output = tf.add_n([
w_z0_y0_x0 * i_z0_y0_x0, w_z0_y0_x1 * i_z0_y0_x1,
w_z0_y1_x0 * i_z0_y1_x0, w_z0_y1_x1 * i_z0_y1_x1,
w_z1_y0_x0 * i_z1_y0_x0, w_z1_y0_x1 * i_z1_y0_x1,
w_z1_y1_x0 * i_z1_y1_x0, w_z1_y1_x1 * i_z1_y1_x1
])
return output
def _meshgrid(depth, height, width, z_near, z_far):
with tf.variable_scope('_meshgrid'):
x_t = tf.reshape(
tf.tile(tf.linspace(-1.0, 1.0, width), [height * depth]),
[depth, height, width])
y_t = tf.reshape(
tf.tile(tf.linspace(-1.0, 1.0, height), [width * depth]),
[depth, width, height])
y_t = tf.transpose(y_t, [0, 2, 1])
sample_grid = tf.tile(
tf.linspace(float(z_near), float(z_far), depth), [width * height])
z_t = tf.reshape(sample_grid, [height, width, depth])
z_t = tf.transpose(z_t, [2, 0, 1])
z_t = 1 / z_t
d_t = 1 / z_t
x_t /= z_t
y_t /= z_t
x_t_flat = tf.reshape(x_t, (1, -1))
y_t_flat = tf.reshape(y_t, (1, -1))
d_t_flat = tf.reshape(d_t, (1, -1))
ones = tf.ones_like(x_t_flat)
grid = tf.concat([d_t_flat, y_t_flat, x_t_flat, ones], 0)
return grid
def _transform(theta, input_dim, out_size, z_near, z_far):
with tf.variable_scope('_transform'):
num_batch = input_dim.get_shape().as_list()[0]
num_channels = input_dim.get_shape().as_list()[4]
theta = tf.reshape(theta, (-1, 4, 4))
theta = tf.cast(theta, 'float32')
out_depth = out_size[0]
out_height = out_size[1]
out_width = out_size[2]
grid = _meshgrid(out_depth, out_height, out_width, z_near, z_far)
grid = tf.expand_dims(grid, 0)
grid = tf.reshape(grid, [-1])
grid = tf.tile(grid, tf.stack([num_batch]))
grid = tf.reshape(grid, tf.stack([num_batch, 4, -1]))
# Transform A x (x_t', y_t', 1, d_t)^T -> (x_s, y_s, z_s, 1).
t_g = tf.matmul(theta, grid)
z_s = tf.slice(t_g, [0, 0, 0], [-1, 1, -1])
y_s = tf.slice(t_g, [0, 1, 0], [-1, 1, -1])
x_s = tf.slice(t_g, [0, 2, 0], [-1, 1, -1])
z_s_flat = tf.reshape(z_s, [-1])
y_s_flat = tf.reshape(y_s, [-1])
x_s_flat = tf.reshape(x_s, [-1])
input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, z_s_flat,
out_size)
output = tf.reshape(
input_transformed,
tf.stack([num_batch, out_depth, out_height, out_width, num_channels]))
return output
with tf.variable_scope(name):
output = _transform(theta, voxels, out_size, z_near, z_far)
return output
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training/Pretraining encoder as used in PTN (NIPS16)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def _preprocess(images):
return images * 2 - 1
def model(images, params, is_training):
"""Model encoding the images into view-invariant embedding."""
del is_training # Unused
image_size = images.get_shape().as_list()[1]
f_dim = params.f_dim
fc_dim = params.fc_dim
z_dim = params.z_dim
outputs = dict()
images = _preprocess(images)
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_initializer=tf.truncated_normal_initializer(stddev=0.02, seed=1)):
h0 = slim.conv2d(images, f_dim, [5, 5], stride=2, activation_fn=tf.nn.relu)
h1 = slim.conv2d(h0, f_dim * 2, [5, 5], stride=2, activation_fn=tf.nn.relu)
h2 = slim.conv2d(h1, f_dim * 4, [5, 5], stride=2, activation_fn=tf.nn.relu)
# Reshape layer
s8 = image_size // 8
h2 = tf.reshape(h2, [-1, s8 * s8 * f_dim * 4])
h3 = slim.fully_connected(h2, fc_dim, activation_fn=tf.nn.relu)
h4 = slim.fully_connected(h3, fc_dim, activation_fn=tf.nn.relu)
outputs['ids'] = slim.fully_connected(h4, z_dim, activation_fn=tf.nn.relu)
outputs['poses'] = slim.fully_connected(h4, z_dim, activation_fn=tf.nn.relu)
return outputs
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image/Mask decoder used while pretraining the network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
_FEATURE_MAP_SIZE = 8
def _postprocess_im(images):
"""Performs post-processing for the images returned from conv net.
Transforms the value from [-1, 1] to [0, 1].
"""
return (images + 1) * 0.5
def model(identities, poses, params, is_training):
"""Decoder model to get image and mask from latent embedding."""
del is_training
f_dim = params.f_dim
fc_dim = params.fc_dim
outputs = dict()
with slim.arg_scope(
[slim.fully_connected, slim.conv2d_transpose],
weights_initializer=tf.truncated_normal_initializer(stddev=0.02, seed=1)):
# Concatenate the identity and pose units
h0 = tf.concat([identities, poses], 1)
h0 = slim.fully_connected(h0, fc_dim, activation_fn=tf.nn.relu)
h1 = slim.fully_connected(h0, fc_dim, activation_fn=tf.nn.relu)
# Mask decoder
dec_m0 = slim.fully_connected(
h1, (_FEATURE_MAP_SIZE**2) * f_dim * 2, activation_fn=tf.nn.relu)
dec_m0 = tf.reshape(
dec_m0, [-1, _FEATURE_MAP_SIZE, _FEATURE_MAP_SIZE, f_dim * 2])
dec_m1 = slim.conv2d_transpose(
dec_m0, f_dim, [5, 5], stride=2, activation_fn=tf.nn.relu)
dec_m2 = slim.conv2d_transpose(
dec_m1, int(f_dim / 2), [5, 5], stride=2, activation_fn=tf.nn.relu)
dec_m3 = slim.conv2d_transpose(
dec_m2, 1, [5, 5], stride=2, activation_fn=tf.nn.sigmoid)
# Image decoder
dec_i0 = slim.fully_connected(
h1, (_FEATURE_MAP_SIZE**2) * f_dim * 4, activation_fn=tf.nn.relu)
dec_i0 = tf.reshape(
dec_i0, [-1, _FEATURE_MAP_SIZE, _FEATURE_MAP_SIZE, f_dim * 4])
dec_i1 = slim.conv2d_transpose(
dec_i0, f_dim * 2, [5, 5], stride=2, activation_fn=tf.nn.relu)
dec_i2 = slim.conv2d_transpose(
dec_i1, f_dim * 2, [5, 5], stride=2, activation_fn=tf.nn.relu)
dec_i3 = slim.conv2d_transpose(
dec_i2, 3, [5, 5], stride=2, activation_fn=tf.nn.tanh)
outputs = dict()
outputs['images'] = _postprocess_im(dec_i3)
outputs['masks'] = dec_m3
return outputs
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Creates rotator network model.
This model performs the out-of-plane rotations given input image and action.
The action is either no-op, rotate clockwise or rotate counter-clockwise.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def bilinear(input_x, input_y, output_size):
"""Define the bilinear transformation layer."""
shape_x = input_x.get_shape().as_list()
shape_y = input_y.get_shape().as_list()
weights_initializer = tf.truncated_normal_initializer(stddev=0.02,
seed=1)
biases_initializer = tf.constant_initializer(0.0)
matrix = tf.get_variable("Matrix", [shape_x[1], shape_y[1], output_size],
tf.float32, initializer=weights_initializer)
bias = tf.get_variable("Bias", [output_size],
initializer=biases_initializer)
# Add to GraphKeys.MODEL_VARIABLES
tf.contrib.framework.add_model_variable(matrix)
tf.contrib.framework.add_model_variable(bias)
# Define the transformation
h0 = tf.matmul(input_x, tf.reshape(matrix,
[shape_x[1], shape_y[1]*output_size]))
h0 = tf.reshape(h0, [-1, shape_y[1], output_size])
h1 = tf.tile(tf.reshape(input_y, [-1, shape_y[1], 1]),
[1, 1, output_size])
h1 = tf.multiply(h0, h1)
return tf.reduce_sum(h1, 1) + bias
def model(poses, actions, params, is_training):
"""Model for performing rotation."""
del is_training # Unused
return bilinear(poses, actions, params.z_dim)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training decoder as used in PTN (NIPS16)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
@tf.contrib.framework.add_arg_scope
def conv3d_transpose(inputs,
num_outputs,
kernel_size,
stride=1,
padding='SAME',
activation_fn=tf.nn.relu,
weights_initializer=tf.contrib.layers.xavier_initializer(),
biases_initializer=tf.zeros_initializer(),
reuse=None,
trainable=True,
scope=None):
"""Wrapper for conv3d_transpose layer.
This function wraps the tf.conv3d_transpose with basic non-linearity.
Tt creates a variable called `weights`, representing the kernel,
that is convoled with the input. A second varibale called `biases'
is added to the result of operation.
"""
with tf.variable_scope(
scope, 'Conv3d_transpose', [inputs], reuse=reuse):
dtype = inputs.dtype.base_dtype
kernel_d, kernel_h, kernel_w = kernel_size[0:3]
num_filters_in = inputs.get_shape()[4]
weights_shape = [kernel_d, kernel_h, kernel_w, num_outputs, num_filters_in]
weights = tf.get_variable('weights',
shape=weights_shape,
dtype=dtype,
initializer=weights_initializer,
trainable=trainable)
tf.contrib.framework.add_model_variable(weights)
input_shape = inputs.get_shape().as_list()
batch_size = input_shape[0]
depth = input_shape[1]
height = input_shape[2]
width = input_shape[3]
def get_deconv_dim(dim_size, stride_size):
# Only support padding='SAME'.
if isinstance(dim_size, tf.Tensor):
dim_size = tf.multiply(dim_size, stride_size)
elif dim_size is not None:
dim_size *= stride_size
return dim_size
out_depth = get_deconv_dim(depth, stride)
out_height = get_deconv_dim(height, stride)
out_width = get_deconv_dim(width, stride)
out_shape = [batch_size, out_depth, out_height, out_width, num_outputs]
outputs = tf.nn.conv3d_transpose(inputs, weights, out_shape,
[1, stride, stride, stride, 1],
padding=padding)
outputs.set_shape(out_shape)
if biases_initializer is not None:
biases = tf.get_variable('biases',
shape=[num_outputs,],
dtype=dtype,
initializer=biases_initializer,
trainable=trainable)
tf.contrib.framework.add_model_variable(biases)
outputs = tf.nn.bias_add(outputs, biases)
if activation_fn:
outputs = activation_fn(outputs)
return outputs
def model(identities, params, is_training):
"""Model transforming embedding to voxels."""
del is_training # Unused
f_dim = params.f_dim
# Please refer to the original implementation: github.com/xcyan/nips16_PTN
# In TF replication, we use a slightly different architecture.
with slim.arg_scope(
[slim.fully_connected, conv3d_transpose],
weights_initializer=tf.truncated_normal_initializer(stddev=0.02, seed=1)):
h0 = slim.fully_connected(
identities, 4 * 4 * 4 * f_dim * 8, activation_fn=tf.nn.relu)
h1 = tf.reshape(h0, [-1, 4, 4, 4, f_dim * 8])
h1 = conv3d_transpose(
h1, f_dim * 4, [4, 4, 4], stride=2, activation_fn=tf.nn.relu)
h2 = conv3d_transpose(
h1, int(f_dim * 3 / 2), [5, 5, 5], stride=2, activation_fn=tf.nn.relu)
h3 = conv3d_transpose(
h2, 1, [6, 6, 6], stride=2, activation_fn=tf.nn.sigmoid)
return h3
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains training plan for the Rotator model (Pretraining in NIPS16)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from tensorflow import app
import model_rotator as model
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_string('inp_dir', '',
'Directory path containing the input data (tfrecords).')
flags.DEFINE_string(
'dataset_name', 'shapenet_chair',
'Dataset name that is to be used for training and evaluation.')
flags.DEFINE_integer('z_dim', 512, '')
flags.DEFINE_integer('a_dim', 3, '')
flags.DEFINE_integer('f_dim', 64, '')
flags.DEFINE_integer('fc_dim', 1024, '')
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.')
flags.DEFINE_integer('image_size', 64,
'Input images dimension (pixels) - width & height.')
flags.DEFINE_integer('step_size', 1, 'Steps to take for rotation in pretraining.')
flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_string('encoder_name', 'ptn_encoder',
'Name of the encoder network being used.')
flags.DEFINE_string('decoder_name', 'ptn_im_decoder',
'Name of the decoder network being used.')
flags.DEFINE_string('rotator_name', 'ptn_rotator',
'Name of the rotator network being used.')
# Save options
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn_train/',
'Directory path for saving trained models and other data.')
flags.DEFINE_string('model_name', 'deeprotator_pretrain',
'Name of the model used in naming the TF job. Must be different for each run.')
flags.DEFINE_string('init_model', None,
'Checkpoint path of the model to initialize with.')
flags.DEFINE_integer('save_every', 1000,
'Average period of steps after which we save a model.')
# Optimization
flags.DEFINE_float('image_weight', 10, 'Weighting factor for image loss.')
flags.DEFINE_float('mask_weight', 1, 'Weighting factor for mask loss.')
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.')
flags.DEFINE_float('weight_decay', 0.001, 'Weight decay parameter while training.')
flags.DEFINE_float('clip_gradient_norm', 0, 'Gradient clim norm, leave 0 if no gradient clipping.')
flags.DEFINE_integer('max_number_of_steps', 320000, 'Maximum number of steps for training.')
# Summary
flags.DEFINE_integer('save_summaries_secs', 15, 'Seconds interval for dumping TF summaries.')
flags.DEFINE_integer('save_interval_secs', 60 * 5, 'Seconds interval to save models.')
# Distribution
flags.DEFINE_string('master', '', 'The address of the tensorflow master if running distributed.')
flags.DEFINE_bool('sync_replicas', False, 'Whether to sync gradients between replicas for optimizer.')
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas (train tasks).')
flags.DEFINE_integer('backup_workers', 0, 'Number of backup workers.')
flags.DEFINE_integer('ps_tasks', 0, 'Number of ps tasks.')
flags.DEFINE_integer('task', 0,
'Task identifier flag to be set for each task running in distributed manner. Task number 0 '
'will be chosen as the chief.')
FLAGS = flags.FLAGS
def main(_):
train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
save_image_dir = os.path.join(train_dir, 'images')
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(save_image_dir):
os.makedirs(save_image_dir)
g = tf.Graph()
with g.as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
global_step = slim.get_or_create_global_step()
##########
## data ##
##########
train_data = model.get_inputs(
FLAGS.inp_dir,
FLAGS.dataset_name,
'train',
FLAGS.batch_size,
FLAGS.image_size,
is_training=True)
inputs = model.preprocess(train_data, FLAGS.step_size)
###########
## model ##
###########
model_fn = model.get_model_fn(FLAGS, is_training=True)
outputs = model_fn(inputs)
##########
## loss ##
##########
task_loss = model.get_loss(inputs, outputs, FLAGS)
regularization_loss = model.get_regularization_loss(
['encoder', 'rotator', 'decoder'], FLAGS)
loss = task_loss + regularization_loss
###############
## optimizer ##
###############
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
if FLAGS.sync_replicas:
optimizer = tf.train.SyncReplicasOptimizer(
optimizer,
replicas_to_aggregate=FLAGS.workers_replicas - FLAGS.backup_workers,
total_num_replicas=FLAGS.worker_replicas)
##############
## train_op ##
##############
train_op = model.get_train_op_for_scope(
loss, optimizer, ['encoder', 'rotator', 'decoder'], FLAGS)
###########
## saver ##
###########
saver = tf.train.Saver(max_to_keep=np.minimum(5,
FLAGS.worker_replicas + 1))
if FLAGS.task == 0:
val_data = model.get_inputs(
FLAGS.inp_dir,
FLAGS.dataset_name,
'val',
FLAGS.batch_size,
FLAGS.image_size,
is_training=False)
val_inputs = model.preprocess(val_data, FLAGS.step_size)
# Note: don't compute loss here
reused_model_fn = model.get_model_fn(
FLAGS, is_training=False, reuse=True)
val_outputs = reused_model_fn(val_inputs)
with tf.device(tf.DeviceSpec(device_type='CPU')):
if FLAGS.step_size == 1:
vis_input_images = val_inputs['images_0'] * 255.0
vis_output_images = val_inputs['images_1'] * 255.0
vis_pred_images = val_outputs['images_1'] * 255.0
vis_pred_masks = (val_outputs['masks_1'] * (-1) + 1) * 255.0
else:
rep_times = int(np.ceil(32.0 / float(FLAGS.step_size)))
vis_list_1 = []
vis_list_2 = []
vis_list_3 = []
vis_list_4 = []
for j in xrange(rep_times):
for k in xrange(FLAGS.step_size):
vis_input_image = val_inputs['images_0'][j],
vis_output_image = val_inputs['images_%d' % (k + 1)][j]
vis_pred_image = val_outputs['images_%d' % (k + 1)][j]
vis_pred_mask = val_outputs['masks_%d' % (k + 1)][j]
vis_list_1.append(tf.expand_dims(vis_input_image, 0))
vis_list_2.append(tf.expand_dims(vis_output_image, 0))
vis_list_3.append(tf.expand_dims(vis_pred_image, 0))
vis_list_4.append(tf.expand_dims(vis_pred_mask, 0))
vis_list_1 = tf.reshape(
tf.stack(vis_list_1), [
rep_times * FLAGS.step_size, FLAGS.image_size,
FLAGS.image_size, 3
])
vis_list_2 = tf.reshape(
tf.stack(vis_list_2), [
rep_times * FLAGS.step_size, FLAGS.image_size,
FLAGS.image_size, 3
])
vis_list_3 = tf.reshape(
tf.stack(vis_list_3), [
rep_times * FLAGS.step_size, FLAGS.image_size,
FLAGS.image_size, 3
])
vis_list_4 = tf.reshape(
tf.stack(vis_list_4), [
rep_times * FLAGS.step_size, FLAGS.image_size,
FLAGS.image_size, 1
])
vis_input_images = vis_list_1 * 255.0
vis_output_images = vis_list_2 * 255.0
vis_pred_images = vis_list_3 * 255.0
vis_pred_masks = (vis_list_4 * (-1) + 1) * 255.0
write_disk_op = model.write_disk_grid(
global_step=global_step,
summary_freq=FLAGS.save_every,
log_dir=save_image_dir,
input_images=vis_input_images,
output_images=vis_output_images,
pred_images=vis_pred_images,
pred_masks=vis_pred_masks)
with tf.control_dependencies([write_disk_op]):
train_op = tf.identity(train_op)
#############
## init_fn ##
#############
init_fn = model.get_init_fn(['encoder, ' 'rotator', 'decoder'], FLAGS)
##############
## training ##
##############
slim.learning.train(
train_op=train_op,
logdir=train_dir,
init_fn=init_fn,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
number_of_steps=FLAGS.max_number_of_steps,
saver=saver,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains training plan for the Im2vox model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from tensorflow import app
import model_ptn
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_string('inp_dir',
'',
'Directory path containing the input data (tfrecords).')
flags.DEFINE_string(
'dataset_name', 'shapenet_chair',
'Dataset name that is to be used for training and evaluation.')
flags.DEFINE_integer('z_dim', 512, '')
flags.DEFINE_integer('f_dim', 64, '')
flags.DEFINE_integer('fc_dim', 1024, '')
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.')
flags.DEFINE_integer('image_size', 64,
'Input images dimension (pixels) - width & height.')
flags.DEFINE_integer('vox_size', 32, 'Voxel prediction dimension.')
flags.DEFINE_integer('step_size', 24, 'Steps to take in rotation to fetch viewpoints.')
flags.DEFINE_integer('batch_size', 1, 'Batch size while training.')
flags.DEFINE_float('focal_length', 0.866, 'Focal length parameter used in perspective projection.')
flags.DEFINE_float('focal_range', 1.732, 'Focal length parameter used in perspective projection.')
flags.DEFINE_string('encoder_name', 'ptn_encoder',
'Name of the encoder network being used.')
flags.DEFINE_string('decoder_name', 'ptn_vox_decoder',
'Name of the decoder network being used.')
flags.DEFINE_string('projector_name', 'perspective_projector',
'Name of the projector network being used.')
# Save options
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn_train/',
'Directory path for saving trained models and other data.')
flags.DEFINE_string('model_name', 'ptn_finetune',
'Name of the model used in naming the TF job. Must be different for each run.')
flags.DEFINE_string('init_model', None,
'Checkpoint path of the model to initialize with.')
flags.DEFINE_integer('save_every', 1000,
'Average period of steps after which we save a model.')
# Optimization
flags.DEFINE_float('proj_weight', 10, 'Weighting factor for projection loss.')
flags.DEFINE_float('volume_weight', 0, 'Weighting factor for volume loss.')
flags.DEFINE_float('viewpoint_weight', 1, 'Weighting factor for viewpoint loss.')
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.')
flags.DEFINE_float('weight_decay', 0.001, 'Weight decay parameter while training.')
flags.DEFINE_float('clip_gradient_norm', 0, 'Gradient clim norm, leave 0 if no gradient clipping.')
flags.DEFINE_integer('max_number_of_steps', 10000, 'Maximum number of steps for training.')
# Summary
flags.DEFINE_integer('save_summaries_secs', 15, 'Seconds interval for dumping TF summaries.')
flags.DEFINE_integer('save_interval_secs', 60 * 5, 'Seconds interval to save models.')
# Scheduling
flags.DEFINE_string('master', '', 'The address of the tensorflow master')
flags.DEFINE_bool('sync_replicas', False, 'Whether to sync gradients between replicas for optimizer.')
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas (train tasks).')
flags.DEFINE_integer('backup_workers', 0, 'Number of backup workers.')
flags.DEFINE_integer('ps_tasks', 0, 'Number of ps tasks.')
flags.DEFINE_integer('task', 0,
'Task identifier flag to be set for each task running in distributed manner. Task number 0 '
'will be chosen as the chief.')
FLAGS = flags.FLAGS
def main(_):
train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
save_image_dir = os.path.join(train_dir, 'images')
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(save_image_dir):
os.makedirs(save_image_dir)
g = tf.Graph()
with g.as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
global_step = slim.get_or_create_global_step()
###########
## model ##
###########
model = model_ptn.model_PTN(FLAGS)
##########
## data ##
##########
train_data = model.get_inputs(
FLAGS.inp_dir,
FLAGS.dataset_name,
'train',
FLAGS.batch_size,
FLAGS.image_size,
FLAGS.vox_size,
is_training=True)
inputs = model.preprocess(train_data, FLAGS.step_size)
##############
## model_fn ##
##############
model_fn = model.get_model_fn(
is_training=True, reuse=False, run_projection=True)
outputs = model_fn(inputs)
##################
## train_scopes ##
##################
if FLAGS.init_model:
train_scopes = ['decoder']
init_scopes = ['encoder']
else:
train_scopes = ['encoder', 'decoder']
##########
## loss ##
##########
task_loss = model.get_loss(inputs, outputs)
regularization_loss = model.get_regularization_loss(train_scopes)
loss = task_loss + regularization_loss
###############
## optimizer ##
###############
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
if FLAGS.sync_replicas:
optimizer = tf.train.SyncReplicasOptimizer(
optimizer,
replicas_to_aggregate=FLAGS.workers_replicas - FLAGS.backup_workers,
total_num_replicas=FLAGS.worker_replicas)
##############
## train_op ##
##############
train_op = model.get_train_op_for_scope(loss, optimizer, train_scopes)
###########
## saver ##
###########
saver = tf.train.Saver(max_to_keep=np.minimum(5,
FLAGS.worker_replicas + 1))
if FLAGS.task == 0:
params = FLAGS
params.batch_size = params.num_views
params.step_size = 1
model.set_params(params)
val_data = model.get_inputs(
params.inp_dir,
params.dataset_name,
'val',
params.batch_size,
params.image_size,
params.vox_size,
is_training=False)
val_inputs = model.preprocess(val_data, params.step_size)
# Note: don't compute loss here
reused_model_fn = model.get_model_fn(is_training=False, reuse=True)
val_outputs = reused_model_fn(val_inputs)
with tf.device(tf.DeviceSpec(device_type='CPU')):
vis_input_images = val_inputs['images_1'] * 255.0
vis_gt_projs = (val_outputs['masks_1'] * (-1) + 1) * 255.0
vis_pred_projs = (val_outputs['projs_1'] * (-1) + 1) * 255.0
vis_gt_projs = tf.concat([vis_gt_projs] * 3, axis=3)
vis_pred_projs = tf.concat([vis_pred_projs] * 3, axis=3)
# rescale
new_size = [FLAGS.image_size] * 2
vis_gt_projs = tf.image.resize_nearest_neighbor(
vis_gt_projs, new_size)
vis_pred_projs = tf.image.resize_nearest_neighbor(
vis_pred_projs, new_size)
# flip
# vis_gt_projs = utils.image_flipud(vis_gt_projs)
# vis_pred_projs = utils.image_flipud(vis_pred_projs)
# vis_gt_projs is of shape [batch, height, width, channels]
write_disk_op = model.write_disk_grid(
global_step=global_step,
log_dir=save_image_dir,
input_images=vis_input_images,
gt_projs=vis_gt_projs,
pred_projs=vis_pred_projs,
input_voxels=val_inputs['voxels'],
output_voxels=val_outputs['voxels_1'])
with tf.control_dependencies([write_disk_op]):
train_op = tf.identity(train_op)
#############
## init_fn ##
#############
if FLAGS.init_model:
init_fn = model.get_init_fn(init_scopes)
else:
init_fn = None
##############
## training ##
##############
slim.learning.train(
train_op=train_op,
logdir=train_dir,
init_fn=init_fn,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
number_of_steps=FLAGS.max_number_of_steps,
saver=saver,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import StringIO
from matplotlib import pylab as p
# axes3d is being used implictly for visualization.
from mpl_toolkits.mplot3d import axes3d as p3 # pylint:disable=unused-import
import numpy as np
from PIL import Image
from skimage import measure
import tensorflow as tf
def save_image(inp_array, image_file):
"""Function that dumps the image to disk."""
inp_array = np.clip(inp_array, 0, 255).astype(np.uint8)
image = Image.fromarray(inp_array)
buf = StringIO.StringIO()
image.save(buf, format='JPEG')
with open(image_file, 'w') as f:
f.write(buf.getvalue())
return None
def image_flipud(images):
"""Function that flip (up-down) the np image."""
quantity = images.get_shape().as_list()[0]
image_list = []
for k in xrange(quantity):
image_list.append(tf.image.flip_up_down(images[k, :, :, :]))
outputs = tf.stack(image_list)
return outputs
def resize_image(inp_array, new_height, new_width):
"""Function that resize the np image."""
inp_array = np.clip(inp_array, 0, 255).astype(np.uint8)
image = Image.fromarray(inp_array)
# Reverse order
image = image.resize((new_width, new_height))
return np.array(image)
def display_voxel(points, vis_size=128):
"""Function to display 3D voxel."""
try:
data = visualize_voxel_spectral(points, vis_size)
except ValueError:
data = visualize_voxel_scatter(points, vis_size)
return data
def visualize_voxel_spectral(points, vis_size=128):
"""Function to visualize voxel (spectral)."""
points = np.rint(points)
points = np.swapaxes(points, 0, 2)
fig = p.figure(figsize=(1, 1), dpi=vis_size)
verts, faces = measure.marching_cubes(points, 0, spacing=(0.1, 0.1, 0.1))
ax = fig.add_subplot(111, projection='3d')
ax.plot_trisurf(
verts[:, 0], verts[:, 1], faces, verts[:, 2], cmap='Spectral_r', lw=0.1)
ax.set_axis_off()
fig.tight_layout(pad=0)
fig.canvas.draw()
data = np.fromstring(
fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
vis_size, vis_size, 3)
p.close('all')
return data
def visualize_voxel_scatter(points, vis_size=128):
"""Function to visualize voxel (scatter)."""
points = np.rint(points)
points = np.swapaxes(points, 0, 2)
fig = p.figure(figsize=(1, 1), dpi=vis_size)
ax = fig.add_subplot(111, projection='3d')
x = []
y = []
z = []
(x_dimension, y_dimension, z_dimension) = points.shape
for i in range(x_dimension):
for j in range(y_dimension):
for k in range(z_dimension):
if points[i, j, k]:
x.append(i)
y.append(j)
z.append(k)
ax.scatter3D(x, y, z)
ax.set_axis_off()
fig.tight_layout(pad=0)
fig.canvas.draw()
data = np.fromstring(
fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
vis_size, vis_size, 3)
p.close('all')
return data
# REINFORCing Concrete with REBAR
*Implemention of REBAR (and other closely related methods) as described
in "REBAR: Low-variance, unbiased gradient estimates for discrete latent variable models" by
George Tucker, Andriy Mnih, Chris J. Maddison, Dieterich Lawson, Jascha Sohl-Dickstein [(https://arxiv.org/abs/1703.07370)](https://arxiv.org/abs/1703.07370).*
Learning in models with discrete latent variables is challenging due to high variance gradient estimators. Generally, approaches have relied on control variates to reduce the variance of the REINFORCE estimator. Recent work ([Jang et al. 2016](https://arxiv.org/abs/1611.01144); [Maddison et al. 2016](https://arxiv.org/abs/1611.00712)) has taken a different approach, introducing a continuous relaxation of discrete variables to produce low-variance, but biased, gradient estimates. In this work, we combine the two approaches through a novel control variate that produces low-variance, unbiased gradient estimates. Then, we introduce a novel continuous relaxation and show that the tightness of the relaxation can be adapted online, removing it as a hyperparameter. We show state-of-the-art variance reduction on several benchmark generative modeling tasks, generally leading to faster convergence to a better final log likelihood.
REBAR applied to multilayer sigmoid belief networks is implemented in rebar.py and rebar_train.py provides a training/evaluation setup. As a comparison, we also implemented the following methods:
* [NVIL](https://arxiv.org/abs/1402.0030)
* [MuProp](https://arxiv.org/abs/1511.05176)
* [Gumbel-Softmax](https://arxiv.org/abs/1611.01144)
The code is not optimized and some computation is repeated for ease of
implementation. We hope that this code will be a useful starting point for future research in this area.
## Quick Start:
Requirements:
* TensorFlow (see tensorflow.org for how to install)
* MNIST dataset
* Omniglot dataset
First download datasets by selecting URLs to download the data from. Then
fill in the download_data.py script like so:
```
MNIST_URL = 'http://yann.lecun.com/exdb/mnist'
MNIST_BINARIZED_URL = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist'
OMNIGLOT_URL = 'https://github.com/yburda/iwae/raw/master/datasets/OMNIGLOT/chardata.mat'
```
Then run the script to download the data:
```
python download_data.py
```
Then run the training script:
```
python rebar_train.py --hparams="model=SBNDynamicRebar,learning_rate=0.0003,n_layer=2,task=sbn"
```
and you should see something like:
```
Step 2084: [-231.026474 0.3711713 1. 1.06934261 1.07023323
1.02173257 1.02171052 1. 1. 1. 1. ]
-3.6465678215
Step 4168: [-156.86795044 0.3097114 1. 1.03964758 1.03936625
1.02627242 1.02629256 1. 1. 1. 1. ]
-4.42727231979
Step 6252: [-143.4650116 0.26153237 1. 1.03633797 1.03600132
1.02639604 1.02639794 1. 1. 1. 1. ]
-4.85577583313
Step 8336: [-137.65275574 0.22313026 1. 1.03467286 1.03428006
1.02336085 1.02335203 0.99999988 1. 0.99999988
1. ]
-4.95563364029
```
The first number in the list is the log likelihood lower bound and the number
after the list is the log of the variance of the gradient estimator. The rest of
the numbers are for debugging.
We can also compare the variance between methods:
```
python rebar_train.py \
--hparams="model=SBNTrackGradVariances,learning_rate=0.0003,n_layer=2,task=omni"
```
and you should see something like:
```
Step 959: [ -2.60478699e+02 3.84281784e-01 6.31126612e-02 3.27319391e-02
6.13379292e-03 1.98278503e-04 1.96425783e-04 8.83973844e-04
8.70995224e-04 -inf]
('DynamicREBAR', -3.725339889526367)
('MuProp', -0.033569782972335815)
('NVIL', 2.7640280723571777)
('REBAR', -3.539274215698242)
('SimpleMuProp', -0.040744658559560776)
Step 1918: [ -2.06948471e+02 3.35904926e-01 5.20901568e-03 7.81541676e-05
2.06885766e-03 1.08521657e-04 1.07351625e-04 2.30646547e-04
2.26554010e-04 -8.22885323e+00]
('DynamicREBAR', -3.864381790161133)
('MuProp', -0.7183765172958374)
('NVIL', 2.266523599624634)
('REBAR', -3.662022113800049)
('SimpleMuProp', -0.7071359157562256)
```
where the tuples show the log of the variance of the gradient estimators.
The training script has a number of hyperparameter configuration flags:
* task (sbn): one of {sbn, sp, omni} which correspond to MNIST generative
modeling, structured prediction on MNIST, and Omniglot generative modeling,
respectively
* model (SBNGumbel) : one of {SBN, SBNNVIL, SBNMuProp, SBNSimpleMuProp,
SBNRebar, SBNDynamicRebar, SBNGumbel SBNTrackGradVariances}. DynamicRebar automatically
adjusts the temperature, whereas Rebar and Gumbel-Softmax require tuning the
temperature. The ones named after
methods uses that method to estimate the gradients (SBN refers to
REINFORCE). SBNTrackGradVariances runs multiple methods and follows a single
optimization trajectory
* n_hidden (200): number of hidden nodes per layer
* n_layer (1): number of layers in the model
* nonlinear (false): if true use 2 x tanh layers between each stochastic layer,
otherwise use a linear layer
* learning_rate (0.001): learning rate
* temperature (0.5): temperature hyperparameter (for DynamicRebar, this is the initial
value of the temperature)
* n_samples (1): number of samples used to compute the gradient estimator (for the
experiments in the paper, set to 1)
* batch_size (24): batch size
* muprop_relaxation (true): if true use the new relaxation described in the paper,
otherwise use the Concrete/Gumbel softmax relaxation
* dynamic_b (false): if true dynamically binarize the training set. This
increases the effective training dataset size and reduces overfitting, though
it is not a standard dataset
Maintained by George Tucker (gjt@google.com, github user: gjtucker).
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration variables."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
DATA_DIR = 'data'
MNIST_BINARIZED = 'mnist_salakhutdinov_07-19-2017.pkl'
MNIST_FLOAT = 'mnist_train_xs_07-19-2017.npy'
OMNIGLOT = 'omniglot_07-19-2017.mat'
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library of datasets for REBAR."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import os
import scipy.io
import numpy as np
import cPickle as pickle
import tensorflow as tf
import config
gfile = tf.gfile
def load_data(hparams):
# Load data
if hparams.task in ['sbn', 'sp']:
reader = read_MNIST
elif hparams.task == 'omni':
reader = read_omniglot
x_train, x_valid, x_test = reader(binarize=not hparams.dynamic_b)
return x_train, x_valid, x_test
def read_MNIST(binarize=False):
"""Reads in MNIST images.
Args:
binarize: whether to use the fixed binarization
Returns:
x_train: 50k training images
x_valid: 10k validation images
x_test: 10k test images
"""
with gfile.FastGFile(os.path.join(config.DATA_DIR, config.MNIST_BINARIZED), 'r') as f:
(x_train, _), (x_valid, _), (x_test, _) = pickle.load(f)
if not binarize:
with gfile.FastGFile(os.path.join(config.DATA_DIR, config.MNIST_FLOAT), 'r') as f:
x_train = np.load(f).reshape(-1, 784)
return x_train, x_valid, x_test
def read_omniglot(binarize=False):
"""Reads in Omniglot images.
Args:
binarize: whether to use the fixed binarization
Returns:
x_train: training images
x_valid: validation images
x_test: test images
"""
n_validation=1345
def reshape_data(data):
return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='fortran')
omni_raw = scipy.io.loadmat(os.path.join(config.DATA_DIR, config.OMNIGLOT))
train_data = reshape_data(omni_raw['data'].T.astype('float32'))
test_data = reshape_data(omni_raw['testdata'].T.astype('float32'))
# Binarize the data with a fixed seed
if binarize:
np.random.seed(5)
train_data = (np.random.rand(*train_data.shape) < train_data).astype(float)
test_data = (np.random.rand(*test_data.shape) < test_data).astype(float)
shuffle_seed = 123
permutation = np.random.RandomState(seed=shuffle_seed).permutation(train_data.shape[0])
train_data = train_data[permutation]
x_train = train_data[:-n_validation]
x_valid = train_data[-n_validation:]
x_test = test_data
return x_train, x_valid, x_test
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment