Commit 8a891524 authored by ofirnachum's avatar ofirnachum
Browse files

updates to pcl_rl according to most recent version of Trust-PCL paper

parent 1887a5fe
...@@ -67,20 +67,27 @@ python trainer.py --logtostderr --batch_size=25 --env=HalfCheetah-v1 \ ...@@ -67,20 +67,27 @@ python trainer.py --logtostderr --batch_size=25 --env=HalfCheetah-v1 \
--max_divergence=0.05 --value_opt=best_fit --critic_weight=0.0 \ --max_divergence=0.05 --value_opt=best_fit --critic_weight=0.0 \
``` ```
Run Mujoco task with Trust-PCL: To run Mujoco task using Trust-PCL (off-policy) use the below command.
It should work well across all environments, given that you
search sufficiently among
(1) max_divergence (0.001, 0.0005, 0.002 are good values),
(2) rollout (1, 5, 10 are good values),
(3) tf_seed (need to average over enough random seeds).
``` ```
python trainer.py --logtostderr --batch_size=1 --env=HalfCheetah-v1 \ python trainer.py --logtostderr --batch_size=1 --env=HalfCheetah-v1 \
--validation_frequency=50 --rollout=10 --critic_weight=0.0 \ --validation_frequency=250 --rollout=1 --critic_weight=1.0 --gamma=0.995 \
--gamma=0.995 --clip_norm=40 --learning_rate=0.002 \ --clip_norm=40 --learning_rate=0.0001 --replay_buffer_freq=1 \
--replay_buffer_freq=1 --replay_buffer_size=20000 \ --replay_buffer_size=5000 --replay_buffer_alpha=0.001 --norecurrent \
--replay_buffer_alpha=0.1 --norecurrent --objective=pcl \ --objective=pcl --max_step=10 --cutoff_agent=1000 --tau=0.0 --eviction=fifo \
--max_step=100 --tau=0.0 --eviction=fifo --max_divergence=0.001 \ --max_divergence=0.001 --internal_dim=256 --replay_batch_size=64 \
--internal_dim=64 --cutoff_agent=1000 \ --nouse_online_batch --batch_by_steps --value_hidden_layers=2 \
--replay_batch_size=25 --nouse_online_batch --batch_by_steps \ --update_eps_lambda --nounify_episodes --target_network_lag=0.99 \
--sample_from=target --value_opt=grad --value_hidden_layers=2 \ --sample_from=online --clip_adv=1 --prioritize_by=step --num_steps=1000000 \
--update_eps_lambda --unify_episodes --clip_adv=1.0 \ --noinput_prev_actions --use_target_values --tf_seed=57
--target_network_lag=0.99 --prioritize_by=step
``` ```
Run Mujoco task with PCL constraint trust region: Run Mujoco task with PCL constraint trust region:
......
...@@ -109,13 +109,14 @@ class Controller(object): ...@@ -109,13 +109,14 @@ class Controller(object):
self.episode_running_rewards = np.zeros(len(self.env)) self.episode_running_rewards = np.zeros(len(self.env))
self.episode_running_lengths = np.zeros(len(self.env)) self.episode_running_lengths = np.zeros(len(self.env))
self.episode_rewards = [] self.episode_rewards = []
self.greedy_episode_rewards = []
self.episode_lengths = [] self.episode_lengths = []
self.total_rewards = [] self.total_rewards = []
self.best_batch_rewards = None self.best_batch_rewards = None
def setup(self): def setup(self, train=True):
self.model.setup() self.model.setup(train=train)
def initial_internal_state(self): def initial_internal_state(self):
return np.zeros(self.model.policy.rnn_state_dim) return np.zeros(self.model.policy.rnn_state_dim)
...@@ -187,7 +188,7 @@ class Controller(object): ...@@ -187,7 +188,7 @@ class Controller(object):
return initial_state, all_obs, all_act, rewards, all_pad return initial_state, all_obs, all_act, rewards, all_pad
def sample_episodes(self, sess): def sample_episodes(self, sess, greedy=False):
"""Sample steps from the environment until we have enough for a batch.""" """Sample steps from the environment until we have enough for a batch."""
# check if last batch ended with episode that was not terminated # check if last batch ended with episode that was not terminated
...@@ -200,7 +201,7 @@ class Controller(object): ...@@ -200,7 +201,7 @@ class Controller(object):
while total_steps < self.max_step * len(self.env): while total_steps < self.max_step * len(self.env):
(initial_state, (initial_state,
observations, actions, rewards, observations, actions, rewards,
pads) = self._sample_episodes(sess) pads) = self._sample_episodes(sess, greedy=greedy)
observations = zip(*observations) observations = zip(*observations)
actions = zip(*actions) actions = zip(*actions)
...@@ -249,19 +250,26 @@ class Controller(object): ...@@ -249,19 +250,26 @@ class Controller(object):
observations, initial_state, actions, observations, initial_state, actions,
rewards, terminated, pads): rewards, terminated, pads):
"""Train model using batch.""" """Train model using batch."""
avg_episode_reward = np.mean(self.episode_rewards)
greedy_episode_reward = (np.mean(self.greedy_episode_rewards)
if self.greedy_episode_rewards else
avg_episode_reward)
loss, summary = None, None
if self.use_trust_region: if self.use_trust_region:
# use trust region to optimize policy # use trust region to optimize policy
loss, _, summary = self.model.trust_region_step( loss, _, summary = self.model.trust_region_step(
sess, sess,
observations, initial_state, actions, observations, initial_state, actions,
rewards, terminated, pads, rewards, terminated, pads,
avg_episode_reward=np.mean(self.episode_rewards)) avg_episode_reward=avg_episode_reward,
greedy_episode_reward=greedy_episode_reward)
else: # otherwise use simple gradient descent on policy else: # otherwise use simple gradient descent on policy
loss, _, summary = self.model.train_step( loss, _, summary = self.model.train_step(
sess, sess,
observations, initial_state, actions, observations, initial_state, actions,
rewards, terminated, pads, rewards, terminated, pads,
avg_episode_reward=np.mean(self.episode_rewards)) avg_episode_reward=avg_episode_reward,
greedy_episode_reward=greedy_episode_reward)
if self.use_value_opt: # optionally perform specific value optimization if self.use_value_opt: # optionally perform specific value optimization
self.model.fit_values( self.model.fit_values(
...@@ -305,7 +313,8 @@ class Controller(object): ...@@ -305,7 +313,8 @@ class Controller(object):
if self.update_eps_lambda: if self.update_eps_lambda:
episode_rewards = np.array(self.episode_rewards) episode_rewards = np.array(self.episode_rewards)
episode_lengths = np.array(self.episode_lengths) episode_lengths = np.array(self.episode_lengths)
eps_lambda = find_best_eps_lambda(episode_rewards, episode_lengths) eps_lambda = find_best_eps_lambda(
episode_rewards[-20:], episode_lengths[-20:])
sess.run(self.model.objective.assign_eps_lambda, sess.run(self.model.objective.assign_eps_lambda,
feed_dict={self.model.objective.new_eps_lambda: eps_lambda}) feed_dict={self.model.objective.new_eps_lambda: eps_lambda})
...@@ -328,10 +337,10 @@ class Controller(object): ...@@ -328,10 +337,10 @@ class Controller(object):
"""Use greedy sampling.""" """Use greedy sampling."""
(initial_state, (initial_state,
observations, actions, rewards, observations, actions, rewards,
pads) = self._sample_episodes(sess, greedy=True) pads, terminated) = self.sample_episodes(sess, greedy=True)
total_rewards = np.sum(np.array(rewards) * (1 - np.array(pads)), axis=0) total_rewards = np.sum(np.array(rewards) * (1 - np.array(pads)), axis=0)
return np.mean(total_rewards) return total_rewards, self.episode_rewards
def convert_from_batched_episodes( def convert_from_batched_episodes(
self, initial_state, observations, actions, rewards, self, initial_state, observations, actions, rewards,
...@@ -351,7 +360,7 @@ class Controller(object): ...@@ -351,7 +360,7 @@ class Controller(object):
for i in xrange(num_episodes): for i in xrange(num_episodes):
length = total_length[i] length = total_length[i]
ep_initial = initial_state[i] ep_initial = initial_state[i]
ep_obs = [obs[:length, i, ...] for obs in observations] ep_obs = [obs[:length + 1, i, ...] for obs in observations]
ep_act = [act[:length + 1, i, ...] for act in actions] ep_act = [act[:length + 1, i, ...] for act in actions]
ep_rewards = rewards[:length, i] ep_rewards = rewards[:length, i]
......
...@@ -42,7 +42,8 @@ class Reinforce(objective.Objective): ...@@ -42,7 +42,8 @@ class Reinforce(objective.Objective):
def get(self, rewards, pads, values, final_values, def get(self, rewards, pads, values, final_values,
log_probs, prev_log_probs, target_log_probs, log_probs, prev_log_probs, target_log_probs,
entropies, logits): entropies, logits,
target_values, final_target_values):
seq_length = tf.shape(rewards)[0] seq_length = tf.shape(rewards)[0]
not_pad = tf.reshape(1 - pads, [seq_length, -1, self.num_samples]) not_pad = tf.reshape(1 - pads, [seq_length, -1, self.num_samples])
......
...@@ -57,6 +57,8 @@ class Model(object): ...@@ -57,6 +57,8 @@ class Model(object):
# summary placeholder # summary placeholder
self.avg_episode_reward = tf.placeholder( self.avg_episode_reward = tf.placeholder(
tf.float32, [], 'avg_episode_reward') tf.float32, [], 'avg_episode_reward')
self.greedy_episode_reward = tf.placeholder(
tf.float32, [], 'greedy_episode_reward')
# sampling placeholders # sampling placeholders
self.internal_state = tf.placeholder(tf.float32, self.internal_state = tf.placeholder(tf.float32,
...@@ -118,12 +120,13 @@ class Model(object): ...@@ -118,12 +120,13 @@ class Model(object):
self.prev_log_probs = tf.placeholder(tf.float32, [None, None], self.prev_log_probs = tf.placeholder(tf.float32, [None, None],
'prev_log_probs') 'prev_log_probs')
def setup(self): def setup(self, train=True):
"""Setup Tensorflow Graph.""" """Setup Tensorflow Graph."""
self.setup_placeholders() self.setup_placeholders()
tf.summary.scalar('avg_episode_reward', self.avg_episode_reward) tf.summary.scalar('avg_episode_reward', self.avg_episode_reward)
tf.summary.scalar('greedy_episode_reward', self.greedy_episode_reward)
with tf.variable_scope('model', reuse=None): with tf.variable_scope('model', reuse=None):
# policy network # policy network
...@@ -174,45 +177,46 @@ class Model(object): ...@@ -174,45 +177,46 @@ class Model(object):
target_p.assign(aa * target_p + (1 - aa) * online_p) target_p.assign(aa * target_p + (1 - aa) * online_p)
for online_p, target_p in zip(online_vars, target_vars)]) for online_p, target_p in zip(online_vars, target_vars)])
# evaluate objective if train:
(self.loss, self.raw_loss, self.regression_target, # evaluate objective
self.gradient_ops, self.summary) = self.objective.get( (self.loss, self.raw_loss, self.regression_target,
self.rewards, self.pads, self.gradient_ops, self.summary) = self.objective.get(
self.values[:-1, :], self.rewards, self.pads,
self.values[-1, :] * (1 - self.terminated), self.values[:-1, :],
self.log_probs, self.prev_log_probs, self.target_log_probs, self.values[-1, :] * (1 - self.terminated),
self.entropies, self.log_probs, self.prev_log_probs, self.target_log_probs,
self.logits) self.entropies, self.logits, self.target_values[:-1, :],
self.target_values[-1, :] * (1 - self.terminated))
self.regression_target = tf.reshape(self.regression_target, [-1])
self.regression_target = tf.reshape(self.regression_target, [-1])
self.policy_vars = [
v for v in tf.trainable_variables() self.policy_vars = [
if '/policy_net' in v.name] v for v in tf.trainable_variables()
self.value_vars = [ if '/policy_net' in v.name]
v for v in tf.trainable_variables() self.value_vars = [
if '/value_net' in v.name] v for v in tf.trainable_variables()
if '/value_net' in v.name]
# trust region optimizer
if self.trust_region_policy_opt is not None: # trust region optimizer
with tf.variable_scope('trust_region_policy', reuse=None): if self.trust_region_policy_opt is not None:
avg_self_kl = ( with tf.variable_scope('trust_region_policy', reuse=None):
tf.reduce_sum(sum(self.self_kls) * (1 - self.pads)) / avg_self_kl = (
tf.reduce_sum(1 - self.pads)) tf.reduce_sum(sum(self.self_kls) * (1 - self.pads)) /
tf.reduce_sum(1 - self.pads))
self.trust_region_policy_opt.setup(
self.policy_vars, self.raw_loss, avg_self_kl, self.trust_region_policy_opt.setup(
self.avg_kl) self.policy_vars, self.raw_loss, avg_self_kl,
self.avg_kl)
# value optimizer
if self.value_opt is not None: # value optimizer
with tf.variable_scope('trust_region_value', reuse=None): if self.value_opt is not None:
self.value_opt.setup( with tf.variable_scope('trust_region_value', reuse=None):
self.value_vars, self.value_opt.setup(
tf.reshape(self.values[:-1, :], [-1]), self.value_vars,
self.regression_target, tf.reshape(self.values[:-1, :], [-1]),
tf.reshape(self.pads, [-1]), self.regression_target,
self.regression_input, self.regression_weight) tf.reshape(self.pads, [-1]),
self.regression_input, self.regression_weight)
# we re-use variables for the sampling operations # we re-use variables for the sampling operations
with tf.variable_scope('model', reuse=True): with tf.variable_scope('model', reuse=True):
...@@ -249,32 +253,42 @@ class Model(object): ...@@ -249,32 +253,42 @@ class Model(object):
def train_step(self, sess, def train_step(self, sess,
observations, internal_state, actions, observations, internal_state, actions,
rewards, terminated, pads, rewards, terminated, pads,
avg_episode_reward=0): avg_episode_reward=0, greedy_episode_reward=0):
"""Train network using standard gradient descent.""" """Train network using standard gradient descent."""
outputs = [self.raw_loss, self.gradient_ops, self.summary] outputs = [self.raw_loss, self.gradient_ops, self.summary]
feed_dict = {self.internal_state: internal_state, feed_dict = {self.internal_state: internal_state,
self.rewards: rewards, self.rewards: rewards,
self.terminated: terminated, self.terminated: terminated,
self.pads: pads, self.pads: pads,
self.avg_episode_reward: avg_episode_reward} self.avg_episode_reward: avg_episode_reward,
self.greedy_episode_reward: greedy_episode_reward}
time_len = None
for action_place, action in zip(self.actions, actions): for action_place, action in zip(self.actions, actions):
if time_len is None:
time_len = len(action)
assert time_len == len(action)
feed_dict[action_place] = action feed_dict[action_place] = action
for obs_place, obs in zip(self.observations, observations): for obs_place, obs in zip(self.observations, observations):
assert time_len == len(obs)
feed_dict[obs_place] = obs feed_dict[obs_place] = obs
assert len(rewards) == time_len - 1
return sess.run(outputs, feed_dict=feed_dict) return sess.run(outputs, feed_dict=feed_dict)
def trust_region_step(self, sess, def trust_region_step(self, sess,
observations, internal_state, actions, observations, internal_state, actions,
rewards, terminated, pads, rewards, terminated, pads,
avg_episode_reward=0): avg_episode_reward=0,
greedy_episode_reward=0):
"""Train policy using trust region step.""" """Train policy using trust region step."""
feed_dict = {self.internal_state: internal_state, feed_dict = {self.internal_state: internal_state,
self.rewards: rewards, self.rewards: rewards,
self.terminated: terminated, self.terminated: terminated,
self.pads: pads, self.pads: pads,
self.avg_episode_reward: avg_episode_reward} self.avg_episode_reward: avg_episode_reward,
self.greedy_episode_reward: greedy_episode_reward}
for action_place, action in zip(self.actions, actions): for action_place, action in zip(self.actions, actions):
feed_dict[action_place] = action feed_dict[action_place] = action
for obs_place, obs in zip(self.observations, observations): for obs_place, obs in zip(self.observations, observations):
......
...@@ -46,7 +46,8 @@ class Objective(object): ...@@ -46,7 +46,8 @@ class Objective(object):
def get(self, rewards, pads, values, final_values, def get(self, rewards, pads, values, final_values,
log_probs, prev_log_probs, target_log_probs, log_probs, prev_log_probs, target_log_probs,
entropies, logits): entropies, logits,
target_values, final_target_values):
"""Get objective calculations.""" """Get objective calculations."""
raise NotImplementedError() raise NotImplementedError()
...@@ -101,7 +102,8 @@ class ActorCritic(Objective): ...@@ -101,7 +102,8 @@ class ActorCritic(Objective):
def __init__(self, learning_rate, clip_norm=5, def __init__(self, learning_rate, clip_norm=5,
policy_weight=1.0, critic_weight=0.1, policy_weight=1.0, critic_weight=0.1,
tau=0.1, gamma=1.0, rollout=10, tau=0.1, gamma=1.0, rollout=10,
eps_lambda=0.0, clip_adv=None): eps_lambda=0.0, clip_adv=None,
use_target_values=False):
super(ActorCritic, self).__init__(learning_rate, clip_norm=clip_norm) super(ActorCritic, self).__init__(learning_rate, clip_norm=clip_norm)
self.policy_weight = policy_weight self.policy_weight = policy_weight
self.critic_weight = critic_weight self.critic_weight = critic_weight
...@@ -111,14 +113,17 @@ class ActorCritic(Objective): ...@@ -111,14 +113,17 @@ class ActorCritic(Objective):
self.clip_adv = clip_adv self.clip_adv = clip_adv
self.eps_lambda = tf.get_variable( # TODO: need a better way self.eps_lambda = tf.get_variable( # TODO: need a better way
'eps_lambda', [], initializer=tf.constant_initializer(eps_lambda)) 'eps_lambda', [], initializer=tf.constant_initializer(eps_lambda),
trainable=False)
self.new_eps_lambda = tf.placeholder(tf.float32, []) self.new_eps_lambda = tf.placeholder(tf.float32, [])
self.assign_eps_lambda = self.eps_lambda.assign( self.assign_eps_lambda = self.eps_lambda.assign(
0.95 * self.eps_lambda + 0.05 * self.new_eps_lambda) 0.99 * self.eps_lambda + 0.01 * self.new_eps_lambda)
self.use_target_values = use_target_values
def get(self, rewards, pads, values, final_values, def get(self, rewards, pads, values, final_values,
log_probs, prev_log_probs, target_log_probs, log_probs, prev_log_probs, target_log_probs,
entropies, logits): entropies, logits,
target_values, final_target_values):
not_pad = 1 - pads not_pad = 1 - pads
batch_size = tf.shape(rewards)[1] batch_size = tf.shape(rewards)[1]
...@@ -126,10 +131,17 @@ class ActorCritic(Objective): ...@@ -126,10 +131,17 @@ class ActorCritic(Objective):
rewards = not_pad * rewards rewards = not_pad * rewards
value_estimates = not_pad * values value_estimates = not_pad * values
log_probs = not_pad * sum(log_probs) log_probs = not_pad * sum(log_probs)
target_values = not_pad * tf.stop_gradient(target_values)
final_target_values = tf.stop_gradient(final_target_values)
sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout) sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout)
last_values = shift_values(value_estimates, self.gamma, self.rollout, if self.use_target_values:
final_values) last_values = shift_values(
target_values, self.gamma, self.rollout,
final_target_values)
else:
last_values = shift_values(value_estimates, self.gamma, self.rollout,
final_values)
future_values = sum_rewards + last_values future_values = sum_rewards + last_values
baseline_values = value_estimates baseline_values = value_estimates
...@@ -183,7 +195,8 @@ class PCL(ActorCritic): ...@@ -183,7 +195,8 @@ class PCL(ActorCritic):
def get(self, rewards, pads, values, final_values, def get(self, rewards, pads, values, final_values,
log_probs, prev_log_probs, target_log_probs, log_probs, prev_log_probs, target_log_probs,
entropies, logits): entropies, logits,
target_values, final_target_values):
not_pad = 1 - pads not_pad = 1 - pads
batch_size = tf.shape(rewards)[1] batch_size = tf.shape(rewards)[1]
...@@ -192,6 +205,8 @@ class PCL(ActorCritic): ...@@ -192,6 +205,8 @@ class PCL(ActorCritic):
log_probs = not_pad * sum(log_probs) log_probs = not_pad * sum(log_probs)
target_log_probs = not_pad * tf.stop_gradient(sum(target_log_probs)) target_log_probs = not_pad * tf.stop_gradient(sum(target_log_probs))
relative_log_probs = not_pad * (log_probs - target_log_probs) relative_log_probs = not_pad * (log_probs - target_log_probs)
target_values = not_pad * tf.stop_gradient(target_values)
final_target_values = tf.stop_gradient(final_target_values)
# Prepend. # Prepend.
not_pad = tf.concat([tf.ones([self.rollout - 1, batch_size]), not_pad = tf.concat([tf.ones([self.rollout - 1, batch_size]),
...@@ -210,14 +225,26 @@ class PCL(ActorCritic): ...@@ -210,14 +225,26 @@ class PCL(ActorCritic):
prev_log_probs], 0) prev_log_probs], 0)
relative_log_probs = tf.concat([tf.zeros([self.rollout - 1, batch_size]), relative_log_probs = tf.concat([tf.zeros([self.rollout - 1, batch_size]),
relative_log_probs], 0) relative_log_probs], 0)
target_values = tf.concat(
[self.gamma ** tf.expand_dims(
tf.range(float(self.rollout - 1), 0, -1), 1) *
tf.ones([self.rollout - 1, batch_size]) *
target_values[0:1, :],
target_values], 0)
sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout) sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout)
sum_log_probs = discounted_future_sum(log_probs, self.gamma, self.rollout) sum_log_probs = discounted_future_sum(log_probs, self.gamma, self.rollout)
sum_prev_log_probs = discounted_future_sum(prev_log_probs, self.gamma, self.rollout) sum_prev_log_probs = discounted_future_sum(prev_log_probs, self.gamma, self.rollout)
sum_relative_log_probs = discounted_future_sum( sum_relative_log_probs = discounted_future_sum(
relative_log_probs, self.gamma, self.rollout) relative_log_probs, self.gamma, self.rollout)
last_values = shift_values(value_estimates, self.gamma, self.rollout,
final_values) if self.use_target_values:
last_values = shift_values(
target_values, self.gamma, self.rollout,
final_target_values)
else:
last_values = shift_values(value_estimates, self.gamma, self.rollout,
final_values)
future_values = ( future_values = (
- self.tau * sum_log_probs - self.tau * sum_log_probs
...@@ -272,7 +299,8 @@ class TRPO(ActorCritic): ...@@ -272,7 +299,8 @@ class TRPO(ActorCritic):
def get(self, rewards, pads, values, final_values, def get(self, rewards, pads, values, final_values,
log_probs, prev_log_probs, target_log_probs, log_probs, prev_log_probs, target_log_probs,
entropies, logits): entropies, logits,
target_values, final_target_values):
not_pad = 1 - pads not_pad = 1 - pads
batch_size = tf.shape(rewards)[1] batch_size = tf.shape(rewards)[1]
...@@ -280,10 +308,18 @@ class TRPO(ActorCritic): ...@@ -280,10 +308,18 @@ class TRPO(ActorCritic):
value_estimates = not_pad * values value_estimates = not_pad * values
log_probs = not_pad * sum(log_probs) log_probs = not_pad * sum(log_probs)
prev_log_probs = not_pad * prev_log_probs prev_log_probs = not_pad * prev_log_probs
target_values = not_pad * tf.stop_gradient(target_values)
final_target_values = tf.stop_gradient(final_target_values)
sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout) sum_rewards = discounted_future_sum(rewards, self.gamma, self.rollout)
last_values = shift_values(value_estimates, self.gamma, self.rollout,
final_values) if self.use_target_values:
last_values = shift_values(
target_values, self.gamma, self.rollout,
final_target_values)
else:
last_values = shift_values(value_estimates, self.gamma, self.rollout,
final_values)
future_values = sum_rewards + last_values future_values = sum_rewards + last_values
baseline_values = value_estimates baseline_values = value_estimates
......
...@@ -150,7 +150,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): ...@@ -150,7 +150,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
def get_batch(self, n): def get_batch(self, n):
"""Get batch of episodes to train on.""" """Get batch of episodes to train on."""
p = self.sampling_distribution() p = self.sampling_distribution()
idxs = np.random.choice(self.cur_size, size=n, replace=False, p=p) idxs = np.random.choice(self.cur_size, size=int(n), replace=False, p=p)
self.last_batch = idxs self.last_batch = idxs
return [self.buffer[idx] for idx in idxs], p[idxs] return [self.buffer[idx] for idx in idxs], p[idxs]
......
...@@ -92,6 +92,8 @@ flags.DEFINE_bool('update_eps_lambda', False, ...@@ -92,6 +92,8 @@ flags.DEFINE_bool('update_eps_lambda', False,
'Update lambda automatically based on last 100 episodes.') 'Update lambda automatically based on last 100 episodes.')
flags.DEFINE_float('gamma', 1.0, 'discount') flags.DEFINE_float('gamma', 1.0, 'discount')
flags.DEFINE_integer('rollout', 10, 'rollout') flags.DEFINE_integer('rollout', 10, 'rollout')
flags.DEFINE_bool('use_target_values', False,
'use target network for value estimates')
flags.DEFINE_bool('fixed_std', True, flags.DEFINE_bool('fixed_std', True,
'fix the std in Gaussian distributions') 'fix the std in Gaussian distributions')
flags.DEFINE_bool('input_prev_actions', True, flags.DEFINE_bool('input_prev_actions', True,
...@@ -152,6 +154,10 @@ class Trainer(object): ...@@ -152,6 +154,10 @@ class Trainer(object):
self.env = gym_wrapper.GymWrapper(self.env_str, self.env = gym_wrapper.GymWrapper(self.env_str,
distinct=FLAGS.batch_size // self.num_samples, distinct=FLAGS.batch_size // self.num_samples,
count=self.num_samples) count=self.num_samples)
self.eval_env = gym_wrapper.GymWrapper(
self.env_str,
distinct=FLAGS.batch_size // self.num_samples,
count=self.num_samples)
self.env_spec = env_spec.EnvSpec(self.env.get_one()) self.env_spec = env_spec.EnvSpec(self.env.get_one())
self.max_step = FLAGS.max_step self.max_step = FLAGS.max_step
...@@ -169,7 +175,8 @@ class Trainer(object): ...@@ -169,7 +175,8 @@ class Trainer(object):
self.value_opt = FLAGS.value_opt self.value_opt = FLAGS.value_opt
assert not self.trust_region_p or self.objective in ['pcl', 'trpo'] assert not self.trust_region_p or self.objective in ['pcl', 'trpo']
assert self.objective != 'trpo' or self.trust_region_p assert self.objective != 'trpo' or self.trust_region_p
assert self.value_opt is None or self.critic_weight == 0.0 assert self.value_opt is None or self.value_opt == 'None' or \
self.critic_weight == 0.0
self.max_divergence = FLAGS.max_divergence self.max_divergence = FLAGS.max_divergence
self.learning_rate = FLAGS.learning_rate self.learning_rate = FLAGS.learning_rate
...@@ -182,6 +189,7 @@ class Trainer(object): ...@@ -182,6 +189,7 @@ class Trainer(object):
self.update_eps_lambda = FLAGS.update_eps_lambda self.update_eps_lambda = FLAGS.update_eps_lambda
self.gamma = FLAGS.gamma self.gamma = FLAGS.gamma
self.rollout = FLAGS.rollout self.rollout = FLAGS.rollout
self.use_target_values = FLAGS.use_target_values
self.fixed_std = FLAGS.fixed_std self.fixed_std = FLAGS.fixed_std
self.input_prev_actions = FLAGS.input_prev_actions self.input_prev_actions = FLAGS.input_prev_actions
self.recurrent = FLAGS.recurrent self.recurrent = FLAGS.recurrent
...@@ -208,8 +216,7 @@ class Trainer(object): ...@@ -208,8 +216,7 @@ class Trainer(object):
self.value_hidden_layers = FLAGS.value_hidden_layers self.value_hidden_layers = FLAGS.value_hidden_layers
self.tf_seed = FLAGS.tf_seed self.tf_seed = FLAGS.tf_seed
self.save_trajectories_dir = ( self.save_trajectories_dir = FLAGS.save_trajectories_dir
FLAGS.save_trajectories_dir or FLAGS.save_dir)
self.save_trajectories_file = ( self.save_trajectories_file = (
os.path.join( os.path.join(
self.save_trajectories_dir, self.env_str.replace('-', '_')) self.save_trajectories_dir, self.env_str.replace('-', '_'))
...@@ -244,7 +251,8 @@ class Trainer(object): ...@@ -244,7 +251,8 @@ class Trainer(object):
policy_weight=policy_weight, policy_weight=policy_weight,
critic_weight=self.critic_weight, critic_weight=self.critic_weight,
tau=tau, gamma=self.gamma, rollout=self.rollout, tau=tau, gamma=self.gamma, rollout=self.rollout,
eps_lambda=self.eps_lambda, clip_adv=self.clip_adv) eps_lambda=self.eps_lambda, clip_adv=self.clip_adv,
use_target_values=self.use_target_values)
elif self.objective in ['reinforce', 'urex']: elif self.objective in ['reinforce', 'urex']:
cls = (full_episode_objective.Reinforce cls = (full_episode_objective.Reinforce
if self.objective == 'reinforce' else if self.objective == 'reinforce' else
...@@ -322,10 +330,10 @@ class Trainer(object): ...@@ -322,10 +330,10 @@ class Trainer(object):
self.num_expert_paths, self.env_str, self.env_spec, self.num_expert_paths, self.env_str, self.env_spec,
load_trajectories_file=self.load_trajectories_file) load_trajectories_file=self.load_trajectories_file)
def get_controller(self): def get_controller(self, env):
"""Get controller.""" """Get controller."""
cls = controller.Controller cls = controller.Controller
return cls(self.env, self.env_spec, self.internal_dim, return cls(env, self.env_spec, self.internal_dim,
use_online_batch=self.use_online_batch, use_online_batch=self.use_online_batch,
batch_by_steps=self.batch_by_steps, batch_by_steps=self.batch_by_steps,
unify_episodes=self.unify_episodes, unify_episodes=self.unify_episodes,
...@@ -334,7 +342,7 @@ class Trainer(object): ...@@ -334,7 +342,7 @@ class Trainer(object):
cutoff_agent=self.cutoff_agent, cutoff_agent=self.cutoff_agent,
save_trajectories_file=self.save_trajectories_file, save_trajectories_file=self.save_trajectories_file,
use_trust_region=self.trust_region_p, use_trust_region=self.trust_region_p,
use_value_opt=self.value_opt is not None, use_value_opt=self.value_opt not in [None, 'None'],
update_eps_lambda=self.update_eps_lambda, update_eps_lambda=self.update_eps_lambda,
prioritize_by=self.prioritize_by, prioritize_by=self.prioritize_by,
get_model=self.get_model, get_model=self.get_model,
...@@ -359,16 +367,19 @@ class Trainer(object): ...@@ -359,16 +367,19 @@ class Trainer(object):
saver.restore(sess, ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path)
elif FLAGS.load_path: elif FLAGS.load_path:
logging.info('restoring from %s', FLAGS.load_path) logging.info('restoring from %s', FLAGS.load_path)
with gfile.AsUser('distbelief-brain-gpu'): saver.restore(sess, FLAGS.load_path)
saver.restore(sess, FLAGS.load_path)
if FLAGS.supervisor: if FLAGS.supervisor:
with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)): with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)):
self.global_step = tf.train.get_or_create_global_step() self.global_step = tf.contrib.framework.get_or_create_global_step()
tf.set_random_seed(FLAGS.tf_seed) tf.set_random_seed(FLAGS.tf_seed)
self.controller = self.get_controller() self.controller = self.get_controller(self.env)
self.model = self.controller.model self.model = self.controller.model
self.controller.setup() self.controller.setup()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self.eval_controller = self.get_controller(self.eval_env)
self.eval_controller.setup(train=False)
saver = tf.train.Saver(max_to_keep=10) saver = tf.train.Saver(max_to_keep=10)
step = self.model.global_step step = self.model.global_step
sv = tf.Supervisor(logdir=FLAGS.save_dir, sv = tf.Supervisor(logdir=FLAGS.save_dir,
...@@ -382,10 +393,14 @@ class Trainer(object): ...@@ -382,10 +393,14 @@ class Trainer(object):
sess = sv.PrepareSession(FLAGS.master) sess = sv.PrepareSession(FLAGS.master)
else: else:
tf.set_random_seed(FLAGS.tf_seed) tf.set_random_seed(FLAGS.tf_seed)
self.global_step = tf.train.get_or_create_global_step() self.global_step = tf.contrib.framework.get_or_create_global_step()
self.controller = self.get_controller() self.controller = self.get_controller(self.env)
self.model = self.controller.model self.model = self.controller.model
self.controller.setup() self.controller.setup()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self.eval_controller = self.get_controller(self.eval_env)
self.eval_controller.setup(train=False)
saver = tf.train.Saver(max_to_keep=10) saver = tf.train.Saver(max_to_keep=10)
sess = tf.Session() sess = tf.Session()
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
...@@ -414,21 +429,25 @@ class Trainer(object): ...@@ -414,21 +429,25 @@ class Trainer(object):
(loss, summary, (loss, summary,
total_rewards, episode_rewards) = self.controller.train(sess) total_rewards, episode_rewards) = self.controller.train(sess)
_, greedy_episode_rewards = self.eval_controller.eval(sess)
self.controller.greedy_episode_rewards = greedy_episode_rewards
losses.append(loss) losses.append(loss)
rewards.append(total_rewards) rewards.append(total_rewards)
all_ep_rewards.extend(episode_rewards) all_ep_rewards.extend(episode_rewards)
if random.random() < 1 and is_chief and sv and sv._summary_writer: if (random.random() < 0.1 and summary and episode_rewards and
is_chief and sv and sv._summary_writer):
sv.summary_computed(sess, summary) sv.summary_computed(sess, summary)
model_step = sess.run(self.model.global_step) model_step = sess.run(self.model.global_step)
if is_chief and step % self.validation_frequency == 0: if is_chief and step % self.validation_frequency == 0:
logging.info('at training step %d, model step %d: ' logging.info('at training step %d, model step %d: '
'avg loss %f, avg reward %f, ' 'avg loss %f, avg reward %f, '
'episode rewards: %f', 'episode rewards: %f, greedy rewards: %f',
step, model_step, step, model_step,
np.mean(losses), np.mean(rewards), np.mean(losses), np.mean(rewards),
np.mean(all_ep_rewards)) np.mean(all_ep_rewards),
np.mean(greedy_episode_rewards))
losses = [] losses = []
rewards = [] rewards = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment