Commit 17ecf9db authored by Krishnan Srinivasan's avatar Krishnan Srinivasan Committed by ofirnachum
Browse files

update mujoco and tensorflow 1.14/2.0 api calls in efficient-hrl (#6525)

* updated mujoco api calls to data, and tensorflow call to CriticalSection (moved from contrib.framework)

* changed efficient-hrl envs so self.physics points to correct PyMj object based on mujoco_py version

* corrected mujoco version number checkm in AntEnv.physics and PointEnv.physics properties

* fix AntEnv.set_xy by reverting to using self.physics
parent e172ac82
...@@ -47,7 +47,7 @@ class CircularBuffer(object): ...@@ -47,7 +47,7 @@ class CircularBuffer(object):
self._tensors = collections.OrderedDict() self._tensors = collections.OrderedDict()
with tf.variable_scope(self._scope): with tf.variable_scope(self._scope):
self._num_adds = tf.Variable(0, dtype=tf.int64, name='num_adds') self._num_adds = tf.Variable(0, dtype=tf.int64, name='num_adds')
self._num_adds_cs = tf.contrib.framework.CriticalSection(name='num_adds') self._num_adds_cs = tf.CriticalSection(name='num_adds')
@property @property
def buffer_size(self): def buffer_size(self):
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import math import math
import numpy as np import numpy as np
import mujoco_py
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
...@@ -50,7 +51,13 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): ...@@ -50,7 +51,13 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@property @property
def physics(self): def physics(self):
return self.model # check mujoco version is greater than version 1.50 to call correct physics
# model containing PyMjData object for getting and setting position/velocity
# check https://github.com/openai/mujoco-py/issues/80 for updates to api
if mujoco_py.get_version() >= '1.50':
return self.sim
else:
return self.model
def _step(self, a): def _step(self, a):
return self.step(a) return self.step(a)
...@@ -117,7 +124,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): ...@@ -117,7 +124,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def get_ori(self): def get_ori(self):
ori = [0, 1, 0, 0] ori = [0, 1, 0, 0]
rot = self.model.data.qpos[self.__class__.ORI_IND:self.__class__.ORI_IND + 4] # take the quaternion rot = self.physics.data.qpos[self.__class__.ORI_IND:self.__class__.ORI_IND + 4] # take the quaternion
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
ori = math.atan2(ori[1], ori[0]) ori = math.atan2(ori[1], ori[0])
return ori return ori
......
...@@ -228,7 +228,7 @@ class MazeEnv(gym.Env): ...@@ -228,7 +228,7 @@ class MazeEnv(gym.Env):
raise Exception("Every geom of the torso must have a name " raise Exception("Every geom of the torso must have a name "
"defined") "defined")
_, file_path = tempfile.mkstemp(text=True) _, file_path = tempfile.mkstemp(text=True, suffix='.xml')
tree.write(file_path) tree.write(file_path)
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs) self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import math import math
import numpy as np import numpy as np
import mujoco_py
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
...@@ -33,7 +34,13 @@ class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle): ...@@ -33,7 +34,13 @@ class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@property @property
def physics(self): def physics(self):
return self.model # check mujoco version is greater than version 1.50 to call correct physics
# model containing PyMjData object for getting and setting position/velocity
# check https://github.com/openai/mujoco-py/issues/80 for updates to api
if mujoco_py.get_version() >= '1.50':
return self.sim
else:
return self.model
def _step(self, a): def _step(self, a):
return self.step(a) return self.step(a)
...@@ -80,7 +87,7 @@ class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle): ...@@ -80,7 +87,7 @@ class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def get_ori(self): def get_ori(self):
return self.model.data.qpos[self.__class__.ORI_IND] return self.physics.data.qpos[self.__class__.ORI_IND]
def set_xy(self, xy): def set_xy(self, xy):
qpos = np.copy(self.physics.data.qpos) qpos = np.copy(self.physics.data.qpos)
......
...@@ -28,7 +28,7 @@ CONFIGS_PATH = './configs' ...@@ -28,7 +28,7 @@ CONFIGS_PATH = './configs'
CONTEXT_CONFIGS_PATH = './context/configs' CONTEXT_CONFIGS_PATH = './context/configs'
def main(): def main():
bb = './' bb = '.'
base_num_args = 6 base_num_args = 6
if len(sys.argv) < base_num_args: if len(sys.argv) < base_num_args:
print( print(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment