Commit dff0f0c1 authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Merge branch 'master' of github.com:tensorflow/models

parents da341f70 36203f09
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Replay buffer.
Implements replay buffer in Python.
"""
import random
import numpy as np
class ReplayBuffer(object):
def __init__(self, max_size):
self.max_size = max_size
self.cur_size = 0
self.buffer = {}
self.init_length = 0
def __len__(self):
return self.cur_size
def seed_buffer(self, episodes):
self.init_length = len(episodes)
self.add(episodes, np.ones(self.init_length))
def add(self, episodes, *args):
"""Add episodes to buffer."""
idx = 0
while self.cur_size < self.max_size and idx < len(episodes):
self.buffer[self.cur_size] = episodes[idx]
self.cur_size += 1
idx += 1
if idx < len(episodes):
remove_idxs = self.remove_n(len(episodes) - idx)
for remove_idx in remove_idxs:
self.buffer[remove_idx] = episodes[idx]
idx += 1
assert len(self.buffer) == self.cur_size
def remove_n(self, n):
"""Get n items for removal."""
# random removal
idxs = random.sample(xrange(self.init_length, self.cur_size), n)
return idxs
def get_batch(self, n):
"""Get batch of episodes to train on."""
# random batch
idxs = random.sample(xrange(self.cur_size), n)
return [self.buffer[idx] for idx in idxs], None
def update_last_batch(self, delta):
pass
class PrioritizedReplayBuffer(ReplayBuffer):
def __init__(self, max_size, alpha=0.2,
eviction_strategy='rand'):
self.max_size = max_size
self.alpha = alpha
self.eviction_strategy = eviction_strategy
assert self.eviction_strategy in ['rand', 'fifo', 'rank']
self.remove_idx = 0
self.cur_size = 0
self.buffer = {}
self.priorities = np.zeros(self.max_size)
self.init_length = 0
def __len__(self):
return self.cur_size
def add(self, episodes, priorities, new_idxs=None):
"""Add episodes to buffer."""
if new_idxs is None:
idx = 0
new_idxs = []
while self.cur_size < self.max_size and idx < len(episodes):
self.buffer[self.cur_size] = episodes[idx]
new_idxs.append(self.cur_size)
self.cur_size += 1
idx += 1
if idx < len(episodes):
remove_idxs = self.remove_n(len(episodes) - idx)
for remove_idx in remove_idxs:
self.buffer[remove_idx] = episodes[idx]
new_idxs.append(remove_idx)
idx += 1
else:
assert len(new_idxs) == len(episodes)
for new_idx, ep in zip(new_idxs, episodes):
self.buffer[new_idx] = ep
self.priorities[new_idxs] = priorities
self.priorities[0:self.init_length] = np.max(
self.priorities[self.init_length:])
assert len(self.buffer) == self.cur_size
return new_idxs
def remove_n(self, n):
"""Get n items for removal."""
assert self.init_length + n <= self.cur_size
if self.eviction_strategy == 'rand':
# random removal
idxs = random.sample(xrange(self.init_length, self.cur_size), n)
elif self.eviction_strategy == 'fifo':
# overwrite elements in cyclical fashion
idxs = [
self.init_length +
(self.remove_idx + i) % (self.max_size - self.init_length)
for i in xrange(n)]
self.remove_idx = idxs[-1] + 1 - self.init_length
elif self.eviction_strategy == 'rank':
# remove lowest-priority indices
idxs = np.argpartition(self.priorities, n)[:n]
return idxs
def sampling_distribution(self):
p = self.priorities[:self.cur_size]
p = np.exp(self.alpha * (p - np.max(p)))
norm = np.sum(p)
if norm > 0:
uniform = 0.0
p = p / norm * (1 - uniform) + 1.0 / self.cur_size * uniform
else:
p = np.ones(self.cur_size) / self.cur_size
return p
def get_batch(self, n):
"""Get batch of episodes to train on."""
p = self.sampling_distribution()
idxs = np.random.choice(self.cur_size, size=n, replace=False, p=p)
self.last_batch = idxs
return [self.buffer[idx] for idx in idxs], p[idxs]
def update_last_batch(self, delta):
"""Update last batch idxs with new priority."""
self.priorities[self.last_batch] = np.abs(delta)
self.priorities[0:self.init_length] = np.max(
self.priorities[self.init_length:])
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer for coordinating single or multi-replica training.
Main point of entry for running models. Specifies most of
the parameters used by different algorithms.
"""
import tensorflow as tf
import numpy as np
import random
import os
import pickle
import controller
import model
import policy
import baseline
import objective
import full_episode_objective
import trust_region
import optimizers
import replay_buffer
import expert_paths
import gym_wrapper
import env_spec
app = tf.app
flags = tf.flags
logging = tf.logging
gfile = tf.gfile
FLAGS = flags.FLAGS
flags.DEFINE_string('env', 'Copy-v0', 'environment name')
flags.DEFINE_integer('batch_size', 100, 'batch size')
flags.DEFINE_integer('replay_batch_size', None, 'replay batch size; defaults to batch_size')
flags.DEFINE_integer('num_samples', 1,
'number of samples from each random seed initialization')
flags.DEFINE_integer('max_step', 200, 'max number of steps to train on')
flags.DEFINE_integer('cutoff_agent', 0,
'number of steps at which to cut-off agent. '
'Defaults to always cutoff')
flags.DEFINE_integer('num_steps', 100000, 'number of training steps')
flags.DEFINE_integer('validation_frequency', 100,
'every so many steps, output some stats')
flags.DEFINE_float('target_network_lag', 0.95,
'This exponential decay on online network yields target '
'network')
flags.DEFINE_string('sample_from', 'online',
'Sample actions from "online" network or "target" network')
flags.DEFINE_string('objective', 'pcl',
'pcl/upcl/a3c/trpo/reinforce/urex')
flags.DEFINE_bool('trust_region_p', False,
'use trust region for policy optimization')
flags.DEFINE_string('value_opt', None,
'leave as None to optimize it along with policy '
'(using critic_weight). Otherwise set to '
'"best_fit" (least squares regression), "lbfgs", or "grad"')
flags.DEFINE_float('max_divergence', 0.01,
'max divergence (i.e. KL) to allow during '
'trust region optimization')
flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
flags.DEFINE_float('clip_norm', 5.0, 'clip norm')
flags.DEFINE_float('clip_adv', 0.0, 'Clip advantages at this value. '
'Leave as 0 to not clip at all.')
flags.DEFINE_float('critic_weight', 0.1, 'critic weight')
flags.DEFINE_float('tau', 0.1, 'entropy regularizer.'
'If using decaying tau, this is the final value.')
flags.DEFINE_float('tau_decay', None,
'decay tau by this much every 100 steps')
flags.DEFINE_float('tau_start', 0.1,
'start tau at this value')
flags.DEFINE_float('eps_lambda', 0.0, 'relative entropy regularizer.')
flags.DEFINE_bool('update_eps_lambda', False,
'Update lambda automatically based on last 100 episodes.')
flags.DEFINE_float('gamma', 1.0, 'discount')
flags.DEFINE_integer('rollout', 10, 'rollout')
flags.DEFINE_bool('fixed_std', True,
'fix the std in Gaussian distributions')
flags.DEFINE_bool('input_prev_actions', True,
'input previous actions to policy network')
flags.DEFINE_bool('recurrent', True,
'use recurrent connections')
flags.DEFINE_bool('input_time_step', False,
'input time step into value calucations')
flags.DEFINE_bool('use_online_batch', True, 'train on batches as they are sampled')
flags.DEFINE_bool('batch_by_steps', False,
'ensure each training batch has batch_size * max_step steps')
flags.DEFINE_bool('unify_episodes', False,
'Make sure replay buffer holds entire episodes, '
'even across distinct sampling steps')
flags.DEFINE_integer('replay_buffer_size', 5000, 'replay buffer size')
flags.DEFINE_float('replay_buffer_alpha', 0.5, 'replay buffer alpha param')
flags.DEFINE_integer('replay_buffer_freq', 0,
'replay buffer frequency (only supports -1/0/1)')
flags.DEFINE_string('eviction', 'rand',
'how to evict from replay buffer: rand/rank/fifo')
flags.DEFINE_string('prioritize_by', 'rewards',
'Prioritize replay buffer by "rewards" or "step"')
flags.DEFINE_integer('num_expert_paths', 0,
'number of expert paths to seed replay buffer with')
flags.DEFINE_integer('internal_dim', 256, 'RNN internal dim')
flags.DEFINE_integer('value_hidden_layers', 0,
'number of hidden layers in value estimate')
flags.DEFINE_integer('tf_seed', 42, 'random seed for tensorflow')
flags.DEFINE_string('save_trajectories_dir', None,
'directory to save trajectories to, if desired')
flags.DEFINE_string('load_trajectories_file', None,
'file to load expert trajectories from')
# supervisor flags
flags.DEFINE_bool('supervisor', False, 'use supervisor training')
flags.DEFINE_integer('task_id', 0, 'task id')
flags.DEFINE_integer('ps_tasks', 0, 'number of ps tasks')
flags.DEFINE_integer('num_replicas', 1, 'number of replicas used')
flags.DEFINE_string('master', 'local', 'name of master')
flags.DEFINE_string('save_dir', '', 'directory to save model to')
flags.DEFINE_string('load_path', '', 'path of saved model to load (if none in save_dir)')
class Trainer(object):
"""Coordinates single or multi-replica training."""
def __init__(self):
self.batch_size = FLAGS.batch_size
self.replay_batch_size = FLAGS.replay_batch_size
if self.replay_batch_size is None:
self.replay_batch_size = self.batch_size
self.num_samples = FLAGS.num_samples
self.env_str = FLAGS.env
self.env = gym_wrapper.GymWrapper(self.env_str,
distinct=FLAGS.batch_size // self.num_samples,
count=self.num_samples)
self.env_spec = env_spec.EnvSpec(self.env.get_one())
self.max_step = FLAGS.max_step
self.cutoff_agent = FLAGS.cutoff_agent
self.num_steps = FLAGS.num_steps
self.validation_frequency = FLAGS.validation_frequency
self.target_network_lag = FLAGS.target_network_lag
self.sample_from = FLAGS.sample_from
assert self.sample_from in ['online', 'target']
self.critic_weight = FLAGS.critic_weight
self.objective = FLAGS.objective
self.trust_region_p = FLAGS.trust_region_p
self.value_opt = FLAGS.value_opt
assert not self.trust_region_p or self.objective in ['pcl', 'trpo']
assert self.objective != 'trpo' or self.trust_region_p
assert self.value_opt is None or self.critic_weight == 0.0
self.max_divergence = FLAGS.max_divergence
self.learning_rate = FLAGS.learning_rate
self.clip_norm = FLAGS.clip_norm
self.clip_adv = FLAGS.clip_adv
self.tau = FLAGS.tau
self.tau_decay = FLAGS.tau_decay
self.tau_start = FLAGS.tau_start
self.eps_lambda = FLAGS.eps_lambda
self.update_eps_lambda = FLAGS.update_eps_lambda
self.gamma = FLAGS.gamma
self.rollout = FLAGS.rollout
self.fixed_std = FLAGS.fixed_std
self.input_prev_actions = FLAGS.input_prev_actions
self.recurrent = FLAGS.recurrent
assert not self.trust_region_p or not self.recurrent
self.input_time_step = FLAGS.input_time_step
assert not self.input_time_step or (self.cutoff_agent <= self.max_step)
self.use_online_batch = FLAGS.use_online_batch
self.batch_by_steps = FLAGS.batch_by_steps
self.unify_episodes = FLAGS.unify_episodes
if self.unify_episodes:
assert self.batch_size == 1
self.replay_buffer_size = FLAGS.replay_buffer_size
self.replay_buffer_alpha = FLAGS.replay_buffer_alpha
self.replay_buffer_freq = FLAGS.replay_buffer_freq
assert self.replay_buffer_freq in [-1, 0, 1]
self.eviction = FLAGS.eviction
self.prioritize_by = FLAGS.prioritize_by
assert self.prioritize_by in ['rewards', 'step']
self.num_expert_paths = FLAGS.num_expert_paths
self.internal_dim = FLAGS.internal_dim
self.value_hidden_layers = FLAGS.value_hidden_layers
self.tf_seed = FLAGS.tf_seed
self.save_trajectories_dir = (
FLAGS.save_trajectories_dir or FLAGS.save_dir)
self.save_trajectories_file = (
os.path.join(
self.save_trajectories_dir, self.env_str.replace('-', '_'))
if self.save_trajectories_dir else None)
self.load_trajectories_file = FLAGS.load_trajectories_file
self.hparams = dict((attr, getattr(self, attr))
for attr in dir(self)
if not attr.startswith('__') and
not callable(getattr(self, attr)))
def hparams_string(self):
return '\n'.join('%s: %s' % item for item in sorted(self.hparams.items()))
def get_objective(self):
tau = self.tau
if self.tau_decay is not None:
assert self.tau_start >= self.tau
tau = tf.maximum(
tf.train.exponential_decay(
self.tau_start, self.global_step, 100, self.tau_decay),
self.tau)
if self.objective in ['pcl', 'a3c', 'trpo', 'upcl']:
cls = (objective.PCL if self.objective in ['pcl', 'upcl'] else
objective.TRPO if self.objective == 'trpo' else
objective.ActorCritic)
policy_weight = 1.0
return cls(self.learning_rate,
clip_norm=self.clip_norm,
policy_weight=policy_weight,
critic_weight=self.critic_weight,
tau=tau, gamma=self.gamma, rollout=self.rollout,
eps_lambda=self.eps_lambda, clip_adv=self.clip_adv)
elif self.objective in ['reinforce', 'urex']:
cls = (full_episode_objective.Reinforce
if self.objective == 'reinforce' else
full_episode_objective.UREX)
return cls(self.learning_rate,
clip_norm=self.clip_norm,
num_samples=self.num_samples,
tau=tau, bonus_weight=1.0) # TODO: bonus weight?
else:
assert False, 'Unknown objective %s' % self.objective
def get_policy(self):
if self.recurrent:
cls = policy.Policy
else:
cls = policy.MLPPolicy
return cls(self.env_spec, self.internal_dim,
fixed_std=self.fixed_std,
recurrent=self.recurrent,
input_prev_actions=self.input_prev_actions)
def get_baseline(self):
cls = (baseline.UnifiedBaseline if self.objective == 'upcl' else
baseline.Baseline)
return cls(self.env_spec, self.internal_dim,
input_prev_actions=self.input_prev_actions,
input_time_step=self.input_time_step,
input_policy_state=self.recurrent, # may want to change this
n_hidden_layers=self.value_hidden_layers,
hidden_dim=self.internal_dim,
tau=self.tau)
def get_trust_region_p_opt(self):
if self.trust_region_p:
return trust_region.TrustRegionOptimization(
max_divergence=self.max_divergence)
else:
return None
def get_value_opt(self):
if self.value_opt == 'grad':
return optimizers.GradOptimization(
learning_rate=self.learning_rate, max_iter=5, mix_frac=0.05)
elif self.value_opt == 'lbfgs':
return optimizers.LbfgsOptimization(max_iter=25, mix_frac=0.1)
elif self.value_opt == 'best_fit':
return optimizers.BestFitOptimization(mix_frac=1.0)
else:
return None
def get_model(self):
cls = model.Model
return cls(self.env_spec, self.global_step,
target_network_lag=self.target_network_lag,
sample_from=self.sample_from,
get_policy=self.get_policy,
get_baseline=self.get_baseline,
get_objective=self.get_objective,
get_trust_region_p_opt=self.get_trust_region_p_opt,
get_value_opt=self.get_value_opt)
def get_replay_buffer(self):
if self.replay_buffer_freq <= 0:
return None
else:
assert self.objective in ['pcl', 'upcl'], 'Can\'t use replay buffer with %s' % (
self.objective)
cls = replay_buffer.PrioritizedReplayBuffer
return cls(self.replay_buffer_size,
alpha=self.replay_buffer_alpha,
eviction_strategy=self.eviction)
def get_buffer_seeds(self):
return expert_paths.sample_expert_paths(
self.num_expert_paths, self.env_str, self.env_spec,
load_trajectories_file=self.load_trajectories_file)
def get_controller(self):
"""Get controller."""
cls = controller.Controller
return cls(self.env, self.env_spec, self.internal_dim,
use_online_batch=self.use_online_batch,
batch_by_steps=self.batch_by_steps,
unify_episodes=self.unify_episodes,
replay_batch_size=self.replay_batch_size,
max_step=self.max_step,
cutoff_agent=self.cutoff_agent,
save_trajectories_file=self.save_trajectories_file,
use_trust_region=self.trust_region_p,
use_value_opt=self.value_opt is not None,
update_eps_lambda=self.update_eps_lambda,
prioritize_by=self.prioritize_by,
get_model=self.get_model,
get_replay_buffer=self.get_replay_buffer,
get_buffer_seeds=self.get_buffer_seeds)
def do_before_step(self, step):
pass
def run(self):
"""Run training."""
is_chief = FLAGS.task_id == 0 or not FLAGS.supervisor
sv = None
def init_fn(sess, saver):
ckpt = None
if FLAGS.save_dir and sv is None:
load_dir = FLAGS.save_dir
ckpt = tf.train.get_checkpoint_state(load_dir)
if ckpt and ckpt.model_checkpoint_path:
logging.info('restoring from %s', ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
elif FLAGS.load_path:
logging.info('restoring from %s', FLAGS.load_path)
with gfile.AsUser('distbelief-brain-gpu'):
saver.restore(sess, FLAGS.load_path)
if FLAGS.supervisor:
with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)):
self.global_step = tf.contrib.framework.get_or_create_global_step()
tf.set_random_seed(FLAGS.tf_seed)
self.controller = self.get_controller()
self.model = self.controller.model
self.controller.setup()
saver = tf.train.Saver(max_to_keep=10)
step = self.model.global_step
sv = tf.Supervisor(logdir=FLAGS.save_dir,
is_chief=is_chief,
saver=saver,
save_model_secs=600,
summary_op=None, # we define it ourselves
save_summaries_secs=60,
global_step=step,
init_fn=lambda sess: init_fn(sess, saver))
sess = sv.PrepareSession(FLAGS.master)
else:
tf.set_random_seed(FLAGS.tf_seed)
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.controller = self.get_controller()
self.model = self.controller.model
self.controller.setup()
saver = tf.train.Saver(max_to_keep=10)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
init_fn(sess, saver)
self.sv = sv
self.sess = sess
logging.info('hparams:\n%s', self.hparams_string())
model_step = sess.run(self.model.global_step)
if model_step >= self.num_steps:
logging.info('training has reached final step')
return
losses = []
rewards = []
all_ep_rewards = []
for step in xrange(1 + self.num_steps):
if sv is not None and sv.ShouldStop():
logging.info('stopping supervisor')
break
self.do_before_step(step)
(loss, summary,
total_rewards, episode_rewards) = self.controller.train(sess)
losses.append(loss)
rewards.append(total_rewards)
all_ep_rewards.extend(episode_rewards)
if random.random() < 1 and is_chief and sv and sv._summary_writer:
sv.summary_computed(sess, summary)
model_step = sess.run(self.model.global_step)
if is_chief and step % self.validation_frequency == 0:
logging.info('at training step %d, model step %d: '
'avg loss %f, avg reward %f, '
'episode rewards: %f',
step, model_step,
np.mean(losses), np.mean(rewards),
np.mean(all_ep_rewards))
losses = []
rewards = []
all_ep_rewards = []
if model_step >= self.num_steps:
logging.info('training has reached final step')
break
if is_chief and sv is not None:
logging.info('saving final model to %s', sv.save_path)
sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
def main(unused_argv):
logging.set_verbosity(logging.INFO)
trainer = Trainer()
trainer.run()
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trust region optimization.
A lot of this is adapted from other's code.
See Schulman's Modular RL, wojzaremba's TRPO, etc.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
def var_size(v):
return int(np.prod([int(d) for d in v.shape]))
def gradients(loss, var_list):
grads = tf.gradients(loss, var_list)
return [g if g is not None else tf.zeros(v.shape)
for g, v in zip(grads, var_list)]
def flatgrad(loss, var_list):
grads = gradients(loss, var_list)
return tf.concat([tf.reshape(grad, [-1])
for (v, grad) in zip(var_list, grads)
if grad is not None], 0)
def get_flat(var_list):
return tf.concat([tf.reshape(v, [-1]) for v in var_list], 0)
def set_from_flat(var_list, flat_theta):
assigns = []
shapes = [v.shape for v in var_list]
sizes = [var_size(v) for v in var_list]
start = 0
assigns = []
for (shape, size, v) in zip(shapes, sizes, var_list):
assigns.append(v.assign(
tf.reshape(flat_theta[start:start + size], shape)))
start += size
assert start == sum(sizes)
return tf.group(*assigns)
class TrustRegionOptimization(object):
def __init__(self, max_divergence=0.1, cg_damping=0.1):
self.max_divergence = max_divergence
self.cg_damping = cg_damping
def setup_placeholders(self):
self.flat_tangent = tf.placeholder(tf.float32, [None], 'flat_tangent')
self.flat_theta = tf.placeholder(tf.float32, [None], 'flat_theta')
def setup(self, var_list, raw_loss, self_divergence,
divergence=None):
self.setup_placeholders()
self.raw_loss = raw_loss
self.divergence = divergence
self.loss_flat_gradient = flatgrad(raw_loss, var_list)
self.divergence_gradient = gradients(self_divergence, var_list)
shapes = [var.shape for var in var_list]
sizes = [var_size(var) for var in var_list]
start = 0
tangents = []
for shape, size in zip(shapes, sizes):
param = tf.reshape(self.flat_tangent[start:start + size], shape)
tangents.append(param)
start += size
assert start == sum(sizes)
self.grad_vector_product = sum(
tf.reduce_sum(g * t) for (g, t) in zip(self.divergence_gradient, tangents))
self.fisher_vector_product = flatgrad(self.grad_vector_product, var_list)
self.flat_vars = get_flat(var_list)
self.set_vars = set_from_flat(var_list, self.flat_theta)
def optimize(self, sess, feed_dict):
old_theta = sess.run(self.flat_vars)
loss_flat_grad = sess.run(self.loss_flat_gradient,
feed_dict=feed_dict)
def calc_fisher_vector_product(tangent):
feed_dict[self.flat_tangent] = tangent
fvp = sess.run(self.fisher_vector_product,
feed_dict=feed_dict)
fvp += self.cg_damping * tangent
return fvp
step_dir = conjugate_gradient(calc_fisher_vector_product, -loss_flat_grad)
shs = 0.5 * step_dir.dot(calc_fisher_vector_product(step_dir))
lm = np.sqrt(shs / self.max_divergence)
fullstep = step_dir / lm
neggdotstepdir = -loss_flat_grad.dot(step_dir)
def calc_loss(theta):
sess.run(self.set_vars, feed_dict={self.flat_theta: theta})
if self.divergence is None:
return sess.run(self.raw_loss, feed_dict=feed_dict), True
else:
raw_loss, divergence = sess.run(
[self.raw_loss, self.divergence], feed_dict=feed_dict)
return raw_loss, divergence < self.max_divergence
# find optimal theta
theta = linesearch(calc_loss, old_theta, fullstep, neggdotstepdir / lm)
if self.divergence is not None:
final_divergence = sess.run(self.divergence, feed_dict=feed_dict)
else:
final_divergence = None
# set vars accordingly
if final_divergence is None or final_divergence < self.max_divergence:
sess.run(self.set_vars, feed_dict={self.flat_theta: theta})
else:
sess.run(self.set_vars, feed_dict={self.flat_theta: old_theta})
def conjugate_gradient(f_Ax, b, cg_iters=10, residual_tol=1e-10):
p = b.copy()
r = b.copy()
x = np.zeros_like(b)
rdotr = r.dot(r)
for i in xrange(cg_iters):
z = f_Ax(p)
v = rdotr / p.dot(z)
x += v * p
r -= v * z
newrdotr = r.dot(r)
mu = newrdotr / rdotr
p = r + mu * p
rdotr = newrdotr
if rdotr < residual_tol:
break
return x
def linesearch(f, x, fullstep, expected_improve_rate):
accept_ratio = 0.1
max_backtracks = 10
fval, _ = f(x)
for (_n_backtracks, stepfrac) in enumerate(.5 ** np.arange(max_backtracks)):
xnew = x + stepfrac * fullstep
newfval, valid = f(xnew)
if not valid:
continue
actual_improve = fval - newfval
expected_improve = expected_improve_rate * stepfrac
ratio = actual_improve / expected_improve
if ratio > accept_ratio and actual_improve > 0:
return xnew
return x
bazel
.idea
bazel-bin
bazel-out
bazel-genfiles
bazel-ptn
bazel-testlogs
WORKSPACE
*.pyc
py_library(
name = "input_generator",
srcs = ["input_generator.py"],
deps = [
],
)
py_library(
name = "losses",
srcs = ["losses.py"],
deps = [
],
)
py_library(
name = "metrics",
srcs = ["metrics.py"],
deps = [
],
)
py_library(
name = "utils",
srcs = ["utils.py"],
deps = [
],
)
# Defines the Rotator model here
py_library(
name = "model_rotator",
srcs = ["model_rotator.py"],
deps = [
":input_generator",
":losses",
":metrics",
":utils",
"//nets:deeprotator_factory",
],
)
# Defines the Im2vox model here
py_library(
name = "model_voxel_generation",
srcs = ["model_voxel_generation.py"],
deps = [
":input_generator",
"//nets:im2vox_factory",
],
)
py_library(
name = "model_ptn",
srcs = ["model_ptn.py"],
deps = [
":losses",
":metrics",
":model_voxel_generation",
":utils",
"//nets:im2vox_factory",
],
)
py_binary(
name = "train_ptn",
srcs = ["train_ptn.py"],
deps = [
":model_ptn",
],
)
py_binary(
name = "eval_ptn",
srcs = ["eval_ptn.py"],
deps = [
":model_ptn",
],
)
py_binary(
name = "pretrain_rotator",
srcs = ["pretrain_rotator.py"],
deps = [
":model_rotator",
],
)
py_binary(
name = "eval_rotator",
srcs = ["eval_rotator.py"],
deps = [
":model_rotator",
],
)
# Perspective Transformer Nets
## Introduction
This is the TensorFlow implementation for the NIPS 2016 work ["Perspective Transformer Nets: Learning Single-View 3D Object Reconstrution without 3D Supervision"](https://papers.nips.cc/paper/6206-perspective-transformer-nets-learning-single-view-3d-object-reconstruction-without-3d-supervision.pdf)
Re-implemented by Xinchen Yan, Arkanath Pathak, Jasmine Hsu, Honglak Lee
Reference: [Orginal implementation in Torch](https://github.com/xcyan/nips16_PTN)
## How to run this code
This implementation is ready to be run locally or ["distributed across multiple machines/tasks"](https://www.tensorflow.org/deploy/distributed).
You will need to set the task number flag for each task when running in a distributed fashion.
Please refer to the original paper for parameter explanations and training details.
### Installation
* TensorFlow
* This code requires the latest open-source TensorFlow that you will need to build manually.
The [documentation](https://www.tensorflow.org/install/install_sources) provides the steps required for that.
* Bazel
* Follow the instructions [here](http://bazel.build/docs/install.html).
* Alternately, Download bazel from
[https://github.com/bazelbuild/bazel/releases](https://github.com/bazelbuild/bazel/releases)
for your system configuration.
* Check for the bazel version using this command: bazel version
* matplotlib
* Follow the instructions [here](https://matplotlib.org/users/installing.html).
* You can use a package repository like pip.
* scikit-image
* Follow the instructions [here](http://scikit-image.org/docs/dev/install.html).
* You can use a package repository like pip.
* PIL
* Install from [here](https://pypi.python.org/pypi/Pillow/2.2.1).
### Dataset
This code requires the dataset to be in *tfrecords* format with the following features:
* image
* Flattened list of image (float representations) for each view point.
* mask
* Flattened list of image masks (float representations) for each view point.
* vox
* Flattened list of voxels (float representations) for the object.
* This is needed for using vox loss and for prediction comparison.
You can download the ShapeNet Dataset in tfrecords format from [here](https://drive.google.com/file/d/0B12XukcbU7T7OHQ4MGh6d25qQlk)<sup>*</sup>.
<sup>*</sup> Disclaimer: This data is hosted personally by Arkanath Pathak for non-commercial research purposes. Please cite the [ShapeNet paper](https://arxiv.org/pdf/1512.03012.pdf) in your works when using ShapeNet for non-commercial research purposes.
### Pretraining: pretrain_rotator.py for each RNN step
$ bazel run -c opt :pretrain_rotator -- --step_size={} --init_model={}
Pass the init_model as the checkpoint path for the last step trained model.
You'll also need to set the inp_dir flag to where your data resides.
### Training: train_ptn.py with last pretrained model.
$ bazel run -c opt :train_ptn -- --init_model={}
### Example TensorBoard Visualizations
To compare the visualizations make sure to set the model_name flag different for each parametric setting:
This code adds summaries for each loss. For instance, these are the losses we encountered in the distributed pretraining for ShapeNet Chair Dataset with 10 workers and 16 parameter servers:
![ShapeNet Chair Pretraining](https://drive.google.com/uc?export=view&id=0B12XukcbU7T7bWdlTjhzbGJVaWs "ShapeNet Chair Experiment Pretraining Losses")
You can expect such images after fine tuning the training as "grid_vis" under **Image** summaries in TensorBoard:
![ShapeNet Chair experiments with projection weight of 1](https://drive.google.com/uc?export=view&id=0B12XukcbU7T7ZFV6aEVBSDdCMjQ "ShapeNet Chair Dataset Predictions")
Here the third and fifth columns are the predicted masks and voxels respectively, alongside their ground truth values.
A similar image for when trained on all ShapeNet Categories (Voxel visualizations might be skewed):
![ShapeNet All Categories experiments](https://drive.google.com/uc?export=view&id=0B12XukcbU7T7bDZKNFlkTVAzZmM "ShapeNet All Categories Dataset Predictions")
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains evaluation plan for the Im2vox model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from tensorflow import app
import model_ptn
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_string('inp_dir',
'',
'Directory path containing the input data (tfrecords).')
flags.DEFINE_string(
'dataset_name', 'shapenet_chair',
'Dataset name that is to be used for training and evaluation.')
flags.DEFINE_integer('z_dim', 512, '')
flags.DEFINE_integer('f_dim', 64, '')
flags.DEFINE_integer('fc_dim', 1024, '')
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.')
flags.DEFINE_integer('image_size', 64,
'Input images dimension (pixels) - width & height.')
flags.DEFINE_integer('vox_size', 32, 'Voxel prediction dimension.')
flags.DEFINE_integer('step_size', 24, '')
flags.DEFINE_integer('batch_size', 1, 'Batch size while training.')
flags.DEFINE_float('focal_length', 0.866, '')
flags.DEFINE_float('focal_range', 1.732, '')
flags.DEFINE_string('encoder_name', 'ptn_encoder',
'Name of the encoder network being used.')
flags.DEFINE_string('decoder_name', 'ptn_vox_decoder',
'Name of the decoder network being used.')
flags.DEFINE_string('projector_name', 'ptn_projector',
'Name of the projector network being used.')
# Save options
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn/eval/',
'Directory path for saving trained models and other data.')
flags.DEFINE_string('model_name', 'ptn_proj',
'Name of the model used in naming the TF job. Must be different for each run.')
flags.DEFINE_string('eval_set', 'val', 'Data partition to form evaluation on.')
# Optimization
flags.DEFINE_float('proj_weight', 10, 'Weighting factor for projection loss.')
flags.DEFINE_float('volume_weight', 0, 'Weighting factor for volume loss.')
flags.DEFINE_float('viewpoint_weight', 1,
'Weighting factor for viewpoint loss.')
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.')
flags.DEFINE_float('weight_decay', 0.001, '')
flags.DEFINE_float('clip_gradient_norm', 0, '')
# Summary
flags.DEFINE_integer('save_summaries_secs', 15, '')
flags.DEFINE_integer('eval_interval_secs', 60 * 5, '')
# Distribution
flags.DEFINE_string('master', '', '')
FLAGS = flags.FLAGS
def main(argv=()):
del argv # Unused.
eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name,
'eval_%s' % FLAGS.eval_set)
if not os.path.exists(eval_dir):
os.makedirs(eval_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
g = tf.Graph()
with g.as_default():
eval_params = FLAGS
eval_params.batch_size = 1
eval_params.step_size = FLAGS.num_views
###########
## model ##
###########
model = model_ptn.model_PTN(eval_params)
##########
## data ##
##########
eval_data = model.get_inputs(
FLAGS.data_sst_path,
FLAGS.dataset_name,
eval_params.eval_set,
eval_params.batch_size,
eval_params.image_size,
eval_params.vox_size,
is_training=False)
inputs = model.preprocess_with_all_views(eval_data)
##############
## model_fn ##
##############
model_fn = model.get_model_fn(is_training=False, run_projection=False)
outputs = model_fn(inputs)
#############
## metrics ##
#############
names_to_values, names_to_updates = model.get_metrics(inputs, outputs)
del names_to_values
################
## evaluation ##
################
num_batches = eval_data['num_samples']
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=eval_dir,
logdir=log_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains evaluation plan for the Rotator model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from tensorflow import app
import model_rotator as model
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_string('inp_dir',
'',
'Directory path containing the input data (tfrecords).')
flags.DEFINE_string(
'dataset_name', 'shapenet_chair',
'Dataset name that is to be used for training and evaluation.')
flags.DEFINE_integer('z_dim', 512, '')
flags.DEFINE_integer('a_dim', 3, '')
flags.DEFINE_integer('f_dim', 64, '')
flags.DEFINE_integer('fc_dim', 1024, '')
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.')
flags.DEFINE_integer('image_size', 64,
'Input images dimension (pixels) - width & height.')
flags.DEFINE_integer('step_size', 24, '')
flags.DEFINE_integer('batch_size', 2, '')
flags.DEFINE_string('encoder_name', 'ptn_encoder',
'Name of the encoder network being used.')
flags.DEFINE_string('decoder_name', 'ptn_im_decoder',
'Name of the decoder network being used.')
flags.DEFINE_string('rotator_name', 'ptn_rotator',
'Name of the rotator network being used.')
# Save options
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn_train/',
'Directory path for saving trained models and other data.')
flags.DEFINE_string('model_name', 'ptn_proj',
'Name of the model used in naming the TF job. Must be different for each run.')
# Optimization
flags.DEFINE_float('image_weight', 10, '')
flags.DEFINE_float('mask_weight', 1, '')
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.')
flags.DEFINE_float('weight_decay', 0.001, '')
flags.DEFINE_float('clip_gradient_norm', 0, '')
# Summary
flags.DEFINE_integer('save_summaries_secs', 15, '')
flags.DEFINE_integer('eval_interval_secs', 60 * 5, '')
# Scheduling
flags.DEFINE_string('master', 'local', '')
FLAGS = flags.FLAGS
def main(argv=()):
del argv # Unused.
eval_dir = os.path.join(FLAGS.checkpoint_dir,
FLAGS.model_name, 'train')
log_dir = os.path.join(FLAGS.checkpoint_dir,
FLAGS.model_name, 'eval')
if not os.path.exists(eval_dir):
os.makedirs(eval_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
g = tf.Graph()
if FLAGS.step_size < FLAGS.num_views:
raise ValueError('Impossible step_size, must not be less than num_views.')
g = tf.Graph()
with g.as_default():
##########
## data ##
##########
val_data = model.get_inputs(
FLAGS.data_sst_path,
FLAGS.dataset_name,
'val',
FLAGS.batch_size,
FLAGS.image_size,
is_training=False)
inputs = model.preprocess(val_data, FLAGS.step_size)
###########
## model ##
###########
model_fn = model.get_model_fn(FLAGS, is_training=False)
outputs = model_fn(inputs)
#############
## metrics ##
#############
names_to_values, names_to_updates = model.get_metrics(
inputs, outputs, FLAGS)
del names_to_values
################
## evaluation ##
################
num_batches = int(val_data['num_samples'] / FLAGS.batch_size)
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=eval_dir,
logdir=log_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides dataset dictionaries as used in our network models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.data import dataset
from tensorflow.contrib.slim.python.slim.data import dataset_data_provider
from tensorflow.contrib.slim.python.slim.data import tfexample_decoder
_ITEMS_TO_DESCRIPTIONS = {
'image': 'Images',
'mask': 'Masks',
'vox': 'Voxels'
}
def _get_split(file_pattern, num_samples, num_views, image_size, vox_size):
"""Get dataset.Dataset for the given dataset file pattern and properties."""
# A dictionary from TF-Example keys to tf.FixedLenFeature instance.
keys_to_features = {
'image': tf.FixedLenFeature(
shape=[num_views, image_size, image_size, 3],
dtype=tf.float32, default_value=None),
'mask': tf.FixedLenFeature(
shape=[num_views, image_size, image_size, 1],
dtype=tf.float32, default_value=None),
'vox': tf.FixedLenFeature(
shape=[vox_size, vox_size, vox_size, 1],
dtype=tf.float32, default_value=None),
}
items_to_handler = {
'image': tfexample_decoder.Tensor(
'image', shape=[num_views, image_size, image_size, 3]),
'mask': tfexample_decoder.Tensor(
'mask', shape=[num_views, image_size, image_size, 1]),
'vox': tfexample_decoder.Tensor(
'vox', shape=[vox_size, vox_size, vox_size, 1])
}
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handler)
return dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=num_samples,
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS)
def get(dataset_dir,
dataset_name,
split_name,
shuffle=True,
num_readers=1,
common_queue_capacity=64,
common_queue_min=50):
"""Provides input data for a specified dataset and split."""
dataset_to_kwargs = {
'shapenet_chair': {
'file_pattern': '03001627_%s.tfrecords' % split_name,
'num_views': 24,
'image_size': 64,
'vox_size': 32,
}, 'shapenet_all': {
'file_pattern': '*_%s.tfrecords' % split_name,
'num_views': 24,
'image_size': 64,
'vox_size': 32,
},
}
split_sizes = {
'shapenet_chair': {
'train': 4744,
'val': 678,
'test': 1356,
},
'shapenet_all': {
'train': 30643,
'val': 4378,
'test': 8762,
}
}
kwargs = dataset_to_kwargs[dataset_name]
kwargs['file_pattern'] = os.path.join(dataset_dir, kwargs['file_pattern'])
kwargs['num_samples'] = split_sizes[dataset_name][split_name]
dataset_split = _get_split(**kwargs)
data_provider = dataset_data_provider.DatasetDataProvider(
dataset_split,
num_readers=num_readers,
common_queue_capacity=common_queue_capacity,
common_queue_min=common_queue_min,
shuffle=shuffle)
inputs = {
'num_samples': dataset_split.num_samples,
}
[image, mask, vox] = data_provider.get(['image', 'mask', 'vox'])
inputs['image'] = image
inputs['mask'] = mask
inputs['voxel'] = vox
return inputs
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the various loss functions in use by the PTN model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def add_rotator_image_loss(inputs, outputs, step_size, weight_scale):
"""Computes the image loss of deep rotator model.
Args:
inputs: Input dictionary to the model containing keys
such as `images_k'.
outputs: Output dictionary returned by the model containing keys
such as `images_k'.
step_size: A scalar representing the number of recurrent
steps (number of repeated out-of-plane rotations)
in the deep rotator network (int).
weight_scale: A reweighting factor applied over the image loss (float).
Returns:
A `Tensor' scalar that returns averaged L2 loss
(divided by batch_size and step_size) between the
ground-truth images (RGB) and predicted images (tf.float32).
"""
batch_size = tf.shape(inputs['images_0'])[0]
image_loss = 0
for k in range(1, step_size + 1):
image_loss += tf.nn.l2_loss(
inputs['images_%d' % k] - outputs['images_%d' % k])
image_loss /= tf.to_float(step_size * batch_size)
slim.summaries.add_scalar_summary(
image_loss, 'image_loss', prefix='losses')
image_loss *= weight_scale
return image_loss
def add_rotator_mask_loss(inputs, outputs, step_size, weight_scale):
"""Computes the mask loss of deep rotator model.
Args:
inputs: Input dictionary to the model containing keys
such as `masks_k'.
outputs: Output dictionary returned by the model containing
keys such as `masks_k'.
step_size: A scalar representing the number of recurrent
steps (number of repeated out-of-plane rotations)
in the deep rotator network (int).
weight_scale: A reweighting factor applied over the mask loss (float).
Returns:
A `Tensor' that returns averaged L2 loss
(divided by batch_size and step_size) between the ground-truth masks
(object silhouettes) and predicted masks (tf.float32).
"""
batch_size = tf.shape(inputs['images_0'])[0]
mask_loss = 0
for k in range(1, step_size + 1):
mask_loss += tf.nn.l2_loss(
inputs['masks_%d' % k] - outputs['masks_%d' % k])
mask_loss /= tf.to_float(step_size * batch_size)
slim.summaries.add_scalar_summary(
mask_loss, 'mask_loss', prefix='losses')
mask_loss *= weight_scale
return mask_loss
def add_volume_proj_loss(inputs, outputs, num_views, weight_scale):
"""Computes the projection loss of voxel generation model.
Args:
inputs: Input dictionary to the model containing keys such as
`images_1'.
outputs: Output dictionary returned by the model containing keys
such as `masks_k' and ``projs_k'.
num_views: A integer scalar represents the total number of
viewpoints for each of the object (int).
weight_scale: A reweighting factor applied over the projection loss (float).
Returns:
A `Tensor' that returns the averaged L2 loss
(divided by batch_size and num_views) between the ground-truth
masks (object silhouettes) and predicted masks (tf.float32).
"""
batch_size = tf.shape(inputs['images_1'])[0]
proj_loss = 0
for k in range(num_views):
proj_loss += tf.nn.l2_loss(
outputs['masks_%d' % (k + 1)] - outputs['projs_%d' % (k + 1)])
proj_loss /= tf.to_float(num_views * batch_size)
slim.summaries.add_scalar_summary(
proj_loss, 'proj_loss', prefix='losses')
proj_loss *= weight_scale
return proj_loss
def add_volume_loss(inputs, outputs, num_views, weight_scale):
"""Computes the volume loss of voxel generation model.
Args:
inputs: Input dictionary to the model containing keys such as
`images_1' and `voxels'.
outputs: Output dictionary returned by the model containing keys
such as `voxels_k'.
num_views: A scalar representing the total number of
viewpoints for each object (int).
weight_scale: A reweighting factor applied over the volume
loss (tf.float32).
Returns:
A `Tensor' that returns the averaged L2 loss
(divided by batch_size and num_views) between the ground-truth
volumes and predicted volumes (tf.float32).
"""
batch_size = tf.shape(inputs['images_1'])[0]
vol_loss = 0
for k in range(num_views):
vol_loss += tf.nn.l2_loss(
inputs['voxels'] - outputs['voxels_%d' % (k + 1)])
vol_loss /= tf.to_float(num_views * batch_size)
slim.summaries.add_scalar_summary(
vol_loss, 'vol_loss', prefix='losses')
vol_loss *= weight_scale
return vol_loss
def regularization_loss(scopes, params):
"""Computes the weight decay as regularization during training.
Args:
scopes: A list of different components of the model such as
``encoder'', ``decoder'' and ``projector''.
params: Parameters of the model.
Returns:
Regularization loss (tf.float32).
"""
reg_loss = tf.zeros(dtype=tf.float32, shape=[])
if params.weight_decay > 0:
is_trainable = lambda x: x in tf.trainable_variables()
is_weights = lambda x: 'weights' in x.name
for scope in scopes:
scope_vars = filter(is_trainable,
tf.contrib.framework.get_model_variables(scope))
scope_vars = filter(is_weights, scope_vars)
if scope_vars:
reg_loss += tf.add_n([tf.nn.l2_loss(var) for var in scope_vars])
slim.summaries.add_scalar_summary(
reg_loss, 'reg_loss', prefix='losses')
reg_loss *= params.weight_decay
return reg_loss
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides metrics used by PTN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def add_image_pred_metrics(
inputs, outputs, num_views, upscale_factor):
"""Computes the image prediction metrics.
Args:
inputs: Input dictionary of the deep rotator model (model_rotator.py).
outputs: Output dictionary of the deep rotator model (model_rotator.py).
num_views: An integer scalar representing the total number
of different viewpoints for each object in the dataset.
upscale_factor: A float scalar representing the number of pixels
per image (num_channels x image_height x image_width).
Returns:
names_to_values: A dictionary representing the current value
of the metric.
names_to_updates: A dictionary representing the operation
that accumulates the error from a batch of data.
"""
names_to_values = dict()
names_to_updates = dict()
for k in xrange(num_views):
tmp_value, tmp_update = tf.contrib.metrics.streaming_mean_squared_error(
outputs['images_%d' % (k + 1)], inputs['images_%d' % (k + 1)])
name = 'image_pred/rnn_%d' % (k + 1)
names_to_values.update({name: tmp_value * upscale_factor})
names_to_updates.update({name: tmp_update})
return names_to_values, names_to_updates
def add_mask_pred_metrics(
inputs, outputs, num_views, upscale_factor):
"""Computes the mask prediction metrics.
Args:
inputs: Input dictionary of the deep rotator model (model_rotator.py).
outputs: Output dictionary of the deep rotator model (model_rotator.py).
num_views: An integer scalar representing the total number
of different viewpoints for each object in the dataset.
upscale_factor: A float scalar representing the number of pixels
per image (num_channels x image_height x image_width).
Returns:
names_to_values: A dictionary representing the current value
of the metric.
names_to_updates: A dictionary representing the operation
that accumulates the error from a batch of data.
"""
names_to_values = dict()
names_to_updates = dict()
for k in xrange(num_views):
tmp_value, tmp_update = tf.contrib.metrics.streaming_mean_squared_error(
outputs['masks_%d' % (k + 1)], inputs['masks_%d' % (k + 1)])
name = 'mask_pred/rnn_%d' % (k + 1)
names_to_values.update({name: tmp_value * upscale_factor})
names_to_updates.update({name: tmp_update})
return names_to_values, names_to_updates
def add_volume_iou_metrics(inputs, outputs):
"""Computes the per-instance volume IOU.
Args:
inputs: Input dictionary of the voxel generation model.
outputs: Output dictionary returned by the voxel generation model.
Returns:
names_to_values: metrics->values (dict).
names_to_updates: metrics->ops (dict).
"""
names_to_values = dict()
names_to_updates = dict()
labels = tf.greater_equal(inputs['voxels'], 0.5)
predictions = tf.greater_equal(outputs['voxels_1'], 0.5)
labels = 2 - tf.to_int32(labels)
predictions = 3 - tf.to_int32(predictions) * 2
tmp_values, tmp_updates = tf.metrics.mean_iou(
labels=labels,
predictions=predictions,
num_classes=3)
names_to_values['volume_iou'] = tmp_values * 3.0
names_to_updates['volume_iou'] = tmp_updates
return names_to_values, names_to_updates
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementations for Im2Vox PTN (NIPS16) model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import losses
import metrics
import model_voxel_generation
import utils
from nets import im2vox_factory
slim = tf.contrib.slim
class model_PTN(model_voxel_generation.Im2Vox): # pylint:disable=invalid-name
"""Inherits the generic Im2Vox model class and implements the functions."""
def __init__(self, params):
super(model_PTN, self).__init__(params)
# For testing, this selects all views in input
def preprocess_with_all_views(self, raw_inputs):
(quantity, num_views) = raw_inputs['images'].get_shape().as_list()[:2]
inputs = dict()
inputs['voxels'] = []
inputs['images_1'] = []
for k in xrange(num_views):
inputs['matrix_%d' % (k + 1)] = []
inputs['matrix_1'] = []
for n in xrange(quantity):
for k in xrange(num_views):
inputs['images_1'].append(raw_inputs['images'][n, k, :, :, :])
inputs['voxels'].append(raw_inputs['voxels'][n, :, :, :, :])
tf_matrix = self.get_transform_matrix(k)
inputs['matrix_%d' % (k + 1)].append(tf_matrix)
inputs['images_1'] = tf.stack(inputs['images_1'])
inputs['voxels'] = tf.stack(inputs['voxels'])
for k in xrange(num_views):
inputs['matrix_%d' % (k + 1)] = tf.stack(inputs['matrix_%d' % (k + 1)])
return inputs
def get_model_fn(self, is_training=True, reuse=False, run_projection=True):
return im2vox_factory.get(self._params, is_training, reuse, run_projection)
def get_regularization_loss(self, scopes):
return losses.regularization_loss(scopes, self._params)
def get_loss(self, inputs, outputs):
"""Computes the loss used for PTN paper (projection + volume loss)."""
g_loss = tf.zeros(dtype=tf.float32, shape=[])
if self._params.proj_weight:
g_loss += losses.add_volume_proj_loss(
inputs, outputs, self._params.step_size, self._params.proj_weight)
if self._params.volume_weight:
g_loss += losses.add_volume_loss(inputs, outputs, 1,
self._params.volume_weight)
slim.summaries.add_scalar_summary(g_loss, 'im2vox_loss', prefix='losses')
return g_loss
def get_metrics(self, inputs, outputs):
"""Aggregate the metrics for voxel generation model.
Args:
inputs: Input dictionary of the voxel generation model.
outputs: Output dictionary returned by the voxel generation model.
Returns:
names_to_values: metrics->values (dict).
names_to_updates: metrics->ops (dict).
"""
names_to_values = dict()
names_to_updates = dict()
tmp_values, tmp_updates = metrics.add_volume_iou_metrics(inputs, outputs)
names_to_values.update(tmp_values)
names_to_updates.update(tmp_updates)
for name, value in names_to_values.iteritems():
slim.summaries.add_scalar_summary(
value, name, prefix='eval', print_summary=True)
return names_to_values, names_to_updates
def write_disk_grid(self,
global_step,
log_dir,
input_images,
gt_projs,
pred_projs,
input_voxels=None,
output_voxels=None):
"""Function called by TF to save the prediction periodically."""
summary_freq = self._params.save_every
def write_grid(input_images, gt_projs, pred_projs, global_step,
input_voxels, output_voxels):
"""Native python function to call for writing images to files."""
grid = _build_image_grid(
input_images,
gt_projs,
pred_projs,
input_voxels=input_voxels,
output_voxels=output_voxels)
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
return grid
save_op = tf.py_func(write_grid, [
input_images, gt_projs, pred_projs, global_step, input_voxels,
output_voxels
], [tf.uint8], 'write_grid')[0]
slim.summaries.add_image_summary(
tf.expand_dims(save_op, axis=0), name='grid_vis')
return save_op
def get_transform_matrix(self, view_out):
"""Get the 4x4 Perspective Transfromation matrix used for PTN."""
num_views = self._params.num_views
focal_length = self._params.focal_length
focal_range = self._params.focal_range
phi = 30
theta_interval = 360.0 / num_views
theta = theta_interval * view_out
# pylint: disable=invalid-name
camera_matrix = np.zeros((4, 4), dtype=np.float32)
intrinsic_matrix = np.eye(4, dtype=np.float32)
extrinsic_matrix = np.eye(4, dtype=np.float32)
sin_phi = np.sin(float(phi) / 180.0 * np.pi)
cos_phi = np.cos(float(phi) / 180.0 * np.pi)
sin_theta = np.sin(float(-theta) / 180.0 * np.pi)
cos_theta = np.cos(float(-theta) / 180.0 * np.pi)
rotation_azimuth = np.zeros((3, 3), dtype=np.float32)
rotation_azimuth[0, 0] = cos_theta
rotation_azimuth[2, 2] = cos_theta
rotation_azimuth[0, 2] = -sin_theta
rotation_azimuth[2, 0] = sin_theta
rotation_azimuth[1, 1] = 1.0
## rotation axis -- x
rotation_elevation = np.zeros((3, 3), dtype=np.float32)
rotation_elevation[0, 0] = cos_phi
rotation_elevation[0, 1] = sin_phi
rotation_elevation[1, 0] = -sin_phi
rotation_elevation[1, 1] = cos_phi
rotation_elevation[2, 2] = 1.0
rotation_matrix = np.matmul(rotation_azimuth, rotation_elevation)
displacement = np.zeros((3, 1), dtype=np.float32)
displacement[0, 0] = float(focal_length) + float(focal_range) / 2.0
displacement = np.matmul(rotation_matrix, displacement)
extrinsic_matrix[0:3, 0:3] = rotation_matrix
extrinsic_matrix[0:3, 3:4] = -displacement
intrinsic_matrix[2, 2] = 1.0 / float(focal_length)
intrinsic_matrix[1, 1] = 1.0 / float(focal_length)
camera_matrix = np.matmul(extrinsic_matrix, intrinsic_matrix)
return camera_matrix
def _build_image_grid(input_images,
gt_projs,
pred_projs,
input_voxels,
output_voxels,
vis_size=128):
"""Builds a grid image by concatenating the input images."""
quantity = input_images.shape[0]
for row in xrange(int(quantity / 3)):
for col in xrange(3):
index = row * 3 + col
input_img_ = utils.resize_image(input_images[index, :, :, :], vis_size,
vis_size)
gt_proj_ = utils.resize_image(gt_projs[index, :, :, :], vis_size,
vis_size)
pred_proj_ = utils.resize_image(pred_projs[index, :, :, :], vis_size,
vis_size)
gt_voxel_vis = utils.resize_image(
utils.display_voxel(input_voxels[index, :, :, :, 0]), vis_size,
vis_size)
pred_voxel_vis = utils.resize_image(
utils.display_voxel(output_voxels[index, :, :, :, 0]), vis_size,
vis_size)
if col == 0:
tmp_ = np.concatenate(
[input_img_, gt_proj_, pred_proj_, gt_voxel_vis, pred_voxel_vis], 1)
else:
tmp_ = np.concatenate([
tmp_, input_img_, gt_proj_, pred_proj_, gt_voxel_vis, pred_voxel_vis
], 1)
if row == 0:
out_grid = tmp_
else:
out_grid = np.concatenate([out_grid, tmp_], 0)
return out_grid
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for pretraining (rotator) as described in PTN paper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import input_generator
import losses
import metrics
import utils
from nets import deeprotator_factory
slim = tf.contrib.slim
def _get_data_from_provider(inputs, batch_size, split_name):
"""Returns dictionary of batch input data processed by tf.train.batch."""
images, masks = tf.train.batch(
[inputs['image'], inputs['mask']],
batch_size=batch_size,
num_threads=8,
capacity=8 * batch_size,
name='batching_queues/%s' % (split_name))
outputs = dict()
outputs['images'] = images
outputs['masks'] = masks
outputs['num_samples'] = inputs['num_samples']
return outputs
def get_inputs(dataset_dir, dataset_name, split_name, batch_size, image_size,
is_training):
"""Loads the given dataset and split."""
del image_size # Unused
with tf.variable_scope('data_loading_%s/%s' % (dataset_name, split_name)):
common_queue_min = 50
common_queue_capacity = 256
num_readers = 4
inputs = input_generator.get(
dataset_dir,
dataset_name,
split_name,
shuffle=is_training,
num_readers=num_readers,
common_queue_min=common_queue_min,
common_queue_capacity=common_queue_capacity)
return _get_data_from_provider(inputs, batch_size, split_name)
def preprocess(raw_inputs, step_size):
"""Selects the subset of viewpoints to train on."""
shp = raw_inputs['images'].get_shape().as_list()
quantity = shp[0]
num_views = shp[1]
image_size = shp[2]
del image_size # Unused
batch_rot = np.zeros((quantity, 3), dtype=np.float32)
inputs = dict()
for n in xrange(step_size + 1):
inputs['images_%d' % n] = []
inputs['masks_%d' % n] = []
for n in xrange(quantity):
view_in = np.random.randint(0, num_views)
rng_rot = np.random.randint(0, 2)
if step_size == 1:
rng_rot = np.random.randint(0, 3)
delta = 0
if rng_rot == 0:
delta = -1
batch_rot[n, 2] = 1
elif rng_rot == 1:
delta = 1
batch_rot[n, 0] = 1
else:
delta = 0
batch_rot[n, 1] = 1
inputs['images_0'].append(raw_inputs['images'][n, view_in, :, :, :])
inputs['masks_0'].append(raw_inputs['masks'][n, view_in, :, :, :])
view_out = view_in
for k in xrange(1, step_size + 1):
view_out += delta
if view_out >= num_views:
view_out = 0
if view_out < 0:
view_out = num_views - 1
inputs['images_%d' % k].append(raw_inputs['images'][n, view_out, :, :, :])
inputs['masks_%d' % k].append(raw_inputs['masks'][n, view_out, :, :, :])
for n in xrange(step_size + 1):
inputs['images_%d' % n] = tf.stack(inputs['images_%d' % n])
inputs['masks_%d' % n] = tf.stack(inputs['masks_%d' % n])
inputs['actions'] = tf.constant(batch_rot, dtype=tf.float32)
return inputs
def get_init_fn(scopes, params):
"""Initialization assignment operator function used while training."""
if not params.init_model:
return None
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
params.init_model, var_list)
def init_assign_function(sess):
sess.run(init_assign_op, init_feed_dict)
return init_assign_function
def get_model_fn(params, is_training, reuse=False):
return deeprotator_factory.get(params, is_training, reuse)
def get_regularization_loss(scopes, params):
return losses.regularization_loss(scopes, params)
def get_loss(inputs, outputs, params):
"""Computes the rotator loss."""
g_loss = tf.zeros(dtype=tf.float32, shape=[])
if hasattr(params, 'image_weight'):
g_loss += losses.add_rotator_image_loss(inputs, outputs, params.step_size,
params.image_weight)
if hasattr(params, 'mask_weight'):
g_loss += losses.add_rotator_mask_loss(inputs, outputs, params.step_size,
params.mask_weight)
slim.summaries.add_scalar_summary(
g_loss, 'rotator_loss', prefix='losses')
return g_loss
def get_train_op_for_scope(loss, optimizer, scopes, params):
"""Train operation function for the given scope used file training."""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
update_ops = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
return slim.learning.create_train_op(
loss,
optimizer,
update_ops=update_ops,
variables_to_train=var_list,
clip_gradient_norm=params.clip_gradient_norm)
def get_metrics(inputs, outputs, params):
names_to_values, names_to_updates = metrics.rotator_metrics(
inputs, outputs, params)
return names_to_values, names_to_updates
def write_disk_grid(global_step, summary_freq, log_dir, input_images,
output_images, pred_images, pred_masks):
"""Function called by TF to save the prediction periodically."""
def write_grid(grid, global_step):
"""Native python function to call for writing images to files."""
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
return 0
grid = _build_image_grid(input_images, output_images, pred_images, pred_masks)
slim.summaries.add_image_summary(
tf.expand_dims(grid, axis=0), name='grid_vis')
save_op = tf.py_func(write_grid, [grid, global_step], [tf.int64],
'write_grid')[0]
return save_op
def _build_image_grid(input_images, output_images, pred_images, pred_masks):
"""Builds a grid image by concatenating the input images."""
quantity = input_images.get_shape().as_list()[0]
for row in xrange(int(quantity / 4)):
for col in xrange(4):
index = row * 4 + col
input_img_ = input_images[index, :, :, :]
output_img_ = output_images[index, :, :, :]
pred_img_ = pred_images[index, :, :, :]
pred_mask_ = tf.tile(pred_masks[index, :, :, :], [1, 1, 3])
if col == 0:
tmp_ = tf.concat([input_img_, output_img_, pred_img_, pred_mask_],
1) ## to the right
else:
tmp_ = tf.concat([tmp_, input_img_, output_img_, pred_img_, pred_mask_],
1)
if row == 0:
out_grid = tmp_
else:
out_grid = tf.concat([out_grid, tmp_], 0)
return out_grid
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for voxel generation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import os
import numpy as np
import tensorflow as tf
import input_generator
import utils
slim = tf.contrib.slim
class Im2Vox(object):
"""Defines the voxel generation model."""
__metaclass__ = abc.ABCMeta
def __init__(self, params):
self._params = params
@abc.abstractmethod
def get_metrics(self, inputs, outputs):
"""Gets dictionaries from metrics to value `Tensors` & update `Tensors`."""
pass
@abc.abstractmethod
def get_loss(self, inputs, outputs):
pass
@abc.abstractmethod
def get_regularization_loss(self, scopes):
pass
def set_params(self, params):
self._params = params
def get_inputs(self,
dataset_dir,
dataset_name,
split_name,
batch_size,
image_size,
vox_size,
is_training=True):
"""Loads data for a specified dataset and split."""
del image_size, vox_size
with tf.variable_scope('data_loading_%s/%s' % (dataset_name, split_name)):
common_queue_min = 64
common_queue_capacity = 256
num_readers = 4
inputs = input_generator.get(
dataset_dir,
dataset_name,
split_name,
shuffle=is_training,
num_readers=num_readers,
common_queue_min=common_queue_min,
common_queue_capacity=common_queue_capacity)
images, voxels = tf.train.batch(
[inputs['image'], inputs['voxel']],
batch_size=batch_size,
num_threads=8,
capacity=8 * batch_size,
name='batching_queues/%s/%s' % (dataset_name, split_name))
outputs = dict()
outputs['images'] = images
outputs['voxels'] = voxels
outputs['num_samples'] = inputs['num_samples']
return outputs
def preprocess(self, raw_inputs, step_size):
"""Selects the subset of viewpoints to train on."""
(quantity, num_views) = raw_inputs['images'].get_shape().as_list()[:2]
inputs = dict()
inputs['voxels'] = raw_inputs['voxels']
for k in xrange(step_size):
inputs['images_%d' % (k + 1)] = []
inputs['matrix_%d' % (k + 1)] = []
for n in xrange(quantity):
selected_views = np.random.choice(num_views, step_size, replace=False)
for k in xrange(step_size):
view_selected = selected_views[k]
inputs['images_%d' %
(k + 1)].append(raw_inputs['images'][n, view_selected, :, :, :])
tf_matrix = self.get_transform_matrix(view_selected)
inputs['matrix_%d' % (k + 1)].append(tf_matrix)
for k in xrange(step_size):
inputs['images_%d' % (k + 1)] = tf.stack(inputs['images_%d' % (k + 1)])
inputs['matrix_%d' % (k + 1)] = tf.stack(inputs['matrix_%d' % (k + 1)])
return inputs
def get_init_fn(self, scopes):
"""Initialization assignment operator function used while training."""
if not self._params.init_model:
return None
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
self._params.init_model, var_list)
def init_assign_function(sess):
sess.run(init_assign_op, init_feed_dict)
return init_assign_function
def get_train_op_for_scope(self, loss, optimizer, scopes):
"""Train operation function for the given scope used file training."""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
update_ops = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
return slim.learning.create_train_op(
loss,
optimizer,
update_ops=update_ops,
variables_to_train=var_list,
clip_gradient_norm=self._params.clip_gradient_norm)
def write_disk_grid(self,
global_step,
log_dir,
input_images,
gt_projs,
pred_projs,
pred_voxels=None):
"""Function called by TF to save the prediction periodically."""
summary_freq = self._params.save_every
def write_grid(input_images, gt_projs, pred_projs, pred_voxels,
global_step):
"""Native python function to call for writing images to files."""
grid = _build_image_grid(input_images, gt_projs, pred_projs, pred_voxels)
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
with open(
os.path.join(log_dir, 'pred_voxels_%s' % str(global_step)),
'w') as fout:
np.save(fout, pred_voxels)
with open(
os.path.join(log_dir, 'input_images_%s' % str(global_step)),
'w') as fout:
np.save(fout, input_images)
return grid
py_func_args = [
input_images, gt_projs, pred_projs, pred_voxels, global_step
]
save_grid_op = tf.py_func(write_grid, py_func_args, [tf.uint8],
'wrtie_grid')[0]
slim.summaries.add_image_summary(
tf.expand_dims(save_grid_op, axis=0), name='grid_vis')
return save_grid_op
def _build_image_grid(input_images, gt_projs, pred_projs, pred_voxels):
"""Build the visualization grid with py_func."""
quantity, img_height, img_width = input_images.shape[:3]
for row in xrange(int(quantity / 3)):
for col in xrange(3):
index = row * 3 + col
input_img_ = input_images[index, :, :, :]
gt_proj_ = gt_projs[index, :, :, :]
pred_proj_ = pred_projs[index, :, :, :]
pred_voxel_ = utils.display_voxel(pred_voxels[index, :, :, :, 0])
pred_voxel_ = utils.resize_image(pred_voxel_, img_height, img_width)
if col == 0:
tmp_ = np.concatenate([input_img_, gt_proj_, pred_proj_, pred_voxel_],
1)
else:
tmp_ = np.concatenate(
[tmp_, input_img_, gt_proj_, pred_proj_, pred_voxel_], 1)
if row == 0:
out_grid = tmp_
else:
out_grid = np.concatenate([out_grid, tmp_], 0)
out_grid = out_grid.astype(np.uint8)
return out_grid
package(default_visibility = ["//visibility:public"])
py_library(
name = "deeprotator_factory",
srcs = ["deeprotator_factory.py"],
deps = [
":ptn_encoder",
":ptn_im_decoder",
":ptn_rotator",
],
)
py_library(
name = "im2vox_factory",
srcs = ["im2vox_factory.py"],
deps = [
":perspective_projector",
":ptn_encoder",
":ptn_vox_decoder",
],
)
py_library(
name = "perspective_projector",
srcs = ["perspective_projector.py"],
deps = [
":perspective_transform",
],
)
py_library(
name = "perspective_transform",
srcs = ["perspective_transform.py"],
deps = [
],
)
py_library(
name = "ptn_encoder",
srcs = ["ptn_encoder.py"],
deps = [
],
)
py_library(
name = "ptn_im_decoder",
srcs = ["ptn_im_decoder.py"],
deps = [
],
)
py_library(
name = "ptn_rotator",
srcs = ["ptn_rotator.py"],
deps = [
],
)
py_library(
name = "ptn_vox_decoder",
srcs = ["ptn_vox_decoder.py"],
deps = [
],
)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory module for different encoder/decoder network models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import ptn_encoder
from nets import ptn_im_decoder
from nets import ptn_rotator
_NAME_TO_NETS = {
'ptn_encoder': ptn_encoder,
'ptn_rotator': ptn_rotator,
'ptn_im_decoder': ptn_im_decoder,
}
def _get_network(name):
"""Gets a single network component."""
if name not in _NAME_TO_NETS:
raise ValueError('Network name [%s] not recognized.' % name)
return _NAME_TO_NETS[name].model
def get(params, is_training=False, reuse=False):
"""Factory function to retrieve a network model.
Args:
params: Different parameters used througout ptn, typically FLAGS (dict)
is_training: Set to True if while training (boolean)
reuse: Set as True if either using a pre-trained model or
in the training loop while the graph has already been built (boolean)
Returns:
Model function for network (inputs to outputs)
"""
def model(inputs):
"""Model function corresponding to a specific network architecture."""
outputs = {}
# First, build the encoder.
encoder_fn = _get_network(params.encoder_name)
with tf.variable_scope('encoder', reuse=reuse):
# Produces id/pose units
features = encoder_fn(inputs['images_0'], params, is_training)
outputs['ids'] = features['ids']
outputs['poses_0'] = features['poses']
# Second, build the rotator and decoder.
rotator_fn = _get_network(params.rotator_name)
with tf.variable_scope('rotator', reuse=reuse):
outputs['poses_1'] = rotator_fn(outputs['poses_0'], inputs['actions'],
params, is_training)
decoder_fn = _get_network(params.decoder_name)
with tf.variable_scope('decoder', reuse=reuse):
dec_output = decoder_fn(outputs['ids'], outputs['poses_1'], params,
is_training)
outputs['images_1'] = dec_output['images']
outputs['masks_1'] = dec_output['masks']
# Third, build the recurrent connection
for k in range(1, params.step_size):
with tf.variable_scope('rotator', reuse=True):
outputs['poses_%d' % (k + 1)] = rotator_fn(
outputs['poses_%d' % k], inputs['actions'], params, is_training)
with tf.variable_scope('decoder', reuse=True):
dec_output = decoder_fn(outputs['ids'], outputs['poses_%d' % (k + 1)],
params, is_training)
outputs['images_%d' % (k + 1)] = dec_output['images']
outputs['masks_%d' % (k + 1)] = dec_output['masks']
return outputs
return model
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory module for getting the complete image to voxel generation network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import perspective_projector
from nets import ptn_encoder
from nets import ptn_vox_decoder
_NAME_TO_NETS = {
'ptn_encoder': ptn_encoder,
'ptn_vox_decoder': ptn_vox_decoder,
'perspective_projector': perspective_projector,
}
def _get_network(name):
"""Gets a single encoder/decoder network model."""
if name not in _NAME_TO_NETS:
raise ValueError('Network name [%s] not recognized.' % name)
return _NAME_TO_NETS[name].model
def get(params, is_training=False, reuse=False, run_projection=True):
"""Factory function to get the training/pretraining im->vox model (NIPS16).
Args:
params: Different parameters used througout ptn, typically FLAGS (dict).
is_training: Set to True if while training (boolean).
reuse: Set as True if sharing variables with a model that has already
been built (boolean).
run_projection: Set as False if not interested in mask and projection
images. Useful in evaluation routine (boolean).
Returns:
Model function for network (inputs to outputs).
"""
def model(inputs):
"""Model function corresponding to a specific network architecture."""
outputs = {}
# First, build the encoder
encoder_fn = _get_network(params.encoder_name)
with tf.variable_scope('encoder', reuse=reuse):
# Produces id/pose units
enc_outputs = encoder_fn(inputs['images_1'], params, is_training)
outputs['ids_1'] = enc_outputs['ids']
# Second, build the decoder and projector
decoder_fn = _get_network(params.decoder_name)
with tf.variable_scope('decoder', reuse=reuse):
outputs['voxels_1'] = decoder_fn(outputs['ids_1'], params, is_training)
if run_projection:
projector_fn = _get_network(params.projector_name)
with tf.variable_scope('projector', reuse=reuse):
outputs['projs_1'] = projector_fn(
outputs['voxels_1'], inputs['matrix_1'], params, is_training)
# Infer the ground-truth mask
with tf.variable_scope('oracle', reuse=reuse):
outputs['masks_1'] = projector_fn(inputs['voxels'], inputs['matrix_1'],
params, False)
# Third, build the entire graph (bundled strategy described in PTN paper)
for k in range(1, params.step_size):
with tf.variable_scope('projector', reuse=True):
outputs['projs_%d' % (k + 1)] = projector_fn(
outputs['voxels_1'], inputs['matrix_%d' %
(k + 1)], params, is_training)
with tf.variable_scope('oracle', reuse=True):
outputs['masks_%d' % (k + 1)] = projector_fn(
inputs['voxels'], inputs['matrix_%d' % (k + 1)], params, False)
return outputs
return model
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""3D->2D projector model as used in PTN (NIPS16)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import perspective_transform
def model(voxels, transform_matrix, params, is_training):
"""Model transforming the 3D voxels into 2D projections.
Args:
voxels: A tensor of size [batch, depth, height, width, channel]
representing the input of projection layer (tf.float32).
transform_matrix: A tensor of size [batch, 16] representing
the flattened 4-by-4 matrix for transformation (tf.float32).
params: Model parameters (dict).
is_training: Set to True if while training (boolean).
Returns:
A transformed tensor (tf.float32)
"""
del is_training # Doesn't make a difference for projector
# Rearrangement (batch, z, y, x, channel) --> (batch, y, z, x, channel).
# By the standard, projection happens along z-axis but the voxels
# are stored in a different way. So we need to switch the y and z
# axis for transformation operation.
voxels = tf.transpose(voxels, [0, 2, 1, 3, 4])
z_near = params.focal_length
z_far = params.focal_length + params.focal_range
transformed_voxels = perspective_transform.transformer(
voxels, transform_matrix, [params.vox_size] * 3, z_near, z_far)
views = tf.reduce_max(transformed_voxels, [1])
views = tf.reverse(views, [1])
return views
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Perspective Transformer Layer Implementation.
Transform the volume based on 4 x 4 perspective projection matrix.
Reference:
(1) "Perspective Transformer Nets: Perspective Transformer Nets:
Learning Single-View 3D Object Reconstruction without 3D Supervision."
Xinchen Yan, Jimei Yang, Ersin Yumer, Yijie Guo, Honglak Lee. In NIPS 2016
https://papers.nips.cc/paper/6206-perspective-transformer-nets-learning-single-view-3d-object-reconstruction-without-3d-supervision.pdf
(2) Official implementation in Torch: https://github.com/xcyan/ptnbhwd
(3) 2D Transformer implementation in TF:
github.com/tensorflow/models/tree/master/transformer
"""
import tensorflow as tf
def transformer(voxels,
theta,
out_size,
z_near,
z_far,
name='PerspectiveTransformer'):
"""Perspective Transformer Layer.
Args:
voxels: A tensor of size [num_batch, depth, height, width, num_channels].
It is the output of a deconv/upsampling conv network (tf.float32).
theta: A tensor of size [num_batch, 16].
It is the inverse camera transformation matrix (tf.float32).
out_size: A tuple representing the size of output of
transformer layer (float).
z_near: A number representing the near clipping plane (float).
z_far: A number representing the far clipping plane (float).
Returns:
A transformed tensor (tf.float32).
"""
def _repeat(x, n_repeats):
with tf.variable_scope('_repeat'):
rep = tf.transpose(
tf.expand_dims(tf.ones(shape=tf.stack([
n_repeats,
])), 1), [1, 0])
rep = tf.to_int32(rep)
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
return tf.reshape(x, [-1])
def _interpolate(im, x, y, z, out_size):
"""Bilinear interploation layer.
Args:
im: A 5D tensor of size [num_batch, depth, height, width, num_channels].
It is the input volume for the transformation layer (tf.float32).
x: A tensor of size [num_batch, out_depth, out_height, out_width]
representing the inverse coordinate mapping for x (tf.float32).
y: A tensor of size [num_batch, out_depth, out_height, out_width]
representing the inverse coordinate mapping for y (tf.float32).
z: A tensor of size [num_batch, out_depth, out_height, out_width]
representing the inverse coordinate mapping for z (tf.float32).
out_size: A tuple representing the output size of transformation layer
(float).
Returns:
A transformed tensor (tf.float32).
"""
with tf.variable_scope('_interpolate'):
num_batch = im.get_shape().as_list()[0]
depth = im.get_shape().as_list()[1]
height = im.get_shape().as_list()[2]
width = im.get_shape().as_list()[3]
channels = im.get_shape().as_list()[4]
x = tf.to_float(x)
y = tf.to_float(y)
z = tf.to_float(z)
depth_f = tf.to_float(depth)
height_f = tf.to_float(height)
width_f = tf.to_float(width)
# Number of disparity interpolated.
out_depth = out_size[0]
out_height = out_size[1]
out_width = out_size[2]
zero = tf.zeros([], dtype='int32')
# 0 <= z < depth, 0 <= y < height & 0 <= x < width.
max_z = tf.to_int32(tf.shape(im)[1] - 1)
max_y = tf.to_int32(tf.shape(im)[2] - 1)
max_x = tf.to_int32(tf.shape(im)[3] - 1)
# Converts scale indices from [-1, 1] to [0, width/height/depth].
x = (x + 1.0) * (width_f) / 2.0
y = (y + 1.0) * (height_f) / 2.0
z = (z + 1.0) * (depth_f) / 2.0
x0 = tf.to_int32(tf.floor(x))
x1 = x0 + 1
y0 = tf.to_int32(tf.floor(y))
y1 = y0 + 1
z0 = tf.to_int32(tf.floor(z))
z1 = z0 + 1
x0_clip = tf.clip_by_value(x0, zero, max_x)
x1_clip = tf.clip_by_value(x1, zero, max_x)
y0_clip = tf.clip_by_value(y0, zero, max_y)
y1_clip = tf.clip_by_value(y1, zero, max_y)
z0_clip = tf.clip_by_value(z0, zero, max_z)
z1_clip = tf.clip_by_value(z1, zero, max_z)
dim3 = width
dim2 = width * height
dim1 = width * height * depth
base = _repeat(
tf.range(num_batch) * dim1, out_depth * out_height * out_width)
base_z0_y0 = base + z0_clip * dim2 + y0_clip * dim3
base_z0_y1 = base + z0_clip * dim2 + y1_clip * dim3
base_z1_y0 = base + z1_clip * dim2 + y0_clip * dim3
base_z1_y1 = base + z1_clip * dim2 + y1_clip * dim3
idx_z0_y0_x0 = base_z0_y0 + x0_clip
idx_z0_y0_x1 = base_z0_y0 + x1_clip
idx_z0_y1_x0 = base_z0_y1 + x0_clip
idx_z0_y1_x1 = base_z0_y1 + x1_clip
idx_z1_y0_x0 = base_z1_y0 + x0_clip
idx_z1_y0_x1 = base_z1_y0 + x1_clip
idx_z1_y1_x0 = base_z1_y1 + x0_clip
idx_z1_y1_x1 = base_z1_y1 + x1_clip
# Use indices to lookup pixels in the flat image and restore
# channels dim
im_flat = tf.reshape(im, tf.stack([-1, channels]))
im_flat = tf.to_float(im_flat)
i_z0_y0_x0 = tf.gather(im_flat, idx_z0_y0_x0)
i_z0_y0_x1 = tf.gather(im_flat, idx_z0_y0_x1)
i_z0_y1_x0 = tf.gather(im_flat, idx_z0_y1_x0)
i_z0_y1_x1 = tf.gather(im_flat, idx_z0_y1_x1)
i_z1_y0_x0 = tf.gather(im_flat, idx_z1_y0_x0)
i_z1_y0_x1 = tf.gather(im_flat, idx_z1_y0_x1)
i_z1_y1_x0 = tf.gather(im_flat, idx_z1_y1_x0)
i_z1_y1_x1 = tf.gather(im_flat, idx_z1_y1_x1)
# Finally calculate interpolated values.
x0_f = tf.to_float(x0)
x1_f = tf.to_float(x1)
y0_f = tf.to_float(y0)
y1_f = tf.to_float(y1)
z0_f = tf.to_float(z0)
z1_f = tf.to_float(z1)
# Check the out-of-boundary case.
x0_valid = tf.to_float(
tf.less_equal(x0, max_x) & tf.greater_equal(x0, 0))
x1_valid = tf.to_float(
tf.less_equal(x1, max_x) & tf.greater_equal(x1, 0))
y0_valid = tf.to_float(
tf.less_equal(y0, max_y) & tf.greater_equal(y0, 0))
y1_valid = tf.to_float(
tf.less_equal(y1, max_y) & tf.greater_equal(y1, 0))
z0_valid = tf.to_float(
tf.less_equal(z0, max_z) & tf.greater_equal(z0, 0))
z1_valid = tf.to_float(
tf.less_equal(z1, max_z) & tf.greater_equal(z1, 0))
w_z0_y0_x0 = tf.expand_dims(((x1_f - x) * (y1_f - y) *
(z1_f - z) * x1_valid * y1_valid * z1_valid),
1)
w_z0_y0_x1 = tf.expand_dims(((x - x0_f) * (y1_f - y) *
(z1_f - z) * x0_valid * y1_valid * z1_valid),
1)
w_z0_y1_x0 = tf.expand_dims(((x1_f - x) * (y - y0_f) *
(z1_f - z) * x1_valid * y0_valid * z1_valid),
1)
w_z0_y1_x1 = tf.expand_dims(((x - x0_f) * (y - y0_f) *
(z1_f - z) * x0_valid * y0_valid * z1_valid),
1)
w_z1_y0_x0 = tf.expand_dims(((x1_f - x) * (y1_f - y) *
(z - z0_f) * x1_valid * y1_valid * z0_valid),
1)
w_z1_y0_x1 = tf.expand_dims(((x - x0_f) * (y1_f - y) *
(z - z0_f) * x0_valid * y1_valid * z0_valid),
1)
w_z1_y1_x0 = tf.expand_dims(((x1_f - x) * (y - y0_f) *
(z - z0_f) * x1_valid * y0_valid * z0_valid),
1)
w_z1_y1_x1 = tf.expand_dims(((x - x0_f) * (y - y0_f) *
(z - z0_f) * x0_valid * y0_valid * z0_valid),
1)
output = tf.add_n([
w_z0_y0_x0 * i_z0_y0_x0, w_z0_y0_x1 * i_z0_y0_x1,
w_z0_y1_x0 * i_z0_y1_x0, w_z0_y1_x1 * i_z0_y1_x1,
w_z1_y0_x0 * i_z1_y0_x0, w_z1_y0_x1 * i_z1_y0_x1,
w_z1_y1_x0 * i_z1_y1_x0, w_z1_y1_x1 * i_z1_y1_x1
])
return output
def _meshgrid(depth, height, width, z_near, z_far):
with tf.variable_scope('_meshgrid'):
x_t = tf.reshape(
tf.tile(tf.linspace(-1.0, 1.0, width), [height * depth]),
[depth, height, width])
y_t = tf.reshape(
tf.tile(tf.linspace(-1.0, 1.0, height), [width * depth]),
[depth, width, height])
y_t = tf.transpose(y_t, [0, 2, 1])
sample_grid = tf.tile(
tf.linspace(float(z_near), float(z_far), depth), [width * height])
z_t = tf.reshape(sample_grid, [height, width, depth])
z_t = tf.transpose(z_t, [2, 0, 1])
z_t = 1 / z_t
d_t = 1 / z_t
x_t /= z_t
y_t /= z_t
x_t_flat = tf.reshape(x_t, (1, -1))
y_t_flat = tf.reshape(y_t, (1, -1))
d_t_flat = tf.reshape(d_t, (1, -1))
ones = tf.ones_like(x_t_flat)
grid = tf.concat([d_t_flat, y_t_flat, x_t_flat, ones], 0)
return grid
def _transform(theta, input_dim, out_size, z_near, z_far):
with tf.variable_scope('_transform'):
num_batch = input_dim.get_shape().as_list()[0]
num_channels = input_dim.get_shape().as_list()[4]
theta = tf.reshape(theta, (-1, 4, 4))
theta = tf.cast(theta, 'float32')
out_depth = out_size[0]
out_height = out_size[1]
out_width = out_size[2]
grid = _meshgrid(out_depth, out_height, out_width, z_near, z_far)
grid = tf.expand_dims(grid, 0)
grid = tf.reshape(grid, [-1])
grid = tf.tile(grid, tf.stack([num_batch]))
grid = tf.reshape(grid, tf.stack([num_batch, 4, -1]))
# Transform A x (x_t', y_t', 1, d_t)^T -> (x_s, y_s, z_s, 1).
t_g = tf.matmul(theta, grid)
z_s = tf.slice(t_g, [0, 0, 0], [-1, 1, -1])
y_s = tf.slice(t_g, [0, 1, 0], [-1, 1, -1])
x_s = tf.slice(t_g, [0, 2, 0], [-1, 1, -1])
z_s_flat = tf.reshape(z_s, [-1])
y_s_flat = tf.reshape(y_s, [-1])
x_s_flat = tf.reshape(x_s, [-1])
input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, z_s_flat,
out_size)
output = tf.reshape(
input_transformed,
tf.stack([num_batch, out_depth, out_height, out_width, num_channels]))
return output
with tf.variable_scope(name):
output = _transform(theta, voxels, out_size, z_near, z_far)
return output
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training/Pretraining encoder as used in PTN (NIPS16)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def _preprocess(images):
return images * 2 - 1
def model(images, params, is_training):
"""Model encoding the images into view-invariant embedding."""
del is_training # Unused
image_size = images.get_shape().as_list()[1]
f_dim = params.f_dim
fc_dim = params.fc_dim
z_dim = params.z_dim
outputs = dict()
images = _preprocess(images)
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_initializer=tf.truncated_normal_initializer(stddev=0.02, seed=1)):
h0 = slim.conv2d(images, f_dim, [5, 5], stride=2, activation_fn=tf.nn.relu)
h1 = slim.conv2d(h0, f_dim * 2, [5, 5], stride=2, activation_fn=tf.nn.relu)
h2 = slim.conv2d(h1, f_dim * 4, [5, 5], stride=2, activation_fn=tf.nn.relu)
# Reshape layer
s8 = image_size // 8
h2 = tf.reshape(h2, [-1, s8 * s8 * f_dim * 4])
h3 = slim.fully_connected(h2, fc_dim, activation_fn=tf.nn.relu)
h4 = slim.fully_connected(h3, fc_dim, activation_fn=tf.nn.relu)
outputs['ids'] = slim.fully_connected(h4, z_dim, activation_fn=tf.nn.relu)
outputs['poses'] = slim.fully_connected(h4, z_dim, activation_fn=tf.nn.relu)
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment