"vscode:/vscode.git/clone" did not exist on "08f46125a007d5b710a83c4f4cda7e95114456ac"
Unverified Commit c9f03bf6 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #5870 from ofirnachum/master

Add training and eval code for efficient-hrl
parents 2c181308 052361de
#-*-Python-*-
create_maze_env.env_name = "AntPush"
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-16, -4), (16, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval2"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@eval2/ConstantSampler],
"explore": [@eval2/ConstantSampler],
"eval2": [@eval2/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval2/ConstantSampler.value = [0, 19]
#-*-Python-*-
ENV_CONTEXT = None
EVAL_MODES = ["eval"]
TARGET_Q_CLIPPING = None
RESET_EPISODE_PERIOD = None
ZERO_OBS = False
CONTEXT_RANGE_MIN = -10
CONTEXT_RANGE_MAX = 10
SUBGOAL_DIM = 2
uvf/negative_distance.summarize = False
uvf/negative_distance.relative_context = True
#-*-Python-*-
ENV_CONTEXT = None
EVAL_MODES = ["eval"]
TARGET_Q_CLIPPING = None
RESET_EPISODE_PERIOD = None
ZERO_OBS = True
IMAGES = False
CONTEXT_RANGE_MIN = (-10, -10, -0.5, -1, -1, -1, -1, -0.5, -0.3, -0.5, -0.3, -0.5, -0.3, -0.5, -0.3)
CONTEXT_RANGE_MAX = ( 10, 10, 0.5, 1, 1, 1, 1, 0.5, 0.3, 0.5, 0.3, 0.5, 0.3, 0.5, 0.3)
SUBGOAL_DIM = 15
META_EXPLORE_NOISE = 1.0
uvf/negative_distance.summarize = False
uvf/negative_distance.relative_context = True
#-*-Python-*-
ENV_CONTEXT = None
EVAL_MODES = ["eval"]
TARGET_Q_CLIPPING = None
RESET_EPISODE_PERIOD = None
ZERO_OBS = False
IMAGES = False
CONTEXT_RANGE_MIN = -10
CONTEXT_RANGE_MAX = 10
SUBGOAL_DIM = 2
META_EXPLORE_NOISE = 5.0
StatePreprocess.trainable = True
StatePreprocess.state_preprocess_net = @state_preprocess_net
StatePreprocess.action_embed_net = @action_embed_net
uvf/negative_distance.summarize = False
uvf/negative_distance.relative_context = True
#-*-Python-*-
ENV_CONTEXT = None
EVAL_MODES = ["eval"]
TARGET_Q_CLIPPING = None
RESET_EPISODE_PERIOD = None
ZERO_OBS = False
IMAGES = False
CONTEXT_RANGE_MIN = -10
CONTEXT_RANGE_MAX = 10
SUBGOAL_DIM = 2
META_EXPLORE_NOISE = 1.0
uvf/negative_distance.summarize = False
uvf/negative_distance.relative_context = True
#-*-Python-*-
# NOTE: For best training, low-level exploration (uvf_add_noise_fn.stddev)
# should be reduced to around 0.1.
create_maze_env.env_name = "PointMaze"
context_range_min = -10
context_range_max = 10
context_range = (%context_range_min, %context_range_max)
meta_context_range = ((-2, -2), (10, 10))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1", "eval2", "eval3"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
"eval1": [@uvf_eval1/ConstantSampler],
"eval2": [@uvf_eval2/ConstantSampler],
"eval3": [@uvf_eval3/ConstantSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
"eval2": [@eval2/ConstantSampler],
"eval3": [@eval3/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [8, 0]
eval2/ConstantSampler.value = [8, 8]
eval3/ConstantSampler.value = [0, 8]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Context for Universal Value Function agents.
A context specifies a list of contextual variables, each with
own sampling and reward computation methods.
Examples of contextual variables include
goal states, reward combination vectors, etc.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tf_agents import specs
import gin.tf
from utils import utils as uvf_utils
@gin.configurable
class Context(object):
"""Base context."""
VAR_NAME = 'action'
def __init__(self,
tf_env,
context_ranges=None,
context_shapes=None,
state_indices=None,
variable_indices=None,
gamma_index=None,
settable_context=False,
timers=None,
samplers=None,
reward_weights=None,
reward_fn=None,
random_sampler_mode='random',
normalizers=None,
context_transition_fn=None,
context_multi_transition_fn=None,
meta_action_every_n=None):
self._tf_env = tf_env
self.variable_indices = variable_indices
self.gamma_index = gamma_index
self._settable_context = settable_context
self.timers = timers
self._context_transition_fn = context_transition_fn
self._context_multi_transition_fn = context_multi_transition_fn
self._random_sampler_mode = random_sampler_mode
# assign specs
self._obs_spec = self._tf_env.observation_spec()
self._context_shapes = tuple([
shape if shape is not None else self._obs_spec.shape
for shape in context_shapes
])
self.context_specs = tuple([
specs.TensorSpec(dtype=self._obs_spec.dtype, shape=shape)
for shape in self._context_shapes
])
if context_ranges is not None:
self.context_ranges = context_ranges
else:
self.context_ranges = [None] * len(self._context_shapes)
self.context_as_action_specs = tuple([
specs.BoundedTensorSpec(
shape=shape,
dtype=(tf.float32 if self._obs_spec.dtype in
[tf.float32, tf.float64] else self._obs_spec.dtype),
minimum=context_range[0],
maximum=context_range[-1])
for shape, context_range in zip(self._context_shapes, self.context_ranges)
])
if state_indices is not None:
self.state_indices = state_indices
else:
self.state_indices = [None] * len(self._context_shapes)
if self.variable_indices is not None and self.n != len(
self.variable_indices):
raise ValueError(
'variable_indices (%s) must have the same length as contexts (%s).' %
(self.variable_indices, self.context_specs))
assert self.n == len(self.context_ranges)
assert self.n == len(self.state_indices)
# assign reward/sampler fns
self._sampler_fns = dict()
self._samplers = dict()
self._reward_fns = dict()
# assign reward fns
self._add_custom_reward_fns()
reward_weights = reward_weights or None
self._reward_fn = self._make_reward_fn(reward_fn, reward_weights)
# assign samplers
self._add_custom_sampler_fns()
for mode, sampler_fns in samplers.items():
self._make_sampler_fn(sampler_fns, mode)
# create normalizers
if normalizers is None:
self._normalizers = [None] * len(self.context_specs)
else:
self._normalizers = [
normalizer(tf.zeros(shape=spec.shape, dtype=spec.dtype))
if normalizer is not None else None
for normalizer, spec in zip(normalizers, self.context_specs)
]
assert self.n == len(self._normalizers)
self.meta_action_every_n = meta_action_every_n
# create vars
self.context_vars = {}
self.timer_vars = {}
self.create_vars(self.VAR_NAME)
self.t = tf.Variable(
tf.zeros(shape=(), dtype=tf.int32), name='num_timer_steps')
def _add_custom_reward_fns(self):
pass
def _add_custom_sampler_fns(self):
pass
def sample_random_contexts(self, batch_size):
"""Sample random batch contexts."""
assert self._random_sampler_mode is not None
return self.sample_contexts(self._random_sampler_mode, batch_size)[0]
def sample_contexts(self, mode, batch_size, state=None, next_state=None,
**kwargs):
"""Sample a batch of contexts.
Args:
mode: A string representing the mode [`train`, `explore`, `eval`].
batch_size: Batch size.
Returns:
Two lists of [batch_size, num_context_dims] contexts.
"""
contexts, next_contexts = self._sampler_fns[mode](
batch_size, state=state, next_state=next_state,
**kwargs)
self._validate_contexts(contexts)
self._validate_contexts(next_contexts)
return contexts, next_contexts
def compute_rewards(self, mode, states, actions, rewards, next_states,
contexts):
"""Compute context-based rewards.
Args:
mode: A string representing the mode ['uvf', 'task'].
states: A [batch_size, num_state_dims] tensor.
actions: A [batch_size, num_action_dims] tensor.
rewards: A [batch_size] tensor representing unmodified rewards.
next_states: A [batch_size, num_state_dims] tensor.
contexts: A list of [batch_size, num_context_dims] tensors.
Returns:
A [batch_size] tensor representing rewards.
"""
return self._reward_fn(states, actions, rewards, next_states,
contexts)
def _make_reward_fn(self, reward_fns_list, reward_weights):
"""Returns a fn that computes rewards.
Args:
reward_fns_list: A fn or a list of reward fns.
mode: A string representing the operating mode.
reward_weights: A list of reward weights.
"""
if not isinstance(reward_fns_list, (list, tuple)):
reward_fns_list = [reward_fns_list]
if reward_weights is None:
reward_weights = [1.0] * len(reward_fns_list)
assert len(reward_fns_list) == len(reward_weights)
reward_fns_list = [
self._custom_reward_fns[fn] if isinstance(fn, (str,)) else fn
for fn in reward_fns_list
]
def reward_fn(*args, **kwargs):
"""Returns rewards, discounts."""
reward_tuples = [
reward_fn(*args, **kwargs) for reward_fn in reward_fns_list
]
rewards_list = [reward_tuple[0] for reward_tuple in reward_tuples]
discounts_list = [reward_tuple[1] for reward_tuple in reward_tuples]
ndims = max([r.shape.ndims for r in rewards_list])
if ndims > 1: # expand reward shapes to allow broadcasting
for i in range(len(rewards_list)):
for _ in range(rewards_list[i].shape.ndims - ndims):
rewards_list[i] = tf.expand_dims(rewards_list[i], axis=-1)
for _ in range(discounts_list[i].shape.ndims - ndims):
discounts_list[i] = tf.expand_dims(discounts_list[i], axis=-1)
rewards = tf.add_n(
[r * tf.to_float(w) for r, w in zip(rewards_list, reward_weights)])
discounts = discounts_list[0]
for d in discounts_list[1:]:
discounts *= d
return rewards, discounts
return reward_fn
def _make_sampler_fn(self, sampler_cls_list, mode):
"""Returns a fn that samples a list of context vars.
Args:
sampler_cls_list: A list of sampler classes.
mode: A string representing the operating mode.
"""
if not isinstance(sampler_cls_list, (list, tuple)):
sampler_cls_list = [sampler_cls_list]
self._samplers[mode] = []
sampler_fns = []
for spec, sampler in zip(self.context_specs, sampler_cls_list):
if isinstance(sampler, (str,)):
sampler_fn = self._custom_sampler_fns[sampler]
else:
sampler_fn = sampler(context_spec=spec)
self._samplers[mode].append(sampler_fn)
sampler_fns.append(sampler_fn)
def batch_sampler_fn(batch_size, state=None, next_state=None, **kwargs):
"""Sampler fn."""
contexts_tuples = [
sampler(batch_size, state=state, next_state=next_state, **kwargs)
for sampler in sampler_fns]
contexts = [c[0] for c in contexts_tuples]
next_contexts = [c[1] for c in contexts_tuples]
contexts = [
normalizer.update_apply(c) if normalizer is not None else c
for normalizer, c in zip(self._normalizers, contexts)
]
next_contexts = [
normalizer.apply(c) if normalizer is not None else c
for normalizer, c in zip(self._normalizers, next_contexts)
]
return contexts, next_contexts
self._sampler_fns[mode] = batch_sampler_fn
def set_env_context_op(self, context, disable_unnormalizer=False):
"""Returns a TensorFlow op that sets the environment context.
Args:
context: A list of context Tensor variables.
disable_unnormalizer: Disable unnormalization.
Returns:
A TensorFlow op that sets the environment context.
"""
ret_val = np.array(1.0, dtype=np.float32)
if not self._settable_context:
return tf.identity(ret_val)
if not disable_unnormalizer:
context = [
normalizer.unapply(tf.expand_dims(c, 0))[0]
if normalizer is not None else c
for normalizer, c in zip(self._normalizers, context)
]
def set_context_func(*env_context_values):
tf.logging.info('[set_env_context_op] Setting gym environment context.')
# pylint: disable=protected-access
self.gym_env.set_context(*env_context_values)
return ret_val
# pylint: enable=protected-access
with tf.name_scope('set_env_context'):
set_op = tf.py_func(set_context_func, context, tf.float32,
name='set_env_context_py_func')
set_op.set_shape([])
return set_op
def set_replay(self, replay):
"""Set replay buffer for samplers.
Args:
replay: A replay buffer.
"""
for _, samplers in self._samplers.items():
for sampler in samplers:
sampler.set_replay(replay)
def get_clip_fns(self):
"""Returns a list of clip fns for contexts.
Returns:
A list of fns that clip context tensors.
"""
clip_fns = []
for context_range in self.context_ranges:
def clip_fn(var_, range_=context_range):
"""Clip a tensor."""
if range_ is None:
clipped_var = tf.identity(var_)
elif isinstance(range_[0], (int, long, float, list, np.ndarray)):
clipped_var = tf.clip_by_value(
var_,
range_[0],
range_[1],)
else: raise NotImplementedError(range_)
return clipped_var
clip_fns.append(clip_fn)
return clip_fns
def _validate_contexts(self, contexts):
"""Validate if contexts have right specs.
Args:
contexts: A list of [batch_size, num_context_dim] tensors.
Raises:
ValueError: If shape or dtype mismatches that of spec.
"""
for i, (context, spec) in enumerate(zip(contexts, self.context_specs)):
if context[0].shape != spec.shape:
raise ValueError('contexts[%d] has invalid shape %s wrt spec shape %s' %
(i, context[0].shape, spec.shape))
if context.dtype != spec.dtype:
raise ValueError('contexts[%d] has invalid dtype %s wrt spec dtype %s' %
(i, context.dtype, spec.dtype))
def context_multi_transition_fn(self, contexts, **kwargs):
"""Returns multiple future contexts starting from a batch."""
assert self._context_multi_transition_fn
return self._context_multi_transition_fn(contexts, None, None, **kwargs)
def step(self, mode, agent=None, action_fn=None, **kwargs):
"""Returns [next_contexts..., next_timer] list of ops.
Args:
mode: a string representing the mode=[train, explore, eval].
**kwargs: kwargs for context_transition_fn.
Returns:
a list of ops that set the context.
"""
if agent is None:
ops = []
if self._context_transition_fn is not None:
def sampler_fn():
samples = self.sample_contexts(mode, 1)[0]
return [s[0] for s in samples]
values = self._context_transition_fn(self.vars, self.t, sampler_fn, **kwargs)
ops += [tf.assign(var, value) for var, value in zip(self.vars, values)]
ops.append(tf.assign_add(self.t, 1)) # increment timer
return ops
else:
ops = agent.tf_context.step(mode, **kwargs)
state = kwargs['state']
next_state = kwargs['next_state']
state_repr = kwargs['state_repr']
next_state_repr = kwargs['next_state_repr']
with tf.control_dependencies(ops): # Step high level context before computing low level one.
# Get the context transition function output.
values = self._context_transition_fn(self.vars, self.t, None,
state=state_repr,
next_state=next_state_repr)
# Select a new goal every C steps, otherwise use context transition.
low_level_context = [
tf.cond(tf.equal(self.t % self.meta_action_every_n, 0),
lambda: tf.cast(action_fn(next_state, context=None), tf.float32),
lambda: values)]
ops = [tf.assign(var, value)
for var, value in zip(self.vars, low_level_context)]
with tf.control_dependencies(ops):
return [tf.assign_add(self.t, 1)] # increment timer
return ops
def reset(self, mode, agent=None, action_fn=None, state=None):
"""Returns ops that reset the context.
Args:
mode: a string representing the mode=[train, explore, eval].
Returns:
a list of ops that reset the context.
"""
if agent is None:
values = self.sample_contexts(mode=mode, batch_size=1)[0]
if values is None:
return []
values = [value[0] for value in values]
values[0] = uvf_utils.tf_print(
values[0],
values,
message='context:reset, mode=%s' % mode,
first_n=10,
name='context:reset:%s' % mode)
all_ops = []
for _, context_vars in sorted(self.context_vars.items()):
ops = [tf.assign(var, value) for var, value in zip(context_vars, values)]
all_ops += ops
all_ops.append(self.set_env_context_op(values))
all_ops.append(tf.assign(self.t, 0)) # reset timer
return all_ops
else:
ops = agent.tf_context.reset(mode)
# NOTE: The code is currently written in such a way that the higher level
# policy does not provide a low-level context until the second
# observation. Insead, we just zero-out low-level contexts.
for key, context_vars in sorted(self.context_vars.items()):
ops += [tf.assign(var, tf.zeros_like(var)) for var, meta_var in
zip(context_vars, agent.tf_context.context_vars[key])]
ops.append(tf.assign(self.t, 0)) # reset timer
return ops
def create_vars(self, name, agent=None):
"""Create tf variables for contexts.
Args:
name: Name of the variables.
Returns:
A list of [num_context_dims] tensors.
"""
if agent is not None:
meta_vars = agent.create_vars(name)
else:
meta_vars = {}
assert name not in self.context_vars, ('Conflict! %s is already '
'initialized.') % name
self.context_vars[name] = tuple([
tf.Variable(
tf.zeros(shape=spec.shape, dtype=spec.dtype),
name='%s_context_%d' % (name, i))
for i, spec in enumerate(self.context_specs)
])
return self.context_vars[name], meta_vars
@property
def n(self):
return len(self.context_specs)
@property
def vars(self):
return self.context_vars[self.VAR_NAME]
# pylint: disable=protected-access
@property
def gym_env(self):
return self._tf_env.pyenv._gym_env
@property
def tf_env(self):
return self._tf_env
# pylint: enable=protected-access
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Context functions.
Given the current contexts, timer and context sampler, returns new contexts
after an environment step. This can be used to define a high-level policy
that controls contexts as its actions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import gin.tf
import utils as uvf_utils
@gin.configurable
def periodic_context_fn(contexts, timer, sampler_fn, period=1):
"""Periodically samples contexts.
Args:
contexts: a list of [num_context_dims] tensor variables representing
current contexts.
timer: a scalar integer tensor variable holding the current time step.
sampler_fn: a sampler function that samples a list of [num_context_dims]
tensors.
period: (integer) period of update.
Returns:
a list of [num_context_dims] tensors.
"""
contexts = list(contexts[:]) # create copy
return tf.cond(tf.mod(timer, period) == 0, sampler_fn, lambda: contexts)
@gin.configurable
def timer_context_fn(contexts,
timer,
sampler_fn,
period=1,
timer_index=-1,
debug=False):
"""Samples contexts based on timer in contexts.
Args:
contexts: a list of [num_context_dims] tensor variables representing
current contexts.
timer: a scalar integer tensor variable holding the current time step.
sampler_fn: a sampler function that samples a list of [num_context_dims]
tensors.
period: (integer) period of update; actual period = `period` + 1.
timer_index: (integer) Index of context list that present timer.
debug: (boolean) Print debug messages.
Returns:
a list of [num_context_dims] tensors.
"""
contexts = list(contexts[:]) # create copy
cond = tf.equal(contexts[timer_index][0], 0)
def reset():
"""Sample context and reset the timer."""
new_contexts = sampler_fn()
new_contexts[timer_index] = tf.zeros_like(
contexts[timer_index]) + period
return new_contexts
def update():
"""Decrement the timer."""
contexts[timer_index] -= 1
return contexts
values = tf.cond(cond, reset, update)
if debug:
values[0] = uvf_utils.tf_print(
values[0],
values + [timer],
'timer_context_fn',
first_n=200,
name='timer_context_fn:contexts')
return values
@gin.configurable
def relative_context_transition_fn(
contexts, timer, sampler_fn,
k=2, state=None, next_state=None,
**kwargs):
"""Contexts updated to be relative to next state.
"""
contexts = list(contexts[:]) # create copy
assert len(contexts) == 1
new_contexts = [
tf.concat(
[contexts[0][:k] + state[:k] - next_state[:k],
contexts[0][k:]], -1)]
return new_contexts
@gin.configurable
def relative_context_multi_transition_fn(
contexts, timer, sampler_fn,
k=2, states=None,
**kwargs):
"""Given contexts at first state and sequence of states, derives sequence of all contexts.
"""
contexts = list(contexts[:]) # create copy
assert len(contexts) == 1
contexts = [
tf.concat(
[tf.expand_dims(contexts[0][:, :k] + states[:, 0, :k], 1) - states[:, :, :k],
contexts[0][:, None, k:] * tf.ones_like(states[:, :, :1])], -1)]
return contexts
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Import gin configurable modules.
"""
# pylint: disable=unused-import
from context import context
from context import context_transition_functions
from context import gin_utils
from context import rewards_functions
from context import samplers
# pylint: disable=unused-import
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gin configurable utility functions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import gin.tf
@gin.configurable
def gin_sparse_array(size, values, indices, fill_value=0):
arr = np.zeros(size)
arr.fill(fill_value)
arr[indices] = values
return arr
@gin.configurable
def gin_sum(values):
result = values[0]
for value in values[1:]:
result += value
return result
@gin.configurable
def gin_range(n):
return range(n)
This diff is collapsed.
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Random policy on an environment."""
import tensorflow as tf
import numpy as np
import random
import create_maze_env
app = tf.app
flags = tf.flags
logging = tf.logging
FLAGS = flags.FLAGS
flags.DEFINE_string('env', 'AntMaze', 'environment name: AntMaze, AntPush, or AntFall')
flags.DEFINE_integer('episode_length', 500, 'episode length')
flags.DEFINE_integer('num_episodes', 50, 'number of episodes')
def get_goal_sample_fn(env_name):
if env_name == 'AntMaze':
# NOTE: When evaluating (i.e. the metrics shown in the paper,
# we use the commented out goal sampling function. The uncommented
# one is only used for training.
#return lambda: np.array([0., 16.])
return lambda: np.random.uniform((-4, -4), (20, 20))
elif env_name == 'AntPush':
return lambda: np.array([0., 19.])
elif env_name == 'AntFall':
return lambda: np.array([0., 27., 4.5])
else:
assert False, 'Unknown env'
def get_reward_fn(env_name):
if env_name == 'AntMaze':
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
elif env_name == 'AntPush':
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
elif env_name == 'AntFall':
return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
else:
assert False, 'Unknown env'
def success_fn(last_reward):
return last_reward > -5.0
class EnvWithGoal(object):
def __init__(self, base_env, env_name):
self.base_env = base_env
self.goal_sample_fn = get_goal_sample_fn(env_name)
self.reward_fn = get_reward_fn(env_name)
self.goal = None
def reset(self):
obs = self.base_env.reset()
self.goal = self.goal_sample_fn()
return np.concatenate([obs, self.goal])
def step(self, a):
obs, _, done, info = self.base_env.step(a)
reward = self.reward_fn(obs, self.goal)
return np.concatenate([obs, self.goal]), reward, done, info
@property
def action_space(self):
return self.base_env.action_space
def run_environment(env_name, episode_length, num_episodes):
env = EnvWithGoal(
create_maze_env.create_maze_env(env_name),
env_name)
def action_fn(obs):
action_space = env.action_space
action_space_mean = (action_space.low + action_space.high) / 2.0
action_space_magn = (action_space.high - action_space.low) / 2.0
random_action = (action_space_mean +
action_space_magn *
np.random.uniform(low=-1.0, high=1.0,
size=action_space.shape))
return random_action
rewards = []
successes = []
for ep in range(num_episodes):
rewards.append(0.0)
successes.append(False)
obs = env.reset()
for _ in range(episode_length):
obs, reward, done, _ = env.step(action_fn(obs))
rewards[-1] += reward
successes[-1] = success_fn(reward)
if done:
break
logging.info('Episode %d reward: %.2f, Success: %d', ep + 1, rewards[-1], successes[-1])
logging.info('Average Reward over %d episodes: %.2f',
num_episodes, np.mean(rewards))
logging.info('Average Success over %d episodes: %.2f',
num_episodes, np.mean(successes))
def main(unused_argv):
logging.set_verbosity(logging.INFO)
run_environment(FLAGS.env, FLAGS.episode_length, FLAGS.num_episodes)
if __name__ == '__main__':
app.run()
...@@ -21,8 +21,21 @@ from gym import utils ...@@ -21,8 +21,21 @@ from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
def q_inv(a):
return [a[0], -a[1], -a[2], -a[3]]
def q_mult(a, b): # multiply two quaternion
w = a[0] * b[0] - a[1] * b[1] - a[2] * b[2] - a[3] * b[3]
i = a[0] * b[1] + a[1] * b[0] + a[2] * b[3] - a[3] * b[2]
j = a[0] * b[2] - a[1] * b[3] + a[2] * b[0] + a[3] * b[1]
k = a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + a[3] * b[0]
return [w, i, j, k]
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
FILE = "ant.xml" FILE = "ant.xml"
ORI_IND = 3
def __init__(self, file_path=None, expose_all_qpos=True, def __init__(self, file_path=None, expose_all_qpos=True,
expose_body_coms=None, expose_body_comvels=None): expose_body_coms=None, expose_body_comvels=None):
...@@ -101,3 +114,21 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): ...@@ -101,3 +114,21 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def viewer_setup(self): def viewer_setup(self):
self.viewer.cam.distance = self.model.stat.extent * 0.5 self.viewer.cam.distance = self.model.stat.extent * 0.5
def get_ori(self):
ori = [0, 1, 0, 0]
rot = self.model.data.qpos[self.__class__.ORI_IND:self.__class__.ORI_IND + 4] # take the quaternion
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
ori = math.atan2(ori[1], ori[0])
return ori
def set_xy(self, xy):
qpos = np.copy(self.physics.data.qpos)
qpos[0] = xy[0]
qpos[1] = xy[1]
qvel = self.physics.data.qvel
self.set_state(qpos, qvel)
def get_xy(self):
return self.physics.data.qpos[:2]
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from maze_env import MazeEnv from environments.maze_env import MazeEnv
from ant import AntEnv from environments.ant import AntEnv
class AntMazeEnv(MazeEnv): class AntMazeEnv(MazeEnv):
......
...@@ -13,18 +13,85 @@ ...@@ -13,18 +13,85 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from ant_maze_env import AntMazeEnv from environments.ant_maze_env import AntMazeEnv
from environments.point_maze_env import PointMazeEnv
import tensorflow as tf
import gin.tf
from tf_agents.environments import gym_wrapper
from tf_agents.environments import tf_py_environment
@gin.configurable
def create_maze_env(env_name=None, top_down_view=False):
n_bins = 0
manual_collision = False
if env_name.startswith('Ego'):
n_bins = 8
env_name = env_name[3:]
if env_name.startswith('Ant'):
cls = AntMazeEnv
env_name = env_name[3:]
maze_size_scaling = 8
elif env_name.startswith('Point'):
cls = PointMazeEnv
manual_collision = True
env_name = env_name[5:]
maze_size_scaling = 4
else:
assert False, 'unknown env %s' % env_name
def create_maze_env(env_name=None):
maze_id = None maze_id = None
if env_name.startswith('AntMaze'): observe_blocks = False
put_spin_near_agent = False
if env_name == 'Maze':
maze_id = 'Maze' maze_id = 'Maze'
elif env_name.startswith('AntPush'): elif env_name == 'Push':
maze_id = 'Push' maze_id = 'Push'
elif env_name.startswith('AntFall'): elif env_name == 'Fall':
maze_id = 'Fall' maze_id = 'Fall'
elif env_name == 'Block':
maze_id = 'Block'
put_spin_near_agent = True
observe_blocks = True
elif env_name == 'BlockMaze':
maze_id = 'BlockMaze'
put_spin_near_agent = True
observe_blocks = True
else: else:
raise ValueError('Unknown maze environment %s' % env_name) raise ValueError('Unknown maze environment %s' % env_name)
return AntMazeEnv(maze_id=maze_id) gym_mujoco_kwargs = {
'maze_id': maze_id,
'n_bins': n_bins,
'observe_blocks': observe_blocks,
'put_spin_near_agent': put_spin_near_agent,
'top_down_view': top_down_view,
'manual_collision': manual_collision,
'maze_size_scaling': maze_size_scaling
}
gym_env = cls(**gym_mujoco_kwargs)
gym_env.reset()
wrapped_env = gym_wrapper.GymWrapper(gym_env)
return wrapped_env
class TFPyEnvironment(tf_py_environment.TFPyEnvironment):
def __init__(self, *args, **kwargs):
super(TFPyEnvironment, self).__init__(*args, **kwargs)
def start_collect(self):
pass
def current_obs(self):
time_step = self.current_time_step()
return time_step.observation[0] # For some reason, there is an extra dim.
def step(self, actions):
actions = tf.expand_dims(actions, 0)
next_step = super(TFPyEnvironment, self).step(actions)
return next_step.is_last()[0], next_step.reward[0], next_step.discount[0]
def reset(self):
return super(TFPyEnvironment, self).reset()
This diff is collapsed.
...@@ -26,20 +26,27 @@ class Move(object): ...@@ -26,20 +26,27 @@ class Move(object):
XZ = 15 XZ = 15
YZ = 16 YZ = 16
XYZ = 17 XYZ = 17
SpinXY = 18
def can_move_x(movable): def can_move_x(movable):
return movable in [Move.X, Move.XY, Move.XZ, Move.XYZ] return movable in [Move.X, Move.XY, Move.XZ, Move.XYZ,
Move.SpinXY]
def can_move_y(movable): def can_move_y(movable):
return movable in [Move.Y, Move.XY, Move.YZ, Move.XYZ] return movable in [Move.Y, Move.XY, Move.YZ, Move.XYZ,
Move.SpinXY]
def can_move_z(movable): def can_move_z(movable):
return movable in [Move.Z, Move.XZ, Move.YZ, Move.XYZ] return movable in [Move.Z, Move.XZ, Move.YZ, Move.XYZ]
def can_spin(movable):
return movable in [Move.SpinXY]
def can_move(movable): def can_move(movable):
return can_move_x(movable) or can_move_y(movable) or can_move_z(movable) return can_move_x(movable) or can_move_y(movable) or can_move_z(movable)
...@@ -70,7 +77,88 @@ def construct_maze(maze_id='Maze'): ...@@ -70,7 +77,88 @@ def construct_maze(maze_id='Maze'):
[1, 0, 0, 1], [1, 0, 0, 1],
[1, 1, 1, 1], [1, 1, 1, 1],
] ]
elif maze_id == 'Block':
O = 'r'
structure = [
[1, 1, 1, 1, 1],
[1, O, 0, 0, 1],
[1, 0, 0, 0, 1],
[1, 0, 0, 0, 1],
[1, 1, 1, 1, 1],
]
elif maze_id == 'BlockMaze':
O = 'r'
structure = [
[1, 1, 1, 1],
[1, O, 0, 1],
[1, 1, 0, 1],
[1, 0, 0, 1],
[1, 1, 1, 1],
]
else: else:
raise NotImplementedError('The provided MazeId %s is not recognized' % maze_id) raise NotImplementedError('The provided MazeId %s is not recognized' % maze_id)
return structure return structure
def line_intersect(pt1, pt2, ptA, ptB):
"""
Taken from https://www.cs.hmc.edu/ACM/lectures/intersections.html
this returns the intersection of Line(pt1,pt2) and Line(ptA,ptB)
"""
DET_TOLERANCE = 0.00000001
# the first line is pt1 + r*(pt2-pt1)
# in component form:
x1, y1 = pt1
x2, y2 = pt2
dx1 = x2 - x1
dy1 = y2 - y1
# the second line is ptA + s*(ptB-ptA)
x, y = ptA
xB, yB = ptB
dx = xB - x
dy = yB - y
DET = (-dx1 * dy + dy1 * dx)
if math.fabs(DET) < DET_TOLERANCE: return (0, 0, 0, 0, 0)
# now, the determinant should be OK
DETinv = 1.0 / DET
# find the scalar amount along the "self" segment
r = DETinv * (-dy * (x - x1) + dx * (y - y1))
# find the scalar amount along the input line
s = DETinv * (-dy1 * (x - x1) + dx1 * (y - y1))
# return the average of the two descriptions
xi = (x1 + r * dx1 + x + s * dx) / 2.0
yi = (y1 + r * dy1 + y + s * dy) / 2.0
return (xi, yi, 1, r, s)
def ray_segment_intersect(ray, segment):
"""
Check if the ray originated from (x, y) with direction theta intersects the line segment (x1, y1) -- (x2, y2),
and return the intersection point if there is one
"""
(x, y), theta = ray
# (x1, y1), (x2, y2) = segment
pt1 = (x, y)
len = 1
pt2 = (x + len * math.cos(theta), y + len * math.sin(theta))
xo, yo, valid, r, s = line_intersect(pt1, pt2, *segment)
if valid and r >= 0 and 0 <= s <= 1:
return (xo, yo)
return None
def point_distance(p1, p2):
x1, y1 = p1
x2, y2 = p2
return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from environments.maze_env import MazeEnv
from environments.point import PointEnv
class PointMazeEnv(MazeEnv):
MODEL_CLASS = PointEnv
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment