"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "55cb4c90649ac6858a769a83ff5e8e4130d6c9ae"
Commit 3d76b69f authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #1999 from arkanath/master

Pull request for models/ptn
parents b528ede6 7a5a5836
...@@ -17,6 +17,7 @@ neural_programmer/* @arvind2505 ...@@ -17,6 +17,7 @@ neural_programmer/* @arvind2505
next_frame_prediction/* @panyx0718 next_frame_prediction/* @panyx0718
object_detection/* @jch1 @tombstone @derekjchow @jesu9 @dreamdragon object_detection/* @jch1 @tombstone @derekjchow @jesu9 @dreamdragon
pcl_rl/* @ofirnachum pcl_rl/* @ofirnachum
ptn/* @xcyan @arkanath @hellojas @honglaklee
real_nvp/* @laurent-dinh real_nvp/* @laurent-dinh
resnet/* @panyx0718 resnet/* @panyx0718
skip_thoughts/* @cshallue skip_thoughts/* @cshallue
......
bazel
.idea
bazel-bin
bazel-out
bazel-genfiles
bazel-ptn
bazel-testlogs
WORKSPACE
*.pyc
py_library(
name = "input_generator",
srcs = ["input_generator.py"],
deps = [
],
)
py_library(
name = "losses",
srcs = ["losses.py"],
deps = [
],
)
py_library(
name = "metrics",
srcs = ["metrics.py"],
deps = [
],
)
py_library(
name = "utils",
srcs = ["utils.py"],
deps = [
],
)
# Defines the Rotator model here
py_library(
name = "model_rotator",
srcs = ["model_rotator.py"],
deps = [
":input_generator",
":losses",
":metrics",
":utils",
"//nets:deeprotator_factory",
],
)
# Defines the Im2vox model here
py_library(
name = "model_voxel_generation",
srcs = ["model_voxel_generation.py"],
deps = [
":input_generator",
"//nets:im2vox_factory",
],
)
py_library(
name = "model_ptn",
srcs = ["model_ptn.py"],
deps = [
":losses",
":metrics",
":model_voxel_generation",
":utils",
"//nets:im2vox_factory",
],
)
py_binary(
name = "train_ptn",
srcs = ["train_ptn.py"],
deps = [
":model_ptn",
],
)
py_binary(
name = "eval_ptn",
srcs = ["eval_ptn.py"],
deps = [
":model_ptn",
],
)
py_binary(
name = "pretrain_rotator",
srcs = ["pretrain_rotator.py"],
deps = [
":model_rotator",
],
)
py_binary(
name = "eval_rotator",
srcs = ["eval_rotator.py"],
deps = [
":model_rotator",
],
)
# Perspective Transformer Nets
## Introduction
This is the TensorFlow implementation for the NIPS 2016 work ["Perspective Transformer Nets: Learning Single-View 3D Object Reconstrution without 3D Supervision"](https://papers.nips.cc/paper/6206-perspective-transformer-nets-learning-single-view-3d-object-reconstruction-without-3d-supervision.pdf)
Re-implemented by Xinchen Yan, Arkanath Pathak, Jasmine Hsu, Honglak Lee
Reference: [Orginal implementation in Torch](https://github.com/xcyan/nips16_PTN)
## How to run this code
This implementation is ready to be run locally or ["distributed across multiple machines/tasks"](https://www.tensorflow.org/deploy/distributed).
You will need to set the task number flag for each task when running in a distributed fashion.
Please refer to the original paper for parameter explanations and training details.
### Installation
* TensorFlow
* This code requires the latest open-source TensorFlow that you will need to build manually.
The [documentation](https://www.tensorflow.org/install/install_sources) provides the steps required for that.
* Bazel
* Follow the instructions [here](http://bazel.build/docs/install.html).
* Alternately, Download bazel from
[https://github.com/bazelbuild/bazel/releases](https://github.com/bazelbuild/bazel/releases)
for your system configuration.
* Check for the bazel version using this command: bazel version
* matplotlib
* Follow the instructions [here](https://matplotlib.org/users/installing.html).
* You can use a package repository like pip.
* scikit-image
* Follow the instructions [here](http://scikit-image.org/docs/dev/install.html).
* You can use a package repository like pip.
* PIL
* Install from [here](https://pypi.python.org/pypi/Pillow/2.2.1).
### Dataset
This code requires the dataset to be in *tfrecords* format with the following features:
* image
* Flattened list of image (float representations) for each view point.
* mask
* Flattened list of image masks (float representations) for each view point.
* vox
* Flattened list of voxels (float representations) for the object.
* This is needed for using vox loss and for prediction comparison.
You can download the ShapeNet Dataset in tfrecords format from [here](https://drive.google.com/file/d/0B12XukcbU7T7OHQ4MGh6d25qQlk)<sup>*</sup>.
<sup>*</sup> Disclaimer: This data is hosted personally by Arkanath Pathak for non-commercial research purposes. Please cite the [ShapeNet paper](https://arxiv.org/pdf/1512.03012.pdf) in your works when using ShapeNet for non-commercial research purposes.
### Pretraining: pretrain_rotator.py for each RNN step
$ bazel run -c opt :pretrain_rotator -- --step_size={} --init_model={}
Pass the init_model as the checkpoint path for the last step trained model.
You'll also need to set the inp_dir flag to where your data resides.
### Training: train_ptn.py with last pretrained model.
$ bazel run -c opt :train_ptn -- --init_model={}
### Example TensorBoard Visualizations
To compare the visualizations make sure to set the model_name flag different for each parametric setting:
This code adds summaries for each loss. For instance, these are the losses we encountered in the distributed pretraining for ShapeNet Chair Dataset with 10 workers and 16 parameter servers:
![ShapeNet Chair Pretraining](https://drive.google.com/uc?export=view&id=0B12XukcbU7T7bWdlTjhzbGJVaWs "ShapeNet Chair Experiment Pretraining Losses")
You can expect such images after fine tuning the training as "grid_vis" under **Image** summaries in TensorBoard:
![ShapeNet Chair experiments with projection weight of 1](https://drive.google.com/uc?export=view&id=0B12XukcbU7T7ZFV6aEVBSDdCMjQ "ShapeNet Chair Dataset Predictions")
Here the third and fifth columns are the predicted masks and voxels respectively, alongside their ground truth values.
A similar image for when trained on all ShapeNet Categories (Voxel visualizations might be skewed):
![ShapeNet All Categories experiments](https://drive.google.com/uc?export=view&id=0B12XukcbU7T7bDZKNFlkTVAzZmM "ShapeNet All Categories Dataset Predictions")
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains evaluation plan for the Im2vox model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from tensorflow import app
import model_ptn
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_string('inp_dir',
'',
'Directory path containing the input data (tfrecords).')
flags.DEFINE_string(
'dataset_name', 'shapenet_chair',
'Dataset name that is to be used for training and evaluation.')
flags.DEFINE_integer('z_dim', 512, '')
flags.DEFINE_integer('f_dim', 64, '')
flags.DEFINE_integer('fc_dim', 1024, '')
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.')
flags.DEFINE_integer('image_size', 64,
'Input images dimension (pixels) - width & height.')
flags.DEFINE_integer('vox_size', 32, 'Voxel prediction dimension.')
flags.DEFINE_integer('step_size', 24, '')
flags.DEFINE_integer('batch_size', 1, 'Batch size while training.')
flags.DEFINE_float('focal_length', 0.866, '')
flags.DEFINE_float('focal_range', 1.732, '')
flags.DEFINE_string('encoder_name', 'ptn_encoder',
'Name of the encoder network being used.')
flags.DEFINE_string('decoder_name', 'ptn_vox_decoder',
'Name of the decoder network being used.')
flags.DEFINE_string('projector_name', 'ptn_projector',
'Name of the projector network being used.')
# Save options
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn/eval/',
'Directory path for saving trained models and other data.')
flags.DEFINE_string('model_name', 'ptn_proj',
'Name of the model used in naming the TF job. Must be different for each run.')
flags.DEFINE_string('eval_set', 'val', 'Data partition to form evaluation on.')
# Optimization
flags.DEFINE_float('proj_weight', 10, 'Weighting factor for projection loss.')
flags.DEFINE_float('volume_weight', 0, 'Weighting factor for volume loss.')
flags.DEFINE_float('viewpoint_weight', 1,
'Weighting factor for viewpoint loss.')
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.')
flags.DEFINE_float('weight_decay', 0.001, '')
flags.DEFINE_float('clip_gradient_norm', 0, '')
# Summary
flags.DEFINE_integer('save_summaries_secs', 15, '')
flags.DEFINE_integer('eval_interval_secs', 60 * 5, '')
# Distribution
flags.DEFINE_string('master', '', '')
FLAGS = flags.FLAGS
def main(argv=()):
del argv # Unused.
eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name,
'eval_%s' % FLAGS.eval_set)
if not os.path.exists(eval_dir):
os.makedirs(eval_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
g = tf.Graph()
with g.as_default():
eval_params = FLAGS
eval_params.batch_size = 1
eval_params.step_size = FLAGS.num_views
###########
## model ##
###########
model = model_ptn.model_PTN(eval_params)
##########
## data ##
##########
eval_data = model.get_inputs(
FLAGS.data_sst_path,
FLAGS.dataset_name,
eval_params.eval_set,
eval_params.batch_size,
eval_params.image_size,
eval_params.vox_size,
is_training=False)
inputs = model.preprocess_with_all_views(eval_data)
##############
## model_fn ##
##############
model_fn = model.get_model_fn(is_training=False, run_projection=False)
outputs = model_fn(inputs)
#############
## metrics ##
#############
names_to_values, names_to_updates = model.get_metrics(inputs, outputs)
del names_to_values
################
## evaluation ##
################
num_batches = eval_data['num_samples']
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=eval_dir,
logdir=log_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains evaluation plan for the Rotator model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from tensorflow import app
import model_rotator as model
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_string('inp_dir',
'',
'Directory path containing the input data (tfrecords).')
flags.DEFINE_string(
'dataset_name', 'shapenet_chair',
'Dataset name that is to be used for training and evaluation.')
flags.DEFINE_integer('z_dim', 512, '')
flags.DEFINE_integer('a_dim', 3, '')
flags.DEFINE_integer('f_dim', 64, '')
flags.DEFINE_integer('fc_dim', 1024, '')
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.')
flags.DEFINE_integer('image_size', 64,
'Input images dimension (pixels) - width & height.')
flags.DEFINE_integer('step_size', 24, '')
flags.DEFINE_integer('batch_size', 2, '')
flags.DEFINE_string('encoder_name', 'ptn_encoder',
'Name of the encoder network being used.')
flags.DEFINE_string('decoder_name', 'ptn_im_decoder',
'Name of the decoder network being used.')
flags.DEFINE_string('rotator_name', 'ptn_rotator',
'Name of the rotator network being used.')
# Save options
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn_train/',
'Directory path for saving trained models and other data.')
flags.DEFINE_string('model_name', 'ptn_proj',
'Name of the model used in naming the TF job. Must be different for each run.')
# Optimization
flags.DEFINE_float('image_weight', 10, '')
flags.DEFINE_float('mask_weight', 1, '')
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.')
flags.DEFINE_float('weight_decay', 0.001, '')
flags.DEFINE_float('clip_gradient_norm', 0, '')
# Summary
flags.DEFINE_integer('save_summaries_secs', 15, '')
flags.DEFINE_integer('eval_interval_secs', 60 * 5, '')
# Scheduling
flags.DEFINE_string('master', 'local', '')
FLAGS = flags.FLAGS
def main(argv=()):
del argv # Unused.
eval_dir = os.path.join(FLAGS.checkpoint_dir,
FLAGS.model_name, 'train')
log_dir = os.path.join(FLAGS.checkpoint_dir,
FLAGS.model_name, 'eval')
if not os.path.exists(eval_dir):
os.makedirs(eval_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
g = tf.Graph()
if FLAGS.step_size < FLAGS.num_views:
raise ValueError('Impossible step_size, must not be less than num_views.')
g = tf.Graph()
with g.as_default():
##########
## data ##
##########
val_data = model.get_inputs(
FLAGS.data_sst_path,
FLAGS.dataset_name,
'val',
FLAGS.batch_size,
FLAGS.image_size,
is_training=False)
inputs = model.preprocess(val_data, FLAGS.step_size)
###########
## model ##
###########
model_fn = model.get_model_fn(FLAGS, is_training=False)
outputs = model_fn(inputs)
#############
## metrics ##
#############
names_to_values, names_to_updates = model.get_metrics(
inputs, outputs, FLAGS)
del names_to_values
################
## evaluation ##
################
num_batches = int(val_data['num_samples'] / FLAGS.batch_size)
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=eval_dir,
logdir=log_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides dataset dictionaries as used in our network models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.data import dataset
from tensorflow.contrib.slim.python.slim.data import dataset_data_provider
from tensorflow.contrib.slim.python.slim.data import tfexample_decoder
_ITEMS_TO_DESCRIPTIONS = {
'image': 'Images',
'mask': 'Masks',
'vox': 'Voxels'
}
def _get_split(file_pattern, num_samples, num_views, image_size, vox_size):
"""Get dataset.Dataset for the given dataset file pattern and properties."""
# A dictionary from TF-Example keys to tf.FixedLenFeature instance.
keys_to_features = {
'image': tf.FixedLenFeature(
shape=[num_views, image_size, image_size, 3],
dtype=tf.float32, default_value=None),
'mask': tf.FixedLenFeature(
shape=[num_views, image_size, image_size, 1],
dtype=tf.float32, default_value=None),
'vox': tf.FixedLenFeature(
shape=[vox_size, vox_size, vox_size, 1],
dtype=tf.float32, default_value=None),
}
items_to_handler = {
'image': tfexample_decoder.Tensor(
'image', shape=[num_views, image_size, image_size, 3]),
'mask': tfexample_decoder.Tensor(
'mask', shape=[num_views, image_size, image_size, 1]),
'vox': tfexample_decoder.Tensor(
'vox', shape=[vox_size, vox_size, vox_size, 1])
}
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handler)
return dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=num_samples,
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS)
def get(dataset_dir,
dataset_name,
split_name,
shuffle=True,
num_readers=1,
common_queue_capacity=64,
common_queue_min=50):
"""Provides input data for a specified dataset and split."""
dataset_to_kwargs = {
'shapenet_chair': {
'file_pattern': '03001627_%s.tfrecords' % split_name,
'num_views': 24,
'image_size': 64,
'vox_size': 32,
}, 'shapenet_all': {
'file_pattern': '*_%s.tfrecords' % split_name,
'num_views': 24,
'image_size': 64,
'vox_size': 32,
},
}
split_sizes = {
'shapenet_chair': {
'train': 4744,
'val': 678,
'test': 1356,
},
'shapenet_all': {
'train': 30643,
'val': 4378,
'test': 8762,
}
}
kwargs = dataset_to_kwargs[dataset_name]
kwargs['file_pattern'] = os.path.join(dataset_dir, kwargs['file_pattern'])
kwargs['num_samples'] = split_sizes[dataset_name][split_name]
dataset_split = _get_split(**kwargs)
data_provider = dataset_data_provider.DatasetDataProvider(
dataset_split,
num_readers=num_readers,
common_queue_capacity=common_queue_capacity,
common_queue_min=common_queue_min,
shuffle=shuffle)
inputs = {
'num_samples': dataset_split.num_samples,
}
[image, mask, vox] = data_provider.get(['image', 'mask', 'vox'])
inputs['image'] = image
inputs['mask'] = mask
inputs['voxel'] = vox
return inputs
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the various loss functions in use by the PTN model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def add_rotator_image_loss(inputs, outputs, step_size, weight_scale):
"""Computes the image loss of deep rotator model.
Args:
inputs: Input dictionary to the model containing keys
such as `images_k'.
outputs: Output dictionary returned by the model containing keys
such as `images_k'.
step_size: A scalar representing the number of recurrent
steps (number of repeated out-of-plane rotations)
in the deep rotator network (int).
weight_scale: A reweighting factor applied over the image loss (float).
Returns:
A `Tensor' scalar that returns averaged L2 loss
(divided by batch_size and step_size) between the
ground-truth images (RGB) and predicted images (tf.float32).
"""
batch_size = tf.shape(inputs['images_0'])[0]
image_loss = 0
for k in range(1, step_size + 1):
image_loss += tf.nn.l2_loss(
inputs['images_%d' % k] - outputs['images_%d' % k])
image_loss /= tf.to_float(step_size * batch_size)
slim.summaries.add_scalar_summary(
image_loss, 'image_loss', prefix='losses')
image_loss *= weight_scale
return image_loss
def add_rotator_mask_loss(inputs, outputs, step_size, weight_scale):
"""Computes the mask loss of deep rotator model.
Args:
inputs: Input dictionary to the model containing keys
such as `masks_k'.
outputs: Output dictionary returned by the model containing
keys such as `masks_k'.
step_size: A scalar representing the number of recurrent
steps (number of repeated out-of-plane rotations)
in the deep rotator network (int).
weight_scale: A reweighting factor applied over the mask loss (float).
Returns:
A `Tensor' that returns averaged L2 loss
(divided by batch_size and step_size) between the ground-truth masks
(object silhouettes) and predicted masks (tf.float32).
"""
batch_size = tf.shape(inputs['images_0'])[0]
mask_loss = 0
for k in range(1, step_size + 1):
mask_loss += tf.nn.l2_loss(
inputs['masks_%d' % k] - outputs['masks_%d' % k])
mask_loss /= tf.to_float(step_size * batch_size)
slim.summaries.add_scalar_summary(
mask_loss, 'mask_loss', prefix='losses')
mask_loss *= weight_scale
return mask_loss
def add_volume_proj_loss(inputs, outputs, num_views, weight_scale):
"""Computes the projection loss of voxel generation model.
Args:
inputs: Input dictionary to the model containing keys such as
`images_1'.
outputs: Output dictionary returned by the model containing keys
such as `masks_k' and ``projs_k'.
num_views: A integer scalar represents the total number of
viewpoints for each of the object (int).
weight_scale: A reweighting factor applied over the projection loss (float).
Returns:
A `Tensor' that returns the averaged L2 loss
(divided by batch_size and num_views) between the ground-truth
masks (object silhouettes) and predicted masks (tf.float32).
"""
batch_size = tf.shape(inputs['images_1'])[0]
proj_loss = 0
for k in range(num_views):
proj_loss += tf.nn.l2_loss(
outputs['masks_%d' % (k + 1)] - outputs['projs_%d' % (k + 1)])
proj_loss /= tf.to_float(num_views * batch_size)
slim.summaries.add_scalar_summary(
proj_loss, 'proj_loss', prefix='losses')
proj_loss *= weight_scale
return proj_loss
def add_volume_loss(inputs, outputs, num_views, weight_scale):
"""Computes the volume loss of voxel generation model.
Args:
inputs: Input dictionary to the model containing keys such as
`images_1' and `voxels'.
outputs: Output dictionary returned by the model containing keys
such as `voxels_k'.
num_views: A scalar representing the total number of
viewpoints for each object (int).
weight_scale: A reweighting factor applied over the volume
loss (tf.float32).
Returns:
A `Tensor' that returns the averaged L2 loss
(divided by batch_size and num_views) between the ground-truth
volumes and predicted volumes (tf.float32).
"""
batch_size = tf.shape(inputs['images_1'])[0]
vol_loss = 0
for k in range(num_views):
vol_loss += tf.nn.l2_loss(
inputs['voxels'] - outputs['voxels_%d' % (k + 1)])
vol_loss /= tf.to_float(num_views * batch_size)
slim.summaries.add_scalar_summary(
vol_loss, 'vol_loss', prefix='losses')
vol_loss *= weight_scale
return vol_loss
def regularization_loss(scopes, params):
"""Computes the weight decay as regularization during training.
Args:
scopes: A list of different components of the model such as
``encoder'', ``decoder'' and ``projector''.
params: Parameters of the model.
Returns:
Regularization loss (tf.float32).
"""
reg_loss = tf.zeros(dtype=tf.float32, shape=[])
if params.weight_decay > 0:
is_trainable = lambda x: x in tf.trainable_variables()
is_weights = lambda x: 'weights' in x.name
for scope in scopes:
scope_vars = filter(is_trainable,
tf.contrib.framework.get_model_variables(scope))
scope_vars = filter(is_weights, scope_vars)
if scope_vars:
reg_loss += tf.add_n([tf.nn.l2_loss(var) for var in scope_vars])
slim.summaries.add_scalar_summary(
reg_loss, 'reg_loss', prefix='losses')
reg_loss *= params.weight_decay
return reg_loss
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides metrics used by PTN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def add_image_pred_metrics(
inputs, outputs, num_views, upscale_factor):
"""Computes the image prediction metrics.
Args:
inputs: Input dictionary of the deep rotator model (model_rotator.py).
outputs: Output dictionary of the deep rotator model (model_rotator.py).
num_views: An integer scalar representing the total number
of different viewpoints for each object in the dataset.
upscale_factor: A float scalar representing the number of pixels
per image (num_channels x image_height x image_width).
Returns:
names_to_values: A dictionary representing the current value
of the metric.
names_to_updates: A dictionary representing the operation
that accumulates the error from a batch of data.
"""
names_to_values = dict()
names_to_updates = dict()
for k in xrange(num_views):
tmp_value, tmp_update = tf.contrib.metrics.streaming_mean_squared_error(
outputs['images_%d' % (k + 1)], inputs['images_%d' % (k + 1)])
name = 'image_pred/rnn_%d' % (k + 1)
names_to_values.update({name: tmp_value * upscale_factor})
names_to_updates.update({name: tmp_update})
return names_to_values, names_to_updates
def add_mask_pred_metrics(
inputs, outputs, num_views, upscale_factor):
"""Computes the mask prediction metrics.
Args:
inputs: Input dictionary of the deep rotator model (model_rotator.py).
outputs: Output dictionary of the deep rotator model (model_rotator.py).
num_views: An integer scalar representing the total number
of different viewpoints for each object in the dataset.
upscale_factor: A float scalar representing the number of pixels
per image (num_channels x image_height x image_width).
Returns:
names_to_values: A dictionary representing the current value
of the metric.
names_to_updates: A dictionary representing the operation
that accumulates the error from a batch of data.
"""
names_to_values = dict()
names_to_updates = dict()
for k in xrange(num_views):
tmp_value, tmp_update = tf.contrib.metrics.streaming_mean_squared_error(
outputs['masks_%d' % (k + 1)], inputs['masks_%d' % (k + 1)])
name = 'mask_pred/rnn_%d' % (k + 1)
names_to_values.update({name: tmp_value * upscale_factor})
names_to_updates.update({name: tmp_update})
return names_to_values, names_to_updates
def add_volume_iou_metrics(inputs, outputs):
"""Computes the per-instance volume IOU.
Args:
inputs: Input dictionary of the voxel generation model.
outputs: Output dictionary returned by the voxel generation model.
Returns:
names_to_values: metrics->values (dict).
names_to_updates: metrics->ops (dict).
"""
names_to_values = dict()
names_to_updates = dict()
labels = tf.greater_equal(inputs['voxels'], 0.5)
predictions = tf.greater_equal(outputs['voxels_1'], 0.5)
labels = 2 - tf.to_int32(labels)
predictions = 3 - tf.to_int32(predictions) * 2
tmp_values, tmp_updates = tf.metrics.mean_iou(
labels=labels,
predictions=predictions,
num_classes=3)
names_to_values['volume_iou'] = tmp_values * 3.0
names_to_updates['volume_iou'] = tmp_updates
return names_to_values, names_to_updates
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementations for Im2Vox PTN (NIPS16) model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import losses
import metrics
import model_voxel_generation
import utils
from nets import im2vox_factory
slim = tf.contrib.slim
class model_PTN(model_voxel_generation.Im2Vox): # pylint:disable=invalid-name
"""Inherits the generic Im2Vox model class and implements the functions."""
def __init__(self, params):
super(model_PTN, self).__init__(params)
# For testing, this selects all views in input
def preprocess_with_all_views(self, raw_inputs):
(quantity, num_views) = raw_inputs['images'].get_shape().as_list()[:2]
inputs = dict()
inputs['voxels'] = []
inputs['images_1'] = []
for k in xrange(num_views):
inputs['matrix_%d' % (k + 1)] = []
inputs['matrix_1'] = []
for n in xrange(quantity):
for k in xrange(num_views):
inputs['images_1'].append(raw_inputs['images'][n, k, :, :, :])
inputs['voxels'].append(raw_inputs['voxels'][n, :, :, :, :])
tf_matrix = self.get_transform_matrix(k)
inputs['matrix_%d' % (k + 1)].append(tf_matrix)
inputs['images_1'] = tf.stack(inputs['images_1'])
inputs['voxels'] = tf.stack(inputs['voxels'])
for k in xrange(num_views):
inputs['matrix_%d' % (k + 1)] = tf.stack(inputs['matrix_%d' % (k + 1)])
return inputs
def get_model_fn(self, is_training=True, reuse=False, run_projection=True):
return im2vox_factory.get(self._params, is_training, reuse, run_projection)
def get_regularization_loss(self, scopes):
return losses.regularization_loss(scopes, self._params)
def get_loss(self, inputs, outputs):
"""Computes the loss used for PTN paper (projection + volume loss)."""
g_loss = tf.zeros(dtype=tf.float32, shape=[])
if self._params.proj_weight:
g_loss += losses.add_volume_proj_loss(
inputs, outputs, self._params.step_size, self._params.proj_weight)
if self._params.volume_weight:
g_loss += losses.add_volume_loss(inputs, outputs, 1,
self._params.volume_weight)
slim.summaries.add_scalar_summary(g_loss, 'im2vox_loss', prefix='losses')
return g_loss
def get_metrics(self, inputs, outputs):
"""Aggregate the metrics for voxel generation model.
Args:
inputs: Input dictionary of the voxel generation model.
outputs: Output dictionary returned by the voxel generation model.
Returns:
names_to_values: metrics->values (dict).
names_to_updates: metrics->ops (dict).
"""
names_to_values = dict()
names_to_updates = dict()
tmp_values, tmp_updates = metrics.add_volume_iou_metrics(inputs, outputs)
names_to_values.update(tmp_values)
names_to_updates.update(tmp_updates)
for name, value in names_to_values.iteritems():
slim.summaries.add_scalar_summary(
value, name, prefix='eval', print_summary=True)
return names_to_values, names_to_updates
def write_disk_grid(self,
global_step,
log_dir,
input_images,
gt_projs,
pred_projs,
input_voxels=None,
output_voxels=None):
"""Function called by TF to save the prediction periodically."""
summary_freq = self._params.save_every
def write_grid(input_images, gt_projs, pred_projs, global_step,
input_voxels, output_voxels):
"""Native python function to call for writing images to files."""
grid = _build_image_grid(
input_images,
gt_projs,
pred_projs,
input_voxels=input_voxels,
output_voxels=output_voxels)
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
return grid
save_op = tf.py_func(write_grid, [
input_images, gt_projs, pred_projs, global_step, input_voxels,
output_voxels
], [tf.uint8], 'write_grid')[0]
slim.summaries.add_image_summary(
tf.expand_dims(save_op, axis=0), name='grid_vis')
return save_op
def get_transform_matrix(self, view_out):
"""Get the 4x4 Perspective Transfromation matrix used for PTN."""
num_views = self._params.num_views
focal_length = self._params.focal_length
focal_range = self._params.focal_range
phi = 30
theta_interval = 360.0 / num_views
theta = theta_interval * view_out
# pylint: disable=invalid-name
camera_matrix = np.zeros((4, 4), dtype=np.float32)
intrinsic_matrix = np.eye(4, dtype=np.float32)
extrinsic_matrix = np.eye(4, dtype=np.float32)
sin_phi = np.sin(float(phi) / 180.0 * np.pi)
cos_phi = np.cos(float(phi) / 180.0 * np.pi)
sin_theta = np.sin(float(-theta) / 180.0 * np.pi)
cos_theta = np.cos(float(-theta) / 180.0 * np.pi)
rotation_azimuth = np.zeros((3, 3), dtype=np.float32)
rotation_azimuth[0, 0] = cos_theta
rotation_azimuth[2, 2] = cos_theta
rotation_azimuth[0, 2] = -sin_theta
rotation_azimuth[2, 0] = sin_theta
rotation_azimuth[1, 1] = 1.0
## rotation axis -- x
rotation_elevation = np.zeros((3, 3), dtype=np.float32)
rotation_elevation[0, 0] = cos_phi
rotation_elevation[0, 1] = sin_phi
rotation_elevation[1, 0] = -sin_phi
rotation_elevation[1, 1] = cos_phi
rotation_elevation[2, 2] = 1.0
rotation_matrix = np.matmul(rotation_azimuth, rotation_elevation)
displacement = np.zeros((3, 1), dtype=np.float32)
displacement[0, 0] = float(focal_length) + float(focal_range) / 2.0
displacement = np.matmul(rotation_matrix, displacement)
extrinsic_matrix[0:3, 0:3] = rotation_matrix
extrinsic_matrix[0:3, 3:4] = -displacement
intrinsic_matrix[2, 2] = 1.0 / float(focal_length)
intrinsic_matrix[1, 1] = 1.0 / float(focal_length)
camera_matrix = np.matmul(extrinsic_matrix, intrinsic_matrix)
return camera_matrix
def _build_image_grid(input_images,
gt_projs,
pred_projs,
input_voxels,
output_voxels,
vis_size=128):
"""Builds a grid image by concatenating the input images."""
quantity = input_images.shape[0]
for row in xrange(int(quantity / 3)):
for col in xrange(3):
index = row * 3 + col
input_img_ = utils.resize_image(input_images[index, :, :, :], vis_size,
vis_size)
gt_proj_ = utils.resize_image(gt_projs[index, :, :, :], vis_size,
vis_size)
pred_proj_ = utils.resize_image(pred_projs[index, :, :, :], vis_size,
vis_size)
gt_voxel_vis = utils.resize_image(
utils.display_voxel(input_voxels[index, :, :, :, 0]), vis_size,
vis_size)
pred_voxel_vis = utils.resize_image(
utils.display_voxel(output_voxels[index, :, :, :, 0]), vis_size,
vis_size)
if col == 0:
tmp_ = np.concatenate(
[input_img_, gt_proj_, pred_proj_, gt_voxel_vis, pred_voxel_vis], 1)
else:
tmp_ = np.concatenate([
tmp_, input_img_, gt_proj_, pred_proj_, gt_voxel_vis, pred_voxel_vis
], 1)
if row == 0:
out_grid = tmp_
else:
out_grid = np.concatenate([out_grid, tmp_], 0)
return out_grid
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for pretraining (rotator) as described in PTN paper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import input_generator
import losses
import metrics
import utils
from nets import deeprotator_factory
slim = tf.contrib.slim
def _get_data_from_provider(inputs, batch_size, split_name):
"""Returns dictionary of batch input data processed by tf.train.batch."""
images, masks = tf.train.batch(
[inputs['image'], inputs['mask']],
batch_size=batch_size,
num_threads=8,
capacity=8 * batch_size,
name='batching_queues/%s' % (split_name))
outputs = dict()
outputs['images'] = images
outputs['masks'] = masks
outputs['num_samples'] = inputs['num_samples']
return outputs
def get_inputs(dataset_dir, dataset_name, split_name, batch_size, image_size,
is_training):
"""Loads the given dataset and split."""
del image_size # Unused
with tf.variable_scope('data_loading_%s/%s' % (dataset_name, split_name)):
common_queue_min = 50
common_queue_capacity = 256
num_readers = 4
inputs = input_generator.get(
dataset_dir,
dataset_name,
split_name,
shuffle=is_training,
num_readers=num_readers,
common_queue_min=common_queue_min,
common_queue_capacity=common_queue_capacity)
return _get_data_from_provider(inputs, batch_size, split_name)
def preprocess(raw_inputs, step_size):
"""Selects the subset of viewpoints to train on."""
shp = raw_inputs['images'].get_shape().as_list()
quantity = shp[0]
num_views = shp[1]
image_size = shp[2]
del image_size # Unused
batch_rot = np.zeros((quantity, 3), dtype=np.float32)
inputs = dict()
for n in xrange(step_size + 1):
inputs['images_%d' % n] = []
inputs['masks_%d' % n] = []
for n in xrange(quantity):
view_in = np.random.randint(0, num_views)
rng_rot = np.random.randint(0, 2)
if step_size == 1:
rng_rot = np.random.randint(0, 3)
delta = 0
if rng_rot == 0:
delta = -1
batch_rot[n, 2] = 1
elif rng_rot == 1:
delta = 1
batch_rot[n, 0] = 1
else:
delta = 0
batch_rot[n, 1] = 1
inputs['images_0'].append(raw_inputs['images'][n, view_in, :, :, :])
inputs['masks_0'].append(raw_inputs['masks'][n, view_in, :, :, :])
view_out = view_in
for k in xrange(1, step_size + 1):
view_out += delta
if view_out >= num_views:
view_out = 0
if view_out < 0:
view_out = num_views - 1
inputs['images_%d' % k].append(raw_inputs['images'][n, view_out, :, :, :])
inputs['masks_%d' % k].append(raw_inputs['masks'][n, view_out, :, :, :])
for n in xrange(step_size + 1):
inputs['images_%d' % n] = tf.stack(inputs['images_%d' % n])
inputs['masks_%d' % n] = tf.stack(inputs['masks_%d' % n])
inputs['actions'] = tf.constant(batch_rot, dtype=tf.float32)
return inputs
def get_init_fn(scopes, params):
"""Initialization assignment operator function used while training."""
if not params.init_model:
return None
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
params.init_model, var_list)
def init_assign_function(sess):
sess.run(init_assign_op, init_feed_dict)
return init_assign_function
def get_model_fn(params, is_training, reuse=False):
return deeprotator_factory.get(params, is_training, reuse)
def get_regularization_loss(scopes, params):
return losses.regularization_loss(scopes, params)
def get_loss(inputs, outputs, params):
"""Computes the rotator loss."""
g_loss = tf.zeros(dtype=tf.float32, shape=[])
if hasattr(params, 'image_weight'):
g_loss += losses.add_rotator_image_loss(inputs, outputs, params.step_size,
params.image_weight)
if hasattr(params, 'mask_weight'):
g_loss += losses.add_rotator_mask_loss(inputs, outputs, params.step_size,
params.mask_weight)
slim.summaries.add_scalar_summary(
g_loss, 'rotator_loss', prefix='losses')
return g_loss
def get_train_op_for_scope(loss, optimizer, scopes, params):
"""Train operation function for the given scope used file training."""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
update_ops = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
return slim.learning.create_train_op(
loss,
optimizer,
update_ops=update_ops,
variables_to_train=var_list,
clip_gradient_norm=params.clip_gradient_norm)
def get_metrics(inputs, outputs, params):
names_to_values, names_to_updates = metrics.rotator_metrics(
inputs, outputs, params)
return names_to_values, names_to_updates
def write_disk_grid(global_step, summary_freq, log_dir, input_images,
output_images, pred_images, pred_masks):
"""Function called by TF to save the prediction periodically."""
def write_grid(grid, global_step):
"""Native python function to call for writing images to files."""
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
return 0
grid = _build_image_grid(input_images, output_images, pred_images, pred_masks)
slim.summaries.add_image_summary(
tf.expand_dims(grid, axis=0), name='grid_vis')
save_op = tf.py_func(write_grid, [grid, global_step], [tf.int64],
'write_grid')[0]
return save_op
def _build_image_grid(input_images, output_images, pred_images, pred_masks):
"""Builds a grid image by concatenating the input images."""
quantity = input_images.get_shape().as_list()[0]
for row in xrange(int(quantity / 4)):
for col in xrange(4):
index = row * 4 + col
input_img_ = input_images[index, :, :, :]
output_img_ = output_images[index, :, :, :]
pred_img_ = pred_images[index, :, :, :]
pred_mask_ = tf.tile(pred_masks[index, :, :, :], [1, 1, 3])
if col == 0:
tmp_ = tf.concat([input_img_, output_img_, pred_img_, pred_mask_],
1) ## to the right
else:
tmp_ = tf.concat([tmp_, input_img_, output_img_, pred_img_, pred_mask_],
1)
if row == 0:
out_grid = tmp_
else:
out_grid = tf.concat([out_grid, tmp_], 0)
return out_grid
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for voxel generation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import os
import numpy as np
import tensorflow as tf
import input_generator
import utils
slim = tf.contrib.slim
class Im2Vox(object):
"""Defines the voxel generation model."""
__metaclass__ = abc.ABCMeta
def __init__(self, params):
self._params = params
@abc.abstractmethod
def get_metrics(self, inputs, outputs):
"""Gets dictionaries from metrics to value `Tensors` & update `Tensors`."""
pass
@abc.abstractmethod
def get_loss(self, inputs, outputs):
pass
@abc.abstractmethod
def get_regularization_loss(self, scopes):
pass
def set_params(self, params):
self._params = params
def get_inputs(self,
dataset_dir,
dataset_name,
split_name,
batch_size,
image_size,
vox_size,
is_training=True):
"""Loads data for a specified dataset and split."""
del image_size, vox_size
with tf.variable_scope('data_loading_%s/%s' % (dataset_name, split_name)):
common_queue_min = 64
common_queue_capacity = 256
num_readers = 4
inputs = input_generator.get(
dataset_dir,
dataset_name,
split_name,
shuffle=is_training,
num_readers=num_readers,
common_queue_min=common_queue_min,
common_queue_capacity=common_queue_capacity)
images, voxels = tf.train.batch(
[inputs['image'], inputs['voxel']],
batch_size=batch_size,
num_threads=8,
capacity=8 * batch_size,
name='batching_queues/%s/%s' % (dataset_name, split_name))
outputs = dict()
outputs['images'] = images
outputs['voxels'] = voxels
outputs['num_samples'] = inputs['num_samples']
return outputs
def preprocess(self, raw_inputs, step_size):
"""Selects the subset of viewpoints to train on."""
(quantity, num_views) = raw_inputs['images'].get_shape().as_list()[:2]
inputs = dict()
inputs['voxels'] = raw_inputs['voxels']
for k in xrange(step_size):
inputs['images_%d' % (k + 1)] = []
inputs['matrix_%d' % (k + 1)] = []
for n in xrange(quantity):
selected_views = np.random.choice(num_views, step_size, replace=False)
for k in xrange(step_size):
view_selected = selected_views[k]
inputs['images_%d' %
(k + 1)].append(raw_inputs['images'][n, view_selected, :, :, :])
tf_matrix = self.get_transform_matrix(view_selected)
inputs['matrix_%d' % (k + 1)].append(tf_matrix)
for k in xrange(step_size):
inputs['images_%d' % (k + 1)] = tf.stack(inputs['images_%d' % (k + 1)])
inputs['matrix_%d' % (k + 1)] = tf.stack(inputs['matrix_%d' % (k + 1)])
return inputs
def get_init_fn(self, scopes):
"""Initialization assignment operator function used while training."""
if not self._params.init_model:
return None
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
self._params.init_model, var_list)
def init_assign_function(sess):
sess.run(init_assign_op, init_feed_dict)
return init_assign_function
def get_train_op_for_scope(self, loss, optimizer, scopes):
"""Train operation function for the given scope used file training."""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
update_ops = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
return slim.learning.create_train_op(
loss,
optimizer,
update_ops=update_ops,
variables_to_train=var_list,
clip_gradient_norm=self._params.clip_gradient_norm)
def write_disk_grid(self,
global_step,
log_dir,
input_images,
gt_projs,
pred_projs,
pred_voxels=None):
"""Function called by TF to save the prediction periodically."""
summary_freq = self._params.save_every
def write_grid(input_images, gt_projs, pred_projs, pred_voxels,
global_step):
"""Native python function to call for writing images to files."""
grid = _build_image_grid(input_images, gt_projs, pred_projs, pred_voxels)
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
with open(
os.path.join(log_dir, 'pred_voxels_%s' % str(global_step)),
'w') as fout:
np.save(fout, pred_voxels)
with open(
os.path.join(log_dir, 'input_images_%s' % str(global_step)),
'w') as fout:
np.save(fout, input_images)
return grid
py_func_args = [
input_images, gt_projs, pred_projs, pred_voxels, global_step
]
save_grid_op = tf.py_func(write_grid, py_func_args, [tf.uint8],
'wrtie_grid')[0]
slim.summaries.add_image_summary(
tf.expand_dims(save_grid_op, axis=0), name='grid_vis')
return save_grid_op
def _build_image_grid(input_images, gt_projs, pred_projs, pred_voxels):
"""Build the visualization grid with py_func."""
quantity, img_height, img_width = input_images.shape[:3]
for row in xrange(int(quantity / 3)):
for col in xrange(3):
index = row * 3 + col
input_img_ = input_images[index, :, :, :]
gt_proj_ = gt_projs[index, :, :, :]
pred_proj_ = pred_projs[index, :, :, :]
pred_voxel_ = utils.display_voxel(pred_voxels[index, :, :, :, 0])
pred_voxel_ = utils.resize_image(pred_voxel_, img_height, img_width)
if col == 0:
tmp_ = np.concatenate([input_img_, gt_proj_, pred_proj_, pred_voxel_],
1)
else:
tmp_ = np.concatenate(
[tmp_, input_img_, gt_proj_, pred_proj_, pred_voxel_], 1)
if row == 0:
out_grid = tmp_
else:
out_grid = np.concatenate([out_grid, tmp_], 0)
out_grid = out_grid.astype(np.uint8)
return out_grid
package(default_visibility = ["//visibility:public"])
py_library(
name = "deeprotator_factory",
srcs = ["deeprotator_factory.py"],
deps = [
":ptn_encoder",
":ptn_im_decoder",
":ptn_rotator",
],
)
py_library(
name = "im2vox_factory",
srcs = ["im2vox_factory.py"],
deps = [
":perspective_projector",
":ptn_encoder",
":ptn_vox_decoder",
],
)
py_library(
name = "perspective_projector",
srcs = ["perspective_projector.py"],
deps = [
":perspective_transform",
],
)
py_library(
name = "perspective_transform",
srcs = ["perspective_transform.py"],
deps = [
],
)
py_library(
name = "ptn_encoder",
srcs = ["ptn_encoder.py"],
deps = [
],
)
py_library(
name = "ptn_im_decoder",
srcs = ["ptn_im_decoder.py"],
deps = [
],
)
py_library(
name = "ptn_rotator",
srcs = ["ptn_rotator.py"],
deps = [
],
)
py_library(
name = "ptn_vox_decoder",
srcs = ["ptn_vox_decoder.py"],
deps = [
],
)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory module for different encoder/decoder network models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import ptn_encoder
from nets import ptn_im_decoder
from nets import ptn_rotator
_NAME_TO_NETS = {
'ptn_encoder': ptn_encoder,
'ptn_rotator': ptn_rotator,
'ptn_im_decoder': ptn_im_decoder,
}
def _get_network(name):
"""Gets a single network component."""
if name not in _NAME_TO_NETS:
raise ValueError('Network name [%s] not recognized.' % name)
return _NAME_TO_NETS[name].model
def get(params, is_training=False, reuse=False):
"""Factory function to retrieve a network model.
Args:
params: Different parameters used througout ptn, typically FLAGS (dict)
is_training: Set to True if while training (boolean)
reuse: Set as True if either using a pre-trained model or
in the training loop while the graph has already been built (boolean)
Returns:
Model function for network (inputs to outputs)
"""
def model(inputs):
"""Model function corresponding to a specific network architecture."""
outputs = {}
# First, build the encoder.
encoder_fn = _get_network(params.encoder_name)
with tf.variable_scope('encoder', reuse=reuse):
# Produces id/pose units
features = encoder_fn(inputs['images_0'], params, is_training)
outputs['ids'] = features['ids']
outputs['poses_0'] = features['poses']
# Second, build the rotator and decoder.
rotator_fn = _get_network(params.rotator_name)
with tf.variable_scope('rotator', reuse=reuse):
outputs['poses_1'] = rotator_fn(outputs['poses_0'], inputs['actions'],
params, is_training)
decoder_fn = _get_network(params.decoder_name)
with tf.variable_scope('decoder', reuse=reuse):
dec_output = decoder_fn(outputs['ids'], outputs['poses_1'], params,
is_training)
outputs['images_1'] = dec_output['images']
outputs['masks_1'] = dec_output['masks']
# Third, build the recurrent connection
for k in range(1, params.step_size):
with tf.variable_scope('rotator', reuse=True):
outputs['poses_%d' % (k + 1)] = rotator_fn(
outputs['poses_%d' % k], inputs['actions'], params, is_training)
with tf.variable_scope('decoder', reuse=True):
dec_output = decoder_fn(outputs['ids'], outputs['poses_%d' % (k + 1)],
params, is_training)
outputs['images_%d' % (k + 1)] = dec_output['images']
outputs['masks_%d' % (k + 1)] = dec_output['masks']
return outputs
return model
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory module for getting the complete image to voxel generation network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import perspective_projector
from nets import ptn_encoder
from nets import ptn_vox_decoder
_NAME_TO_NETS = {
'ptn_encoder': ptn_encoder,
'ptn_vox_decoder': ptn_vox_decoder,
'perspective_projector': perspective_projector,
}
def _get_network(name):
"""Gets a single encoder/decoder network model."""
if name not in _NAME_TO_NETS:
raise ValueError('Network name [%s] not recognized.' % name)
return _NAME_TO_NETS[name].model
def get(params, is_training=False, reuse=False, run_projection=True):
"""Factory function to get the training/pretraining im->vox model (NIPS16).
Args:
params: Different parameters used througout ptn, typically FLAGS (dict).
is_training: Set to True if while training (boolean).
reuse: Set as True if sharing variables with a model that has already
been built (boolean).
run_projection: Set as False if not interested in mask and projection
images. Useful in evaluation routine (boolean).
Returns:
Model function for network (inputs to outputs).
"""
def model(inputs):
"""Model function corresponding to a specific network architecture."""
outputs = {}
# First, build the encoder
encoder_fn = _get_network(params.encoder_name)
with tf.variable_scope('encoder', reuse=reuse):
# Produces id/pose units
enc_outputs = encoder_fn(inputs['images_1'], params, is_training)
outputs['ids_1'] = enc_outputs['ids']
# Second, build the decoder and projector
decoder_fn = _get_network(params.decoder_name)
with tf.variable_scope('decoder', reuse=reuse):
outputs['voxels_1'] = decoder_fn(outputs['ids_1'], params, is_training)
if run_projection:
projector_fn = _get_network(params.projector_name)
with tf.variable_scope('projector', reuse=reuse):
outputs['projs_1'] = projector_fn(
outputs['voxels_1'], inputs['matrix_1'], params, is_training)
# Infer the ground-truth mask
with tf.variable_scope('oracle', reuse=reuse):
outputs['masks_1'] = projector_fn(inputs['voxels'], inputs['matrix_1'],
params, False)
# Third, build the entire graph (bundled strategy described in PTN paper)
for k in range(1, params.step_size):
with tf.variable_scope('projector', reuse=True):
outputs['projs_%d' % (k + 1)] = projector_fn(
outputs['voxels_1'], inputs['matrix_%d' %
(k + 1)], params, is_training)
with tf.variable_scope('oracle', reuse=True):
outputs['masks_%d' % (k + 1)] = projector_fn(
inputs['voxels'], inputs['matrix_%d' % (k + 1)], params, False)
return outputs
return model
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""3D->2D projector model as used in PTN (NIPS16)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import perspective_transform
def model(voxels, transform_matrix, params, is_training):
"""Model transforming the 3D voxels into 2D projections.
Args:
voxels: A tensor of size [batch, depth, height, width, channel]
representing the input of projection layer (tf.float32).
transform_matrix: A tensor of size [batch, 16] representing
the flattened 4-by-4 matrix for transformation (tf.float32).
params: Model parameters (dict).
is_training: Set to True if while training (boolean).
Returns:
A transformed tensor (tf.float32)
"""
del is_training # Doesn't make a difference for projector
# Rearrangement (batch, z, y, x, channel) --> (batch, y, z, x, channel).
# By the standard, projection happens along z-axis but the voxels
# are stored in a different way. So we need to switch the y and z
# axis for transformation operation.
voxels = tf.transpose(voxels, [0, 2, 1, 3, 4])
z_near = params.focal_length
z_far = params.focal_length + params.focal_range
transformed_voxels = perspective_transform.transformer(
voxels, transform_matrix, [params.vox_size] * 3, z_near, z_far)
views = tf.reduce_max(transformed_voxels, [1])
views = tf.reverse(views, [1])
return views
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Perspective Transformer Layer Implementation.
Transform the volume based on 4 x 4 perspective projection matrix.
Reference:
(1) "Perspective Transformer Nets: Perspective Transformer Nets:
Learning Single-View 3D Object Reconstruction without 3D Supervision."
Xinchen Yan, Jimei Yang, Ersin Yumer, Yijie Guo, Honglak Lee. In NIPS 2016
https://papers.nips.cc/paper/6206-perspective-transformer-nets-learning-single-view-3d-object-reconstruction-without-3d-supervision.pdf
(2) Official implementation in Torch: https://github.com/xcyan/ptnbhwd
(3) 2D Transformer implementation in TF:
github.com/tensorflow/models/tree/master/transformer
"""
import tensorflow as tf
def transformer(voxels,
theta,
out_size,
z_near,
z_far,
name='PerspectiveTransformer'):
"""Perspective Transformer Layer.
Args:
voxels: A tensor of size [num_batch, depth, height, width, num_channels].
It is the output of a deconv/upsampling conv network (tf.float32).
theta: A tensor of size [num_batch, 16].
It is the inverse camera transformation matrix (tf.float32).
out_size: A tuple representing the size of output of
transformer layer (float).
z_near: A number representing the near clipping plane (float).
z_far: A number representing the far clipping plane (float).
Returns:
A transformed tensor (tf.float32).
"""
def _repeat(x, n_repeats):
with tf.variable_scope('_repeat'):
rep = tf.transpose(
tf.expand_dims(tf.ones(shape=tf.stack([
n_repeats,
])), 1), [1, 0])
rep = tf.to_int32(rep)
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
return tf.reshape(x, [-1])
def _interpolate(im, x, y, z, out_size):
"""Bilinear interploation layer.
Args:
im: A 5D tensor of size [num_batch, depth, height, width, num_channels].
It is the input volume for the transformation layer (tf.float32).
x: A tensor of size [num_batch, out_depth, out_height, out_width]
representing the inverse coordinate mapping for x (tf.float32).
y: A tensor of size [num_batch, out_depth, out_height, out_width]
representing the inverse coordinate mapping for y (tf.float32).
z: A tensor of size [num_batch, out_depth, out_height, out_width]
representing the inverse coordinate mapping for z (tf.float32).
out_size: A tuple representing the output size of transformation layer
(float).
Returns:
A transformed tensor (tf.float32).
"""
with tf.variable_scope('_interpolate'):
num_batch = im.get_shape().as_list()[0]
depth = im.get_shape().as_list()[1]
height = im.get_shape().as_list()[2]
width = im.get_shape().as_list()[3]
channels = im.get_shape().as_list()[4]
x = tf.to_float(x)
y = tf.to_float(y)
z = tf.to_float(z)
depth_f = tf.to_float(depth)
height_f = tf.to_float(height)
width_f = tf.to_float(width)
# Number of disparity interpolated.
out_depth = out_size[0]
out_height = out_size[1]
out_width = out_size[2]
zero = tf.zeros([], dtype='int32')
# 0 <= z < depth, 0 <= y < height & 0 <= x < width.
max_z = tf.to_int32(tf.shape(im)[1] - 1)
max_y = tf.to_int32(tf.shape(im)[2] - 1)
max_x = tf.to_int32(tf.shape(im)[3] - 1)
# Converts scale indices from [-1, 1] to [0, width/height/depth].
x = (x + 1.0) * (width_f) / 2.0
y = (y + 1.0) * (height_f) / 2.0
z = (z + 1.0) * (depth_f) / 2.0
x0 = tf.to_int32(tf.floor(x))
x1 = x0 + 1
y0 = tf.to_int32(tf.floor(y))
y1 = y0 + 1
z0 = tf.to_int32(tf.floor(z))
z1 = z0 + 1
x0_clip = tf.clip_by_value(x0, zero, max_x)
x1_clip = tf.clip_by_value(x1, zero, max_x)
y0_clip = tf.clip_by_value(y0, zero, max_y)
y1_clip = tf.clip_by_value(y1, zero, max_y)
z0_clip = tf.clip_by_value(z0, zero, max_z)
z1_clip = tf.clip_by_value(z1, zero, max_z)
dim3 = width
dim2 = width * height
dim1 = width * height * depth
base = _repeat(
tf.range(num_batch) * dim1, out_depth * out_height * out_width)
base_z0_y0 = base + z0_clip * dim2 + y0_clip * dim3
base_z0_y1 = base + z0_clip * dim2 + y1_clip * dim3
base_z1_y0 = base + z1_clip * dim2 + y0_clip * dim3
base_z1_y1 = base + z1_clip * dim2 + y1_clip * dim3
idx_z0_y0_x0 = base_z0_y0 + x0_clip
idx_z0_y0_x1 = base_z0_y0 + x1_clip
idx_z0_y1_x0 = base_z0_y1 + x0_clip
idx_z0_y1_x1 = base_z0_y1 + x1_clip
idx_z1_y0_x0 = base_z1_y0 + x0_clip
idx_z1_y0_x1 = base_z1_y0 + x1_clip
idx_z1_y1_x0 = base_z1_y1 + x0_clip
idx_z1_y1_x1 = base_z1_y1 + x1_clip
# Use indices to lookup pixels in the flat image and restore
# channels dim
im_flat = tf.reshape(im, tf.stack([-1, channels]))
im_flat = tf.to_float(im_flat)
i_z0_y0_x0 = tf.gather(im_flat, idx_z0_y0_x0)
i_z0_y0_x1 = tf.gather(im_flat, idx_z0_y0_x1)
i_z0_y1_x0 = tf.gather(im_flat, idx_z0_y1_x0)
i_z0_y1_x1 = tf.gather(im_flat, idx_z0_y1_x1)
i_z1_y0_x0 = tf.gather(im_flat, idx_z1_y0_x0)
i_z1_y0_x1 = tf.gather(im_flat, idx_z1_y0_x1)
i_z1_y1_x0 = tf.gather(im_flat, idx_z1_y1_x0)
i_z1_y1_x1 = tf.gather(im_flat, idx_z1_y1_x1)
# Finally calculate interpolated values.
x0_f = tf.to_float(x0)
x1_f = tf.to_float(x1)
y0_f = tf.to_float(y0)
y1_f = tf.to_float(y1)
z0_f = tf.to_float(z0)
z1_f = tf.to_float(z1)
# Check the out-of-boundary case.
x0_valid = tf.to_float(
tf.less_equal(x0, max_x) & tf.greater_equal(x0, 0))
x1_valid = tf.to_float(
tf.less_equal(x1, max_x) & tf.greater_equal(x1, 0))
y0_valid = tf.to_float(
tf.less_equal(y0, max_y) & tf.greater_equal(y0, 0))
y1_valid = tf.to_float(
tf.less_equal(y1, max_y) & tf.greater_equal(y1, 0))
z0_valid = tf.to_float(
tf.less_equal(z0, max_z) & tf.greater_equal(z0, 0))
z1_valid = tf.to_float(
tf.less_equal(z1, max_z) & tf.greater_equal(z1, 0))
w_z0_y0_x0 = tf.expand_dims(((x1_f - x) * (y1_f - y) *
(z1_f - z) * x1_valid * y1_valid * z1_valid),
1)
w_z0_y0_x1 = tf.expand_dims(((x - x0_f) * (y1_f - y) *
(z1_f - z) * x0_valid * y1_valid * z1_valid),
1)
w_z0_y1_x0 = tf.expand_dims(((x1_f - x) * (y - y0_f) *
(z1_f - z) * x1_valid * y0_valid * z1_valid),
1)
w_z0_y1_x1 = tf.expand_dims(((x - x0_f) * (y - y0_f) *
(z1_f - z) * x0_valid * y0_valid * z1_valid),
1)
w_z1_y0_x0 = tf.expand_dims(((x1_f - x) * (y1_f - y) *
(z - z0_f) * x1_valid * y1_valid * z0_valid),
1)
w_z1_y0_x1 = tf.expand_dims(((x - x0_f) * (y1_f - y) *
(z - z0_f) * x0_valid * y1_valid * z0_valid),
1)
w_z1_y1_x0 = tf.expand_dims(((x1_f - x) * (y - y0_f) *
(z - z0_f) * x1_valid * y0_valid * z0_valid),
1)
w_z1_y1_x1 = tf.expand_dims(((x - x0_f) * (y - y0_f) *
(z - z0_f) * x0_valid * y0_valid * z0_valid),
1)
output = tf.add_n([
w_z0_y0_x0 * i_z0_y0_x0, w_z0_y0_x1 * i_z0_y0_x1,
w_z0_y1_x0 * i_z0_y1_x0, w_z0_y1_x1 * i_z0_y1_x1,
w_z1_y0_x0 * i_z1_y0_x0, w_z1_y0_x1 * i_z1_y0_x1,
w_z1_y1_x0 * i_z1_y1_x0, w_z1_y1_x1 * i_z1_y1_x1
])
return output
def _meshgrid(depth, height, width, z_near, z_far):
with tf.variable_scope('_meshgrid'):
x_t = tf.reshape(
tf.tile(tf.linspace(-1.0, 1.0, width), [height * depth]),
[depth, height, width])
y_t = tf.reshape(
tf.tile(tf.linspace(-1.0, 1.0, height), [width * depth]),
[depth, width, height])
y_t = tf.transpose(y_t, [0, 2, 1])
sample_grid = tf.tile(
tf.linspace(float(z_near), float(z_far), depth), [width * height])
z_t = tf.reshape(sample_grid, [height, width, depth])
z_t = tf.transpose(z_t, [2, 0, 1])
z_t = 1 / z_t
d_t = 1 / z_t
x_t /= z_t
y_t /= z_t
x_t_flat = tf.reshape(x_t, (1, -1))
y_t_flat = tf.reshape(y_t, (1, -1))
d_t_flat = tf.reshape(d_t, (1, -1))
ones = tf.ones_like(x_t_flat)
grid = tf.concat([d_t_flat, y_t_flat, x_t_flat, ones], 0)
return grid
def _transform(theta, input_dim, out_size, z_near, z_far):
with tf.variable_scope('_transform'):
num_batch = input_dim.get_shape().as_list()[0]
num_channels = input_dim.get_shape().as_list()[4]
theta = tf.reshape(theta, (-1, 4, 4))
theta = tf.cast(theta, 'float32')
out_depth = out_size[0]
out_height = out_size[1]
out_width = out_size[2]
grid = _meshgrid(out_depth, out_height, out_width, z_near, z_far)
grid = tf.expand_dims(grid, 0)
grid = tf.reshape(grid, [-1])
grid = tf.tile(grid, tf.stack([num_batch]))
grid = tf.reshape(grid, tf.stack([num_batch, 4, -1]))
# Transform A x (x_t', y_t', 1, d_t)^T -> (x_s, y_s, z_s, 1).
t_g = tf.matmul(theta, grid)
z_s = tf.slice(t_g, [0, 0, 0], [-1, 1, -1])
y_s = tf.slice(t_g, [0, 1, 0], [-1, 1, -1])
x_s = tf.slice(t_g, [0, 2, 0], [-1, 1, -1])
z_s_flat = tf.reshape(z_s, [-1])
y_s_flat = tf.reshape(y_s, [-1])
x_s_flat = tf.reshape(x_s, [-1])
input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, z_s_flat,
out_size)
output = tf.reshape(
input_transformed,
tf.stack([num_batch, out_depth, out_height, out_width, num_channels]))
return output
with tf.variable_scope(name):
output = _transform(theta, voxels, out_size, z_near, z_far)
return output
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training/Pretraining encoder as used in PTN (NIPS16)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def _preprocess(images):
return images * 2 - 1
def model(images, params, is_training):
"""Model encoding the images into view-invariant embedding."""
del is_training # Unused
image_size = images.get_shape().as_list()[1]
f_dim = params.f_dim
fc_dim = params.fc_dim
z_dim = params.z_dim
outputs = dict()
images = _preprocess(images)
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_initializer=tf.truncated_normal_initializer(stddev=0.02, seed=1)):
h0 = slim.conv2d(images, f_dim, [5, 5], stride=2, activation_fn=tf.nn.relu)
h1 = slim.conv2d(h0, f_dim * 2, [5, 5], stride=2, activation_fn=tf.nn.relu)
h2 = slim.conv2d(h1, f_dim * 4, [5, 5], stride=2, activation_fn=tf.nn.relu)
# Reshape layer
s8 = image_size // 8
h2 = tf.reshape(h2, [-1, s8 * s8 * f_dim * 4])
h3 = slim.fully_connected(h2, fc_dim, activation_fn=tf.nn.relu)
h4 = slim.fully_connected(h3, fc_dim, activation_fn=tf.nn.relu)
outputs['ids'] = slim.fully_connected(h4, z_dim, activation_fn=tf.nn.relu)
outputs['poses'] = slim.fully_connected(h4, z_dim, activation_fn=tf.nn.relu)
return outputs
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image/Mask decoder used while pretraining the network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
_FEATURE_MAP_SIZE = 8
def _postprocess_im(images):
"""Performs post-processing for the images returned from conv net.
Transforms the value from [-1, 1] to [0, 1].
"""
return (images + 1) * 0.5
def model(identities, poses, params, is_training):
"""Decoder model to get image and mask from latent embedding."""
del is_training
f_dim = params.f_dim
fc_dim = params.fc_dim
outputs = dict()
with slim.arg_scope(
[slim.fully_connected, slim.conv2d_transpose],
weights_initializer=tf.truncated_normal_initializer(stddev=0.02, seed=1)):
# Concatenate the identity and pose units
h0 = tf.concat([identities, poses], 1)
h0 = slim.fully_connected(h0, fc_dim, activation_fn=tf.nn.relu)
h1 = slim.fully_connected(h0, fc_dim, activation_fn=tf.nn.relu)
# Mask decoder
dec_m0 = slim.fully_connected(
h1, (_FEATURE_MAP_SIZE**2) * f_dim * 2, activation_fn=tf.nn.relu)
dec_m0 = tf.reshape(
dec_m0, [-1, _FEATURE_MAP_SIZE, _FEATURE_MAP_SIZE, f_dim * 2])
dec_m1 = slim.conv2d_transpose(
dec_m0, f_dim, [5, 5], stride=2, activation_fn=tf.nn.relu)
dec_m2 = slim.conv2d_transpose(
dec_m1, int(f_dim / 2), [5, 5], stride=2, activation_fn=tf.nn.relu)
dec_m3 = slim.conv2d_transpose(
dec_m2, 1, [5, 5], stride=2, activation_fn=tf.nn.sigmoid)
# Image decoder
dec_i0 = slim.fully_connected(
h1, (_FEATURE_MAP_SIZE**2) * f_dim * 4, activation_fn=tf.nn.relu)
dec_i0 = tf.reshape(
dec_i0, [-1, _FEATURE_MAP_SIZE, _FEATURE_MAP_SIZE, f_dim * 4])
dec_i1 = slim.conv2d_transpose(
dec_i0, f_dim * 2, [5, 5], stride=2, activation_fn=tf.nn.relu)
dec_i2 = slim.conv2d_transpose(
dec_i1, f_dim * 2, [5, 5], stride=2, activation_fn=tf.nn.relu)
dec_i3 = slim.conv2d_transpose(
dec_i2, 3, [5, 5], stride=2, activation_fn=tf.nn.tanh)
outputs = dict()
outputs['images'] = _postprocess_im(dec_i3)
outputs['masks'] = dec_m3
return outputs
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Creates rotator network model.
This model performs the out-of-plane rotations given input image and action.
The action is either no-op, rotate clockwise or rotate counter-clockwise.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def bilinear(input_x, input_y, output_size):
"""Define the bilinear transformation layer."""
shape_x = input_x.get_shape().as_list()
shape_y = input_y.get_shape().as_list()
weights_initializer = tf.truncated_normal_initializer(stddev=0.02,
seed=1)
biases_initializer = tf.constant_initializer(0.0)
matrix = tf.get_variable("Matrix", [shape_x[1], shape_y[1], output_size],
tf.float32, initializer=weights_initializer)
bias = tf.get_variable("Bias", [output_size],
initializer=biases_initializer)
# Add to GraphKeys.MODEL_VARIABLES
tf.contrib.framework.add_model_variable(matrix)
tf.contrib.framework.add_model_variable(bias)
# Define the transformation
h0 = tf.matmul(input_x, tf.reshape(matrix,
[shape_x[1], shape_y[1]*output_size]))
h0 = tf.reshape(h0, [-1, shape_y[1], output_size])
h1 = tf.tile(tf.reshape(input_y, [-1, shape_y[1], 1]),
[1, 1, output_size])
h1 = tf.multiply(h0, h1)
return tf.reduce_sum(h1, 1) + bias
def model(poses, actions, params, is_training):
"""Model for performing rotation."""
del is_training # Unused
return bilinear(poses, actions, params.z_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment