Unverified Commit 6b9d5fba authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge branch 'master' into patch-1

parents 5fd687c5 5fa2a4e6
...@@ -42,7 +42,8 @@ class Reinforce(objective.Objective): ...@@ -42,7 +42,8 @@ class Reinforce(objective.Objective):
def get(self, rewards, pads, values, final_values, def get(self, rewards, pads, values, final_values,
log_probs, prev_log_probs, target_log_probs, log_probs, prev_log_probs, target_log_probs,
entropies, logits): entropies, logits,
target_values, final_target_values):
seq_length = tf.shape(rewards)[0] seq_length = tf.shape(rewards)[0]
not_pad = tf.reshape(1 - pads, [seq_length, -1, self.num_samples]) not_pad = tf.reshape(1 - pads, [seq_length, -1, self.num_samples])
......
...@@ -22,6 +22,7 @@ import gym ...@@ -22,6 +22,7 @@ import gym
import numpy as np import numpy as np
import random import random
from six.moves import xrange
import env_spec import env_spec
...@@ -92,14 +93,14 @@ class GymWrapper(object): ...@@ -92,14 +93,14 @@ class GymWrapper(object):
def step(self, actions): def step(self, actions):
def env_step(action): def env_step(env, action):
action = self.env_spec.convert_action_to_gym(action) action = self.env_spec.convert_action_to_gym(action)
obs, reward, done, tt = env.step(action) obs, reward, done, tt = env.step(action)
obs = self.env_spec.convert_obs_to_list(obs) obs = self.env_spec.convert_obs_to_list(obs)
return obs, reward, done, tt return obs, reward, done, tt
actions = zip(*actions) actions = zip(*actions)
outputs = [env_step(action) outputs = [env_step(env, action)
if not done else (self.env_spec.initial_obs(None), 0, True, None) if not done else (self.env_spec.initial_obs(None), 0, True, None)
for action, env, done in zip(actions, self.envs, self.dones)] for action, env, done in zip(actions, self.envs, self.dones)]
for i, (_, _, done, _) in enumerate(outputs): for i, (_, _, done, _) in enumerate(outputs):
......
...@@ -57,6 +57,8 @@ class Model(object): ...@@ -57,6 +57,8 @@ class Model(object):
# summary placeholder # summary placeholder
self.avg_episode_reward = tf.placeholder( self.avg_episode_reward = tf.placeholder(
tf.float32, [], 'avg_episode_reward') tf.float32, [], 'avg_episode_reward')
self.greedy_episode_reward = tf.placeholder(
tf.float32, [], 'greedy_episode_reward')
# sampling placeholders # sampling placeholders
self.internal_state = tf.placeholder(tf.float32, self.internal_state = tf.placeholder(tf.float32,
...@@ -118,12 +120,13 @@ class Model(object): ...@@ -118,12 +120,13 @@ class Model(object):
self.prev_log_probs = tf.placeholder(tf.float32, [None, None], self.prev_log_probs = tf.placeholder(tf.float32, [None, None],
'prev_log_probs') 'prev_log_probs')
def setup(self): def setup(self, train=True):
"""Setup Tensorflow Graph.""" """Setup Tensorflow Graph."""
self.setup_placeholders() self.setup_placeholders()
tf.summary.scalar('avg_episode_reward', self.avg_episode_reward) tf.summary.scalar('avg_episode_reward', self.avg_episode_reward)
tf.summary.scalar('greedy_episode_reward', self.greedy_episode_reward)
with tf.variable_scope('model', reuse=None): with tf.variable_scope('model', reuse=None):
# policy network # policy network
...@@ -174,6 +177,7 @@ class Model(object): ...@@ -174,6 +177,7 @@ class Model(object):
target_p.assign(aa * target_p + (1 - aa) * online_p) target_p.assign(aa * target_p + (1 - aa) * online_p)
for online_p, target_p in zip(online_vars, target_vars)]) for online_p, target_p in zip(online_vars, target_vars)])
if train:
# evaluate objective # evaluate objective
(self.loss, self.raw_loss, self.regression_target, (self.loss, self.raw_loss, self.regression_target,
self.gradient_ops, self.summary) = self.objective.get( self.gradient_ops, self.summary) = self.objective.get(
...@@ -181,8 +185,8 @@ class Model(object): ...@@ -181,8 +185,8 @@ class Model(object):
self.values[:-1, :], self.values[:-1, :],
self.values[-1, :] * (1 - self.terminated), self.values[-1, :] * (1 - self.terminated),
self.log_probs, self.prev_log_probs, self.target_log_probs, self.log_probs, self.prev_log_probs, self.target_log_probs,
self.entropies, self.entropies, self.logits, self.target_values[:-1, :],
self.logits) self.target_values[-1, :] * (1 - self.terminated))
self.regression_target = tf.reshape(self.regression_target, [-1]) self.regression_target = tf.reshape(self.regression_target, [-1])
...@@ -249,32 +253,42 @@ class Model(object): ...@@ -249,32 +253,42 @@ class Model(object):
def train_step(self, sess, def train_step(self, sess,
observations, internal_state, actions, observations, internal_state, actions,
rewards, terminated, pads, rewards, terminated, pads,
avg_episode_reward=0): avg_episode_reward=0, greedy_episode_reward=0):
"""Train network using standard gradient descent.""" """Train network using standard gradient descent."""
outputs = [self.raw_loss, self.gradient_ops, self.summary] outputs = [self.raw_loss, self.gradient_ops, self.summary]
feed_dict = {self.internal_state: internal_state, feed_dict = {self.internal_state: internal_state,
self.rewards: rewards, self.rewards: rewards,
self.terminated: terminated, self.terminated: terminated,
self.pads: pads, self.pads: pads,
self.avg_episode_reward: avg_episode_reward} self.avg_episode_reward: avg_episode_reward,
self.greedy_episode_reward: greedy_episode_reward}
time_len = None
for action_place, action in zip(self.actions, actions): for action_place, action in zip(self.actions, actions):
if time_len is None:
time_len = len(action)
assert time_len == len(action)
feed_dict[action_place] = action feed_dict[action_place] = action
for obs_place, obs in zip(self.observations, observations): for obs_place, obs in zip(self.observations, observations):
assert time_len == len(obs)
feed_dict[obs_place] = obs feed_dict[obs_place] = obs
assert len(rewards) == time_len - 1
return sess.run(outputs, feed_dict=feed_dict) return sess.run(outputs, feed_dict=feed_dict)
def trust_region_step(self, sess, def trust_region_step(self, sess,
observations, internal_state, actions, observations, internal_state, actions,
rewards, terminated, pads, rewards, terminated, pads,
avg_episode_reward=0): avg_episode_reward=0,
greedy_episode_reward=0):
"""Train policy using trust region step.""" """Train policy using trust region step."""
feed_dict = {self.internal_state: internal_state, feed_dict = {self.internal_state: internal_state,
self.rewards: rewards, self.rewards: rewards,
self.terminated: terminated, self.terminated: terminated,
self.pads: pads, self.pads: pads,
self.avg_episode_reward: avg_episode_reward} self.avg_episode_reward: avg_episode_reward,
self.greedy_episode_reward: greedy_episode_reward}
for action_place, action in zip(self.actions, actions): for action_place, action in zip(self.actions, actions):
feed_dict[action_place] = action feed_dict[action_place] = action
for obs_place, obs in zip(self.observations, observations): for obs_place, obs in zip(self.observations, observations):
......
...@@ -46,7 +46,8 @@ class Objective(object): ...@@ -46,7 +46,8 @@ class Objective(object):
def get(self, rewards, pads, values, final_values, def get(self, rewards, pads, values, final_values,
log_probs, prev_log_probs, target_log_probs, log_probs, prev_log_probs, target_log_probs,
entropies, logits): entropies, logits,
target_values, final_target_values):
"""Get objective calculations.""" """Get objective calculations."""
raise NotImplementedError() raise NotImplementedError()
...@@ -101,7 +102,8 @@ class ActorCritic(Objective): ...@@ -101,7 +102,8 @@ class ActorCritic(Objective):
def __init__(self, learning_rate, clip_norm=5, def __init__(self, learning_rate, clip_norm=5,
policy_weight=1.0, critic_weight=0.1, policy_weight=1.0, critic_weight=0.1,
tau=0.1, gamma=1.0, rollout=10, tau=0.1, gamma=1.0, rollout=10,
eps_lambda=0.0, clip_adv=None): eps_lambda=0.0, clip_adv=None,
use_target_values=False):
super(ActorCritic, self).__init__(learning_rate, clip_norm=clip_norm) super(ActorCritic, self).__init__(learning_rate, clip_norm=clip_norm)
self.policy_weight = policy_weight self.policy_weight = policy_weight
self.critic_weight = critic_weight self.critic_weight = critic_weight
...@@ -111,14 +113,17 @@ class ActorCritic(Objective): ...@@ -111,14 +113,17 @@ class ActorCritic(Objective):
self.clip_adv = clip_adv self.clip_adv = clip_adv
self.eps_lambda = tf.get_variable( # TODO: need a better way self.eps_lambda = tf.get_variable( # TODO: need a better way
'eps_lambda', [], initializer=tf.constant_initializer(eps_lambda)) 'eps_lambda', [], initializer=tf.constant_initializer(eps_lambda),
trainable=False)
self.new_eps_lambda = tf.placeholder(tf.float32, []) self.new_eps_lambda = tf.placeholder(tf.float32, [])
self.assign_eps_lambda = self.eps_lambda.assign( self.assign_eps_lambda = self.eps_lambda.assign(
0.95 * self.eps_lambda + 0.05 * self.new_eps_lambda) 0.99 * self.eps_lambda + 0.01 * self.new_eps_lambda)
self.use_target_values = use_target_values
def get(self, rewards, pads, values, final_values, def get(self, rewards, pads, values, final_values,
log_probs, prev_log_probs, target_log_probs, log_probs, prev_log_probs, target_log_probs,
entropies, logits): entropies, logits,
target_values, final_target_values):
not_pad = 1 - pads not_pad = 1 - pads
batch_size = tf.shape(rewards)[1] batch_size = tf.shape(rewards)[1]
...@@ -126,8 +131,15 @@ class ActorCritic(Objective): ...@@ -126,8 +131,15 @@ class ActorCritic(Objective):
rewards = not_pad * rewards rewards = not_pad * rewards
value_estimates = not_pad * values value_estimates = not_pad * values
log_probs = not_pad * sum(log_probs) log_probs = not_pad * sum(log_probs)
target_values = not_pad * tf.stop_gradient(target_values)
final_target_values = tf.stop_gradient(final_target_values)
sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout) sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout)
if self.use_target_values:
last_values = shift_values(
target_values, self.gamma, self.rollout,
final_target_values)
else:
last_values = shift_values(value_estimates, self.gamma, self.rollout, last_values = shift_values(value_estimates, self.gamma, self.rollout,
final_values) final_values)
...@@ -183,7 +195,8 @@ class PCL(ActorCritic): ...@@ -183,7 +195,8 @@ class PCL(ActorCritic):
def get(self, rewards, pads, values, final_values, def get(self, rewards, pads, values, final_values,
log_probs, prev_log_probs, target_log_probs, log_probs, prev_log_probs, target_log_probs,
entropies, logits): entropies, logits,
target_values, final_target_values):
not_pad = 1 - pads not_pad = 1 - pads
batch_size = tf.shape(rewards)[1] batch_size = tf.shape(rewards)[1]
...@@ -192,6 +205,8 @@ class PCL(ActorCritic): ...@@ -192,6 +205,8 @@ class PCL(ActorCritic):
log_probs = not_pad * sum(log_probs) log_probs = not_pad * sum(log_probs)
target_log_probs = not_pad * tf.stop_gradient(sum(target_log_probs)) target_log_probs = not_pad * tf.stop_gradient(sum(target_log_probs))
relative_log_probs = not_pad * (log_probs - target_log_probs) relative_log_probs = not_pad * (log_probs - target_log_probs)
target_values = not_pad * tf.stop_gradient(target_values)
final_target_values = tf.stop_gradient(final_target_values)
# Prepend. # Prepend.
not_pad = tf.concat([tf.ones([self.rollout - 1, batch_size]), not_pad = tf.concat([tf.ones([self.rollout - 1, batch_size]),
...@@ -210,12 +225,24 @@ class PCL(ActorCritic): ...@@ -210,12 +225,24 @@ class PCL(ActorCritic):
prev_log_probs], 0) prev_log_probs], 0)
relative_log_probs = tf.concat([tf.zeros([self.rollout - 1, batch_size]), relative_log_probs = tf.concat([tf.zeros([self.rollout - 1, batch_size]),
relative_log_probs], 0) relative_log_probs], 0)
target_values = tf.concat(
[self.gamma ** tf.expand_dims(
tf.range(float(self.rollout - 1), 0, -1), 1) *
tf.ones([self.rollout - 1, batch_size]) *
target_values[0:1, :],
target_values], 0)
sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout) sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout)
sum_log_probs = discounted_future_sum(log_probs, self.gamma, self.rollout) sum_log_probs = discounted_future_sum(log_probs, self.gamma, self.rollout)
sum_prev_log_probs = discounted_future_sum(prev_log_probs, self.gamma, self.rollout) sum_prev_log_probs = discounted_future_sum(prev_log_probs, self.gamma, self.rollout)
sum_relative_log_probs = discounted_future_sum( sum_relative_log_probs = discounted_future_sum(
relative_log_probs, self.gamma, self.rollout) relative_log_probs, self.gamma, self.rollout)
if self.use_target_values:
last_values = shift_values(
target_values, self.gamma, self.rollout,
final_target_values)
else:
last_values = shift_values(value_estimates, self.gamma, self.rollout, last_values = shift_values(value_estimates, self.gamma, self.rollout,
final_values) final_values)
...@@ -272,7 +299,8 @@ class TRPO(ActorCritic): ...@@ -272,7 +299,8 @@ class TRPO(ActorCritic):
def get(self, rewards, pads, values, final_values, def get(self, rewards, pads, values, final_values,
log_probs, prev_log_probs, target_log_probs, log_probs, prev_log_probs, target_log_probs,
entropies, logits): entropies, logits,
target_values, final_target_values):
not_pad = 1 - pads not_pad = 1 - pads
batch_size = tf.shape(rewards)[1] batch_size = tf.shape(rewards)[1]
...@@ -280,8 +308,16 @@ class TRPO(ActorCritic): ...@@ -280,8 +308,16 @@ class TRPO(ActorCritic):
value_estimates = not_pad * values value_estimates = not_pad * values
log_probs = not_pad * sum(log_probs) log_probs = not_pad * sum(log_probs)
prev_log_probs = not_pad * prev_log_probs prev_log_probs = not_pad * prev_log_probs
target_values = not_pad * tf.stop_gradient(target_values)
final_target_values = tf.stop_gradient(final_target_values)
sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout) sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout)
if self.use_target_values:
last_values = shift_values(
target_values, self.gamma, self.rollout,
final_target_values)
else:
last_values = shift_values(value_estimates, self.gamma, self.rollout, last_values = shift_values(value_estimates, self.gamma, self.rollout,
final_values) final_values)
......
...@@ -25,6 +25,7 @@ from __future__ import absolute_import ...@@ -25,6 +25,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import scipy.optimize import scipy.optimize
......
...@@ -20,6 +20,7 @@ Implements replay buffer in Python. ...@@ -20,6 +20,7 @@ Implements replay buffer in Python.
import random import random
import numpy as np import numpy as np
from six.moves import xrange
class ReplayBuffer(object): class ReplayBuffer(object):
...@@ -150,7 +151,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): ...@@ -150,7 +151,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
def get_batch(self, n): def get_batch(self, n):
"""Get batch of episodes to train on.""" """Get batch of episodes to train on."""
p = self.sampling_distribution() p = self.sampling_distribution()
idxs = np.random.choice(self.cur_size, size=n, replace=False, p=p) idxs = np.random.choice(self.cur_size, size=int(n), replace=False, p=p)
self.last_batch = idxs self.last_batch = idxs
return [self.buffer[idx] for idx in idxs], p[idxs] return [self.buffer[idx] for idx in idxs], p[idxs]
......
...@@ -25,6 +25,7 @@ import random ...@@ -25,6 +25,7 @@ import random
import os import os
import pickle import pickle
from six.moves import xrange
import controller import controller
import model import model
import policy import policy
...@@ -92,6 +93,8 @@ flags.DEFINE_bool('update_eps_lambda', False, ...@@ -92,6 +93,8 @@ flags.DEFINE_bool('update_eps_lambda', False,
'Update lambda automatically based on last 100 episodes.') 'Update lambda automatically based on last 100 episodes.')
flags.DEFINE_float('gamma', 1.0, 'discount') flags.DEFINE_float('gamma', 1.0, 'discount')
flags.DEFINE_integer('rollout', 10, 'rollout') flags.DEFINE_integer('rollout', 10, 'rollout')
flags.DEFINE_bool('use_target_values', False,
'use target network for value estimates')
flags.DEFINE_bool('fixed_std', True, flags.DEFINE_bool('fixed_std', True,
'fix the std in Gaussian distributions') 'fix the std in Gaussian distributions')
flags.DEFINE_bool('input_prev_actions', True, flags.DEFINE_bool('input_prev_actions', True,
...@@ -152,6 +155,10 @@ class Trainer(object): ...@@ -152,6 +155,10 @@ class Trainer(object):
self.env = gym_wrapper.GymWrapper(self.env_str, self.env = gym_wrapper.GymWrapper(self.env_str,
distinct=FLAGS.batch_size // self.num_samples, distinct=FLAGS.batch_size // self.num_samples,
count=self.num_samples) count=self.num_samples)
self.eval_env = gym_wrapper.GymWrapper(
self.env_str,
distinct=FLAGS.batch_size // self.num_samples,
count=self.num_samples)
self.env_spec = env_spec.EnvSpec(self.env.get_one()) self.env_spec = env_spec.EnvSpec(self.env.get_one())
self.max_step = FLAGS.max_step self.max_step = FLAGS.max_step
...@@ -169,7 +176,8 @@ class Trainer(object): ...@@ -169,7 +176,8 @@ class Trainer(object):
self.value_opt = FLAGS.value_opt self.value_opt = FLAGS.value_opt
assert not self.trust_region_p or self.objective in ['pcl', 'trpo'] assert not self.trust_region_p or self.objective in ['pcl', 'trpo']
assert self.objective != 'trpo' or self.trust_region_p assert self.objective != 'trpo' or self.trust_region_p
assert self.value_opt is None or self.critic_weight == 0.0 assert self.value_opt is None or self.value_opt == 'None' or \
self.critic_weight == 0.0
self.max_divergence = FLAGS.max_divergence self.max_divergence = FLAGS.max_divergence
self.learning_rate = FLAGS.learning_rate self.learning_rate = FLAGS.learning_rate
...@@ -182,6 +190,7 @@ class Trainer(object): ...@@ -182,6 +190,7 @@ class Trainer(object):
self.update_eps_lambda = FLAGS.update_eps_lambda self.update_eps_lambda = FLAGS.update_eps_lambda
self.gamma = FLAGS.gamma self.gamma = FLAGS.gamma
self.rollout = FLAGS.rollout self.rollout = FLAGS.rollout
self.use_target_values = FLAGS.use_target_values
self.fixed_std = FLAGS.fixed_std self.fixed_std = FLAGS.fixed_std
self.input_prev_actions = FLAGS.input_prev_actions self.input_prev_actions = FLAGS.input_prev_actions
self.recurrent = FLAGS.recurrent self.recurrent = FLAGS.recurrent
...@@ -208,8 +217,7 @@ class Trainer(object): ...@@ -208,8 +217,7 @@ class Trainer(object):
self.value_hidden_layers = FLAGS.value_hidden_layers self.value_hidden_layers = FLAGS.value_hidden_layers
self.tf_seed = FLAGS.tf_seed self.tf_seed = FLAGS.tf_seed
self.save_trajectories_dir = ( self.save_trajectories_dir = FLAGS.save_trajectories_dir
FLAGS.save_trajectories_dir or FLAGS.save_dir)
self.save_trajectories_file = ( self.save_trajectories_file = (
os.path.join( os.path.join(
self.save_trajectories_dir, self.env_str.replace('-', '_')) self.save_trajectories_dir, self.env_str.replace('-', '_'))
...@@ -244,7 +252,8 @@ class Trainer(object): ...@@ -244,7 +252,8 @@ class Trainer(object):
policy_weight=policy_weight, policy_weight=policy_weight,
critic_weight=self.critic_weight, critic_weight=self.critic_weight,
tau=tau, gamma=self.gamma, rollout=self.rollout, tau=tau, gamma=self.gamma, rollout=self.rollout,
eps_lambda=self.eps_lambda, clip_adv=self.clip_adv) eps_lambda=self.eps_lambda, clip_adv=self.clip_adv,
use_target_values=self.use_target_values)
elif self.objective in ['reinforce', 'urex']: elif self.objective in ['reinforce', 'urex']:
cls = (full_episode_objective.Reinforce cls = (full_episode_objective.Reinforce
if self.objective == 'reinforce' else if self.objective == 'reinforce' else
...@@ -322,10 +331,10 @@ class Trainer(object): ...@@ -322,10 +331,10 @@ class Trainer(object):
self.num_expert_paths, self.env_str, self.env_spec, self.num_expert_paths, self.env_str, self.env_spec,
load_trajectories_file=self.load_trajectories_file) load_trajectories_file=self.load_trajectories_file)
def get_controller(self): def get_controller(self, env):
"""Get controller.""" """Get controller."""
cls = controller.Controller cls = controller.Controller
return cls(self.env, self.env_spec, self.internal_dim, return cls(env, self.env_spec, self.internal_dim,
use_online_batch=self.use_online_batch, use_online_batch=self.use_online_batch,
batch_by_steps=self.batch_by_steps, batch_by_steps=self.batch_by_steps,
unify_episodes=self.unify_episodes, unify_episodes=self.unify_episodes,
...@@ -334,7 +343,7 @@ class Trainer(object): ...@@ -334,7 +343,7 @@ class Trainer(object):
cutoff_agent=self.cutoff_agent, cutoff_agent=self.cutoff_agent,
save_trajectories_file=self.save_trajectories_file, save_trajectories_file=self.save_trajectories_file,
use_trust_region=self.trust_region_p, use_trust_region=self.trust_region_p,
use_value_opt=self.value_opt is not None, use_value_opt=self.value_opt not in [None, 'None'],
update_eps_lambda=self.update_eps_lambda, update_eps_lambda=self.update_eps_lambda,
prioritize_by=self.prioritize_by, prioritize_by=self.prioritize_by,
get_model=self.get_model, get_model=self.get_model,
...@@ -359,16 +368,19 @@ class Trainer(object): ...@@ -359,16 +368,19 @@ class Trainer(object):
saver.restore(sess, ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path)
elif FLAGS.load_path: elif FLAGS.load_path:
logging.info('restoring from %s', FLAGS.load_path) logging.info('restoring from %s', FLAGS.load_path)
with gfile.AsUser('distbelief-brain-gpu'):
saver.restore(sess, FLAGS.load_path) saver.restore(sess, FLAGS.load_path)
if FLAGS.supervisor: if FLAGS.supervisor:
with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)): with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)):
self.global_step = tf.train.get_or_create_global_step() self.global_step = tf.contrib.framework.get_or_create_global_step()
tf.set_random_seed(FLAGS.tf_seed) tf.set_random_seed(FLAGS.tf_seed)
self.controller = self.get_controller() self.controller = self.get_controller(self.env)
self.model = self.controller.model self.model = self.controller.model
self.controller.setup() self.controller.setup()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self.eval_controller = self.get_controller(self.eval_env)
self.eval_controller.setup(train=False)
saver = tf.train.Saver(max_to_keep=10) saver = tf.train.Saver(max_to_keep=10)
step = self.model.global_step step = self.model.global_step
sv = tf.Supervisor(logdir=FLAGS.save_dir, sv = tf.Supervisor(logdir=FLAGS.save_dir,
...@@ -382,10 +394,14 @@ class Trainer(object): ...@@ -382,10 +394,14 @@ class Trainer(object):
sess = sv.PrepareSession(FLAGS.master) sess = sv.PrepareSession(FLAGS.master)
else: else:
tf.set_random_seed(FLAGS.tf_seed) tf.set_random_seed(FLAGS.tf_seed)
self.global_step = tf.train.get_or_create_global_step() self.global_step = tf.contrib.framework.get_or_create_global_step()
self.controller = self.get_controller() self.controller = self.get_controller(self.env)
self.model = self.controller.model self.model = self.controller.model
self.controller.setup() self.controller.setup()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self.eval_controller = self.get_controller(self.eval_env)
self.eval_controller.setup(train=False)
saver = tf.train.Saver(max_to_keep=10) saver = tf.train.Saver(max_to_keep=10)
sess = tf.Session() sess = tf.Session()
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
...@@ -414,21 +430,25 @@ class Trainer(object): ...@@ -414,21 +430,25 @@ class Trainer(object):
(loss, summary, (loss, summary,
total_rewards, episode_rewards) = self.controller.train(sess) total_rewards, episode_rewards) = self.controller.train(sess)
_, greedy_episode_rewards = self.eval_controller.eval(sess)
self.controller.greedy_episode_rewards = greedy_episode_rewards
losses.append(loss) losses.append(loss)
rewards.append(total_rewards) rewards.append(total_rewards)
all_ep_rewards.extend(episode_rewards) all_ep_rewards.extend(episode_rewards)
if random.random() < 1 and is_chief and sv and sv._summary_writer: if (random.random() < 0.1 and summary and episode_rewards and
is_chief and sv and sv._summary_writer):
sv.summary_computed(sess, summary) sv.summary_computed(sess, summary)
model_step = sess.run(self.model.global_step) model_step = sess.run(self.model.global_step)
if is_chief and step % self.validation_frequency == 0: if is_chief and step % self.validation_frequency == 0:
logging.info('at training step %d, model step %d: ' logging.info('at training step %d, model step %d: '
'avg loss %f, avg reward %f, ' 'avg loss %f, avg reward %f, '
'episode rewards: %f', 'episode rewards: %f, greedy rewards: %f',
step, model_step, step, model_step,
np.mean(losses), np.mean(rewards), np.mean(losses), np.mean(rewards),
np.mean(all_ep_rewards)) np.mean(all_ep_rewards),
np.mean(greedy_episode_rewards))
losses = [] losses = []
rewards = [] rewards = []
......
...@@ -24,6 +24,7 @@ from __future__ import absolute_import ...@@ -24,6 +24,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
......
...@@ -19,6 +19,7 @@ from __future__ import absolute_import ...@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -108,5 +109,3 @@ def add_volume_iou_metrics(inputs, outputs): ...@@ -108,5 +109,3 @@ def add_volume_iou_metrics(inputs, outputs):
names_to_values['volume_iou'] = tmp_values * 3.0 names_to_values['volume_iou'] = tmp_values * 3.0
names_to_updates['volume_iou'] = tmp_updates names_to_updates['volume_iou'] = tmp_updates
return names_to_values, names_to_updates return names_to_values, names_to_updates
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
import numpy as np import numpy as np
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
import losses import losses
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
import numpy as np import numpy as np
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
import input_generator import input_generator
......
...@@ -22,6 +22,7 @@ import abc ...@@ -22,6 +22,7 @@ import abc
import os import os
import numpy as np import numpy as np
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
import input_generator import input_generator
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
import numpy as np import numpy as np
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
from tensorflow import app from tensorflow import app
......
...@@ -28,6 +28,7 @@ from mpl_toolkits.mplot3d import axes3d as p3 # pylint:disable=unused-import ...@@ -28,6 +28,7 @@ from mpl_toolkits.mplot3d import axes3d as p3 # pylint:disable=unused-import
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from skimage import measure from skimage import measure
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
...@@ -116,4 +117,3 @@ def visualize_voxel_scatter(points, vis_size=128): ...@@ -116,4 +117,3 @@ def visualize_voxel_scatter(points, vis_size=128):
vis_size, vis_size, 3) vis_size, vis_size, 3)
p.close('all') p.close('all')
return data return data
...@@ -30,6 +30,8 @@ python celeba_formatting.py \ ...@@ -30,6 +30,8 @@ python celeba_formatting.py \
""" """
from __future__ import print_function
import os import os
import os.path import os.path
...@@ -70,7 +72,7 @@ def main(): ...@@ -70,7 +72,7 @@ def main():
writer = tf.python_io.TFRecordWriter(file_out) writer = tf.python_io.TFRecordWriter(file_out)
for example_idx, img_fn in enumerate(img_fn_list): for example_idx, img_fn in enumerate(img_fn_list):
if example_idx % 1000 == 0: if example_idx % 1000 == 0:
print example_idx, "/", num_examples print(example_idx, "/", num_examples)
image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn)) image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn))
rows = image_raw.shape[0] rows = image_raw.shape[0]
cols = image_raw.shape[1] cols = image_raw.shape[1]
......
...@@ -34,6 +34,8 @@ done ...@@ -34,6 +34,8 @@ done
""" """
from __future__ import print_function
import os import os
import os.path import os.path
...@@ -73,10 +75,10 @@ def main(): ...@@ -73,10 +75,10 @@ def main():
file_out = "%s_%05d.tfrecords" file_out = "%s_%05d.tfrecords"
file_out = file_out % (FLAGS.file_out, file_out = file_out % (FLAGS.file_out,
example_idx // n_examples_per_file) example_idx // n_examples_per_file)
print "Writing on:", file_out print("Writing on:", file_out)
writer = tf.python_io.TFRecordWriter(file_out) writer = tf.python_io.TFRecordWriter(file_out)
if example_idx % 1000 == 0: if example_idx % 1000 == 0:
print example_idx, "/", num_examples print(example_idx, "/", num_examples)
image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn)) image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn))
rows = image_raw.shape[0] rows = image_raw.shape[0]
cols = image_raw.shape[1] cols = image_raw.shape[1]
......
...@@ -29,6 +29,7 @@ python lsun_formatting.py \ ...@@ -29,6 +29,7 @@ python lsun_formatting.py \
--fn_root [LSUN_FOLDER] --fn_root [LSUN_FOLDER]
""" """
from __future__ import print_function
import os import os
import os.path import os.path
...@@ -68,10 +69,10 @@ def main(): ...@@ -68,10 +69,10 @@ def main():
file_out = "%s_%05d.tfrecords" file_out = "%s_%05d.tfrecords"
file_out = file_out % (FLAGS.file_out, file_out = file_out % (FLAGS.file_out,
example_idx // n_examples_per_file) example_idx // n_examples_per_file)
print "Writing on:", file_out print("Writing on:", file_out)
writer = tf.python_io.TFRecordWriter(file_out) writer = tf.python_io.TFRecordWriter(file_out)
if example_idx % 1000 == 0: if example_idx % 1000 == 0:
print example_idx, "/", num_examples print(example_idx, "/", num_examples)
image_raw = numpy.array(Image.open(os.path.join(fn_root, img_fn))) image_raw = numpy.array(Image.open(os.path.join(fn_root, img_fn)))
rows = image_raw.shape[0] rows = image_raw.shape[0]
cols = image_raw.shape[1] cols = image_raw.shape[1]
......
...@@ -23,11 +23,14 @@ $ python real_nvp_multiscale_dataset.py \ ...@@ -23,11 +23,14 @@ $ python real_nvp_multiscale_dataset.py \
--data_path [DATA_PATH] --data_path [DATA_PATH]
""" """
from __future__ import print_function
import time import time
from datetime import datetime from datetime import datetime
import os import os
import numpy import numpy
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
from tensorflow import gfile from tensorflow import gfile
...@@ -1435,10 +1438,10 @@ class RealNVP(object): ...@@ -1435,10 +1438,10 @@ class RealNVP(object):
n_equal = int(n_equal) n_equal = int(n_equal)
n_dash = bar_len - n_equal n_dash = bar_len - n_equal
progress_bar = "[" + "=" * n_equal + "-" * n_dash + "]\r" progress_bar = "[" + "=" * n_equal + "-" * n_dash + "]\r"
print progress_bar, print(progress_bar, end=' ')
cost = self.bit_per_dim.eval() cost = self.bit_per_dim.eval()
eval_costs.append(cost) eval_costs.append(cost)
print "" print("")
return float(numpy.mean(eval_costs)) return float(numpy.mean(eval_costs))
...@@ -1467,7 +1470,7 @@ def train_model(hps, logdir): ...@@ -1467,7 +1470,7 @@ def train_model(hps, logdir):
ckpt_state = tf.train.get_checkpoint_state(logdir) ckpt_state = tf.train.get_checkpoint_state(logdir)
if ckpt_state and ckpt_state.model_checkpoint_path: if ckpt_state and ckpt_state.model_checkpoint_path:
print "Loading file %s" % ckpt_state.model_checkpoint_path print("Loading file %s" % ckpt_state.model_checkpoint_path)
saver.restore(sess, ckpt_state.model_checkpoint_path) saver.restore(sess, ckpt_state.model_checkpoint_path)
# Start the queue runners. # Start the queue runners.
...@@ -1499,8 +1502,8 @@ def train_model(hps, logdir): ...@@ -1499,8 +1502,8 @@ def train_model(hps, logdir):
format_str = ('%s: step %d, loss = %.2f ' format_str = ('%s: step %d, loss = %.2f '
'(%.1f examples/sec; %.3f ' '(%.1f examples/sec; %.3f '
'sec/batch)') 'sec/batch)')
print format_str % (datetime.now(), global_step_val, loss, print(format_str % (datetime.now(), global_step_val, loss,
examples_per_sec, duration) examples_per_sec, duration))
if should_eval_summaries: if should_eval_summaries:
summary_str = outputs[-1] summary_str = outputs[-1]
...@@ -1542,24 +1545,24 @@ def evaluate(hps, logdir, traindir, subset="valid", return_val=False): ...@@ -1542,24 +1545,24 @@ def evaluate(hps, logdir, traindir, subset="valid", return_val=False):
while True: while True:
ckpt_state = tf.train.get_checkpoint_state(traindir) ckpt_state = tf.train.get_checkpoint_state(traindir)
if not (ckpt_state and ckpt_state.model_checkpoint_path): if not (ckpt_state and ckpt_state.model_checkpoint_path):
print "No model to eval yet at %s" % traindir print("No model to eval yet at %s" % traindir)
time.sleep(30) time.sleep(30)
continue continue
print "Loading file %s" % ckpt_state.model_checkpoint_path print("Loading file %s" % ckpt_state.model_checkpoint_path)
saver.restore(sess, ckpt_state.model_checkpoint_path) saver.restore(sess, ckpt_state.model_checkpoint_path)
current_step = tf.train.global_step(sess, eval_model.step) current_step = tf.train.global_step(sess, eval_model.step)
if current_step == previous_global_step: if current_step == previous_global_step:
print "Waiting for the checkpoint to be updated." print("Waiting for the checkpoint to be updated.")
time.sleep(30) time.sleep(30)
continue continue
previous_global_step = current_step previous_global_step = current_step
print "Evaluating..." print("Evaluating...")
bit_per_dim = eval_model.eval_epoch(hps) bit_per_dim = eval_model.eval_epoch(hps)
print ("Epoch: %d, %s -> %.3f bits/dim" print("Epoch: %d, %s -> %.3f bits/dim"
% (current_step, subset, bit_per_dim)) % (current_step, subset, bit_per_dim))
print "Writing summary..." print("Writing summary...")
summary = tf.Summary() summary = tf.Summary()
summary.value.extend( summary.value.extend(
[tf.Summary.Value( [tf.Summary.Value(
...@@ -1597,7 +1600,7 @@ def sample_from_model(hps, logdir, traindir): ...@@ -1597,7 +1600,7 @@ def sample_from_model(hps, logdir, traindir):
ckpt_state = tf.train.get_checkpoint_state(traindir) ckpt_state = tf.train.get_checkpoint_state(traindir)
if not (ckpt_state and ckpt_state.model_checkpoint_path): if not (ckpt_state and ckpt_state.model_checkpoint_path):
if not initialized: if not initialized:
print "No model to eval yet at %s" % traindir print("No model to eval yet at %s" % traindir)
time.sleep(30) time.sleep(30)
continue continue
else: else:
...@@ -1607,7 +1610,7 @@ def sample_from_model(hps, logdir, traindir): ...@@ -1607,7 +1610,7 @@ def sample_from_model(hps, logdir, traindir):
current_step = tf.train.global_step(sess, eval_model.step) current_step = tf.train.global_step(sess, eval_model.step)
if current_step == previous_global_step: if current_step == previous_global_step:
print "Waiting for the checkpoint to be updated." print("Waiting for the checkpoint to be updated.")
time.sleep(30) time.sleep(30)
continue continue
previous_global_step = current_step previous_global_step = current_step
......
...@@ -19,6 +19,7 @@ r"""Utility functions for Real NVP. ...@@ -19,6 +19,7 @@ r"""Utility functions for Real NVP.
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
import numpy import numpy
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
......
...@@ -94,6 +94,7 @@ import threading ...@@ -94,6 +94,7 @@ import threading
import google3 import google3
import numpy as np import numpy as np
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
tf.app.flags.DEFINE_string('train_directory', '/tmp/', tf.app.flags.DEFINE_string('train_directory', '/tmp/',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment