Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
65da497f
Commit
65da497f
authored
Dec 13, 2018
by
Shining Sun
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
into cifar_keras
parents
93e0022d
7d032ea3
Changes
186
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2700 additions
and
1156 deletions
+2700
-1156
research/astronet/third_party/robust_mean/robust_mean.py
research/astronet/third_party/robust_mean/robust_mean.py
+0
-72
research/astronet/third_party/robust_mean/robust_mean_test.py
...arch/astronet/third_party/robust_mean/robust_mean_test.py
+0
-65
research/astronet/third_party/robust_mean/test_data/__init__.py
...ch/astronet/third_party/robust_mean/test_data/__init__.py
+0
-0
research/astronet/third_party/robust_mean/test_data/random_normal.py
...tronet/third_party/robust_mean/test_data/random_normal.py
+0
-1011
research/efficient-hrl/README.md
research/efficient-hrl/README.md
+42
-8
research/efficient-hrl/agent.py
research/efficient-hrl/agent.py
+774
-0
research/efficient-hrl/agents/__init__.py
research/efficient-hrl/agents/__init__.py
+1
-0
research/efficient-hrl/agents/circular_buffer.py
research/efficient-hrl/agents/circular_buffer.py
+289
-0
research/efficient-hrl/agents/ddpg_agent.py
research/efficient-hrl/agents/ddpg_agent.py
+739
-0
research/efficient-hrl/agents/ddpg_networks.py
research/efficient-hrl/agents/ddpg_networks.py
+150
-0
research/efficient-hrl/cond_fn.py
research/efficient-hrl/cond_fn.py
+244
-0
research/efficient-hrl/configs/base_uvf.gin
research/efficient-hrl/configs/base_uvf.gin
+68
-0
research/efficient-hrl/configs/eval_uvf.gin
research/efficient-hrl/configs/eval_uvf.gin
+14
-0
research/efficient-hrl/configs/train_uvf.gin
research/efficient-hrl/configs/train_uvf.gin
+52
-0
research/efficient-hrl/context/__init__.py
research/efficient-hrl/context/__init__.py
+1
-0
research/efficient-hrl/context/configs/ant_block.gin
research/efficient-hrl/context/configs/ant_block.gin
+67
-0
research/efficient-hrl/context/configs/ant_block_maze.gin
research/efficient-hrl/context/configs/ant_block_maze.gin
+67
-0
research/efficient-hrl/context/configs/ant_fall_multi.gin
research/efficient-hrl/context/configs/ant_fall_multi.gin
+62
-0
research/efficient-hrl/context/configs/ant_fall_multi_img.gin
...arch/efficient-hrl/context/configs/ant_fall_multi_img.gin
+68
-0
research/efficient-hrl/context/configs/ant_fall_single.gin
research/efficient-hrl/context/configs/ant_fall_single.gin
+62
-0
No files found.
research/astronet/third_party/robust_mean/robust_mean.py
deleted
100644 → 0
View file @
93e0022d
"""Function for computing a robust mean estimate in the presence of outliers.
This is a modified Python implementation of this file:
https://idlastro.gsfc.nasa.gov/ftp/pro/robust/resistant_mean.pro
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
def
robust_mean
(
y
,
cut
):
"""Computes a robust mean estimate in the presence of outliers.
Args:
y: 1D numpy array. Assumed to be normally distributed with outliers.
cut: Points more than this number of standard deviations from the median are
ignored.
Returns:
mean: A robust estimate of the mean of y.
mean_stddev: The standard deviation of the mean.
mask: Boolean array with the same length as y. Values corresponding to
outliers in y are False. All other values are True.
"""
# First, make a robust estimate of the standard deviation of y, assuming y is
# normally distributed. The conversion factor of 1.4826 takes the median
# absolute deviation to the standard deviation of a normal distribution.
# See, e.g. https://www.mathworks.com/help/stats/mad.html.
absdev
=
np
.
abs
(
y
-
np
.
median
(
y
))
sigma
=
1.4826
*
np
.
median
(
absdev
)
# If the previous estimate of the standard deviation using the median absolute
# deviation is zero, fall back to a robust estimate using the mean absolute
# deviation. This estimator has a different conversion factor of 1.253.
# See, e.g. https://www.mathworks.com/help/stats/mad.html.
if
sigma
<
1.0e-24
:
sigma
=
1.253
*
np
.
mean
(
absdev
)
# Identify outliers using our estimate of the standard deviation of y.
mask
=
absdev
<=
cut
*
sigma
# Now, recompute the standard deviation, using the sample standard deviation
# of non-outlier points.
sigma
=
np
.
std
(
y
[
mask
])
# Compensate the estimate of sigma due to trimming away outliers. The
# following formula is an approximation, see
# http://w.astro.berkeley.edu/~johnjohn/idlprocs/robust_mean.pro.
sc
=
np
.
max
([
cut
,
1.0
])
if
sc
<=
4.5
:
sigma
/=
(
-
0.15405
+
0.90723
*
sc
-
0.23584
*
sc
**
2
+
0.020142
*
sc
**
3
)
# Identify outliers using our second estimate of the standard deviation of y.
mask
=
absdev
<=
cut
*
sigma
# Now, recompute the standard deviation, using the sample standard deviation
# with non-outlier points.
sigma
=
np
.
std
(
y
[
mask
])
# Compensate the estimate of sigma due to trimming away outliers.
sc
=
np
.
max
([
cut
,
1.0
])
if
sc
<=
4.5
:
sigma
/=
(
-
0.15405
+
0.90723
*
sc
-
0.23584
*
sc
**
2
+
0.020142
*
sc
**
3
)
# Final estimate is the sample mean with outliers removed.
mean
=
np
.
mean
(
y
[
mask
])
mean_stddev
=
sigma
/
np
.
sqrt
(
len
(
y
)
-
1.0
)
return
mean
,
mean_stddev
,
mask
research/astronet/third_party/robust_mean/robust_mean_test.py
deleted
100644 → 0
View file @
93e0022d
"""Tests for robust_mean.py."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
absltest
import
numpy
as
np
from
third_party.robust_mean
import
robust_mean
from
third_party.robust_mean.test_data
import
random_normal
class
RobustMeanTest
(
absltest
.
TestCase
):
def
testRobustMean
(
self
):
# To avoid non-determinism in the unit test, we use a pre-generated vector
# of length 1,000. Each entry is independently sampled from a random normal
# distribution with mean 2 and standard deviation 1. The maximum value of
# y is 6.075 (+4.075 sigma from the mean) and the minimum value is -1.54
# (-3.54 sigma from the mean).
y
=
np
.
array
(
random_normal
.
RANDOM_NORMAL
)
self
.
assertAlmostEqual
(
np
.
mean
(
y
),
2.00336615850485
)
self
.
assertAlmostEqual
(
np
.
std
(
y
),
1.01690907798
)
# High cut. No points rejected, so the mean should be the sample mean, and
# the mean standard deviation should be the sample standard deviation
# divided by sqrt(1000 - 1).
mean
,
mean_stddev
,
mask
=
robust_mean
.
robust_mean
(
y
,
cut
=
5
)
self
.
assertAlmostEqual
(
mean
,
2.00336615850485
)
self
.
assertAlmostEqual
(
mean_stddev
,
0.032173579
)
self
.
assertLen
(
mask
,
1000
)
self
.
assertEqual
(
np
.
sum
(
mask
),
1000
)
# Cut of 3 standard deviations.
mean
,
mean_stddev
,
mask
=
robust_mean
.
robust_mean
(
y
,
cut
=
3
)
self
.
assertAlmostEqual
(
mean
,
2.0059050070632178
)
self
.
assertAlmostEqual
(
mean_stddev
,
0.03197075302321066
)
# There are exactly 3 points in the sample less than 1 or greater than 5.
# These have indices 12, 220, 344.
self
.
assertLen
(
mask
,
1000
)
self
.
assertEqual
(
np
.
sum
(
mask
),
997
)
self
.
assertFalse
(
np
.
any
(
mask
[[
12
,
220
,
344
]]))
# Add outliers. This corrupts the sample mean to 2.082.
mean
,
mean_stddev
,
mask
=
robust_mean
.
robust_mean
(
y
=
np
.
concatenate
([
y
,
[
10
]
*
10
]),
cut
=
5
)
self
.
assertAlmostEqual
(
mean
,
2.0033661585048681
)
self
.
assertAlmostEqual
(
mean_stddev
,
0.032013749413590531
)
self
.
assertLen
(
mask
,
1010
)
self
.
assertEqual
(
np
.
sum
(
mask
),
1000
)
self
.
assertFalse
(
np
.
any
(
mask
[
1000
:
1010
]))
# Add an outlier. This corrupts the mean to 1.002.
mean
,
mean_stddev
,
mask
=
robust_mean
.
robust_mean
(
y
=
np
.
concatenate
([
y
,
[
-
1000
]]),
cut
=
5
)
self
.
assertAlmostEqual
(
mean
,
2.0033661585048681
)
self
.
assertAlmostEqual
(
mean_stddev
,
0.032157488597211903
)
self
.
assertLen
(
mask
,
1001
)
self
.
assertEqual
(
np
.
sum
(
mask
),
1000
)
self
.
assertFalse
(
mask
[
1000
])
if
__name__
==
"__main__"
:
absltest
.
main
()
research/astronet/third_party/robust_mean/test_data/__init__.py
deleted
100644 → 0
View file @
93e0022d
research/astronet/third_party/robust_mean/test_data/random_normal.py
deleted
100644 → 0
View file @
93e0022d
"""This file contains 1,000 points of a random normal distribution.
The mean of the distribution is 2, and the standard deviation is 1.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
RANDOM_NORMAL
=
[
0.692741320869
,
1.948556207658
,
-
0.140158117639
,
2.680322906859
,
1.876671492867
,
2.885286232509
,
1.482222151802
,
2.234349266246
,
2.437427989583
,
1.573053624952
,
2.169198249367
,
2.791023059619
,
-
1.053798286951
,
1.796497126664
,
3.806070621390
,
1.744055208958
,
3.474399140181
,
1.564447665560
,
1.143107137921
,
1.618376255615
,
2.615632139609
,
1.413239777404
,
1.047237320108
,
3.190489536636
,
2.918428435434
,
1.268789896280
,
0.931181066003
,
3.797790627792
,
0.493025834330
,
1.866146169585
,
0.949927834893
,
1.439666857958
,
2.705521702500
,
1.815406073907
,
1.570503841718
,
1.834429337005
,
2.903916580263
,
-
0.110549195467
,
2.065338922749
,
1.119498053048
,
0.427627428035
,
3.025052175045
,
2.645448868784
,
1.442644951218
,
0.774681298962
,
2.247561418494
,
1.743438974941
,
1.184440017832
,
1.643691193885
,
1.947748675186
,
2.178309991836
,
2.815355272672
,
2.207620544168
,
2.077889048169
,
2.915504366132
,
2.440862146850
,
2.804729838623
,
0.534712595625
,
1.956491042766
,
2.230542009671
,
2.186536281651
,
3.694129968231
,
3.313526598170
,
2.170240599444
,
2.793531289796
,
1.454464312809
,
1.197463589804
,
0.713332299712
,
1.180965411999
,
2.180022174106
,
2.861107927091
,
1.795223865106
,
1.730056153040
,
1.431404424890
,
1.839372935334
,
1.271871740741
,
3.773103777671
,
1.026069424885
,
2.006070770486
,
1.276836142291
,
1.414098998873
,
1.749117068374
,
2.040006827147
,
1.815581326626
,
2.892666735522
,
3.093934769003
,
2.129166907135
,
1.260521633663
,
3.259431640120
,
1.879415647487
,
1.368769201985
,
2.236653714367
,
2.293120875655
,
2.361086097355
,
2.140675892497
,
2.860288793716
,
3.109921655205
,
2.142509743586
,
0.661829413359
,
0.620852115030
,
2.279817287885
,
2.077609300700
,
1.917031492891
,
2.549328729021
,
1.402961147881
,
2.989802645752
,
2.126646549508
,
0.581285045065
,
3.226987223858
,
1.790860716921
,
0.998661497130
,
2.125771640271
,
2.186096892741
,
2.160189267804
,
2.206460323846
,
3.366179111195
,
-
0.125206283025
,
0.645228886619
,
0.505553980622
,
4.494406059555
,
1.291690417806
,
2.977896904657
,
2.869240282824
,
3.344192278881
,
2.487041683297
,
4.236730343795
,
3.007206122800
,
1.210065291965
,
-
0.053847768077
,
1.108953782402
,
1.843857008095
,
2.374767801329
,
1.472199059501
,
3.332198116275
,
2.027084082885
,
2.305065331530
,
3.387400013580
,
1.493365795517
,
2.344295515065
,
2.898632740793
,
3.307836869328
,
1.892766317783
,
2.348033912288
,
1.288522200888
,
2.178559140529
,
2.366037265891
,
3.468023805733
,
1.910134543982
,
1.750500687923
,
1.506717073807
,
1.345976221745
,
1.898226480175
,
2.362688287820
,
2.176558673313
,
1.716475335783
,
1.109563102324
,
1.824697060483
,
2.290331853365
,
3.660496355225
,
3.695990930547
,
0.995131810353
,
2.083740307542
,
2.515409175245
,
1.734919119633
,
0.186488629263
,
3.470910728743
,
3.503515673097
,
2.225335667636
,
4.925211524431
,
3.176405299532
,
2.938260408825
,
2.336603901159
,
2.218333712640
,
3.269148549824
,
1.921171637456
,
3.876114839719
,
1.492216718705
,
2.792835112200
,
3.563198188748
,
2.728530961520
,
3.231549893645
,
2.209018339760
,
1.081828242171
,
0.754161622090
,
1.948018149260
,
2.413945024183
,
1.425023717183
,
2.005406706788
,
0.964987890314
,
1.603414847296
,
0.132077263346
,
1.789327371404
,
1.423488299029
,
2.590160851192
,
3.131340836085
,
2.325779171436
,
2.129789552692
,
1.876126153813
,
2.667783873354
,
-
0.220464828097
,
2.285158851436
,
1.188664672684
,
1.968980968179
,
2.510328726654
,
1.690300427857
,
2.041495293673
,
2.471293710293
,
1.660589811070
,
1.801640276851
,
2.200864460731
,
1.489583958038
,
1.545725376492
,
4.208130184998
,
2.428489533380
,
3.539990060815
,
1.317090333595
,
0.785936916712
,
0.809688718378
,
1.265062896735
,
2.749291333938
,
6.075297866258
,
2.165845459075
,
2.055273600728
,
2.584618009430
,
2.782654850307
,
0.967100649409
,
2.267394795463
,
2.783350629984
,
0.238340558296
,
1.566536380829
,
1.165403279885
,
3.409015124349
,
1.047853632456
,
2.100798231132
,
1.824776518459
,
1.517825551662
,
2.148972385365
,
1.818426298006
,
1.954355115973
,
2.428393037760
,
2.225660788849
,
1.287880002052
,
3.083900598687
,
2.561457835470
,
2.547146477110
,
-
0.060868513691
,
1.917876348341
,
1.194823858275
,
1.237685798924
,
2.500081029116
,
0.605823016300
,
1.341027488293
,
1.357719149407
,
3.959221361786
,
1.457342301661
,
1.450552596247
,
3.152966485077
,
1.755910034199
,
2.252303064393
,
2.315145292843
,
2.092889154866
,
2.044536701039
,
3.078226379252
,
1.940374989780
,
0.981160719305
,
1.801484599888
,
4.599412580952
,
3.029815652986
,
2.234894233100
,
1.884862677960
,
2.703542617621
,
2.188894869734
,
1.031225637544
,
4.487470294014
,
1.916903861878
,
2.178877764206
,
2.001204233385
,
1.668533128794
,
0.118714387565
,
1.236342841750
,
0.697779517270
,
4.061304247309
,
1.873047854221
,
0.529730720609
,
0.772303413290
,
1.734928501976
,
0.830164961083
,
3.674107591296
,
3.027005867653
,
2.798171180697
,
2.754769626808
,
2.287213251879
,
0.224122591017
,
1.996907607820
,
2.272196861888
,
1.423156951562
,
2.649423732022
,
2.410425004883
,
2.348764499112
,
4.188086272873
,
2.592584804958
,
1.360716155533
,
1.089292416194
,
0.877166635938
,
2.923298927077
,
1.699602289582
,
1.764010718116
,
0.851384613856
,
1.362786130903
,
4.014401248962
,
2.004378924317
,
2.680507997712
,
4.162602009325
,
2.080304752717
,
0.758782969232
,
0.896584126809
,
1.907281638800
,
2.753415491620
,
1.571468221472
,
1.510571435517
,
3.133254430892
,
1.314198176176
,
2.871092309494
,
0.505771497509
,
0.608771053519
,
0.099600620869
,
2.202314023992
,
1.561845986404
,
1.935860544395
,
4.227485606155
,
2.507702606518
,
1.966897273255
,
3.462827375982
,
2.297865682096
,
2.018310409281
,
2.231512822040
,
2.912164920958
,
0.391926284930
,
3.233896921158
,
2.270671144478
,
2.151928087898
,
1.169376547635
,
1.410447269758
,
1.104075308499
,
-
1.542633116467
,
1.153006815104
,
1.825678952144
,
3.170518866440
,
4.259372395300
,
2.991591841969
,
2.936827860147
,
1.621450416535
,
2.022035327270
,
3.512668911326
,
2.840069655471
,
0.445725474197
,
0.462229554454
,
0.318918997270
,
2.764048560322
,
1.707769041832
,
0.354635293838
,
1.422103811424
,
1.567812002847
,
1.024884046523
,
3.417171077354
,
1.638428319488
,
3.241722761084
,
1.903144274531
,
2.386560261127
,
1.089737760638
,
0.709565288091
,
-
0.055267123709
,
2.220171017505
,
2.676992914119
,
1.795938808643
,
0.857048483230
,
2.341277450146
,
0.597747299826
,
2.172474110279
,
1.658595631706
,
1.984212673322
,
3.348561751121
,
1.548578130896
,
1.072758349253
,
0.032774558165
,
1.706534108602
,
0.755870027998
,
3.896324551791
,
1.893154948782
,
1.610014175651
,
1.689869260730
,
0.837788921577
,
2.072889043390
,
1.849998133863
,
2.312731199777
,
3.080340939707
,
3.091164029551
,
0.393260795576
,
2.003732199463
,
1.850471999505
,
1.674223365447
,
0.950827266616
,
0.893144704712
,
4.088054520567
,
2.494717669405
,
2.940185915913
,
2.362745036344
,
4.420918853822
,
1.196829910488
,
0.751131585724
,
2.572876053732
,
4.783864101630
,
1.371390533975
,
2.265749507496
,
0.980731387353
,
1.194594017621
,
1.167489912193
,
1.964259577764
,
2.911981147100
,
3.425120291588
,
-
0.257485591786
,
2.472881717211
,
3.053440640390
,
0.762578570358
,
1.132958189893
,
2.182874371350
,
3.052476057575
,
1.277863138274
,
1.639136886663
,
3.068422388091
,
4.082802262329
,
3.817537635954
,
0.097850368917
,
1.833230262781
,
1.868086753582
,
1.887983463294
,
1.651402760749
,
1.139536636683
,
1.506983113398
,
2.136499510660
,
1.554089544528
,
1.817657472715
,
1.881949483974
,
1.694259012321
,
2.466961181010
,
0.934064795958
,
1.780986169906
,
2.370334192643
,
1.384364501906
,
2.661332270733
,
0.505133486534
,
3.377981661644
,
4.528998885749
,
1.374759695170
,
4.722230929249
,
2.457241846607
,
1.089449047113
,
2.069442203989
,
2.922047383746
,
2.418239643214
,
2.102706555829
,
3.402947114317
,
0.796159207619
,
2.632564349894
,
4.200094165974
,
1.193215106012
,
3.096716644943
,
2.876829603210
,
1.921543697812
,
1.061475001173
,
1.539143201636
,
1.962758648643
,
2.295863280945
,
1.165666782756
,
2.795634751169
,
1.964614100540
,
2.881578005533
,
2.637037175067
,
2.892982065300
,
1.370612909045
,
2.259776562066
,
2.613792094772
,
1.906706647250
,
1.557148053231
,
1.133780390845
,
1.143533122599
,
4.117191444375
,
0.188018096004
,
0.214460776257
,
1.603522547618
,
1.983864405185
,
2.699735877141
,
3.298710632472
,
1.487986613587
,
1.991452281956
,
4.766298265537
,
2.586190101112
,
1.065148174656
,
2.145271421777
,
1.522765549323
,
4.396072859233
,
1.606095900438
,
1.031438092798
,
2.960703649068
,
3.605096318253
,
0.738490507896
,
3.432893690638
,
1.851816325195
,
1.776083706595
,
2.626748456025
,
3.271038565612
,
-
0.070612895830
,
1.287744477416
,
1.695753443501
,
2.436753770247
,
2.202642339826
,
1.408938299450
,
2.016160912374
,
0.898288687497
,
1.289813203546
,
2.817550580649
,
1.633807646433
,
0.729748534247
,
3.731152016795
,
4.390636356521
,
0.960267728055
,
2.664438938132
,
3.353944579683
,
3.269381121224
,
1.172165464812
,
2.400938318282
,
0.807728448589
,
1.354368440944
,
0.710514838580
,
3.856287086165
,
1.844610913086
,
0.998911164773
,
2.675023215997
,
2.434288374696
,
3.159869294202
,
2.216260317913
,
1.045656776305
,
2.009335242758
,
-
0.709434674480
,
1.331363427990
,
1.413333913927
,
0.929006400187
,
1.733184360011
,
1.592926769452
,
2.244402061817
,
-
0.685165133598
,
1.588319181570
,
2.023137693888
,
2.371988587924
,
3.467422121044
,
1.853048214949
,
2.619054775096
,
2.800663994934
,
3.602262658213
,
2.739826236907
,
1.657450448996
,
2.635627781551
,
3.426068129480
,
1.151293859797
,
2.126497682712
,
1.381359388969
,
2.247544386446
,
1.174628386484
,
0.958611426223
,
3.704308164799
,
1.239169433782
,
1.113358149317
,
2.955595260742
,
4.724760227403
,
1.980693866973
,
1.178308425749
,
1.764128983382
,
3.075383932276
,
1.825832517572
,
0.832366525583
,
1.141057225660
,
1.525888435245
,
2.385324115908
,
1.358714765290
,
0.551137984167
,
3.731145479114
,
2.129026962460
,
2.573317010477
,
1.976940325893
,
2.420475072025
,
0.684154421404
,
2.802696725755
,
2.541095686615
,
2.058591811473
,
1.640112285999
,
1.856989192038
,
2.614611193561
,
3.469421336856
,
1.053557146407
,
3.032200283499
,
2.573750297024
,
3.083216685185
,
2.404219296708
,
0.346271398570
,
2.589361666010
,
2.774804416246
,
0.445877540011
,
0.905444077995
,
0.063823875188
,
2.931316420485
,
1.682860197161
,
2.972795382257
,
2.597485175923
,
1.554827252582
,
0.938640710601
,
1.554015447012
,
0.698644188586
,
2.957760202695
,
2.706304471141
,
2.642415006150
,
1.464184232137
,
2.765792229162
,
2.039393447616
,
1.582779254230
,
1.722697961910
,
0.354842490538
,
0.839688308674
,
3.250316830782
,
3.993268587677
,
1.831751003414
,
3.737987669486
,
3.837008408003
,
3.656452995704
,
1.378085850241
,
3.992366605685
,
3.063520565655
,
1.829600671075
,
2.853149829083
,
1.948008763331
,
2.489355654745
,
2.039149456991
,
1.723308108929
,
1.530719515047
,
1.390322318375
,
1.015161747970
,
1.902647551975
,
0.587714373573
,
2.419343238401
,
2.037241109090
,
1.989108845487
,
2.555164211364
,
2.145078634562
,
2.453495232937
,
1.572091583978
,
3.017196269239
,
1.359683738353
,
1.905697793148
,
1.745346338494
,
2.410960789923
,
1.688108090624
,
2.041661869959
,
2.261146892703
,
0.108311227666
,
2.261198590438
,
1.205414457068
,
0.815680644627
,
2.373638547036
,
0.314446220857
,
2.407160216258
,
1.767921824455
,
1.812649838016
,
1.981483340407
,
2.294353826751
,
1.219794724258
,
4.384759314526
,
3.362912919095
,
0.358020839800
,
2.416111296383
,
0.772765268291
,
3.036908153028
,
3.499839475422
,
2.504401672085
,
1.000612753791
,
1.031364523108
,
2.905950640378
,
3.816584440139
,
0.846980443659
,
2.806102343007
,
1.662302388297
,
2.146698147213
,
2.247505312463
,
1.485016638026
,
3.139004503074
,
0.525587710167
,
1.271023854890
,
1.255521130946
,
1.814043296479
,
1.216959307975
,
0.978300004743
,
1.793024541935
,
2.436108214253
,
1.805501508380
,
2.289362542667
,
3.103303146056
,
3.070219780476
,
1.928865588661
,
0.671011951957
,
1.892825933013
,
2.777602823529
,
0.491871575583
,
2.240415846966
,
2.375489703208
,
2.709091612473
,
1.454643174490
,
0.932692202068
,
1.330312119137
,
2.127413235976
,
1.317902934165
,
1.580714395448
,
1.008090918724
,
0.713394722097
,
0.414109615934
,
1.497415539366
,
2.768670845431
,
2.432044584164
,
3.549640318635
,
1.342147007285
,
1.490711835094
,
2.215255822048
,
2.953966963699
,
1.397115922550
,
1.378544056641
,
0.295634610499
,
0.310858641177
,
1.767513113561
,
3.434648852323
,
1.491911596249
,
4.374871362485
,
1.373675010945
,
0.738310910553
,
1.234191541434
,
2.155481175306
,
2.958616497624
,
1.540019317971
,
1.890919744323
,
0.015363864461
,
0.611976171745
,
2.048461203755
,
0.905204536881
,
1.952638485996
,
2.425065685214
,
2.237320237401
,
1.261771567053
,
2.589404719675
,
2.475731886267
,
1.327151422229
,
1.535419742175
,
1.208022763652
,
3.436317329939
,
1.228705365902
,
1.566902016116
,
2.085976618587
,
3.604339608476
,
1.070321479131
,
2.085842592869
,
2.524588738973
,
0.844371573275
,
0.666896658382
,
3.021051396651
,
1.304696763442
,
2.885533651158
,
1.076496681998
,
2.291817051246
,
1.297874264925
,
2.181105432748
,
2.017938562177
,
1.688714920892
,
2.778982555151
,
2.464336632348
,
2.351157814765
,
0.328615108417
,
2.045729524665
,
1.097213832650
,
1.821737535398
,
1.493311782343
,
2.664732414197
,
2.409262056717
,
3.113216373314
,
-
0.736586322990
,
2.622696405459
,
2.182304470993
,
0.270476186892
,
2.628658716523
,
1.938159535287
,
3.336990992309
,
1.843049113291
,
4.411231162672
,
2.385234723754
,
2.814668480292
,
0.835204823261
,
1.486415794299
,
4.428506992622
,
2.010951035701
,
2.386497902232
,
-
0.054343008749
,
1.897798394390
,
2.128378005079
,
2.003863798772
,
1.414857649717
,
3.058382706867
,
0.734980076150
,
0.128402966890
,
0.075621261496
,
2.062530850812
,
2.257054591626
,
3.098405063129
,
1.184503294303
,
1.927098840462
,
3.590105219538
,
2.324770104189
,
2.920547827923
,
3.774469427430
,
0.643975980439
,
2.972011913013
,
1.545636773552
,
1.284276446577
,
2.116456504846
,
2.334765924705
,
1.476322485264
,
0.333938454195
,
1.740780860437
,
0.809641636242
,
2.114359904589
,
3.495010537745
,
1.394057058959
,
2.099880999687
,
1.723136191694
,
1.824145520142
,
2.206175560435
,
2.217935160800
,
3.184151380365
,
-
0.165277107839
,
2.066902569755
,
4.109207223806
,
2.639922346758
,
2.869289441530
,
2.992666432223
,
2.628328010580
,
1.318946413946
,
3.437097382310
,
2.043254488710
,
0.244000873823
,
1.857441713051
,
1.302602111278
,
2.850286225242
,
1.988609208476
,
0.406765856788
,
1.691073499692
,
0.918912799942
,
1.943198487145
,
3.174415802822
,
1.916755816708
,
2.196550794119
,
0.930720476044
,
2.032189015326
,
-
0.777034945338
,
1.406753268550
,
0.870345844705
,
1.793195283464
,
2.066120080070
,
2.916729217526
,
0.642313449142
,
2.617529572000
,
2.396572272668
,
1.942111542427
,
2.435603256612
,
3.898219347884
,
1.979409214342
,
1.235681010137
,
-
0.802441600645
,
1.927883866070
,
0.852232772749
,
2.626513188209
,
1.994232584644
,
2.677125120554
,
2.945149227801
,
0.344859114264
,
2.988484765052
,
2.221699681734
,
1.157038942208
,
2.703070759809
,
1.410436365113
,
3.056534135285
,
0.975232183559
,
1.032651705560
,
1.787301763233
,
1.587502529729
,
1.425207628405
,
2.443158189935
,
3.786205343468
,
0.240451061053
,
2.993759767949
,
2.527525916677
,
2.990777291349
,
1.458774147434
,
4.293524428909
,
-
0.116618748162
,
1.674243883127
,
2.434026351267
,
3.129729749455
,
1.532120640786
,
3.584008627649
,
2.126682783899
,
0.784920593215
,
1.954841166456
,
2.659877218373
,
2.639038968190
,
3.009597452617
,
3.820422929562
,
2.950718556164
,
2.942026969809
,
2.899140330708
,
-
0.003511099104
,
0.780849789152
,
2.375904463772
,
1.034820493941
,
2.010379907777
,
2.273452795908
,
2.508893511243
,
0.495773521197
,
3.010585297044
,
3.029210010516
,
3.973880821070
,
2.416599047057
,
1.320773195864
,
-
0.296283555404
,
3.112367101202
,
0.568454165534
,
3.950197901953
,
3.040255296379
,
2.892209686169
,
1.355195417805
,
2.139684432822
,
2.920582903729
,
1.588899963320
,
2.235959314499
,
1.768769790964
,
2.854126298598
,
-
0.132279995647
,
1.984901818097
,
0.667875459687
,
1.338320000460
,
1.394322304858
,
0.299399873843
,
1.062140558649
,
1.283070416608
,
2.043726915244
,
1.426628725160
,
1.763445352183
,
2.517159283156
,
1.334042379945
,
1.120896888394
,
1.890222921582
,
0.565772476594
,
1.106579774451
,
1.419654511698
,
2.809182593659
,
2.500132279723
,
2.818415931740
,
2.302096389328
,
2.700248827229
,
2.649016952991
,
3.051337084118
,
1.040839832658
,
1.068258609432
,
3.917982023425
,
0.893981534117
,
1.258354231702
,
0.154108914546
,
1.578873281706
,
1.438841948587
,
0.854591114516
,
3.199994919222
,
0.946279793861
,
1.911631903701
,
3.821301368668
,
1.417597738430
,
2.979514439880
,
1.202734703576
,
1.724395921518
,
3.069121580556
,
1.024707141488
,
3.309560047408
,
2.147433150108
,
2.173493008341
,
3.875831804386
,
2.160166379458
,
2.017326408423
,
3.941632320127
,
2.583832116740
,
]
research/efficient-hrl/README.md
View file @
65da497f
Code for performing Hierarchical RL based on
Code for performing Hierarchical RL based on the following publications:
"Data-Efficient Hierarchical Reinforcement Learning" by
Ofir Nachum, Shixiang (Shane) Gu, Honglak Lee, and Sergey Levine
(https://arxiv.org/abs/1805.08296).
This library currently includes three of the environments used:
Ant Maze, Ant Push, and Ant Fall.
The training code is planned to be open-sourced at a later time.
"Near-Optimal Representation Learning for Hierarchical Reinforcement Learning"
by Ofir Nachum, Shixiang (Shane) Gu, Honglak Lee, and Sergey Levine
(https://arxiv.org/abs/1810.01257).
Requirements:
*
TensorFlow (see http://www.tensorflow.org for how to install/upgrade)
*
Gin Config (see https://github.com/google/gin-config)
*
Tensorflow Agents (see https://github.com/tensorflow/agents)
*
OpenAI Gym (see http://gym.openai.com/docs, be sure to install MuJoCo as well)
*
NumPy (see http://www.numpy.org/)
Quick Start:
Run a random policy on AntMaze (or AntPush, AntFall):
Run a training job based on the original HIRO paper on Ant Maze:
```
python scripts/local_train.py test1 hiro_orig ant_maze base_uvf suite
```
Run a continuous evaluation job for that experiment:
```
python
environments/__init__.py --env=AntMaz
e
python
scripts/local_eval.py test1 hiro_orig ant_maze base_uvf suit
e
```
To run the same experiment with online representation learning (the
"Near-Optimal" paper), change
`hiro_orig`
to
`hiro_repr`
.
You can also run with
`hiro_xy`
to run the same experiment with HIRO on only the
xy coordinates of the agent.
To run on other environments, change
`ant_maze`
to something else; e.g.,
`ant_push_multi`
,
`ant_fall_multi`
, etc. See
`context/configs/*`
for other options.
Basic Code Guide:
The code for training resides in train.py. The code trains a lower-level policy
(a UVF agent in the code) and a higher-level policy (a MetaAgent in the code)
concurrently. The higher-level policy communicates goals to the lower-level
policy. In the code, this is called a context. Not only does the lower-level
policy act with respect to a context (a higher-level specified goal), but the
higher-level policy also acts with respect to an environment-specified context
(corresponding to the navigation target location associated with the task).
Therefore, in
`context/configs/*`
you will find both specifications for task setup
as well as goal configurations. Most remaining hyperparameters used for
training/evaluation may be found in
`configs/*`
.
NOTE: Not all the code corresponding to the "Near-Optimal" paper is included.
Namely, changes to low-level policy training proposed in the paper (discounting
and auxiliary rewards) are not implemented here. Performance should not change
significantly.
Maintained by Ofir Nachum (ofirnachum).
research/efficient-hrl/agent.py
0 → 100644
View file @
65da497f
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A UVF agent.
"""
import
tensorflow
as
tf
import
gin.tf
from
agents
import
ddpg_agent
# pylint: disable=unused-import
import
cond_fn
from
utils
import
utils
as
uvf_utils
from
context
import
gin_imports
# pylint: enable=unused-import
slim
=
tf
.
contrib
.
slim
@
gin
.
configurable
class
UvfAgentCore
(
object
):
"""Defines basic functions for UVF agent. Must be inherited with an RL agent.
Used as lower-level agent.
"""
def
__init__
(
self
,
observation_spec
,
action_spec
,
tf_env
,
tf_context
,
step_cond_fn
=
cond_fn
.
env_transition
,
reset_episode_cond_fn
=
cond_fn
.
env_restart
,
reset_env_cond_fn
=
cond_fn
.
false_fn
,
metrics
=
None
,
**
base_agent_kwargs
):
"""Constructs a UVF agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
tf_env: A Tensorflow environment object.
tf_context: A Context class.
step_cond_fn: A function indicating whether to increment the num of steps.
reset_episode_cond_fn: A function indicating whether to restart the
episode, resampling the context.
reset_env_cond_fn: A function indicating whether to perform a manual reset
of the environment.
metrics: A list of functions that evaluate metrics of the agent.
**base_agent_kwargs: A dictionary of parameters for base RL Agent.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self
.
_step_cond_fn
=
step_cond_fn
self
.
_reset_episode_cond_fn
=
reset_episode_cond_fn
self
.
_reset_env_cond_fn
=
reset_env_cond_fn
self
.
metrics
=
metrics
# expose tf_context methods
self
.
tf_context
=
tf_context
(
tf_env
=
tf_env
)
self
.
set_replay
=
self
.
tf_context
.
set_replay
self
.
sample_contexts
=
self
.
tf_context
.
sample_contexts
self
.
compute_rewards
=
self
.
tf_context
.
compute_rewards
self
.
gamma_index
=
self
.
tf_context
.
gamma_index
self
.
context_specs
=
self
.
tf_context
.
context_specs
self
.
context_as_action_specs
=
self
.
tf_context
.
context_as_action_specs
self
.
init_context_vars
=
self
.
tf_context
.
create_vars
self
.
env_observation_spec
=
observation_spec
[
0
]
merged_observation_spec
=
(
uvf_utils
.
merge_specs
(
(
self
.
env_observation_spec
,)
+
self
.
context_specs
),)
self
.
_context_vars
=
dict
()
self
.
_action_vars
=
dict
()
self
.
BASE_AGENT_CLASS
.
__init__
(
self
,
observation_spec
=
merged_observation_spec
,
action_spec
=
action_spec
,
**
base_agent_kwargs
)
def
set_meta_agent
(
self
,
agent
=
None
):
self
.
_meta_agent
=
agent
@
property
def
meta_agent
(
self
):
return
self
.
_meta_agent
def
actor_loss
(
self
,
states
,
actions
,
rewards
,
discounts
,
next_states
):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
return
self
.
BASE_AGENT_CLASS
.
actor_loss
(
self
,
states
)
def
action
(
self
,
state
,
context
=
None
):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
merged_state
=
self
.
merged_state
(
state
,
context
)
return
self
.
BASE_AGENT_CLASS
.
action
(
self
,
merged_state
)
def
actions
(
self
,
state
,
context
=
None
):
"""Returns the next action for the state.
Args:
state: A [-1, num_state_dims] tensor representing a state.
context: A list of [-1, num_context_dims] tensor representing a context.
Returns:
A [-1, num_action_dims] tensor representing the action.
"""
merged_states
=
self
.
merged_states
(
state
,
context
)
return
self
.
BASE_AGENT_CLASS
.
actor_net
(
self
,
merged_states
)
def
log_probs
(
self
,
states
,
actions
,
state_reprs
,
contexts
=
None
):
assert
contexts
is
not
None
batch_dims
=
[
tf
.
shape
(
states
)[
0
],
tf
.
shape
(
states
)[
1
]]
contexts
=
self
.
tf_context
.
context_multi_transition_fn
(
contexts
,
states
=
tf
.
to_float
(
state_reprs
))
flat_states
=
tf
.
reshape
(
states
,
[
batch_dims
[
0
]
*
batch_dims
[
1
],
states
.
shape
[
-
1
]])
flat_contexts
=
[
tf
.
reshape
(
tf
.
cast
(
context
,
states
.
dtype
),
[
batch_dims
[
0
]
*
batch_dims
[
1
],
context
.
shape
[
-
1
]])
for
context
in
contexts
]
flat_pred_actions
=
self
.
actions
(
flat_states
,
flat_contexts
)
pred_actions
=
tf
.
reshape
(
flat_pred_actions
,
batch_dims
+
[
flat_pred_actions
.
shape
[
-
1
]])
error
=
tf
.
square
(
actions
-
pred_actions
)
spec_range
=
(
self
.
_action_spec
.
maximum
-
self
.
_action_spec
.
minimum
)
/
2
normalized_error
=
error
/
tf
.
constant
(
spec_range
)
**
2
return
-
normalized_error
@
gin
.
configurable
(
'uvf_add_noise_fn'
)
def
add_noise_fn
(
self
,
action_fn
,
stddev
=
1.0
,
debug
=
False
,
clip
=
True
,
global_step
=
None
):
"""Returns the action_fn with additive Gaussian noise.
Args:
action_fn: A callable(`state`, `context`) which returns a
[num_action_dims] tensor representing a action.
stddev: stddev for the Ornstein-Uhlenbeck noise.
debug: Print debug messages.
Returns:
A [num_action_dims] action tensor.
"""
if
global_step
is
not
None
:
stddev
*=
tf
.
maximum
(
# Decay exploration during training.
tf
.
train
.
exponential_decay
(
1.0
,
global_step
,
1e6
,
0.8
),
0.5
)
def
noisy_action_fn
(
state
,
context
=
None
):
"""Noisy action fn."""
action
=
action_fn
(
state
,
context
)
if
debug
:
action
=
uvf_utils
.
tf_print
(
action
,
[
action
],
message
=
'[add_noise_fn] pre-noise action'
,
first_n
=
100
)
noise_dist
=
tf
.
distributions
.
Normal
(
tf
.
zeros_like
(
action
),
tf
.
ones_like
(
action
)
*
stddev
)
noise
=
noise_dist
.
sample
()
action
+=
noise
if
debug
:
action
=
uvf_utils
.
tf_print
(
action
,
[
action
],
message
=
'[add_noise_fn] post-noise action'
,
first_n
=
100
)
if
clip
:
action
=
uvf_utils
.
clip_to_spec
(
action
,
self
.
_action_spec
)
return
action
return
noisy_action_fn
def
merged_state
(
self
,
state
,
context
=
None
):
"""Returns the merged state from the environment state and contexts.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
If None, use the internal context.
Returns:
A [num_merged_state_dims] tensor representing the merged state.
"""
if
context
is
None
:
context
=
list
(
self
.
context_vars
)
state
=
tf
.
concat
([
state
,]
+
context
,
axis
=-
1
)
self
.
_validate_states
(
self
.
_batch_state
(
state
))
return
state
def
merged_states
(
self
,
states
,
contexts
=
None
):
"""Returns the batch merged state from the batch env state and contexts.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
contexts: A list of [batch_size, num_context_dims] tensor
representing a batch of contexts. If None,
use the internal context.
Returns:
A [batch_size, num_merged_state_dims] tensor representing the batch
of merged states.
"""
if
contexts
is
None
:
contexts
=
[
tf
.
tile
(
tf
.
expand_dims
(
context
,
axis
=
0
),
(
tf
.
shape
(
states
)[
0
],
1
))
for
context
in
self
.
context_vars
]
states
=
tf
.
concat
([
states
,]
+
contexts
,
axis
=-
1
)
self
.
_validate_states
(
states
)
return
states
def
unmerged_states
(
self
,
merged_states
):
"""Returns the batch state and contexts from the batch merged state.
Args:
merged_states: A [batch_size, num_merged_state_dims] tensor
representing a batch of merged states.
Returns:
A [batch_size, num_state_dims] tensor and a list of
[batch_size, num_context_dims] tensors representing the batch state
and contexts respectively.
"""
self
.
_validate_states
(
merged_states
)
num_state_dims
=
self
.
env_observation_spec
.
shape
.
as_list
()[
0
]
num_context_dims_list
=
[
c
.
shape
.
as_list
()[
0
]
for
c
in
self
.
context_specs
]
states
=
merged_states
[:,
:
num_state_dims
]
contexts
=
[]
i
=
num_state_dims
for
num_context_dims
in
num_context_dims_list
:
contexts
.
append
(
merged_states
[:,
i
:
i
+
num_context_dims
])
i
+=
num_context_dims
return
states
,
contexts
def
sample_random_actions
(
self
,
batch_size
=
1
):
"""Return random actions.
Args:
batch_size: Batch size.
Returns:
A [batch_size, num_action_dims] tensor representing the batch of actions.
"""
actions
=
tf
.
concat
(
[
tf
.
random_uniform
(
shape
=
(
batch_size
,
1
),
minval
=
self
.
_action_spec
.
minimum
[
i
],
maxval
=
self
.
_action_spec
.
maximum
[
i
])
for
i
in
range
(
self
.
_action_spec
.
shape
[
0
].
value
)
],
axis
=
1
)
return
actions
def
clip_actions
(
self
,
actions
):
"""Clip actions to spec.
Args:
actions: A [batch_size, num_action_dims] tensor representing
the batch of actions.
Returns:
A [batch_size, num_action_dims] tensor representing the batch
of clipped actions.
"""
actions
=
tf
.
concat
(
[
tf
.
clip_by_value
(
actions
[:,
i
:
i
+
1
],
self
.
_action_spec
.
minimum
[
i
],
self
.
_action_spec
.
maximum
[
i
])
for
i
in
range
(
self
.
_action_spec
.
shape
[
0
].
value
)
],
axis
=
1
)
return
actions
def
mix_contexts
(
self
,
contexts
,
insert_contexts
,
indices
):
"""Mix two contexts based on indices.
Args:
contexts: A list of [batch_size, num_context_dims] tensor representing
the batch of contexts.
insert_contexts: A list of [batch_size, num_context_dims] tensor
representing the batch of contexts to be inserted.
indices: A list of a list of integers denoting indices to replace.
Returns:
A list of resulting contexts.
"""
if
indices
is
None
:
indices
=
[[]]
*
len
(
contexts
)
assert
len
(
contexts
)
==
len
(
indices
)
assert
all
([
spec
.
shape
.
ndims
==
1
for
spec
in
self
.
context_specs
])
mix_contexts
=
[]
for
contexts_
,
insert_contexts_
,
indices_
,
spec
in
zip
(
contexts
,
insert_contexts
,
indices
,
self
.
context_specs
):
mix_contexts
.
append
(
tf
.
concat
(
[
insert_contexts_
[:,
i
:
i
+
1
]
if
i
in
indices_
else
contexts_
[:,
i
:
i
+
1
]
for
i
in
range
(
spec
.
shape
.
as_list
()[
0
])
],
axis
=
1
))
return
mix_contexts
def
begin_episode_ops
(
self
,
mode
,
action_fn
=
None
,
state
=
None
):
"""Returns ops that reset agent at beginning of episodes.
Args:
mode: a string representing the mode=[train, explore, eval].
Returns:
A list of ops.
"""
all_ops
=
[]
for
_
,
action_var
in
sorted
(
self
.
_action_vars
.
items
()):
sample_action
=
self
.
sample_random_actions
(
1
)[
0
]
all_ops
.
append
(
tf
.
assign
(
action_var
,
sample_action
))
all_ops
+=
self
.
tf_context
.
reset
(
mode
=
mode
,
agent
=
self
.
_meta_agent
,
action_fn
=
action_fn
,
state
=
state
)
return
all_ops
def
cond_begin_episode_op
(
self
,
cond
,
input_vars
,
mode
,
meta_action_fn
):
"""Returns op that resets agent at beginning of episodes.
A new episode is begun if the cond op evalues to `False`.
Args:
cond: a Boolean tensor variable.
input_vars: A list of tensor variables.
mode: a string representing the mode=[train, explore, eval].
Returns:
Conditional begin op.
"""
(
state
,
action
,
reward
,
next_state
,
state_repr
,
next_state_repr
)
=
input_vars
def
continue_fn
():
"""Continue op fn."""
items
=
[
state
,
action
,
reward
,
next_state
,
state_repr
,
next_state_repr
]
+
list
(
self
.
context_vars
)
batch_items
=
[
tf
.
expand_dims
(
item
,
0
)
for
item
in
items
]
(
states
,
actions
,
rewards
,
next_states
,
state_reprs
,
next_state_reprs
)
=
batch_items
[:
6
]
context_reward
=
self
.
compute_rewards
(
mode
,
state_reprs
,
actions
,
rewards
,
next_state_reprs
,
batch_items
[
6
:])[
0
][
0
]
context_reward
=
tf
.
cast
(
context_reward
,
dtype
=
reward
.
dtype
)
if
self
.
meta_agent
is
not
None
:
meta_action
=
tf
.
concat
(
self
.
context_vars
,
-
1
)
items
=
[
state
,
meta_action
,
reward
,
next_state
,
state_repr
,
next_state_repr
]
+
list
(
self
.
meta_agent
.
context_vars
)
batch_items
=
[
tf
.
expand_dims
(
item
,
0
)
for
item
in
items
]
(
states
,
meta_actions
,
rewards
,
next_states
,
state_reprs
,
next_state_reprs
)
=
batch_items
[:
6
]
meta_reward
=
self
.
meta_agent
.
compute_rewards
(
mode
,
states
,
meta_actions
,
rewards
,
next_states
,
batch_items
[
6
:])[
0
][
0
]
meta_reward
=
tf
.
cast
(
meta_reward
,
dtype
=
reward
.
dtype
)
else
:
meta_reward
=
tf
.
constant
(
0
,
dtype
=
reward
.
dtype
)
with
tf
.
control_dependencies
([
context_reward
,
meta_reward
]):
step_ops
=
self
.
tf_context
.
step
(
mode
=
mode
,
agent
=
self
.
_meta_agent
,
state
=
state
,
next_state
=
next_state
,
state_repr
=
state_repr
,
next_state_repr
=
next_state_repr
,
action_fn
=
meta_action_fn
)
with
tf
.
control_dependencies
(
step_ops
):
context_reward
,
meta_reward
=
map
(
tf
.
identity
,
[
context_reward
,
meta_reward
])
return
context_reward
,
meta_reward
def
begin_episode_fn
():
"""Begin op fn."""
begin_ops
=
self
.
begin_episode_ops
(
mode
=
mode
,
action_fn
=
meta_action_fn
,
state
=
state
)
with
tf
.
control_dependencies
(
begin_ops
):
return
tf
.
zeros_like
(
reward
),
tf
.
zeros_like
(
reward
)
with
tf
.
control_dependencies
(
input_vars
):
cond_begin_episode_op
=
tf
.
cond
(
cond
,
continue_fn
,
begin_episode_fn
)
return
cond_begin_episode_op
def
get_env_base_wrapper
(
self
,
env_base
,
**
begin_kwargs
):
"""Create a wrapper around env_base, with agent-specific begin/end_episode.
Args:
env_base: A python environment base.
**begin_kwargs: Keyword args for begin_episode_ops.
Returns:
An object with begin_episode() and end_episode().
"""
begin_ops
=
self
.
begin_episode_ops
(
**
begin_kwargs
)
return
uvf_utils
.
get_contextual_env_base
(
env_base
,
begin_ops
)
def
init_action_vars
(
self
,
name
,
i
=
None
):
"""Create and return a tensorflow Variable holding an action.
Args:
name: Name of the variables.
i: Integer id.
Returns:
A [num_action_dims] tensor.
"""
if
i
is
not
None
:
name
+=
'_%d'
%
i
assert
name
not
in
self
.
_action_vars
,
(
'Conflict! %s is already '
'initialized.'
)
%
name
self
.
_action_vars
[
name
]
=
tf
.
Variable
(
self
.
sample_random_actions
(
1
)[
0
],
name
=
'%s_action'
%
(
name
))
self
.
_validate_actions
(
tf
.
expand_dims
(
self
.
_action_vars
[
name
],
0
))
return
self
.
_action_vars
[
name
]
@
gin
.
configurable
(
'uvf_critic_function'
)
def
critic_function
(
self
,
critic_vals
,
states
,
critic_fn
=
None
):
"""Computes q values based on outputs from the critic net.
Args:
critic_vals: A tf.float32 [batch_size, ...] tensor representing outputs
from the critic net.
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
critic_fn: A callable that process outputs from critic_net and
outputs a [batch_size] tensor representing q values.
Returns:
A tf.float32 [batch_size] tensor representing q values.
"""
if
critic_fn
is
not
None
:
env_states
,
contexts
=
self
.
unmerged_states
(
states
)
critic_vals
=
critic_fn
(
critic_vals
,
env_states
,
contexts
)
critic_vals
.
shape
.
assert_has_rank
(
1
)
return
critic_vals
def
get_action_vars
(
self
,
key
):
return
self
.
_action_vars
[
key
]
def
get_context_vars
(
self
,
key
):
return
self
.
tf_context
.
context_vars
[
key
]
def
step_cond_fn
(
self
,
*
args
):
return
self
.
_step_cond_fn
(
self
,
*
args
)
def
reset_episode_cond_fn
(
self
,
*
args
):
return
self
.
_reset_episode_cond_fn
(
self
,
*
args
)
def
reset_env_cond_fn
(
self
,
*
args
):
return
self
.
_reset_env_cond_fn
(
self
,
*
args
)
@
property
def
context_vars
(
self
):
return
self
.
tf_context
.
vars
@
gin
.
configurable
class
MetaAgentCore
(
UvfAgentCore
):
"""Defines basic functions for UVF Meta-agent. Must be inherited with an RL agent.
Used as higher-level agent.
"""
def
__init__
(
self
,
observation_spec
,
action_spec
,
tf_env
,
tf_context
,
sub_context
,
step_cond_fn
=
cond_fn
.
env_transition
,
reset_episode_cond_fn
=
cond_fn
.
env_restart
,
reset_env_cond_fn
=
cond_fn
.
false_fn
,
metrics
=
None
,
actions_reg
=
0.
,
k
=
2
,
**
base_agent_kwargs
):
"""Constructs a Meta agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
tf_env: A Tensorflow environment object.
tf_context: A Context class.
step_cond_fn: A function indicating whether to increment the num of steps.
reset_episode_cond_fn: A function indicating whether to restart the
episode, resampling the context.
reset_env_cond_fn: A function indicating whether to perform a manual reset
of the environment.
metrics: A list of functions that evaluate metrics of the agent.
**base_agent_kwargs: A dictionary of parameters for base RL Agent.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self
.
_step_cond_fn
=
step_cond_fn
self
.
_reset_episode_cond_fn
=
reset_episode_cond_fn
self
.
_reset_env_cond_fn
=
reset_env_cond_fn
self
.
metrics
=
metrics
self
.
_actions_reg
=
actions_reg
self
.
_k
=
k
# expose tf_context methods
self
.
tf_context
=
tf_context
(
tf_env
=
tf_env
)
self
.
sub_context
=
sub_context
(
tf_env
=
tf_env
)
self
.
set_replay
=
self
.
tf_context
.
set_replay
self
.
sample_contexts
=
self
.
tf_context
.
sample_contexts
self
.
compute_rewards
=
self
.
tf_context
.
compute_rewards
self
.
gamma_index
=
self
.
tf_context
.
gamma_index
self
.
context_specs
=
self
.
tf_context
.
context_specs
self
.
context_as_action_specs
=
self
.
tf_context
.
context_as_action_specs
self
.
sub_context_as_action_specs
=
self
.
sub_context
.
context_as_action_specs
self
.
init_context_vars
=
self
.
tf_context
.
create_vars
self
.
env_observation_spec
=
observation_spec
[
0
]
merged_observation_spec
=
(
uvf_utils
.
merge_specs
(
(
self
.
env_observation_spec
,)
+
self
.
context_specs
),)
self
.
_context_vars
=
dict
()
self
.
_action_vars
=
dict
()
assert
len
(
self
.
context_as_action_specs
)
==
1
self
.
BASE_AGENT_CLASS
.
__init__
(
self
,
observation_spec
=
merged_observation_spec
,
action_spec
=
self
.
sub_context_as_action_specs
,
**
base_agent_kwargs
)
@
gin
.
configurable
(
'meta_add_noise_fn'
)
def
add_noise_fn
(
self
,
action_fn
,
stddev
=
1.0
,
debug
=
False
,
global_step
=
None
):
noisy_action_fn
=
super
(
MetaAgentCore
,
self
).
add_noise_fn
(
action_fn
,
stddev
,
clip
=
True
,
global_step
=
global_step
)
return
noisy_action_fn
def
actor_loss
(
self
,
states
,
actions
,
rewards
,
discounts
,
next_states
):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
actions
=
self
.
actor_net
(
states
,
stop_gradients
=
False
)
regularizer
=
self
.
_actions_reg
*
tf
.
reduce_mean
(
tf
.
reduce_sum
(
tf
.
abs
(
actions
[:,
self
.
_k
:]),
-
1
),
0
)
loss
=
self
.
BASE_AGENT_CLASS
.
actor_loss
(
self
,
states
)
return
regularizer
+
loss
@
gin
.
configurable
class
UvfAgent
(
UvfAgentCore
,
ddpg_agent
.
TD3Agent
):
"""A DDPG agent with UVF.
"""
BASE_AGENT_CLASS
=
ddpg_agent
.
TD3Agent
ACTION_TYPE
=
'continuous'
def
__init__
(
self
,
*
args
,
**
kwargs
):
UvfAgentCore
.
__init__
(
self
,
*
args
,
**
kwargs
)
@
gin
.
configurable
class
MetaAgent
(
MetaAgentCore
,
ddpg_agent
.
TD3Agent
):
"""A DDPG meta-agent.
"""
BASE_AGENT_CLASS
=
ddpg_agent
.
TD3Agent
ACTION_TYPE
=
'continuous'
def
__init__
(
self
,
*
args
,
**
kwargs
):
MetaAgentCore
.
__init__
(
self
,
*
args
,
**
kwargs
)
@
gin
.
configurable
()
def
state_preprocess_net
(
states
,
num_output_dims
=
2
,
states_hidden_layers
=
(
100
,),
normalizer_fn
=
None
,
activation_fn
=
tf
.
nn
.
relu
,
zero_time
=
True
,
images
=
False
):
"""Creates a simple feed forward net for embedding states.
"""
with
slim
.
arg_scope
(
[
slim
.
fully_connected
],
activation_fn
=
activation_fn
,
normalizer_fn
=
normalizer_fn
,
weights_initializer
=
slim
.
variance_scaling_initializer
(
factor
=
1.0
/
3.0
,
mode
=
'FAN_IN'
,
uniform
=
True
)):
states_shape
=
tf
.
shape
(
states
)
states_dtype
=
states
.
dtype
states
=
tf
.
to_float
(
states
)
if
images
:
# Zero-out x-y
states
*=
tf
.
constant
([
0.
]
*
2
+
[
1.
]
*
(
states
.
shape
[
-
1
]
-
2
),
dtype
=
states
.
dtype
)
if
zero_time
:
states
*=
tf
.
constant
([
1.
]
*
(
states
.
shape
[
-
1
]
-
1
)
+
[
0.
],
dtype
=
states
.
dtype
)
orig_states
=
states
embed
=
states
if
states_hidden_layers
:
embed
=
slim
.
stack
(
embed
,
slim
.
fully_connected
,
states_hidden_layers
,
scope
=
'states'
)
with
slim
.
arg_scope
([
slim
.
fully_connected
],
weights_regularizer
=
None
,
weights_initializer
=
tf
.
random_uniform_initializer
(
minval
=-
0.003
,
maxval
=
0.003
)):
embed
=
slim
.
fully_connected
(
embed
,
num_output_dims
,
activation_fn
=
None
,
normalizer_fn
=
None
,
scope
=
'value'
)
output
=
embed
output
=
tf
.
cast
(
output
,
states_dtype
)
return
output
@
gin
.
configurable
()
def
action_embed_net
(
actions
,
states
=
None
,
num_output_dims
=
2
,
hidden_layers
=
(
400
,
300
),
normalizer_fn
=
None
,
activation_fn
=
tf
.
nn
.
relu
,
zero_time
=
True
,
images
=
False
):
"""Creates a simple feed forward net for embedding actions.
"""
with
slim
.
arg_scope
(
[
slim
.
fully_connected
],
activation_fn
=
activation_fn
,
normalizer_fn
=
normalizer_fn
,
weights_initializer
=
slim
.
variance_scaling_initializer
(
factor
=
1.0
/
3.0
,
mode
=
'FAN_IN'
,
uniform
=
True
)):
actions
=
tf
.
to_float
(
actions
)
if
states
is
not
None
:
if
images
:
# Zero-out x-y
states
*=
tf
.
constant
([
0.
]
*
2
+
[
1.
]
*
(
states
.
shape
[
-
1
]
-
2
),
dtype
=
states
.
dtype
)
if
zero_time
:
states
*=
tf
.
constant
([
1.
]
*
(
states
.
shape
[
-
1
]
-
1
)
+
[
0.
],
dtype
=
states
.
dtype
)
actions
=
tf
.
concat
([
actions
,
tf
.
to_float
(
states
)],
-
1
)
embed
=
actions
if
hidden_layers
:
embed
=
slim
.
stack
(
embed
,
slim
.
fully_connected
,
hidden_layers
,
scope
=
'hidden'
)
with
slim
.
arg_scope
([
slim
.
fully_connected
],
weights_regularizer
=
None
,
weights_initializer
=
tf
.
random_uniform_initializer
(
minval
=-
0.003
,
maxval
=
0.003
)):
embed
=
slim
.
fully_connected
(
embed
,
num_output_dims
,
activation_fn
=
None
,
normalizer_fn
=
None
,
scope
=
'value'
)
if
num_output_dims
==
1
:
return
embed
[:,
0
,
...]
else
:
return
embed
def
huber
(
x
,
kappa
=
0.1
):
return
(
0.5
*
tf
.
square
(
x
)
*
tf
.
to_float
(
tf
.
abs
(
x
)
<=
kappa
)
+
kappa
*
(
tf
.
abs
(
x
)
-
0.5
*
kappa
)
*
tf
.
to_float
(
tf
.
abs
(
x
)
>
kappa
)
)
/
kappa
@
gin
.
configurable
()
class
StatePreprocess
(
object
):
STATE_PREPROCESS_NET_SCOPE
=
'state_process_net'
ACTION_EMBED_NET_SCOPE
=
'action_embed_net'
def
__init__
(
self
,
trainable
=
False
,
state_preprocess_net
=
lambda
states
:
states
,
action_embed_net
=
lambda
actions
,
*
args
,
**
kwargs
:
actions
,
ndims
=
None
):
self
.
trainable
=
trainable
self
.
_scope
=
tf
.
get_variable_scope
().
name
self
.
_ndims
=
ndims
self
.
_state_preprocess_net
=
tf
.
make_template
(
self
.
STATE_PREPROCESS_NET_SCOPE
,
state_preprocess_net
,
create_scope_now_
=
True
)
self
.
_action_embed_net
=
tf
.
make_template
(
self
.
ACTION_EMBED_NET_SCOPE
,
action_embed_net
,
create_scope_now_
=
True
)
def
__call__
(
self
,
states
):
batched
=
states
.
get_shape
().
ndims
!=
1
if
not
batched
:
states
=
tf
.
expand_dims
(
states
,
0
)
embedded
=
self
.
_state_preprocess_net
(
states
)
if
self
.
_ndims
is
not
None
:
embedded
=
embedded
[...,
:
self
.
_ndims
]
if
not
batched
:
return
embedded
[
0
]
return
embedded
def
loss
(
self
,
states
,
next_states
,
low_actions
,
low_states
):
batch_size
=
tf
.
shape
(
states
)[
0
]
d
=
int
(
low_states
.
shape
[
1
])
# Sample indices into meta-transition to train on.
probs
=
0.99
**
tf
.
range
(
d
,
dtype
=
tf
.
float32
)
probs
*=
tf
.
constant
([
1.0
]
*
(
d
-
1
)
+
[
1.0
/
(
1
-
0.99
)],
dtype
=
tf
.
float32
)
probs
/=
tf
.
reduce_sum
(
probs
)
index_dist
=
tf
.
distributions
.
Categorical
(
probs
=
probs
,
dtype
=
tf
.
int64
)
indices
=
index_dist
.
sample
(
batch_size
)
batch_size
=
tf
.
cast
(
batch_size
,
tf
.
int64
)
next_indices
=
tf
.
concat
(
[
tf
.
range
(
batch_size
,
dtype
=
tf
.
int64
)[:,
None
],
(
1
+
indices
[:,
None
])
%
d
],
-
1
)
new_next_states
=
tf
.
where
(
indices
<
d
-
1
,
tf
.
gather_nd
(
low_states
,
next_indices
),
next_states
)
next_states
=
new_next_states
embed1
=
tf
.
to_float
(
self
.
_state_preprocess_net
(
states
))
embed2
=
tf
.
to_float
(
self
.
_state_preprocess_net
(
next_states
))
action_embed
=
self
.
_action_embed_net
(
tf
.
layers
.
flatten
(
low_actions
),
states
=
states
)
tau
=
2.0
fn
=
lambda
z
:
tau
*
tf
.
reduce_sum
(
huber
(
z
),
-
1
)
all_embed
=
tf
.
get_variable
(
'all_embed'
,
[
1024
,
int
(
embed1
.
shape
[
-
1
])],
initializer
=
tf
.
zeros_initializer
())
upd
=
all_embed
.
assign
(
tf
.
concat
([
all_embed
[
batch_size
:],
embed2
],
0
))
with
tf
.
control_dependencies
([
upd
]):
close
=
1
*
tf
.
reduce_mean
(
fn
(
embed1
+
action_embed
-
embed2
))
prior_log_probs
=
tf
.
reduce_logsumexp
(
-
fn
((
embed1
+
action_embed
)[:,
None
,
:]
-
all_embed
[
None
,
:,
:]),
axis
=-
1
)
-
tf
.
log
(
tf
.
to_float
(
all_embed
.
shape
[
0
]))
far
=
tf
.
reduce_mean
(
tf
.
exp
(
-
fn
((
embed1
+
action_embed
)[
1
:]
-
embed2
[:
-
1
])
-
tf
.
stop_gradient
(
prior_log_probs
[
1
:])))
repr_log_probs
=
tf
.
stop_gradient
(
-
fn
(
embed1
+
action_embed
-
embed2
)
-
prior_log_probs
)
/
tau
return
close
+
far
,
repr_log_probs
,
indices
def
get_trainable_vars
(
self
):
return
(
slim
.
get_trainable_variables
(
uvf_utils
.
join_scope
(
self
.
_scope
,
self
.
STATE_PREPROCESS_NET_SCOPE
))
+
slim
.
get_trainable_variables
(
uvf_utils
.
join_scope
(
self
.
_scope
,
self
.
ACTION_EMBED_NET_SCOPE
)))
@
gin
.
configurable
()
class
InverseDynamics
(
object
):
INVERSE_DYNAMICS_NET_SCOPE
=
'inverse_dynamics'
def
__init__
(
self
,
spec
):
self
.
_spec
=
spec
def
sample
(
self
,
states
,
next_states
,
num_samples
,
orig_goals
,
sc
=
0.5
):
goal_dim
=
orig_goals
.
shape
[
-
1
]
spec_range
=
(
self
.
_spec
.
maximum
-
self
.
_spec
.
minimum
)
/
2
*
tf
.
ones
([
goal_dim
])
loc
=
tf
.
cast
(
next_states
-
states
,
tf
.
float32
)[:,
:
goal_dim
]
scale
=
sc
*
tf
.
tile
(
tf
.
reshape
(
spec_range
,
[
1
,
goal_dim
]),
[
tf
.
shape
(
states
)[
0
],
1
])
dist
=
tf
.
distributions
.
Normal
(
loc
,
scale
)
if
num_samples
==
1
:
return
dist
.
sample
()
samples
=
tf
.
concat
([
dist
.
sample
(
num_samples
-
2
),
tf
.
expand_dims
(
loc
,
0
),
tf
.
expand_dims
(
orig_goals
,
0
)],
0
)
return
uvf_utils
.
clip_to_spec
(
samples
,
self
.
_spec
)
research/efficient-hrl/agents/__init__.py
0 → 100644
View file @
65da497f
research/efficient-hrl/agents/circular_buffer.py
0 → 100644
View file @
65da497f
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A circular buffer where each element is a list of tensors.
Each element of the buffer is a list of tensors. An example use case is a replay
buffer in reinforcement learning, where each element is a list of tensors
representing the state, action, reward etc.
New elements are added sequentially, and once the buffer is full, we
start overwriting them in a circular fashion. Reading does not remove any
elements, only adding new elements does.
"""
import
collections
import
numpy
as
np
import
tensorflow
as
tf
import
gin.tf
@
gin
.
configurable
class
CircularBuffer
(
object
):
"""A circular buffer where each element is a list of tensors."""
def
__init__
(
self
,
buffer_size
=
1000
,
scope
=
'replay_buffer'
):
"""Circular buffer of list of tensors.
Args:
buffer_size: (integer) maximum number of tensor lists the buffer can hold.
scope: (string) variable scope for creating the variables.
"""
self
.
_buffer_size
=
np
.
int64
(
buffer_size
)
self
.
_scope
=
scope
self
.
_tensors
=
collections
.
OrderedDict
()
with
tf
.
variable_scope
(
self
.
_scope
):
self
.
_num_adds
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
,
name
=
'num_adds'
)
self
.
_num_adds_cs
=
tf
.
contrib
.
framework
.
CriticalSection
(
name
=
'num_adds'
)
@
property
def
buffer_size
(
self
):
return
self
.
_buffer_size
@
property
def
scope
(
self
):
return
self
.
_scope
@
property
def
num_adds
(
self
):
return
self
.
_num_adds
def
_create_variables
(
self
,
tensors
):
with
tf
.
variable_scope
(
self
.
_scope
):
for
name
in
tensors
.
keys
():
tensor
=
tensors
[
name
]
self
.
_tensors
[
name
]
=
tf
.
get_variable
(
name
=
'BufferVariable_'
+
name
,
shape
=
[
self
.
_buffer_size
]
+
tensor
.
get_shape
().
as_list
(),
dtype
=
tensor
.
dtype
,
trainable
=
False
)
def
_validate
(
self
,
tensors
):
"""Validate shapes of tensors."""
if
len
(
tensors
)
!=
len
(
self
.
_tensors
):
raise
ValueError
(
'Expected tensors to have %d elements. Received %d '
'instead.'
%
(
len
(
self
.
_tensors
),
len
(
tensors
)))
if
self
.
_tensors
.
keys
()
!=
tensors
.
keys
():
raise
ValueError
(
'The keys of tensors should be the always the same.'
'Received %s instead %s.'
%
(
tensors
.
keys
(),
self
.
_tensors
.
keys
()))
for
name
,
tensor
in
tensors
.
items
():
if
tensor
.
get_shape
().
as_list
()
!=
self
.
_tensors
[
name
].
get_shape
().
as_list
()[
1
:]:
raise
ValueError
(
'Tensor %s has incorrect shape.'
%
name
)
if
not
tensor
.
dtype
.
is_compatible_with
(
self
.
_tensors
[
name
].
dtype
):
raise
ValueError
(
'Tensor %s has incorrect data type. Expected %s, received %s'
%
(
name
,
self
.
_tensors
[
name
].
read_value
().
dtype
,
tensor
.
dtype
))
def
add
(
self
,
tensors
):
"""Adds an element (list/tuple/dict of tensors) to the buffer.
Args:
tensors: (list/tuple/dict of tensors) to be added to the buffer.
Returns:
An add operation that adds the input `tensors` to the buffer. Similar to
an enqueue_op.
Raises:
ValueError: If the shapes and data types of input `tensors' are not the
same across calls to the add function.
"""
return
self
.
maybe_add
(
tensors
,
True
)
def
maybe_add
(
self
,
tensors
,
condition
):
"""Adds an element (tensors) to the buffer based on the condition..
Args:
tensors: (list/tuple of tensors) to be added to the buffer.
condition: A boolean Tensor controlling whether the tensors would be added
to the buffer or not.
Returns:
An add operation that adds the input `tensors` to the buffer. Similar to
an maybe_enqueue_op.
Raises:
ValueError: If the shapes and data types of input `tensors' are not the
same across calls to the add function.
"""
if
not
isinstance
(
tensors
,
dict
):
names
=
[
str
(
i
)
for
i
in
range
(
len
(
tensors
))]
tensors
=
collections
.
OrderedDict
(
zip
(
names
,
tensors
))
if
not
isinstance
(
tensors
,
collections
.
OrderedDict
):
tensors
=
collections
.
OrderedDict
(
sorted
(
tensors
.
items
(),
key
=
lambda
t
:
t
[
0
]))
if
not
self
.
_tensors
:
self
.
_create_variables
(
tensors
)
else
:
self
.
_validate
(
tensors
)
#@tf.critical_section(self._position_mutex)
def
_increment_num_adds
():
# Adding 0 to the num_adds variable is a trick to read the value of the
# variable and return a read-only tensor. Doing this in a critical
# section allows us to capture a snapshot of the variable that will
# not be affected by other threads updating num_adds.
return
self
.
_num_adds
.
assign_add
(
1
)
+
0
def
_add
():
num_adds_inc
=
self
.
_num_adds_cs
.
execute
(
_increment_num_adds
)
current_pos
=
tf
.
mod
(
num_adds_inc
-
1
,
self
.
_buffer_size
)
update_ops
=
[]
for
name
in
self
.
_tensors
.
keys
():
update_ops
.
append
(
tf
.
scatter_update
(
self
.
_tensors
[
name
],
current_pos
,
tensors
[
name
]))
return
tf
.
group
(
*
update_ops
)
return
tf
.
contrib
.
framework
.
smart_cond
(
condition
,
_add
,
tf
.
no_op
)
def
get_random_batch
(
self
,
batch_size
,
keys
=
None
,
num_steps
=
1
):
"""Samples a batch of tensors from the buffer with replacement.
Args:
batch_size: (integer) number of elements to sample.
keys: List of keys of tensors to retrieve. If None retrieve all.
num_steps: (integer) length of trajectories to return. If > 1 will return
a list of lists, where each internal list represents a trajectory of
length num_steps.
Returns:
A list of tensors, where each element in the list is a batch sampled from
one of the tensors in the buffer.
Raises:
ValueError: If get_random_batch is called before calling the add function.
tf.errors.InvalidArgumentError: If this operation is executed before any
items are added to the buffer.
"""
if
not
self
.
_tensors
:
raise
ValueError
(
'The add function must be called before get_random_batch.'
)
if
keys
is
None
:
keys
=
self
.
_tensors
.
keys
()
latest_start_index
=
self
.
get_num_adds
()
-
num_steps
+
1
empty_buffer_assert
=
tf
.
Assert
(
tf
.
greater
(
latest_start_index
,
0
),
[
'Not enough elements have been added to the buffer.'
])
with
tf
.
control_dependencies
([
empty_buffer_assert
]):
max_index
=
tf
.
minimum
(
self
.
_buffer_size
,
latest_start_index
)
indices
=
tf
.
random_uniform
(
[
batch_size
],
minval
=
0
,
maxval
=
max_index
,
dtype
=
tf
.
int64
)
if
num_steps
==
1
:
return
self
.
gather
(
indices
,
keys
)
else
:
return
self
.
gather_nstep
(
num_steps
,
indices
,
keys
)
def
gather
(
self
,
indices
,
keys
=
None
):
"""Returns elements at the specified indices from the buffer.
Args:
indices: (list of integers or rank 1 int Tensor) indices in the buffer to
retrieve elements from.
keys: List of keys of tensors to retrieve. If None retrieve all.
Returns:
A list of tensors, where each element in the list is obtained by indexing
one of the tensors in the buffer.
Raises:
ValueError: If gather is called before calling the add function.
tf.errors.InvalidArgumentError: If indices are bigger than the number of
items in the buffer.
"""
if
not
self
.
_tensors
:
raise
ValueError
(
'The add function must be called before calling gather.'
)
if
keys
is
None
:
keys
=
self
.
_tensors
.
keys
()
with
tf
.
name_scope
(
'Gather'
):
index_bound_assert
=
tf
.
Assert
(
tf
.
less
(
tf
.
to_int64
(
tf
.
reduce_max
(
indices
)),
tf
.
minimum
(
self
.
get_num_adds
(),
self
.
_buffer_size
)),
[
'Index out of bounds.'
])
with
tf
.
control_dependencies
([
index_bound_assert
]):
indices
=
tf
.
convert_to_tensor
(
indices
)
batch
=
[]
for
key
in
keys
:
batch
.
append
(
tf
.
gather
(
self
.
_tensors
[
key
],
indices
,
name
=
key
))
return
batch
def
gather_nstep
(
self
,
num_steps
,
indices
,
keys
=
None
):
"""Returns elements at the specified indices from the buffer.
Args:
num_steps: (integer) length of trajectories to return.
indices: (list of rank num_steps int Tensor) indices in the buffer to
retrieve elements from for multiple trajectories. Each Tensor in the
list represents the indices for a trajectory.
keys: List of keys of tensors to retrieve. If None retrieve all.
Returns:
A list of list-of-tensors, where each element in the list is obtained by
indexing one of the tensors in the buffer.
Raises:
ValueError: If gather is called before calling the add function.
tf.errors.InvalidArgumentError: If indices are bigger than the number of
items in the buffer.
"""
if
not
self
.
_tensors
:
raise
ValueError
(
'The add function must be called before calling gather.'
)
if
keys
is
None
:
keys
=
self
.
_tensors
.
keys
()
with
tf
.
name_scope
(
'Gather'
):
index_bound_assert
=
tf
.
Assert
(
tf
.
less_equal
(
tf
.
to_int64
(
tf
.
reduce_max
(
indices
)
+
num_steps
),
self
.
get_num_adds
()),
[
'Trajectory indices go out of bounds.'
])
with
tf
.
control_dependencies
([
index_bound_assert
]):
indices
=
tf
.
map_fn
(
lambda
x
:
tf
.
mod
(
tf
.
range
(
x
,
x
+
num_steps
),
self
.
_buffer_size
),
indices
,
dtype
=
tf
.
int64
)
batch
=
[]
for
key
in
keys
:
def
SampleTrajectories
(
trajectory_indices
,
key
=
key
,
num_steps
=
num_steps
):
trajectory_indices
.
set_shape
([
num_steps
])
return
tf
.
gather
(
self
.
_tensors
[
key
],
trajectory_indices
,
name
=
key
)
batch
.
append
(
tf
.
map_fn
(
SampleTrajectories
,
indices
,
dtype
=
self
.
_tensors
[
key
].
dtype
))
return
batch
def
get_position
(
self
):
"""Returns the position at which the last element was added.
Returns:
An int tensor representing the index at which the last element was added
to the buffer or -1 if no elements were added.
"""
return
tf
.
cond
(
self
.
get_num_adds
()
<
1
,
lambda
:
self
.
get_num_adds
()
-
1
,
lambda
:
tf
.
mod
(
self
.
get_num_adds
()
-
1
,
self
.
_buffer_size
))
def
get_num_adds
(
self
):
"""Returns the number of additions to the buffer.
Returns:
An int tensor representing the number of elements that were added.
"""
def
num_adds
():
return
self
.
_num_adds
.
value
()
return
self
.
_num_adds_cs
.
execute
(
num_adds
)
def
get_num_tensors
(
self
):
"""Returns the number of tensors (slots) in the buffer."""
return
len
(
self
.
_tensors
)
research/efficient-hrl/agents/ddpg_agent.py
0 → 100644
View file @
65da497f
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A DDPG/NAF agent.
Implements the Deep Deterministic Policy Gradient (DDPG) algorithm from
"Continuous control with deep reinforcement learning" - Lilicrap et al.
https://arxiv.org/abs/1509.02971, and the Normalized Advantage Functions (NAF)
algorithm "Continuous Deep Q-Learning with Model-based Acceleration" - Gu et al.
https://arxiv.org/pdf/1603.00748.
"""
import
tensorflow
as
tf
slim
=
tf
.
contrib
.
slim
import
gin.tf
from
utils
import
utils
from
agents
import
ddpg_networks
as
networks
@
gin
.
configurable
class
DdpgAgent
(
object
):
"""An RL agent that learns using the DDPG algorithm.
Example usage:
def critic_net(states, actions):
...
def actor_net(states, num_action_dims):
...
Given a tensorflow environment tf_env,
(of type learning.deepmind.rl.environments.tensorflow.python.tfpyenvironment)
obs_spec = tf_env.observation_spec()
action_spec = tf_env.action_spec()
ddpg_agent = agent.DdpgAgent(obs_spec,
action_spec,
actor_net=actor_net,
critic_net=critic_net)
we can perform actions on the environment as follows:
state = tf_env.observations()[0]
action = ddpg_agent.actor_net(tf.expand_dims(state, 0))[0, :]
transition_type, reward, discount = tf_env.step([action])
Train:
critic_loss = ddpg_agent.critic_loss(states, actions, rewards, discounts,
next_states)
actor_loss = ddpg_agent.actor_loss(states)
critic_train_op = slim.learning.create_train_op(
critic_loss,
critic_optimizer,
variables_to_train=ddpg_agent.get_trainable_critic_vars(),
)
actor_train_op = slim.learning.create_train_op(
actor_loss,
actor_optimizer,
variables_to_train=ddpg_agent.get_trainable_actor_vars(),
)
"""
ACTOR_NET_SCOPE
=
'actor_net'
CRITIC_NET_SCOPE
=
'critic_net'
TARGET_ACTOR_NET_SCOPE
=
'target_actor_net'
TARGET_CRITIC_NET_SCOPE
=
'target_critic_net'
def
__init__
(
self
,
observation_spec
,
action_spec
,
actor_net
=
networks
.
actor_net
,
critic_net
=
networks
.
critic_net
,
td_errors_loss
=
tf
.
losses
.
huber_loss
,
dqda_clipping
=
0.
,
actions_regularizer
=
0.
,
target_q_clipping
=
None
,
residual_phi
=
0.0
,
debug_summaries
=
False
):
"""Constructs a DDPG agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
actor_net: A callable that creates the actor network. Must take the
following arguments: states, num_actions. Please see networks.actor_net
for an example.
critic_net: A callable that creates the critic network. Must take the
following arguments: states, actions. Please see networks.critic_net
for an example.
td_errors_loss: A callable defining the loss function for the critic
td error.
dqda_clipping: (float) clips the gradient dqda element-wise between
[-dqda_clipping, dqda_clipping]. Does not perform clipping if
dqda_clipping == 0.
actions_regularizer: A scalar, when positive penalizes the norm of the
actions. This can prevent saturation of actions for the actor_loss.
target_q_clipping: (tuple of floats) clips target q values within
(low, high) values when computing the critic loss.
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that
interpolates between Q-learning and residual gradient algorithm.
http://www.leemon.com/papers/1995b.pdf
debug_summaries: If True, add summaries to help debug behavior.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self
.
_observation_spec
=
observation_spec
[
0
]
self
.
_action_spec
=
action_spec
[
0
]
self
.
_state_shape
=
tf
.
TensorShape
([
None
]).
concatenate
(
self
.
_observation_spec
.
shape
)
self
.
_action_shape
=
tf
.
TensorShape
([
None
]).
concatenate
(
self
.
_action_spec
.
shape
)
self
.
_num_action_dims
=
self
.
_action_spec
.
shape
.
num_elements
()
self
.
_scope
=
tf
.
get_variable_scope
().
name
self
.
_actor_net
=
tf
.
make_template
(
self
.
ACTOR_NET_SCOPE
,
actor_net
,
create_scope_now_
=
True
)
self
.
_critic_net
=
tf
.
make_template
(
self
.
CRITIC_NET_SCOPE
,
critic_net
,
create_scope_now_
=
True
)
self
.
_target_actor_net
=
tf
.
make_template
(
self
.
TARGET_ACTOR_NET_SCOPE
,
actor_net
,
create_scope_now_
=
True
)
self
.
_target_critic_net
=
tf
.
make_template
(
self
.
TARGET_CRITIC_NET_SCOPE
,
critic_net
,
create_scope_now_
=
True
)
self
.
_td_errors_loss
=
td_errors_loss
if
dqda_clipping
<
0
:
raise
ValueError
(
'dqda_clipping must be >= 0.'
)
self
.
_dqda_clipping
=
dqda_clipping
self
.
_actions_regularizer
=
actions_regularizer
self
.
_target_q_clipping
=
target_q_clipping
self
.
_residual_phi
=
residual_phi
self
.
_debug_summaries
=
debug_summaries
def
_batch_state
(
self
,
state
):
"""Convert state to a batched state.
Args:
state: Either a list/tuple with an state tensor [num_state_dims].
Returns:
A tensor [1, num_state_dims]
"""
if
isinstance
(
state
,
(
tuple
,
list
)):
state
=
state
[
0
]
if
state
.
get_shape
().
ndims
==
1
:
state
=
tf
.
expand_dims
(
state
,
0
)
return
state
def
action
(
self
,
state
):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
Returns:
A [num_action_dims] tensor representing the action.
"""
return
self
.
actor_net
(
self
.
_batch_state
(
state
),
stop_gradients
=
True
)[
0
,
:]
@
gin
.
configurable
(
'ddpg_sample_action'
)
def
sample_action
(
self
,
state
,
stddev
=
1.0
):
"""Returns the action for the state with additive noise.
Args:
state: A [num_state_dims] tensor representing a state.
stddev: stddev for the Ornstein-Uhlenbeck noise.
Returns:
A [num_action_dims] action tensor.
"""
agent_action
=
self
.
action
(
state
)
agent_action
+=
tf
.
random_normal
(
tf
.
shape
(
agent_action
))
*
stddev
return
utils
.
clip_to_spec
(
agent_action
,
self
.
_action_spec
)
def
actor_net
(
self
,
states
,
stop_gradients
=
False
):
"""Returns the output of the actor network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
stop_gradients: (boolean) if true, gradients cannot be propogated through
this operation.
Returns:
A [batch_size, num_action_dims] tensor of actions.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self
.
_validate_states
(
states
)
actions
=
self
.
_actor_net
(
states
,
self
.
_action_spec
)
if
stop_gradients
:
actions
=
tf
.
stop_gradient
(
actions
)
return
actions
def
critic_net
(
self
,
states
,
actions
,
for_critic_loss
=
False
):
"""Returns the output of the critic network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self
.
_validate_states
(
states
)
self
.
_validate_actions
(
actions
)
return
self
.
_critic_net
(
states
,
actions
,
for_critic_loss
=
for_critic_loss
)
def
target_actor_net
(
self
,
states
):
"""Returns the output of the target actor network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
A [batch_size, num_action_dims] tensor of actions.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self
.
_validate_states
(
states
)
actions
=
self
.
_target_actor_net
(
states
,
self
.
_action_spec
)
return
tf
.
stop_gradient
(
actions
)
def
target_critic_net
(
self
,
states
,
actions
,
for_critic_loss
=
False
):
"""Returns the output of the target critic network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self
.
_validate_states
(
states
)
self
.
_validate_actions
(
actions
)
return
tf
.
stop_gradient
(
self
.
_target_critic_net
(
states
,
actions
,
for_critic_loss
=
for_critic_loss
))
def
value_net
(
self
,
states
,
for_critic_loss
=
False
):
"""Returns the output of the critic evaluated with the actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
actions
=
self
.
actor_net
(
states
)
return
self
.
critic_net
(
states
,
actions
,
for_critic_loss
=
for_critic_loss
)
def
target_value_net
(
self
,
states
,
for_critic_loss
=
False
):
"""Returns the output of the target critic evaluated with the target actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
target_actions
=
self
.
target_actor_net
(
states
)
return
self
.
target_critic_net
(
states
,
target_actions
,
for_critic_loss
=
for_critic_loss
)
def
critic_loss
(
self
,
states
,
actions
,
rewards
,
discounts
,
next_states
):
"""Computes a loss for training the critic network.
The loss is the mean squared error between the Q value predictions of the
critic and Q values estimated using TD-lambda.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
rewards: A [batch_size, ...] tensor representing a batch of rewards,
broadcastable to the critic net output.
discounts: A [batch_size, ...] tensor representing a batch of discounts,
broadcastable to the critic net output.
next_states: A [batch_size, num_state_dims] tensor representing a batch
of next states.
Returns:
A rank-0 tensor representing the critic loss.
Raises:
ValueError: If any of the inputs do not have the expected dimensions, or
if their batch_sizes do not match.
"""
self
.
_validate_states
(
states
)
self
.
_validate_actions
(
actions
)
self
.
_validate_states
(
next_states
)
target_q_values
=
self
.
target_value_net
(
next_states
,
for_critic_loss
=
True
)
td_targets
=
target_q_values
*
discounts
+
rewards
if
self
.
_target_q_clipping
is
not
None
:
td_targets
=
tf
.
clip_by_value
(
td_targets
,
self
.
_target_q_clipping
[
0
],
self
.
_target_q_clipping
[
1
])
q_values
=
self
.
critic_net
(
states
,
actions
,
for_critic_loss
=
True
)
td_errors
=
td_targets
-
q_values
if
self
.
_debug_summaries
:
gen_debug_td_error_summaries
(
target_q_values
,
q_values
,
td_targets
,
td_errors
)
loss
=
self
.
_td_errors_loss
(
td_targets
,
q_values
)
if
self
.
_residual_phi
>
0.0
:
# compute residual gradient loss
residual_q_values
=
self
.
value_net
(
next_states
,
for_critic_loss
=
True
)
residual_td_targets
=
residual_q_values
*
discounts
+
rewards
if
self
.
_target_q_clipping
is
not
None
:
residual_td_targets
=
tf
.
clip_by_value
(
residual_td_targets
,
self
.
_target_q_clipping
[
0
],
self
.
_target_q_clipping
[
1
])
residual_td_errors
=
residual_td_targets
-
q_values
residual_loss
=
self
.
_td_errors_loss
(
residual_td_targets
,
residual_q_values
)
loss
=
(
loss
*
(
1.0
-
self
.
_residual_phi
)
+
residual_loss
*
self
.
_residual_phi
)
return
loss
def
actor_loss
(
self
,
states
):
"""Computes a loss for training the actor network.
Note that output does not represent an actual loss. It is called a loss only
in the sense that its gradient w.r.t. the actor network weights is the
correct gradient for training the actor network,
i.e. dloss/dweights = (dq/da)*(da/dweights)
which is the gradient used in Algorithm 1 of Lilicrap et al.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
A rank-0 tensor representing the actor loss.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self
.
_validate_states
(
states
)
actions
=
self
.
actor_net
(
states
,
stop_gradients
=
False
)
critic_values
=
self
.
critic_net
(
states
,
actions
)
q_values
=
self
.
critic_function
(
critic_values
,
states
)
dqda
=
tf
.
gradients
([
q_values
],
[
actions
])[
0
]
dqda_unclipped
=
dqda
if
self
.
_dqda_clipping
>
0
:
dqda
=
tf
.
clip_by_value
(
dqda
,
-
self
.
_dqda_clipping
,
self
.
_dqda_clipping
)
actions_norm
=
tf
.
norm
(
actions
)
if
self
.
_debug_summaries
:
with
tf
.
name_scope
(
'dqda'
):
tf
.
summary
.
scalar
(
'actions_norm'
,
actions_norm
)
tf
.
summary
.
histogram
(
'dqda'
,
dqda
)
tf
.
summary
.
histogram
(
'dqda_unclipped'
,
dqda_unclipped
)
tf
.
summary
.
histogram
(
'actions'
,
actions
)
for
a
in
range
(
self
.
_num_action_dims
):
tf
.
summary
.
histogram
(
'dqda_unclipped_%d'
%
a
,
dqda_unclipped
[:,
a
])
tf
.
summary
.
histogram
(
'dqda_%d'
%
a
,
dqda
[:,
a
])
actions_norm
*=
self
.
_actions_regularizer
return
slim
.
losses
.
mean_squared_error
(
tf
.
stop_gradient
(
dqda
+
actions
),
actions
,
scope
=
'actor_loss'
)
+
actions_norm
@
gin
.
configurable
(
'ddpg_critic_function'
)
def
critic_function
(
self
,
critic_values
,
states
,
weights
=
None
):
"""Computes q values based on critic_net outputs, states, and weights.
Args:
critic_values: A tf.float32 [batch_size, ...] tensor representing outputs
from the critic net.
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
weights: A list or Numpy array or tensor with a shape broadcastable to
`critic_values`.
Returns:
A tf.float32 [batch_size] tensor representing q values.
"""
del
states
# unused args
if
weights
is
not
None
:
weights
=
tf
.
convert_to_tensor
(
weights
,
dtype
=
critic_values
.
dtype
)
critic_values
*=
weights
if
critic_values
.
shape
.
ndims
>
1
:
critic_values
=
tf
.
reduce_sum
(
critic_values
,
range
(
1
,
critic_values
.
shape
.
ndims
))
critic_values
.
shape
.
assert_has_rank
(
1
)
return
critic_values
@
gin
.
configurable
(
'ddpg_update_targets'
)
def
update_targets
(
self
,
tau
=
1.0
):
"""Performs a soft update of the target network parameters.
For each weight w_s in the actor/critic networks, and its corresponding
weight w_t in the target actor/critic networks, a soft update is:
w_t = (1- tau) x w_t + tau x ws
Args:
tau: A float scalar in [0, 1]
Returns:
An operation that performs a soft update of the target network parameters.
Raises:
ValueError: If `tau` is not in [0, 1].
"""
if
tau
<
0
or
tau
>
1
:
raise
ValueError
(
'Input `tau` should be in [0, 1].'
)
update_actor
=
utils
.
soft_variables_update
(
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
ACTOR_NET_SCOPE
)),
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
TARGET_ACTOR_NET_SCOPE
)),
tau
)
update_critic
=
utils
.
soft_variables_update
(
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
CRITIC_NET_SCOPE
)),
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
TARGET_CRITIC_NET_SCOPE
)),
tau
)
return
tf
.
group
(
update_actor
,
update_critic
,
name
=
'update_targets'
)
def
get_trainable_critic_vars
(
self
):
"""Returns a list of trainable variables in the critic network.
Returns:
A list of trainable variables in the critic network.
"""
return
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
CRITIC_NET_SCOPE
))
def
get_trainable_actor_vars
(
self
):
"""Returns a list of trainable variables in the actor network.
Returns:
A list of trainable variables in the actor network.
"""
return
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
ACTOR_NET_SCOPE
))
def
get_critic_vars
(
self
):
"""Returns a list of all variables in the critic network.
Returns:
A list of trainable variables in the critic network.
"""
return
slim
.
get_model_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
CRITIC_NET_SCOPE
))
def
get_actor_vars
(
self
):
"""Returns a list of all variables in the actor network.
Returns:
A list of trainable variables in the actor network.
"""
return
slim
.
get_model_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
ACTOR_NET_SCOPE
))
def
_validate_states
(
self
,
states
):
"""Raises a value error if `states` does not have the expected shape.
Args:
states: A tensor.
Raises:
ValueError: If states.shape or states.dtype are not compatible with
observation_spec.
"""
states
.
shape
.
assert_is_compatible_with
(
self
.
_state_shape
)
if
not
states
.
dtype
.
is_compatible_with
(
self
.
_observation_spec
.
dtype
):
raise
ValueError
(
'states.dtype={} is not compatible with'
' observation_spec.dtype={}'
.
format
(
states
.
dtype
,
self
.
_observation_spec
.
dtype
))
def
_validate_actions
(
self
,
actions
):
"""Raises a value error if `actions` does not have the expected shape.
Args:
actions: A tensor.
Raises:
ValueError: If actions.shape or actions.dtype are not compatible with
action_spec.
"""
actions
.
shape
.
assert_is_compatible_with
(
self
.
_action_shape
)
if
not
actions
.
dtype
.
is_compatible_with
(
self
.
_action_spec
.
dtype
):
raise
ValueError
(
'actions.dtype={} is not compatible with'
' action_spec.dtype={}'
.
format
(
actions
.
dtype
,
self
.
_action_spec
.
dtype
))
@
gin
.
configurable
class
TD3Agent
(
DdpgAgent
):
"""An RL agent that learns using the TD3 algorithm."""
ACTOR_NET_SCOPE
=
'actor_net'
CRITIC_NET_SCOPE
=
'critic_net'
CRITIC_NET2_SCOPE
=
'critic_net2'
TARGET_ACTOR_NET_SCOPE
=
'target_actor_net'
TARGET_CRITIC_NET_SCOPE
=
'target_critic_net'
TARGET_CRITIC_NET2_SCOPE
=
'target_critic_net2'
def
__init__
(
self
,
observation_spec
,
action_spec
,
actor_net
=
networks
.
actor_net
,
critic_net
=
networks
.
critic_net
,
td_errors_loss
=
tf
.
losses
.
huber_loss
,
dqda_clipping
=
0.
,
actions_regularizer
=
0.
,
target_q_clipping
=
None
,
residual_phi
=
0.0
,
debug_summaries
=
False
):
"""Constructs a TD3 agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
actor_net: A callable that creates the actor network. Must take the
following arguments: states, num_actions. Please see networks.actor_net
for an example.
critic_net: A callable that creates the critic network. Must take the
following arguments: states, actions. Please see networks.critic_net
for an example.
td_errors_loss: A callable defining the loss function for the critic
td error.
dqda_clipping: (float) clips the gradient dqda element-wise between
[-dqda_clipping, dqda_clipping]. Does not perform clipping if
dqda_clipping == 0.
actions_regularizer: A scalar, when positive penalizes the norm of the
actions. This can prevent saturation of actions for the actor_loss.
target_q_clipping: (tuple of floats) clips target q values within
(low, high) values when computing the critic loss.
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that
interpolates between Q-learning and residual gradient algorithm.
http://www.leemon.com/papers/1995b.pdf
debug_summaries: If True, add summaries to help debug behavior.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self
.
_observation_spec
=
observation_spec
[
0
]
self
.
_action_spec
=
action_spec
[
0
]
self
.
_state_shape
=
tf
.
TensorShape
([
None
]).
concatenate
(
self
.
_observation_spec
.
shape
)
self
.
_action_shape
=
tf
.
TensorShape
([
None
]).
concatenate
(
self
.
_action_spec
.
shape
)
self
.
_num_action_dims
=
self
.
_action_spec
.
shape
.
num_elements
()
self
.
_scope
=
tf
.
get_variable_scope
().
name
self
.
_actor_net
=
tf
.
make_template
(
self
.
ACTOR_NET_SCOPE
,
actor_net
,
create_scope_now_
=
True
)
self
.
_critic_net
=
tf
.
make_template
(
self
.
CRITIC_NET_SCOPE
,
critic_net
,
create_scope_now_
=
True
)
self
.
_critic_net2
=
tf
.
make_template
(
self
.
CRITIC_NET2_SCOPE
,
critic_net
,
create_scope_now_
=
True
)
self
.
_target_actor_net
=
tf
.
make_template
(
self
.
TARGET_ACTOR_NET_SCOPE
,
actor_net
,
create_scope_now_
=
True
)
self
.
_target_critic_net
=
tf
.
make_template
(
self
.
TARGET_CRITIC_NET_SCOPE
,
critic_net
,
create_scope_now_
=
True
)
self
.
_target_critic_net2
=
tf
.
make_template
(
self
.
TARGET_CRITIC_NET2_SCOPE
,
critic_net
,
create_scope_now_
=
True
)
self
.
_td_errors_loss
=
td_errors_loss
if
dqda_clipping
<
0
:
raise
ValueError
(
'dqda_clipping must be >= 0.'
)
self
.
_dqda_clipping
=
dqda_clipping
self
.
_actions_regularizer
=
actions_regularizer
self
.
_target_q_clipping
=
target_q_clipping
self
.
_residual_phi
=
residual_phi
self
.
_debug_summaries
=
debug_summaries
def
get_trainable_critic_vars
(
self
):
"""Returns a list of trainable variables in the critic network.
NOTE: This gets the vars of both critic networks.
Returns:
A list of trainable variables in the critic network.
"""
return
(
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
CRITIC_NET_SCOPE
)))
def
critic_net
(
self
,
states
,
actions
,
for_critic_loss
=
False
):
"""Returns the output of the critic network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
values1
=
self
.
_critic_net
(
states
,
actions
,
for_critic_loss
=
for_critic_loss
)
values2
=
self
.
_critic_net2
(
states
,
actions
,
for_critic_loss
=
for_critic_loss
)
if
for_critic_loss
:
return
values1
,
values2
return
values1
def
target_critic_net
(
self
,
states
,
actions
,
for_critic_loss
=
False
):
"""Returns the output of the target critic network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self
.
_validate_states
(
states
)
self
.
_validate_actions
(
actions
)
values1
=
tf
.
stop_gradient
(
self
.
_target_critic_net
(
states
,
actions
,
for_critic_loss
=
for_critic_loss
))
values2
=
tf
.
stop_gradient
(
self
.
_target_critic_net2
(
states
,
actions
,
for_critic_loss
=
for_critic_loss
))
if
for_critic_loss
:
return
values1
,
values2
return
values1
def
value_net
(
self
,
states
,
for_critic_loss
=
False
):
"""Returns the output of the critic evaluated with the actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
actions
=
self
.
actor_net
(
states
)
return
self
.
critic_net
(
states
,
actions
,
for_critic_loss
=
for_critic_loss
)
def
target_value_net
(
self
,
states
,
for_critic_loss
=
False
):
"""Returns the output of the target critic evaluated with the target actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
target_actions
=
self
.
target_actor_net
(
states
)
noise
=
tf
.
clip_by_value
(
tf
.
random_normal
(
tf
.
shape
(
target_actions
),
stddev
=
0.2
),
-
0.5
,
0.5
)
values1
,
values2
=
self
.
target_critic_net
(
states
,
target_actions
+
noise
,
for_critic_loss
=
for_critic_loss
)
values
=
tf
.
minimum
(
values1
,
values2
)
return
values
,
values
@
gin
.
configurable
(
'td3_update_targets'
)
def
update_targets
(
self
,
tau
=
1.0
):
"""Performs a soft update of the target network parameters.
For each weight w_s in the actor/critic networks, and its corresponding
weight w_t in the target actor/critic networks, a soft update is:
w_t = (1- tau) x w_t + tau x ws
Args:
tau: A float scalar in [0, 1]
Returns:
An operation that performs a soft update of the target network parameters.
Raises:
ValueError: If `tau` is not in [0, 1].
"""
if
tau
<
0
or
tau
>
1
:
raise
ValueError
(
'Input `tau` should be in [0, 1].'
)
update_actor
=
utils
.
soft_variables_update
(
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
ACTOR_NET_SCOPE
)),
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
TARGET_ACTOR_NET_SCOPE
)),
tau
)
# NOTE: This updates both critic networks.
update_critic
=
utils
.
soft_variables_update
(
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
CRITIC_NET_SCOPE
)),
slim
.
get_trainable_variables
(
utils
.
join_scope
(
self
.
_scope
,
self
.
TARGET_CRITIC_NET_SCOPE
)),
tau
)
return
tf
.
group
(
update_actor
,
update_critic
,
name
=
'update_targets'
)
def
gen_debug_td_error_summaries
(
target_q_values
,
q_values
,
td_targets
,
td_errors
):
"""Generates debug summaries for critic given a set of batch samples.
Args:
target_q_values: set of predicted next stage values.
q_values: current predicted value for the critic network.
td_targets: discounted target_q_values with added next stage reward.
td_errors: the different between td_targets and q_values.
"""
with
tf
.
name_scope
(
'td_errors'
):
tf
.
summary
.
histogram
(
'td_targets'
,
td_targets
)
tf
.
summary
.
histogram
(
'q_values'
,
q_values
)
tf
.
summary
.
histogram
(
'target_q_values'
,
target_q_values
)
tf
.
summary
.
histogram
(
'td_errors'
,
td_errors
)
with
tf
.
name_scope
(
'td_targets'
):
tf
.
summary
.
scalar
(
'mean'
,
tf
.
reduce_mean
(
td_targets
))
tf
.
summary
.
scalar
(
'max'
,
tf
.
reduce_max
(
td_targets
))
tf
.
summary
.
scalar
(
'min'
,
tf
.
reduce_min
(
td_targets
))
with
tf
.
name_scope
(
'q_values'
):
tf
.
summary
.
scalar
(
'mean'
,
tf
.
reduce_mean
(
q_values
))
tf
.
summary
.
scalar
(
'max'
,
tf
.
reduce_max
(
q_values
))
tf
.
summary
.
scalar
(
'min'
,
tf
.
reduce_min
(
q_values
))
with
tf
.
name_scope
(
'target_q_values'
):
tf
.
summary
.
scalar
(
'mean'
,
tf
.
reduce_mean
(
target_q_values
))
tf
.
summary
.
scalar
(
'max'
,
tf
.
reduce_max
(
target_q_values
))
tf
.
summary
.
scalar
(
'min'
,
tf
.
reduce_min
(
target_q_values
))
with
tf
.
name_scope
(
'td_errors'
):
tf
.
summary
.
scalar
(
'mean'
,
tf
.
reduce_mean
(
td_errors
))
tf
.
summary
.
scalar
(
'max'
,
tf
.
reduce_max
(
td_errors
))
tf
.
summary
.
scalar
(
'min'
,
tf
.
reduce_min
(
td_errors
))
tf
.
summary
.
scalar
(
'mean_abs'
,
tf
.
reduce_mean
(
tf
.
abs
(
td_errors
)))
research/efficient-hrl/agents/ddpg_networks.py
0 → 100644
View file @
65da497f
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sample actor(policy) and critic(q) networks to use with DDPG/NAF agents.
The DDPG networks are defined in "Section 7: Experiment Details" of
"Continuous control with deep reinforcement learning" - Lilicrap et al.
https://arxiv.org/abs/1509.02971
The NAF critic network is based on "Section 4" of "Continuous deep Q-learning
with model-based acceleration" - Gu et al. https://arxiv.org/pdf/1603.00748.
"""
import
tensorflow
as
tf
slim
=
tf
.
contrib
.
slim
import
gin.tf
@
gin
.
configurable
(
'ddpg_critic_net'
)
def
critic_net
(
states
,
actions
,
for_critic_loss
=
False
,
num_reward_dims
=
1
,
states_hidden_layers
=
(
400
,),
actions_hidden_layers
=
None
,
joint_hidden_layers
=
(
300
,),
weight_decay
=
0.0001
,
normalizer_fn
=
None
,
activation_fn
=
tf
.
nn
.
relu
,
zero_obs
=
False
,
images
=
False
):
"""Creates a critic that returns q values for the given states and actions.
Args:
states: (castable to tf.float32) a [batch_size, num_state_dims] tensor
representing a batch of states.
actions: (castable to tf.float32) a [batch_size, num_action_dims] tensor
representing a batch of actions.
num_reward_dims: Number of reward dimensions.
states_hidden_layers: tuple of hidden layers units for states.
actions_hidden_layers: tuple of hidden layers units for actions.
joint_hidden_layers: tuple of hidden layers units after joining states
and actions using tf.concat().
weight_decay: Weight decay for l2 weights regularizer.
normalizer_fn: Normalizer function, i.e. slim.layer_norm,
activation_fn: Activation function, i.e. tf.nn.relu, slim.leaky_relu, ...
Returns:
A tf.float32 [batch_size] tensor of q values, or a tf.float32
[batch_size, num_reward_dims] tensor of vector q values if
num_reward_dims > 1.
"""
with
slim
.
arg_scope
(
[
slim
.
fully_connected
],
activation_fn
=
activation_fn
,
normalizer_fn
=
normalizer_fn
,
weights_regularizer
=
slim
.
l2_regularizer
(
weight_decay
),
weights_initializer
=
slim
.
variance_scaling_initializer
(
factor
=
1.0
/
3.0
,
mode
=
'FAN_IN'
,
uniform
=
True
)):
orig_states
=
tf
.
to_float
(
states
)
#states = tf.to_float(states)
states
=
tf
.
concat
([
tf
.
to_float
(
states
),
tf
.
to_float
(
actions
)],
-
1
)
#TD3
if
images
or
zero_obs
:
states
*=
tf
.
constant
([
0.0
]
*
2
+
[
1.0
]
*
(
states
.
shape
[
1
]
-
2
))
#LALA
actions
=
tf
.
to_float
(
actions
)
if
states_hidden_layers
:
states
=
slim
.
stack
(
states
,
slim
.
fully_connected
,
states_hidden_layers
,
scope
=
'states'
)
if
actions_hidden_layers
:
actions
=
slim
.
stack
(
actions
,
slim
.
fully_connected
,
actions_hidden_layers
,
scope
=
'actions'
)
joint
=
tf
.
concat
([
states
,
actions
],
1
)
if
joint_hidden_layers
:
joint
=
slim
.
stack
(
joint
,
slim
.
fully_connected
,
joint_hidden_layers
,
scope
=
'joint'
)
with
slim
.
arg_scope
([
slim
.
fully_connected
],
weights_regularizer
=
None
,
weights_initializer
=
tf
.
random_uniform_initializer
(
minval
=-
0.003
,
maxval
=
0.003
)):
value
=
slim
.
fully_connected
(
joint
,
num_reward_dims
,
activation_fn
=
None
,
normalizer_fn
=
None
,
scope
=
'q_value'
)
if
num_reward_dims
==
1
:
value
=
tf
.
reshape
(
value
,
[
-
1
])
if
not
for_critic_loss
and
num_reward_dims
>
1
:
value
=
tf
.
reduce_sum
(
value
*
tf
.
abs
(
orig_states
[:,
-
num_reward_dims
:]),
-
1
)
return
value
@
gin
.
configurable
(
'ddpg_actor_net'
)
def
actor_net
(
states
,
action_spec
,
hidden_layers
=
(
400
,
300
),
normalizer_fn
=
None
,
activation_fn
=
tf
.
nn
.
relu
,
zero_obs
=
False
,
images
=
False
):
"""Creates an actor that returns actions for the given states.
Args:
states: (castable to tf.float32) a [batch_size, num_state_dims] tensor
representing a batch of states.
action_spec: (BoundedTensorSpec) A tensor spec indicating the shape
and range of actions.
hidden_layers: tuple of hidden layers units.
normalizer_fn: Normalizer function, i.e. slim.layer_norm,
activation_fn: Activation function, i.e. tf.nn.relu, slim.leaky_relu, ...
Returns:
A tf.float32 [batch_size, num_action_dims] tensor of actions.
"""
with
slim
.
arg_scope
(
[
slim
.
fully_connected
],
activation_fn
=
activation_fn
,
normalizer_fn
=
normalizer_fn
,
weights_initializer
=
slim
.
variance_scaling_initializer
(
factor
=
1.0
/
3.0
,
mode
=
'FAN_IN'
,
uniform
=
True
)):
states
=
tf
.
to_float
(
states
)
orig_states
=
states
if
images
or
zero_obs
:
# Zero-out x, y position. Hacky.
states
*=
tf
.
constant
([
0.0
]
*
2
+
[
1.0
]
*
(
states
.
shape
[
1
]
-
2
))
if
hidden_layers
:
states
=
slim
.
stack
(
states
,
slim
.
fully_connected
,
hidden_layers
,
scope
=
'states'
)
with
slim
.
arg_scope
([
slim
.
fully_connected
],
weights_initializer
=
tf
.
random_uniform_initializer
(
minval
=-
0.003
,
maxval
=
0.003
)):
actions
=
slim
.
fully_connected
(
states
,
action_spec
.
shape
.
num_elements
(),
scope
=
'actions'
,
normalizer_fn
=
None
,
activation_fn
=
tf
.
nn
.
tanh
)
action_means
=
(
action_spec
.
maximum
+
action_spec
.
minimum
)
/
2.0
action_magnitudes
=
(
action_spec
.
maximum
-
action_spec
.
minimum
)
/
2.0
actions
=
action_means
+
action_magnitudes
*
actions
return
actions
research/efficient-hrl/cond_fn.py
0 → 100644
View file @
65da497f
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines many boolean functions indicating when to step and reset.
"""
import
tensorflow
as
tf
import
gin.tf
@
gin
.
configurable
def
env_transition
(
agent
,
state
,
action
,
transition_type
,
environment_steps
,
num_episodes
):
"""True if the transition_type is TRANSITION or FINAL_TRANSITION.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: Returns an op that evaluates to true if the transition type is
not RESTARTING
"""
del
agent
,
state
,
action
,
num_episodes
,
environment_steps
cond
=
tf
.
logical_not
(
transition_type
)
return
cond
@
gin
.
configurable
def
env_restart
(
agent
,
state
,
action
,
transition_type
,
environment_steps
,
num_episodes
):
"""True if the transition_type is RESTARTING.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: Returns an op that evaluates to true if the transition type equals
RESTARTING.
"""
del
agent
,
state
,
action
,
num_episodes
,
environment_steps
cond
=
tf
.
identity
(
transition_type
)
return
cond
@
gin
.
configurable
def
every_n_steps
(
agent
,
state
,
action
,
transition_type
,
environment_steps
,
num_episodes
,
n
=
150
):
"""True once every n steps.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
n: Return true once every n steps.
Returns:
cond: Returns an op that evaluates to true if environment_steps
equals 0 mod n. We increment the step before checking this condition, so
we do not need to add one to environment_steps.
"""
del
agent
,
state
,
action
,
transition_type
,
num_episodes
cond
=
tf
.
equal
(
tf
.
mod
(
environment_steps
,
n
),
0
)
return
cond
@
gin
.
configurable
def
every_n_episodes
(
agent
,
state
,
action
,
transition_type
,
environment_steps
,
num_episodes
,
n
=
2
,
steps_per_episode
=
None
):
"""True once every n episodes.
Specifically, evaluates to True on the 0th step of every nth episode.
Unlike environment_steps, num_episodes starts at 0, so we do want to add
one to ensure it does not reset on the first call.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
n: Return true once every n episodes.
steps_per_episode: How many steps per episode. Needed to determine when a
new episode starts.
Returns:
cond: Returns an op that evaluates to true on the last step of the episode
(i.e. if num_episodes equals 0 mod n).
"""
assert
steps_per_episode
is
not
None
del
agent
,
action
,
transition_type
ant_fell
=
tf
.
logical_or
(
state
[
2
]
<
0.2
,
state
[
2
]
>
1.0
)
cond
=
tf
.
logical_and
(
tf
.
logical_or
(
ant_fell
,
tf
.
equal
(
tf
.
mod
(
num_episodes
+
1
,
n
),
0
)),
tf
.
equal
(
tf
.
mod
(
environment_steps
,
steps_per_episode
),
0
))
return
cond
@
gin
.
configurable
def
failed_reset_after_n_episodes
(
agent
,
state
,
action
,
transition_type
,
environment_steps
,
num_episodes
,
steps_per_episode
=
None
,
reset_state
=
None
,
max_dist
=
1.0
,
epsilon
=
1e-10
):
"""Every n episodes, returns True if the reset agent fails to return.
Specifically, evaluates to True if the distance between the state and the
reset state is greater than max_dist at the end of the episode.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
steps_per_episode: How many steps per episode. Needed to determine when a
new episode starts.
reset_state: State to which the reset controller should return.
max_dist: Agent is considered to have successfully reset if its distance
from the reset_state is less than max_dist.
epsilon: small offset to ensure non-negative/zero distance.
Returns:
cond: Returns an op that evaluates to true if num_episodes+1 equals 0
mod n. We add one to the num_episodes so the environment is not reset after
the 0th step.
"""
assert
steps_per_episode
is
not
None
assert
reset_state
is
not
None
del
agent
,
state
,
action
,
transition_type
,
num_episodes
dist
=
tf
.
sqrt
(
tf
.
reduce_sum
(
tf
.
squared_difference
(
state
,
reset_state
))
+
epsilon
)
cond
=
tf
.
logical_and
(
tf
.
greater
(
dist
,
tf
.
constant
(
max_dist
)),
tf
.
equal
(
tf
.
mod
(
environment_steps
,
steps_per_episode
),
0
))
return
cond
@
gin
.
configurable
def
q_too_small
(
agent
,
state
,
action
,
transition_type
,
environment_steps
,
num_episodes
,
q_min
=
0.5
):
"""True of q is too small.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
q_min: Returns true if the qval is less than q_min
Returns:
cond: Returns an op that evaluates to true if qval is less than q_min.
"""
del
transition_type
,
environment_steps
,
num_episodes
state_for_reset_agent
=
tf
.
stack
(
state
[:
-
1
],
tf
.
constant
([
0
],
dtype
=
tf
.
float
))
qval
=
agent
.
BASE_AGENT_CLASS
.
critic_net
(
tf
.
expand_dims
(
state_for_reset_agent
,
0
),
tf
.
expand_dims
(
action
,
0
))[
0
,
:]
cond
=
tf
.
greater
(
tf
.
constant
(
q_min
),
qval
)
return
cond
@
gin
.
configurable
def
true_fn
(
agent
,
state
,
action
,
transition_type
,
environment_steps
,
num_episodes
):
"""Returns an op that evaluates to true.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: op that always evaluates to True.
"""
del
agent
,
state
,
action
,
transition_type
,
environment_steps
,
num_episodes
cond
=
tf
.
constant
(
True
,
dtype
=
tf
.
bool
)
return
cond
@
gin
.
configurable
def
false_fn
(
agent
,
state
,
action
,
transition_type
,
environment_steps
,
num_episodes
):
"""Returns an op that evaluates to false.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: op that always evaluates to False.
"""
del
agent
,
state
,
action
,
transition_type
,
environment_steps
,
num_episodes
cond
=
tf
.
constant
(
False
,
dtype
=
tf
.
bool
)
return
cond
research/efficient-hrl/configs/base_uvf.gin
0 → 100644
View file @
65da497f
#-*-Python-*-
import gin.tf.external_configurables
create_maze_env.top_down_view = %IMAGES
## Create the agent
AGENT_CLASS = @UvfAgent
UvfAgent.tf_context = %CONTEXT
UvfAgent.actor_net = @agent/ddpg_actor_net
UvfAgent.critic_net = @agent/ddpg_critic_net
UvfAgent.dqda_clipping = 0.0
UvfAgent.td_errors_loss = @tf.losses.huber_loss
UvfAgent.target_q_clipping = %TARGET_Q_CLIPPING
# Create meta agent
META_CLASS = @MetaAgent
MetaAgent.tf_context = %META_CONTEXT
MetaAgent.sub_context = %CONTEXT
MetaAgent.actor_net = @meta/ddpg_actor_net
MetaAgent.critic_net = @meta/ddpg_critic_net
MetaAgent.dqda_clipping = 0.0
MetaAgent.td_errors_loss = @tf.losses.huber_loss
MetaAgent.target_q_clipping = %TARGET_Q_CLIPPING
# Create state preprocess
STATE_PREPROCESS_CLASS = @StatePreprocess
StatePreprocess.ndims = %SUBGOAL_DIM
state_preprocess_net.states_hidden_layers = (100, 100)
state_preprocess_net.num_output_dims = %SUBGOAL_DIM
state_preprocess_net.images = %IMAGES
action_embed_net.num_output_dims = %SUBGOAL_DIM
INVERSE_DYNAMICS_CLASS = @InverseDynamics
# actor_net
ACTOR_HIDDEN_SIZE_1 = 300
ACTOR_HIDDEN_SIZE_2 = 300
agent/ddpg_actor_net.hidden_layers = (%ACTOR_HIDDEN_SIZE_1, %ACTOR_HIDDEN_SIZE_2)
agent/ddpg_actor_net.activation_fn = @tf.nn.relu
agent/ddpg_actor_net.zero_obs = %ZERO_OBS
agent/ddpg_actor_net.images = %IMAGES
meta/ddpg_actor_net.hidden_layers = (%ACTOR_HIDDEN_SIZE_1, %ACTOR_HIDDEN_SIZE_2)
meta/ddpg_actor_net.activation_fn = @tf.nn.relu
meta/ddpg_actor_net.zero_obs = False
meta/ddpg_actor_net.images = %IMAGES
# critic_net
CRITIC_HIDDEN_SIZE_1 = 300
CRITIC_HIDDEN_SIZE_2 = 300
agent/ddpg_critic_net.states_hidden_layers = (%CRITIC_HIDDEN_SIZE_1,)
agent/ddpg_critic_net.actions_hidden_layers = None
agent/ddpg_critic_net.joint_hidden_layers = (%CRITIC_HIDDEN_SIZE_2,)
agent/ddpg_critic_net.weight_decay = 0.0
agent/ddpg_critic_net.activation_fn = @tf.nn.relu
agent/ddpg_critic_net.zero_obs = %ZERO_OBS
agent/ddpg_critic_net.images = %IMAGES
meta/ddpg_critic_net.states_hidden_layers = (%CRITIC_HIDDEN_SIZE_1,)
meta/ddpg_critic_net.actions_hidden_layers = None
meta/ddpg_critic_net.joint_hidden_layers = (%CRITIC_HIDDEN_SIZE_2,)
meta/ddpg_critic_net.weight_decay = 0.0
meta/ddpg_critic_net.activation_fn = @tf.nn.relu
meta/ddpg_critic_net.zero_obs = False
meta/ddpg_critic_net.images = %IMAGES
tf.losses.huber_loss.delta = 1.0
# Sample action
uvf_add_noise_fn.stddev = 1.0
meta_add_noise_fn.stddev = %META_EXPLORE_NOISE
# Update targets
ddpg_update_targets.tau = 0.001
td3_update_targets.tau = 0.005
research/efficient-hrl/configs/eval_uvf.gin
0 → 100644
View file @
65da497f
#-*-Python-*-
# Config eval
evaluate.environment = @create_maze_env()
evaluate.agent_class = %AGENT_CLASS
evaluate.meta_agent_class = %META_CLASS
evaluate.state_preprocess_class = %STATE_PREPROCESS_CLASS
evaluate.num_episodes_eval = 50
evaluate.num_episodes_videos = 1
evaluate.gamma = 1.0
evaluate.eval_interval_secs = 1
evaluate.generate_videos = False
evaluate.generate_summaries = True
evaluate.eval_modes = %EVAL_MODES
evaluate.max_steps_per_episode = %RESET_EPISODE_PERIOD
research/efficient-hrl/configs/train_uvf.gin
0 → 100644
View file @
65da497f
#-*-Python-*-
# Create replay_buffer
agent/CircularBuffer.buffer_size = 200000
meta/CircularBuffer.buffer_size = 200000
agent/CircularBuffer.scope = "agent"
meta/CircularBuffer.scope = "meta"
# Config train
train_uvf.environment = @create_maze_env()
train_uvf.agent_class = %AGENT_CLASS
train_uvf.meta_agent_class = %META_CLASS
train_uvf.state_preprocess_class = %STATE_PREPROCESS_CLASS
train_uvf.inverse_dynamics_class = %INVERSE_DYNAMICS_CLASS
train_uvf.replay_buffer = @agent/CircularBuffer()
train_uvf.meta_replay_buffer = @meta/CircularBuffer()
train_uvf.critic_optimizer = @critic/AdamOptimizer()
train_uvf.actor_optimizer = @actor/AdamOptimizer()
train_uvf.meta_critic_optimizer = @meta_critic/AdamOptimizer()
train_uvf.meta_actor_optimizer = @meta_actor/AdamOptimizer()
train_uvf.repr_optimizer = @repr/AdamOptimizer()
train_uvf.num_episodes_train = 25000
train_uvf.batch_size = 100
train_uvf.initial_episodes = 5
train_uvf.gamma = 0.99
train_uvf.meta_gamma = 0.99
train_uvf.reward_scale_factor = 1.0
train_uvf.target_update_period = 2
train_uvf.num_updates_per_observation = 1
train_uvf.num_collect_per_update = 1
train_uvf.num_collect_per_meta_update = 10
train_uvf.debug_summaries = False
train_uvf.log_every_n_steps = 1000
train_uvf.save_policy_every_n_steps =100000
# Config Optimizers
critic/AdamOptimizer.learning_rate = 0.001
critic/AdamOptimizer.beta1 = 0.9
critic/AdamOptimizer.beta2 = 0.999
actor/AdamOptimizer.learning_rate = 0.0001
actor/AdamOptimizer.beta1 = 0.9
actor/AdamOptimizer.beta2 = 0.999
meta_critic/AdamOptimizer.learning_rate = 0.001
meta_critic/AdamOptimizer.beta1 = 0.9
meta_critic/AdamOptimizer.beta2 = 0.999
meta_actor/AdamOptimizer.learning_rate = 0.0001
meta_actor/AdamOptimizer.beta1 = 0.9
meta_actor/AdamOptimizer.beta2 = 0.999
repr/AdamOptimizer.learning_rate = 0.0001
repr/AdamOptimizer.beta1 = 0.9
repr/AdamOptimizer.beta2 = 0.999
research/efficient-hrl/context/__init__.py
0 → 100644
View file @
65da497f
research/efficient-hrl/context/configs/ant_block.gin
0 → 100644
View file @
65da497f
#-*-Python-*-
create_maze_env.env_name = "AntBlock"
ZERO_OBS = False
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4), (20, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1", "eval2", "eval3"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
"eval2": [@eval2/ConstantSampler],
"eval3": [@eval3/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [3, 4]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [16, 0]
eval2/ConstantSampler.value = [16, 16]
eval3/ConstantSampler.value = [0, 16]
research/efficient-hrl/context/configs/ant_block_maze.gin
0 → 100644
View file @
65da497f
#-*-Python-*-
create_maze_env.env_name = "AntBlockMaze"
ZERO_OBS = False
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4), (12, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1", "eval2", "eval3"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
"eval2": [@eval2/ConstantSampler],
"eval3": [@eval3/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [3, 4]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [8, 0]
eval2/ConstantSampler.value = [8, 16]
eval3/ConstantSampler.value = [0, 16]
research/efficient-hrl/context/configs/ant_fall_multi.gin
0 → 100644
View file @
65da497f
#-*-Python-*-
create_maze_env.env_name = "AntFall"
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4, 0), (12, 28, 5))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [3]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1, 2]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [0, 27, 4.5]
research/efficient-hrl/context/configs/ant_fall_multi_img.gin
0 → 100644
View file @
65da497f
#-*-Python-*-
create_maze_env.env_name = "AntFall"
IMAGES = True
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4, 0), (12, 28, 5))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [3]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
}
meta/Context.context_transition_fn = @task/relative_context_transition_fn
meta/Context.context_multi_transition_fn = @task/relative_context_multi_transition_fn
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1, 2]
task/negative_distance.relative_context = True
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
task/relative_context_transition_fn.k = 3
task/relative_context_multi_transition_fn.k = 3
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [0, 27, 0]
research/efficient-hrl/context/configs/ant_fall_single.gin
0 → 100644
View file @
65da497f
#-*-Python-*-
create_maze_env.env_name = "AntFall"
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4, 0), (12, 28, 5))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [3]
meta/Context.samplers = {
"train": [@eval1/ConstantSampler],
"explore": [@eval1/ConstantSampler],
"eval1": [@eval1/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1, 2]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [0, 27, 4.5]
Prev
1
…
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment