Commit 811fa209 authored by Anelia Angelova's avatar Anelia Angelova
Browse files

Added struct2depth model

parent 4d4eb85c
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
/research/slim/ @sguada @nathansilberman /research/slim/ @sguada @nathansilberman
/research/steve/ @buckman-google /research/steve/ @buckman-google
/research/street/ @theraysmith /research/street/ @theraysmith
/research/struct2depth/ @aneliaangelova
/research/swivel/ @waterson /research/swivel/ @waterson
/research/syntaxnet/ @calberti @andorardo @bogatyy @markomernick /research/syntaxnet/ @calberti @andorardo @bogatyy @markomernick
/research/tcn/ @coreylynch @sermanet /research/tcn/ @coreylynch @sermanet
......
...@@ -74,6 +74,7 @@ request. ...@@ -74,6 +74,7 @@ request.
- [slim](slim): image classification models in TF-Slim. - [slim](slim): image classification models in TF-Slim.
- [street](street): identify the name of a street (in France) from an image - [street](street): identify the name of a street (in France) from an image
using a Deep RNN. using a Deep RNN.
- [struct2depth](struct2depth): unsupervised learning of depth and ego-motion.
- [swivel](swivel): the Swivel algorithm for generating word embeddings. - [swivel](swivel): the Swivel algorithm for generating word embeddings.
- [syntaxnet](syntaxnet): neural models of natural language syntax. - [syntaxnet](syntaxnet): neural models of natural language syntax.
- [tcn](tcn): Self-supervised representation learning from multi-view video. - [tcn](tcn): Self-supervised representation learning from multi-view video.
......
package(default_visibility = ["//visibility:public"])
# struct2depth
This a method for unsupervised learning of depth and egomotion from monocular video, achieving new state-of-the-art results on both tasks by explicitly modeling 3D object motion, performing on-line refinement and improving quality for moving objects by novel loss formulations. It will appear in the following paper:
**V. Casser, S. Pirk, R. Mahjourian, A. Angelova, Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos, AAAI Conference on Artificial Intelligence, 2019**
https://arxiv.org/pdf/1811.06152.pdf
This code is implemented and supported by Vincent Casser (git username: VincentCa) and Anelia Angelova (git username: AneliaAngelova). Please contact anelia@google.com for questions.
Project website: https://sites.google.com/view/struct2depth.
## Quick start: Running training
Before running training, run gen_data_* script for the respective dataset in order to generate the data in the appropriate format for KITTI or Cityscapes. It is assumed that motion masks are already generated and stored as images.
Models are trained from an Imagenet pretrained model.
```shell
ckpt_dir="your/checkpoint/folder"
data_dir="KITTI_SEQ2_LR/" # Set for KITTI
data_dir="CITYSCAPES_SEQ2_LR/" # Set for Cityscapes
imagenet_ckpt="resnet_pretrained/model.ckpt"
python train.py \
--logtostderr \
--checkpoint_dir $ckpt_dir \
--data_dir $data_dir \
--architecture resnet \
--imagenet_ckpt $imagenet_ckpt \
--imagenet_norm true \
--joint_encoder false
```
## Running depth/egomotion inference on an image folder
KITTI is trained on the raw image data (resized to 416 x 128), but inputs are standardized before feeding them, and Cityscapes images are cropped using the following cropping parameters: (192, 1856, 256, 768). If using a different crop, it is likely that additional training is necessary. Therefore, please follow the inference example shown below when using one of the models. The right choice might depend on a variety of factors. For example, if a checkpoint should be used for odometry, be aware that for improved odometry on motion models, using segmentation masks could be advantageous (setting *use_masks=true* for inference). On the other hand, all models can be used for single-frame depth estimation without any additional information.
```shell
input_dir="your/image/folder"
output_dir="your/output/folder"
model_checkpoint="your/model/checkpoint"
python inference.py \
--logtostderr \
--file_extension png \
--depth \
--egomotion true \
--input_dir $input_dir \
--output_dir $output_dir \
--model_ckpt $model_checkpoint
```
Note that the egomotion prediction expects the files in the input directory to be a consecutive sequence, and that sorting the filenames alphabetically is putting them in the right order.
One can also run inference on KITTI by providing
```shell
--input_list_file ~/kitti-raw-uncompressed/test_files_eigen.txt
```
and on Cityscapes by passing
```shell
--input_list_file CITYSCAPES_FULL/test_files_cityscapes.txt
```
instead of *input_dir*.
Alternatively inference can also be ran on pre-processed images.
## Running on-line refinement
On-line refinement is executed on top of an existing inference folder, so make sure to run regular inference first. Then you can run the on-line fusion procedure as follows:
```shell
prediction_dir="some/prediction/dir"
model_ckpt="checkpoints/checkpoints_baseline/model-199160"
handle_motion="false"
size_constraint_weight="0" # This must be zero when not handling motion.
# If running on KITTI, set as follows:
data_dir="KITTI_SEQ2_LR_EIGEN/"
triplet_list_file="$data_dir/test_files_eigen_triplets.txt"
triplet_list_file_remains="$data_dir/test_files_eigen_triplets_remains.txt"
ft_name="kitti"
# If running on Cityscapes, set as follows:
data_dir="CITYSCAPES_SEQ2_LR_TEST/" # Set for Cityscapes
triplet_list_file="/CITYSCAPES_SEQ2_LR_TEST/test_files_cityscapes_triplets.txt"
triplet_list_file_remains="CITYSCAPES_SEQ2_LR_TEST/test_files_cityscapes_triplets_remains.txt"
ft_name="cityscapes"
python optimize.py \
--logtostderr \
--output_dir $prediction_dir \
--data_dir $data_dir \
--triplet_list_file $triplet_list_file \
--triplet_list_file_remains $triplet_list_file_remains \
--ft_name $ft_name \
--model_ckpt $model_ckpt \
--file_extension png \
--handle_motion $handle_motion \
--size_constraint_weight $size_constraint_weight
```
## Running evaluation
```shell
prediction_dir="some/prediction/dir"
# Use these settings for KITTI:
eval_list_file="KITTI_FULL/kitti-raw-uncompressed/test_files_eigen.txt"
eval_crop="garg"
eval_mode="kitti"
# Use these settings for Cityscapes:
eval_list_file="CITYSCAPES_FULL/test_files_cityscapes.txt"
eval_crop="none"
eval_mode="cityscapes"
python evaluate.py \
--logtostderr \
--prediction_dir $prediction_dir \
--eval_list_file $eval_list_file \
--eval_crop $eval_crop \
--eval_mode $eval_mode
```
## Credits
This code is implemented and supported by Vincent Casser and Anelia Angelova and can be found at
https://sites.google.com/view/struct2depth.
The core implementation is derived from [https://github.com/tensorflow/models/tree/master/research/vid2depth)](https://github.com/tensorflow/models/tree/master/research/vid2depth)
by [Reza Mahjourian](rezama@google.com), which in turn is based on [SfMLearner
(https://github.com/tinghuiz/SfMLearner)](https://github.com/tinghuiz/SfMLearner)
by [Tinghui Zhou](https://github.com/tinghuiz).
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common utilities for data pre-processing, e.g. matching moving object across frames."""
import numpy as np
def compute_overlap(mask1, mask2):
# Use IoU here.
return np.sum(mask1 & mask2)/np.sum(mask1 | mask2)
def align(seg_img1, seg_img2, seg_img3, threshold_same=0.3):
res_img1 = np.zeros_like(seg_img1)
res_img2 = np.zeros_like(seg_img2)
res_img3 = np.zeros_like(seg_img3)
remaining_objects2 = list(np.unique(seg_img2.flatten()))
remaining_objects3 = list(np.unique(seg_img3.flatten()))
for seg_id in np.unique(seg_img1):
# See if we can find correspondences to seg_id in seg_img2.
max_overlap2 = float('-inf')
max_segid2 = -1
for seg_id2 in remaining_objects2:
overlap = compute_overlap(seg_img1==seg_id, seg_img2==seg_id2)
if overlap>max_overlap2:
max_overlap2 = overlap
max_segid2 = seg_id2
if max_overlap2 > threshold_same:
max_overlap3 = float('-inf')
max_segid3 = -1
for seg_id3 in remaining_objects3:
overlap = compute_overlap(seg_img2==max_segid2, seg_img3==seg_id3)
if overlap>max_overlap3:
max_overlap3 = overlap
max_segid3 = seg_id3
if max_overlap3 > threshold_same:
res_img1[seg_img1==seg_id] = seg_id
res_img2[seg_img2==max_segid2] = seg_id
res_img3[seg_img3==max_segid3] = seg_id
remaining_objects2.remove(max_segid2)
remaining_objects3.remove(max_segid3)
return res_img1, res_img2, res_img3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Offline data generation for the Cityscapes dataset."""
import os
from absl import app
from absl import flags
from absl import logging
import numpy as np
import cv2
import os, glob
import alignment
from alignment import compute_overlap
from alignment import align
SKIP = 2
WIDTH = 416
HEIGHT = 128
SUB_FOLDER = 'train'
INPUT_DIR = '/usr/local/google/home/anelia/struct2depth/CITYSCAPES_FULL/'
OUTPUT_DIR = '/usr/local/google/home/anelia/struct2depth/CITYSCAPES_Processed/'
def crop(img, segimg, fx, fy, cx, cy):
# Perform center cropping, preserving 50% vertically.
middle_perc = 0.50
left = 1 - middle_perc
half = left / 2
a = img[int(img.shape[0]*(half)):int(img.shape[0]*(1-half)), :]
aseg = segimg[int(segimg.shape[0]*(half)):int(segimg.shape[0]*(1-half)), :]
cy /= (1 / middle_perc)
# Resize to match target height while preserving aspect ratio.
wdt = int((float(HEIGHT)*a.shape[1]/a.shape[0]))
x_scaling = float(wdt)/a.shape[1]
y_scaling = float(HEIGHT)/a.shape[0]
b = cv2.resize(a, (wdt, HEIGHT))
bseg = cv2.resize(aseg, (wdt, HEIGHT))
# Adjust intrinsics.
fx*=x_scaling
fy*=y_scaling
cx*=x_scaling
cy*=y_scaling
# Perform center cropping horizontally.
remain = b.shape[1] - WIDTH
cx /= (b.shape[1] / WIDTH)
c = b[:, int(remain/2):b.shape[1]-int(remain/2)]
cseg = bseg[:, int(remain/2):b.shape[1]-int(remain/2)]
return c, cseg, fx, fy, cx, cy
def run_all():
dir_name=INPUT_DIR + '/leftImg8bit_sequence/' + SUB_FOLDER + '/*'
print('Processing directory', dir_name)
for location in glob.glob(INPUT_DIR + '/leftImg8bit_sequence/' + SUB_FOLDER + '/*'):
location_name = os.path.basename(location)
print('Processing location', location_name)
files = sorted(glob.glob(location + '/*.png'))
files = [file for file in files if '-seg.png' not in file]
# Break down into sequences
sequences = {}
seq_nr = 0
last_seq = ''
last_imgnr = -1
for i in range(len(files)):
seq = os.path.basename(files[i]).split('_')[1]
nr = int(os.path.basename(files[i]).split('_')[2])
if seq!=last_seq or last_imgnr+1!=nr:
seq_nr+=1
last_imgnr = nr
last_seq = seq
if not seq_nr in sequences:
sequences[seq_nr] = []
sequences[seq_nr].append(files[i])
for (k,v) in sequences.items():
print('Processing sequence', k, 'with', len(v), 'elements...')
output_dir = OUTPUT_DIR + '/' + location_name + '_' + str(k)
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
files = sorted(v)
triplet = []
seg_triplet = []
ct = 1
# Find applicable intrinsics.
for j in range(len(files)):
osegname = os.path.basename(files[j]).split('_')[1]
oimgnr = os.path.basename(files[j]).split('_')[2]
applicable_intrinsics = INPUT_DIR + '/camera/' + SUB_FOLDER + '/' + location_name + '/' + location_name + '_' + osegname + '_' + oimgnr + '_camera.json'
# Get the intrinsics for one of the file of the sequence.
if os.path.isfile(applicable_intrinsics):
f = open(applicable_intrinsics, 'r')
lines = f.readlines()
f.close()
lines = [line.rstrip() for line in lines]
fx = float(lines[11].split(': ')[1].replace(',', ''))
fy = float(lines[12].split(': ')[1].replace(',', ''))
cx = float(lines[13].split(': ')[1].replace(',', ''))
cy = float(lines[14].split(': ')[1].replace(',', ''))
for j in range(0, len(files), SKIP):
img = cv2.imread(files[j])
segimg = cv2.imread(files[j].replace('.png', '-seg.png'))
smallimg, segimg, fx_this, fy_this, cx_this, cy_this = crop(img, segimg, fx, fy, cx, cy)
triplet.append(smallimg)
seg_triplet.append(segimg)
if len(triplet)==3:
cmb = np.hstack(triplet)
align1, align2, align3 = align(seg_triplet[0], seg_triplet[1], seg_triplet[2])
cmb_seg = np.hstack([align1, align2, align3])
cv2.imwrite(os.path.join(output_dir, str(ct).zfill(10) + '.png'), cmb)
cv2.imwrite(os.path.join(output_dir, str(ct).zfill(10) + '-fseg.png'), cmb_seg)
f = open(os.path.join(output_dir, str(ct).zfill(10) + '_cam.txt'), 'w')
f.write(str(fx_this) + ',0.0,' + str(cx_this) + ',0.0,' + str(fy_this) + ',' + str(cy_this) + ',0.0,0.0,1.0')
f.close()
del triplet[0]
del seg_triplet[0]
ct+=1
# Create file list for training. Be careful as it collects and includes all files recursively.
fn = open(OUTPUT_DIR + '/' + SUB_FOLDER + '.txt', 'w')
for f in glob.glob(OUTPUT_DIR + '/*/*.png'):
if '-seg.png' in f or '-fseg.png' in f:
continue
folder_name = f.split('/')[-2]
img_name = f.split('/')[-1].replace('.png', '')
fn.write(folder_name + ' ' + img_name + '\n')
fn.close()
def main(_):
run_all()
if __name__ == '__main__':
app.run(main)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Offline data generation for the KITTI dataset."""
import os
from absl import app
from absl import flags
from absl import logging
import numpy as np
import cv2
import os, glob
import alignment
from alignment import compute_overlap
from alignment import align
SEQ_LENGTH = 3
WIDTH = 416
HEIGHT = 128
STEPSIZE = 1
INPUT_DIR = '/usr/local/google/home/anelia/struct2depth/KITTI_FULL/kitti-raw-uncompressed'
OUTPUT_DIR = '/usr/local/google/home/anelia/struct2depth/KITTI_procesed/'
def get_line(file, start):
file = open(file, 'r')
lines = file.readlines()
lines = [line.rstrip() for line in lines]
ret = None
for line in lines:
nline = line.split(': ')
if nline[0]==start:
ret = nline[1].split(' ')
ret = np.array([float(r) for r in ret], dtype=float)
ret = ret.reshape((3,4))[0:3, 0:3]
break
file.close()
return ret
def crop(img, segimg, fx, fy, cx, cy):
# Perform center cropping, preserving 50% vertically.
middle_perc = 0.50
left = 1-middle_perc
half = left/2
a = img[int(img.shape[0]*(half)):int(img.shape[0]*(1-half)), :]
aseg = segimg[int(segimg.shape[0]*(half)):int(segimg.shape[0]*(1-half)), :]
cy /= (1/middle_perc)
# Resize to match target height while preserving aspect ratio.
wdt = int((128*a.shape[1]/a.shape[0]))
x_scaling = float(wdt)/a.shape[1]
y_scaling = 128.0/a.shape[0]
b = cv2.resize(a, (wdt, 128))
bseg = cv2.resize(aseg, (wdt, 128))
# Adjust intrinsics.
fx*=x_scaling
fy*=y_scaling
cx*=x_scaling
cy*=y_scaling
# Perform center cropping horizontally.
remain = b.shape[1] - 416
cx /= (b.shape[1]/416)
c = b[:, int(remain/2):b.shape[1]-int(remain/2)]
cseg = bseg[:, int(remain/2):b.shape[1]-int(remain/2)]
return c, cseg, fx, fy, cx, cy
def run_all():
ct = 0
if not OUTPUT_DIR.endswith('/'):
OUTPUT_DIR = OUTPUT_DIR + '/'
for d in glob.glob(INPUT_DIR + '/*/'):
date = d.split('/')[-2]
file_calibration = d + 'calib_cam_to_cam.txt'
calib_raw = [get_line(file_calibration, 'P_rect_02'), get_line(file_calibration, 'P_rect_03')]
for d2 in glob.glob(d + '*/'):
seqname = d2.split('/')[-2]
print('Processing sequence', seqname)
for subfolder in ['image_02/data', 'image_03/data']:
ct = 1
seqname = d2.split('/')[-2] + subfolder.replace('image', '').replace('/data', '')
if not os.path.exists(OUTPUT_DIR + seqname):
os.mkdir(OUTPUT_DIR + seqname)
calib_camera = calib_raw[0] if subfolder=='image_02/data' else calib_raw[1]
folder = d2 + subfolder
files = glob.glob(folder + '/*.png')
files = [file for file in files if not 'disp' in file and not 'flip' in file and not 'seg' in file]
files = sorted(files)
for i in range(SEQ_LENGTH, len(files)+1, STEPSIZE):
imgnum = str(ct).zfill(10)
if os.path.exists(OUTPUT_DIR + seqname + '/' + imgnum + '.png'):
ct+=1
continue
big_img = np.zeros(shape=(HEIGHT, WIDTH*SEQ_LENGTH, 3))
wct = 0
for j in range(i-SEQ_LENGTH, i): # Collect frames for this sample.
img = cv2.imread(files[j])
ORIGINAL_HEIGHT, ORIGINAL_WIDTH, _ = img.shape
zoom_x = WIDTH/ORIGINAL_WIDTH
zoom_y = HEIGHT/ORIGINAL_HEIGHT
# Adjust intrinsics.
calib_current = calib_camera.copy()
calib_current[0, 0] *= zoom_x
calib_current[0, 2] *= zoom_x
calib_current[1, 1] *= zoom_y
calib_current[1, 2] *= zoom_y
calib_representation = ','.join([str(c) for c in calib_current.flatten()])
img = cv2.resize(img, (WIDTH, HEIGHT))
big_img[:,wct*WIDTH:(wct+1)*WIDTH] = img
wct+=1
cv2.imwrite(OUTPUT_DIR + seqname + '/' + imgnum + '.png', big_img)
f = open(OUTPUT_DIR + seqname + '/' + imgnum + '_cam.txt', 'w')
f.write(calib_representation)
f.close()
ct+=1
def main(_):
run_all()
if __name__ == '__main__':
app.run(main)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs struct2depth at inference. Produces depth estimates, ego-motion and object motion."""
# Example usage:
#
# python inference.py \
# --input_dir ~/struct2depth/kitti-raw-uncompressed/ \
# --output_dir ~/struct2depth/output \
# --model_ckpt ~/struct2depth/model/model-199160
# --file_extension png \
# --depth \
# --egomotion true \
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from absl import logging
#import matplotlib.pyplot as plt
import model
import numpy as np
import fnmatch
import tensorflow as tf
import nets
import util
gfile = tf.gfile
# CMAP = 'plasma'
INFERENCE_MODE_SINGLE = 'single' # Take plain single-frame input.
INFERENCE_MODE_TRIPLETS = 'triplets' # Take image triplets as input.
# For KITTI, we just resize input images and do not perform cropping. For
# Cityscapes, the car hood and more image content has been cropped in order
# to fit aspect ratio, and remove static content from the images. This has to be
# kept at inference time.
INFERENCE_CROP_NONE = 'none'
INFERENCE_CROP_CITYSCAPES = 'cityscapes'
flags.DEFINE_string('output_dir', None, 'Directory to store predictions.')
flags.DEFINE_string('file_extension', 'png', 'Image data file extension of '
'files provided with input_dir. Also determines the output '
'file format of depth prediction images.')
flags.DEFINE_bool('depth', True, 'Determines if the depth prediction network '
'should be executed and its predictions be saved.')
flags.DEFINE_bool('egomotion', False, 'Determines if the egomotion prediction '
'network should be executed and its predictions be saved. If '
'inference is run in single inference mode, it is assumed '
'that files in the same directory belong in the same '
'sequence, and sorting them alphabetically establishes the '
'right temporal order.')
flags.DEFINE_string('model_ckpt', None, 'Model checkpoint to evaluate.')
flags.DEFINE_string('input_dir', None, 'Directory containing image files to '
'evaluate. This crawls recursively for images in the '
'directory, mirroring relative subdirectory structures '
'into the output directory.')
flags.DEFINE_string('input_list_file', None, 'Text file containing paths to '
'image files to process. Paths should be relative with '
'respect to the list file location. Relative path '
'structures will be mirrored in the output directory.')
flags.DEFINE_integer('batch_size', 1, 'The size of a sample batch')
flags.DEFINE_integer('img_height', 128, 'Input frame height.')
flags.DEFINE_integer('img_width', 416, 'Input frame width.')
flags.DEFINE_integer('seq_length', 3, 'Number of frames in sequence.')
flags.DEFINE_enum('architecture', nets.RESNET, nets.ARCHITECTURES,
'Defines the architecture to use for the depth prediction '
'network. Defaults to ResNet-based encoder and accompanying '
'decoder.')
flags.DEFINE_boolean('imagenet_norm', True, 'Whether to normalize the input '
'images channel-wise so that they match the distribution '
'most ImageNet-models were trained on.')
flags.DEFINE_bool('use_skip', True, 'Whether to use skip connections in the '
'encoder-decoder architecture.')
flags.DEFINE_bool('joint_encoder', False, 'Whether to share parameters '
'between the depth and egomotion networks by using a joint '
'encoder architecture. The egomotion network is then '
'operating only on the hidden representation provided by the '
'joint encoder.')
flags.DEFINE_bool('shuffle', False, 'Whether to shuffle the order in which '
'images are processed.')
flags.DEFINE_bool('flip', False, 'Whether images should be flipped as well as '
'resulting predictions (for test-time augmentation). This '
'currently applies to the depth network only.')
flags.DEFINE_enum('inference_mode', INFERENCE_MODE_SINGLE,
[INFERENCE_MODE_SINGLE,
INFERENCE_MODE_TRIPLETS],
'Whether to use triplet mode for inference, which accepts '
'triplets instead of single frames.')
flags.DEFINE_enum('inference_crop', INFERENCE_CROP_NONE,
[INFERENCE_CROP_NONE,
INFERENCE_CROP_CITYSCAPES],
'Whether to apply a Cityscapes-specific crop on the input '
'images first before running inference.')
flags.DEFINE_bool('use_masks', False, 'Whether to mask out potentially '
'moving objects when feeding image input to the egomotion '
'network. This might improve odometry results when using '
'a motion model. For this, pre-computed segmentation '
'masks have to be available for every image, with the '
'background being zero.')
FLAGS = flags.FLAGS
flags.mark_flag_as_required('output_dir')
flags.mark_flag_as_required('model_ckpt')
def _run_inference(output_dir=None,
file_extension='png',
depth=True,
egomotion=False,
model_ckpt=None,
input_dir=None,
input_list_file=None,
batch_size=1,
img_height=128,
img_width=416,
seq_length=3,
architecture=nets.RESNET,
imagenet_norm=True,
use_skip=True,
joint_encoder=True,
shuffle=False,
flip_for_depth=False,
inference_mode=INFERENCE_MODE_SINGLE,
inference_crop=INFERENCE_CROP_NONE,
use_masks=False):
"""Runs inference. Refer to flags in inference.py for details."""
inference_model = model.Model(is_training=False,
batch_size=batch_size,
img_height=img_height,
img_width=img_width,
seq_length=seq_length,
architecture=architecture,
imagenet_norm=imagenet_norm,
use_skip=use_skip,
joint_encoder=joint_encoder)
vars_to_restore = util.get_vars_to_save_and_restore(model_ckpt)
saver = tf.train.Saver(vars_to_restore)
sv = tf.train.Supervisor(logdir='/tmp/', saver=None)
with sv.managed_session() as sess:
saver.restore(sess, model_ckpt)
if not gfile.Exists(output_dir):
gfile.MakeDirs(output_dir)
logging.info('Predictions will be saved in %s.', output_dir)
# Collect all images to run inference on.
im_files, basepath_in = collect_input_images(input_dir, input_list_file,
file_extension)
if shuffle:
logging.info('Shuffling data...')
np.random.shuffle(im_files)
logging.info('Running inference on %d files.', len(im_files))
# Create missing output folders and pre-compute target directories.
output_dirs = create_output_dirs(im_files, basepath_in, output_dir)
# Run depth prediction network.
if depth:
im_batch = []
for i in range(len(im_files)):
if i % 100 == 0:
logging.info('%s of %s files processed.', i, len(im_files))
# Read image and run inference.
if inference_mode == INFERENCE_MODE_SINGLE:
if inference_crop == INFERENCE_CROP_NONE:
im = util.load_image(im_files[i], resize=(img_width, img_height))
elif inference_crop == INFERENCE_CROP_CITYSCAPES:
im = util.crop_cityscapes(util.load_image(im_files[i]),
resize=(img_width, img_height))
elif inference_mode == INFERENCE_MODE_TRIPLETS:
im = util.load_image(im_files[i], resize=(img_width * 3, img_height))
im = im[:, img_width:img_width*2]
if flip_for_depth:
im = np.flip(im, axis=1)
im_batch.append(im)
if len(im_batch) == batch_size or i == len(im_files) - 1:
# Call inference on batch.
for _ in range(batch_size - len(im_batch)): # Fill up batch.
im_batch.append(np.zeros(shape=(img_height, img_width, 3),
dtype=np.float32))
im_batch = np.stack(im_batch, axis=0)
est_depth = inference_model.inference_depth(im_batch, sess)
if flip_for_depth:
est_depth = np.flip(est_depth, axis=2)
im_batch = np.flip(im_batch, axis=2)
for j in range(len(im_batch)):
color_map = util.normalize_depth_for_display(
np.squeeze(est_depth[j]))
visualization = np.concatenate((im_batch[j], color_map), axis=0)
# Save raw prediction and color visualization. Extract filename
# without extension from full path: e.g. path/to/input_dir/folder1/
# file1.png -> file1
k = i - len(im_batch) + 1 + j
filename_root = os.path.splitext(os.path.basename(im_files[k]))[0]
pref = '_flip' if flip_for_depth else ''
output_raw = os.path.join(
output_dirs[k], filename_root + pref + '.npy')
output_vis = os.path.join(
output_dirs[k], filename_root + pref + '.png')
with gfile.Open(output_raw, 'wb') as f:
np.save(f, est_depth[j])
util.save_image(output_vis, visualization, file_extension)
im_batch = []
# Run egomotion network.
if egomotion:
if inference_mode == INFERENCE_MODE_SINGLE:
# Run regular egomotion inference loop.
input_image_seq = []
input_seg_seq = []
current_sequence_dir = None
current_output_handle = None
for i in range(len(im_files)):
sequence_dir = os.path.dirname(im_files[i])
if sequence_dir != current_sequence_dir:
# Assume start of a new sequence, since this image lies in a
# different directory than the previous ones.
# Clear egomotion input buffer.
output_filepath = os.path.join(output_dirs[i], 'egomotion.txt')
if current_output_handle is not None:
current_output_handle.close()
current_sequence_dir = sequence_dir
logging.info('Writing egomotion sequence to %s.', output_filepath)
current_output_handle = gfile.Open(output_filepath, 'w')
input_image_seq = []
im = util.load_image(im_files[i], resize=(img_width, img_height))
input_image_seq.append(im)
if use_masks:
im_seg_path = im_files[i].replace('.%s' % file_extension,
'-seg.%s' % file_extension)
if not gfile.Exists(im_seg_path):
raise ValueError('No segmentation mask %s has been found for '
'image %s. If none are available, disable '
'use_masks.' % (im_seg_path, im_files[i]))
input_seg_seq.append(util.load_image(im_seg_path,
resize=(img_width, img_height),
interpolation='nn'))
if len(input_image_seq) < seq_length: # Buffer not filled yet.
continue
if len(input_image_seq) > seq_length: # Remove oldest entry.
del input_image_seq[0]
if use_masks:
del input_seg_seq[0]
input_image_stack = np.concatenate(input_image_seq, axis=2)
input_image_stack = np.expand_dims(input_image_stack, axis=0)
if use_masks:
input_image_stack = mask_image_stack(input_image_stack,
input_seg_seq)
est_egomotion = np.squeeze(inference_model.inference_egomotion(
input_image_stack, sess))
egomotion_str = []
for j in range(seq_length - 1):
egomotion_str.append(','.join([str(d) for d in est_egomotion[j]]))
current_output_handle.write(
str(i) + ' ' + ' '.join(egomotion_str) + '\n')
if current_output_handle is not None:
current_output_handle.close()
elif inference_mode == INFERENCE_MODE_TRIPLETS:
written_before = []
for i in range(len(im_files)):
im = util.load_image(im_files[i], resize=(img_width * 3, img_height))
input_image_stack = np.concatenate(
[im[:, :img_width], im[:, img_width:img_width*2],
im[:, img_width*2:]], axis=2)
input_image_stack = np.expand_dims(input_image_stack, axis=0)
if use_masks:
im_seg_path = im_files[i].replace('.%s' % file_extension,
'-seg.%s' % file_extension)
if not gfile.Exists(im_seg_path):
raise ValueError('No segmentation mask %s has been found for '
'image %s. If none are available, disable '
'use_masks.' % (im_seg_path, im_files[i]))
seg = util.load_image(im_seg_path,
resize=(img_width * 3, img_height),
interpolation='nn')
input_seg_seq = [seg[:, :img_width], seg[:, img_width:img_width*2],
seg[:, img_width*2:]]
input_image_stack = mask_image_stack(input_image_stack,
input_seg_seq)
est_egomotion = inference_model.inference_egomotion(
input_image_stack, sess)
est_egomotion = np.squeeze(est_egomotion)
egomotion_1_2 = ','.join([str(d) for d in est_egomotion[0]])
egomotion_2_3 = ','.join([str(d) for d in est_egomotion[1]])
output_filepath = os.path.join(output_dirs[i], 'egomotion.txt')
file_mode = 'w' if output_filepath not in written_before else 'a'
with gfile.Open(output_filepath, file_mode) as current_output_handle:
current_output_handle.write(str(i) + ' ' + egomotion_1_2 + ' ' +
egomotion_2_3 + '\n')
written_before.append(output_filepath)
logging.info('Done.')
def mask_image_stack(input_image_stack, input_seg_seq):
"""Masks out moving image contents by using the segmentation masks provided.
This can lead to better odometry accuracy for motion models, but is optional
to use. Is only called if use_masks is enabled.
Args:
input_image_stack: The input image stack of shape (1, H, W, seq_length).
input_seg_seq: List of segmentation masks with seq_length elements of shape
(H, W, C) for some number of channels C.
Returns:
Input image stack with detections provided by segmentation mask removed.
"""
background = [mask == 0 for mask in input_seg_seq]
background = reduce(lambda m1, m2: m1 & m2, background)
# If masks are RGB, assume all channels to be the same. Reduce to the first.
if background.ndim == 3 and background.shape[2] > 1:
background = np.expand_dims(background[:, :, 0], axis=2)
elif background.ndim == 2: # Expand.
background = np.expand_dism(background, axis=2)
# background is now of shape (H, W, 1).
background_stack = np.tile(background, [1, 1, input_image_stack.shape[3]])
return np.multiply(input_image_stack, background_stack)
def collect_input_images(input_dir, input_list_file, file_extension):
"""Collects all input images that are to be processed."""
if input_dir is not None:
im_files = _recursive_glob(input_dir, '*.' + file_extension)
basepath_in = os.path.normpath(input_dir)
elif input_list_file is not None:
im_files = util.read_text_lines(input_list_file)
basepath_in = os.path.dirname(input_list_file)
im_files = [os.path.join(basepath_in, f) for f in im_files]
im_files = [f for f in im_files if 'disp' not in f and '-seg' not in f and
'-fseg' not in f and '-flip' not in f]
return sorted(im_files), basepath_in
def create_output_dirs(im_files, basepath_in, output_dir):
"""Creates required directories, and returns output dir for each file."""
output_dirs = []
for i in range(len(im_files)):
relative_folder_in = os.path.relpath(
os.path.dirname(im_files[i]), basepath_in)
absolute_folder_out = os.path.join(output_dir, relative_folder_in)
if not gfile.IsDirectory(absolute_folder_out):
gfile.MakeDirs(absolute_folder_out)
output_dirs.append(absolute_folder_out)
return output_dirs
def _recursive_glob(treeroot, pattern):
results = []
for base, _, files in os.walk(treeroot):
files = fnmatch.filter(files, pattern)
results.extend(os.path.join(base, f) for f in files)
return results
def main(_):
#if (flags.input_dir is None) == (flags.input_list_file is None):
# raise ValueError('Exactly one of either input_dir or input_list_file has '
# 'to be provided.')
#if not flags.depth and not flags.egomotion:
# raise ValueError('At least one of the depth and egomotion network has to '
# 'be called for inference.')
#if (flags.inference_mode == inference_lib.INFERENCE_MODE_TRIPLETS and
# flags.seq_length != 3):
# raise ValueError('For sequence lengths other than three, single inference '
# 'mode has to be used.')
_run_inference(output_dir=FLAGS.output_dir,
file_extension=FLAGS.file_extension,
depth=FLAGS.depth,
egomotion=FLAGS.egomotion,
model_ckpt=FLAGS.model_ckpt,
input_dir=FLAGS.input_dir,
input_list_file=FLAGS.input_list_file,
batch_size=FLAGS.batch_size,
img_height=FLAGS.img_height,
img_width=FLAGS.img_width,
seq_length=FLAGS.seq_length,
architecture=FLAGS.architecture,
imagenet_norm=FLAGS.imagenet_norm,
use_skip=FLAGS.use_skip,
joint_encoder=FLAGS.joint_encoder,
shuffle=FLAGS.shuffle,
flip_for_depth=FLAGS.flip,
inference_mode=FLAGS.inference_mode,
inference_crop=FLAGS.inference_crop,
use_masks=FLAGS.use_masks)
if __name__ == '__main__':
app.run(main)
This diff is collapsed.
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Applies online refinement while running inference.
Instructions: Run static inference first before calling this script. Make sure
to point output_dir to the same folder where static inference results were
saved previously.
For example use, please refer to README.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
import os
import random
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
import model
import nets
import reader
import util
gfile = tf.gfile
SAVE_EVERY = 1 # Defines the interval that predictions should be saved at.
SAVE_PREVIEWS = True # If set, while save image previews of depth predictions.
FIXED_SEED = 8964 # Fixed seed for repeatability.
flags.DEFINE_string('output_dir', None, 'Directory to store predictions. '
'Assumes that regular inference has been executed before '
'and results were stored in this folder.')
flags.DEFINE_string('data_dir', None, 'Folder pointing to preprocessed '
'triplets to fine-tune on.')
flags.DEFINE_string('triplet_list_file', None, 'Text file containing paths to '
'image files to process. Paths should be relative with '
'respect to the list file location. Every line should be '
'of the form [input_folder_name] [input_frame_num] '
'[output_path], where [output_path] is optional to specify '
'a different path to store the prediction.')
flags.DEFINE_string('triplet_list_file_remains', None, 'Optional text file '
'containing relative paths to image files which should not '
'be fine-tuned, e.g. because of missing adjacent frames. '
'For all files listed, the static prediction will be '
'copied instead. File can be empty. If not, every line '
'should be of the form [input_folder_name] '
'[input_frame_num] [output_path], where [output_path] is '
'optional to specify a different path to take and store '
'the unrefined prediction from/to.')
flags.DEFINE_string('model_ckpt', None, 'Model checkpoint to optimize.')
flags.DEFINE_string('ft_name', '', 'Optional prefix for temporary files.')
flags.DEFINE_string('file_extension', 'png', 'Image data file extension.')
flags.DEFINE_float('learning_rate', 0.0001, 'Adam learning rate.')
flags.DEFINE_float('beta1', 0.9, 'Adam momentum.')
flags.DEFINE_float('reconstr_weight', 0.85, 'Frame reconstruction loss weight.')
flags.DEFINE_float('ssim_weight', 0.15, 'SSIM loss weight.')
flags.DEFINE_float('smooth_weight', 0.01, 'Smoothness loss weight.')
flags.DEFINE_float('icp_weight', 0.0, 'ICP loss weight.')
flags.DEFINE_float('size_constraint_weight', 0.0005, 'Weight of the object '
'size constraint loss. Use only with motion handling.')
flags.DEFINE_integer('batch_size', 1, 'The size of a sample batch')
flags.DEFINE_integer('img_height', 128, 'Input frame height.')
flags.DEFINE_integer('img_width', 416, 'Input frame width.')
flags.DEFINE_integer('seq_length', 3, 'Number of frames in sequence.')
flags.DEFINE_enum('architecture', nets.RESNET, nets.ARCHITECTURES,
'Defines the architecture to use for the depth prediction '
'network. Defaults to ResNet-based encoder and accompanying '
'decoder.')
flags.DEFINE_boolean('imagenet_norm', True, 'Whether to normalize the input '
'images channel-wise so that they match the distribution '
'most ImageNet-models were trained on.')
flags.DEFINE_float('weight_reg', 0.05, 'The amount of weight regularization to '
'apply. This has no effect on the ResNet-based encoder '
'architecture.')
flags.DEFINE_boolean('exhaustive_mode', False, 'Whether to exhaustively warp '
'from any frame to any other instead of just considering '
'adjacent frames. Where necessary, multiple egomotion '
'estimates will be applied. Does not have an effect if '
'compute_minimum_loss is enabled.')
flags.DEFINE_boolean('random_scale_crop', False, 'Whether to apply random '
'image scaling and center cropping during training.')
flags.DEFINE_bool('depth_upsampling', True, 'Whether to apply depth '
'upsampling of lower-scale representations before warping to '
'compute reconstruction loss on full-resolution image.')
flags.DEFINE_bool('depth_normalization', True, 'Whether to apply depth '
'normalization, that is, normalizing inverse depth '
'prediction maps by their mean to avoid degeneration towards '
'small values.')
flags.DEFINE_bool('compute_minimum_loss', True, 'Whether to take the '
'element-wise minimum of the reconstruction/SSIM error in '
'order to avoid overly penalizing dis-occlusion effects.')
flags.DEFINE_bool('use_skip', True, 'Whether to use skip connections in the '
'encoder-decoder architecture.')
flags.DEFINE_bool('joint_encoder', False, 'Whether to share parameters '
'between the depth and egomotion networks by using a joint '
'encoder architecture. The egomotion network is then '
'operating only on the hidden representation provided by the '
'joint encoder.')
flags.DEFINE_float('egomotion_threshold', 0.01, 'Minimum egomotion magnitude '
'to apply finetuning. If lower, just forwards the ordinary '
'prediction.')
flags.DEFINE_integer('num_steps', 20, 'Number of optimization steps to run.')
flags.DEFINE_boolean('handle_motion', True, 'Whether the checkpoint was '
'trained with motion handling.')
flags.DEFINE_bool('flip', False, 'Whether images should be flipped as well as '
'resulting predictions (for test-time augmentation). This '
'currently applies to the depth network only.')
FLAGS = flags.FLAGS
flags.mark_flag_as_required('output_dir')
flags.mark_flag_as_required('data_dir')
flags.mark_flag_as_required('model_ckpt')
flags.mark_flag_as_required('triplet_list_file')
def main(_):
"""Runs fine-tuning and inference.
There are three categories of images.
1) Images where we have previous and next frame, and that are not filtered
out by the heuristic. For them, we will use the fine-tuned predictions.
2) Images where we have previous and next frame, but that were filtered out
by our heuristic. For them, we will use the ordinary prediction instead.
3) Images where we have at least one missing adjacent frame. For them, we will
use the ordinary prediction as indicated by triplet_list_file_remains (if
provided). They will also not be part of the generated inference list in
the first place.
Raises:
ValueError: Invalid parameters have been passed.
"""
if FLAGS.handle_motion and FLAGS.joint_encoder:
raise ValueError('Using a joint encoder is currently not supported when '
'modeling object motion.')
if FLAGS.handle_motion and FLAGS.seq_length != 3:
raise ValueError('The current motion model implementation only supports '
'using a sequence length of three.')
if FLAGS.handle_motion and not FLAGS.compute_minimum_loss:
raise ValueError('Computing the minimum photometric loss is required when '
'enabling object motion handling.')
if FLAGS.size_constraint_weight > 0 and not FLAGS.handle_motion:
raise ValueError('To enforce object size constraints, enable motion '
'handling.')
if FLAGS.icp_weight > 0.0:
raise ValueError('ICP is currently not supported.')
if FLAGS.compute_minimum_loss and FLAGS.seq_length % 2 != 1:
raise ValueError('Compute minimum loss requires using an odd number of '
'images in a sequence.')
if FLAGS.compute_minimum_loss and FLAGS.exhaustive_mode:
raise ValueError('Exhaustive mode has no effect when compute_minimum_loss '
'is enabled.')
if FLAGS.img_width % (2 ** 5) != 0 or FLAGS.img_height % (2 ** 5) != 0:
logging.warn('Image size is not divisible by 2^5. For the architecture '
'employed, this could cause artefacts caused by resizing in '
'lower dimensions.')
if FLAGS.output_dir.endswith('/'):
FLAGS.output_dir = FLAGS.output_dir[:-1]
# Create file lists to prepare fine-tuning, save it to unique_file.
unique_file_name = (str(datetime.datetime.now().date()) + '_' +
str(datetime.datetime.now().time()).replace(':', '_'))
unique_file = os.path.join(FLAGS.data_dir, unique_file_name + '.txt')
with gfile.FastGFile(FLAGS.triplet_list_file, 'r') as f:
files_to_process = f.readlines()
files_to_process = [line.rstrip() for line in files_to_process]
files_to_process = [line for line in files_to_process if len(line)]
logging.info('Creating unique file list %s with %s entries.', unique_file,
len(files_to_process))
with gfile.FastGFile(unique_file, 'w') as f_out:
fetches_network = FLAGS.num_steps * FLAGS.batch_size
fetches_saves = FLAGS.batch_size * int(np.floor(FLAGS.num_steps/SAVE_EVERY))
repetitions = fetches_network + 3 * fetches_saves
for i in range(len(files_to_process)):
for _ in range(repetitions):
f_out.write(files_to_process[i] + '\n')
# Read remaining files.
remaining = []
if gfile.Exists(FLAGS.triplet_list_file_remains):
with gfile.FastGFile(FLAGS.triplet_list_file_remains, 'r') as f:
remaining = f.readlines()
remaining = [line.rstrip() for line in remaining]
remaining = [line for line in remaining if len(line)]
logging.info('Running fine-tuning on %s files, %s files are remaining.',
len(files_to_process), len(remaining))
# Run fine-tuning process and save predictions in id-folders.
tf.set_random_seed(FIXED_SEED)
np.random.seed(FIXED_SEED)
random.seed(FIXED_SEED)
flipping_mode = reader.FLIP_ALWAYS if FLAGS.flip else reader.FLIP_NONE
train_model = model.Model(data_dir=FLAGS.data_dir,
file_extension=FLAGS.file_extension,
is_training=True,
learning_rate=FLAGS.learning_rate,
beta1=FLAGS.beta1,
reconstr_weight=FLAGS.reconstr_weight,
smooth_weight=FLAGS.smooth_weight,
ssim_weight=FLAGS.ssim_weight,
icp_weight=FLAGS.icp_weight,
batch_size=FLAGS.batch_size,
img_height=FLAGS.img_height,
img_width=FLAGS.img_width,
seq_length=FLAGS.seq_length,
architecture=FLAGS.architecture,
imagenet_norm=FLAGS.imagenet_norm,
weight_reg=FLAGS.weight_reg,
exhaustive_mode=FLAGS.exhaustive_mode,
random_scale_crop=FLAGS.random_scale_crop,
flipping_mode=flipping_mode,
random_color=False,
depth_upsampling=FLAGS.depth_upsampling,
depth_normalization=FLAGS.depth_normalization,
compute_minimum_loss=FLAGS.compute_minimum_loss,
use_skip=FLAGS.use_skip,
joint_encoder=FLAGS.joint_encoder,
build_sum=False,
shuffle=False,
input_file=unique_file_name,
handle_motion=FLAGS.handle_motion,
size_constraint_weight=FLAGS.size_constraint_weight,
train_global_scale_var=False)
failed_heuristic_ids = finetune_inference(train_model, FLAGS.model_ckpt,
FLAGS.output_dir + '_ft')
logging.info('Fine-tuning completed, %s files were filtered out by '
'heuristic.', len(failed_heuristic_ids))
for failed_id in failed_heuristic_ids:
failed_entry = files_to_process[failed_id]
remaining.append(failed_entry)
logging.info('In total, %s images were fine-tuned, while %s were not.',
len(files_to_process)-len(failed_heuristic_ids), len(remaining))
# Copy all results to have the same structural output as running ordinary
# inference.
for i in range(len(files_to_process)):
if files_to_process[i] not in remaining: # Use fine-tuned result.
elements = files_to_process[i].split(' ')
source_file = os.path.join(FLAGS.output_dir + '_ft', FLAGS.ft_name +
'id_' + str(i),
str(FLAGS.num_steps).zfill(10) +
('_flip' if FLAGS.flip else ''))
if len(elements) == 2: # No differing mapping defined.
target_dir = os.path.join(FLAGS.output_dir + '_ft', elements[0])
target_file = os.path.join(
target_dir, elements[1] + ('_flip' if FLAGS.flip else ''))
else: # Other mapping for file defined, copy to this location instead.
target_dir = os.path.join(
FLAGS.output_dir + '_ft', os.path.dirname(elements[2]))
target_file = os.path.join(
target_dir,
os.path.basename(elements[2]) + ('_flip' if FLAGS.flip else ''))
if not gfile.Exists(target_dir):
gfile.MakeDirs(target_dir)
logging.info('Copy refined result %s to %s.', source_file, target_file)
gfile.Copy(source_file + '.npy', target_file + '.npy', overwrite=True)
gfile.Copy(source_file + '.txt', target_file + '.txt', overwrite=True)
gfile.Copy(source_file + '.%s' % FLAGS.file_extension,
target_file + '.%s' % FLAGS.file_extension, overwrite=True)
for j in range(len(remaining)):
elements = remaining[j].split(' ')
if len(elements) == 2: # No differing mapping defined.
target_dir = os.path.join(FLAGS.output_dir + '_ft', elements[0])
target_file = os.path.join(
target_dir, elements[1] + ('_flip' if FLAGS.flip else ''))
else: # Other mapping for file defined, copy to this location instead.
target_dir = os.path.join(
FLAGS.output_dir + '_ft', os.path.dirname(elements[2]))
target_file = os.path.join(
target_dir,
os.path.basename(elements[2]) + ('_flip' if FLAGS.flip else ''))
if not gfile.Exists(target_dir):
gfile.MakeDirs(target_dir)
source_file = target_file.replace('_ft', '')
logging.info('Copy unrefined result %s to %s.', source_file, target_file)
gfile.Copy(source_file + '.npy', target_file + '.npy', overwrite=True)
gfile.Copy(source_file + '.%s' % FLAGS.file_extension,
target_file + '.%s' % FLAGS.file_extension, overwrite=True)
logging.info('Done, predictions saved in %s.', FLAGS.output_dir + '_ft')
def finetune_inference(train_model, model_ckpt, output_dir):
"""Train model."""
vars_to_restore = None
if model_ckpt is not None:
vars_to_restore = util.get_vars_to_save_and_restore(model_ckpt)
ckpt_path = model_ckpt
pretrain_restorer = tf.train.Saver(vars_to_restore)
sv = tf.train.Supervisor(logdir=None, save_summaries_secs=0, saver=None,
summary_op=None)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
img_nr = 0
failed_heuristic = []
with sv.managed_session(config=config) as sess:
# TODO(casser): Caching the weights would be better to avoid I/O bottleneck.
while True: # Loop terminates when all examples have been processed.
if model_ckpt is not None:
logging.info('Restored weights from %s', ckpt_path)
pretrain_restorer.restore(sess, ckpt_path)
logging.info('Running fine-tuning, image %s...', img_nr)
img_pred_folder = os.path.join(
output_dir, FLAGS.ft_name + 'id_' + str(img_nr))
if not gfile.Exists(img_pred_folder):
gfile.MakeDirs(img_pred_folder)
step = 1
# Run fine-tuning.
while step <= FLAGS.num_steps:
logging.info('Running step %s of %s.', step, FLAGS.num_steps)
fetches = {
'train': train_model.train_op,
'global_step': train_model.global_step,
'incr_global_step': train_model.incr_global_step
}
_ = sess.run(fetches)
if step % SAVE_EVERY == 0:
# Get latest prediction for middle frame, highest scale.
pred = train_model.depth[1][0].eval(session=sess)
if FLAGS.flip:
pred = np.flip(pred, axis=2)
input_img = train_model.image_stack.eval(session=sess)
input_img_prev = input_img[0, :, :, 0:3]
input_img_center = input_img[0, :, :, 3:6]
input_img_next = input_img[0, :, :, 6:]
img_pred_file = os.path.join(
img_pred_folder,
str(step).zfill(10) + ('_flip' if FLAGS.flip else '') + '.npy')
motion = np.squeeze(train_model.egomotion.eval(session=sess))
# motion of shape (seq_length - 1, 6).
motion = np.mean(motion, axis=0) # Average egomotion across frames.
if SAVE_PREVIEWS or step == FLAGS.num_steps:
# Also save preview of depth map.
color_map = util.normalize_depth_for_display(
np.squeeze(pred[0, :, :]))
visualization = np.concatenate(
(input_img_prev, input_img_center, input_img_next, color_map))
motion_s = [str(m) for m in motion]
s_rep = ','.join(motion_s)
with gfile.Open(img_pred_file.replace('.npy', '.txt'), 'w') as f:
f.write(s_rep)
util.save_image(
img_pred_file.replace('.npy', '.%s' % FLAGS.file_extension),
visualization, FLAGS.file_extension)
with gfile.Open(img_pred_file, 'wb') as f:
np.save(f, pred)
# Apply heuristic to not finetune if egomotion magnitude is too low.
ego_magnitude = np.linalg.norm(motion[:3], ord=2)
heuristic = ego_magnitude >= FLAGS.egomotion_threshold
if not heuristic and step == FLAGS.num_steps:
failed_heuristic.append(img_nr)
step += 1
img_nr += 1
return failed_heuristic
if __name__ == '__main__':
app.run(main)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Geometry utilities for projecting frames based on depth and motion.
Modified from Spatial Transformer Networks:
https://github.com/tensorflow/models/blob/master/transformer/spatial_transformer.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import numpy as np
import tensorflow as tf
def inverse_warp(img, depth, egomotion_mat, intrinsic_mat,
intrinsic_mat_inv):
"""Inverse warp a source image to the target image plane.
Args:
img: The source image (to sample pixels from) -- [B, H, W, 3].
depth: Depth map of the target image -- [B, H, W].
egomotion_mat: Matrix defining egomotion transform -- [B, 4, 4].
intrinsic_mat: Camera intrinsic matrix -- [B, 3, 3].
intrinsic_mat_inv: Inverse of the intrinsic matrix -- [B, 3, 3].
Returns:
Projected source image
"""
dims = tf.shape(img)
batch_size, img_height, img_width = dims[0], dims[1], dims[2]
depth = tf.reshape(depth, [batch_size, 1, img_height * img_width])
grid = _meshgrid_abs(img_height, img_width)
grid = tf.tile(tf.expand_dims(grid, 0), [batch_size, 1, 1])
cam_coords = _pixel2cam(depth, grid, intrinsic_mat_inv)
ones = tf.ones([batch_size, 1, img_height * img_width])
cam_coords_hom = tf.concat([cam_coords, ones], axis=1)
# Get projection matrix for target camera frame to source pixel frame
hom_filler = tf.constant([0.0, 0.0, 0.0, 1.0], shape=[1, 1, 4])
hom_filler = tf.tile(hom_filler, [batch_size, 1, 1])
intrinsic_mat_hom = tf.concat(
[intrinsic_mat, tf.zeros([batch_size, 3, 1])], axis=2)
intrinsic_mat_hom = tf.concat([intrinsic_mat_hom, hom_filler], axis=1)
proj_target_cam_to_source_pixel = tf.matmul(intrinsic_mat_hom, egomotion_mat)
source_pixel_coords = _cam2pixel(cam_coords_hom,
proj_target_cam_to_source_pixel)
source_pixel_coords = tf.reshape(source_pixel_coords,
[batch_size, 2, img_height, img_width])
source_pixel_coords = tf.transpose(source_pixel_coords, perm=[0, 2, 3, 1])
projected_img, mask = _spatial_transformer(img, source_pixel_coords)
return projected_img, mask
def get_transform_mat(egomotion_vecs, i, j):
"""Returns a transform matrix defining the transform from frame i to j."""
egomotion_transforms = []
batchsize = tf.shape(egomotion_vecs)[0]
if i == j:
return tf.tile(tf.expand_dims(tf.eye(4, 4), axis=0), [batchsize, 1, 1])
for k in range(min(i, j), max(i, j)):
transform_matrix = _egomotion_vec2mat(egomotion_vecs[:, k, :], batchsize)
if i > j: # Going back in sequence, need to invert egomotion.
egomotion_transforms.insert(0, tf.linalg.inv(transform_matrix))
else: # Going forward in sequence
egomotion_transforms.append(transform_matrix)
# Multiply all matrices.
egomotion_mat = egomotion_transforms[0]
for i in range(1, len(egomotion_transforms)):
egomotion_mat = tf.matmul(egomotion_mat, egomotion_transforms[i])
return egomotion_mat
def _pixel2cam(depth, pixel_coords, intrinsic_mat_inv):
"""Transform coordinates in the pixel frame to the camera frame."""
cam_coords = tf.matmul(intrinsic_mat_inv, pixel_coords) * depth
return cam_coords
def _cam2pixel(cam_coords, proj_c2p):
"""Transform coordinates in the camera frame to the pixel frame."""
pcoords = tf.matmul(proj_c2p, cam_coords)
x = tf.slice(pcoords, [0, 0, 0], [-1, 1, -1])
y = tf.slice(pcoords, [0, 1, 0], [-1, 1, -1])
z = tf.slice(pcoords, [0, 2, 0], [-1, 1, -1])
# Not tested if adding a small number is necessary
x_norm = x / (z + 1e-10)
y_norm = y / (z + 1e-10)
pixel_coords = tf.concat([x_norm, y_norm], axis=1)
return pixel_coords
def _meshgrid_abs(height, width):
"""Meshgrid in the absolute coordinates."""
x_t = tf.matmul(
tf.ones(shape=tf.stack([height, 1])),
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
y_t = tf.matmul(
tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
tf.ones(shape=tf.stack([1, width])))
x_t = (x_t + 1.0) * 0.5 * tf.cast(width - 1, tf.float32)
y_t = (y_t + 1.0) * 0.5 * tf.cast(height - 1, tf.float32)
x_t_flat = tf.reshape(x_t, (1, -1))
y_t_flat = tf.reshape(y_t, (1, -1))
ones = tf.ones_like(x_t_flat)
grid = tf.concat([x_t_flat, y_t_flat, ones], axis=0)
return grid
def _euler2mat(z, y, x):
"""Converts euler angles to rotation matrix.
From:
https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174
TODO: Remove the dimension for 'N' (deprecated for converting all source
poses altogether).
Args:
z: rotation angle along z axis (in radians) -- size = [B, n]
y: rotation angle along y axis (in radians) -- size = [B, n]
x: rotation angle along x axis (in radians) -- size = [B, n]
Returns:
Rotation matrix corresponding to the euler angles, with shape [B, n, 3, 3].
"""
batch_size = tf.shape(z)[0]
n = 1
z = tf.clip_by_value(z, -np.pi, np.pi)
y = tf.clip_by_value(y, -np.pi, np.pi)
x = tf.clip_by_value(x, -np.pi, np.pi)
# Expand to B x N x 1 x 1
z = tf.expand_dims(tf.expand_dims(z, -1), -1)
y = tf.expand_dims(tf.expand_dims(y, -1), -1)
x = tf.expand_dims(tf.expand_dims(x, -1), -1)
zeros = tf.zeros([batch_size, n, 1, 1])
ones = tf.ones([batch_size, n, 1, 1])
cosz = tf.cos(z)
sinz = tf.sin(z)
rotz_1 = tf.concat([cosz, -sinz, zeros], axis=3)
rotz_2 = tf.concat([sinz, cosz, zeros], axis=3)
rotz_3 = tf.concat([zeros, zeros, ones], axis=3)
zmat = tf.concat([rotz_1, rotz_2, rotz_3], axis=2)
cosy = tf.cos(y)
siny = tf.sin(y)
roty_1 = tf.concat([cosy, zeros, siny], axis=3)
roty_2 = tf.concat([zeros, ones, zeros], axis=3)
roty_3 = tf.concat([-siny, zeros, cosy], axis=3)
ymat = tf.concat([roty_1, roty_2, roty_3], axis=2)
cosx = tf.cos(x)
sinx = tf.sin(x)
rotx_1 = tf.concat([ones, zeros, zeros], axis=3)
rotx_2 = tf.concat([zeros, cosx, -sinx], axis=3)
rotx_3 = tf.concat([zeros, sinx, cosx], axis=3)
xmat = tf.concat([rotx_1, rotx_2, rotx_3], axis=2)
return tf.matmul(tf.matmul(xmat, ymat), zmat)
def _egomotion_vec2mat(vec, batch_size):
"""Converts 6DoF transform vector to transformation matrix.
Args:
vec: 6DoF parameters [tx, ty, tz, rx, ry, rz] -- [B, 6].
batch_size: Batch size.
Returns:
A transformation matrix -- [B, 4, 4].
"""
translation = tf.slice(vec, [0, 0], [-1, 3])
translation = tf.expand_dims(translation, -1)
rx = tf.slice(vec, [0, 3], [-1, 1])
ry = tf.slice(vec, [0, 4], [-1, 1])
rz = tf.slice(vec, [0, 5], [-1, 1])
rot_mat = _euler2mat(rz, ry, rx)
rot_mat = tf.squeeze(rot_mat, squeeze_dims=[1])
filler = tf.constant([0.0, 0.0, 0.0, 1.0], shape=[1, 1, 4])
filler = tf.tile(filler, [batch_size, 1, 1])
transform_mat = tf.concat([rot_mat, translation], axis=2)
transform_mat = tf.concat([transform_mat, filler], axis=1)
return transform_mat
def _bilinear_sampler(im, x, y, name='blinear_sampler'):
"""Perform bilinear sampling on im given list of x, y coordinates.
Implements the differentiable sampling mechanism with bilinear kernel
in https://arxiv.org/abs/1506.02025.
x,y are tensors specifying normalized coordinates [-1, 1] to be sampled on im.
For example, (-1, -1) in (x, y) corresponds to pixel location (0, 0) in im,
and (1, 1) in (x, y) corresponds to the bottom right pixel in im.
Args:
im: Batch of images with shape [B, h, w, channels].
x: Tensor of normalized x coordinates in [-1, 1], with shape [B, h, w, 1].
y: Tensor of normalized y coordinates in [-1, 1], with shape [B, h, w, 1].
name: Name scope for ops.
Returns:
Sampled image with shape [B, h, w, channels].
Principled mask with shape [B, h, w, 1], dtype:float32. A value of 1.0
in the mask indicates that the corresponding coordinate in the sampled
image is valid.
"""
with tf.variable_scope(name):
x = tf.reshape(x, [-1])
y = tf.reshape(y, [-1])
# Constants.
batch_size = tf.shape(im)[0]
_, height, width, channels = im.get_shape().as_list()
x = tf.to_float(x)
y = tf.to_float(y)
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
zero = tf.constant(0, dtype=tf.int32)
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
# Scale indices from [-1, 1] to [0, width - 1] or [0, height - 1].
x = (x + 1.0) * (width_f - 1.0) / 2.0
y = (y + 1.0) * (height_f - 1.0) / 2.0
# Compute the coordinates of the 4 pixels to sample from.
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
mask = tf.logical_and(
tf.logical_and(x0 >= zero, x1 <= max_x),
tf.logical_and(y0 >= zero, y1 <= max_y))
mask = tf.to_float(mask)
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
dim2 = width
dim1 = width * height
# Create base index.
base = tf.range(batch_size) * dim1
base = tf.reshape(base, [-1, 1])
base = tf.tile(base, [1, height * width])
base = tf.reshape(base, [-1])
base_y0 = base + y0 * dim2
base_y1 = base + y1 * dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
# Use indices to lookup pixels in the flat image and restore channels dim.
im_flat = tf.reshape(im, tf.stack([-1, channels]))
im_flat = tf.to_float(im_flat)
pixel_a = tf.gather(im_flat, idx_a)
pixel_b = tf.gather(im_flat, idx_b)
pixel_c = tf.gather(im_flat, idx_c)
pixel_d = tf.gather(im_flat, idx_d)
x1_f = tf.to_float(x1)
y1_f = tf.to_float(y1)
# And finally calculate interpolated values.
wa = tf.expand_dims(((x1_f - x) * (y1_f - y)), 1)
wb = tf.expand_dims((x1_f - x) * (1.0 - (y1_f - y)), 1)
wc = tf.expand_dims(((1.0 - (x1_f - x)) * (y1_f - y)), 1)
wd = tf.expand_dims(((1.0 - (x1_f - x)) * (1.0 - (y1_f - y))), 1)
output = tf.add_n([wa * pixel_a, wb * pixel_b, wc * pixel_c, wd * pixel_d])
output = tf.reshape(output, tf.stack([batch_size, height, width, channels]))
mask = tf.reshape(mask, tf.stack([batch_size, height, width, 1]))
return output, mask
def _spatial_transformer(img, coords):
"""A wrapper over binlinear_sampler(), taking absolute coords as input."""
img_height = tf.cast(tf.shape(img)[1], tf.float32)
img_width = tf.cast(tf.shape(img)[2], tf.float32)
px = coords[:, :, :, :1]
py = coords[:, :, :, 1:]
# Normalize coordinates to [-1, 1] to send to _bilinear_sampler.
px = px / (img_width - 1) * 2.0 - 1.0
py = py / (img_height - 1) * 2.0 - 1.0
output_img, mask = _bilinear_sampler(img, px, py)
return output_img, mask
def get_cloud(depth, intrinsics_inv, name=None):
"""Convert depth map to 3D point cloud."""
with tf.name_scope(name):
dims = depth.shape.as_list()
batch_size, img_height, img_width = dims[0], dims[1], dims[2]
depth = tf.reshape(depth, [batch_size, 1, img_height * img_width])
grid = _meshgrid_abs(img_height, img_width)
grid = tf.tile(tf.expand_dims(grid, 0), [batch_size, 1, 1])
cam_coords = _pixel2cam(depth, grid, intrinsics_inv)
cam_coords = tf.transpose(cam_coords, [0, 2, 1])
cam_coords = tf.reshape(cam_coords, [batch_size, img_height, img_width, 3])
logging.info('depth -> cloud: %s', cam_coords)
return cam_coords
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reads data that is produced by dataset/gen_data.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
from absl import logging
import tensorflow as tf
import util
gfile = tf.gfile
QUEUE_SIZE = 2000
QUEUE_BUFFER = 3
# See nets.encoder_resnet as reference for below input-normalizing constants.
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_SD = (0.229, 0.224, 0.225)
FLIP_RANDOM = 'random' # Always perform random flipping.
FLIP_ALWAYS = 'always' # Always flip image input, used for test augmentation.
FLIP_NONE = 'none' # Always disables flipping.
class DataReader(object):
"""Reads stored sequences which are produced by dataset/gen_data.py."""
def __init__(self, data_dir, batch_size, img_height, img_width, seq_length,
num_scales, file_extension, random_scale_crop, flipping_mode,
random_color, imagenet_norm, shuffle, input_file='train'):
self.data_dir = data_dir
self.batch_size = batch_size
self.img_height = img_height
self.img_width = img_width
self.seq_length = seq_length
self.num_scales = num_scales
self.file_extension = file_extension
self.random_scale_crop = random_scale_crop
self.flipping_mode = flipping_mode
self.random_color = random_color
self.imagenet_norm = imagenet_norm
self.shuffle = shuffle
self.input_file = input_file
def read_data(self):
"""Provides images and camera intrinsics."""
with tf.name_scope('data_loading'):
with tf.name_scope('enqueue_paths'):
seed = random.randint(0, 2**31 - 1)
self.file_lists = self.compile_file_list(self.data_dir, self.input_file)
image_paths_queue = tf.train.string_input_producer(
self.file_lists['image_file_list'], seed=seed,
shuffle=self.shuffle,
num_epochs=(1 if not self.shuffle else None)
)
seg_paths_queue = tf.train.string_input_producer(
self.file_lists['segment_file_list'], seed=seed,
shuffle=self.shuffle,
num_epochs=(1 if not self.shuffle else None))
cam_paths_queue = tf.train.string_input_producer(
self.file_lists['cam_file_list'], seed=seed,
shuffle=self.shuffle,
num_epochs=(1 if not self.shuffle else None))
img_reader = tf.WholeFileReader()
_, image_contents = img_reader.read(image_paths_queue)
seg_reader = tf.WholeFileReader()
_, seg_contents = seg_reader.read(seg_paths_queue)
if self.file_extension == 'jpg':
image_seq = tf.image.decode_jpeg(image_contents)
seg_seq = tf.image.decode_jpeg(seg_contents, channels=3)
elif self.file_extension == 'png':
image_seq = tf.image.decode_png(image_contents, channels=3)
seg_seq = tf.image.decode_png(seg_contents, channels=3)
with tf.name_scope('load_intrinsics'):
cam_reader = tf.TextLineReader()
_, raw_cam_contents = cam_reader.read(cam_paths_queue)
rec_def = []
for _ in range(9):
rec_def.append([1.0])
raw_cam_vec = tf.decode_csv(raw_cam_contents, record_defaults=rec_def)
raw_cam_vec = tf.stack(raw_cam_vec)
intrinsics = tf.reshape(raw_cam_vec, [3, 3])
with tf.name_scope('convert_image'):
image_seq = self.preprocess_image(image_seq) # Converts to float.
if self.random_color:
with tf.name_scope('image_augmentation'):
image_seq = self.augment_image_colorspace(image_seq)
image_stack = self.unpack_images(image_seq)
seg_stack = self.unpack_images(seg_seq)
if self.flipping_mode != FLIP_NONE:
random_flipping = (self.flipping_mode == FLIP_RANDOM)
with tf.name_scope('image_augmentation_flip'):
image_stack, seg_stack, intrinsics = self.augment_images_flip(
image_stack, seg_stack, intrinsics,
randomized=random_flipping)
if self.random_scale_crop:
with tf.name_scope('image_augmentation_scale_crop'):
image_stack, seg_stack, intrinsics = self.augment_images_scale_crop(
image_stack, seg_stack, intrinsics, self.img_height,
self.img_width)
with tf.name_scope('multi_scale_intrinsics'):
intrinsic_mat = self.get_multi_scale_intrinsics(intrinsics,
self.num_scales)
intrinsic_mat.set_shape([self.num_scales, 3, 3])
intrinsic_mat_inv = tf.matrix_inverse(intrinsic_mat)
intrinsic_mat_inv.set_shape([self.num_scales, 3, 3])
if self.imagenet_norm:
im_mean = tf.tile(
tf.constant(IMAGENET_MEAN), multiples=[self.seq_length])
im_sd = tf.tile(
tf.constant(IMAGENET_SD), multiples=[self.seq_length])
image_stack_norm = (image_stack - im_mean) / im_sd
else:
image_stack_norm = image_stack
with tf.name_scope('batching'):
if self.shuffle:
(image_stack, image_stack_norm, seg_stack, intrinsic_mat,
intrinsic_mat_inv) = tf.train.shuffle_batch(
[image_stack, image_stack_norm, seg_stack, intrinsic_mat,
intrinsic_mat_inv],
batch_size=self.batch_size,
capacity=QUEUE_SIZE + QUEUE_BUFFER * self.batch_size,
min_after_dequeue=QUEUE_SIZE)
else:
(image_stack, image_stack_norm, seg_stack, intrinsic_mat,
intrinsic_mat_inv) = tf.train.batch(
[image_stack, image_stack_norm, seg_stack, intrinsic_mat,
intrinsic_mat_inv],
batch_size=self.batch_size,
num_threads=1,
capacity=QUEUE_SIZE + QUEUE_BUFFER * self.batch_size)
logging.info('image_stack: %s', util.info(image_stack))
return (image_stack, image_stack_norm, seg_stack, intrinsic_mat,
intrinsic_mat_inv)
def unpack_images(self, image_seq):
"""[h, w * seq_length, 3] -> [h, w, 3 * seq_length]."""
with tf.name_scope('unpack_images'):
image_list = [
image_seq[:, i * self.img_width:(i + 1) * self.img_width, :]
for i in range(self.seq_length)
]
image_stack = tf.concat(image_list, axis=2)
image_stack.set_shape(
[self.img_height, self.img_width, self.seq_length * 3])
return image_stack
@classmethod
def preprocess_image(cls, image):
# Convert from uint8 to float.
return tf.image.convert_image_dtype(image, dtype=tf.float32)
@classmethod
def augment_image_colorspace(cls, image_stack):
"""Apply data augmentation to inputs."""
image_stack_aug = image_stack
# Randomly shift brightness.
apply_brightness = tf.less(tf.random_uniform(
shape=[], minval=0.0, maxval=1.0, dtype=tf.float32), 0.5)
image_stack_aug = tf.cond(
apply_brightness,
lambda: tf.image.random_brightness(image_stack_aug, max_delta=0.1),
lambda: image_stack_aug)
# Randomly shift contrast.
apply_contrast = tf.less(tf.random_uniform(
shape=[], minval=0.0, maxval=1.0, dtype=tf.float32), 0.5)
image_stack_aug = tf.cond(
apply_contrast,
lambda: tf.image.random_contrast(image_stack_aug, 0.85, 1.15),
lambda: image_stack_aug)
# Randomly change saturation.
apply_saturation = tf.less(tf.random_uniform(
shape=[], minval=0.0, maxval=1.0, dtype=tf.float32), 0.5)
image_stack_aug = tf.cond(
apply_saturation,
lambda: tf.image.random_saturation(image_stack_aug, 0.85, 1.15),
lambda: image_stack_aug)
# Randomly change hue.
apply_hue = tf.less(tf.random_uniform(
shape=[], minval=0.0, maxval=1.0, dtype=tf.float32), 0.5)
image_stack_aug = tf.cond(
apply_hue,
lambda: tf.image.random_hue(image_stack_aug, max_delta=0.1),
lambda: image_stack_aug)
image_stack_aug = tf.clip_by_value(image_stack_aug, 0, 1)
return image_stack_aug
@classmethod
def augment_images_flip(cls, image_stack, seg_stack, intrinsics,
randomized=True):
"""Randomly flips the image horizontally."""
def flip(cls, image_stack, seg_stack, intrinsics):
_, in_w, _ = image_stack.get_shape().as_list()
fx = intrinsics[0, 0]
fy = intrinsics[1, 1]
cx = in_w - intrinsics[0, 2]
cy = intrinsics[1, 2]
intrinsics = cls.make_intrinsics_matrix(fx, fy, cx, cy)
return (tf.image.flip_left_right(image_stack),
tf.image.flip_left_right(seg_stack), intrinsics)
if randomized:
prob = tf.random_uniform(shape=[], minval=0.0, maxval=1.0,
dtype=tf.float32)
predicate = tf.less(prob, 0.5)
return tf.cond(predicate,
lambda: flip(cls, image_stack, seg_stack, intrinsics),
lambda: (image_stack, seg_stack, intrinsics))
else:
return flip(cls, image_stack, seg_stack, intrinsics)
@classmethod
def augment_images_scale_crop(cls, im, seg, intrinsics, out_h, out_w):
"""Randomly scales and crops image."""
def scale_randomly(im, seg, intrinsics):
"""Scales image and adjust intrinsics accordingly."""
in_h, in_w, _ = im.get_shape().as_list()
scaling = tf.random_uniform([2], 1, 1.15)
x_scaling = scaling[0]
y_scaling = scaling[1]
out_h = tf.cast(in_h * y_scaling, dtype=tf.int32)
out_w = tf.cast(in_w * x_scaling, dtype=tf.int32)
# Add batch.
im = tf.expand_dims(im, 0)
im = tf.image.resize_area(im, [out_h, out_w])
im = im[0]
seg = tf.expand_dims(seg, 0)
seg = tf.image.resize_area(seg, [out_h, out_w])
seg = seg[0]
fx = intrinsics[0, 0] * x_scaling
fy = intrinsics[1, 1] * y_scaling
cx = intrinsics[0, 2] * x_scaling
cy = intrinsics[1, 2] * y_scaling
intrinsics = cls.make_intrinsics_matrix(fx, fy, cx, cy)
return im, seg, intrinsics
# Random cropping
def crop_randomly(im, seg, intrinsics, out_h, out_w):
"""Crops image and adjust intrinsics accordingly."""
# batch_size, in_h, in_w, _ = im.get_shape().as_list()
in_h, in_w, _ = tf.unstack(tf.shape(im))
offset_y = tf.random_uniform([1], 0, in_h - out_h + 1, dtype=tf.int32)[0]
offset_x = tf.random_uniform([1], 0, in_w - out_w + 1, dtype=tf.int32)[0]
im = tf.image.crop_to_bounding_box(im, offset_y, offset_x, out_h, out_w)
seg = tf.image.crop_to_bounding_box(seg, offset_y, offset_x, out_h, out_w)
fx = intrinsics[0, 0]
fy = intrinsics[1, 1]
cx = intrinsics[0, 2] - tf.cast(offset_x, dtype=tf.float32)
cy = intrinsics[1, 2] - tf.cast(offset_y, dtype=tf.float32)
intrinsics = cls.make_intrinsics_matrix(fx, fy, cx, cy)
return im, seg, intrinsics
im, seg, intrinsics = scale_randomly(im, seg, intrinsics)
im, seg, intrinsics = crop_randomly(im, seg, intrinsics, out_h, out_w)
return im, seg, intrinsics
def compile_file_list(self, data_dir, split, load_pose=False):
"""Creates a list of input files."""
logging.info('data_dir: %s', data_dir)
with gfile.Open(os.path.join(data_dir, '%s.txt' % split), 'r') as f:
frames = f.readlines()
frames = [k.rstrip() for k in frames]
subfolders = [x.split(' ')[0] for x in frames]
frame_ids = [x.split(' ')[1] for x in frames]
image_file_list = [
os.path.join(data_dir, subfolders[i], frame_ids[i] + '.' +
self.file_extension)
for i in range(len(frames))
]
segment_file_list = [
os.path.join(data_dir, subfolders[i], frame_ids[i] + '-fseg.' +
self.file_extension)
for i in range(len(frames))
]
cam_file_list = [
os.path.join(data_dir, subfolders[i], frame_ids[i] + '_cam.txt')
for i in range(len(frames))
]
file_lists = {}
file_lists['image_file_list'] = image_file_list
file_lists['segment_file_list'] = segment_file_list
file_lists['cam_file_list'] = cam_file_list
if load_pose:
pose_file_list = [
os.path.join(data_dir, subfolders[i], frame_ids[i] + '_pose.txt')
for i in range(len(frames))
]
file_lists['pose_file_list'] = pose_file_list
self.steps_per_epoch = len(image_file_list) // self.batch_size
return file_lists
@classmethod
def make_intrinsics_matrix(cls, fx, fy, cx, cy):
r1 = tf.stack([fx, 0, cx])
r2 = tf.stack([0, fy, cy])
r3 = tf.constant([0., 0., 1.])
intrinsics = tf.stack([r1, r2, r3])
return intrinsics
@classmethod
def get_multi_scale_intrinsics(cls, intrinsics, num_scales):
"""Returns multiple intrinsic matrices for different scales."""
intrinsics_multi_scale = []
# Scale the intrinsics accordingly for each scale
for s in range(num_scales):
fx = intrinsics[0, 0] / (2**s)
fy = intrinsics[1, 1] / (2**s)
cx = intrinsics[0, 2] / (2**s)
cy = intrinsics[1, 2] / (2**s)
intrinsics_multi_scale.append(cls.make_intrinsics_matrix(fx, fy, cx, cy))
intrinsics_multi_scale = tf.stack(intrinsics_multi_scale)
return intrinsics_multi_scale
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train the model. Please refer to README for example usage."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import random
import time
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
import model
import nets
import reader
import util
gfile = tf.gfile
MAX_TO_KEEP = 1000000 # Maximum number of checkpoints to keep.
flags.DEFINE_string('data_dir', None, 'Preprocessed data.')
flags.DEFINE_string('file_extension', 'png', 'Image data file extension.')
flags.DEFINE_float('learning_rate', 0.0002, 'Adam learning rate.')
flags.DEFINE_float('beta1', 0.9, 'Adam momentum.')
flags.DEFINE_float('reconstr_weight', 0.85, 'Frame reconstruction loss weight.')
flags.DEFINE_float('ssim_weight', 0.15, 'SSIM loss weight.')
flags.DEFINE_float('smooth_weight', 0.04, 'Smoothness loss weight.')
flags.DEFINE_float('icp_weight', 0.0, 'ICP loss weight.')
flags.DEFINE_float('size_constraint_weight', 0.0005, 'Weight of the object '
'size constraint loss. Use only when motion handling is '
'enabled.')
flags.DEFINE_integer('batch_size', 4, 'The size of a sample batch')
flags.DEFINE_integer('img_height', 128, 'Input frame height.')
flags.DEFINE_integer('img_width', 416, 'Input frame width.')
flags.DEFINE_integer('seq_length', 3, 'Number of frames in sequence.')
flags.DEFINE_enum('architecture', nets.RESNET, nets.ARCHITECTURES,
'Defines the architecture to use for the depth prediction '
'network. Defaults to ResNet-based encoder and accompanying '
'decoder.')
flags.DEFINE_boolean('imagenet_norm', True, 'Whether to normalize the input '
'images channel-wise so that they match the distribution '
'most ImageNet-models were trained on.')
flags.DEFINE_float('weight_reg', 0.05, 'The amount of weight regularization to '
'apply. This has no effect on the ResNet-based encoder '
'architecture.')
flags.DEFINE_boolean('exhaustive_mode', False, 'Whether to exhaustively warp '
'from any frame to any other instead of just considering '
'adjacent frames. Where necessary, multiple egomotion '
'estimates will be applied. Does not have an effect if '
'compute_minimum_loss is enabled.')
flags.DEFINE_boolean('random_scale_crop', False, 'Whether to apply random '
'image scaling and center cropping during training.')
flags.DEFINE_enum('flipping_mode', reader.FLIP_RANDOM,
[reader.FLIP_RANDOM, reader.FLIP_ALWAYS, reader.FLIP_NONE],
'Determines the image flipping mode: if random, performs '
'on-the-fly augmentation. Otherwise, flips the input images '
'always or never, respectively.')
flags.DEFINE_string('pretrained_ckpt', None, 'Path to checkpoint with '
'pretrained weights. Do not include .data* extension.')
flags.DEFINE_string('imagenet_ckpt', None, 'Initialize the weights according '
'to an ImageNet-pretrained checkpoint. Requires '
'architecture to be ResNet-18.')
flags.DEFINE_string('checkpoint_dir', None, 'Directory to save model '
'checkpoints.')
flags.DEFINE_integer('train_steps', 10000000, 'Number of training steps.')
flags.DEFINE_integer('summary_freq', 100, 'Save summaries every N steps.')
flags.DEFINE_bool('depth_upsampling', True, 'Whether to apply depth '
'upsampling of lower-scale representations before warping to '
'compute reconstruction loss on full-resolution image.')
flags.DEFINE_bool('depth_normalization', True, 'Whether to apply depth '
'normalization, that is, normalizing inverse depth '
'prediction maps by their mean to avoid degeneration towards '
'small values.')
flags.DEFINE_bool('compute_minimum_loss', True, 'Whether to take the '
'element-wise minimum of the reconstruction/SSIM error in '
'order to avoid overly penalizing dis-occlusion effects.')
flags.DEFINE_bool('use_skip', True, 'Whether to use skip connections in the '
'encoder-decoder architecture.')
flags.DEFINE_bool('equal_weighting', False, 'Whether to use equal weighting '
'of the smoothing loss term, regardless of resolution.')
flags.DEFINE_bool('joint_encoder', False, 'Whether to share parameters '
'between the depth and egomotion networks by using a joint '
'encoder architecture. The egomotion network is then '
'operating only on the hidden representation provided by the '
'joint encoder.')
flags.DEFINE_bool('handle_motion', True, 'Whether to try to handle motion by '
'using the provided segmentation masks.')
flags.DEFINE_string('master', 'local', 'Location of the session.')
FLAGS = flags.FLAGS
flags.mark_flag_as_required('data_dir')
flags.mark_flag_as_required('checkpoint_dir')
def main(_):
# Fixed seed for repeatability
seed = 8964
tf.set_random_seed(seed)
np.random.seed(seed)
random.seed(seed)
if FLAGS.handle_motion and FLAGS.joint_encoder:
raise ValueError('Using a joint encoder is currently not supported when '
'modeling object motion.')
if FLAGS.handle_motion and FLAGS.seq_length != 3:
raise ValueError('The current motion model implementation only supports '
'using a sequence length of three.')
if FLAGS.handle_motion and not FLAGS.compute_minimum_loss:
raise ValueError('Computing the minimum photometric loss is required when '
'enabling object motion handling.')
if FLAGS.size_constraint_weight > 0 and not FLAGS.handle_motion:
raise ValueError('To enforce object size constraints, enable motion '
'handling.')
if FLAGS.imagenet_ckpt and not FLAGS.imagenet_norm:
logging.warn('When initializing with an ImageNet-pretrained model, it is '
'recommended to normalize the image inputs accordingly using '
'imagenet_norm.')
if FLAGS.compute_minimum_loss and FLAGS.seq_length % 2 != 1:
raise ValueError('Compute minimum loss requires using an odd number of '
'images in a sequence.')
if FLAGS.architecture != nets.RESNET and FLAGS.imagenet_ckpt:
raise ValueError('Can only load weights from pre-trained ImageNet model '
'when using ResNet-architecture.')
if FLAGS.compute_minimum_loss and FLAGS.exhaustive_mode:
raise ValueError('Exhaustive mode has no effect when compute_minimum_loss '
'is enabled.')
if FLAGS.img_width % (2 ** 5) != 0 or FLAGS.img_height % (2 ** 5) != 0:
logging.warn('Image size is not divisible by 2^5. For the architecture '
'employed, this could cause artefacts caused by resizing in '
'lower dimensions.')
if FLAGS.icp_weight > 0.0:
# TODO(casser): Change ICP interface to take matrix instead of vector.
raise ValueError('ICP is currently not supported.')
if not gfile.Exists(FLAGS.checkpoint_dir):
gfile.MakeDirs(FLAGS.checkpoint_dir)
train_model = model.Model(data_dir=FLAGS.data_dir,
file_extension=FLAGS.file_extension,
is_training=True,
learning_rate=FLAGS.learning_rate,
beta1=FLAGS.beta1,
reconstr_weight=FLAGS.reconstr_weight,
smooth_weight=FLAGS.smooth_weight,
ssim_weight=FLAGS.ssim_weight,
icp_weight=FLAGS.icp_weight,
batch_size=FLAGS.batch_size,
img_height=FLAGS.img_height,
img_width=FLAGS.img_width,
seq_length=FLAGS.seq_length,
architecture=FLAGS.architecture,
imagenet_norm=FLAGS.imagenet_norm,
weight_reg=FLAGS.weight_reg,
exhaustive_mode=FLAGS.exhaustive_mode,
random_scale_crop=FLAGS.random_scale_crop,
flipping_mode=FLAGS.flipping_mode,
depth_upsampling=FLAGS.depth_upsampling,
depth_normalization=FLAGS.depth_normalization,
compute_minimum_loss=FLAGS.compute_minimum_loss,
use_skip=FLAGS.use_skip,
joint_encoder=FLAGS.joint_encoder,
handle_motion=FLAGS.handle_motion,
equal_weighting=FLAGS.equal_weighting,
size_constraint_weight=FLAGS.size_constraint_weight)
train(train_model, FLAGS.pretrained_ckpt, FLAGS.imagenet_ckpt,
FLAGS.checkpoint_dir, FLAGS.train_steps, FLAGS.summary_freq)
def train(train_model, pretrained_ckpt, imagenet_ckpt, checkpoint_dir,
train_steps, summary_freq):
"""Train model."""
vars_to_restore = None
if pretrained_ckpt is not None:
vars_to_restore = util.get_vars_to_save_and_restore(pretrained_ckpt)
ckpt_path = pretrained_ckpt
elif imagenet_ckpt:
vars_to_restore = util.get_imagenet_vars_to_restore(imagenet_ckpt)
ckpt_path = imagenet_ckpt
pretrain_restorer = tf.train.Saver(vars_to_restore)
vars_to_save = util.get_vars_to_save_and_restore()
vars_to_save[train_model.global_step.op.name] = train_model.global_step
saver = tf.train.Saver(vars_to_save, max_to_keep=MAX_TO_KEEP)
sv = tf.train.Supervisor(logdir=checkpoint_dir, save_summaries_secs=0,
saver=None)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with sv.managed_session(config=config) as sess:
if pretrained_ckpt is not None or imagenet_ckpt:
logging.info('Restoring pretrained weights from %s', ckpt_path)
pretrain_restorer.restore(sess, ckpt_path)
logging.info('Attempting to resume training from %s...', checkpoint_dir)
checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
logging.info('Last checkpoint found: %s', checkpoint)
if checkpoint:
saver.restore(sess, checkpoint)
logging.info('Training...')
start_time = time.time()
last_summary_time = time.time()
steps_per_epoch = train_model.reader.steps_per_epoch
step = 1
while step <= train_steps:
fetches = {
'train': train_model.train_op,
'global_step': train_model.global_step,
'incr_global_step': train_model.incr_global_step
}
if step % summary_freq == 0:
fetches['loss'] = train_model.total_loss
fetches['summary'] = sv.summary_op
results = sess.run(fetches)
global_step = results['global_step']
if step % summary_freq == 0:
sv.summary_writer.add_summary(results['summary'], global_step)
train_epoch = math.ceil(global_step / steps_per_epoch)
train_step = global_step - (train_epoch - 1) * steps_per_epoch
this_cycle = time.time() - last_summary_time
last_summary_time += this_cycle
logging.info(
'Epoch: [%2d] [%5d/%5d] time: %4.2fs (%ds total) loss: %.3f',
train_epoch, train_step, steps_per_epoch, this_cycle,
time.time() - start_time, results['loss'])
if step % steps_per_epoch == 0:
logging.info('[*] Saving checkpoint to %s...', checkpoint_dir)
saver.save(sess, os.path.join(checkpoint_dir, 'model'),
global_step=global_step)
# Setting step to global_step allows for training for a total of
# train_steps even if the program is restarted during training.
step = global_step + 1
if __name__ == '__main__':
app.run(main)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common utilities and functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import locale
import os
import re
from absl import logging
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import cv2
gfile = tf.gfile
CMAP_DEFAULT = 'plasma'
# Defines the cropping that is applied to the Cityscapes dataset with respect to
# the original raw input resolution.
CITYSCAPES_CROP = [256, 768, 192, 1856]
def crop_cityscapes(im, resize=None):
ymin, ymax, xmin, xmax = CITYSCAPES_CROP
im = im[ymin:ymax, xmin:xmax]
if resize is not None:
im = cv2.resize(im, resize)
return im
def gray2rgb(im, cmap=CMAP_DEFAULT):
cmap = plt.get_cmap(cmap)
result_img = cmap(im.astype(np.float32))
if result_img.shape[2] > 3:
result_img = np.delete(result_img, 3, 2)
return result_img
def load_image(img_file, resize=None, interpolation='linear'):
"""Load image from disk. Output value range: [0,1]."""
im_data = np.fromstring(gfile.Open(img_file).read(), np.uint8)
im = cv2.imdecode(im_data, cv2.IMREAD_COLOR)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if resize and resize != im.shape[:2]:
ip = cv2.INTER_LINEAR if interpolation == 'linear' else cv2.INTER_NEAREST
im = cv2.resize(im, resize, interpolation=ip)
return np.array(im, dtype=np.float32) / 255.0
def save_image(img_file, im, file_extension):
"""Save image from disk. Expected input value range: [0,1]."""
im = (im * 255.0).astype(np.uint8)
with gfile.Open(img_file, 'w') as f:
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
_, im_data = cv2.imencode('.%s' % file_extension, im)
f.write(im_data.tostring())
def normalize_depth_for_display(depth, pc=95, crop_percent=0, normalizer=None,
cmap=CMAP_DEFAULT):
"""Converts a depth map to an RGB image."""
# Convert to disparity.
disp = 1.0 / (depth + 1e-6)
if normalizer is not None:
disp /= normalizer
else:
disp /= (np.percentile(disp, pc) + 1e-6)
disp = np.clip(disp, 0, 1)
disp = gray2rgb(disp, cmap=cmap)
keep_h = int(disp.shape[0] * (1 - crop_percent))
disp = disp[:keep_h]
return disp
def get_seq_start_end(target_index, seq_length, sample_every=1):
"""Returns absolute seq start and end indices for a given target frame."""
half_offset = int((seq_length - 1) / 2) * sample_every
end_index = target_index + half_offset
start_index = end_index - (seq_length - 1) * sample_every
return start_index, end_index
def get_seq_middle(seq_length):
"""Returns relative index for the middle frame in sequence."""
half_offset = int((seq_length - 1) / 2)
return seq_length - 1 - half_offset
def info(obj):
"""Return info on shape and dtype of a numpy array or TensorFlow tensor."""
if obj is None:
return 'None.'
elif isinstance(obj, list):
if obj:
return 'List of %d... %s' % (len(obj), info(obj[0]))
else:
return 'Empty list.'
elif isinstance(obj, tuple):
if obj:
return 'Tuple of %d... %s' % (len(obj), info(obj[0]))
else:
return 'Empty tuple.'
else:
if is_a_numpy_array(obj):
return 'Array with shape: %s, dtype: %s' % (obj.shape, obj.dtype)
else:
return str(obj)
def is_a_numpy_array(obj):
"""Returns true if obj is a numpy array."""
return type(obj).__module__ == np.__name__
def count_parameters(also_print=True):
"""Cound the number of parameters in the model.
Args:
also_print: Boolean. If True also print the numbers.
Returns:
The total number of parameters.
"""
total = 0
if also_print:
logging.info('Model Parameters:')
for (_, v) in get_vars_to_save_and_restore().items():
shape = v.get_shape()
if also_print:
logging.info('%s %s: %s', v.op.name, shape,
format_number(shape.num_elements()))
total += shape.num_elements()
if also_print:
logging.info('Total: %s', format_number(total))
return total
def get_vars_to_save_and_restore(ckpt=None):
"""Returns list of variables that should be saved/restored.
Args:
ckpt: Path to existing checkpoint. If present, returns only the subset of
variables that exist in given checkpoint.
Returns:
List of all variables that need to be saved/restored.
"""
model_vars = tf.trainable_variables()
# Add batchnorm variables.
bn_vars = [v for v in tf.global_variables()
if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name or
'mu' in v.op.name or 'sigma' in v.op.name or
'global_scale_var' in v.op.name]
model_vars.extend(bn_vars)
model_vars = sorted(model_vars, key=lambda x: x.op.name)
mapping = {}
if ckpt is not None:
ckpt_var = tf.contrib.framework.list_variables(ckpt)
ckpt_var_names = [name for (name, unused_shape) in ckpt_var]
ckpt_var_shapes = [shape for (unused_name, shape) in ckpt_var]
not_loaded = list(ckpt_var_names)
for v in model_vars:
if v.op.name not in ckpt_var_names:
# For backward compatibility, try additional matching.
v_additional_name = v.op.name.replace('egomotion_prediction/', '')
if v_additional_name in ckpt_var_names:
# Check if shapes match.
ind = ckpt_var_names.index(v_additional_name)
if ckpt_var_shapes[ind] == v.get_shape():
mapping[v_additional_name] = v
not_loaded.remove(v_additional_name)
continue
else:
logging.warn('Shape mismatch, will not restore %s.', v.op.name)
logging.warn('Did not find var %s in checkpoint: %s', v.op.name,
os.path.basename(ckpt))
else:
# Check if shapes match.
ind = ckpt_var_names.index(v.op.name)
if ckpt_var_shapes[ind] == v.get_shape():
mapping[v.op.name] = v
not_loaded.remove(v.op.name)
else:
logging.warn('Shape mismatch, will not restore %s.', v.op.name)
if not_loaded:
logging.warn('The following variables in the checkpoint were not loaded:')
for varname_not_loaded in not_loaded:
logging.info('%s', varname_not_loaded)
else: # just get model vars.
for v in model_vars:
mapping[v.op.name] = v
return mapping
def get_imagenet_vars_to_restore(imagenet_ckpt):
"""Returns dict of variables to restore from ImageNet-checkpoint."""
vars_to_restore_imagenet = {}
ckpt_var_names = tf.contrib.framework.list_variables(imagenet_ckpt)
ckpt_var_names = [name for (name, unused_shape) in ckpt_var_names]
model_vars = tf.global_variables()
for v in model_vars:
if 'global_step' in v.op.name: continue
mvname_noprefix = v.op.name.replace('depth_prediction/', '')
mvname_noprefix = mvname_noprefix.replace('moving_mean', 'mu')
mvname_noprefix = mvname_noprefix.replace('moving_variance', 'sigma')
if mvname_noprefix in ckpt_var_names:
vars_to_restore_imagenet[mvname_noprefix] = v
else:
logging.info('The following variable will not be restored from '
'pretrained ImageNet-checkpoint: %s', mvname_noprefix)
return vars_to_restore_imagenet
def format_number(n):
"""Formats number with thousands commas."""
locale.setlocale(locale.LC_ALL, 'en_US')
return locale.format('%d', n, grouping=True)
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
return [atoi(c) for c in re.split(r'(\d+)', text)]
def read_text_lines(filepath):
with tf.gfile.Open(filepath, 'r') as f:
lines = f.readlines()
lines = [l.rstrip() for l in lines]
return lines
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment