Commit cbd571f2 authored by Corey Lynch's avatar Corey Lynch
Browse files

Adding TCN.

parent 69cf6fca
......@@ -32,6 +32,7 @@ research/slim/* @sguada @nathansilberman
research/street/* @theraysmith
research/swivel/* @waterson
research/syntaxnet/* @calberti @andorardo @bogatyy @markomernick
research/tcn/* @coreylynch @sermanet
research/textsum/* @panyx0718 @peterjliu
research/transformer/* @daviddao
research/video_prediction/* @cbfinn
......
......@@ -61,6 +61,7 @@ installation](https://www.tensorflow.org/install).
using a Deep RNN.
- [swivel](swivel): the Swivel algorithm for generating word embeddings.
- [syntaxnet](syntaxnet): neural models of natural language syntax.
- [tcn](tcn): Self-supervised representation learning from multi-view video.
- [textsum](textsum): sequence-to-sequence with attention model for text
summarization.
- [transformer](transformer): spatial transformer network, which allows the
......
package(default_visibility = [":internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//tcn/...",
],
)
py_binary(
name = "download_pretrained",
srcs = [
"download_pretrained.py",
],
)
py_binary(
name = "generate_videos",
srcs = [
"generate_videos.py",
],
main = "generate_videos.py",
deps = [
":data_providers",
":get_estimator",
":util",
],
)
py_test(
name = "svtcn_loss_test",
size = "medium",
srcs = [
"estimators/svtcn_loss.py",
"estimators/svtcn_loss_test.py",
],
deps = [
":util",
],
)
py_library(
name = "data_providers",
srcs = [
"data_providers.py",
],
deps = [
":preprocessing",
],
)
py_test(
name = "data_providers_test",
size = "large",
srcs = ["data_providers_test.py"],
deps = [
":data_providers",
],
)
py_library(
name = "preprocessing",
srcs = [
"preprocessing.py",
],
)
py_binary(
name = "get_estimator",
srcs = [
"estimators/get_estimator.py",
],
deps = [
":mvtcn_estimator",
":svtcn_estimator",
],
)
py_binary(
name = "base_estimator",
srcs = [
"estimators/base_estimator.py",
"model.py",
],
deps = [
":data_providers",
":util",
],
)
py_library(
name = "util",
srcs = [
"utils/luatables.py",
"utils/progress.py",
"utils/util.py",
],
)
py_binary(
name = "mvtcn_estimator",
srcs = [
"estimators/mvtcn_estimator.py",
],
deps = [
":base_estimator",
],
)
py_binary(
name = "svtcn_estimator",
srcs = [
"estimators/svtcn_estimator.py",
"estimators/svtcn_loss.py",
],
deps = [
":base_estimator",
],
)
py_binary(
name = "train",
srcs = [
"train.py",
],
deps = [
":data_providers",
":get_estimator",
":util",
],
)
py_binary(
name = "labeled_eval",
srcs = [
"labeled_eval.py",
],
deps = [
":get_estimator",
],
)
py_test(
name = "labeled_eval_test",
size = "small",
srcs = ["labeled_eval_test.py"],
deps = [
":labeled_eval",
],
)
py_binary(
name = "eval",
srcs = [
"eval.py",
],
deps = [
":get_estimator",
],
)
py_binary(
name = "alignment",
srcs = [
"alignment.py",
],
deps = [
":get_estimator",
],
)
py_binary(
name = "visualize_embeddings",
srcs = [
"visualize_embeddings.py",
],
deps = [
":data_providers",
":get_estimator",
":util",
],
)
py_binary(
name = "webcam",
srcs = [
"dataset/webcam.py",
],
main = "dataset/webcam.py",
)
py_binary(
name = "images_to_videos",
srcs = [
"dataset/images_to_videos.py",
],
main = "dataset/images_to_videos.py",
)
py_binary(
name = "videos_to_tfrecords",
srcs = [
"dataset/videos_to_tfrecords.py",
],
main = "dataset/videos_to_tfrecords.py",
deps = [
":preprocessing",
],
)
# Contributing guidelines
If you have created a model and would like to publish it here, please send us a
pull request. For those just getting started with pull requests, GitHub has a
[howto](https://help.github.com/articles/using-pull-requests/).
The code for any model in this repository is licensed under the Apache License
2.0.
In order to accept our code, we have to make sure that we can publish your code:
You have to sign a Contributor License Agreement (CLA).
### Contributor License Agreements
Please fill out either the individual or corporate Contributor License Agreement (CLA).
* If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html).
* If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html).
Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests.
***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the repository.
Copyright 2016 The TensorFlow Authors. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2016, The Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
# Time Contrastive Networks
This implements ["Time Contrastive Networks"](https://arxiv.org/abs/1704.06888),
which is part of the larger [Self-Supervised Imitation
Learning](https://sermanet.github.io/imitation/) project.
![](https://sermanet.github.io/tcn/docs/figs/mvTCN.png)
## Contacts
Maintainers of TCN:
* Corey Lynch: [github](https://github.com/coreylynch),
[twitter](https://twitter.com/coreylynch)
* Pierre Sermanet: [github](https://github.com/sermanet),
[twitter](https://twitter.com/psermanet)
## Contents
* [Getting Started](#getting-started)
* [Install Dependencies](#install-dependencies)
* [Download the Inception v3
Checkpoint](#download-pretrained-inceptionv3-checkpoint)
* [Run all the tests](#run-all-the-tests)
* [Concepts](#concepts)
* [Multi-view Webcam Video](#multi-view-webcam-video)
* [Data Pipelines](#data-pipelines)
* [Estimators](#estimators)
* [Models](#models)
* [Losses](#losses)
* [Inference](#inference)
* [Configuration](#configuration)
* [Monitoring Training](#monitoring-training)
* [KNN Classification Error](#knn-classification-error)
* [KNN Classification Error](#multi-view-alignment)
* [Visualization](#visualization)
* [Nearest Neighbor Imitation
Videos](#nearest-neighbor-imitation-videos)
* [PCA & T-SNE Visualization](#pca-t-sne-visualization)
* [Tutorial Part I: Collecting Multi-View Webcam
Videos](#tutorial-part-i-collecting-multi-view-webcam-videos)
* [Collect Webcam Videos](#collect-webcam-videos)
* [Create TFRecords](#create-tfrecords)
* [Tutorial Part II: Training, Evaluation, and
Visualization](#tutorial-part-ii-training-evaluation-and-visualization)
* [Download Data](#download-data)
* [Download the Inception v3
Checkpoint](#download-pretrained-inceptionv3-checkpoint)
* [Define a Config](#define-a-config)
* [Train](#train)
* [Evaluate](#evaluate)
* [Monitor training](#monior-training)
* [Visualize](#visualize)
* [Generate Imitation Videos](#generate-imitation-videos)
* [Run PCA & T-SNE Visualization](#t-sne-pca-visualization)
## Getting started
### Install Dependencies
* [Tensorflow nightly build](https://pypi.python.org/pypi/tf-nightly-gpu) or
via `pip install tf-nightly-gpu`.
* [Bazel](http://bazel.io/docs/install.html)
* matplotlib
* sklearn
* opencv
### Download Pretrained InceptionV3 Checkpoint
Run the script that downloads the pretrained InceptionV3 checkpoint:
```bash
cd tensorflow-models/tcn
python download_pretrained.py
```
### Run all the tests
```bash
bazel test :all
```
## Concepts
### Multi-View Webcam Video
We provide utilities to collect your own multi-view videos in dataset/webcam.py.
See the [webcam tutorial](#tutorial-part-i-collecting-multi-view-webcam-videos)
for an end to end example of how to collect multi-view webcam data and convert
it to the TFRecord format expected by this library.
## Data Pipelines
We use the [tf.data.Dataset
API](https://www.tensorflow.org/programmers_guide/datasets) to construct input
pipelines that feed training, evaluation, and visualization. These pipelines are
defined in `data_providers.py`.
## Estimators
We define training, evaluation, and inference behavior using the
[tf.estimator.Estimator
API](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator). See
`estimators/mvtcn_estimator.py` for an example of how multi-view TCN training,
evaluation, and inference is implemented.
## Models
Different embedder architectures are implemented in model.py. We used the
`InceptionConvSSFCEmbedder` in the pouring experiments, but we're also
evaluating `Resnet` embedders.
## Losses
We use the
[tf.contrib.losses.metric_learning](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/losses/metric_learning)
library's implementations of triplet loss with semi-hard negative mining and
npairs loss. In our experiments, npairs loss has better empirical convergence
and produces the best qualitative visualizations, and will likely be our choice
for future experiments. See the
[paper](http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf)
for details on the algorithm.
## Inference
We support 3 modes of inference for trained TCN models:
* Mode 1: Input is a tf.Estimator input_fn (see
[this](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#predict)
for details). Output is an iterator over embeddings and additional metadata.
See `labeled_eval.py` for a usage example.
* Mode 2: Input is a TFRecord or (or list of TFRecords). This returns an
iterator over tuples of (embeddings, raw_image_strings, sequence_name),
where embeddings is the [num views, sequence length, embedding size] numpy
array holding the full embedded sequence (for all views), raw_image_strings
is a [num views, sequence length] string array holding the jpeg-encoded raw
image strings, and sequence_name is the name of the sequence. See
`generate_videos.py` for a usage example.
* Mode 3: Input is a numpy array of size [num images, height, width, num
channels]. This returns a tuple of (embeddings, raw_image_strings), where
embeddings is a 2-D float32 numpy array holding [num_images, embedding_size]
image embeddings, and raw_image_strings is a 1-D string numpy array holding
[batch_size] jpeg-encoded image strings. This can be used as follows:
```python
images = np.random.uniform(0, 1, size=(batch_size, 1080, 1920, 3))
embeddings, _ = estimator.inference(
images, checkpoint_path=checkpoint_path)
```
See `estimators/base_estimator.py` for details.
## Configuration
Data pipelines, training, eval, and visualization are all configured using
key-value parameters passed as [YAML](https://en.wikipedia.org/wiki/YAML) files.
Configurations can be nested, e.g.:
```yaml
learning:
optimizer: 'adam'
learning_rate: 0.001
```
### T objects
YAML configs are converted to LuaTable-like `T` object (see
`utils/luatables.py`), which behave like a python `dict`, but allow you to use
dot notation to access (nested) keys. For example we could access the learning
rate in the above config snippet via `config.learning.learning_rate`.
### Multiple Configs
Multiple configs can be passed to the various binaries as a comma separated list
of config paths via the `--config_paths` flag. This allows us to specify a
default config that applies to all experiments (e.g. how often to write
checkpoints, default embedder hyperparams) and one config per experiment holding
the just hyperparams specific to the experiment (path to data, etc.).
See `configs/tcn_default.yml` for an example of our default config and
`configs/pouring.yml` for an example of how we define the pouring experiments.
Configs are applied left to right. For example, consider two config files:
default.yml
```yaml
learning:
learning_rate: 0.001 # Default learning rate.
optimizer: 'adam'
```
myexperiment.yml
```yaml
learning:
learning_rate: 1.0 # Experiment learning rate (overwrites default).
data:
training: '/path/to/myexperiment/training.tfrecord'
```
Running
```bash
bazel run train.py --config_paths='default.yml,myexperiment.yml'
```
results in a final merged config called final_training_config.yml
```yaml
learning:
optimizer: 'adam'
learning_rate: 1.0
data:
training: '/path/to/myexperiment/training.tfrecord'
```
which is created automatically and stored in the experiment log directory
alongside model checkpoints and tensorboard summaries. This gives us a record of
the exact configs that went into each trial.
## Monitoring training
We usually look at two validation metrics during training: knn classification
error and multi-view alignment.
### KNN-Classification Error
In cases where we have labeled validation data, we can compute the average
cross-sequence KNN classification error (1.0 - recall@k=1) over all embedded
labeled images in the validation set. See `labeled_eval.py`.
### Multi-view Alignment
In cases where there is no labeled validation data, we can look at the how well
our model aligns multiple views of same embedded validation sequences. That is,
for each embedded validation sequence, for all cross-view pairs, we compute the
scaled absolute distance between ground truth time indices and knn time indices.
See `alignment.py`.
## Visualization
We visualize the embedding space learned by our models in two ways: nearest
neighbor imitation videos and PCA/T-SNE.
### Nearest Neighbor Imitation Videos
One of the easiest way to evaluate the understanding of your model is to see how
well the model can semantically align two videos via nearest neighbors in
embedding space.
Consider the case where we have multiple validation demo videos of a human or
robot performing the same task. For example, in the pouring experiments, we
collected many different multiview validation videos of a person pouring the
contents of one container into another, then setting the container down. If we'd
like to see how well our embeddings generalize across viewpoint, object/agent
appearance, and background, we can construct what we call "Nearest Neighbor
Imitation" videos, by embedding some validation query sequence `i` from view 1,
and finding the nearest neighbor for each query frame in some embedded target
sequence `j` filmed from view 1.
[Here's](https://sermanet.github.io/tcn/docs/figs/pouring_human.mov.gif) an
example of the final product.
See `generate_videos.py` for details.
### PCA & T-SNE Visualization
We can also embed a set of images taken randomly from validation videos and
visualize the embedding space using PCA projection and T-SNE in the tensorboard
projector. See `visualize_embeddings.py` for details.
## Tutorial Part I: Collecting Multi-View Webcam Videos
Here we give an end-to-end example of how to collect your own multiview webcam
videos and convert them to the TFRecord format expected by training.
Note: This was tested with up to 8 concurrent [Logitech c930e
webcams](https://www.logitech.com/en-us/product/c930e-webcam) extended with
[Plugable 5 Meter (16 Foot) USB 2.0 Active Repeater Extension
Cables](https://www.amazon.com/gp/product/B006LFL4X0/ref=oh_aui_detailpage_o05_s00?ie=UTF8&psc=1).
### Collect webcam videos
Go to dataset/webcam.py
1. Plug your webcams in and run
```bash
ls -ltrh /dev/video*
```
You should see one device listed per connected webcam.
2. Define some environment variables describing the dataset you're collecting.
```bash
dataset=tutorial # Name of the dataset.
mode=train # E.g. 'train', 'validation', 'test', 'demo'.
num_views=2 # Number of webcams.
viddir=/tmp/tcn/videos # Output directory for the videos.
tmp_imagedir=/tmp/tcn/tmp_images # Temp directory to hold images.
debug_vids=1 # Whether or not to generate side-by-side debug videos.
export DISPLAY=:0.0 # This allows real time matplotlib display.
```
3. Run the webcam.py script.
```bash
bazel build -c opt --copt=-mavx webcam && \
bazel-bin/webcam \
--dataset $dataset \
--mode $mode \
--num_views $num_views \
--tmp_imagedir $tmp_imagedir \
--viddir $viddir \
--debug_vids 1
```
4. Hit Ctrl-C when done collecting, upon which the script will compile videos
for each view and optionally a debug video concatenating multiple
simultaneous views.
5. If `--seqname` flag isn't set, the script will name the first sequence '0',
the second sequence '1', and so on (meaning you can just keep rerunning step
3.). When you are finished, you should see an output viddir with the
following structure:
```bash
videos/0_view0.mov
videos/0_view1.mov
...
videos/0_viewM.mov
videos/1_viewM.mov
...
videos/N_viewM.mov
for N sequences and M webcam views.
```
### Create TFRecords
Use `dataset/videos_to_tfrecords.py` to convert the directory of videos into a
directory of TFRecords files, one per multi-view sequence.
```bash
viddir=/tmp/tcn/videos
dataset=tutorial
mode=train
videos=$viddir/$dataset
bazel build -c opt videos_to_tfrecords && \
bazel-bin/videos_to_tfrecords --logtostderr \
--input_dir $videos/$mode \
--output_dir ~/tcn_data/$dataset/$mode \
--max_per_shard 400
```
Setting `--max_per_shard` > 0 allows you to shard training data. We've observed
that sharding long training sequences provides better performance in terms of
global steps/sec.
This should be left at the default of 0 for validation / test data.
You should now have a directory of TFRecords files with the following structure:
```bash
output_dir/0.tfrecord
...
output_dir/N.tfrecord
1 TFRecord file for each of N multi-view sequences.
```
Now we're ready to move on to part II: training, evaluation, and visualization.
## Tutorial Part II: Training, Evaluation, and Visualization
Here we give an end-to-end example of how to train, evaluate, and visualize the
embedding space learned by TCN models.
### Download Data
We will be using the 'Multiview Pouring' dataset, which can be downloaded using
the download.sh script
[here.](https://sites.google.com/site/brainrobotdata/home/multiview-pouring)
The rest of the tutorial will assume that you have your data downloaded to a
folder at `~/tcn_data`.
```bash
mkdir ~/tcn_data
mv ~/Downloads/download.sh ~/tcn_data
./download.sh
```
You should now have the following path containing all the data:
```bash
ls ~/tcn_data/multiview-pouring
labels README.txt tfrecords videos
```
### Download Pretrained Inception Checkpoint
If you haven't already, run the script that downloads the pretrained InceptionV3
checkpoint:
```bash
python download_pretrained.py
```
### Define A Config
For our experiment, we create 2 configs:
* `configs/tcn_default.yml`: This contains all the default hyperparameters
that generally don't vary across experiments.
* `configs/pouring.yml`: This contains all the hyperparameters that are
specific to the pouring experiment.
Important note about `configs/pouring.yml`:
* data.eval_cropping: We use 'pad200' for the pouring dataset, which was
filmed rather close up on iphone cameras. A better choice for data filmed on
webcam is likely 'crop_center'. See preprocessing.py for options.
### Train
Run the training binary:
```yaml
logdir=/tmp/tcn/pouring
c=configs
configs=$c/tcn_default.yml,$c/pouring.yml
bazel build -c opt --copt=-mavx --config=cuda train && \
bazel-bin/train \
--config_paths $configs --logdir $logdir
```
### Evaluate
Run the binary that computes running validation loss. Set `export
CUDA_VISIBLE_DEVICES=` to run on CPU.
```bash
bazel build -c opt --copt=-mavx eval && \
bazel-bin/eval \
--config_paths $configs --logdir $logdir
```
Run the binary that computes running validation cross-view sequence alignment.
Set `export CUDA_VISIBLE_DEVICES=` to run on CPU.
```bash
bazel build -c opt --copt=-mavx alignment && \
bazel-bin/alignment \
--config_paths $configs --checkpointdir $logdir --outdir $logdir
```
Run the binary that computes running labeled KNN validation error. Set `export
CUDA_VISIBLE_DEVICES=` to run on CPU.
```bash
bazel build -c opt --copt=-mavx labeled_eval && \
bazel-bin/labeled_eval \
--config_paths $configs --checkpointdir $logdir --outdir $logdir
```
### Monitor training
Run `tensorboard --logdir=$logdir`. After a bit of training, you should see
curves that look like this:
#### Training loss
<img src="g3doc/loss.png" title="Training Loss" />
#### Validation loss
<img src="g3doc/val_loss.png" title="Validation Loss" />
#### Validation Alignment
<img src="g3doc/alignment.png" title="Validation Alignment" />
#### Average Validation KNN Classification Error
<img src="g3doc/avg_error.png" title="Validation Average KNN Error" />
#### Individual Validation KNN Classification Errors
<img src="g3doc/all_error.png" title="All Validation Average KNN Errors" />
### Visualize
To visualize the embedding space learned by a model, we can:
#### Generate Imitation Videos
```bash
# Use the automatically generated final config file as config.
configs=$logdir/final_training_config.yml
# Visualize checkpoint 40001.
checkpoint_iter=40001
# Use validation records for visualization.
records=~/tcn_data/multiview-pouring/tfrecords/val
# Write videos to this location.
outdir=$logdir/tcn_viz/imitation_vids
```
```bash
bazel build -c opt --config=cuda --copt=-mavx generate_videos && \
bazel-bin/generate_videos \
--config_paths $configs \
--checkpointdir $logdir \
--checkpoint_iter $checkpoint_iter \
--query_records_dir $records \
--target_records_dir $records \
--outdir $outdir
```
After the script completes, you should see a directory of videos with names
like:
`$outdir/qtrain_clearodwalla_to_clear1_realv1_imtrain_clearsoda_to_white13_realv0.mp4`
that look like this: <img src="g3doc/im.gif" title="Imitation Video" />
#### T-SNE / PCA Visualization
Run the binary that generates embeddings and metadata.
```bash
outdir=$logdir/tcn_viz/embedding_viz
bazel build -c opt --config=cuda --copt=-mavx visualize_embeddings && \
bazel-bin/visualize_embeddings \
--config_paths $configs \
--checkpointdir $logdir \
--checkpoint_iter $checkpoint_iter \
--embedding_records $records \
--outdir $outdir \
--num_embed 1000 \
--sprite_dim 64
```
Run tensorboard, pointed at the embedding viz output directory.
```
tensorboard --logdir=$outdir
```
You should see something like this in tensorboard.
<img src="g3doc/pca.png" title="PCA" />
workspace(name = "tcn")
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Calculates test sequence alignment score."""
from __future__ import absolute_import
from __future__ import absolute_import
from __future__ import division
import os
import numpy as np
from estimators.get_estimator import get_estimator
from utils import util
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
tf.flags.DEFINE_string(
'config_paths', '',
"""
Path to a YAML configuration files defining FLAG values. Multiple files
can be separated by the `#` symbol. Files are merged recursively. Setting
a key in these files is equivalent to setting the FLAG value with
the same name.
""")
tf.flags.DEFINE_string(
'model_params', '{}', 'YAML configuration string for the model parameters.')
tf.app.flags.DEFINE_string(
'checkpoint_iter', '', 'Evaluate this specific checkpoint.')
tf.app.flags.DEFINE_string(
'checkpointdir', '/tmp/tcn', 'Path to model checkpoints.')
tf.app.flags.DEFINE_string('outdir', '/tmp/tcn', 'Path to write summaries to.')
FLAGS = tf.app.flags.FLAGS
def compute_average_alignment(
seqname_to_embeddings, num_views, summary_writer, training_step):
"""Computes the average cross-view alignment for all sequence view pairs.
Args:
seqname_to_embeddings: Dict, mapping sequence name to a
[num_views, embedding size] numpy matrix holding all embedded views.
num_views: Int, number of simultaneous views in the dataset.
summary_writer: A `SummaryWriter` object.
training_step: Int, the training step of the model used to embed images.
Alignment is the scaled absolute difference between the ground truth time
and the knn aligned time.
abs(|time_i - knn_time|) / sequence_length
"""
all_alignments = []
for _, view_embeddings in seqname_to_embeddings.iteritems():
for idx_i in range(num_views):
for idx_j in range(idx_i+1, num_views):
embeddings_view_i = view_embeddings[idx_i]
embeddings_view_j = view_embeddings[idx_j]
seq_len = len(embeddings_view_i)
times_i = np.array(range(seq_len))
# Get the nearest time_index for each embedding in view_i.
times_j = np.array([util.KNNIdsWithDistances(
q, embeddings_view_j, k=1)[0][0] for q in embeddings_view_i])
# Compute sequence view pair alignment.
alignment = np.mean(
np.abs(np.array(times_i)-np.array(times_j))/float(seq_len))
all_alignments.append(alignment)
print 'alignment so far %f' % alignment
average_alignment = np.mean(all_alignments)
print 'Average alignment %f' % average_alignment
summ = tf.Summary(value=[tf.Summary.Value(
tag='validation/alignment', simple_value=average_alignment)])
summary_writer.add_summary(summ, int(training_step))
def evaluate_once(
config, checkpointdir, validation_records, checkpoint_path, batch_size,
num_views):
"""Evaluates and reports the validation alignment."""
# Choose an estimator based on training strategy.
estimator = get_estimator(config, checkpointdir)
# Embed all validation sequences.
seqname_to_embeddings = {}
for (view_embeddings, _, seqname) in estimator.inference(
validation_records, checkpoint_path, batch_size):
seqname_to_embeddings[seqname] = view_embeddings
# Compute and report alignment statistics.
ckpt_step = int(checkpoint_path.split('-')[-1])
summary_dir = os.path.join(FLAGS.outdir, 'alignment_summaries')
summary_writer = tf.summary.FileWriter(summary_dir)
compute_average_alignment(
seqname_to_embeddings, num_views, summary_writer, ckpt_step)
def main(_):
# Parse config dict from yaml config files / command line flags.
config = util.ParseConfigsToLuaTable(FLAGS.config_paths, FLAGS.model_params)
num_views = config.data.num_views
validation_records = util.GetFilesRecursively(config.data.validation)
batch_size = config.data.batch_size
checkpointdir = FLAGS.checkpointdir
# If evaluating a specific checkpoint, do that.
if FLAGS.checkpoint_iter:
checkpoint_path = os.path.join(
'%s/model.ckpt-%s' % (checkpointdir, FLAGS.checkpoint_iter))
evaluate_once(
config, checkpointdir, validation_records, checkpoint_path, batch_size,
num_views)
else:
for checkpoint_path in tf.contrib.training.checkpoints_iterator(
checkpointdir):
evaluate_once(
config, checkpointdir, validation_records, checkpoint_path,
batch_size, num_views)
if __name__ == '__main__':
tf.app.run()
# Train with Multi-View TCN.
training_strategy: 'mvtcn'
# Use the 'inception_conv_ss_fc' embedder, which has the structure:
# InceptionV3 -> 2 conv adaptation layers -> spatial softmax -> fully connected
# -> embedding.
embedder_strategy: 'inception_conv_ss_fc'
# Use npairs loss.
loss_strategy: 'npairs'
learning:
learning_rate: 0.0001
# Set some hyperparameters for our embedder.
inception_conv_ss_fc:
# Don't finetune the pre-trained weights.
finetune_inception: false
dropout:
# Don't dropout convolutional activations.
keep_conv: 1.0
# Use a dropout of 0.8 on the fully connected activations.
keep_fc: 0.8
# Use a dropout of 0.8 on the inception activations.
keep_pretrained: 0.8
# Size of the TCN embedding.
embedding_size: 32
data:
raw_height: 480
raw_width: 360
batch_size: 32
examples_per_sequence: 32
num_views: 2
preprocessing:
# Inference-time image cropping strategy.
eval_cropping: 'pad200'
augmentation:
# Do scale augmentation.
minscale: 0.8 # When downscaling, zoom in to 80% of the central bounding box.
maxscale: 3.0 # When upscaling, zoom out to 300% of the central bounding box.
proportion_scaled_up: 0.5 # Proportion of the time to scale up rather than down.
color: true # Do color augmentation.
fast_mode: true
# Paths to the data.
training: '~/tcn_data/multiview-pouring/tfrecords/train'
validation: '~/tcn_data/multiview-pouring/tfrecords/val'
test: 'path/to/test'
labeled:
image_attr_keys: ['image/view0', 'image/view1', 'task']
label_attr_keys: ['contact', 'distance', 'liquid_flowing', 'has_liquid', 'container_angle']
validation: '~/tcn_data/multiview-pouring/monolithic-labeled/val'
test: '~/tcn_data/multiview-pouring/monolithic-labeled/test'
logging:
checkpoint:
save_checkpoints_steps: 1000
\ No newline at end of file
# These configs are the defaults we used for both the pouring and pose
# experiments.
# Train on TPU?
use_tpu: false # Default is to run without TPU locally.
tpu:
num_shards: 1
iterations: 100
# SGD / general learning hyperparameters.
learning:
max_step: 1000000
learning_rate: 0.001
decay_steps: 10000
decay_factor: 1.00
l2_reg_weight: 0.000001
optimizer: 'adam'
# Default metric learning loss hyperparameters.
triplet_semihard:
embedding_l2: true # Suggestion from Hyun Oh Song's slides.
margin: .2 # Default value for Facenet.
npairs:
embedding_l2: false # Suggestion from Hyun Oh Song's slides.
clustering_loss:
embedding_l2: true # Suggestion from Hyun Oh Song's slides.
margin: 1.0 # Default in deep_metric_learning.
lifted_struct:
embedding_l2: false # Suggestion from Hyun Oh Song's slides.
margin: 1.0
contrastive:
embedding_l2: true # Suggestion from Hyun Oh Song's slides.
margin: 1.0
# Which method to use to train the embedding.
# Options are "mvtcn", "svtcn".
training_strategy: 'mvtcn'
# Which embedder architecture to use.
# Options are 'inception_conv_ss_fc' (used in pouring / pose experiments),
# 'resnet'.
embedder_strategy: 'inception_conv_ss_fc'
# Size of the TCN embedding.
embedding_size: 32
# Default hyperparameters for the different embedder architectures.
inception_conv_ss_fc:
pretrained_checkpoint: 'pretrained_checkpoints/inception/inception_v3.ckpt'
pretrained_layer: 'Mixed_5d'
additional_conv_sizes: [512, 512]
fc_hidden_sizes: [2048]
finetune: false
dropout:
keep_pretrained: 1.0
keep_conv: 1.0
keep_fc: 1.0
resnet:
pretrained_checkpoint: 'pretrained_checkpoints/resnet/resnet_v2_50.ckpt'
pretrained_layer: 4
finetune: false
adaptation_blocks: '512_3-512_3'
emb_connection: 'conv'
fc_hidden_sizes: 'None'
dropout:
keep_pretrained: 1.0
# Loss hyperparameters.
mvtcn:
# Size of the window in timesteps to get random anchor-positive pairs for
# training.
window: 580 # 29fps * 20 seconds.
svtcn:
pos_radius: 6 # 0.2 seconds * 29fps ~ 6 timesteps.
neg_radius: 12 # 2.0 * pos_radius.
# Data configs.
data:
height: 299
width: 299
preprocessing:
# Strategy to use when cropping images at inference time.
# See preprocessing.py for options.
eval_cropping: 'crop_center'
# Training scale, color augmentation hyparameters.
augmentation:
# See preprocessing.py for a discussion of how to use these parameters.
minscale: 1.0
maxscale: 1.0
proportion_scaled_up: 0.5
color: true
fast_mode: true
num_parallel_calls: 12
sequence_prefetch_size: 12
batch_prefetch_size: 12
batch_size: 36
eval_batch_size: 36
embed_batch_size: 128
val:
recall_at_k_list: [1]
num_eval_samples: 1000
eval_interval_secs: 300
logging:
summary:
image_summaries: false
save_summaries_steps: 100
flush_secs: 600
checkpoint:
num_to_keep: 0 # Keep all checkpoints.
save_checkpoints_steps: 1000
secs: 1800
\ No newline at end of file
use_tpu: False
training_strategy: 'mvtcn'
loss_strategy: 'triplet_semihard'
learning:
max_step: 2
optimizer: 'adam'
embedding_size: 8
data:
embed_batch_size: 12
batch_size: 12
examples_per_sequence: 12
num_views: 2
num_parallel_calls: 1
sequence_prefetch_size: 1
batch_prefetch_size: 1
logging:
summary:
image_summaries: false
save_summaries_steps: 100
flush_secs: 600
save_summaries_secs: 60
checkpoint:
num_to_keep: 0 # Keep all checkpoints.
save_checkpoints_steps: 1000
secs: 1800
\ No newline at end of file
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines data providers used in training and evaluating TCNs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import random
import numpy as np
import preprocessing
import tensorflow as tf
def record_dataset(filename):
"""Generate a TFRecordDataset from a `filename`."""
return tf.data.TFRecordDataset(filename)
def full_sequence_provider(file_list, num_views):
"""Provides full preprocessed image sequences.
Args:
file_list: List of strings, paths to TFRecords to preprocess.
num_views: Int, the number of simultaneous viewpoints at each timestep in
the dataset.
Returns:
preprocessed: A 4-D float32 `Tensor` holding a sequence of preprocessed
images.
raw_image_strings: A 2-D string `Tensor` holding a sequence of raw
jpeg-encoded image strings.
task: String, the name of the sequence.
seq_len: Int, the number of timesteps in the sequence.
"""
def _parse_sequence(x):
context, views, seq_len = parse_sequence_example(x, num_views)
task = context['task']
return views, task, seq_len
data_files = tf.contrib.slim.parallel_reader.get_data_files(file_list)
dataset = tf.data.Dataset.from_tensor_slices(data_files)
dataset = dataset.repeat(1)
# Get a dataset of sequences.
dataset = dataset.flat_map(record_dataset)
# Build a dataset of TFRecord files.
dataset = dataset.repeat(1)
# Prefetch a number of opened files.
dataset = dataset.prefetch(12)
# Use _parse_sequence to deserialize (but not decode) image strings.
dataset = dataset.map(_parse_sequence, num_parallel_calls=12)
# Prefetch batches of images.
dataset = dataset.prefetch(12)
dataset = dataset.make_one_shot_iterator()
views, task, seq_len = dataset.get_next()
return views, task, seq_len
def parse_labeled_example(
example_proto, view_index, preprocess_fn, image_attr_keys, label_attr_keys):
"""Parses a labeled test example from a specified view.
Args:
example_proto: A scalar string Tensor.
view_index: Int, index on which view to parse.
preprocess_fn: A function with the signature (raw_images, is_training) ->
preprocessed_images, where raw_images is a 4-D float32 image `Tensor`
of raw images, is_training is a Boolean describing if we're in training,
and preprocessed_images is a 4-D float32 image `Tensor` holding
preprocessed images.
image_attr_keys: List of Strings, names for image keys.
label_attr_keys: List of Strings, names for label attributes.
Returns:
data: A tuple of images, attributes and tasks `Tensors`.
"""
features = {}
for attr_key in image_attr_keys:
features[attr_key] = tf.FixedLenFeature((), tf.string)
for attr_key in label_attr_keys:
features[attr_key] = tf.FixedLenFeature((), tf.int64)
parsed_features = tf.parse_single_example(example_proto, features)
image_only_keys = [i for i in image_attr_keys if 'image' in i]
view_image_key = image_only_keys[view_index]
image = preprocessing.decode_image(parsed_features[view_image_key])
preprocessed = preprocess_fn(image, is_training=False)
attributes = [parsed_features[k] for k in label_attr_keys]
task = parsed_features['task']
return tuple([preprocessed] + attributes + [task])
def labeled_data_provider(
filenames, preprocess_fn, view_index, image_attr_keys, label_attr_keys,
batch_size=32, num_epochs=1):
"""Gets a batched dataset iterator over annotated test images + labels.
Provides a single view, specifed in `view_index`.
Args:
filenames: List of Strings, paths to tfrecords on disk.
preprocess_fn: A function with the signature (raw_images, is_training) ->
preprocessed_images, where raw_images is a 4-D float32 image `Tensor`
of raw images, is_training is a Boolean describing if we're in training,
and preprocessed_images is a 4-D float32 image `Tensor` holding
preprocessed images.
view_index: Int, the index of the view to embed.
image_attr_keys: List of Strings, names for image keys.
label_attr_keys: List of Strings, names for label attributes.
batch_size: Int, size of the batch.
num_epochs: Int, number of epochs over the classification dataset.
Returns:
batch_images: 4-d float `Tensor` holding the batch images for the view.
labels: K-d int `Tensor` holding the K label attributes.
tasks: 1-D String `Tensor`, holding the task names for each batch element.
"""
dataset = tf.data.TFRecordDataset(filenames)
# pylint: disable=g-long-lambda
dataset = dataset.map(
lambda p: parse_labeled_example(
p, view_index, preprocess_fn, image_attr_keys, label_attr_keys))
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
data_iterator = dataset.make_one_shot_iterator()
batch_data = data_iterator.get_next()
batch_images = batch_data[0]
batch_labels = tf.stack(batch_data[1:-1], 1)
batch_tasks = batch_data[-1]
batch_images = set_image_tensor_batch_dim(batch_images, batch_size)
batch_labels.set_shape([batch_size, len(label_attr_keys)])
batch_tasks.set_shape([batch_size])
return batch_images, batch_labels, batch_tasks
def parse_sequence_example(serialized_example, num_views):
"""Parses a serialized sequence example into views, sequence length data."""
context_features = {
'task': tf.FixedLenFeature(shape=[], dtype=tf.string),
'len': tf.FixedLenFeature(shape=[], dtype=tf.int64)
}
view_names = ['view%d' % i for i in range(num_views)]
fixed_features = [
tf.FixedLenSequenceFeature(
shape=[], dtype=tf.string) for _ in range(len(view_names))]
sequence_features = dict(zip(view_names, fixed_features))
context_parse, sequence_parse = tf.parse_single_sequence_example(
serialized=serialized_example,
context_features=context_features,
sequence_features=sequence_features)
views = tf.stack([sequence_parse[v] for v in view_names])
lens = [sequence_parse[v].get_shape().as_list()[0] for v in view_names]
assert len(set(lens)) == 1
seq_len = tf.shape(sequence_parse[v])[0]
return context_parse, views, seq_len
def get_shuffled_input_records(file_list):
"""Build a tf.data.Dataset of shuffled input TFRecords that repeats."""
dataset = tf.data.Dataset.from_tensor_slices(file_list)
dataset = dataset.shuffle(len(file_list))
dataset = dataset.repeat()
dataset = dataset.flat_map(record_dataset)
dataset = dataset.repeat()
return dataset
def get_tcn_anchor_pos_indices(seq_len, num_views, num_pairs, window):
"""Gets batch TCN anchor positive timestep and view indices.
This gets random (anchor, positive) timesteps from a sequence, and chooses
2 random differing viewpoints for each anchor positive pair.
Args:
seq_len: Int, the size of the batch sequence in timesteps.
num_views: Int, the number of simultaneous viewpoints at each timestep.
num_pairs: Int, the number of pairs to build.
window: Int, the window (in frames) from which to take anchor, positive
and negative indices.
Returns:
ap_time_indices: 1-D Int `Tensor` with size [num_pairs], holding the
timestep for each (anchor,pos) pair.
a_view_indices: 1-D Int `Tensor` with size [num_pairs], holding the
view index for each anchor.
p_view_indices: 1-D Int `Tensor` with size [num_pairs], holding the
view index for each positive.
"""
# Get anchor, positive time indices.
def f1():
# Choose a random window-length range from the sequence.
range_min = tf.random_shuffle(tf.range(seq_len-window))[0]
range_max = range_min+window
return tf.range(range_min, range_max)
def f2():
# Consider the full sequence.
return tf.range(seq_len)
time_indices = tf.cond(tf.greater(seq_len, window), f1, f2)
shuffled_indices = tf.random_shuffle(time_indices)
num_pairs = tf.minimum(seq_len, num_pairs)
ap_time_indices = shuffled_indices[:num_pairs]
# Get opposing anchor, positive view indices.
view_indices = tf.tile(
tf.expand_dims(tf.range(num_views), 0), (num_pairs, 1))
shuffled_view_indices = tf.map_fn(tf.random_shuffle, view_indices)
a_view_indices = shuffled_view_indices[:, 0]
p_view_indices = shuffled_view_indices[:, 1]
return ap_time_indices, a_view_indices, p_view_indices
def set_image_tensor_batch_dim(tensor, batch_dim):
"""Sets the batch dimension on an image tensor."""
shape = tensor.get_shape()
tensor.set_shape([batch_dim, shape[1], shape[2], shape[3]])
return tensor
def parse_sequence_to_pairs_batch(
serialized_example, preprocess_fn, is_training, num_views, batch_size,
window):
"""Parses a serialized sequence example into a batch of preprocessed data.
Args:
serialized_example: A serialized SequenceExample.
preprocess_fn: A function with the signature (raw_images, is_training) ->
preprocessed_images.
is_training: Boolean, whether or not we're in training.
num_views: Int, the number of simultaneous viewpoints at each timestep in
the dataset.
batch_size: Int, size of the batch to get.
window: Int, only take pairs from a maximium window of this size.
Returns:
preprocessed: A 4-D float32 `Tensor` holding preprocessed images.
anchor_images: A 4-D float32 `Tensor` holding raw anchor images.
pos_images: A 4-D float32 `Tensor` holding raw positive images.
"""
_, views, seq_len = parse_sequence_example(serialized_example, num_views)
# Get random (anchor, positive) timestep and viewpoint indices.
num_pairs = batch_size // 2
ap_time_indices, a_view_indices, p_view_indices = get_tcn_anchor_pos_indices(
seq_len, num_views, num_pairs, window)
# Gather the image strings.
combined_anchor_indices = tf.concat(
[tf.expand_dims(a_view_indices, 1),
tf.expand_dims(ap_time_indices, 1)], 1)
combined_pos_indices = tf.concat(
[tf.expand_dims(p_view_indices, 1),
tf.expand_dims(ap_time_indices, 1)], 1)
anchor_images = tf.gather_nd(views, combined_anchor_indices)
pos_images = tf.gather_nd(views, combined_pos_indices)
# Decode images.
anchor_images = tf.map_fn(
preprocessing.decode_image, anchor_images, dtype=tf.float32)
pos_images = tf.map_fn(
preprocessing.decode_image, pos_images, dtype=tf.float32)
# Concatenate [anchor, postitive] images into a batch and preprocess it.
concatenated = tf.concat([anchor_images, pos_images], 0)
preprocessed = preprocess_fn(concatenated, is_training)
anchor_prepro, positive_prepro = tf.split(preprocessed, num_or_size_splits=2,
axis=0)
# Set static batch dimensions for all image tensors
ims = [anchor_prepro, positive_prepro, anchor_images, pos_images]
ims = [set_image_tensor_batch_dim(i, num_pairs) for i in ims]
[anchor_prepro, positive_prepro, anchor_images, pos_images] = ims
# Assign each anchor and positive the same label.
anchor_labels = tf.range(1, num_pairs+1)
positive_labels = tf.range(1, num_pairs+1)
return (anchor_prepro, positive_prepro, anchor_images, pos_images,
anchor_labels, positive_labels, seq_len)
def multiview_pairs_provider(file_list,
preprocess_fn,
num_views,
window,
is_training,
batch_size,
examples_per_seq=2,
num_parallel_calls=12,
sequence_prefetch_size=12,
batch_prefetch_size=12):
"""Provides multi-view TCN anchor-positive image pairs.
Returns batches of Multi-view TCN pairs, where each pair consists of an
anchor and a positive coming from different views from the same timestep.
Batches are filled one entire sequence at a time until
batch_size is exhausted. Pairs are chosen randomly without replacement
within a sequence.
Used by:
* triplet semihard loss.
* clustering loss.
* npairs loss.
* lifted struct loss.
* contrastive loss.
Args:
file_list: List of Strings, paths to tfrecords.
preprocess_fn: A function with the signature (raw_images, is_training) ->
preprocessed_images, where raw_images is a 4-D float32 image `Tensor`
of raw images, is_training is a Boolean describing if we're in training,
and preprocessed_images is a 4-D float32 image `Tensor` holding
preprocessed images.
num_views: Int, the number of simultaneous viewpoints at each timestep.
window: Int, size of the window (in frames) from which to draw batch ids.
is_training: Boolean, whether or not we're in training.
batch_size: Int, how many examples in the batch (num pairs * 2).
examples_per_seq: Int, how many examples to take per sequence.
num_parallel_calls: Int, the number of elements to process in parallel by
mapper.
sequence_prefetch_size: Int, size of the buffer used to prefetch sequences.
batch_prefetch_size: Int, size of the buffer used to prefetch batches.
Returns:
batch_images: A 4-D float32 `Tensor` holding preprocessed batch images.
anchor_labels: A 1-D int32 `Tensor` holding anchor image labels.
anchor_images: A 4-D float32 `Tensor` holding raw anchor images.
positive_labels: A 1-D int32 `Tensor` holding positive image labels.
pos_images: A 4-D float32 `Tensor` holding raw positive images.
"""
def _parse_sequence(x):
return parse_sequence_to_pairs_batch(
x, preprocess_fn, is_training, num_views, examples_per_seq, window)
# Build a buffer of shuffled input TFRecords that repeats forever.
dataset = get_shuffled_input_records(file_list)
# Prefetch a number of opened TFRecords.
dataset = dataset.prefetch(sequence_prefetch_size)
# Use _parse_sequence to map sequences to batches (one sequence per batch).
dataset = dataset.map(
_parse_sequence, num_parallel_calls=num_parallel_calls)
# Filter out sequences that don't have at least examples_per_seq.
def seq_greater_than_min(seqlen, maximum):
return seqlen >= maximum
filter_fn = functools.partial(seq_greater_than_min, maximum=examples_per_seq)
dataset = dataset.filter(lambda a, b, c, d, e, f, seqlen: filter_fn(seqlen))
# Take a number of sequences for the batch.
assert batch_size % examples_per_seq == 0
sequences_per_batch = batch_size // examples_per_seq
dataset = dataset.batch(sequences_per_batch)
# Prefetch batches of images.
dataset = dataset.prefetch(batch_prefetch_size)
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
# Pull out images, reshape to [batch_size, ...], concatenate anchor and pos.
ims = list(data[:4])
anchor_labels, positive_labels = data[4:6]
# Set labels shape.
anchor_labels.set_shape([sequences_per_batch, None])
positive_labels.set_shape([sequences_per_batch, None])
def _reshape_to_batchsize(im):
"""[num_sequences, num_per_seq, ...] images to [batch_size, ...]."""
sequence_ims = tf.split(im, num_or_size_splits=sequences_per_batch, axis=0)
sequence_ims = [tf.squeeze(i) for i in sequence_ims]
return tf.concat(sequence_ims, axis=0)
# Reshape labels.
anchor_labels = _reshape_to_batchsize(anchor_labels)
positive_labels = _reshape_to_batchsize(positive_labels)
def _set_shape(im):
"""Sets a static shape for an image tensor of [sequences_per_batch,...] ."""
shape = im.get_shape()
im.set_shape([sequences_per_batch, shape[1], shape[2], shape[3], shape[4]])
return im
ims = [_set_shape(im) for im in ims]
ims = [_reshape_to_batchsize(im) for im in ims]
anchor_prepro, positive_prepro, anchor_images, pos_images = ims
batch_images = tf.concat([anchor_prepro, positive_prepro], axis=0)
return batch_images, anchor_labels, positive_labels, anchor_images, pos_images
def get_svtcn_indices(seq_len, batch_size, num_views):
"""Gets a random window of contiguous time indices from a sequence.
Args:
seq_len: Int, number of timesteps in the image sequence.
batch_size: Int, size of the batch to construct.
num_views: Int, the number of simultaneous viewpoints at each
timestep in the dataset.
Returns:
time_indices: 1-D Int `Tensor` with size [batch_size], holding the
timestep for each batch image.
view_indices: 1-D Int `Tensor` with size [batch_size], holding the
view for each batch image. This is consistent across the batch.
"""
# Get anchor, positive time indices.
def f1():
# Choose a random contiguous range from within the sequence.
range_min = tf.random_shuffle(tf.range(seq_len-batch_size))[0]
range_max = range_min+batch_size
return tf.range(range_min, range_max)
def f2():
# Consider the full sequence.
return tf.range(seq_len)
time_indices = tf.cond(tf.greater(seq_len, batch_size), f1, f2)
# Get opposing anchor, positive view indices.
random_view = tf.random_shuffle(tf.range(num_views))[0]
view_indices = tf.tile([random_view], (batch_size,))
return time_indices, view_indices
def parse_sequence_to_svtcn_batch(
serialized_example, preprocess_fn, is_training, num_views, batch_size):
"""Parses a serialized sequence example into a batch of SVTCN data."""
_, views, seq_len = parse_sequence_example(serialized_example, num_views)
# Get svtcn indices.
time_indices, view_indices = get_svtcn_indices(seq_len, batch_size, num_views)
combined_indices = tf.concat(
[tf.expand_dims(view_indices, 1),
tf.expand_dims(time_indices, 1)], 1)
# Gather the image strings.
images = tf.gather_nd(views, combined_indices)
# Decode images.
images = tf.map_fn(preprocessing.decode_image, images, dtype=tf.float32)
# Concatenate anchor and postitive images, preprocess the batch.
preprocessed = preprocess_fn(images, is_training)
return preprocessed, images, time_indices
def singleview_tcn_provider(file_list,
preprocess_fn,
num_views,
is_training,
batch_size,
num_parallel_calls=12,
sequence_prefetch_size=12,
batch_prefetch_size=12):
"""Provides data to train singleview TCNs.
Args:
file_list: List of Strings, paths to tfrecords.
preprocess_fn: A function with the signature (raw_images, is_training) ->
preprocessed_images, where raw_images is a 4-D float32 image `Tensor`
of raw images, is_training is a Boolean describing if we're in training,
and preprocessed_images is a 4-D float32 image `Tensor` holding
preprocessed images.
num_views: Int, the number of simultaneous viewpoints at each timestep.
is_training: Boolean, whether or not we're in training.
batch_size: Int, how many examples in the batch.
num_parallel_calls: Int, the number of elements to process in parallel by
mapper.
sequence_prefetch_size: Int, size of the buffer used to prefetch sequences.
batch_prefetch_size: Int, size of the buffer used to prefetch batches.
Returns:
batch_images: A 4-D float32 `Tensor` of preprocessed images.
raw_images: A 4-D float32 `Tensor` of raw images.
timesteps: A 1-D int32 `Tensor` of timesteps associated with each image.
"""
def _parse_sequence(x):
return parse_sequence_to_svtcn_batch(
x, preprocess_fn, is_training, num_views, batch_size)
# Build a buffer of shuffled input TFRecords that repeats forever.
dataset = get_shuffled_input_records(file_list)
# Prefetch a number of opened files.
dataset = dataset.prefetch(sequence_prefetch_size)
# Use _parse_sequence to map sequences to image batches.
dataset = dataset.map(
_parse_sequence, num_parallel_calls=num_parallel_calls)
# Prefetch batches of images.
dataset = dataset.prefetch(batch_prefetch_size)
dataset = dataset.make_one_shot_iterator()
batch_images, raw_images, timesteps = dataset.get_next()
return batch_images, raw_images, timesteps
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_providers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import data_providers
import tensorflow as tf
class DataTest(tf.test.TestCase):
def testMVTripletIndices(self):
"""Ensures anchor/pos indices for a TCN batch are valid."""
tf.set_random_seed(0)
window = 580
batch_size = 36
num_pairs = batch_size // 2
num_views = 2
seq_len = 600
# Get anchor time and view indices for this sequence.
(_, a_view_indices,
p_view_indices) = data_providers.get_tcn_anchor_pos_indices(
seq_len, num_views, num_pairs, window)
with self.test_session() as sess:
(np_a_view_indices,
np_p_view_indices) = sess.run([a_view_indices, p_view_indices])
# Assert no overlap between anchor and pos view indices.
np.testing.assert_equal(
np.any(np.not_equal(np_a_view_indices, np_p_view_indices)), True)
# Assert set of view indices is a subset of expected set of view indices.
view_set = set(range(num_views))
self.assertTrue(set(np_a_view_indices).issubset(view_set))
self.assertTrue(set(np_p_view_indices).issubset(view_set))
def testSVTripletIndices(self):
"""Ensures time indices for a SV triplet batch are valid."""
seq_len = 600
batch_size = 36
num_views = 2
time_indices, _ = data_providers.get_svtcn_indices(
seq_len, batch_size, num_views)
with self.test_session() as sess:
np_time_indices = sess.run(time_indices)
first = np_time_indices[0]
last = np_time_indices[-1]
# Make sure batch time indices are a contiguous range.
self.assertTrue(np.array_equal(np_time_indices, range(first, last+1)))
if __name__ == "__main__":
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts temp directories of images to videos."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import shutil
# pylint: disable=invalid-name
parser = argparse.ArgumentParser()
parser.add_argument(
'--view_dirs', type=str, default='',
help='Comma-separated list of temp view image directories.')
parser.add_argument(
'--vid_paths', type=str, default='',
help='Comma-separated list of video output paths.')
parser.add_argument(
'--debug_path', type=str, default='',
help='Output path to debug video.')
parser.add_argument(
'--debug_lhs_view', type=str, default='',
help='Output path to debug video.')
parser.add_argument(
'--debug_rhs_view', type=str, default='',
help='Output path to debug video.')
def create_vids(view_dirs, vid_paths, debug_path=None,
debug_lhs_view=0, debug_rhs_view=1):
"""Creates one video per view per sequence."""
# Create the view videos.
for (view_dir, vidpath) in zip(view_dirs, vid_paths):
encode_vid_cmd = r'mencoder mf://%s/*.png \
-mf fps=29:type=png \
-ovc lavc -lavcopts vcodec=mpeg4:mbd=2:trell \
-oac copy -o %s' % (view_dir, vidpath)
os.system(encode_vid_cmd)
# Optionally create a debug side-by-side video.
if debug_path:
lhs = vid_paths[int(debug_lhs_view)]
rhs = vid_paths[int(debug_rhs_view)]
os.system(r"avconv \
-i %s \
-i %s \
-filter_complex '[0:v]pad=iw*2:ih[int];[int][1:v]overlay=W/2:0[vid]' \
-map [vid] \
-c:v libx264 \
-crf 23 \
-preset veryfast \
%s" % (lhs, rhs, debug_path))
def main():
FLAGS, _ = parser.parse_known_args()
assert FLAGS.view_dirs
assert FLAGS.vid_paths
view_dirs = FLAGS.view_dirs.split(',')
vid_paths = FLAGS.vid_paths.split(',')
create_vids(view_dirs, vid_paths, FLAGS.debug_path,
FLAGS.debug_lhs_view, FLAGS.debug_rhs_view)
# Cleanup temp image dirs.
for i in view_dirs:
shutil.rmtree(i)
if __name__ == '__main__':
main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts videos to training, validation, test, and debug tfrecords on cns.
Example usage:
# From phone videos.
x=learning/brain/research/tcn/videos_to_tfrecords && \
blaze build -c opt $x && \
set=tmp && videos=~/data/tcn/datasets/$set/ && \
blaze-bin/$x --logtostderr --output_dir /cns/oi-d/home/$USER/tcn_data/$set \
--input_dir $videos/train
--debug $dataset/debug --rotate 90 --max_per_shard 400
# From webcam videos.
mode=train
x=learning/brain/research/tcn/videos_to_tfrecords && \
blaze build -c opt $x && \
set=tmp && videos=/tmp/tcn/videos/$set/ && \
blaze-bin/$x --logtostderr \
--output_dir /cns/oi-d/home/$USER/tcn_data/$set/$mode \
--input_dir $videos/$mode --max_per_shard 400
"""
import glob
import math
import multiprocessing
from multiprocessing.pool import ThreadPool
import os
from random import shuffle
import re
from StringIO import StringIO
import cv2
from PIL import Image
from PIL import ImageFile
from preprocessing import cv2resizeminedge
from preprocessing import cv2rotateimage
from preprocessing import shapestring
from utils.progress import Progress
import tensorflow.google as tf
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.flags.DEFINE_string('view_pattern', '_view[_]*[0]+[.].*',
'view regexp pattern for first view')
tf.app.flags.DEFINE_string('input_dir', '', '''input data path''')
tf.app.flags.DEFINE_integer('resize_min_edge', 0,
'''resize the smallest edge to this size.''')
tf.app.flags.DEFINE_integer('rotate', 0, '''rotate the image in degrees.''')
tf.app.flags.DEFINE_string('rotate_if_matching', None,
'rotate only if video path matches regexp.')
tf.app.flags.DEFINE_string('output_dir', '', 'output directory for the dataset')
tf.app.flags.DEFINE_integer(
'max_per_shard', -1, 'max # of frames per data chunk')
tf.app.flags.DEFINE_integer('expected_views', 2, 'expected number of views')
tf.app.flags.DEFINE_integer('log_frequency', 50, 'frequency of logging')
tf.app.flags.DEFINE_integer(
'max_views_discrepancy', 100,
'Maximum length difference (in frames) allowed between views')
tf.app.flags.DEFINE_boolean('overwrite', False, 'overwrite output files')
FLAGS = tf.app.flags.FLAGS
feature = tf.train.Feature
bytes_feature = lambda v: feature(bytes_list=tf.train.BytesList(value=v))
int64_feature = lambda v: feature(int64_list=tf.train.Int64List(value=v))
float_feature = lambda v: feature(float_list=tf.train.FloatList(value=v))
def FindPatternFiles(path, view_pattern, errors):
"""Recursively find all files matching a certain pattern."""
if not path:
return None
tf.logging.info(
'Recursively searching for files matching pattern \'%s\' in %s' %
(view_pattern, path))
view_patt = re.compile('.*' + view_pattern)
sequences = []
for root, _, filenames in os.walk(path, followlinks=True):
path_root = root[:len(path)]
assert path_root == path
for filename in filenames:
if view_patt.match(filename):
fullpath = os.path.join(root, re.sub(view_pattern, '', filename))
shortpath = re.sub(path, '', fullpath).lstrip('/')
# Determine if this sequence should be sharded or not.
shard = False
if FLAGS.max_per_shard > 0:
shard = True
# Retrieve number of frames for this sequence.
num_views, length, view_paths, num_frames = GetViewInfo(
fullpath + view_pattern[0] + '*')
if num_views != FLAGS.expected_views:
tf.logging.info('Expected %d views but found: %s' %
(FLAGS.expected_views, str(view_paths)))
assert num_views == FLAGS.expected_views
assert length > 0
# Drop sequences if view lengths differ too much.
if max(num_frames) - min(num_frames) > FLAGS.max_views_discrepancy:
error_msg = (
'Error: ignoring sequence with views with length difference > %d:'
'%s in %s') % (FLAGS.max_views_discrepancy, str(num_frames),
fullpath)
errors.append(error_msg)
tf.logging.error(error_msg)
else:
# Append sequence info.
sequences.append({'full': fullpath, 'name': shortpath, 'len': length,
'start': 0, 'end': length, 'num_views': num_views,
'shard': shard})
return sorted(sequences, key=lambda k: k['name'])
def ShardSequences(sequences, max_per_shard):
"""Find all sequences, shard and randomize them."""
total_shards_len = 0
total_shards = 0
assert max_per_shard > 0
for sequence in sequences:
if sequence['shard']:
sequence['shard'] = False # Reset shard flag.
length = sequence['len']
start = sequence['start']
end = sequence['end']
name = sequence['name']
assert end - start == length
if length > max_per_shard:
# Dividing sequence into smaller shards.
num_shards = int(math.floor(length / max_per_shard)) + 1
size = int(math.ceil(length / num_shards))
tf.logging.info(
'splitting sequence of length %d into %d shards of size %d' %
(length, num_shards, size))
last_end = 0
for i in range(num_shards):
shard_start = last_end
shard_end = min(length, shard_start + size)
if i == num_shards - 1:
shard_end = length
shard_len = shard_end - shard_start
total_shards_len += shard_len
shard_name = name + '_shard%02d' % i
last_end = shard_end
# Enqueuing shard.
if i == 0: # Replace current sequence.
sequence['len'] = shard_len
sequence['start'] = shard_start
sequence['end'] = shard_end
sequence['name'] = shard_name
else: # Enqueue new sequence.
sequences.append(
{'full': sequence['full'], 'name': shard_name,
'len': shard_len, 'start': shard_start, 'end': shard_end,
'num_views': sequence['num_views'], 'shard': False})
total_shards += num_shards
assert last_end == length
# Print resulting sharding.
if total_shards > 0:
tf.logging.info('%d shards of average length %d' %
(total_shards, total_shards_len / total_shards))
return sorted(sequences, key=lambda k: k['name'])
def RandomizeSets(sets):
"""Randomize each set."""
for _, sequences in sorted(sets.iteritems()):
if sequences:
# Randomize order.
shuffle(sequences)
def GetSpecificFrame(vid_path, frame_index):
"""Gets a frame at a specified index in a video."""
cap = cv2.VideoCapture(vid_path)
cap.set(1, frame_index)
_, bgr = cap.read()
cap.release()
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
return rgb
def JpegString(image, jpeg_quality=90):
"""Returns given PIL.Image instance as jpeg string.
Args:
image: A PIL image.
jpeg_quality: The image quality, on a scale from 1 (worst) to 95 (best).
Returns:
a jpeg_string.
"""
# This fix to PIL makes sure that we don't get an error when saving large
# jpeg files. This is a workaround for a bug in PIL. The value should be
# substantially larger than the size of the image being saved.
ImageFile.MAXBLOCK = 640 * 512 * 64
output_jpeg = StringIO()
image.save(output_jpeg, 'jpeg', quality=jpeg_quality, optimize=True)
return output_jpeg.getvalue()
def ParallelPreprocessing(args):
"""Parallel preprocessing: rotation, resize and jpeg encoding to string."""
(vid_path, timestep, num_timesteps, view) = args
try:
image = GetSpecificFrame(vid_path, timestep)
# Resizing.
resize_str = ''
if FLAGS.resize_min_edge > 0:
resize_str += ', resize ' + shapestring(image)
image = cv2resizeminedge(image, FLAGS.resize_min_edge)
resize_str += ' => ' + shapestring(image)
# Rotating.
rotate = None
if FLAGS.rotate:
rotate = FLAGS.rotate
if FLAGS.rotate_if_matching is not None:
rotate = None
patt = re.compile(FLAGS.rotate_if_matching)
if patt.match(vid_path) is not None:
rotate = FLAGS.rotate
if rotate is not None:
image = cv2rotateimage(image, FLAGS.rotate)
# Jpeg encoding.
image = Image.fromarray(image)
im_string = bytes_feature([JpegString(image)])
if timestep % FLAGS.log_frequency == 0:
tf.logging.info('Loaded frame %d / %d for %s (rotation %s%s) from %s' %
(timestep, num_timesteps, view, str(rotate), resize_str,
vid_path))
return im_string
except cv2.error as e:
tf.logging.error('Error while loading frame %d of %s: %s' %
(timestep, vid_path, str(e)))
return None
def GetNumFrames(vid_path):
"""Gets the number of frames in a video."""
cap = cv2.VideoCapture(vid_path)
total_frames = cap.get(7)
cap.release()
return int(total_frames)
def GetViewInfo(views_fullname):
"""Return information about a group of views."""
view_paths = sorted(glob.glob(views_fullname))
num_frames = [GetNumFrames(i) for i in view_paths]
min_num_frames = min(num_frames)
num_views = len(view_paths)
return num_views, min_num_frames, view_paths, num_frames
def AddSequence(sequence, writer, progress, errors):
"""Converts a sequence to a SequenceExample.
Sequences have multiple viewpoint videos. Extract all frames from all
viewpoint videos in parallel, build a single SequenceExample containing
all viewpoint images for every timestep.
Args:
sequence: a dict with information on a sequence.
writer: A TFRecordWriter.
progress: A Progress object to report processing progress.
errors: a list of string to append to in case of errors.
"""
fullname = sequence['full']
shortname = sequence['name']
start = sequence['start']
end = sequence['end']
num_timesteps = sequence['len']
# Build a list of all view paths for this fullname.
path = fullname + FLAGS.view_pattern[0] + '*'
tf.logging.info('Loading sequence from ' + path)
view_paths = sorted(glob.glob(path))
# Extract all images for all views
num_frames = [GetNumFrames(i) for i in view_paths]
tf.logging.info('Loading %s with [%d, %d[ (%d frames) from: %s %s' %
(shortname, start, end, num_timesteps,
str(num_frames), str(view_paths)))
num_views = len(view_paths)
total_timesteps = int(min(num_frames))
assert num_views == FLAGS.expected_views
assert num_views == sequence['num_views']
# Create a worker pool to parallelize loading/rotating
worker_pool = ThreadPool(multiprocessing.cpu_count())
# Collect all images for each view.
view_to_feature_list = {}
view_images = []
for view_idx, view in enumerate(
['view'+str(i) for i in range(num_views)]):
# Flatten list to process in parallel
work = []
for i in range(start, end):
work.append((view_paths[view_idx], i, total_timesteps, view))
# Load and rotate images in parallel
view_images.append(worker_pool.map(ParallelPreprocessing, work))
# Report progress.
progress.Add(len(view_images[view_idx]))
tf.logging.info('%s' % str(progress))
# Remove error frames from all views
i = start
num_errors = 0
while i < len(view_images[0]):
remove_frame = False
# Check if one or more views have an error for this frame.
for view_idx in range(num_views):
if view_images[view_idx][i] is None:
remove_frame = True
error_msg = 'Removing frame %d for all views for %s ' % (i, fullname)
errors.append(error_msg)
tf.logging.error(error_msg)
# Remove faulty frames.
if remove_frame:
num_errors += 1
for view_idx in range(num_views):
del view_images[view_idx][i]
else:
i += 1
# Ignore sequences that have errors.
if num_errors > 0:
error_msg = 'Dropping sequence because of frame errors for %s' % fullname
errors.append(error_msg)
tf.logging.error(error_msg)
else:
# Build FeatureList objects for each view.
for view_idx, view in enumerate(
['view'+str(i) for i in range(num_views)]):
# Construct FeatureList from repeated feature.
view_to_feature_list[view] = tf.train.FeatureList(
feature=view_images[view_idx])
context_features = tf.train.Features(feature={
'task': bytes_feature([shortname]),
'len': int64_feature([num_timesteps])
})
feature_lists = tf.train.FeatureLists(feature_list=view_to_feature_list)
ex = tf.train.SequenceExample(
context=context_features, feature_lists=feature_lists)
writer.write(ex.SerializeToString())
tf.logging.info('Done adding %s with %d timesteps'
% (fullname, num_timesteps))
def PrintSequencesInfo(sequences, prefix):
"""Print information about sequences and return the total number of frames."""
tf.logging.info('')
tf.logging.info(prefix)
num_frames = 0
for sequence in sequences:
shard_str = ''
if sequence['shard']:
shard_str = ' (sharding)'
tf.logging.info('frames [%d, %d[\t(%d frames * %d views)%s\t%s' % (
sequence['start'], sequence['end'], sequence['len'],
sequence['num_views'], shard_str, sequence['name']))
num_frames += sequence['len'] * sequence['num_views']
tf.logging.info(('%d frames (all views), %d sequences, average sequence'
' length (all views): %d') %
(num_frames, len(sequences), num_frames / len(sequences)))
tf.logging.info('')
return num_frames
def CheckRecord(filename, sequence):
"""Check that an existing tfrecord corresponds to the expected sequence."""
num_sequences = 0
total_frames = 0
for serialized_example in tf.python_io.tf_record_iterator(filename):
num_sequences += 1
example = tf.train.SequenceExample()
example.ParseFromString(serialized_example)
length = example.context.feature['len'].int64_list.value[0]
name = example.context.feature['task'].bytes_list.value[0]
total_frames += len(example.feature_lists.feature_list) * length
if sequence['name'] != name or sequence['len'] != length:
return False, total_frames
if num_sequences == 0:
return False, total_frames
return True, total_frames
def AddSequences():
"""Creates one training, validation."""
errors = []
# Generate datasets file lists.
sequences = FindPatternFiles(FLAGS.input_dir, FLAGS.view_pattern, errors)
num_frames = PrintSequencesInfo(sequences,
'Found the following datasets and files:')
# Sharding and randomizing sets.
if FLAGS.max_per_shard > 0:
sequences = ShardSequences(sequences, FLAGS.max_per_shard)
num_frames = PrintSequencesInfo(sequences, 'After sharding:')
tf.logging.info('')
# Process sets.
progress = Progress(num_frames)
output_list = []
for sequence in sequences:
record_name = os.path.join(
FLAGS.output_dir, '%s.tfrecord' % sequence['name'])
if tf.gfile.Exists(record_name) and not FLAGS.overwrite:
ok, num_frames = CheckRecord(record_name, sequence)
if ok:
progress.Add(num_frames)
tf.logging.info('Skipping existing output file: %s' % record_name)
continue
else:
tf.logging.info('File does not match sequence, reprocessing...')
output_dir = os.path.dirname(record_name)
if not tf.gfile.Exists(output_dir):
tf.logging.info('Creating output directory: %s' % output_dir)
tf.gfile.MakeDirs(output_dir)
output_list.append(record_name)
tf.logging.info('Writing to ' + record_name)
writer = tf.python_io.TFRecordWriter(record_name)
AddSequence(sequence, writer, progress, errors)
writer.close()
tf.logging.info('Wrote dataset files: ' + str(output_list))
tf.logging.info('All errors (%d): %s' % (len(errors), str(errors)))
def main(_):
AddSequences()
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Collect images from multiple simultaneous webcams.
Usage:
1. Define some environment variables that describe what you're collecting.
dataset=your_dataset_name
mode=train
num_views=2
viddir=/tmp/tcn/videos
tmp_imagedir=/tmp/tcn/tmp_images
debug_vids=1
2. Run the script.
export DISPLAY=:0.0 && \
root=learning/brain/research/tcn && \
bazel build -c opt --copt=-mavx tcn/webcam && \
bazel-bin/tcn/webcam \
--dataset $dataset \
--mode $mode \
--num_views $num_views \
--tmp_imagedir $tmp_imagedir \
--viddir $viddir \
--debug_vids 1 \
--logtostderr
3. Hit Ctrl-C when done collecting, upon which the script will compile videos
for each view and optionally a debug video concatenating multiple
simultaneous views.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
from multiprocessing import Process
import os
import subprocess
import sys
import time
import cv2
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import animation # pylint: disable=g-import-not-at-top
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
tf.flags.DEFINE_string('dataset', '', 'Name of the dataset we`re collecting.')
tf.flags.DEFINE_string('mode', '',
'What type of data we`re collecting. E.g.:'
'`train`,`valid`,`test`, or `demo`')
tf.flags.DEFINE_string('seqname', '',
'Name of this sequence. If empty, the script will use'
'the name seq_N+1 where seq_N is the latest'
'integer-named sequence in the videos directory.')
tf.flags.DEFINE_integer('num_views', 2,
'Number of webcams.')
tf.flags.DEFINE_string('tmp_imagedir', '/tmp/tcn/data',
'Temporary outdir to write images.')
tf.flags.DEFINE_string('viddir', '/tmp/tcn/videos',
'Base directory to write debug videos.')
tf.flags.DEFINE_boolean('debug_vids', True,
'Whether to generate debug vids with multiple'
'concatenated views.')
tf.flags.DEFINE_string('debug_lhs_view', '0',
'Which viewpoint to use for the lhs video.')
tf.flags.DEFINE_string('debug_rhs_view', '1',
'Which viewpoint to use for the rhs video.')
tf.flags.DEFINE_integer('height', 1080, 'Raw input height.')
tf.flags.DEFINE_integer('width', 1920, 'Raw input width.')
tf.flags.DEFINE_string('webcam_ports', None,
'Comma-separated list of each webcam usb port.')
FLAGS = tf.app.flags.FLAGS
class ImageQueue(object):
"""An image queue holding each stream's most recent image.
Basically implements a process-safe collections.deque(maxlen=1).
"""
def __init__(self):
self.lock = multiprocessing.Lock()
self._queue = multiprocessing.Queue(maxsize=1)
def append(self, data):
with self.lock:
if self._queue.full():
# Pop the first element.
_ = self._queue.get()
self._queue.put(data)
def get(self):
with self.lock:
return self._queue.get()
def empty(self):
return self._queue.empty()
def close(self):
return self._queue.close()
class WebcamViewer(object):
"""A class which displays a live stream from the webcams."""
def __init__(self, display_queues):
"""Create a WebcamViewer instance."""
self.height = FLAGS.height
self.width = FLAGS.width
self.queues = display_queues
def _get_next_images(self):
"""Gets the next image to display."""
# Wait for one image per view.
not_found = True
while not_found:
if True in [q.empty() for q in self.queues]:
# At least one image queue is empty; wait.
continue
else:
# Retrieve the images.
latest = [q.get() for q in self.queues]
combined = np.concatenate(latest, axis=1)
not_found = False
return combined
def run(self):
"""Displays the Kcam live stream in a window.
This function blocks until the window is closed.
"""
fig, rgb_axis = plt.subplots()
image_rows = self.height
image_cols = self.width * FLAGS.num_views
initial_image = np.zeros((image_rows, image_cols, 3))
rgb_image = rgb_axis.imshow(initial_image, interpolation='nearest')
def update_figure(frame_index):
"""Animation function for matplotlib FuncAnimation. Updates the image.
Args:
frame_index: The frame number.
Returns:
An iterable of matplotlib drawables to clear.
"""
_ = frame_index
images = self._get_next_images()
images = images[..., [2, 1, 0]]
rgb_image.set_array(images)
return rgb_image,
# We must keep a reference to this animation in order for it to work.
unused_animation = animation.FuncAnimation(
fig, update_figure, interval=50, blit=True)
mng = plt.get_current_fig_manager()
mng.resize(*mng.window.maxsize())
plt.show()
def reconcile(queues, write_queue):
"""Gets a list of concurrent images from each view queue.
This waits for latest images to be available in all view queues,
then continuously:
- Creates a list of current images for each view.
- Writes the list to a queue of image lists to write to disk.
Args:
queues: A list of `ImageQueues`, holding the latest image from each webcam.
write_queue: A multiprocessing.Queue holding lists of concurrent images.
"""
# Loop forever.
while True:
# Wait till all queues have an image.
if True in [q.empty() for q in queues]:
continue
else:
# Retrieve all views' images.
latest = [q.get() for q in queues]
# Copy the list of all concurrent images to the write queue.
write_queue.put(latest)
def persist(write_queue, view_dirs):
"""Pulls lists of concurrent images off a write queue, writes them to disk.
Args:
write_queue: A multiprocessing.Queue holding lists of concurrent images;
one image per view.
view_dirs: A list of strings, holding the output image directories for each
view.
"""
timestep = 0
while True:
# Wait till there is work in the queue.
if write_queue.empty():
continue
# Get a list of concurrent images to write to disk.
view_ims = write_queue.get()
for view_idx, image in enumerate(view_ims):
view_base = view_dirs[view_idx]
# Assign all concurrent view images the same sequence timestep.
fname = os.path.join(view_base, '%s.png' % str(timestep).zfill(10))
cv2.imwrite(fname, image)
# Move to the next timestep.
timestep += 1
def get_image(camera):
"""Captures a single image from the camera and returns it in PIL format."""
data = camera.read()
_, im = data
return im
def capture_webcam(camera, display_queue, reconcile_queue):
"""Captures images from simultaneous webcams, writes them to queues.
Args:
camera: A cv2.VideoCapture object representing an open webcam stream.
display_queue: An ImageQueue.
reconcile_queue: An ImageQueue.
"""
# Take some ramp images to allow cams to adjust for brightness etc.
for i in range(60):
tf.logging.info('Taking ramp image %d.' % i)
get_image(camera)
cnt = 0
start = time.time()
while True:
# Get images for all cameras.
im = get_image(camera)
# Replace the current image in the display and reconcile queues.
display_queue.append(im)
reconcile_queue.append(im)
cnt += 1
current = time.time()
if cnt % 100 == 0:
tf.logging.info('Collected %s of video, %d frames at ~%.2f fps.' % (
timer(start, current), cnt, cnt/(current-start)))
def timer(start, end):
"""Returns a formatted time elapsed."""
hours, rem = divmod(end-start, 3600)
minutes, seconds = divmod(rem, 60)
return '{:0>2}:{:0>2}:{:05.2f}'.format(int(hours), int(minutes), seconds)
def display_webcams(display_queues):
"""Builds an WebcamViewer to animate incoming images, runs it."""
viewer = WebcamViewer(display_queues)
viewer.run()
def create_vids(view_dirs, seqname):
"""Creates one video per view per sequence."""
vidbase = os.path.join(FLAGS.viddir, FLAGS.dataset, FLAGS.mode)
if not os.path.exists(vidbase):
os.makedirs(vidbase)
vidpaths = []
for idx, view_dir in enumerate(view_dirs):
vidname = os.path.join(vidbase, '%s_view%d.mp4' % (seqname, idx))
encode_vid_cmd = r'mencoder mf://%s/*.png \
-mf fps=29:type=png \
-ovc lavc -lavcopts vcodec=mpeg4:mbd=2:trell \
-oac copy -o %s' % (view_dir, vidname)
os.system(encode_vid_cmd)
vidpaths.append(vidname)
debugpath = None
if FLAGS.debug_vids:
lhs = vidpaths[FLAGS.debug_lhs_view]
rhs = vidpaths[FLAGS.debug_rhs_view]
debug_base = os.path.join('%s_debug' % FLAGS.viddir, FLAGS.dataset,
FLAGS.mode)
if not os.path.exists(debug_base):
os.makedirs(debug_base)
debugpath = '%s/%s.mp4' % (debug_base, seqname)
os.system(r"avconv \
-i %s \
-i %s \
-filter_complex '[0:v]pad=iw*2:ih[int];[int][1:v]overlay=W/2:0[vid]' \
-map [vid] \
-c:v libx264 \
-crf 23 \
-preset veryfast \
%s" % (lhs, rhs, debugpath))
return vidpaths, debugpath
def setup_paths():
"""Sets up the necessary paths to collect videos."""
assert FLAGS.dataset
assert FLAGS.mode
assert FLAGS.num_views
# Setup directory for final images used to create videos for this sequence.
tmp_imagedir = os.path.join(FLAGS.tmp_imagedir, FLAGS.dataset, FLAGS.mode)
if not os.path.exists(tmp_imagedir):
os.makedirs(tmp_imagedir)
# Create a base directory to hold all sequence videos if it doesn't exist.
vidbase = os.path.join(FLAGS.viddir, FLAGS.dataset, FLAGS.mode)
if not os.path.exists(vidbase):
os.makedirs(vidbase)
# Get one directory per concurrent view and a sequence name.
view_dirs, seqname = get_view_dirs(vidbase, tmp_imagedir)
# Get an output path to each view's video.
vid_paths = []
for idx, _ in enumerate(view_dirs):
vid_path = os.path.join(vidbase, '%s_view%d.mp4' % (seqname, idx))
vid_paths.append(vid_path)
# Optionally build paths to debug_videos.
debug_path = None
if FLAGS.debug_vids:
debug_base = os.path.join('%s_debug' % FLAGS.viddir, FLAGS.dataset,
FLAGS.mode)
if not os.path.exists(debug_base):
os.makedirs(debug_base)
debug_path = '%s/%s.mp4' % (debug_base, seqname)
return view_dirs, vid_paths, debug_path
def get_view_dirs(vidbase, tmp_imagedir):
"""Creates and returns one view directory per webcam."""
# Create and append a sequence name.
if FLAGS.seqname:
seqname = FLAGS.seqname
else:
# If there's no video directory, this is the first sequence.
if not os.listdir(vidbase):
seqname = '0'
else:
# Otherwise, get the latest sequence name and increment it.
seq_names = [i.split('_')[0] for i in os.listdir(vidbase)]
latest_seq = sorted(map(int, seq_names), reverse=True)[0]
seqname = str(latest_seq+1)
tf.logging.info('No seqname specified, using: %s' % seqname)
view_dirs = [os.path.join(
tmp_imagedir, '%s_view%d' % (seqname, v)) for v in range(FLAGS.num_views)]
for d in view_dirs:
if not os.path.exists(d):
os.makedirs(d)
return view_dirs, seqname
def get_cameras():
"""Opens cameras using cv2, ensures they can take images."""
# Try to get free webcam ports.
if FLAGS.webcam_ports:
ports = map(int, FLAGS.webcam_ports.split(','))
else:
ports = range(FLAGS.num_views)
cameras = [cv2.VideoCapture(i) for i in ports]
if not all([i.isOpened() for i in cameras]):
try:
# Try to find and kill hanging cv2 process_ids.
output = subprocess.check_output(['lsof -t /dev/video*'], shell=True)
tf.logging.info('Found hanging cv2 process_ids: \n')
tf.logging.info(output)
tf.logging.info('Killing hanging processes...')
for process_id in output.split('\n')[:-1]:
subprocess.call(['kill %s' % process_id], shell=True)
time.sleep(3)
# Recapture webcams.
cameras = [cv2.VideoCapture(i) for i in ports]
except subprocess.CalledProcessError:
raise ValueError(
'Cannot connect to cameras. Try running: \n'
'ls -ltrh /dev/video* \n '
'to see which ports your webcams are connected to. Then hand those '
'ports as a comma-separated list to --webcam_ports, e.g. '
'--webcam_ports 0,1')
# Verify each camera is able to capture images.
ims = map(get_image, cameras)
assert False not in [i is not None for i in ims]
return cameras
def launch_images_to_videos(view_dirs, vid_paths, debug_path):
"""Launch job in separate process to convert images to videos."""
f = 'learning/brain/research/tcn/dataset/images_to_videos.py'
cmd = ['python %s ' % f]
cmd += ['--view_dirs %s ' % ','.join(i for i in view_dirs)]
cmd += ['--vid_paths %s ' % ','.join(i for i in vid_paths)]
cmd += ['--debug_path %s ' % debug_path]
cmd += ['--debug_lhs_view %s ' % FLAGS.debug_lhs_view]
cmd += ['--debug_rhs_view %s ' % FLAGS.debug_rhs_view]
cmd += [' & ']
cmd = ''.join(i for i in cmd)
# Call images_to_videos asynchronously.
fnull = open(os.devnull, 'w')
subprocess.Popen([cmd], stdout=fnull, stderr=subprocess.STDOUT, shell=True)
for p in vid_paths:
tf.logging.info('Writing final video to: %s' % p)
if debug_path:
tf.logging.info('Writing debug video to: %s' % debug_path)
def main(_):
# Initialize the camera capture objects.
cameras = get_cameras()
# Get one output directory per view.
view_dirs, vid_paths, debug_path = setup_paths()
try:
# Wait for user input.
try:
tf.logging.info('About to write to:')
for v in view_dirs:
tf.logging.info(v)
raw_input('Press Enter to continue...')
except SyntaxError:
pass
# Create a queue per view for displaying and saving images.
display_queues = [ImageQueue() for _ in range(FLAGS.num_views)]
reconcile_queues = [ImageQueue() for _ in range(FLAGS.num_views)]
# Create a queue for collecting all tuples of multi-view images to write to
# disk.
write_queue = multiprocessing.Queue()
processes = []
# Create a process to display collected images in real time.
processes.append(Process(target=display_webcams, args=(display_queues,)))
# Create a process to collect the latest simultaneous images from each view.
processes.append(Process(
target=reconcile, args=(reconcile_queues, write_queue,)))
# Create a process to collect the latest simultaneous images from each view.
processes.append(Process(
target=persist, args=(write_queue, view_dirs,)))
for (cam, dq, rq) in zip(cameras, display_queues, reconcile_queues):
processes.append(Process(
target=capture_webcam, args=(cam, dq, rq,)))
for p in processes:
p.start()
for p in processes:
p.join()
except KeyboardInterrupt:
# Close the queues.
for q in display_queues + reconcile_queues:
q.close()
# Release the cameras.
for cam in cameras:
cam.release()
# Launch images_to_videos script asynchronously.
launch_images_to_videos(view_dirs, vid_paths, debug_path)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Downloads pretrained InceptionV3 and ResnetV2-50 checkpoints."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tarfile
import urllib
INCEPTION_URL = 'http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz'
RESNET_URL = 'http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz'
def DownloadWeights(model_dir, url):
os.makedirs(model_dir)
tar_path = os.path.join(model_dir, 'ckpt.tar.gz')
urllib.urlretrieve(url, tar_path)
tar = tarfile.open(os.path.join(model_dir, 'ckpt.tar.gz'))
tar.extractall(model_dir)
if __name__ == '__main__':
# Create a directory for all pretrained checkpoints.
ckpt_dir = 'pretrained_checkpoints'
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
# Download inception.
print('Downloading inception pretrained weights...')
inception_dir = os.path.join(ckpt_dir, 'inception')
DownloadWeights(inception_dir, INCEPTION_URL)
print('Done downloading inception pretrained weights.')
print('Downloading resnet pretrained weights...')
resnet_dir = os.path.join(ckpt_dir, 'resnet')
DownloadWeights(resnet_dir, RESNET_URL)
print('Done downloading resnet pretrained weights.')
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base estimator defining TCN training, test, and inference."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import ABCMeta
from abc import abstractmethod
import os
import numpy as np
import numpy as np
import data_providers
import preprocessing
from utils import util
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.python.training import session_run_hook
tf.app.flags.DEFINE_integer(
'tf_random_seed', 0, 'Random seed.')
FLAGS = tf.app.flags.FLAGS
class InitFromPretrainedCheckpointHook(session_run_hook.SessionRunHook):
"""Hook that can init graph from a pretrained checkpoint."""
def __init__(self, pretrained_checkpoint_dir):
"""Initializes a `InitFromPretrainedCheckpointHook`.
Args:
pretrained_checkpoint_dir: The dir of pretrained checkpoint.
Raises:
ValueError: If pretrained_checkpoint_dir is invalid.
"""
if pretrained_checkpoint_dir is None:
raise ValueError('pretrained_checkpoint_dir must be specified.')
self._pretrained_checkpoint_dir = pretrained_checkpoint_dir
def begin(self):
checkpoint_reader = tf.contrib.framework.load_checkpoint(
self._pretrained_checkpoint_dir)
variable_shape_map = checkpoint_reader.get_variable_to_shape_map()
exclude_scopes = 'logits/,final_layer/,aux_'
# Skip restoring global_step as to run fine tuning from step=0.
exclusions = ['global_step']
if exclude_scopes:
exclusions.extend([scope.strip() for scope in exclude_scopes.split(',')])
variable_to_restore = tf.contrib.framework.get_model_variables()
# Variable filtering by given exclude_scopes.
filtered_variables_to_restore = {}
for v in variable_to_restore:
excluded = False
for exclusion in exclusions:
if v.name.startswith(exclusion):
excluded = True
break
if not excluded:
var_name = v.name.split(':')[0]
filtered_variables_to_restore[var_name] = v
# Final filter by checking shape matching and skipping variables that
# are not in the checkpoint.
final_variables_to_restore = {}
for var_name, var_tensor in filtered_variables_to_restore.iteritems():
if var_name not in variable_shape_map:
# Try moving average version of variable.
var_name = os.path.join(var_name, 'ExponentialMovingAverage')
if var_name not in variable_shape_map:
tf.logging.info(
'Skip init [%s] because it is not in ckpt.', var_name)
# Skip variables not in the checkpoint.
continue
if not var_tensor.get_shape().is_compatible_with(
variable_shape_map[var_name]):
# Skip init variable from ckpt if shape dismatch.
tf.logging.info(
'Skip init [%s] from [%s] in ckpt because shape dismatch: %s vs %s',
var_tensor.name, var_name,
var_tensor.get_shape(), variable_shape_map[var_name])
continue
tf.logging.info('Init %s from %s in ckpt' % (var_tensor, var_name))
final_variables_to_restore[var_name] = var_tensor
self._init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
self._pretrained_checkpoint_dir,
final_variables_to_restore)
def after_create_session(self, session, coord):
tf.logging.info('Restoring InceptionV3 weights.')
self._init_fn(session)
tf.logging.info('Done restoring InceptionV3 weights.')
class BaseEstimator(object):
"""Abstract TCN base estimator class."""
__metaclass__ = ABCMeta
def __init__(self, config, logdir):
"""Constructor.
Args:
config: A Luatable-like T object holding training config.
logdir: String, a directory where checkpoints and summaries are written.
"""
self._config = config
self._logdir = logdir
@abstractmethod
def construct_input_fn(self, records, is_training):
"""Builds an estimator input_fn.
The input_fn is used to pass feature and target data to the train,
evaluate, and predict methods of the Estimator.
Method to be overridden by implementations.
Args:
records: A list of Strings, paths to TFRecords with image data.
is_training: Boolean, whether or not we're training.
Returns:
Function, that has signature of ()->(dict of features, target).
features is a dict mapping feature names to `Tensors`
containing the corresponding feature data (typically, just a single
key/value pair 'raw_data' -> image `Tensor` for TCN.
labels is a 1-D int32 `Tensor` holding labels.
"""
pass
def preprocess_data(self, images, is_training):
"""Preprocesses raw images for either training or inference.
Args:
images: A 4-D float32 `Tensor` holding images to preprocess.
is_training: Boolean, whether or not we're in training.
Returns:
data_preprocessed: data after the preprocessor.
"""
config = self._config
height = config.data.height
width = config.data.width
min_scale = config.data.augmentation.minscale
max_scale = config.data.augmentation.maxscale
p_scale_up = config.data.augmentation.proportion_scaled_up
aug_color = config.data.augmentation.color
fast_mode = config.data.augmentation.fast_mode
crop_strategy = config.data.preprocessing.eval_cropping
preprocessed_images = preprocessing.preprocess_images(
images, is_training, height, width,
min_scale, max_scale, p_scale_up,
aug_color=aug_color, fast_mode=fast_mode,
crop_strategy=crop_strategy)
return preprocessed_images
@abstractmethod
def forward(self, images, is_training, reuse=False):
"""Defines the forward pass that converts batch images to embeddings.
Method to be overridden by implementations.
Args:
images: A 4-D float32 `Tensor` holding images to be embedded.
is_training: Boolean, whether or not we're in training mode.
reuse: Boolean, whether or not to reuse embedder.
Returns:
embeddings: A 2-D float32 `Tensor` holding embedded images.
"""
pass
@abstractmethod
def define_loss(self, embeddings, labels, is_training):
"""Defines the loss function on the embedding vectors.
Method to be overridden by implementations.
Args:
embeddings: A 2-D float32 `Tensor` holding embedded images.
labels: A 1-D int32 `Tensor` holding problem labels.
is_training: Boolean, whether or not we're in training mode.
Returns:
loss: tf.float32 scalar.
"""
pass
@abstractmethod
def define_eval_metric_ops(self):
"""Defines the dictionary of eval metric tensors.
Method to be overridden by implementations.
Returns:
eval_metric_ops: A dict of name/value pairs specifying the
metrics that will be calculated when the model runs in EVAL mode.
"""
pass
def get_train_op(self, loss):
"""Creates a training op.
Args:
loss: A float32 `Tensor` representing the total training loss.
Returns:
train_op: A slim.learning.create_train_op train_op.
Raises:
ValueError: If specified optimizer isn't supported.
"""
# Get variables to train (defined in subclass).
assert self.variables_to_train
# Define a learning rate schedule.
decay_steps = self._config.learning.decay_steps
decay_factor = self._config.learning.decay_factor
learning_rate = float(self._config.learning.learning_rate)
# Define a learning rate schedule.
global_step = slim.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(
learning_rate,
global_step,
decay_steps,
decay_factor,
staircase=True)
# Create an optimizer.
opt_type = self._config.learning.optimizer
if opt_type == 'adam':
opt = tf.train.AdamOptimizer(learning_rate)
elif opt_type == 'momentum':
opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
elif opt_type == 'rmsprop':
opt = tf.train.RMSPropOptimizer(learning_rate, momentum=0.9,
epsilon=1.0, decay=0.9)
else:
raise ValueError('Unsupported optimizer %s' % opt_type)
if self._config.use_tpu:
opt = tpu_optimizer.CrossShardOptimizer(opt)
# Create a training op.
# train_op = opt.minimize(loss, var_list=self.variables_to_train)
# Create a training op.
train_op = slim.learning.create_train_op(
loss,
optimizer=opt,
variables_to_train=self.variables_to_train,
update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))
return train_op
def _get_model_fn(self):
"""Defines behavior for training, evaluation, and inference (prediction).
Returns:
`model_fn` for `Estimator`.
"""
# pylint: disable=unused-argument
def model_fn(features, labels, mode, params):
"""Build the model based on features, labels, and mode.
Args:
features: Dict, strings to `Tensor` input data, returned by the
input_fn.
labels: The labels Tensor returned by the input_fn.
mode: A string indicating the mode. This will be either
tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.PREDICT,
or tf.estimator.ModeKeys.EVAL.
params: A dict holding training parameters, passed in during TPU
training.
Returns:
A tf.estimator.EstimatorSpec specifying train/test/inference behavior.
"""
is_training = mode == tf.estimator.ModeKeys.TRAIN
# Get preprocessed images from the features dict.
batch_preprocessed = features['batch_preprocessed']
# Do a forward pass to embed data.
batch_encoded = self.forward(batch_preprocessed, is_training)
# Optionally set the pretrained initialization function.
initializer_fn = None
if mode == tf.estimator.ModeKeys.TRAIN:
initializer_fn = self.pretrained_init_fn
# If we're training or evaluating, define total loss.
total_loss = None
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
loss = self.define_loss(batch_encoded, labels, is_training)
tf.losses.add_loss(loss)
total_loss = tf.losses.get_total_loss()
# If we're training, define a train op.
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = self.get_train_op(total_loss)
# If we're doing inference, set the output to be the embedded images.
predictions_dict = None
if mode == tf.estimator.ModeKeys.PREDICT:
predictions_dict = {'embeddings': batch_encoded}
# Pass through additional metadata stored in features.
for k, v in features.iteritems():
predictions_dict[k] = v
# If we're evaluating, define some eval metrics.
eval_metric_ops = None
if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = self.define_eval_metric_ops()
# Define training scaffold to load pretrained weights.
num_checkpoint_to_keep = self._config.logging.checkpoint.num_to_keep
saver = tf.train.Saver(
max_to_keep=num_checkpoint_to_keep)
if is_training and self._config.use_tpu:
# TPU doesn't have a scaffold option at the moment, so initialize
# pretrained weights using a custom train_hook instead.
return tpu_estimator.TPUEstimatorSpec(
mode,
loss=total_loss,
eval_metrics=None,
train_op=train_op,
predictions=predictions_dict)
else:
# Build a scaffold to initialize pretrained weights.
scaffold = tf.train.Scaffold(
init_fn=initializer_fn,
saver=saver,
summary_op=None)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions_dict,
loss=total_loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
scaffold=scaffold)
return model_fn
def train(self):
"""Runs training."""
# Get a list of training tfrecords.
config = self._config
training_dir = config.data.training
training_records = util.GetFilesRecursively(training_dir)
# Define batch size.
self._batch_size = config.data.batch_size
# Create a subclass-defined training input function.
train_input_fn = self.construct_input_fn(
training_records, is_training=True)
# Create the estimator.
estimator = self._build_estimator(is_training=True)
train_hooks = None
if config.use_tpu:
# TPU training initializes pretrained weights using a custom train hook.
train_hooks = []
if tf.train.latest_checkpoint(self._logdir) is None:
train_hooks.append(
InitFromPretrainedCheckpointHook(
config[config.embedder_strategy].pretrained_checkpoint))
# Run training.
estimator.train(input_fn=train_input_fn, hooks=train_hooks,
steps=config.learning.max_step)
def _build_estimator(self, is_training):
"""Returns an Estimator object.
Args:
is_training: Boolean, whether or not we're in training mode.
Returns:
A tf.estimator.Estimator.
"""
config = self._config
save_checkpoints_steps = config.logging.checkpoint.save_checkpoints_steps
keep_checkpoint_max = self._config.logging.checkpoint.num_to_keep
if is_training and config.use_tpu:
iterations = config.tpu.iterations
num_shards = config.tpu.num_shards
run_config = tpu_config.RunConfig(
save_checkpoints_secs=None,
save_checkpoints_steps=save_checkpoints_steps,
keep_checkpoint_max=keep_checkpoint_max,
master=FLAGS.master,
evaluation_master=FLAGS.master,
model_dir=self._logdir,
tpu_config=tpu_config.TPUConfig(
iterations_per_loop=iterations,
num_shards=num_shards,
per_host_input_for_training=num_shards <= 8),
tf_random_seed=FLAGS.tf_random_seed)
batch_size = config.data.batch_size
return tpu_estimator.TPUEstimator(
model_fn=self._get_model_fn(),
config=run_config,
use_tpu=True,
train_batch_size=batch_size,
eval_batch_size=batch_size)
else:
run_config = tf.estimator.RunConfig().replace(
model_dir=self._logdir,
save_checkpoints_steps=save_checkpoints_steps,
keep_checkpoint_max=keep_checkpoint_max,
tf_random_seed=FLAGS.tf_random_seed)
return tf.estimator.Estimator(
model_fn=self._get_model_fn(),
config=run_config)
def evaluate(self):
"""Runs `Estimator` validation.
"""
config = self._config
# Get a list of validation tfrecords.
validation_dir = config.data.validation
validation_records = util.GetFilesRecursively(validation_dir)
# Define batch size.
self._batch_size = config.data.batch_size
# Create a subclass-defined training input function.
validation_input_fn = self.construct_input_fn(
validation_records, False)
# Create the estimator.
estimator = self._build_estimator(is_training=False)
# Run validation.
eval_batch_size = config.data.batch_size
num_eval_samples = config.val.num_eval_samples
num_eval_batches = int(num_eval_samples / eval_batch_size)
estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches)
def inference(
self, inference_input, checkpoint_path, batch_size=None, **kwargs):
"""Defines 3 of modes of inference.
Inputs:
* Mode 1: Input is an input_fn.
* Mode 2: Input is a TFRecord (or list of TFRecords).
* Mode 3: Input is a numpy array holding an image (or array of images).
Outputs:
* Mode 1: this returns an iterator over embeddings and additional
metadata. See
https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#predict
for details.
* Mode 2: Returns an iterator over tuples of
(embeddings, raw_image_strings, sequence_name), where embeddings is a
2-D float32 numpy array holding [sequence_size, embedding_size] image
embeddings, raw_image_strings is a 1-D string numpy array holding
[sequence_size] jpeg-encoded image strings, and sequence_name is a
string holding the name of the embedded sequence.
* Mode 3: Returns a tuple of (embeddings, raw_image_strings), where
embeddings is a 2-D float32 numpy array holding
[batch_size, embedding_size] image embeddings, raw_image_strings is a
1-D string numpy array holding [batch_size] jpeg-encoded image strings.
Args:
inference_input: This can be a tf.Estimator input_fn, a TFRecord path,
a list of TFRecord paths, a numpy image, or an array of numpy images.
checkpoint_path: String, path to the checkpoint to restore for inference.
batch_size: Int, the size of the batch to use for inference.
**kwargs: Additional keyword arguments, depending on the mode.
See _input_fn_inference, _tfrecord_inference, and _np_inference.
Returns:
inference_output: Inference output depending on mode, see above for
details.
Raises:
ValueError: If inference_input isn't a tf.Estimator input_fn,
a TFRecord path, a list of TFRecord paths, or a numpy array,
"""
# Mode 1: input is a callable tf.Estimator input_fn.
if callable(inference_input):
return self._input_fn_inference(
input_fn=inference_input, checkpoint_path=checkpoint_path, **kwargs)
# Mode 2: Input is a TFRecord path (or list of TFRecord paths).
elif util.is_tfrecord_input(inference_input):
return self._tfrecord_inference(
records=inference_input, checkpoint_path=checkpoint_path,
batch_size=batch_size, **kwargs)
# Mode 3: Input is a numpy array of raw images.
elif util.is_np_array(inference_input):
return self._np_inference(
np_images=inference_input, checkpoint_path=checkpoint_path, **kwargs)
else:
raise ValueError(
'inference input must be a tf.Estimator input_fn, a TFRecord path,'
'a list of TFRecord paths, or a numpy array. Got: %s' % str(type(
inference_input)))
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None):
"""Mode 1: tf.Estimator inference.
Args:
input_fn: Function, that has signature of ()->(dict of features, None).
This is a function called by the estimator to get input tensors (stored
in the features dict) to do inference over.
checkpoint_path: String, path to a specific checkpoint to restore.
predict_keys: List of strings, the keys of the `Tensors` in the features
dict (returned by the input_fn) to evaluate during inference.
Returns:
predictions: An Iterator, yielding evaluated values of `Tensors`
specified in `predict_keys`.
"""
# Create the estimator.
estimator = self._build_estimator(is_training=False)
# Create an iterator of predicted embeddings.
predictions = estimator.predict(input_fn=input_fn,
checkpoint_path=checkpoint_path,
predict_keys=predict_keys)
return predictions
def _tfrecord_inference(self, records, checkpoint_path, batch_size,
num_sequences=-1, reuse=False):
"""Mode 2: TFRecord inference.
Args:
records: List of strings, paths to TFRecords.
checkpoint_path: String, path to a specific checkpoint to restore.
batch_size: Int, size of inference batch.
num_sequences: Int, number of sequences to embed. If -1,
embed everything.
reuse: Boolean, whether or not to reuse embedder weights.
Yields:
(embeddings, raw_image_strings, sequence_name):
embeddings is a 2-D float32 numpy array holding
[sequence_size, embedding_size] image embeddings.
raw_image_strings is a 1-D string numpy array holding
[sequence_size] jpeg-encoded image strings.
sequence_name is a string holding the name of the embedded sequence.
"""
tf.reset_default_graph()
if not isinstance(records, list):
records = list(records)
# Map the list of tfrecords to a dataset of preprocessed images.
num_views = self._config.data.num_views
(views, task, seq_len) = data_providers.full_sequence_provider(
records, num_views)
tensor_dict = {
'raw_image_strings': views,
'task': task,
'seq_len': seq_len
}
# Create a preprocess function over raw image string placeholders.
image_str_placeholder = tf.placeholder(tf.string, shape=[None])
decoded = preprocessing.decode_images(image_str_placeholder)
decoded.set_shape([batch_size, None, None, 3])
preprocessed = self.preprocess_data(decoded, is_training=False)
# Create an inference graph over preprocessed images.
embeddings = self.forward(preprocessed, is_training=False, reuse=reuse)
# Create a saver to restore model variables.
tf.train.get_or_create_global_step()
saver = tf.train.Saver(tf.all_variables())
# Create a session and restore model variables.
with tf.train.MonitoredSession() as sess:
saver.restore(sess, checkpoint_path)
cnt = 0
# If num_sequences is specified, embed that many sequences, else embed
# everything.
try:
while cnt < num_sequences if num_sequences != -1 else True:
# Get a preprocessed image sequence.
np_data = sess.run(tensor_dict)
np_raw_images = np_data['raw_image_strings']
np_seq_len = np_data['seq_len']
np_task = np_data['task']
# Embed each view.
embedding_size = self._config.embedding_size
view_embeddings = [
np.zeros((0, embedding_size)) for _ in range(num_views)]
for view_index in range(num_views):
view_raw = np_raw_images[view_index]
# Embed the full sequence.
t = 0
while t < np_seq_len:
# Decode and preprocess the batch of image strings.
embeddings_np = sess.run(
embeddings, feed_dict={
image_str_placeholder: view_raw[t:t+batch_size]})
view_embeddings[view_index] = np.append(
view_embeddings[view_index], embeddings_np, axis=0)
tf.logging.info('Embedded %d images for task %s' % (t, np_task))
t += batch_size
# Done embedding for all views.
view_raw_images = np_data['raw_image_strings']
yield (view_embeddings, view_raw_images, np_task)
cnt += 1
except tf.errors.OutOfRangeError:
tf.logging.info('Done embedding entire dataset.')
def _np_inference(self, np_images, checkpoint_path):
"""Mode 3: Call this repeatedly to do inference over numpy images.
This mode is for when we we want to do real-time inference over
some stream of images (represented as numpy arrays).
Args:
np_images: A float32 numpy array holding images to embed.
checkpoint_path: String, path to a specific checkpoint to restore.
Returns:
(embeddings, raw_image_strings):
embeddings is a 2-D float32 numpy array holding
[inferred batch_size, embedding_size] image embeddings.
raw_image_strings is a 1-D string numpy array holding
[inferred batch_size] jpeg-encoded image strings.
"""
if isinstance(np_images, list):
np_images = np.asarray(np_images)
# Add a batch dimension if only 3-dimensional.
if len(np_images.shape) == 3:
np_images = np.expand_dims(np_images, axis=0)
# If np_images are in the range [0,255], convert to [0,1].
assert np.min(np_images) >= 0.
if (np.min(np_images), np.max(np_images)) == (0, 255):
np_images = np_images.astype(np.float32) / 255.
assert (np.min(np_images), np.max(np_images)) == (0., 1.)
# If this is the first pass, set up inference graph.
if not hasattr(self, '_np_inf_tensor_dict'):
self._setup_np_inference(np_images, checkpoint_path)
# Convert np_images to embeddings.
np_tensor_dict = self._sess.run(self._np_inf_tensor_dict, feed_dict={
self._image_placeholder: np_images
})
return np_tensor_dict['embeddings'], np_tensor_dict['raw_image_strings']
def _setup_np_inference(self, np_images, checkpoint_path):
"""Sets up and restores inference graph, creates and caches a Session."""
tf.logging.info('Restoring model weights.')
# Define inference over an image placeholder.
_, height, width, _ = np.shape(np_images)
image_placeholder = tf.placeholder(
tf.float32, shape=(None, height, width, 3))
# Preprocess batch.
preprocessed = self.preprocess_data(image_placeholder, is_training=False)
# Unscale and jpeg encode preprocessed images for display purposes.
im_strings = preprocessing.unscale_jpeg_encode(preprocessed)
# Do forward pass to get embeddings.
embeddings = self.forward(preprocessed, is_training=False)
# Create a saver to restore model variables.
tf.train.get_or_create_global_step()
saver = tf.train.Saver(tf.all_variables())
self._image_placeholder = image_placeholder
self._batch_encoded = embeddings
self._np_inf_tensor_dict = {
'embeddings': embeddings,
'raw_image_strings': im_strings,
}
# Create a session and restore model variables.
self._sess = tf.Session()
saver.restore(self._sess, checkpoint_path)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Get a configured estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from estimators import mvtcn_estimator as mvtcn_estimators
from estimators import svtcn_estimator
def get_mvtcn_estimator(loss_strategy, config, logdir):
"""Returns a configured MVTCN estimator."""
loss_to_trainer = {
'triplet_semihard': mvtcn_estimators.MVTCNTripletEstimator,
'npairs': mvtcn_estimators.MVTCNNpairsEstimator,
}
if loss_strategy not in loss_to_trainer:
raise ValueError('Unknown loss for MVTCN: %s' % loss_strategy)
estimator = loss_to_trainer[loss_strategy](config, logdir)
return estimator
def get_estimator(config, logdir):
"""Returns an unsupervised model trainer based on config.
Args:
config: A T object holding training configs.
logdir: String, path to directory where model checkpoints and summaries
are saved.
Returns:
estimator: A configured `TCNEstimator` object.
Raises:
ValueError: If unknown training strategy is specified.
"""
# Get the training strategy.
training_strategy = config.training_strategy
if training_strategy == 'mvtcn':
loss_strategy = config.loss_strategy
estimator = get_mvtcn_estimator(
loss_strategy, config, logdir)
elif training_strategy == 'svtcn':
estimator = svtcn_estimator.SVTCNTripletEstimator(config, logdir)
else:
raise ValueError('Unknown training strategy: %s' % training_strategy)
return estimator
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MVTCN trainer implementations with various metric learning losses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import data_providers
import model as model_module
from estimators import base_estimator
import tensorflow as tf
class MVTCNEstimator(base_estimator.BaseEstimator):
"""Multi-view TCN base class."""
def __init__(self, config, logdir):
super(MVTCNEstimator, self).__init__(config, logdir)
def _pairs_provider(self, records, is_training):
config = self._config
num_views = config.data.num_views
window = config.mvtcn.window
num_parallel_calls = config.data.num_parallel_calls
sequence_prefetch_size = config.data.sequence_prefetch_size
batch_prefetch_size = config.data.batch_prefetch_size
examples_per_seq = config.data.examples_per_sequence
return functools.partial(
data_providers.multiview_pairs_provider,
file_list=records,
preprocess_fn=self.preprocess_data,
num_views=num_views,
window=window,
is_training=is_training,
examples_per_seq=examples_per_seq,
num_parallel_calls=num_parallel_calls,
sequence_prefetch_size=sequence_prefetch_size,
batch_prefetch_size=batch_prefetch_size)
def forward(self, images_concat, is_training, reuse=False):
"""See base class."""
embedder_strategy = self._config.embedder_strategy
loss_strategy = self._config.loss_strategy
l2_normalize_embedding = self._config[loss_strategy].embedding_l2
embedder = model_module.get_embedder(
embedder_strategy,
self._config,
images_concat,
is_training=is_training,
l2_normalize_embedding=l2_normalize_embedding, reuse=reuse)
embeddings_concat = embedder.construct_embedding()
variables_to_train = embedder.get_trainable_variables()
self.variables_to_train = variables_to_train
self.pretrained_init_fn = embedder.init_fn
return embeddings_concat
def _collect_image_summaries(self, anchor_images, positive_images,
images_concat):
image_summaries = self._config.logging.summary.image_summaries
if image_summaries and not self._config.use_tpu:
batch_pairs_summary = tf.concat(
[anchor_images, positive_images], axis=2)
tf.summary.image('training/mvtcn_pairs', batch_pairs_summary)
tf.summary.image('training/images_preprocessed_concat', images_concat)
class MVTCNTripletEstimator(MVTCNEstimator):
"""Multi-View TCN with semihard triplet loss."""
def __init__(self, config, logdir):
super(MVTCNTripletEstimator, self).__init__(config, logdir)
def construct_input_fn(self, records, is_training):
"""See base class."""
def input_fn(params):
"""Provides input to MVTCN models."""
if is_training and self._config.use_tpu:
batch_size = params['batch_size']
else:
batch_size = self._batch_size
(images_concat,
anchor_labels,
positive_labels,
anchor_images,
positive_images) = self._pairs_provider(
records, is_training)(batch_size=batch_size)
if is_training:
self._collect_image_summaries(anchor_images, positive_images,
images_concat)
labels = tf.concat([anchor_labels, positive_labels], axis=0)
features = {'batch_preprocessed': images_concat}
return (features, labels)
return input_fn
def define_loss(self, embeddings, labels, is_training):
"""See base class."""
margin = self._config.triplet_semihard.margin
loss = tf.contrib.losses.metric_learning.triplet_semihard_loss(
labels=labels, embeddings=embeddings, margin=margin)
self._loss = loss
if is_training and not self._config.use_tpu:
tf.summary.scalar('training/triplet_semihard', loss)
return loss
def define_eval_metric_ops(self):
"""See base class."""
return {'validation/triplet_semihard': tf.metrics.mean(self._loss)}
class MVTCNNpairsEstimator(MVTCNEstimator):
"""Multi-View TCN with npairs loss."""
def __init__(self, config, logdir):
super(MVTCNNpairsEstimator, self).__init__(config, logdir)
def construct_input_fn(self, records, is_training):
"""See base class."""
def input_fn(params):
"""Provides input to MVTCN models."""
if is_training and self._config.use_tpu:
batch_size = params['batch_size']
else:
batch_size = self._batch_size
(images_concat,
npairs_labels,
_,
anchor_images,
positive_images) = self._pairs_provider(
records, is_training)(batch_size=batch_size)
if is_training:
self._collect_image_summaries(anchor_images, positive_images,
images_concat)
features = {'batch_preprocessed': images_concat}
return (features, npairs_labels)
return input_fn
def define_loss(self, embeddings, labels, is_training):
"""See base class."""
embeddings_anchor, embeddings_positive = tf.split(embeddings, 2, axis=0)
loss = tf.contrib.losses.metric_learning.npairs_loss(
labels=labels, embeddings_anchor=embeddings_anchor,
embeddings_positive=embeddings_positive)
self._loss = loss
if is_training and not self._config.use_tpu:
tf.summary.scalar('training/npairs', loss)
return loss
def define_eval_metric_ops(self):
"""See base class."""
return {'validation/npairs': tf.metrics.mean(self._loss)}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment