Commit 748eceae authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Merge branch 'master' into cifar10_experiment

parents 40e906d2 ed65b632
......@@ -116,24 +116,46 @@ class SsdMetaArchTest(tf.test.TestCase):
localization_loss_weight, normalize_loss_by_num_matches,
hard_example_miner)
def test_preprocess_preserves_input_shapes(self):
image_shapes = [(3, None, None, 3),
(None, 10, 10, 3),
(None, None, None, 3)]
for image_shape in image_shapes:
image_placeholder = tf.placeholder(tf.float32, shape=image_shape)
preprocessed_inputs = self._model.preprocess(image_placeholder)
self.assertAllEqual(preprocessed_inputs.shape.as_list(), image_shape)
def test_predict_results_have_correct_keys_and_shapes(self):
batch_size = 3
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3),
dtype=tf.float32)
prediction_dict = self._model.predict(preprocessed_input)
self.assertTrue('box_encodings' in prediction_dict)
self.assertTrue('class_predictions_with_background' in prediction_dict)
self.assertTrue('feature_maps' in prediction_dict)
image_size = 2
input_shapes = [(batch_size, image_size, image_size, 3),
(None, image_size, image_size, 3),
(batch_size, None, None, 3),
(None, None, None, 3)]
expected_box_encodings_shape_out = (
batch_size, self._num_anchors, self._code_size)
expected_class_predictions_with_background_shape_out = (
batch_size, self._num_anchors, self._num_classes+1)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
prediction_out = sess.run(prediction_dict)
for input_shape in input_shapes:
tf_graph = tf.Graph()
with tf_graph.as_default():
preprocessed_input_placeholder = tf.placeholder(tf.float32,
shape=input_shape)
prediction_dict = self._model.predict(preprocessed_input_placeholder)
self.assertTrue('box_encodings' in prediction_dict)
self.assertTrue('class_predictions_with_background' in prediction_dict)
self.assertTrue('feature_maps' in prediction_dict)
init_op = tf.global_variables_initializer()
with self.test_session(graph=tf_graph) as sess:
sess.run(init_op)
prediction_out = sess.run(prediction_dict,
feed_dict={
preprocessed_input_placeholder:
np.random.uniform(
size=(batch_size, 2, 2, 3))})
self.assertAllEqual(prediction_out['box_encodings'].shape,
expected_box_encodings_shape_out)
self.assertAllEqual(
......@@ -142,10 +164,11 @@ class SsdMetaArchTest(tf.test.TestCase):
def test_postprocess_results_are_correct(self):
batch_size = 2
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3),
dtype=tf.float32)
prediction_dict = self._model.predict(preprocessed_input)
detections = self._model.postprocess(prediction_dict)
image_size = 2
input_shapes = [(batch_size, image_size, image_size, 3),
(None, image_size, image_size, 3),
(batch_size, None, None, 3),
(None, None, None, 3)]
expected_boxes = np.array([[[0, 0, .5, .5],
[0, .5, .5, 1],
......@@ -163,15 +186,25 @@ class SsdMetaArchTest(tf.test.TestCase):
[0, 0, 0, 0, 0]])
expected_num_detections = np.array([4, 4])
self.assertTrue('detection_boxes' in detections)
self.assertTrue('detection_scores' in detections)
self.assertTrue('detection_classes' in detections)
self.assertTrue('num_detections' in detections)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
detections_out = sess.run(detections)
for input_shape in input_shapes:
tf_graph = tf.Graph()
with tf_graph.as_default():
preprocessed_input_placeholder = tf.placeholder(tf.float32,
shape=input_shape)
prediction_dict = self._model.predict(preprocessed_input_placeholder)
detections = self._model.postprocess(prediction_dict)
self.assertTrue('detection_boxes' in detections)
self.assertTrue('detection_scores' in detections)
self.assertTrue('detection_classes' in detections)
self.assertTrue('num_detections' in detections)
init_op = tf.global_variables_initializer()
with self.test_session(graph=tf_graph) as sess:
sess.run(init_op)
detections_out = sess.run(detections,
feed_dict={
preprocessed_input_placeholder:
np.random.uniform(
size=(batch_size, 2, 2, 3))})
self.assertAllClose(detections_out['detection_boxes'], expected_boxes)
self.assertAllClose(detections_out['detection_scores'], expected_scores)
self.assertAllClose(detections_out['detection_classes'], expected_classes)
......@@ -207,20 +240,21 @@ class SsdMetaArchTest(tf.test.TestCase):
self.assertAllClose(losses_out['classification_loss'],
expected_classification_loss)
def test_restore_fn_detection(self):
def test_restore_map_for_detection_ckpt(self):
init_op = tf.global_variables_initializer()
saver = tf_saver.Saver()
save_path = self.get_temp_dir()
with self.test_session() as sess:
sess.run(init_op)
saved_model_path = saver.save(sess, save_path)
restore_fn = self._model.restore_fn(saved_model_path,
from_detection_checkpoint=True)
restore_fn(sess)
var_map = self._model.restore_map(from_detection_checkpoint=True)
self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name)
def test_restore_fn_classification(self):
def test_restore_map_for_classification_ckpt(self):
# Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph()
with test_graph_classification.as_default():
......@@ -246,10 +280,11 @@ class SsdMetaArchTest(tf.test.TestCase):
preprocessed_inputs = self._model.preprocess(inputs)
prediction_dict = self._model.predict(preprocessed_inputs)
self._model.postprocess(prediction_dict)
restore_fn = self._model.restore_fn(saved_model_path,
from_detection_checkpoint=False)
var_map = self._model.restore_map(from_detection_checkpoint=False)
self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess:
restore_fn(sess)
saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name)
......
......@@ -94,7 +94,6 @@ py_library(
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch",
"//tensorflow_models/object_detection/utils:variables_helper",
"//tensorflow_models/slim:inception_resnet_v2",
],
)
......
......@@ -25,7 +25,6 @@ Huang et al. (https://arxiv.org/abs/1611.10012)
import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import variables_helper
from nets import inception_resnet_v2
slim = tf.contrib.slim
......@@ -168,30 +167,30 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
def restore_from_classification_checkpoint_fn(
self,
checkpoint_path,
first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope):
"""Returns callable for loading a checkpoint into the tensorflow graph.
"""Returns a map of variables to load from a foreign checkpoint.
Note that this overrides the default implementation in
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
InceptionResnetV2 checkpoints.
TODO: revisit whether it's possible to force the `Repeat` namescope as
created in `_extract_box_classifier_features` to start counting at 2 (e.g.
`Repeat_2`) so that the default restore_fn can be used.
TODO: revisit whether it's possible to force the
`Repeat` namescope as created in `_extract_box_classifier_features` to
start counting at 2 (e.g. `Repeat_2`) so that the default restore_fn can
be used.
Args:
checkpoint_path: Path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor.
Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
A dict mapping variable names (to load from a checkpoint) to variables in
the model graph.
"""
variables_to_restore = {}
for variable in tf.global_variables():
if variable.op.name.startswith(
......@@ -207,10 +206,4 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
var_name = var_name.replace(
second_stage_feature_extractor_scope + '/', '')
variables_to_restore[var_name] = variable
variables_to_restore = (
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
return variables_to_restore
......@@ -63,7 +63,7 @@ class MultiResolutionFeatureMapGeneratorTest(tf.test.TestCase):
sess.run(init_op)
out_feature_maps = sess.run(feature_maps)
out_feature_map_shapes = dict(
(key, value.shape) for key, value in out_feature_maps.iteritems())
(key, value.shape) for key, value in out_feature_maps.items())
self.assertDictEqual(out_feature_map_shapes, expected_feature_map_shapes)
def test_get_expected_feature_map_shapes_with_inception_v3(self):
......@@ -93,7 +93,7 @@ class MultiResolutionFeatureMapGeneratorTest(tf.test.TestCase):
sess.run(init_op)
out_feature_maps = sess.run(feature_maps)
out_feature_map_shapes = dict(
(key, value.shape) for key, value in out_feature_maps.iteritems())
(key, value.shape) for key, value in out_feature_maps.items())
self.assertDictEqual(out_feature_map_shapes, expected_feature_map_shapes)
......
......@@ -140,9 +140,9 @@
"opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)\n",
"tar_file = tarfile.open(MODEL_FILE)\n",
"for file in tar_file.getmembers():\n",
" file_name = os.path.basename(file.name)\n",
" if 'frozen_inference_graph.pb' in file_name:\n",
" tar_file.extract(file, os.getcwd())"
" file_name = os.path.basename(file.name)\n",
" if 'frozen_inference_graph.pb' in file_name:\n",
" tar_file.extract(file, os.getcwd())"
]
},
{
......@@ -162,11 +162,11 @@
"source": [
"detection_graph = tf.Graph()\n",
"with detection_graph.as_default():\n",
" od_graph_def = tf.GraphDef()\n",
" with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:\n",
" serialized_graph = fid.read()\n",
" od_graph_def.ParseFromString(serialized_graph)\n",
" tf.import_graph_def(od_graph_def, name='')"
" od_graph_def = tf.GraphDef()\n",
" with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:\n",
" serialized_graph = fid.read()\n",
" od_graph_def.ParseFromString(serialized_graph)\n",
" tf.import_graph_def(od_graph_def, name='')"
]
},
{
......
......@@ -211,9 +211,15 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
# Create ops required to initialize the model from a given checkpoint.
init_fn = None
if train_config.fine_tune_checkpoint:
init_fn = detection_model.restore_fn(
train_config.fine_tune_checkpoint,
var_map = detection_model.restore_map(
from_detection_checkpoint=train_config.from_detection_checkpoint)
available_var_map = (variables_helper.
get_variables_available_in_checkpoint(
var_map, train_config.fine_tune_checkpoint))
init_saver = tf.train.Saver(available_var_map)
def initializer_fn(sess):
init_saver.restore(sess, train_config.fine_tune_checkpoint)
init_fn = initializer_fn
with tf.device(deploy_config.optimizer_device()):
total_loss, grads_and_vars = model_deploy.optimize_clones(
......
......@@ -139,21 +139,18 @@ class FakeDetectionModel(model.DetectionModel):
}
return loss_dict
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True):
"""Return callable for loading a checkpoint into the tensorflow graph.
def restore_map(self, from_detection_checkpoint=True):
"""Returns a map of variables to load from a foreign checkpoint.
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Returns:
a callable which takes a tf.Session and does nothing.
A dict mapping variable names to variables.
"""
def restore(unused_sess):
return
return restore
return {var.op.name: var for var in tf.global_variables()}
class TrainerTest(tf.test.TestCase):
......
......@@ -120,6 +120,7 @@ py_library(
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/core:matcher",
"//tensorflow_models/object_detection/utils:shape_utils"
],
)
......
......@@ -22,6 +22,20 @@ from google.protobuf import text_format
from object_detection.protos import string_int_label_map_pb2
def _validate_label_map(label_map):
"""Checks if a label map is valid.
Args:
label_map: StringIntLabelMap to validate.
Raises:
ValueError: if label map is invalid.
"""
for item in label_map.item:
if item.id < 1:
raise ValueError('Label map ids should be >= 1.')
def create_category_index(categories):
"""Creates dictionary of COCO compatible categories keyed by category id.
......@@ -61,7 +75,7 @@ def convert_label_map_to_categories(label_map,
list is created with max_num_classes categories.
max_num_classes: maximum number of (consecutive) label indices to include.
use_display_name: (boolean) choose whether to load 'display_name' field
as category name. If False of if the display_name field does not exist,
as category name. If False or if the display_name field does not exist,
uses 'name' field as category names instead.
Returns:
categories: a list of dictionaries representing all possible categories.
......@@ -91,7 +105,6 @@ def convert_label_map_to_categories(label_map,
return categories
# TODO: double check documentaion.
def load_labelmap(path):
"""Loads label map proto.
......@@ -107,6 +120,7 @@ def load_labelmap(path):
text_format.Merge(label_map_string, label_map)
except text_format.ParseError:
label_map.ParseFromString(label_map_string)
_validate_label_map(label_map)
return label_map
......
......@@ -53,7 +53,29 @@ class LabelMapUtilTest(tf.test.TestCase):
self.assertEqual(label_map_dict['dog'], 1)
self.assertEqual(label_map_dict['cat'], 2)
def test_keep_categories_with_unique_id(self):
def test_load_bad_label_map(self):
label_map_string = """
item {
id:0
name:'class that should not be indexed at zero'
}
item {
id:2
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
with self.assertRaises(ValueError):
label_map_util.load_labelmap(label_map_path)
def test_keep_categories_with_unique_id(self):
label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
label_map_string = """
item {
......
......@@ -111,3 +111,26 @@ def pad_or_clip_tensor(t, length):
if not _is_tensor(length):
processed_t = _set_dim_0(processed_t, length)
return processed_t
def combined_static_and_dynamic_shape(tensor):
"""Returns a list containing static and dynamic values for the dimensions.
Returns a list of static and dynamic values for shape dimensions. This is
useful to preserve static shapes when available in reshape operation.
Args:
tensor: A tensor of any type.
Returns:
A list of size tensor.shape.ndims containing integers or a scalar tensor.
"""
static_shape = tensor.shape.as_list()
dynamic_shape = tf.shape(tensor)
combined_shape = []
for index, dim in enumerate(static_shape):
if dim is not None:
combined_shape.append(dim)
else:
combined_shape.append(dynamic_shape[index])
return combined_shape
......@@ -115,6 +115,13 @@ class UtilTest(tf.test.TestCase):
self.assertAllEqual([1, 2], tt3_result)
self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], tt4_result)
def test_combines_static_dynamic_shape(self):
tensor = tf.placeholder(tf.float32, shape=(None, 2, 3))
combined_shape = shape_utils.combined_static_and_dynamic_shape(
tensor)
self.assertTrue(tf.contrib.framework.is_tensor(combined_shape[0]))
self.assertListEqual(combined_shape[1:], [2, 3])
if __name__ == '__main__':
tf.test.main()
......@@ -22,6 +22,7 @@ from object_detection.core import box_coder
from object_detection.core import box_list
from object_detection.core import box_predictor
from object_detection.core import matcher
from object_detection.utils import shape_utils
class MockBoxCoder(box_coder.BoxCoder):
......@@ -45,9 +46,10 @@ class MockBoxPredictor(box_predictor.BoxPredictor):
super(MockBoxPredictor, self).__init__(is_training, num_classes)
def _predict(self, image_features, num_predictions_per_location):
batch_size = image_features.get_shape().as_list()[0]
num_anchors = (image_features.get_shape().as_list()[1]
* image_features.get_shape().as_list()[2])
combined_feature_shape = shape_utils.combined_static_and_dynamic_shape(
image_features)
batch_size = combined_feature_shape[0]
num_anchors = (combined_feature_shape[1] * combined_feature_shape[2])
code_size = 4
zero = tf.reduce_sum(0 * image_features)
box_encodings = zero + tf.zeros(
......
......@@ -398,7 +398,7 @@ def visualize_boxes_and_labels_on_image_array(image,
classes[i] % len(STANDARD_COLORS)]
# Draw all boxes onto image.
for box, color in six.iteritems(box_to_color_map):
for box, color in box_to_color_map.items():
ymin, xmin, ymax, xmax = box
if instance_masks is not None:
draw_mask_on_image_array(
......
bazel
.idea
bazel-bin
bazel-out
bazel-genfiles
bazel-ptn
bazel-testlogs
WORKSPACE
*.pyc
py_library(
name = "input_generator",
srcs = ["input_generator.py"],
deps = [
],
)
py_library(
name = "losses",
srcs = ["losses.py"],
deps = [
],
)
py_library(
name = "metrics",
srcs = ["metrics.py"],
deps = [
],
)
py_library(
name = "utils",
srcs = ["utils.py"],
deps = [
],
)
# Defines the Rotator model here
py_library(
name = "model_rotator",
srcs = ["model_rotator.py"],
deps = [
":input_generator",
":losses",
":metrics",
":utils",
"//nets:deeprotator_factory",
],
)
# Defines the Im2vox model here
py_library(
name = "model_voxel_generation",
srcs = ["model_voxel_generation.py"],
deps = [
":input_generator",
"//nets:im2vox_factory",
],
)
py_library(
name = "model_ptn",
srcs = ["model_ptn.py"],
deps = [
":losses",
":metrics",
":model_voxel_generation",
":utils",
"//nets:im2vox_factory",
],
)
py_binary(
name = "train_ptn",
srcs = ["train_ptn.py"],
deps = [
":model_ptn",
],
)
py_binary(
name = "eval_ptn",
srcs = ["eval_ptn.py"],
deps = [
":model_ptn",
],
)
py_binary(
name = "pretrain_rotator",
srcs = ["pretrain_rotator.py"],
deps = [
":model_rotator",
],
)
py_binary(
name = "eval_rotator",
srcs = ["eval_rotator.py"],
deps = [
":model_rotator",
],
)
# Perspective Transformer Nets
## Introduction
This is the TensorFlow implementation for the NIPS 2016 work ["Perspective Transformer Nets: Learning Single-View 3D Object Reconstrution without 3D Supervision"](https://papers.nips.cc/paper/6206-perspective-transformer-nets-learning-single-view-3d-object-reconstruction-without-3d-supervision.pdf)
Re-implemented by Xinchen Yan, Arkanath Pathak, Jasmine Hsu, Honglak Lee
Reference: [Orginal implementation in Torch](https://github.com/xcyan/nips16_PTN)
## How to run this code
This implementation is ready to be run locally or ["distributed across multiple machines/tasks"](https://www.tensorflow.org/deploy/distributed).
You will need to set the task number flag for each task when running in a distributed fashion.
Please refer to the original paper for parameter explanations and training details.
### Installation
* TensorFlow
* This code requires the latest open-source TensorFlow that you will need to build manually.
The [documentation](https://www.tensorflow.org/install/install_sources) provides the steps required for that.
* Bazel
* Follow the instructions [here](http://bazel.build/docs/install.html).
* Alternately, Download bazel from
[https://github.com/bazelbuild/bazel/releases](https://github.com/bazelbuild/bazel/releases)
for your system configuration.
* Check for the bazel version using this command: bazel version
* matplotlib
* Follow the instructions [here](https://matplotlib.org/users/installing.html).
* You can use a package repository like pip.
* scikit-image
* Follow the instructions [here](http://scikit-image.org/docs/dev/install.html).
* You can use a package repository like pip.
* PIL
* Install from [here](https://pypi.python.org/pypi/Pillow/2.2.1).
### Dataset
This code requires the dataset to be in *tfrecords* format with the following features:
* image
* Flattened list of image (float representations) for each view point.
* mask
* Flattened list of image masks (float representations) for each view point.
* vox
* Flattened list of voxels (float representations) for the object.
* This is needed for using vox loss and for prediction comparison.
You can download the ShapeNet Dataset in tfrecords format from [here](https://drive.google.com/file/d/0B12XukcbU7T7OHQ4MGh6d25qQlk)<sup>*</sup>.
<sup>*</sup> Disclaimer: This data is hosted personally by Arkanath Pathak for non-commercial research purposes. Please cite the [ShapeNet paper](https://arxiv.org/pdf/1512.03012.pdf) in your works when using ShapeNet for non-commercial research purposes.
### Pretraining: pretrain_rotator.py for each RNN step
$ bazel run -c opt :pretrain_rotator -- --step_size={} --init_model={}
Pass the init_model as the checkpoint path for the last step trained model.
You'll also need to set the inp_dir flag to where your data resides.
### Training: train_ptn.py with last pretrained model.
$ bazel run -c opt :train_ptn -- --init_model={}
### Example TensorBoard Visualizations
To compare the visualizations make sure to set the model_name flag different for each parametric setting:
This code adds summaries for each loss. For instance, these are the losses we encountered in the distributed pretraining for ShapeNet Chair Dataset with 10 workers and 16 parameter servers:
![ShapeNet Chair Pretraining](https://drive.google.com/uc?export=view&id=0B12XukcbU7T7bWdlTjhzbGJVaWs "ShapeNet Chair Experiment Pretraining Losses")
You can expect such images after fine tuning the training as "grid_vis" under **Image** summaries in TensorBoard:
![ShapeNet Chair experiments with projection weight of 1](https://drive.google.com/uc?export=view&id=0B12XukcbU7T7ZFV6aEVBSDdCMjQ "ShapeNet Chair Dataset Predictions")
Here the third and fifth columns are the predicted masks and voxels respectively, alongside their ground truth values.
A similar image for when trained on all ShapeNet Categories (Voxel visualizations might be skewed):
![ShapeNet All Categories experiments](https://drive.google.com/uc?export=view&id=0B12XukcbU7T7bDZKNFlkTVAzZmM "ShapeNet All Categories Dataset Predictions")
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains evaluation plan for the Im2vox model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from tensorflow import app
import model_ptn
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_string('inp_dir',
'',
'Directory path containing the input data (tfrecords).')
flags.DEFINE_string(
'dataset_name', 'shapenet_chair',
'Dataset name that is to be used for training and evaluation.')
flags.DEFINE_integer('z_dim', 512, '')
flags.DEFINE_integer('f_dim', 64, '')
flags.DEFINE_integer('fc_dim', 1024, '')
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.')
flags.DEFINE_integer('image_size', 64,
'Input images dimension (pixels) - width & height.')
flags.DEFINE_integer('vox_size', 32, 'Voxel prediction dimension.')
flags.DEFINE_integer('step_size', 24, '')
flags.DEFINE_integer('batch_size', 1, 'Batch size while training.')
flags.DEFINE_float('focal_length', 0.866, '')
flags.DEFINE_float('focal_range', 1.732, '')
flags.DEFINE_string('encoder_name', 'ptn_encoder',
'Name of the encoder network being used.')
flags.DEFINE_string('decoder_name', 'ptn_vox_decoder',
'Name of the decoder network being used.')
flags.DEFINE_string('projector_name', 'ptn_projector',
'Name of the projector network being used.')
# Save options
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn/eval/',
'Directory path for saving trained models and other data.')
flags.DEFINE_string('model_name', 'ptn_proj',
'Name of the model used in naming the TF job. Must be different for each run.')
flags.DEFINE_string('eval_set', 'val', 'Data partition to form evaluation on.')
# Optimization
flags.DEFINE_float('proj_weight', 10, 'Weighting factor for projection loss.')
flags.DEFINE_float('volume_weight', 0, 'Weighting factor for volume loss.')
flags.DEFINE_float('viewpoint_weight', 1,
'Weighting factor for viewpoint loss.')
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.')
flags.DEFINE_float('weight_decay', 0.001, '')
flags.DEFINE_float('clip_gradient_norm', 0, '')
# Summary
flags.DEFINE_integer('save_summaries_secs', 15, '')
flags.DEFINE_integer('eval_interval_secs', 60 * 5, '')
# Distribution
flags.DEFINE_string('master', '', '')
FLAGS = flags.FLAGS
def main(argv=()):
del argv # Unused.
eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name,
'eval_%s' % FLAGS.eval_set)
if not os.path.exists(eval_dir):
os.makedirs(eval_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
g = tf.Graph()
with g.as_default():
eval_params = FLAGS
eval_params.batch_size = 1
eval_params.step_size = FLAGS.num_views
###########
## model ##
###########
model = model_ptn.model_PTN(eval_params)
##########
## data ##
##########
eval_data = model.get_inputs(
FLAGS.data_sst_path,
FLAGS.dataset_name,
eval_params.eval_set,
eval_params.batch_size,
eval_params.image_size,
eval_params.vox_size,
is_training=False)
inputs = model.preprocess_with_all_views(eval_data)
##############
## model_fn ##
##############
model_fn = model.get_model_fn(is_training=False, run_projection=False)
outputs = model_fn(inputs)
#############
## metrics ##
#############
names_to_values, names_to_updates = model.get_metrics(inputs, outputs)
del names_to_values
################
## evaluation ##
################
num_batches = eval_data['num_samples']
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=eval_dir,
logdir=log_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains evaluation plan for the Rotator model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from tensorflow import app
import model_rotator as model
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_string('inp_dir',
'',
'Directory path containing the input data (tfrecords).')
flags.DEFINE_string(
'dataset_name', 'shapenet_chair',
'Dataset name that is to be used for training and evaluation.')
flags.DEFINE_integer('z_dim', 512, '')
flags.DEFINE_integer('a_dim', 3, '')
flags.DEFINE_integer('f_dim', 64, '')
flags.DEFINE_integer('fc_dim', 1024, '')
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.')
flags.DEFINE_integer('image_size', 64,
'Input images dimension (pixels) - width & height.')
flags.DEFINE_integer('step_size', 24, '')
flags.DEFINE_integer('batch_size', 2, '')
flags.DEFINE_string('encoder_name', 'ptn_encoder',
'Name of the encoder network being used.')
flags.DEFINE_string('decoder_name', 'ptn_im_decoder',
'Name of the decoder network being used.')
flags.DEFINE_string('rotator_name', 'ptn_rotator',
'Name of the rotator network being used.')
# Save options
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn_train/',
'Directory path for saving trained models and other data.')
flags.DEFINE_string('model_name', 'ptn_proj',
'Name of the model used in naming the TF job. Must be different for each run.')
# Optimization
flags.DEFINE_float('image_weight', 10, '')
flags.DEFINE_float('mask_weight', 1, '')
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.')
flags.DEFINE_float('weight_decay', 0.001, '')
flags.DEFINE_float('clip_gradient_norm', 0, '')
# Summary
flags.DEFINE_integer('save_summaries_secs', 15, '')
flags.DEFINE_integer('eval_interval_secs', 60 * 5, '')
# Scheduling
flags.DEFINE_string('master', 'local', '')
FLAGS = flags.FLAGS
def main(argv=()):
del argv # Unused.
eval_dir = os.path.join(FLAGS.checkpoint_dir,
FLAGS.model_name, 'train')
log_dir = os.path.join(FLAGS.checkpoint_dir,
FLAGS.model_name, 'eval')
if not os.path.exists(eval_dir):
os.makedirs(eval_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
g = tf.Graph()
if FLAGS.step_size < FLAGS.num_views:
raise ValueError('Impossible step_size, must not be less than num_views.')
g = tf.Graph()
with g.as_default():
##########
## data ##
##########
val_data = model.get_inputs(
FLAGS.data_sst_path,
FLAGS.dataset_name,
'val',
FLAGS.batch_size,
FLAGS.image_size,
is_training=False)
inputs = model.preprocess(val_data, FLAGS.step_size)
###########
## model ##
###########
model_fn = model.get_model_fn(FLAGS, is_training=False)
outputs = model_fn(inputs)
#############
## metrics ##
#############
names_to_values, names_to_updates = model.get_metrics(
inputs, outputs, FLAGS)
del names_to_values
################
## evaluation ##
################
num_batches = int(val_data['num_samples'] / FLAGS.batch_size)
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=eval_dir,
logdir=log_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides dataset dictionaries as used in our network models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.data import dataset
from tensorflow.contrib.slim.python.slim.data import dataset_data_provider
from tensorflow.contrib.slim.python.slim.data import tfexample_decoder
_ITEMS_TO_DESCRIPTIONS = {
'image': 'Images',
'mask': 'Masks',
'vox': 'Voxels'
}
def _get_split(file_pattern, num_samples, num_views, image_size, vox_size):
"""Get dataset.Dataset for the given dataset file pattern and properties."""
# A dictionary from TF-Example keys to tf.FixedLenFeature instance.
keys_to_features = {
'image': tf.FixedLenFeature(
shape=[num_views, image_size, image_size, 3],
dtype=tf.float32, default_value=None),
'mask': tf.FixedLenFeature(
shape=[num_views, image_size, image_size, 1],
dtype=tf.float32, default_value=None),
'vox': tf.FixedLenFeature(
shape=[vox_size, vox_size, vox_size, 1],
dtype=tf.float32, default_value=None),
}
items_to_handler = {
'image': tfexample_decoder.Tensor(
'image', shape=[num_views, image_size, image_size, 3]),
'mask': tfexample_decoder.Tensor(
'mask', shape=[num_views, image_size, image_size, 1]),
'vox': tfexample_decoder.Tensor(
'vox', shape=[vox_size, vox_size, vox_size, 1])
}
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handler)
return dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=num_samples,
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS)
def get(dataset_dir,
dataset_name,
split_name,
shuffle=True,
num_readers=1,
common_queue_capacity=64,
common_queue_min=50):
"""Provides input data for a specified dataset and split."""
dataset_to_kwargs = {
'shapenet_chair': {
'file_pattern': '03001627_%s.tfrecords' % split_name,
'num_views': 24,
'image_size': 64,
'vox_size': 32,
}, 'shapenet_all': {
'file_pattern': '*_%s.tfrecords' % split_name,
'num_views': 24,
'image_size': 64,
'vox_size': 32,
},
}
split_sizes = {
'shapenet_chair': {
'train': 4744,
'val': 678,
'test': 1356,
},
'shapenet_all': {
'train': 30643,
'val': 4378,
'test': 8762,
}
}
kwargs = dataset_to_kwargs[dataset_name]
kwargs['file_pattern'] = os.path.join(dataset_dir, kwargs['file_pattern'])
kwargs['num_samples'] = split_sizes[dataset_name][split_name]
dataset_split = _get_split(**kwargs)
data_provider = dataset_data_provider.DatasetDataProvider(
dataset_split,
num_readers=num_readers,
common_queue_capacity=common_queue_capacity,
common_queue_min=common_queue_min,
shuffle=shuffle)
inputs = {
'num_samples': dataset_split.num_samples,
}
[image, mask, vox] = data_provider.get(['image', 'mask', 'vox'])
inputs['image'] = image
inputs['mask'] = mask
inputs['voxel'] = vox
return inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment