Commit 052361de authored by ofirnachum's avatar ofirnachum
Browse files

add training code

parent 9b969ca5
Code for performing Hierarchical RL based on
Code for performing Hierarchical RL based on the following publications:
"Data-Efficient Hierarchical Reinforcement Learning" by
Ofir Nachum, Shixiang (Shane) Gu, Honglak Lee, and Sergey Levine
(https://arxiv.org/abs/1805.08296).
This library currently includes three of the environments used:
Ant Maze, Ant Push, and Ant Fall.
The training code is planned to be open-sourced at a later time.
"Near-Optimal Representation Learning for Hierarchical Reinforcement Learning"
by Ofir Nachum, Shixiang (Shane) Gu, Honglak Lee, and Sergey Levine
(https://arxiv.org/abs/1810.01257).
Requirements:
* TensorFlow (see http://www.tensorflow.org for how to install/upgrade)
* Gin Config (see https://github.com/google/gin-config)
* Tensorflow Agents (see https://github.com/tensorflow/agents)
* OpenAI Gym (see http://gym.openai.com/docs, be sure to install MuJoCo as well)
* NumPy (see http://www.numpy.org/)
Quick Start:
Run a random policy on AntMaze (or AntPush, AntFall):
Run a training job based on the original HIRO paper on Ant Maze:
```
python scripts/local_train.py test1 hiro_orig ant_maze base_uvf suite
```
Run a continuous evaluation job for that experiment:
```
python environments/__init__.py --env=AntMaze
python scripts/local_eval.py test1 hiro_orig ant_maze base_uvf suite
```
To run the same experiment with online representation learning (the
"Near-Optimal" paper), change `hiro_orig` to `hiro_repr`.
You can also run with `hiro_xy` to run the same experiment with HIRO on only the
xy coordinates of the agent.
To run on other environments, change `ant_maze` to something else; e.g.,
`ant_push_multi`, `ant_fall_multi`, etc. See `context/configs/*` for other options.
Basic Code Guide:
The code for training resides in train.py. The code trains a lower-level policy
(a UVF agent in the code) and a higher-level policy (a MetaAgent in the code)
concurrently. The higher-level policy communicates goals to the lower-level
policy. In the code, this is called a context. Not only does the lower-level
policy act with respect to a context (a higher-level specified goal), but the
higher-level policy also acts with respect to an environment-specified context
(corresponding to the navigation target location associated with the task).
Therefore, in `context/configs/*` you will find both specifications for task setup
as well as goal configurations. Most remaining hyperparameters used for
training/evaluation may be found in `configs/*`.
NOTE: Not all the code corresponding to the "Near-Optimal" paper is included.
Namely, changes to low-level policy training proposed in the paper (discounting
and auxiliary rewards) are not implemented here. Performance should not change
significantly.
Maintained by Ofir Nachum (ofirnachum).
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A UVF agent.
"""
import tensorflow as tf
import gin.tf
from agents import ddpg_agent
# pylint: disable=unused-import
import cond_fn
from utils import utils as uvf_utils
from context import gin_imports
# pylint: enable=unused-import
slim = tf.contrib.slim
@gin.configurable
class UvfAgentCore(object):
"""Defines basic functions for UVF agent. Must be inherited with an RL agent.
Used as lower-level agent.
"""
def __init__(self,
observation_spec,
action_spec,
tf_env,
tf_context,
step_cond_fn=cond_fn.env_transition,
reset_episode_cond_fn=cond_fn.env_restart,
reset_env_cond_fn=cond_fn.false_fn,
metrics=None,
**base_agent_kwargs):
"""Constructs a UVF agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
tf_env: A Tensorflow environment object.
tf_context: A Context class.
step_cond_fn: A function indicating whether to increment the num of steps.
reset_episode_cond_fn: A function indicating whether to restart the
episode, resampling the context.
reset_env_cond_fn: A function indicating whether to perform a manual reset
of the environment.
metrics: A list of functions that evaluate metrics of the agent.
**base_agent_kwargs: A dictionary of parameters for base RL Agent.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._step_cond_fn = step_cond_fn
self._reset_episode_cond_fn = reset_episode_cond_fn
self._reset_env_cond_fn = reset_env_cond_fn
self.metrics = metrics
# expose tf_context methods
self.tf_context = tf_context(tf_env=tf_env)
self.set_replay = self.tf_context.set_replay
self.sample_contexts = self.tf_context.sample_contexts
self.compute_rewards = self.tf_context.compute_rewards
self.gamma_index = self.tf_context.gamma_index
self.context_specs = self.tf_context.context_specs
self.context_as_action_specs = self.tf_context.context_as_action_specs
self.init_context_vars = self.tf_context.create_vars
self.env_observation_spec = observation_spec[0]
merged_observation_spec = (uvf_utils.merge_specs(
(self.env_observation_spec,) + self.context_specs),)
self._context_vars = dict()
self._action_vars = dict()
self.BASE_AGENT_CLASS.__init__(
self,
observation_spec=merged_observation_spec,
action_spec=action_spec,
**base_agent_kwargs
)
def set_meta_agent(self, agent=None):
self._meta_agent = agent
@property
def meta_agent(self):
return self._meta_agent
def actor_loss(self, states, actions, rewards, discounts,
next_states):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
return self.BASE_AGENT_CLASS.actor_loss(self, states)
def action(self, state, context=None):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
merged_state = self.merged_state(state, context)
return self.BASE_AGENT_CLASS.action(self, merged_state)
def actions(self, state, context=None):
"""Returns the next action for the state.
Args:
state: A [-1, num_state_dims] tensor representing a state.
context: A list of [-1, num_context_dims] tensor representing a context.
Returns:
A [-1, num_action_dims] tensor representing the action.
"""
merged_states = self.merged_states(state, context)
return self.BASE_AGENT_CLASS.actor_net(self, merged_states)
def log_probs(self, states, actions, state_reprs, contexts=None):
assert contexts is not None
batch_dims = [tf.shape(states)[0], tf.shape(states)[1]]
contexts = self.tf_context.context_multi_transition_fn(
contexts, states=tf.to_float(state_reprs))
flat_states = tf.reshape(states,
[batch_dims[0] * batch_dims[1], states.shape[-1]])
flat_contexts = [tf.reshape(tf.cast(context, states.dtype),
[batch_dims[0] * batch_dims[1], context.shape[-1]])
for context in contexts]
flat_pred_actions = self.actions(flat_states, flat_contexts)
pred_actions = tf.reshape(flat_pred_actions,
batch_dims + [flat_pred_actions.shape[-1]])
error = tf.square(actions - pred_actions)
spec_range = (self._action_spec.maximum - self._action_spec.minimum) / 2
normalized_error = error / tf.constant(spec_range) ** 2
return -normalized_error
@gin.configurable('uvf_add_noise_fn')
def add_noise_fn(self, action_fn, stddev=1.0, debug=False,
clip=True, global_step=None):
"""Returns the action_fn with additive Gaussian noise.
Args:
action_fn: A callable(`state`, `context`) which returns a
[num_action_dims] tensor representing a action.
stddev: stddev for the Ornstein-Uhlenbeck noise.
debug: Print debug messages.
Returns:
A [num_action_dims] action tensor.
"""
if global_step is not None:
stddev *= tf.maximum( # Decay exploration during training.
tf.train.exponential_decay(1.0, global_step, 1e6, 0.8), 0.5)
def noisy_action_fn(state, context=None):
"""Noisy action fn."""
action = action_fn(state, context)
if debug:
action = uvf_utils.tf_print(
action, [action],
message='[add_noise_fn] pre-noise action',
first_n=100)
noise_dist = tf.distributions.Normal(tf.zeros_like(action),
tf.ones_like(action) * stddev)
noise = noise_dist.sample()
action += noise
if debug:
action = uvf_utils.tf_print(
action, [action],
message='[add_noise_fn] post-noise action',
first_n=100)
if clip:
action = uvf_utils.clip_to_spec(action, self._action_spec)
return action
return noisy_action_fn
def merged_state(self, state, context=None):
"""Returns the merged state from the environment state and contexts.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
If None, use the internal context.
Returns:
A [num_merged_state_dims] tensor representing the merged state.
"""
if context is None:
context = list(self.context_vars)
state = tf.concat([state,] + context, axis=-1)
self._validate_states(self._batch_state(state))
return state
def merged_states(self, states, contexts=None):
"""Returns the batch merged state from the batch env state and contexts.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
contexts: A list of [batch_size, num_context_dims] tensor
representing a batch of contexts. If None,
use the internal context.
Returns:
A [batch_size, num_merged_state_dims] tensor representing the batch
of merged states.
"""
if contexts is None:
contexts = [tf.tile(tf.expand_dims(context, axis=0),
(tf.shape(states)[0], 1)) for
context in self.context_vars]
states = tf.concat([states,] + contexts, axis=-1)
self._validate_states(states)
return states
def unmerged_states(self, merged_states):
"""Returns the batch state and contexts from the batch merged state.
Args:
merged_states: A [batch_size, num_merged_state_dims] tensor
representing a batch of merged states.
Returns:
A [batch_size, num_state_dims] tensor and a list of
[batch_size, num_context_dims] tensors representing the batch state
and contexts respectively.
"""
self._validate_states(merged_states)
num_state_dims = self.env_observation_spec.shape.as_list()[0]
num_context_dims_list = [c.shape.as_list()[0] for c in self.context_specs]
states = merged_states[:, :num_state_dims]
contexts = []
i = num_state_dims
for num_context_dims in num_context_dims_list:
contexts.append(merged_states[:, i: i+num_context_dims])
i += num_context_dims
return states, contexts
def sample_random_actions(self, batch_size=1):
"""Return random actions.
Args:
batch_size: Batch size.
Returns:
A [batch_size, num_action_dims] tensor representing the batch of actions.
"""
actions = tf.concat(
[
tf.random_uniform(
shape=(batch_size, 1),
minval=self._action_spec.minimum[i],
maxval=self._action_spec.maximum[i])
for i in range(self._action_spec.shape[0].value)
],
axis=1)
return actions
def clip_actions(self, actions):
"""Clip actions to spec.
Args:
actions: A [batch_size, num_action_dims] tensor representing
the batch of actions.
Returns:
A [batch_size, num_action_dims] tensor representing the batch
of clipped actions.
"""
actions = tf.concat(
[
tf.clip_by_value(
actions[:, i:i+1],
self._action_spec.minimum[i],
self._action_spec.maximum[i])
for i in range(self._action_spec.shape[0].value)
],
axis=1)
return actions
def mix_contexts(self, contexts, insert_contexts, indices):
"""Mix two contexts based on indices.
Args:
contexts: A list of [batch_size, num_context_dims] tensor representing
the batch of contexts.
insert_contexts: A list of [batch_size, num_context_dims] tensor
representing the batch of contexts to be inserted.
indices: A list of a list of integers denoting indices to replace.
Returns:
A list of resulting contexts.
"""
if indices is None: indices = [[]] * len(contexts)
assert len(contexts) == len(indices)
assert all([spec.shape.ndims == 1 for spec in self.context_specs])
mix_contexts = []
for contexts_, insert_contexts_, indices_, spec in zip(
contexts, insert_contexts, indices, self.context_specs):
mix_contexts.append(
tf.concat(
[
insert_contexts_[:, i:i + 1] if i in indices_ else
contexts_[:, i:i + 1] for i in range(spec.shape.as_list()[0])
],
axis=1))
return mix_contexts
def begin_episode_ops(self, mode, action_fn=None, state=None):
"""Returns ops that reset agent at beginning of episodes.
Args:
mode: a string representing the mode=[train, explore, eval].
Returns:
A list of ops.
"""
all_ops = []
for _, action_var in sorted(self._action_vars.items()):
sample_action = self.sample_random_actions(1)[0]
all_ops.append(tf.assign(action_var, sample_action))
all_ops += self.tf_context.reset(mode=mode, agent=self._meta_agent,
action_fn=action_fn, state=state)
return all_ops
def cond_begin_episode_op(self, cond, input_vars, mode, meta_action_fn):
"""Returns op that resets agent at beginning of episodes.
A new episode is begun if the cond op evalues to `False`.
Args:
cond: a Boolean tensor variable.
input_vars: A list of tensor variables.
mode: a string representing the mode=[train, explore, eval].
Returns:
Conditional begin op.
"""
(state, action, reward, next_state,
state_repr, next_state_repr) = input_vars
def continue_fn():
"""Continue op fn."""
items = [state, action, reward, next_state,
state_repr, next_state_repr] + list(self.context_vars)
batch_items = [tf.expand_dims(item, 0) for item in items]
(states, actions, rewards, next_states,
state_reprs, next_state_reprs) = batch_items[:6]
context_reward = self.compute_rewards(
mode, state_reprs, actions, rewards, next_state_reprs,
batch_items[6:])[0][0]
context_reward = tf.cast(context_reward, dtype=reward.dtype)
if self.meta_agent is not None:
meta_action = tf.concat(self.context_vars, -1)
items = [state, meta_action, reward, next_state,
state_repr, next_state_repr] + list(self.meta_agent.context_vars)
batch_items = [tf.expand_dims(item, 0) for item in items]
(states, meta_actions, rewards, next_states,
state_reprs, next_state_reprs) = batch_items[:6]
meta_reward = self.meta_agent.compute_rewards(
mode, states, meta_actions, rewards,
next_states, batch_items[6:])[0][0]
meta_reward = tf.cast(meta_reward, dtype=reward.dtype)
else:
meta_reward = tf.constant(0, dtype=reward.dtype)
with tf.control_dependencies([context_reward, meta_reward]):
step_ops = self.tf_context.step(mode=mode, agent=self._meta_agent,
state=state,
next_state=next_state,
state_repr=state_repr,
next_state_repr=next_state_repr,
action_fn=meta_action_fn)
with tf.control_dependencies(step_ops):
context_reward, meta_reward = map(tf.identity, [context_reward, meta_reward])
return context_reward, meta_reward
def begin_episode_fn():
"""Begin op fn."""
begin_ops = self.begin_episode_ops(mode=mode, action_fn=meta_action_fn, state=state)
with tf.control_dependencies(begin_ops):
return tf.zeros_like(reward), tf.zeros_like(reward)
with tf.control_dependencies(input_vars):
cond_begin_episode_op = tf.cond(cond, continue_fn, begin_episode_fn)
return cond_begin_episode_op
def get_env_base_wrapper(self, env_base, **begin_kwargs):
"""Create a wrapper around env_base, with agent-specific begin/end_episode.
Args:
env_base: A python environment base.
**begin_kwargs: Keyword args for begin_episode_ops.
Returns:
An object with begin_episode() and end_episode().
"""
begin_ops = self.begin_episode_ops(**begin_kwargs)
return uvf_utils.get_contextual_env_base(env_base, begin_ops)
def init_action_vars(self, name, i=None):
"""Create and return a tensorflow Variable holding an action.
Args:
name: Name of the variables.
i: Integer id.
Returns:
A [num_action_dims] tensor.
"""
if i is not None:
name += '_%d' % i
assert name not in self._action_vars, ('Conflict! %s is already '
'initialized.') % name
self._action_vars[name] = tf.Variable(
self.sample_random_actions(1)[0], name='%s_action' % (name))
self._validate_actions(tf.expand_dims(self._action_vars[name], 0))
return self._action_vars[name]
@gin.configurable('uvf_critic_function')
def critic_function(self, critic_vals, states, critic_fn=None):
"""Computes q values based on outputs from the critic net.
Args:
critic_vals: A tf.float32 [batch_size, ...] tensor representing outputs
from the critic net.
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
critic_fn: A callable that process outputs from critic_net and
outputs a [batch_size] tensor representing q values.
Returns:
A tf.float32 [batch_size] tensor representing q values.
"""
if critic_fn is not None:
env_states, contexts = self.unmerged_states(states)
critic_vals = critic_fn(critic_vals, env_states, contexts)
critic_vals.shape.assert_has_rank(1)
return critic_vals
def get_action_vars(self, key):
return self._action_vars[key]
def get_context_vars(self, key):
return self.tf_context.context_vars[key]
def step_cond_fn(self, *args):
return self._step_cond_fn(self, *args)
def reset_episode_cond_fn(self, *args):
return self._reset_episode_cond_fn(self, *args)
def reset_env_cond_fn(self, *args):
return self._reset_env_cond_fn(self, *args)
@property
def context_vars(self):
return self.tf_context.vars
@gin.configurable
class MetaAgentCore(UvfAgentCore):
"""Defines basic functions for UVF Meta-agent. Must be inherited with an RL agent.
Used as higher-level agent.
"""
def __init__(self,
observation_spec,
action_spec,
tf_env,
tf_context,
sub_context,
step_cond_fn=cond_fn.env_transition,
reset_episode_cond_fn=cond_fn.env_restart,
reset_env_cond_fn=cond_fn.false_fn,
metrics=None,
actions_reg=0.,
k=2,
**base_agent_kwargs):
"""Constructs a Meta agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
tf_env: A Tensorflow environment object.
tf_context: A Context class.
step_cond_fn: A function indicating whether to increment the num of steps.
reset_episode_cond_fn: A function indicating whether to restart the
episode, resampling the context.
reset_env_cond_fn: A function indicating whether to perform a manual reset
of the environment.
metrics: A list of functions that evaluate metrics of the agent.
**base_agent_kwargs: A dictionary of parameters for base RL Agent.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._step_cond_fn = step_cond_fn
self._reset_episode_cond_fn = reset_episode_cond_fn
self._reset_env_cond_fn = reset_env_cond_fn
self.metrics = metrics
self._actions_reg = actions_reg
self._k = k
# expose tf_context methods
self.tf_context = tf_context(tf_env=tf_env)
self.sub_context = sub_context(tf_env=tf_env)
self.set_replay = self.tf_context.set_replay
self.sample_contexts = self.tf_context.sample_contexts
self.compute_rewards = self.tf_context.compute_rewards
self.gamma_index = self.tf_context.gamma_index
self.context_specs = self.tf_context.context_specs
self.context_as_action_specs = self.tf_context.context_as_action_specs
self.sub_context_as_action_specs = self.sub_context.context_as_action_specs
self.init_context_vars = self.tf_context.create_vars
self.env_observation_spec = observation_spec[0]
merged_observation_spec = (uvf_utils.merge_specs(
(self.env_observation_spec,) + self.context_specs),)
self._context_vars = dict()
self._action_vars = dict()
assert len(self.context_as_action_specs) == 1
self.BASE_AGENT_CLASS.__init__(
self,
observation_spec=merged_observation_spec,
action_spec=self.sub_context_as_action_specs,
**base_agent_kwargs
)
@gin.configurable('meta_add_noise_fn')
def add_noise_fn(self, action_fn, stddev=1.0, debug=False,
global_step=None):
noisy_action_fn = super(MetaAgentCore, self).add_noise_fn(
action_fn, stddev,
clip=True, global_step=global_step)
return noisy_action_fn
def actor_loss(self, states, actions, rewards, discounts,
next_states):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
actions = self.actor_net(states, stop_gradients=False)
regularizer = self._actions_reg * tf.reduce_mean(
tf.reduce_sum(tf.abs(actions[:, self._k:]), -1), 0)
loss = self.BASE_AGENT_CLASS.actor_loss(self, states)
return regularizer + loss
@gin.configurable
class UvfAgent(UvfAgentCore, ddpg_agent.TD3Agent):
"""A DDPG agent with UVF.
"""
BASE_AGENT_CLASS = ddpg_agent.TD3Agent
ACTION_TYPE = 'continuous'
def __init__(self, *args, **kwargs):
UvfAgentCore.__init__(self, *args, **kwargs)
@gin.configurable
class MetaAgent(MetaAgentCore, ddpg_agent.TD3Agent):
"""A DDPG meta-agent.
"""
BASE_AGENT_CLASS = ddpg_agent.TD3Agent
ACTION_TYPE = 'continuous'
def __init__(self, *args, **kwargs):
MetaAgentCore.__init__(self, *args, **kwargs)
@gin.configurable()
def state_preprocess_net(
states,
num_output_dims=2,
states_hidden_layers=(100,),
normalizer_fn=None,
activation_fn=tf.nn.relu,
zero_time=True,
images=False):
"""Creates a simple feed forward net for embedding states.
"""
with slim.arg_scope(
[slim.fully_connected],
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
weights_initializer=slim.variance_scaling_initializer(
factor=1.0/3.0, mode='FAN_IN', uniform=True)):
states_shape = tf.shape(states)
states_dtype = states.dtype
states = tf.to_float(states)
if images: # Zero-out x-y
states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype)
if zero_time:
states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)
orig_states = states
embed = states
if states_hidden_layers:
embed = slim.stack(embed, slim.fully_connected, states_hidden_layers,
scope='states')
with slim.arg_scope([slim.fully_connected],
weights_regularizer=None,
weights_initializer=tf.random_uniform_initializer(
minval=-0.003, maxval=0.003)):
embed = slim.fully_connected(embed, num_output_dims,
activation_fn=None,
normalizer_fn=None,
scope='value')
output = embed
output = tf.cast(output, states_dtype)
return output
@gin.configurable()
def action_embed_net(
actions,
states=None,
num_output_dims=2,
hidden_layers=(400, 300),
normalizer_fn=None,
activation_fn=tf.nn.relu,
zero_time=True,
images=False):
"""Creates a simple feed forward net for embedding actions.
"""
with slim.arg_scope(
[slim.fully_connected],
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
weights_initializer=slim.variance_scaling_initializer(
factor=1.0/3.0, mode='FAN_IN', uniform=True)):
actions = tf.to_float(actions)
if states is not None:
if images: # Zero-out x-y
states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype)
if zero_time:
states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)
actions = tf.concat([actions, tf.to_float(states)], -1)
embed = actions
if hidden_layers:
embed = slim.stack(embed, slim.fully_connected, hidden_layers,
scope='hidden')
with slim.arg_scope([slim.fully_connected],
weights_regularizer=None,
weights_initializer=tf.random_uniform_initializer(
minval=-0.003, maxval=0.003)):
embed = slim.fully_connected(embed, num_output_dims,
activation_fn=None,
normalizer_fn=None,
scope='value')
if num_output_dims == 1:
return embed[:, 0, ...]
else:
return embed
def huber(x, kappa=0.1):
return (0.5 * tf.square(x) * tf.to_float(tf.abs(x) <= kappa) +
kappa * (tf.abs(x) - 0.5 * kappa) * tf.to_float(tf.abs(x) > kappa)
) / kappa
@gin.configurable()
class StatePreprocess(object):
STATE_PREPROCESS_NET_SCOPE = 'state_process_net'
ACTION_EMBED_NET_SCOPE = 'action_embed_net'
def __init__(self, trainable=False,
state_preprocess_net=lambda states: states,
action_embed_net=lambda actions, *args, **kwargs: actions,
ndims=None):
self.trainable = trainable
self._scope = tf.get_variable_scope().name
self._ndims = ndims
self._state_preprocess_net = tf.make_template(
self.STATE_PREPROCESS_NET_SCOPE, state_preprocess_net,
create_scope_now_=True)
self._action_embed_net = tf.make_template(
self.ACTION_EMBED_NET_SCOPE, action_embed_net,
create_scope_now_=True)
def __call__(self, states):
batched = states.get_shape().ndims != 1
if not batched:
states = tf.expand_dims(states, 0)
embedded = self._state_preprocess_net(states)
if self._ndims is not None:
embedded = embedded[..., :self._ndims]
if not batched:
return embedded[0]
return embedded
def loss(self, states, next_states, low_actions, low_states):
batch_size = tf.shape(states)[0]
d = int(low_states.shape[1])
# Sample indices into meta-transition to train on.
probs = 0.99 ** tf.range(d, dtype=tf.float32)
probs *= tf.constant([1.0] * (d - 1) + [1.0 / (1 - 0.99)],
dtype=tf.float32)
probs /= tf.reduce_sum(probs)
index_dist = tf.distributions.Categorical(probs=probs, dtype=tf.int64)
indices = index_dist.sample(batch_size)
batch_size = tf.cast(batch_size, tf.int64)
next_indices = tf.concat(
[tf.range(batch_size, dtype=tf.int64)[:, None],
(1 + indices[:, None]) % d], -1)
new_next_states = tf.where(indices < d - 1,
tf.gather_nd(low_states, next_indices),
next_states)
next_states = new_next_states
embed1 = tf.to_float(self._state_preprocess_net(states))
embed2 = tf.to_float(self._state_preprocess_net(next_states))
action_embed = self._action_embed_net(
tf.layers.flatten(low_actions), states=states)
tau = 2.0
fn = lambda z: tau * tf.reduce_sum(huber(z), -1)
all_embed = tf.get_variable('all_embed', [1024, int(embed1.shape[-1])],
initializer=tf.zeros_initializer())
upd = all_embed.assign(tf.concat([all_embed[batch_size:], embed2], 0))
with tf.control_dependencies([upd]):
close = 1 * tf.reduce_mean(fn(embed1 + action_embed - embed2))
prior_log_probs = tf.reduce_logsumexp(
-fn((embed1 + action_embed)[:, None, :] - all_embed[None, :, :]),
axis=-1) - tf.log(tf.to_float(all_embed.shape[0]))
far = tf.reduce_mean(tf.exp(-fn((embed1 + action_embed)[1:] - embed2[:-1])
- tf.stop_gradient(prior_log_probs[1:])))
repr_log_probs = tf.stop_gradient(
-fn(embed1 + action_embed - embed2) - prior_log_probs) / tau
return close + far, repr_log_probs, indices
def get_trainable_vars(self):
return (
slim.get_trainable_variables(
uvf_utils.join_scope(self._scope, self.STATE_PREPROCESS_NET_SCOPE)) +
slim.get_trainable_variables(
uvf_utils.join_scope(self._scope, self.ACTION_EMBED_NET_SCOPE)))
@gin.configurable()
class InverseDynamics(object):
INVERSE_DYNAMICS_NET_SCOPE = 'inverse_dynamics'
def __init__(self, spec):
self._spec = spec
def sample(self, states, next_states, num_samples, orig_goals, sc=0.5):
goal_dim = orig_goals.shape[-1]
spec_range = (self._spec.maximum - self._spec.minimum) / 2 * tf.ones([goal_dim])
loc = tf.cast(next_states - states, tf.float32)[:, :goal_dim]
scale = sc * tf.tile(tf.reshape(spec_range, [1, goal_dim]),
[tf.shape(states)[0], 1])
dist = tf.distributions.Normal(loc, scale)
if num_samples == 1:
return dist.sample()
samples = tf.concat([dist.sample(num_samples - 2),
tf.expand_dims(loc, 0),
tf.expand_dims(orig_goals, 0)], 0)
return uvf_utils.clip_to_spec(samples, self._spec)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A circular buffer where each element is a list of tensors.
Each element of the buffer is a list of tensors. An example use case is a replay
buffer in reinforcement learning, where each element is a list of tensors
representing the state, action, reward etc.
New elements are added sequentially, and once the buffer is full, we
start overwriting them in a circular fashion. Reading does not remove any
elements, only adding new elements does.
"""
import collections
import numpy as np
import tensorflow as tf
import gin.tf
@gin.configurable
class CircularBuffer(object):
"""A circular buffer where each element is a list of tensors."""
def __init__(self, buffer_size=1000, scope='replay_buffer'):
"""Circular buffer of list of tensors.
Args:
buffer_size: (integer) maximum number of tensor lists the buffer can hold.
scope: (string) variable scope for creating the variables.
"""
self._buffer_size = np.int64(buffer_size)
self._scope = scope
self._tensors = collections.OrderedDict()
with tf.variable_scope(self._scope):
self._num_adds = tf.Variable(0, dtype=tf.int64, name='num_adds')
self._num_adds_cs = tf.contrib.framework.CriticalSection(name='num_adds')
@property
def buffer_size(self):
return self._buffer_size
@property
def scope(self):
return self._scope
@property
def num_adds(self):
return self._num_adds
def _create_variables(self, tensors):
with tf.variable_scope(self._scope):
for name in tensors.keys():
tensor = tensors[name]
self._tensors[name] = tf.get_variable(
name='BufferVariable_' + name,
shape=[self._buffer_size] + tensor.get_shape().as_list(),
dtype=tensor.dtype,
trainable=False)
def _validate(self, tensors):
"""Validate shapes of tensors."""
if len(tensors) != len(self._tensors):
raise ValueError('Expected tensors to have %d elements. Received %d '
'instead.' % (len(self._tensors), len(tensors)))
if self._tensors.keys() != tensors.keys():
raise ValueError('The keys of tensors should be the always the same.'
'Received %s instead %s.' %
(tensors.keys(), self._tensors.keys()))
for name, tensor in tensors.items():
if tensor.get_shape().as_list() != self._tensors[
name].get_shape().as_list()[1:]:
raise ValueError('Tensor %s has incorrect shape.' % name)
if not tensor.dtype.is_compatible_with(self._tensors[name].dtype):
raise ValueError(
'Tensor %s has incorrect data type. Expected %s, received %s' %
(name, self._tensors[name].read_value().dtype, tensor.dtype))
def add(self, tensors):
"""Adds an element (list/tuple/dict of tensors) to the buffer.
Args:
tensors: (list/tuple/dict of tensors) to be added to the buffer.
Returns:
An add operation that adds the input `tensors` to the buffer. Similar to
an enqueue_op.
Raises:
ValueError: If the shapes and data types of input `tensors' are not the
same across calls to the add function.
"""
return self.maybe_add(tensors, True)
def maybe_add(self, tensors, condition):
"""Adds an element (tensors) to the buffer based on the condition..
Args:
tensors: (list/tuple of tensors) to be added to the buffer.
condition: A boolean Tensor controlling whether the tensors would be added
to the buffer or not.
Returns:
An add operation that adds the input `tensors` to the buffer. Similar to
an maybe_enqueue_op.
Raises:
ValueError: If the shapes and data types of input `tensors' are not the
same across calls to the add function.
"""
if not isinstance(tensors, dict):
names = [str(i) for i in range(len(tensors))]
tensors = collections.OrderedDict(zip(names, tensors))
if not isinstance(tensors, collections.OrderedDict):
tensors = collections.OrderedDict(
sorted(tensors.items(), key=lambda t: t[0]))
if not self._tensors:
self._create_variables(tensors)
else:
self._validate(tensors)
#@tf.critical_section(self._position_mutex)
def _increment_num_adds():
# Adding 0 to the num_adds variable is a trick to read the value of the
# variable and return a read-only tensor. Doing this in a critical
# section allows us to capture a snapshot of the variable that will
# not be affected by other threads updating num_adds.
return self._num_adds.assign_add(1) + 0
def _add():
num_adds_inc = self._num_adds_cs.execute(_increment_num_adds)
current_pos = tf.mod(num_adds_inc - 1, self._buffer_size)
update_ops = []
for name in self._tensors.keys():
update_ops.append(
tf.scatter_update(self._tensors[name], current_pos, tensors[name]))
return tf.group(*update_ops)
return tf.contrib.framework.smart_cond(condition, _add, tf.no_op)
def get_random_batch(self, batch_size, keys=None, num_steps=1):
"""Samples a batch of tensors from the buffer with replacement.
Args:
batch_size: (integer) number of elements to sample.
keys: List of keys of tensors to retrieve. If None retrieve all.
num_steps: (integer) length of trajectories to return. If > 1 will return
a list of lists, where each internal list represents a trajectory of
length num_steps.
Returns:
A list of tensors, where each element in the list is a batch sampled from
one of the tensors in the buffer.
Raises:
ValueError: If get_random_batch is called before calling the add function.
tf.errors.InvalidArgumentError: If this operation is executed before any
items are added to the buffer.
"""
if not self._tensors:
raise ValueError('The add function must be called before get_random_batch.')
if keys is None:
keys = self._tensors.keys()
latest_start_index = self.get_num_adds() - num_steps + 1
empty_buffer_assert = tf.Assert(
tf.greater(latest_start_index, 0),
['Not enough elements have been added to the buffer.'])
with tf.control_dependencies([empty_buffer_assert]):
max_index = tf.minimum(self._buffer_size, latest_start_index)
indices = tf.random_uniform(
[batch_size],
minval=0,
maxval=max_index,
dtype=tf.int64)
if num_steps == 1:
return self.gather(indices, keys)
else:
return self.gather_nstep(num_steps, indices, keys)
def gather(self, indices, keys=None):
"""Returns elements at the specified indices from the buffer.
Args:
indices: (list of integers or rank 1 int Tensor) indices in the buffer to
retrieve elements from.
keys: List of keys of tensors to retrieve. If None retrieve all.
Returns:
A list of tensors, where each element in the list is obtained by indexing
one of the tensors in the buffer.
Raises:
ValueError: If gather is called before calling the add function.
tf.errors.InvalidArgumentError: If indices are bigger than the number of
items in the buffer.
"""
if not self._tensors:
raise ValueError('The add function must be called before calling gather.')
if keys is None:
keys = self._tensors.keys()
with tf.name_scope('Gather'):
index_bound_assert = tf.Assert(
tf.less(
tf.to_int64(tf.reduce_max(indices)),
tf.minimum(self.get_num_adds(), self._buffer_size)),
['Index out of bounds.'])
with tf.control_dependencies([index_bound_assert]):
indices = tf.convert_to_tensor(indices)
batch = []
for key in keys:
batch.append(tf.gather(self._tensors[key], indices, name=key))
return batch
def gather_nstep(self, num_steps, indices, keys=None):
"""Returns elements at the specified indices from the buffer.
Args:
num_steps: (integer) length of trajectories to return.
indices: (list of rank num_steps int Tensor) indices in the buffer to
retrieve elements from for multiple trajectories. Each Tensor in the
list represents the indices for a trajectory.
keys: List of keys of tensors to retrieve. If None retrieve all.
Returns:
A list of list-of-tensors, where each element in the list is obtained by
indexing one of the tensors in the buffer.
Raises:
ValueError: If gather is called before calling the add function.
tf.errors.InvalidArgumentError: If indices are bigger than the number of
items in the buffer.
"""
if not self._tensors:
raise ValueError('The add function must be called before calling gather.')
if keys is None:
keys = self._tensors.keys()
with tf.name_scope('Gather'):
index_bound_assert = tf.Assert(
tf.less_equal(
tf.to_int64(tf.reduce_max(indices) + num_steps),
self.get_num_adds()),
['Trajectory indices go out of bounds.'])
with tf.control_dependencies([index_bound_assert]):
indices = tf.map_fn(
lambda x: tf.mod(tf.range(x, x + num_steps), self._buffer_size),
indices,
dtype=tf.int64)
batch = []
for key in keys:
def SampleTrajectories(trajectory_indices, key=key,
num_steps=num_steps):
trajectory_indices.set_shape([num_steps])
return tf.gather(self._tensors[key], trajectory_indices, name=key)
batch.append(tf.map_fn(SampleTrajectories, indices,
dtype=self._tensors[key].dtype))
return batch
def get_position(self):
"""Returns the position at which the last element was added.
Returns:
An int tensor representing the index at which the last element was added
to the buffer or -1 if no elements were added.
"""
return tf.cond(self.get_num_adds() < 1,
lambda: self.get_num_adds() - 1,
lambda: tf.mod(self.get_num_adds() - 1, self._buffer_size))
def get_num_adds(self):
"""Returns the number of additions to the buffer.
Returns:
An int tensor representing the number of elements that were added.
"""
def num_adds():
return self._num_adds.value()
return self._num_adds_cs.execute(num_adds)
def get_num_tensors(self):
"""Returns the number of tensors (slots) in the buffer."""
return len(self._tensors)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A DDPG/NAF agent.
Implements the Deep Deterministic Policy Gradient (DDPG) algorithm from
"Continuous control with deep reinforcement learning" - Lilicrap et al.
https://arxiv.org/abs/1509.02971, and the Normalized Advantage Functions (NAF)
algorithm "Continuous Deep Q-Learning with Model-based Acceleration" - Gu et al.
https://arxiv.org/pdf/1603.00748.
"""
import tensorflow as tf
slim = tf.contrib.slim
import gin.tf
from utils import utils
from agents import ddpg_networks as networks
@gin.configurable
class DdpgAgent(object):
"""An RL agent that learns using the DDPG algorithm.
Example usage:
def critic_net(states, actions):
...
def actor_net(states, num_action_dims):
...
Given a tensorflow environment tf_env,
(of type learning.deepmind.rl.environments.tensorflow.python.tfpyenvironment)
obs_spec = tf_env.observation_spec()
action_spec = tf_env.action_spec()
ddpg_agent = agent.DdpgAgent(obs_spec,
action_spec,
actor_net=actor_net,
critic_net=critic_net)
we can perform actions on the environment as follows:
state = tf_env.observations()[0]
action = ddpg_agent.actor_net(tf.expand_dims(state, 0))[0, :]
transition_type, reward, discount = tf_env.step([action])
Train:
critic_loss = ddpg_agent.critic_loss(states, actions, rewards, discounts,
next_states)
actor_loss = ddpg_agent.actor_loss(states)
critic_train_op = slim.learning.create_train_op(
critic_loss,
critic_optimizer,
variables_to_train=ddpg_agent.get_trainable_critic_vars(),
)
actor_train_op = slim.learning.create_train_op(
actor_loss,
actor_optimizer,
variables_to_train=ddpg_agent.get_trainable_actor_vars(),
)
"""
ACTOR_NET_SCOPE = 'actor_net'
CRITIC_NET_SCOPE = 'critic_net'
TARGET_ACTOR_NET_SCOPE = 'target_actor_net'
TARGET_CRITIC_NET_SCOPE = 'target_critic_net'
def __init__(self,
observation_spec,
action_spec,
actor_net=networks.actor_net,
critic_net=networks.critic_net,
td_errors_loss=tf.losses.huber_loss,
dqda_clipping=0.,
actions_regularizer=0.,
target_q_clipping=None,
residual_phi=0.0,
debug_summaries=False):
"""Constructs a DDPG agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
actor_net: A callable that creates the actor network. Must take the
following arguments: states, num_actions. Please see networks.actor_net
for an example.
critic_net: A callable that creates the critic network. Must take the
following arguments: states, actions. Please see networks.critic_net
for an example.
td_errors_loss: A callable defining the loss function for the critic
td error.
dqda_clipping: (float) clips the gradient dqda element-wise between
[-dqda_clipping, dqda_clipping]. Does not perform clipping if
dqda_clipping == 0.
actions_regularizer: A scalar, when positive penalizes the norm of the
actions. This can prevent saturation of actions for the actor_loss.
target_q_clipping: (tuple of floats) clips target q values within
(low, high) values when computing the critic loss.
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that
interpolates between Q-learning and residual gradient algorithm.
http://www.leemon.com/papers/1995b.pdf
debug_summaries: If True, add summaries to help debug behavior.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._observation_spec = observation_spec[0]
self._action_spec = action_spec[0]
self._state_shape = tf.TensorShape([None]).concatenate(
self._observation_spec.shape)
self._action_shape = tf.TensorShape([None]).concatenate(
self._action_spec.shape)
self._num_action_dims = self._action_spec.shape.num_elements()
self._scope = tf.get_variable_scope().name
self._actor_net = tf.make_template(
self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._critic_net = tf.make_template(
self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._target_actor_net = tf.make_template(
self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._target_critic_net = tf.make_template(
self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._td_errors_loss = td_errors_loss
if dqda_clipping < 0:
raise ValueError('dqda_clipping must be >= 0.')
self._dqda_clipping = dqda_clipping
self._actions_regularizer = actions_regularizer
self._target_q_clipping = target_q_clipping
self._residual_phi = residual_phi
self._debug_summaries = debug_summaries
def _batch_state(self, state):
"""Convert state to a batched state.
Args:
state: Either a list/tuple with an state tensor [num_state_dims].
Returns:
A tensor [1, num_state_dims]
"""
if isinstance(state, (tuple, list)):
state = state[0]
if state.get_shape().ndims == 1:
state = tf.expand_dims(state, 0)
return state
def action(self, state):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
Returns:
A [num_action_dims] tensor representing the action.
"""
return self.actor_net(self._batch_state(state), stop_gradients=True)[0, :]
@gin.configurable('ddpg_sample_action')
def sample_action(self, state, stddev=1.0):
"""Returns the action for the state with additive noise.
Args:
state: A [num_state_dims] tensor representing a state.
stddev: stddev for the Ornstein-Uhlenbeck noise.
Returns:
A [num_action_dims] action tensor.
"""
agent_action = self.action(state)
agent_action += tf.random_normal(tf.shape(agent_action)) * stddev
return utils.clip_to_spec(agent_action, self._action_spec)
def actor_net(self, states, stop_gradients=False):
"""Returns the output of the actor network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
stop_gradients: (boolean) if true, gradients cannot be propogated through
this operation.
Returns:
A [batch_size, num_action_dims] tensor of actions.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self._validate_states(states)
actions = self._actor_net(states, self._action_spec)
if stop_gradients:
actions = tf.stop_gradient(actions)
return actions
def critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the critic network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self._validate_states(states)
self._validate_actions(actions)
return self._critic_net(states, actions,
for_critic_loss=for_critic_loss)
def target_actor_net(self, states):
"""Returns the output of the target actor network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
A [batch_size, num_action_dims] tensor of actions.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self._validate_states(states)
actions = self._target_actor_net(states, self._action_spec)
return tf.stop_gradient(actions)
def target_critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the target critic network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self._validate_states(states)
self._validate_actions(actions)
return tf.stop_gradient(
self._target_critic_net(states, actions,
for_critic_loss=for_critic_loss))
def value_net(self, states, for_critic_loss=False):
"""Returns the output of the critic evaluated with the actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
actions = self.actor_net(states)
return self.critic_net(states, actions,
for_critic_loss=for_critic_loss)
def target_value_net(self, states, for_critic_loss=False):
"""Returns the output of the target critic evaluated with the target actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
target_actions = self.target_actor_net(states)
return self.target_critic_net(states, target_actions,
for_critic_loss=for_critic_loss)
def critic_loss(self, states, actions, rewards, discounts,
next_states):
"""Computes a loss for training the critic network.
The loss is the mean squared error between the Q value predictions of the
critic and Q values estimated using TD-lambda.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
rewards: A [batch_size, ...] tensor representing a batch of rewards,
broadcastable to the critic net output.
discounts: A [batch_size, ...] tensor representing a batch of discounts,
broadcastable to the critic net output.
next_states: A [batch_size, num_state_dims] tensor representing a batch
of next states.
Returns:
A rank-0 tensor representing the critic loss.
Raises:
ValueError: If any of the inputs do not have the expected dimensions, or
if their batch_sizes do not match.
"""
self._validate_states(states)
self._validate_actions(actions)
self._validate_states(next_states)
target_q_values = self.target_value_net(next_states, for_critic_loss=True)
td_targets = target_q_values * discounts + rewards
if self._target_q_clipping is not None:
td_targets = tf.clip_by_value(td_targets, self._target_q_clipping[0],
self._target_q_clipping[1])
q_values = self.critic_net(states, actions, for_critic_loss=True)
td_errors = td_targets - q_values
if self._debug_summaries:
gen_debug_td_error_summaries(
target_q_values, q_values, td_targets, td_errors)
loss = self._td_errors_loss(td_targets, q_values)
if self._residual_phi > 0.0: # compute residual gradient loss
residual_q_values = self.value_net(next_states, for_critic_loss=True)
residual_td_targets = residual_q_values * discounts + rewards
if self._target_q_clipping is not None:
residual_td_targets = tf.clip_by_value(residual_td_targets,
self._target_q_clipping[0],
self._target_q_clipping[1])
residual_td_errors = residual_td_targets - q_values
residual_loss = self._td_errors_loss(
residual_td_targets, residual_q_values)
loss = (loss * (1.0 - self._residual_phi) +
residual_loss * self._residual_phi)
return loss
def actor_loss(self, states):
"""Computes a loss for training the actor network.
Note that output does not represent an actual loss. It is called a loss only
in the sense that its gradient w.r.t. the actor network weights is the
correct gradient for training the actor network,
i.e. dloss/dweights = (dq/da)*(da/dweights)
which is the gradient used in Algorithm 1 of Lilicrap et al.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
A rank-0 tensor representing the actor loss.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self._validate_states(states)
actions = self.actor_net(states, stop_gradients=False)
critic_values = self.critic_net(states, actions)
q_values = self.critic_function(critic_values, states)
dqda = tf.gradients([q_values], [actions])[0]
dqda_unclipped = dqda
if self._dqda_clipping > 0:
dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping)
actions_norm = tf.norm(actions)
if self._debug_summaries:
with tf.name_scope('dqda'):
tf.summary.scalar('actions_norm', actions_norm)
tf.summary.histogram('dqda', dqda)
tf.summary.histogram('dqda_unclipped', dqda_unclipped)
tf.summary.histogram('actions', actions)
for a in range(self._num_action_dims):
tf.summary.histogram('dqda_unclipped_%d' % a, dqda_unclipped[:, a])
tf.summary.histogram('dqda_%d' % a, dqda[:, a])
actions_norm *= self._actions_regularizer
return slim.losses.mean_squared_error(tf.stop_gradient(dqda + actions),
actions,
scope='actor_loss') + actions_norm
@gin.configurable('ddpg_critic_function')
def critic_function(self, critic_values, states, weights=None):
"""Computes q values based on critic_net outputs, states, and weights.
Args:
critic_values: A tf.float32 [batch_size, ...] tensor representing outputs
from the critic net.
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
weights: A list or Numpy array or tensor with a shape broadcastable to
`critic_values`.
Returns:
A tf.float32 [batch_size] tensor representing q values.
"""
del states # unused args
if weights is not None:
weights = tf.convert_to_tensor(weights, dtype=critic_values.dtype)
critic_values *= weights
if critic_values.shape.ndims > 1:
critic_values = tf.reduce_sum(critic_values,
range(1, critic_values.shape.ndims))
critic_values.shape.assert_has_rank(1)
return critic_values
@gin.configurable('ddpg_update_targets')
def update_targets(self, tau=1.0):
"""Performs a soft update of the target network parameters.
For each weight w_s in the actor/critic networks, and its corresponding
weight w_t in the target actor/critic networks, a soft update is:
w_t = (1- tau) x w_t + tau x ws
Args:
tau: A float scalar in [0, 1]
Returns:
An operation that performs a soft update of the target network parameters.
Raises:
ValueError: If `tau` is not in [0, 1].
"""
if tau < 0 or tau > 1:
raise ValueError('Input `tau` should be in [0, 1].')
update_actor = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)),
tau)
update_critic = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)),
tau)
return tf.group(update_actor, update_critic, name='update_targets')
def get_trainable_critic_vars(self):
"""Returns a list of trainable variables in the critic network.
Returns:
A list of trainable variables in the critic network.
"""
return slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))
def get_trainable_actor_vars(self):
"""Returns a list of trainable variables in the actor network.
Returns:
A list of trainable variables in the actor network.
"""
return slim.get_trainable_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE))
def get_critic_vars(self):
"""Returns a list of all variables in the critic network.
Returns:
A list of trainable variables in the critic network.
"""
return slim.get_model_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))
def get_actor_vars(self):
"""Returns a list of all variables in the actor network.
Returns:
A list of trainable variables in the actor network.
"""
return slim.get_model_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE))
def _validate_states(self, states):
"""Raises a value error if `states` does not have the expected shape.
Args:
states: A tensor.
Raises:
ValueError: If states.shape or states.dtype are not compatible with
observation_spec.
"""
states.shape.assert_is_compatible_with(self._state_shape)
if not states.dtype.is_compatible_with(self._observation_spec.dtype):
raise ValueError('states.dtype={} is not compatible with'
' observation_spec.dtype={}'.format(
states.dtype, self._observation_spec.dtype))
def _validate_actions(self, actions):
"""Raises a value error if `actions` does not have the expected shape.
Args:
actions: A tensor.
Raises:
ValueError: If actions.shape or actions.dtype are not compatible with
action_spec.
"""
actions.shape.assert_is_compatible_with(self._action_shape)
if not actions.dtype.is_compatible_with(self._action_spec.dtype):
raise ValueError('actions.dtype={} is not compatible with'
' action_spec.dtype={}'.format(
actions.dtype, self._action_spec.dtype))
@gin.configurable
class TD3Agent(DdpgAgent):
"""An RL agent that learns using the TD3 algorithm."""
ACTOR_NET_SCOPE = 'actor_net'
CRITIC_NET_SCOPE = 'critic_net'
CRITIC_NET2_SCOPE = 'critic_net2'
TARGET_ACTOR_NET_SCOPE = 'target_actor_net'
TARGET_CRITIC_NET_SCOPE = 'target_critic_net'
TARGET_CRITIC_NET2_SCOPE = 'target_critic_net2'
def __init__(self,
observation_spec,
action_spec,
actor_net=networks.actor_net,
critic_net=networks.critic_net,
td_errors_loss=tf.losses.huber_loss,
dqda_clipping=0.,
actions_regularizer=0.,
target_q_clipping=None,
residual_phi=0.0,
debug_summaries=False):
"""Constructs a TD3 agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
actor_net: A callable that creates the actor network. Must take the
following arguments: states, num_actions. Please see networks.actor_net
for an example.
critic_net: A callable that creates the critic network. Must take the
following arguments: states, actions. Please see networks.critic_net
for an example.
td_errors_loss: A callable defining the loss function for the critic
td error.
dqda_clipping: (float) clips the gradient dqda element-wise between
[-dqda_clipping, dqda_clipping]. Does not perform clipping if
dqda_clipping == 0.
actions_regularizer: A scalar, when positive penalizes the norm of the
actions. This can prevent saturation of actions for the actor_loss.
target_q_clipping: (tuple of floats) clips target q values within
(low, high) values when computing the critic loss.
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that
interpolates between Q-learning and residual gradient algorithm.
http://www.leemon.com/papers/1995b.pdf
debug_summaries: If True, add summaries to help debug behavior.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._observation_spec = observation_spec[0]
self._action_spec = action_spec[0]
self._state_shape = tf.TensorShape([None]).concatenate(
self._observation_spec.shape)
self._action_shape = tf.TensorShape([None]).concatenate(
self._action_spec.shape)
self._num_action_dims = self._action_spec.shape.num_elements()
self._scope = tf.get_variable_scope().name
self._actor_net = tf.make_template(
self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._critic_net = tf.make_template(
self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._critic_net2 = tf.make_template(
self.CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True)
self._target_actor_net = tf.make_template(
self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._target_critic_net = tf.make_template(
self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._target_critic_net2 = tf.make_template(
self.TARGET_CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True)
self._td_errors_loss = td_errors_loss
if dqda_clipping < 0:
raise ValueError('dqda_clipping must be >= 0.')
self._dqda_clipping = dqda_clipping
self._actions_regularizer = actions_regularizer
self._target_q_clipping = target_q_clipping
self._residual_phi = residual_phi
self._debug_summaries = debug_summaries
def get_trainable_critic_vars(self):
"""Returns a list of trainable variables in the critic network.
NOTE: This gets the vars of both critic networks.
Returns:
A list of trainable variables in the critic network.
"""
return (
slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)))
def critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the critic network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
values1 = self._critic_net(states, actions,
for_critic_loss=for_critic_loss)
values2 = self._critic_net2(states, actions,
for_critic_loss=for_critic_loss)
if for_critic_loss:
return values1, values2
return values1
def target_critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the target critic network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self._validate_states(states)
self._validate_actions(actions)
values1 = tf.stop_gradient(
self._target_critic_net(states, actions,
for_critic_loss=for_critic_loss))
values2 = tf.stop_gradient(
self._target_critic_net2(states, actions,
for_critic_loss=for_critic_loss))
if for_critic_loss:
return values1, values2
return values1
def value_net(self, states, for_critic_loss=False):
"""Returns the output of the critic evaluated with the actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
actions = self.actor_net(states)
return self.critic_net(states, actions,
for_critic_loss=for_critic_loss)
def target_value_net(self, states, for_critic_loss=False):
"""Returns the output of the target critic evaluated with the target actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
target_actions = self.target_actor_net(states)
noise = tf.clip_by_value(
tf.random_normal(tf.shape(target_actions), stddev=0.2), -0.5, 0.5)
values1, values2 = self.target_critic_net(
states, target_actions + noise,
for_critic_loss=for_critic_loss)
values = tf.minimum(values1, values2)
return values, values
@gin.configurable('td3_update_targets')
def update_targets(self, tau=1.0):
"""Performs a soft update of the target network parameters.
For each weight w_s in the actor/critic networks, and its corresponding
weight w_t in the target actor/critic networks, a soft update is:
w_t = (1- tau) x w_t + tau x ws
Args:
tau: A float scalar in [0, 1]
Returns:
An operation that performs a soft update of the target network parameters.
Raises:
ValueError: If `tau` is not in [0, 1].
"""
if tau < 0 or tau > 1:
raise ValueError('Input `tau` should be in [0, 1].')
update_actor = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)),
tau)
# NOTE: This updates both critic networks.
update_critic = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)),
tau)
return tf.group(update_actor, update_critic, name='update_targets')
def gen_debug_td_error_summaries(
target_q_values, q_values, td_targets, td_errors):
"""Generates debug summaries for critic given a set of batch samples.
Args:
target_q_values: set of predicted next stage values.
q_values: current predicted value for the critic network.
td_targets: discounted target_q_values with added next stage reward.
td_errors: the different between td_targets and q_values.
"""
with tf.name_scope('td_errors'):
tf.summary.histogram('td_targets', td_targets)
tf.summary.histogram('q_values', q_values)
tf.summary.histogram('target_q_values', target_q_values)
tf.summary.histogram('td_errors', td_errors)
with tf.name_scope('td_targets'):
tf.summary.scalar('mean', tf.reduce_mean(td_targets))
tf.summary.scalar('max', tf.reduce_max(td_targets))
tf.summary.scalar('min', tf.reduce_min(td_targets))
with tf.name_scope('q_values'):
tf.summary.scalar('mean', tf.reduce_mean(q_values))
tf.summary.scalar('max', tf.reduce_max(q_values))
tf.summary.scalar('min', tf.reduce_min(q_values))
with tf.name_scope('target_q_values'):
tf.summary.scalar('mean', tf.reduce_mean(target_q_values))
tf.summary.scalar('max', tf.reduce_max(target_q_values))
tf.summary.scalar('min', tf.reduce_min(target_q_values))
with tf.name_scope('td_errors'):
tf.summary.scalar('mean', tf.reduce_mean(td_errors))
tf.summary.scalar('max', tf.reduce_max(td_errors))
tf.summary.scalar('min', tf.reduce_min(td_errors))
tf.summary.scalar('mean_abs', tf.reduce_mean(tf.abs(td_errors)))
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sample actor(policy) and critic(q) networks to use with DDPG/NAF agents.
The DDPG networks are defined in "Section 7: Experiment Details" of
"Continuous control with deep reinforcement learning" - Lilicrap et al.
https://arxiv.org/abs/1509.02971
The NAF critic network is based on "Section 4" of "Continuous deep Q-learning
with model-based acceleration" - Gu et al. https://arxiv.org/pdf/1603.00748.
"""
import tensorflow as tf
slim = tf.contrib.slim
import gin.tf
@gin.configurable('ddpg_critic_net')
def critic_net(states, actions,
for_critic_loss=False,
num_reward_dims=1,
states_hidden_layers=(400,),
actions_hidden_layers=None,
joint_hidden_layers=(300,),
weight_decay=0.0001,
normalizer_fn=None,
activation_fn=tf.nn.relu,
zero_obs=False,
images=False):
"""Creates a critic that returns q values for the given states and actions.
Args:
states: (castable to tf.float32) a [batch_size, num_state_dims] tensor
representing a batch of states.
actions: (castable to tf.float32) a [batch_size, num_action_dims] tensor
representing a batch of actions.
num_reward_dims: Number of reward dimensions.
states_hidden_layers: tuple of hidden layers units for states.
actions_hidden_layers: tuple of hidden layers units for actions.
joint_hidden_layers: tuple of hidden layers units after joining states
and actions using tf.concat().
weight_decay: Weight decay for l2 weights regularizer.
normalizer_fn: Normalizer function, i.e. slim.layer_norm,
activation_fn: Activation function, i.e. tf.nn.relu, slim.leaky_relu, ...
Returns:
A tf.float32 [batch_size] tensor of q values, or a tf.float32
[batch_size, num_reward_dims] tensor of vector q values if
num_reward_dims > 1.
"""
with slim.arg_scope(
[slim.fully_connected],
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
weights_regularizer=slim.l2_regularizer(weight_decay),
weights_initializer=slim.variance_scaling_initializer(
factor=1.0/3.0, mode='FAN_IN', uniform=True)):
orig_states = tf.to_float(states)
#states = tf.to_float(states)
states = tf.concat([tf.to_float(states), tf.to_float(actions)], -1) #TD3
if images or zero_obs:
states *= tf.constant([0.0] * 2 + [1.0] * (states.shape[1] - 2)) #LALA
actions = tf.to_float(actions)
if states_hidden_layers:
states = slim.stack(states, slim.fully_connected, states_hidden_layers,
scope='states')
if actions_hidden_layers:
actions = slim.stack(actions, slim.fully_connected, actions_hidden_layers,
scope='actions')
joint = tf.concat([states, actions], 1)
if joint_hidden_layers:
joint = slim.stack(joint, slim.fully_connected, joint_hidden_layers,
scope='joint')
with slim.arg_scope([slim.fully_connected],
weights_regularizer=None,
weights_initializer=tf.random_uniform_initializer(
minval=-0.003, maxval=0.003)):
value = slim.fully_connected(joint, num_reward_dims,
activation_fn=None,
normalizer_fn=None,
scope='q_value')
if num_reward_dims == 1:
value = tf.reshape(value, [-1])
if not for_critic_loss and num_reward_dims > 1:
value = tf.reduce_sum(
value * tf.abs(orig_states[:, -num_reward_dims:]), -1)
return value
@gin.configurable('ddpg_actor_net')
def actor_net(states, action_spec,
hidden_layers=(400, 300),
normalizer_fn=None,
activation_fn=tf.nn.relu,
zero_obs=False,
images=False):
"""Creates an actor that returns actions for the given states.
Args:
states: (castable to tf.float32) a [batch_size, num_state_dims] tensor
representing a batch of states.
action_spec: (BoundedTensorSpec) A tensor spec indicating the shape
and range of actions.
hidden_layers: tuple of hidden layers units.
normalizer_fn: Normalizer function, i.e. slim.layer_norm,
activation_fn: Activation function, i.e. tf.nn.relu, slim.leaky_relu, ...
Returns:
A tf.float32 [batch_size, num_action_dims] tensor of actions.
"""
with slim.arg_scope(
[slim.fully_connected],
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
weights_initializer=slim.variance_scaling_initializer(
factor=1.0/3.0, mode='FAN_IN', uniform=True)):
states = tf.to_float(states)
orig_states = states
if images or zero_obs: # Zero-out x, y position. Hacky.
states *= tf.constant([0.0] * 2 + [1.0] * (states.shape[1] - 2))
if hidden_layers:
states = slim.stack(states, slim.fully_connected, hidden_layers,
scope='states')
with slim.arg_scope([slim.fully_connected],
weights_initializer=tf.random_uniform_initializer(
minval=-0.003, maxval=0.003)):
actions = slim.fully_connected(states,
action_spec.shape.num_elements(),
scope='actions',
normalizer_fn=None,
activation_fn=tf.nn.tanh)
action_means = (action_spec.maximum + action_spec.minimum) / 2.0
action_magnitudes = (action_spec.maximum - action_spec.minimum) / 2.0
actions = action_means + action_magnitudes * actions
return actions
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines many boolean functions indicating when to step and reset.
"""
import tensorflow as tf
import gin.tf
@gin.configurable
def env_transition(agent, state, action, transition_type, environment_steps,
num_episodes):
"""True if the transition_type is TRANSITION or FINAL_TRANSITION.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: Returns an op that evaluates to true if the transition type is
not RESTARTING
"""
del agent, state, action, num_episodes, environment_steps
cond = tf.logical_not(transition_type)
return cond
@gin.configurable
def env_restart(agent, state, action, transition_type, environment_steps,
num_episodes):
"""True if the transition_type is RESTARTING.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: Returns an op that evaluates to true if the transition type equals
RESTARTING.
"""
del agent, state, action, num_episodes, environment_steps
cond = tf.identity(transition_type)
return cond
@gin.configurable
def every_n_steps(agent,
state,
action,
transition_type,
environment_steps,
num_episodes,
n=150):
"""True once every n steps.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
n: Return true once every n steps.
Returns:
cond: Returns an op that evaluates to true if environment_steps
equals 0 mod n. We increment the step before checking this condition, so
we do not need to add one to environment_steps.
"""
del agent, state, action, transition_type, num_episodes
cond = tf.equal(tf.mod(environment_steps, n), 0)
return cond
@gin.configurable
def every_n_episodes(agent,
state,
action,
transition_type,
environment_steps,
num_episodes,
n=2,
steps_per_episode=None):
"""True once every n episodes.
Specifically, evaluates to True on the 0th step of every nth episode.
Unlike environment_steps, num_episodes starts at 0, so we do want to add
one to ensure it does not reset on the first call.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
n: Return true once every n episodes.
steps_per_episode: How many steps per episode. Needed to determine when a
new episode starts.
Returns:
cond: Returns an op that evaluates to true on the last step of the episode
(i.e. if num_episodes equals 0 mod n).
"""
assert steps_per_episode is not None
del agent, action, transition_type
ant_fell = tf.logical_or(state[2] < 0.2, state[2] > 1.0)
cond = tf.logical_and(
tf.logical_or(
ant_fell,
tf.equal(tf.mod(num_episodes + 1, n), 0)),
tf.equal(tf.mod(environment_steps, steps_per_episode), 0))
return cond
@gin.configurable
def failed_reset_after_n_episodes(agent,
state,
action,
transition_type,
environment_steps,
num_episodes,
steps_per_episode=None,
reset_state=None,
max_dist=1.0,
epsilon=1e-10):
"""Every n episodes, returns True if the reset agent fails to return.
Specifically, evaluates to True if the distance between the state and the
reset state is greater than max_dist at the end of the episode.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
steps_per_episode: How many steps per episode. Needed to determine when a
new episode starts.
reset_state: State to which the reset controller should return.
max_dist: Agent is considered to have successfully reset if its distance
from the reset_state is less than max_dist.
epsilon: small offset to ensure non-negative/zero distance.
Returns:
cond: Returns an op that evaluates to true if num_episodes+1 equals 0
mod n. We add one to the num_episodes so the environment is not reset after
the 0th step.
"""
assert steps_per_episode is not None
assert reset_state is not None
del agent, state, action, transition_type, num_episodes
dist = tf.sqrt(
tf.reduce_sum(tf.squared_difference(state, reset_state)) + epsilon)
cond = tf.logical_and(
tf.greater(dist, tf.constant(max_dist)),
tf.equal(tf.mod(environment_steps, steps_per_episode), 0))
return cond
@gin.configurable
def q_too_small(agent,
state,
action,
transition_type,
environment_steps,
num_episodes,
q_min=0.5):
"""True of q is too small.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
q_min: Returns true if the qval is less than q_min
Returns:
cond: Returns an op that evaluates to true if qval is less than q_min.
"""
del transition_type, environment_steps, num_episodes
state_for_reset_agent = tf.stack(state[:-1], tf.constant([0], dtype=tf.float))
qval = agent.BASE_AGENT_CLASS.critic_net(
tf.expand_dims(state_for_reset_agent, 0), tf.expand_dims(action, 0))[0, :]
cond = tf.greater(tf.constant(q_min), qval)
return cond
@gin.configurable
def true_fn(agent, state, action, transition_type, environment_steps,
num_episodes):
"""Returns an op that evaluates to true.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: op that always evaluates to True.
"""
del agent, state, action, transition_type, environment_steps, num_episodes
cond = tf.constant(True, dtype=tf.bool)
return cond
@gin.configurable
def false_fn(agent, state, action, transition_type, environment_steps,
num_episodes):
"""Returns an op that evaluates to false.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: op that always evaluates to False.
"""
del agent, state, action, transition_type, environment_steps, num_episodes
cond = tf.constant(False, dtype=tf.bool)
return cond
#-*-Python-*-
import gin.tf.external_configurables
create_maze_env.top_down_view = %IMAGES
## Create the agent
AGENT_CLASS = @UvfAgent
UvfAgent.tf_context = %CONTEXT
UvfAgent.actor_net = @agent/ddpg_actor_net
UvfAgent.critic_net = @agent/ddpg_critic_net
UvfAgent.dqda_clipping = 0.0
UvfAgent.td_errors_loss = @tf.losses.huber_loss
UvfAgent.target_q_clipping = %TARGET_Q_CLIPPING
# Create meta agent
META_CLASS = @MetaAgent
MetaAgent.tf_context = %META_CONTEXT
MetaAgent.sub_context = %CONTEXT
MetaAgent.actor_net = @meta/ddpg_actor_net
MetaAgent.critic_net = @meta/ddpg_critic_net
MetaAgent.dqda_clipping = 0.0
MetaAgent.td_errors_loss = @tf.losses.huber_loss
MetaAgent.target_q_clipping = %TARGET_Q_CLIPPING
# Create state preprocess
STATE_PREPROCESS_CLASS = @StatePreprocess
StatePreprocess.ndims = %SUBGOAL_DIM
state_preprocess_net.states_hidden_layers = (100, 100)
state_preprocess_net.num_output_dims = %SUBGOAL_DIM
state_preprocess_net.images = %IMAGES
action_embed_net.num_output_dims = %SUBGOAL_DIM
INVERSE_DYNAMICS_CLASS = @InverseDynamics
# actor_net
ACTOR_HIDDEN_SIZE_1 = 300
ACTOR_HIDDEN_SIZE_2 = 300
agent/ddpg_actor_net.hidden_layers = (%ACTOR_HIDDEN_SIZE_1, %ACTOR_HIDDEN_SIZE_2)
agent/ddpg_actor_net.activation_fn = @tf.nn.relu
agent/ddpg_actor_net.zero_obs = %ZERO_OBS
agent/ddpg_actor_net.images = %IMAGES
meta/ddpg_actor_net.hidden_layers = (%ACTOR_HIDDEN_SIZE_1, %ACTOR_HIDDEN_SIZE_2)
meta/ddpg_actor_net.activation_fn = @tf.nn.relu
meta/ddpg_actor_net.zero_obs = False
meta/ddpg_actor_net.images = %IMAGES
# critic_net
CRITIC_HIDDEN_SIZE_1 = 300
CRITIC_HIDDEN_SIZE_2 = 300
agent/ddpg_critic_net.states_hidden_layers = (%CRITIC_HIDDEN_SIZE_1,)
agent/ddpg_critic_net.actions_hidden_layers = None
agent/ddpg_critic_net.joint_hidden_layers = (%CRITIC_HIDDEN_SIZE_2,)
agent/ddpg_critic_net.weight_decay = 0.0
agent/ddpg_critic_net.activation_fn = @tf.nn.relu
agent/ddpg_critic_net.zero_obs = %ZERO_OBS
agent/ddpg_critic_net.images = %IMAGES
meta/ddpg_critic_net.states_hidden_layers = (%CRITIC_HIDDEN_SIZE_1,)
meta/ddpg_critic_net.actions_hidden_layers = None
meta/ddpg_critic_net.joint_hidden_layers = (%CRITIC_HIDDEN_SIZE_2,)
meta/ddpg_critic_net.weight_decay = 0.0
meta/ddpg_critic_net.activation_fn = @tf.nn.relu
meta/ddpg_critic_net.zero_obs = False
meta/ddpg_critic_net.images = %IMAGES
tf.losses.huber_loss.delta = 1.0
# Sample action
uvf_add_noise_fn.stddev = 1.0
meta_add_noise_fn.stddev = %META_EXPLORE_NOISE
# Update targets
ddpg_update_targets.tau = 0.001
td3_update_targets.tau = 0.005
#-*-Python-*-
# Config eval
evaluate.environment = @create_maze_env()
evaluate.agent_class = %AGENT_CLASS
evaluate.meta_agent_class = %META_CLASS
evaluate.state_preprocess_class = %STATE_PREPROCESS_CLASS
evaluate.num_episodes_eval = 50
evaluate.num_episodes_videos = 1
evaluate.gamma = 1.0
evaluate.eval_interval_secs = 1
evaluate.generate_videos = False
evaluate.generate_summaries = True
evaluate.eval_modes = %EVAL_MODES
evaluate.max_steps_per_episode = %RESET_EPISODE_PERIOD
#-*-Python-*-
# Create replay_buffer
agent/CircularBuffer.buffer_size = 200000
meta/CircularBuffer.buffer_size = 200000
agent/CircularBuffer.scope = "agent"
meta/CircularBuffer.scope = "meta"
# Config train
train_uvf.environment = @create_maze_env()
train_uvf.agent_class = %AGENT_CLASS
train_uvf.meta_agent_class = %META_CLASS
train_uvf.state_preprocess_class = %STATE_PREPROCESS_CLASS
train_uvf.inverse_dynamics_class = %INVERSE_DYNAMICS_CLASS
train_uvf.replay_buffer = @agent/CircularBuffer()
train_uvf.meta_replay_buffer = @meta/CircularBuffer()
train_uvf.critic_optimizer = @critic/AdamOptimizer()
train_uvf.actor_optimizer = @actor/AdamOptimizer()
train_uvf.meta_critic_optimizer = @meta_critic/AdamOptimizer()
train_uvf.meta_actor_optimizer = @meta_actor/AdamOptimizer()
train_uvf.repr_optimizer = @repr/AdamOptimizer()
train_uvf.num_episodes_train = 25000
train_uvf.batch_size = 100
train_uvf.initial_episodes = 5
train_uvf.gamma = 0.99
train_uvf.meta_gamma = 0.99
train_uvf.reward_scale_factor = 1.0
train_uvf.target_update_period = 2
train_uvf.num_updates_per_observation = 1
train_uvf.num_collect_per_update = 1
train_uvf.num_collect_per_meta_update = 10
train_uvf.debug_summaries = False
train_uvf.log_every_n_steps = 1000
train_uvf.save_policy_every_n_steps =100000
# Config Optimizers
critic/AdamOptimizer.learning_rate = 0.001
critic/AdamOptimizer.beta1 = 0.9
critic/AdamOptimizer.beta2 = 0.999
actor/AdamOptimizer.learning_rate = 0.0001
actor/AdamOptimizer.beta1 = 0.9
actor/AdamOptimizer.beta2 = 0.999
meta_critic/AdamOptimizer.learning_rate = 0.001
meta_critic/AdamOptimizer.beta1 = 0.9
meta_critic/AdamOptimizer.beta2 = 0.999
meta_actor/AdamOptimizer.learning_rate = 0.0001
meta_actor/AdamOptimizer.beta1 = 0.9
meta_actor/AdamOptimizer.beta2 = 0.999
repr/AdamOptimizer.learning_rate = 0.0001
repr/AdamOptimizer.beta1 = 0.9
repr/AdamOptimizer.beta2 = 0.999
#-*-Python-*-
create_maze_env.env_name = "AntBlock"
ZERO_OBS = False
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4), (20, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1", "eval2", "eval3"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
"eval2": [@eval2/ConstantSampler],
"eval3": [@eval3/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [3, 4]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [16, 0]
eval2/ConstantSampler.value = [16, 16]
eval3/ConstantSampler.value = [0, 16]
#-*-Python-*-
create_maze_env.env_name = "AntBlockMaze"
ZERO_OBS = False
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4), (12, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1", "eval2", "eval3"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
"eval2": [@eval2/ConstantSampler],
"eval3": [@eval3/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [3, 4]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [8, 0]
eval2/ConstantSampler.value = [8, 16]
eval3/ConstantSampler.value = [0, 16]
#-*-Python-*-
create_maze_env.env_name = "AntFall"
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4, 0), (12, 28, 5))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [3]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1, 2]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [0, 27, 4.5]
#-*-Python-*-
create_maze_env.env_name = "AntFall"
IMAGES = True
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4, 0), (12, 28, 5))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [3]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
}
meta/Context.context_transition_fn = @task/relative_context_transition_fn
meta/Context.context_multi_transition_fn = @task/relative_context_multi_transition_fn
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1, 2]
task/negative_distance.relative_context = True
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
task/relative_context_transition_fn.k = 3
task/relative_context_multi_transition_fn.k = 3
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [0, 27, 0]
#-*-Python-*-
create_maze_env.env_name = "AntFall"
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4, 0), (12, 28, 5))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [3]
meta/Context.samplers = {
"train": [@eval1/ConstantSampler],
"explore": [@eval1/ConstantSampler],
"eval1": [@eval1/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1, 2]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [0, 27, 4.5]
#-*-Python-*-
create_maze_env.env_name = "AntMaze"
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4), (20, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1", "eval2", "eval3"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
"eval2": [@eval2/ConstantSampler],
"eval3": [@eval3/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [16, 0]
eval2/ConstantSampler.value = [16, 16]
eval3/ConstantSampler.value = [0, 16]
#-*-Python-*-
create_maze_env.env_name = "AntMaze"
IMAGES = True
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4), (20, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1", "eval2", "eval3"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
"eval2": [@eval2/ConstantSampler],
"eval3": [@eval3/ConstantSampler],
}
meta/Context.context_transition_fn = @task/relative_context_transition_fn
meta/Context.context_multi_transition_fn = @task/relative_context_multi_transition_fn
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1]
task/negative_distance.relative_context = True
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
task/relative_context_transition_fn.k = 2
task/relative_context_multi_transition_fn.k = 2
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [16, 0]
eval2/ConstantSampler.value = [16, 16]
eval3/ConstantSampler.value = [0, 16]
#-*-Python-*-
create_maze_env.env_name = "AntPush"
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-16, -4), (16, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval2"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval2": [@eval2/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval2/ConstantSampler.value = [0, 19]
#-*-Python-*-
create_maze_env.env_name = "AntPush"
IMAGES = True
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-16, -4), (16, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval2"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval2": [@eval2/ConstantSampler],
}
meta/Context.context_transition_fn = @task/relative_context_transition_fn
meta/Context.context_multi_transition_fn = @task/relative_context_multi_transition_fn
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1]
task/negative_distance.relative_context = True
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
task/relative_context_transition_fn.k = 2
task/relative_context_multi_transition_fn.k = 2
MetaAgent.k = %SUBGOAL_DIM
eval2/ConstantSampler.value = [0, 19]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment