Commit 65da497f authored by Shining Sun's avatar Shining Sun
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into cifar_keras

parents 93e0022d 7d032ea3
"""Function for computing a robust mean estimate in the presence of outliers.
This is a modified Python implementation of this file:
https://idlastro.gsfc.nasa.gov/ftp/pro/robust/resistant_mean.pro
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def robust_mean(y, cut):
"""Computes a robust mean estimate in the presence of outliers.
Args:
y: 1D numpy array. Assumed to be normally distributed with outliers.
cut: Points more than this number of standard deviations from the median are
ignored.
Returns:
mean: A robust estimate of the mean of y.
mean_stddev: The standard deviation of the mean.
mask: Boolean array with the same length as y. Values corresponding to
outliers in y are False. All other values are True.
"""
# First, make a robust estimate of the standard deviation of y, assuming y is
# normally distributed. The conversion factor of 1.4826 takes the median
# absolute deviation to the standard deviation of a normal distribution.
# See, e.g. https://www.mathworks.com/help/stats/mad.html.
absdev = np.abs(y - np.median(y))
sigma = 1.4826 * np.median(absdev)
# If the previous estimate of the standard deviation using the median absolute
# deviation is zero, fall back to a robust estimate using the mean absolute
# deviation. This estimator has a different conversion factor of 1.253.
# See, e.g. https://www.mathworks.com/help/stats/mad.html.
if sigma < 1.0e-24:
sigma = 1.253 * np.mean(absdev)
# Identify outliers using our estimate of the standard deviation of y.
mask = absdev <= cut * sigma
# Now, recompute the standard deviation, using the sample standard deviation
# of non-outlier points.
sigma = np.std(y[mask])
# Compensate the estimate of sigma due to trimming away outliers. The
# following formula is an approximation, see
# http://w.astro.berkeley.edu/~johnjohn/idlprocs/robust_mean.pro.
sc = np.max([cut, 1.0])
if sc <= 4.5:
sigma /= (-0.15405 + 0.90723 * sc - 0.23584 * sc**2 + 0.020142 * sc**3)
# Identify outliers using our second estimate of the standard deviation of y.
mask = absdev <= cut * sigma
# Now, recompute the standard deviation, using the sample standard deviation
# with non-outlier points.
sigma = np.std(y[mask])
# Compensate the estimate of sigma due to trimming away outliers.
sc = np.max([cut, 1.0])
if sc <= 4.5:
sigma /= (-0.15405 + 0.90723 * sc - 0.23584 * sc**2 + 0.020142 * sc**3)
# Final estimate is the sample mean with outliers removed.
mean = np.mean(y[mask])
mean_stddev = sigma / np.sqrt(len(y) - 1.0)
return mean, mean_stddev, mask
"""Tests for robust_mean.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from third_party.robust_mean import robust_mean
from third_party.robust_mean.test_data import random_normal
class RobustMeanTest(absltest.TestCase):
def testRobustMean(self):
# To avoid non-determinism in the unit test, we use a pre-generated vector
# of length 1,000. Each entry is independently sampled from a random normal
# distribution with mean 2 and standard deviation 1. The maximum value of
# y is 6.075 (+4.075 sigma from the mean) and the minimum value is -1.54
# (-3.54 sigma from the mean).
y = np.array(random_normal.RANDOM_NORMAL)
self.assertAlmostEqual(np.mean(y), 2.00336615850485)
self.assertAlmostEqual(np.std(y), 1.01690907798)
# High cut. No points rejected, so the mean should be the sample mean, and
# the mean standard deviation should be the sample standard deviation
# divided by sqrt(1000 - 1).
mean, mean_stddev, mask = robust_mean.robust_mean(y, cut=5)
self.assertAlmostEqual(mean, 2.00336615850485)
self.assertAlmostEqual(mean_stddev, 0.032173579)
self.assertLen(mask, 1000)
self.assertEqual(np.sum(mask), 1000)
# Cut of 3 standard deviations.
mean, mean_stddev, mask = robust_mean.robust_mean(y, cut=3)
self.assertAlmostEqual(mean, 2.0059050070632178)
self.assertAlmostEqual(mean_stddev, 0.03197075302321066)
# There are exactly 3 points in the sample less than 1 or greater than 5.
# These have indices 12, 220, 344.
self.assertLen(mask, 1000)
self.assertEqual(np.sum(mask), 997)
self.assertFalse(np.any(mask[[12, 220, 344]]))
# Add outliers. This corrupts the sample mean to 2.082.
mean, mean_stddev, mask = robust_mean.robust_mean(
y=np.concatenate([y, [10] * 10]), cut=5)
self.assertAlmostEqual(mean, 2.0033661585048681)
self.assertAlmostEqual(mean_stddev, 0.032013749413590531)
self.assertLen(mask, 1010)
self.assertEqual(np.sum(mask), 1000)
self.assertFalse(np.any(mask[1000:1010]))
# Add an outlier. This corrupts the mean to 1.002.
mean, mean_stddev, mask = robust_mean.robust_mean(
y=np.concatenate([y, [-1000]]), cut=5)
self.assertAlmostEqual(mean, 2.0033661585048681)
self.assertAlmostEqual(mean_stddev, 0.032157488597211903)
self.assertLen(mask, 1001)
self.assertEqual(np.sum(mask), 1000)
self.assertFalse(mask[1000])
if __name__ == "__main__":
absltest.main()
"""This file contains 1,000 points of a random normal distribution.
The mean of the distribution is 2, and the standard deviation is 1.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
RANDOM_NORMAL = [
0.692741320869,
1.948556207658,
-0.140158117639,
2.680322906859,
1.876671492867,
2.885286232509,
1.482222151802,
2.234349266246,
2.437427989583,
1.573053624952,
2.169198249367,
2.791023059619,
-1.053798286951,
1.796497126664,
3.806070621390,
1.744055208958,
3.474399140181,
1.564447665560,
1.143107137921,
1.618376255615,
2.615632139609,
1.413239777404,
1.047237320108,
3.190489536636,
2.918428435434,
1.268789896280,
0.931181066003,
3.797790627792,
0.493025834330,
1.866146169585,
0.949927834893,
1.439666857958,
2.705521702500,
1.815406073907,
1.570503841718,
1.834429337005,
2.903916580263,
-0.110549195467,
2.065338922749,
1.119498053048,
0.427627428035,
3.025052175045,
2.645448868784,
1.442644951218,
0.774681298962,
2.247561418494,
1.743438974941,
1.184440017832,
1.643691193885,
1.947748675186,
2.178309991836,
2.815355272672,
2.207620544168,
2.077889048169,
2.915504366132,
2.440862146850,
2.804729838623,
0.534712595625,
1.956491042766,
2.230542009671,
2.186536281651,
3.694129968231,
3.313526598170,
2.170240599444,
2.793531289796,
1.454464312809,
1.197463589804,
0.713332299712,
1.180965411999,
2.180022174106,
2.861107927091,
1.795223865106,
1.730056153040,
1.431404424890,
1.839372935334,
1.271871740741,
3.773103777671,
1.026069424885,
2.006070770486,
1.276836142291,
1.414098998873,
1.749117068374,
2.040006827147,
1.815581326626,
2.892666735522,
3.093934769003,
2.129166907135,
1.260521633663,
3.259431640120,
1.879415647487,
1.368769201985,
2.236653714367,
2.293120875655,
2.361086097355,
2.140675892497,
2.860288793716,
3.109921655205,
2.142509743586,
0.661829413359,
0.620852115030,
2.279817287885,
2.077609300700,
1.917031492891,
2.549328729021,
1.402961147881,
2.989802645752,
2.126646549508,
0.581285045065,
3.226987223858,
1.790860716921,
0.998661497130,
2.125771640271,
2.186096892741,
2.160189267804,
2.206460323846,
3.366179111195,
-0.125206283025,
0.645228886619,
0.505553980622,
4.494406059555,
1.291690417806,
2.977896904657,
2.869240282824,
3.344192278881,
2.487041683297,
4.236730343795,
3.007206122800,
1.210065291965,
-0.053847768077,
1.108953782402,
1.843857008095,
2.374767801329,
1.472199059501,
3.332198116275,
2.027084082885,
2.305065331530,
3.387400013580,
1.493365795517,
2.344295515065,
2.898632740793,
3.307836869328,
1.892766317783,
2.348033912288,
1.288522200888,
2.178559140529,
2.366037265891,
3.468023805733,
1.910134543982,
1.750500687923,
1.506717073807,
1.345976221745,
1.898226480175,
2.362688287820,
2.176558673313,
1.716475335783,
1.109563102324,
1.824697060483,
2.290331853365,
3.660496355225,
3.695990930547,
0.995131810353,
2.083740307542,
2.515409175245,
1.734919119633,
0.186488629263,
3.470910728743,
3.503515673097,
2.225335667636,
4.925211524431,
3.176405299532,
2.938260408825,
2.336603901159,
2.218333712640,
3.269148549824,
1.921171637456,
3.876114839719,
1.492216718705,
2.792835112200,
3.563198188748,
2.728530961520,
3.231549893645,
2.209018339760,
1.081828242171,
0.754161622090,
1.948018149260,
2.413945024183,
1.425023717183,
2.005406706788,
0.964987890314,
1.603414847296,
0.132077263346,
1.789327371404,
1.423488299029,
2.590160851192,
3.131340836085,
2.325779171436,
2.129789552692,
1.876126153813,
2.667783873354,
-0.220464828097,
2.285158851436,
1.188664672684,
1.968980968179,
2.510328726654,
1.690300427857,
2.041495293673,
2.471293710293,
1.660589811070,
1.801640276851,
2.200864460731,
1.489583958038,
1.545725376492,
4.208130184998,
2.428489533380,
3.539990060815,
1.317090333595,
0.785936916712,
0.809688718378,
1.265062896735,
2.749291333938,
6.075297866258,
2.165845459075,
2.055273600728,
2.584618009430,
2.782654850307,
0.967100649409,
2.267394795463,
2.783350629984,
0.238340558296,
1.566536380829,
1.165403279885,
3.409015124349,
1.047853632456,
2.100798231132,
1.824776518459,
1.517825551662,
2.148972385365,
1.818426298006,
1.954355115973,
2.428393037760,
2.225660788849,
1.287880002052,
3.083900598687,
2.561457835470,
2.547146477110,
-0.060868513691,
1.917876348341,
1.194823858275,
1.237685798924,
2.500081029116,
0.605823016300,
1.341027488293,
1.357719149407,
3.959221361786,
1.457342301661,
1.450552596247,
3.152966485077,
1.755910034199,
2.252303064393,
2.315145292843,
2.092889154866,
2.044536701039,
3.078226379252,
1.940374989780,
0.981160719305,
1.801484599888,
4.599412580952,
3.029815652986,
2.234894233100,
1.884862677960,
2.703542617621,
2.188894869734,
1.031225637544,
4.487470294014,
1.916903861878,
2.178877764206,
2.001204233385,
1.668533128794,
0.118714387565,
1.236342841750,
0.697779517270,
4.061304247309,
1.873047854221,
0.529730720609,
0.772303413290,
1.734928501976,
0.830164961083,
3.674107591296,
3.027005867653,
2.798171180697,
2.754769626808,
2.287213251879,
0.224122591017,
1.996907607820,
2.272196861888,
1.423156951562,
2.649423732022,
2.410425004883,
2.348764499112,
4.188086272873,
2.592584804958,
1.360716155533,
1.089292416194,
0.877166635938,
2.923298927077,
1.699602289582,
1.764010718116,
0.851384613856,
1.362786130903,
4.014401248962,
2.004378924317,
2.680507997712,
4.162602009325,
2.080304752717,
0.758782969232,
0.896584126809,
1.907281638800,
2.753415491620,
1.571468221472,
1.510571435517,
3.133254430892,
1.314198176176,
2.871092309494,
0.505771497509,
0.608771053519,
0.099600620869,
2.202314023992,
1.561845986404,
1.935860544395,
4.227485606155,
2.507702606518,
1.966897273255,
3.462827375982,
2.297865682096,
2.018310409281,
2.231512822040,
2.912164920958,
0.391926284930,
3.233896921158,
2.270671144478,
2.151928087898,
1.169376547635,
1.410447269758,
1.104075308499,
-1.542633116467,
1.153006815104,
1.825678952144,
3.170518866440,
4.259372395300,
2.991591841969,
2.936827860147,
1.621450416535,
2.022035327270,
3.512668911326,
2.840069655471,
0.445725474197,
0.462229554454,
0.318918997270,
2.764048560322,
1.707769041832,
0.354635293838,
1.422103811424,
1.567812002847,
1.024884046523,
3.417171077354,
1.638428319488,
3.241722761084,
1.903144274531,
2.386560261127,
1.089737760638,
0.709565288091,
-0.055267123709,
2.220171017505,
2.676992914119,
1.795938808643,
0.857048483230,
2.341277450146,
0.597747299826,
2.172474110279,
1.658595631706,
1.984212673322,
3.348561751121,
1.548578130896,
1.072758349253,
0.032774558165,
1.706534108602,
0.755870027998,
3.896324551791,
1.893154948782,
1.610014175651,
1.689869260730,
0.837788921577,
2.072889043390,
1.849998133863,
2.312731199777,
3.080340939707,
3.091164029551,
0.393260795576,
2.003732199463,
1.850471999505,
1.674223365447,
0.950827266616,
0.893144704712,
4.088054520567,
2.494717669405,
2.940185915913,
2.362745036344,
4.420918853822,
1.196829910488,
0.751131585724,
2.572876053732,
4.783864101630,
1.371390533975,
2.265749507496,
0.980731387353,
1.194594017621,
1.167489912193,
1.964259577764,
2.911981147100,
3.425120291588,
-0.257485591786,
2.472881717211,
3.053440640390,
0.762578570358,
1.132958189893,
2.182874371350,
3.052476057575,
1.277863138274,
1.639136886663,
3.068422388091,
4.082802262329,
3.817537635954,
0.097850368917,
1.833230262781,
1.868086753582,
1.887983463294,
1.651402760749,
1.139536636683,
1.506983113398,
2.136499510660,
1.554089544528,
1.817657472715,
1.881949483974,
1.694259012321,
2.466961181010,
0.934064795958,
1.780986169906,
2.370334192643,
1.384364501906,
2.661332270733,
0.505133486534,
3.377981661644,
4.528998885749,
1.374759695170,
4.722230929249,
2.457241846607,
1.089449047113,
2.069442203989,
2.922047383746,
2.418239643214,
2.102706555829,
3.402947114317,
0.796159207619,
2.632564349894,
4.200094165974,
1.193215106012,
3.096716644943,
2.876829603210,
1.921543697812,
1.061475001173,
1.539143201636,
1.962758648643,
2.295863280945,
1.165666782756,
2.795634751169,
1.964614100540,
2.881578005533,
2.637037175067,
2.892982065300,
1.370612909045,
2.259776562066,
2.613792094772,
1.906706647250,
1.557148053231,
1.133780390845,
1.143533122599,
4.117191444375,
0.188018096004,
0.214460776257,
1.603522547618,
1.983864405185,
2.699735877141,
3.298710632472,
1.487986613587,
1.991452281956,
4.766298265537,
2.586190101112,
1.065148174656,
2.145271421777,
1.522765549323,
4.396072859233,
1.606095900438,
1.031438092798,
2.960703649068,
3.605096318253,
0.738490507896,
3.432893690638,
1.851816325195,
1.776083706595,
2.626748456025,
3.271038565612,
-0.070612895830,
1.287744477416,
1.695753443501,
2.436753770247,
2.202642339826,
1.408938299450,
2.016160912374,
0.898288687497,
1.289813203546,
2.817550580649,
1.633807646433,
0.729748534247,
3.731152016795,
4.390636356521,
0.960267728055,
2.664438938132,
3.353944579683,
3.269381121224,
1.172165464812,
2.400938318282,
0.807728448589,
1.354368440944,
0.710514838580,
3.856287086165,
1.844610913086,
0.998911164773,
2.675023215997,
2.434288374696,
3.159869294202,
2.216260317913,
1.045656776305,
2.009335242758,
-0.709434674480,
1.331363427990,
1.413333913927,
0.929006400187,
1.733184360011,
1.592926769452,
2.244402061817,
-0.685165133598,
1.588319181570,
2.023137693888,
2.371988587924,
3.467422121044,
1.853048214949,
2.619054775096,
2.800663994934,
3.602262658213,
2.739826236907,
1.657450448996,
2.635627781551,
3.426068129480,
1.151293859797,
2.126497682712,
1.381359388969,
2.247544386446,
1.174628386484,
0.958611426223,
3.704308164799,
1.239169433782,
1.113358149317,
2.955595260742,
4.724760227403,
1.980693866973,
1.178308425749,
1.764128983382,
3.075383932276,
1.825832517572,
0.832366525583,
1.141057225660,
1.525888435245,
2.385324115908,
1.358714765290,
0.551137984167,
3.731145479114,
2.129026962460,
2.573317010477,
1.976940325893,
2.420475072025,
0.684154421404,
2.802696725755,
2.541095686615,
2.058591811473,
1.640112285999,
1.856989192038,
2.614611193561,
3.469421336856,
1.053557146407,
3.032200283499,
2.573750297024,
3.083216685185,
2.404219296708,
0.346271398570,
2.589361666010,
2.774804416246,
0.445877540011,
0.905444077995,
0.063823875188,
2.931316420485,
1.682860197161,
2.972795382257,
2.597485175923,
1.554827252582,
0.938640710601,
1.554015447012,
0.698644188586,
2.957760202695,
2.706304471141,
2.642415006150,
1.464184232137,
2.765792229162,
2.039393447616,
1.582779254230,
1.722697961910,
0.354842490538,
0.839688308674,
3.250316830782,
3.993268587677,
1.831751003414,
3.737987669486,
3.837008408003,
3.656452995704,
1.378085850241,
3.992366605685,
3.063520565655,
1.829600671075,
2.853149829083,
1.948008763331,
2.489355654745,
2.039149456991,
1.723308108929,
1.530719515047,
1.390322318375,
1.015161747970,
1.902647551975,
0.587714373573,
2.419343238401,
2.037241109090,
1.989108845487,
2.555164211364,
2.145078634562,
2.453495232937,
1.572091583978,
3.017196269239,
1.359683738353,
1.905697793148,
1.745346338494,
2.410960789923,
1.688108090624,
2.041661869959,
2.261146892703,
0.108311227666,
2.261198590438,
1.205414457068,
0.815680644627,
2.373638547036,
0.314446220857,
2.407160216258,
1.767921824455,
1.812649838016,
1.981483340407,
2.294353826751,
1.219794724258,
4.384759314526,
3.362912919095,
0.358020839800,
2.416111296383,
0.772765268291,
3.036908153028,
3.499839475422,
2.504401672085,
1.000612753791,
1.031364523108,
2.905950640378,
3.816584440139,
0.846980443659,
2.806102343007,
1.662302388297,
2.146698147213,
2.247505312463,
1.485016638026,
3.139004503074,
0.525587710167,
1.271023854890,
1.255521130946,
1.814043296479,
1.216959307975,
0.978300004743,
1.793024541935,
2.436108214253,
1.805501508380,
2.289362542667,
3.103303146056,
3.070219780476,
1.928865588661,
0.671011951957,
1.892825933013,
2.777602823529,
0.491871575583,
2.240415846966,
2.375489703208,
2.709091612473,
1.454643174490,
0.932692202068,
1.330312119137,
2.127413235976,
1.317902934165,
1.580714395448,
1.008090918724,
0.713394722097,
0.414109615934,
1.497415539366,
2.768670845431,
2.432044584164,
3.549640318635,
1.342147007285,
1.490711835094,
2.215255822048,
2.953966963699,
1.397115922550,
1.378544056641,
0.295634610499,
0.310858641177,
1.767513113561,
3.434648852323,
1.491911596249,
4.374871362485,
1.373675010945,
0.738310910553,
1.234191541434,
2.155481175306,
2.958616497624,
1.540019317971,
1.890919744323,
0.015363864461,
0.611976171745,
2.048461203755,
0.905204536881,
1.952638485996,
2.425065685214,
2.237320237401,
1.261771567053,
2.589404719675,
2.475731886267,
1.327151422229,
1.535419742175,
1.208022763652,
3.436317329939,
1.228705365902,
1.566902016116,
2.085976618587,
3.604339608476,
1.070321479131,
2.085842592869,
2.524588738973,
0.844371573275,
0.666896658382,
3.021051396651,
1.304696763442,
2.885533651158,
1.076496681998,
2.291817051246,
1.297874264925,
2.181105432748,
2.017938562177,
1.688714920892,
2.778982555151,
2.464336632348,
2.351157814765,
0.328615108417,
2.045729524665,
1.097213832650,
1.821737535398,
1.493311782343,
2.664732414197,
2.409262056717,
3.113216373314,
-0.736586322990,
2.622696405459,
2.182304470993,
0.270476186892,
2.628658716523,
1.938159535287,
3.336990992309,
1.843049113291,
4.411231162672,
2.385234723754,
2.814668480292,
0.835204823261,
1.486415794299,
4.428506992622,
2.010951035701,
2.386497902232,
-0.054343008749,
1.897798394390,
2.128378005079,
2.003863798772,
1.414857649717,
3.058382706867,
0.734980076150,
0.128402966890,
0.075621261496,
2.062530850812,
2.257054591626,
3.098405063129,
1.184503294303,
1.927098840462,
3.590105219538,
2.324770104189,
2.920547827923,
3.774469427430,
0.643975980439,
2.972011913013,
1.545636773552,
1.284276446577,
2.116456504846,
2.334765924705,
1.476322485264,
0.333938454195,
1.740780860437,
0.809641636242,
2.114359904589,
3.495010537745,
1.394057058959,
2.099880999687,
1.723136191694,
1.824145520142,
2.206175560435,
2.217935160800,
3.184151380365,
-0.165277107839,
2.066902569755,
4.109207223806,
2.639922346758,
2.869289441530,
2.992666432223,
2.628328010580,
1.318946413946,
3.437097382310,
2.043254488710,
0.244000873823,
1.857441713051,
1.302602111278,
2.850286225242,
1.988609208476,
0.406765856788,
1.691073499692,
0.918912799942,
1.943198487145,
3.174415802822,
1.916755816708,
2.196550794119,
0.930720476044,
2.032189015326,
-0.777034945338,
1.406753268550,
0.870345844705,
1.793195283464,
2.066120080070,
2.916729217526,
0.642313449142,
2.617529572000,
2.396572272668,
1.942111542427,
2.435603256612,
3.898219347884,
1.979409214342,
1.235681010137,
-0.802441600645,
1.927883866070,
0.852232772749,
2.626513188209,
1.994232584644,
2.677125120554,
2.945149227801,
0.344859114264,
2.988484765052,
2.221699681734,
1.157038942208,
2.703070759809,
1.410436365113,
3.056534135285,
0.975232183559,
1.032651705560,
1.787301763233,
1.587502529729,
1.425207628405,
2.443158189935,
3.786205343468,
0.240451061053,
2.993759767949,
2.527525916677,
2.990777291349,
1.458774147434,
4.293524428909,
-0.116618748162,
1.674243883127,
2.434026351267,
3.129729749455,
1.532120640786,
3.584008627649,
2.126682783899,
0.784920593215,
1.954841166456,
2.659877218373,
2.639038968190,
3.009597452617,
3.820422929562,
2.950718556164,
2.942026969809,
2.899140330708,
-0.003511099104,
0.780849789152,
2.375904463772,
1.034820493941,
2.010379907777,
2.273452795908,
2.508893511243,
0.495773521197,
3.010585297044,
3.029210010516,
3.973880821070,
2.416599047057,
1.320773195864,
-0.296283555404,
3.112367101202,
0.568454165534,
3.950197901953,
3.040255296379,
2.892209686169,
1.355195417805,
2.139684432822,
2.920582903729,
1.588899963320,
2.235959314499,
1.768769790964,
2.854126298598,
-0.132279995647,
1.984901818097,
0.667875459687,
1.338320000460,
1.394322304858,
0.299399873843,
1.062140558649,
1.283070416608,
2.043726915244,
1.426628725160,
1.763445352183,
2.517159283156,
1.334042379945,
1.120896888394,
1.890222921582,
0.565772476594,
1.106579774451,
1.419654511698,
2.809182593659,
2.500132279723,
2.818415931740,
2.302096389328,
2.700248827229,
2.649016952991,
3.051337084118,
1.040839832658,
1.068258609432,
3.917982023425,
0.893981534117,
1.258354231702,
0.154108914546,
1.578873281706,
1.438841948587,
0.854591114516,
3.199994919222,
0.946279793861,
1.911631903701,
3.821301368668,
1.417597738430,
2.979514439880,
1.202734703576,
1.724395921518,
3.069121580556,
1.024707141488,
3.309560047408,
2.147433150108,
2.173493008341,
3.875831804386,
2.160166379458,
2.017326408423,
3.941632320127,
2.583832116740,
]
Code for performing Hierarchical RL based on
Code for performing Hierarchical RL based on the following publications:
"Data-Efficient Hierarchical Reinforcement Learning" by
Ofir Nachum, Shixiang (Shane) Gu, Honglak Lee, and Sergey Levine
(https://arxiv.org/abs/1805.08296).
This library currently includes three of the environments used:
Ant Maze, Ant Push, and Ant Fall.
The training code is planned to be open-sourced at a later time.
"Near-Optimal Representation Learning for Hierarchical Reinforcement Learning"
by Ofir Nachum, Shixiang (Shane) Gu, Honglak Lee, and Sergey Levine
(https://arxiv.org/abs/1810.01257).
Requirements:
* TensorFlow (see http://www.tensorflow.org for how to install/upgrade)
* Gin Config (see https://github.com/google/gin-config)
* Tensorflow Agents (see https://github.com/tensorflow/agents)
* OpenAI Gym (see http://gym.openai.com/docs, be sure to install MuJoCo as well)
* NumPy (see http://www.numpy.org/)
Quick Start:
Run a random policy on AntMaze (or AntPush, AntFall):
Run a training job based on the original HIRO paper on Ant Maze:
```
python scripts/local_train.py test1 hiro_orig ant_maze base_uvf suite
```
Run a continuous evaluation job for that experiment:
```
python environments/__init__.py --env=AntMaze
python scripts/local_eval.py test1 hiro_orig ant_maze base_uvf suite
```
To run the same experiment with online representation learning (the
"Near-Optimal" paper), change `hiro_orig` to `hiro_repr`.
You can also run with `hiro_xy` to run the same experiment with HIRO on only the
xy coordinates of the agent.
To run on other environments, change `ant_maze` to something else; e.g.,
`ant_push_multi`, `ant_fall_multi`, etc. See `context/configs/*` for other options.
Basic Code Guide:
The code for training resides in train.py. The code trains a lower-level policy
(a UVF agent in the code) and a higher-level policy (a MetaAgent in the code)
concurrently. The higher-level policy communicates goals to the lower-level
policy. In the code, this is called a context. Not only does the lower-level
policy act with respect to a context (a higher-level specified goal), but the
higher-level policy also acts with respect to an environment-specified context
(corresponding to the navigation target location associated with the task).
Therefore, in `context/configs/*` you will find both specifications for task setup
as well as goal configurations. Most remaining hyperparameters used for
training/evaluation may be found in `configs/*`.
NOTE: Not all the code corresponding to the "Near-Optimal" paper is included.
Namely, changes to low-level policy training proposed in the paper (discounting
and auxiliary rewards) are not implemented here. Performance should not change
significantly.
Maintained by Ofir Nachum (ofirnachum).
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A UVF agent.
"""
import tensorflow as tf
import gin.tf
from agents import ddpg_agent
# pylint: disable=unused-import
import cond_fn
from utils import utils as uvf_utils
from context import gin_imports
# pylint: enable=unused-import
slim = tf.contrib.slim
@gin.configurable
class UvfAgentCore(object):
"""Defines basic functions for UVF agent. Must be inherited with an RL agent.
Used as lower-level agent.
"""
def __init__(self,
observation_spec,
action_spec,
tf_env,
tf_context,
step_cond_fn=cond_fn.env_transition,
reset_episode_cond_fn=cond_fn.env_restart,
reset_env_cond_fn=cond_fn.false_fn,
metrics=None,
**base_agent_kwargs):
"""Constructs a UVF agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
tf_env: A Tensorflow environment object.
tf_context: A Context class.
step_cond_fn: A function indicating whether to increment the num of steps.
reset_episode_cond_fn: A function indicating whether to restart the
episode, resampling the context.
reset_env_cond_fn: A function indicating whether to perform a manual reset
of the environment.
metrics: A list of functions that evaluate metrics of the agent.
**base_agent_kwargs: A dictionary of parameters for base RL Agent.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._step_cond_fn = step_cond_fn
self._reset_episode_cond_fn = reset_episode_cond_fn
self._reset_env_cond_fn = reset_env_cond_fn
self.metrics = metrics
# expose tf_context methods
self.tf_context = tf_context(tf_env=tf_env)
self.set_replay = self.tf_context.set_replay
self.sample_contexts = self.tf_context.sample_contexts
self.compute_rewards = self.tf_context.compute_rewards
self.gamma_index = self.tf_context.gamma_index
self.context_specs = self.tf_context.context_specs
self.context_as_action_specs = self.tf_context.context_as_action_specs
self.init_context_vars = self.tf_context.create_vars
self.env_observation_spec = observation_spec[0]
merged_observation_spec = (uvf_utils.merge_specs(
(self.env_observation_spec,) + self.context_specs),)
self._context_vars = dict()
self._action_vars = dict()
self.BASE_AGENT_CLASS.__init__(
self,
observation_spec=merged_observation_spec,
action_spec=action_spec,
**base_agent_kwargs
)
def set_meta_agent(self, agent=None):
self._meta_agent = agent
@property
def meta_agent(self):
return self._meta_agent
def actor_loss(self, states, actions, rewards, discounts,
next_states):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
return self.BASE_AGENT_CLASS.actor_loss(self, states)
def action(self, state, context=None):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
merged_state = self.merged_state(state, context)
return self.BASE_AGENT_CLASS.action(self, merged_state)
def actions(self, state, context=None):
"""Returns the next action for the state.
Args:
state: A [-1, num_state_dims] tensor representing a state.
context: A list of [-1, num_context_dims] tensor representing a context.
Returns:
A [-1, num_action_dims] tensor representing the action.
"""
merged_states = self.merged_states(state, context)
return self.BASE_AGENT_CLASS.actor_net(self, merged_states)
def log_probs(self, states, actions, state_reprs, contexts=None):
assert contexts is not None
batch_dims = [tf.shape(states)[0], tf.shape(states)[1]]
contexts = self.tf_context.context_multi_transition_fn(
contexts, states=tf.to_float(state_reprs))
flat_states = tf.reshape(states,
[batch_dims[0] * batch_dims[1], states.shape[-1]])
flat_contexts = [tf.reshape(tf.cast(context, states.dtype),
[batch_dims[0] * batch_dims[1], context.shape[-1]])
for context in contexts]
flat_pred_actions = self.actions(flat_states, flat_contexts)
pred_actions = tf.reshape(flat_pred_actions,
batch_dims + [flat_pred_actions.shape[-1]])
error = tf.square(actions - pred_actions)
spec_range = (self._action_spec.maximum - self._action_spec.minimum) / 2
normalized_error = error / tf.constant(spec_range) ** 2
return -normalized_error
@gin.configurable('uvf_add_noise_fn')
def add_noise_fn(self, action_fn, stddev=1.0, debug=False,
clip=True, global_step=None):
"""Returns the action_fn with additive Gaussian noise.
Args:
action_fn: A callable(`state`, `context`) which returns a
[num_action_dims] tensor representing a action.
stddev: stddev for the Ornstein-Uhlenbeck noise.
debug: Print debug messages.
Returns:
A [num_action_dims] action tensor.
"""
if global_step is not None:
stddev *= tf.maximum( # Decay exploration during training.
tf.train.exponential_decay(1.0, global_step, 1e6, 0.8), 0.5)
def noisy_action_fn(state, context=None):
"""Noisy action fn."""
action = action_fn(state, context)
if debug:
action = uvf_utils.tf_print(
action, [action],
message='[add_noise_fn] pre-noise action',
first_n=100)
noise_dist = tf.distributions.Normal(tf.zeros_like(action),
tf.ones_like(action) * stddev)
noise = noise_dist.sample()
action += noise
if debug:
action = uvf_utils.tf_print(
action, [action],
message='[add_noise_fn] post-noise action',
first_n=100)
if clip:
action = uvf_utils.clip_to_spec(action, self._action_spec)
return action
return noisy_action_fn
def merged_state(self, state, context=None):
"""Returns the merged state from the environment state and contexts.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
If None, use the internal context.
Returns:
A [num_merged_state_dims] tensor representing the merged state.
"""
if context is None:
context = list(self.context_vars)
state = tf.concat([state,] + context, axis=-1)
self._validate_states(self._batch_state(state))
return state
def merged_states(self, states, contexts=None):
"""Returns the batch merged state from the batch env state and contexts.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
contexts: A list of [batch_size, num_context_dims] tensor
representing a batch of contexts. If None,
use the internal context.
Returns:
A [batch_size, num_merged_state_dims] tensor representing the batch
of merged states.
"""
if contexts is None:
contexts = [tf.tile(tf.expand_dims(context, axis=0),
(tf.shape(states)[0], 1)) for
context in self.context_vars]
states = tf.concat([states,] + contexts, axis=-1)
self._validate_states(states)
return states
def unmerged_states(self, merged_states):
"""Returns the batch state and contexts from the batch merged state.
Args:
merged_states: A [batch_size, num_merged_state_dims] tensor
representing a batch of merged states.
Returns:
A [batch_size, num_state_dims] tensor and a list of
[batch_size, num_context_dims] tensors representing the batch state
and contexts respectively.
"""
self._validate_states(merged_states)
num_state_dims = self.env_observation_spec.shape.as_list()[0]
num_context_dims_list = [c.shape.as_list()[0] for c in self.context_specs]
states = merged_states[:, :num_state_dims]
contexts = []
i = num_state_dims
for num_context_dims in num_context_dims_list:
contexts.append(merged_states[:, i: i+num_context_dims])
i += num_context_dims
return states, contexts
def sample_random_actions(self, batch_size=1):
"""Return random actions.
Args:
batch_size: Batch size.
Returns:
A [batch_size, num_action_dims] tensor representing the batch of actions.
"""
actions = tf.concat(
[
tf.random_uniform(
shape=(batch_size, 1),
minval=self._action_spec.minimum[i],
maxval=self._action_spec.maximum[i])
for i in range(self._action_spec.shape[0].value)
],
axis=1)
return actions
def clip_actions(self, actions):
"""Clip actions to spec.
Args:
actions: A [batch_size, num_action_dims] tensor representing
the batch of actions.
Returns:
A [batch_size, num_action_dims] tensor representing the batch
of clipped actions.
"""
actions = tf.concat(
[
tf.clip_by_value(
actions[:, i:i+1],
self._action_spec.minimum[i],
self._action_spec.maximum[i])
for i in range(self._action_spec.shape[0].value)
],
axis=1)
return actions
def mix_contexts(self, contexts, insert_contexts, indices):
"""Mix two contexts based on indices.
Args:
contexts: A list of [batch_size, num_context_dims] tensor representing
the batch of contexts.
insert_contexts: A list of [batch_size, num_context_dims] tensor
representing the batch of contexts to be inserted.
indices: A list of a list of integers denoting indices to replace.
Returns:
A list of resulting contexts.
"""
if indices is None: indices = [[]] * len(contexts)
assert len(contexts) == len(indices)
assert all([spec.shape.ndims == 1 for spec in self.context_specs])
mix_contexts = []
for contexts_, insert_contexts_, indices_, spec in zip(
contexts, insert_contexts, indices, self.context_specs):
mix_contexts.append(
tf.concat(
[
insert_contexts_[:, i:i + 1] if i in indices_ else
contexts_[:, i:i + 1] for i in range(spec.shape.as_list()[0])
],
axis=1))
return mix_contexts
def begin_episode_ops(self, mode, action_fn=None, state=None):
"""Returns ops that reset agent at beginning of episodes.
Args:
mode: a string representing the mode=[train, explore, eval].
Returns:
A list of ops.
"""
all_ops = []
for _, action_var in sorted(self._action_vars.items()):
sample_action = self.sample_random_actions(1)[0]
all_ops.append(tf.assign(action_var, sample_action))
all_ops += self.tf_context.reset(mode=mode, agent=self._meta_agent,
action_fn=action_fn, state=state)
return all_ops
def cond_begin_episode_op(self, cond, input_vars, mode, meta_action_fn):
"""Returns op that resets agent at beginning of episodes.
A new episode is begun if the cond op evalues to `False`.
Args:
cond: a Boolean tensor variable.
input_vars: A list of tensor variables.
mode: a string representing the mode=[train, explore, eval].
Returns:
Conditional begin op.
"""
(state, action, reward, next_state,
state_repr, next_state_repr) = input_vars
def continue_fn():
"""Continue op fn."""
items = [state, action, reward, next_state,
state_repr, next_state_repr] + list(self.context_vars)
batch_items = [tf.expand_dims(item, 0) for item in items]
(states, actions, rewards, next_states,
state_reprs, next_state_reprs) = batch_items[:6]
context_reward = self.compute_rewards(
mode, state_reprs, actions, rewards, next_state_reprs,
batch_items[6:])[0][0]
context_reward = tf.cast(context_reward, dtype=reward.dtype)
if self.meta_agent is not None:
meta_action = tf.concat(self.context_vars, -1)
items = [state, meta_action, reward, next_state,
state_repr, next_state_repr] + list(self.meta_agent.context_vars)
batch_items = [tf.expand_dims(item, 0) for item in items]
(states, meta_actions, rewards, next_states,
state_reprs, next_state_reprs) = batch_items[:6]
meta_reward = self.meta_agent.compute_rewards(
mode, states, meta_actions, rewards,
next_states, batch_items[6:])[0][0]
meta_reward = tf.cast(meta_reward, dtype=reward.dtype)
else:
meta_reward = tf.constant(0, dtype=reward.dtype)
with tf.control_dependencies([context_reward, meta_reward]):
step_ops = self.tf_context.step(mode=mode, agent=self._meta_agent,
state=state,
next_state=next_state,
state_repr=state_repr,
next_state_repr=next_state_repr,
action_fn=meta_action_fn)
with tf.control_dependencies(step_ops):
context_reward, meta_reward = map(tf.identity, [context_reward, meta_reward])
return context_reward, meta_reward
def begin_episode_fn():
"""Begin op fn."""
begin_ops = self.begin_episode_ops(mode=mode, action_fn=meta_action_fn, state=state)
with tf.control_dependencies(begin_ops):
return tf.zeros_like(reward), tf.zeros_like(reward)
with tf.control_dependencies(input_vars):
cond_begin_episode_op = tf.cond(cond, continue_fn, begin_episode_fn)
return cond_begin_episode_op
def get_env_base_wrapper(self, env_base, **begin_kwargs):
"""Create a wrapper around env_base, with agent-specific begin/end_episode.
Args:
env_base: A python environment base.
**begin_kwargs: Keyword args for begin_episode_ops.
Returns:
An object with begin_episode() and end_episode().
"""
begin_ops = self.begin_episode_ops(**begin_kwargs)
return uvf_utils.get_contextual_env_base(env_base, begin_ops)
def init_action_vars(self, name, i=None):
"""Create and return a tensorflow Variable holding an action.
Args:
name: Name of the variables.
i: Integer id.
Returns:
A [num_action_dims] tensor.
"""
if i is not None:
name += '_%d' % i
assert name not in self._action_vars, ('Conflict! %s is already '
'initialized.') % name
self._action_vars[name] = tf.Variable(
self.sample_random_actions(1)[0], name='%s_action' % (name))
self._validate_actions(tf.expand_dims(self._action_vars[name], 0))
return self._action_vars[name]
@gin.configurable('uvf_critic_function')
def critic_function(self, critic_vals, states, critic_fn=None):
"""Computes q values based on outputs from the critic net.
Args:
critic_vals: A tf.float32 [batch_size, ...] tensor representing outputs
from the critic net.
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
critic_fn: A callable that process outputs from critic_net and
outputs a [batch_size] tensor representing q values.
Returns:
A tf.float32 [batch_size] tensor representing q values.
"""
if critic_fn is not None:
env_states, contexts = self.unmerged_states(states)
critic_vals = critic_fn(critic_vals, env_states, contexts)
critic_vals.shape.assert_has_rank(1)
return critic_vals
def get_action_vars(self, key):
return self._action_vars[key]
def get_context_vars(self, key):
return self.tf_context.context_vars[key]
def step_cond_fn(self, *args):
return self._step_cond_fn(self, *args)
def reset_episode_cond_fn(self, *args):
return self._reset_episode_cond_fn(self, *args)
def reset_env_cond_fn(self, *args):
return self._reset_env_cond_fn(self, *args)
@property
def context_vars(self):
return self.tf_context.vars
@gin.configurable
class MetaAgentCore(UvfAgentCore):
"""Defines basic functions for UVF Meta-agent. Must be inherited with an RL agent.
Used as higher-level agent.
"""
def __init__(self,
observation_spec,
action_spec,
tf_env,
tf_context,
sub_context,
step_cond_fn=cond_fn.env_transition,
reset_episode_cond_fn=cond_fn.env_restart,
reset_env_cond_fn=cond_fn.false_fn,
metrics=None,
actions_reg=0.,
k=2,
**base_agent_kwargs):
"""Constructs a Meta agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
tf_env: A Tensorflow environment object.
tf_context: A Context class.
step_cond_fn: A function indicating whether to increment the num of steps.
reset_episode_cond_fn: A function indicating whether to restart the
episode, resampling the context.
reset_env_cond_fn: A function indicating whether to perform a manual reset
of the environment.
metrics: A list of functions that evaluate metrics of the agent.
**base_agent_kwargs: A dictionary of parameters for base RL Agent.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._step_cond_fn = step_cond_fn
self._reset_episode_cond_fn = reset_episode_cond_fn
self._reset_env_cond_fn = reset_env_cond_fn
self.metrics = metrics
self._actions_reg = actions_reg
self._k = k
# expose tf_context methods
self.tf_context = tf_context(tf_env=tf_env)
self.sub_context = sub_context(tf_env=tf_env)
self.set_replay = self.tf_context.set_replay
self.sample_contexts = self.tf_context.sample_contexts
self.compute_rewards = self.tf_context.compute_rewards
self.gamma_index = self.tf_context.gamma_index
self.context_specs = self.tf_context.context_specs
self.context_as_action_specs = self.tf_context.context_as_action_specs
self.sub_context_as_action_specs = self.sub_context.context_as_action_specs
self.init_context_vars = self.tf_context.create_vars
self.env_observation_spec = observation_spec[0]
merged_observation_spec = (uvf_utils.merge_specs(
(self.env_observation_spec,) + self.context_specs),)
self._context_vars = dict()
self._action_vars = dict()
assert len(self.context_as_action_specs) == 1
self.BASE_AGENT_CLASS.__init__(
self,
observation_spec=merged_observation_spec,
action_spec=self.sub_context_as_action_specs,
**base_agent_kwargs
)
@gin.configurable('meta_add_noise_fn')
def add_noise_fn(self, action_fn, stddev=1.0, debug=False,
global_step=None):
noisy_action_fn = super(MetaAgentCore, self).add_noise_fn(
action_fn, stddev,
clip=True, global_step=global_step)
return noisy_action_fn
def actor_loss(self, states, actions, rewards, discounts,
next_states):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
actions = self.actor_net(states, stop_gradients=False)
regularizer = self._actions_reg * tf.reduce_mean(
tf.reduce_sum(tf.abs(actions[:, self._k:]), -1), 0)
loss = self.BASE_AGENT_CLASS.actor_loss(self, states)
return regularizer + loss
@gin.configurable
class UvfAgent(UvfAgentCore, ddpg_agent.TD3Agent):
"""A DDPG agent with UVF.
"""
BASE_AGENT_CLASS = ddpg_agent.TD3Agent
ACTION_TYPE = 'continuous'
def __init__(self, *args, **kwargs):
UvfAgentCore.__init__(self, *args, **kwargs)
@gin.configurable
class MetaAgent(MetaAgentCore, ddpg_agent.TD3Agent):
"""A DDPG meta-agent.
"""
BASE_AGENT_CLASS = ddpg_agent.TD3Agent
ACTION_TYPE = 'continuous'
def __init__(self, *args, **kwargs):
MetaAgentCore.__init__(self, *args, **kwargs)
@gin.configurable()
def state_preprocess_net(
states,
num_output_dims=2,
states_hidden_layers=(100,),
normalizer_fn=None,
activation_fn=tf.nn.relu,
zero_time=True,
images=False):
"""Creates a simple feed forward net for embedding states.
"""
with slim.arg_scope(
[slim.fully_connected],
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
weights_initializer=slim.variance_scaling_initializer(
factor=1.0/3.0, mode='FAN_IN', uniform=True)):
states_shape = tf.shape(states)
states_dtype = states.dtype
states = tf.to_float(states)
if images: # Zero-out x-y
states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype)
if zero_time:
states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)
orig_states = states
embed = states
if states_hidden_layers:
embed = slim.stack(embed, slim.fully_connected, states_hidden_layers,
scope='states')
with slim.arg_scope([slim.fully_connected],
weights_regularizer=None,
weights_initializer=tf.random_uniform_initializer(
minval=-0.003, maxval=0.003)):
embed = slim.fully_connected(embed, num_output_dims,
activation_fn=None,
normalizer_fn=None,
scope='value')
output = embed
output = tf.cast(output, states_dtype)
return output
@gin.configurable()
def action_embed_net(
actions,
states=None,
num_output_dims=2,
hidden_layers=(400, 300),
normalizer_fn=None,
activation_fn=tf.nn.relu,
zero_time=True,
images=False):
"""Creates a simple feed forward net for embedding actions.
"""
with slim.arg_scope(
[slim.fully_connected],
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
weights_initializer=slim.variance_scaling_initializer(
factor=1.0/3.0, mode='FAN_IN', uniform=True)):
actions = tf.to_float(actions)
if states is not None:
if images: # Zero-out x-y
states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype)
if zero_time:
states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)
actions = tf.concat([actions, tf.to_float(states)], -1)
embed = actions
if hidden_layers:
embed = slim.stack(embed, slim.fully_connected, hidden_layers,
scope='hidden')
with slim.arg_scope([slim.fully_connected],
weights_regularizer=None,
weights_initializer=tf.random_uniform_initializer(
minval=-0.003, maxval=0.003)):
embed = slim.fully_connected(embed, num_output_dims,
activation_fn=None,
normalizer_fn=None,
scope='value')
if num_output_dims == 1:
return embed[:, 0, ...]
else:
return embed
def huber(x, kappa=0.1):
return (0.5 * tf.square(x) * tf.to_float(tf.abs(x) <= kappa) +
kappa * (tf.abs(x) - 0.5 * kappa) * tf.to_float(tf.abs(x) > kappa)
) / kappa
@gin.configurable()
class StatePreprocess(object):
STATE_PREPROCESS_NET_SCOPE = 'state_process_net'
ACTION_EMBED_NET_SCOPE = 'action_embed_net'
def __init__(self, trainable=False,
state_preprocess_net=lambda states: states,
action_embed_net=lambda actions, *args, **kwargs: actions,
ndims=None):
self.trainable = trainable
self._scope = tf.get_variable_scope().name
self._ndims = ndims
self._state_preprocess_net = tf.make_template(
self.STATE_PREPROCESS_NET_SCOPE, state_preprocess_net,
create_scope_now_=True)
self._action_embed_net = tf.make_template(
self.ACTION_EMBED_NET_SCOPE, action_embed_net,
create_scope_now_=True)
def __call__(self, states):
batched = states.get_shape().ndims != 1
if not batched:
states = tf.expand_dims(states, 0)
embedded = self._state_preprocess_net(states)
if self._ndims is not None:
embedded = embedded[..., :self._ndims]
if not batched:
return embedded[0]
return embedded
def loss(self, states, next_states, low_actions, low_states):
batch_size = tf.shape(states)[0]
d = int(low_states.shape[1])
# Sample indices into meta-transition to train on.
probs = 0.99 ** tf.range(d, dtype=tf.float32)
probs *= tf.constant([1.0] * (d - 1) + [1.0 / (1 - 0.99)],
dtype=tf.float32)
probs /= tf.reduce_sum(probs)
index_dist = tf.distributions.Categorical(probs=probs, dtype=tf.int64)
indices = index_dist.sample(batch_size)
batch_size = tf.cast(batch_size, tf.int64)
next_indices = tf.concat(
[tf.range(batch_size, dtype=tf.int64)[:, None],
(1 + indices[:, None]) % d], -1)
new_next_states = tf.where(indices < d - 1,
tf.gather_nd(low_states, next_indices),
next_states)
next_states = new_next_states
embed1 = tf.to_float(self._state_preprocess_net(states))
embed2 = tf.to_float(self._state_preprocess_net(next_states))
action_embed = self._action_embed_net(
tf.layers.flatten(low_actions), states=states)
tau = 2.0
fn = lambda z: tau * tf.reduce_sum(huber(z), -1)
all_embed = tf.get_variable('all_embed', [1024, int(embed1.shape[-1])],
initializer=tf.zeros_initializer())
upd = all_embed.assign(tf.concat([all_embed[batch_size:], embed2], 0))
with tf.control_dependencies([upd]):
close = 1 * tf.reduce_mean(fn(embed1 + action_embed - embed2))
prior_log_probs = tf.reduce_logsumexp(
-fn((embed1 + action_embed)[:, None, :] - all_embed[None, :, :]),
axis=-1) - tf.log(tf.to_float(all_embed.shape[0]))
far = tf.reduce_mean(tf.exp(-fn((embed1 + action_embed)[1:] - embed2[:-1])
- tf.stop_gradient(prior_log_probs[1:])))
repr_log_probs = tf.stop_gradient(
-fn(embed1 + action_embed - embed2) - prior_log_probs) / tau
return close + far, repr_log_probs, indices
def get_trainable_vars(self):
return (
slim.get_trainable_variables(
uvf_utils.join_scope(self._scope, self.STATE_PREPROCESS_NET_SCOPE)) +
slim.get_trainable_variables(
uvf_utils.join_scope(self._scope, self.ACTION_EMBED_NET_SCOPE)))
@gin.configurable()
class InverseDynamics(object):
INVERSE_DYNAMICS_NET_SCOPE = 'inverse_dynamics'
def __init__(self, spec):
self._spec = spec
def sample(self, states, next_states, num_samples, orig_goals, sc=0.5):
goal_dim = orig_goals.shape[-1]
spec_range = (self._spec.maximum - self._spec.minimum) / 2 * tf.ones([goal_dim])
loc = tf.cast(next_states - states, tf.float32)[:, :goal_dim]
scale = sc * tf.tile(tf.reshape(spec_range, [1, goal_dim]),
[tf.shape(states)[0], 1])
dist = tf.distributions.Normal(loc, scale)
if num_samples == 1:
return dist.sample()
samples = tf.concat([dist.sample(num_samples - 2),
tf.expand_dims(loc, 0),
tf.expand_dims(orig_goals, 0)], 0)
return uvf_utils.clip_to_spec(samples, self._spec)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A circular buffer where each element is a list of tensors.
Each element of the buffer is a list of tensors. An example use case is a replay
buffer in reinforcement learning, where each element is a list of tensors
representing the state, action, reward etc.
New elements are added sequentially, and once the buffer is full, we
start overwriting them in a circular fashion. Reading does not remove any
elements, only adding new elements does.
"""
import collections
import numpy as np
import tensorflow as tf
import gin.tf
@gin.configurable
class CircularBuffer(object):
"""A circular buffer where each element is a list of tensors."""
def __init__(self, buffer_size=1000, scope='replay_buffer'):
"""Circular buffer of list of tensors.
Args:
buffer_size: (integer) maximum number of tensor lists the buffer can hold.
scope: (string) variable scope for creating the variables.
"""
self._buffer_size = np.int64(buffer_size)
self._scope = scope
self._tensors = collections.OrderedDict()
with tf.variable_scope(self._scope):
self._num_adds = tf.Variable(0, dtype=tf.int64, name='num_adds')
self._num_adds_cs = tf.contrib.framework.CriticalSection(name='num_adds')
@property
def buffer_size(self):
return self._buffer_size
@property
def scope(self):
return self._scope
@property
def num_adds(self):
return self._num_adds
def _create_variables(self, tensors):
with tf.variable_scope(self._scope):
for name in tensors.keys():
tensor = tensors[name]
self._tensors[name] = tf.get_variable(
name='BufferVariable_' + name,
shape=[self._buffer_size] + tensor.get_shape().as_list(),
dtype=tensor.dtype,
trainable=False)
def _validate(self, tensors):
"""Validate shapes of tensors."""
if len(tensors) != len(self._tensors):
raise ValueError('Expected tensors to have %d elements. Received %d '
'instead.' % (len(self._tensors), len(tensors)))
if self._tensors.keys() != tensors.keys():
raise ValueError('The keys of tensors should be the always the same.'
'Received %s instead %s.' %
(tensors.keys(), self._tensors.keys()))
for name, tensor in tensors.items():
if tensor.get_shape().as_list() != self._tensors[
name].get_shape().as_list()[1:]:
raise ValueError('Tensor %s has incorrect shape.' % name)
if not tensor.dtype.is_compatible_with(self._tensors[name].dtype):
raise ValueError(
'Tensor %s has incorrect data type. Expected %s, received %s' %
(name, self._tensors[name].read_value().dtype, tensor.dtype))
def add(self, tensors):
"""Adds an element (list/tuple/dict of tensors) to the buffer.
Args:
tensors: (list/tuple/dict of tensors) to be added to the buffer.
Returns:
An add operation that adds the input `tensors` to the buffer. Similar to
an enqueue_op.
Raises:
ValueError: If the shapes and data types of input `tensors' are not the
same across calls to the add function.
"""
return self.maybe_add(tensors, True)
def maybe_add(self, tensors, condition):
"""Adds an element (tensors) to the buffer based on the condition..
Args:
tensors: (list/tuple of tensors) to be added to the buffer.
condition: A boolean Tensor controlling whether the tensors would be added
to the buffer or not.
Returns:
An add operation that adds the input `tensors` to the buffer. Similar to
an maybe_enqueue_op.
Raises:
ValueError: If the shapes and data types of input `tensors' are not the
same across calls to the add function.
"""
if not isinstance(tensors, dict):
names = [str(i) for i in range(len(tensors))]
tensors = collections.OrderedDict(zip(names, tensors))
if not isinstance(tensors, collections.OrderedDict):
tensors = collections.OrderedDict(
sorted(tensors.items(), key=lambda t: t[0]))
if not self._tensors:
self._create_variables(tensors)
else:
self._validate(tensors)
#@tf.critical_section(self._position_mutex)
def _increment_num_adds():
# Adding 0 to the num_adds variable is a trick to read the value of the
# variable and return a read-only tensor. Doing this in a critical
# section allows us to capture a snapshot of the variable that will
# not be affected by other threads updating num_adds.
return self._num_adds.assign_add(1) + 0
def _add():
num_adds_inc = self._num_adds_cs.execute(_increment_num_adds)
current_pos = tf.mod(num_adds_inc - 1, self._buffer_size)
update_ops = []
for name in self._tensors.keys():
update_ops.append(
tf.scatter_update(self._tensors[name], current_pos, tensors[name]))
return tf.group(*update_ops)
return tf.contrib.framework.smart_cond(condition, _add, tf.no_op)
def get_random_batch(self, batch_size, keys=None, num_steps=1):
"""Samples a batch of tensors from the buffer with replacement.
Args:
batch_size: (integer) number of elements to sample.
keys: List of keys of tensors to retrieve. If None retrieve all.
num_steps: (integer) length of trajectories to return. If > 1 will return
a list of lists, where each internal list represents a trajectory of
length num_steps.
Returns:
A list of tensors, where each element in the list is a batch sampled from
one of the tensors in the buffer.
Raises:
ValueError: If get_random_batch is called before calling the add function.
tf.errors.InvalidArgumentError: If this operation is executed before any
items are added to the buffer.
"""
if not self._tensors:
raise ValueError('The add function must be called before get_random_batch.')
if keys is None:
keys = self._tensors.keys()
latest_start_index = self.get_num_adds() - num_steps + 1
empty_buffer_assert = tf.Assert(
tf.greater(latest_start_index, 0),
['Not enough elements have been added to the buffer.'])
with tf.control_dependencies([empty_buffer_assert]):
max_index = tf.minimum(self._buffer_size, latest_start_index)
indices = tf.random_uniform(
[batch_size],
minval=0,
maxval=max_index,
dtype=tf.int64)
if num_steps == 1:
return self.gather(indices, keys)
else:
return self.gather_nstep(num_steps, indices, keys)
def gather(self, indices, keys=None):
"""Returns elements at the specified indices from the buffer.
Args:
indices: (list of integers or rank 1 int Tensor) indices in the buffer to
retrieve elements from.
keys: List of keys of tensors to retrieve. If None retrieve all.
Returns:
A list of tensors, where each element in the list is obtained by indexing
one of the tensors in the buffer.
Raises:
ValueError: If gather is called before calling the add function.
tf.errors.InvalidArgumentError: If indices are bigger than the number of
items in the buffer.
"""
if not self._tensors:
raise ValueError('The add function must be called before calling gather.')
if keys is None:
keys = self._tensors.keys()
with tf.name_scope('Gather'):
index_bound_assert = tf.Assert(
tf.less(
tf.to_int64(tf.reduce_max(indices)),
tf.minimum(self.get_num_adds(), self._buffer_size)),
['Index out of bounds.'])
with tf.control_dependencies([index_bound_assert]):
indices = tf.convert_to_tensor(indices)
batch = []
for key in keys:
batch.append(tf.gather(self._tensors[key], indices, name=key))
return batch
def gather_nstep(self, num_steps, indices, keys=None):
"""Returns elements at the specified indices from the buffer.
Args:
num_steps: (integer) length of trajectories to return.
indices: (list of rank num_steps int Tensor) indices in the buffer to
retrieve elements from for multiple trajectories. Each Tensor in the
list represents the indices for a trajectory.
keys: List of keys of tensors to retrieve. If None retrieve all.
Returns:
A list of list-of-tensors, where each element in the list is obtained by
indexing one of the tensors in the buffer.
Raises:
ValueError: If gather is called before calling the add function.
tf.errors.InvalidArgumentError: If indices are bigger than the number of
items in the buffer.
"""
if not self._tensors:
raise ValueError('The add function must be called before calling gather.')
if keys is None:
keys = self._tensors.keys()
with tf.name_scope('Gather'):
index_bound_assert = tf.Assert(
tf.less_equal(
tf.to_int64(tf.reduce_max(indices) + num_steps),
self.get_num_adds()),
['Trajectory indices go out of bounds.'])
with tf.control_dependencies([index_bound_assert]):
indices = tf.map_fn(
lambda x: tf.mod(tf.range(x, x + num_steps), self._buffer_size),
indices,
dtype=tf.int64)
batch = []
for key in keys:
def SampleTrajectories(trajectory_indices, key=key,
num_steps=num_steps):
trajectory_indices.set_shape([num_steps])
return tf.gather(self._tensors[key], trajectory_indices, name=key)
batch.append(tf.map_fn(SampleTrajectories, indices,
dtype=self._tensors[key].dtype))
return batch
def get_position(self):
"""Returns the position at which the last element was added.
Returns:
An int tensor representing the index at which the last element was added
to the buffer or -1 if no elements were added.
"""
return tf.cond(self.get_num_adds() < 1,
lambda: self.get_num_adds() - 1,
lambda: tf.mod(self.get_num_adds() - 1, self._buffer_size))
def get_num_adds(self):
"""Returns the number of additions to the buffer.
Returns:
An int tensor representing the number of elements that were added.
"""
def num_adds():
return self._num_adds.value()
return self._num_adds_cs.execute(num_adds)
def get_num_tensors(self):
"""Returns the number of tensors (slots) in the buffer."""
return len(self._tensors)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A DDPG/NAF agent.
Implements the Deep Deterministic Policy Gradient (DDPG) algorithm from
"Continuous control with deep reinforcement learning" - Lilicrap et al.
https://arxiv.org/abs/1509.02971, and the Normalized Advantage Functions (NAF)
algorithm "Continuous Deep Q-Learning with Model-based Acceleration" - Gu et al.
https://arxiv.org/pdf/1603.00748.
"""
import tensorflow as tf
slim = tf.contrib.slim
import gin.tf
from utils import utils
from agents import ddpg_networks as networks
@gin.configurable
class DdpgAgent(object):
"""An RL agent that learns using the DDPG algorithm.
Example usage:
def critic_net(states, actions):
...
def actor_net(states, num_action_dims):
...
Given a tensorflow environment tf_env,
(of type learning.deepmind.rl.environments.tensorflow.python.tfpyenvironment)
obs_spec = tf_env.observation_spec()
action_spec = tf_env.action_spec()
ddpg_agent = agent.DdpgAgent(obs_spec,
action_spec,
actor_net=actor_net,
critic_net=critic_net)
we can perform actions on the environment as follows:
state = tf_env.observations()[0]
action = ddpg_agent.actor_net(tf.expand_dims(state, 0))[0, :]
transition_type, reward, discount = tf_env.step([action])
Train:
critic_loss = ddpg_agent.critic_loss(states, actions, rewards, discounts,
next_states)
actor_loss = ddpg_agent.actor_loss(states)
critic_train_op = slim.learning.create_train_op(
critic_loss,
critic_optimizer,
variables_to_train=ddpg_agent.get_trainable_critic_vars(),
)
actor_train_op = slim.learning.create_train_op(
actor_loss,
actor_optimizer,
variables_to_train=ddpg_agent.get_trainable_actor_vars(),
)
"""
ACTOR_NET_SCOPE = 'actor_net'
CRITIC_NET_SCOPE = 'critic_net'
TARGET_ACTOR_NET_SCOPE = 'target_actor_net'
TARGET_CRITIC_NET_SCOPE = 'target_critic_net'
def __init__(self,
observation_spec,
action_spec,
actor_net=networks.actor_net,
critic_net=networks.critic_net,
td_errors_loss=tf.losses.huber_loss,
dqda_clipping=0.,
actions_regularizer=0.,
target_q_clipping=None,
residual_phi=0.0,
debug_summaries=False):
"""Constructs a DDPG agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
actor_net: A callable that creates the actor network. Must take the
following arguments: states, num_actions. Please see networks.actor_net
for an example.
critic_net: A callable that creates the critic network. Must take the
following arguments: states, actions. Please see networks.critic_net
for an example.
td_errors_loss: A callable defining the loss function for the critic
td error.
dqda_clipping: (float) clips the gradient dqda element-wise between
[-dqda_clipping, dqda_clipping]. Does not perform clipping if
dqda_clipping == 0.
actions_regularizer: A scalar, when positive penalizes the norm of the
actions. This can prevent saturation of actions for the actor_loss.
target_q_clipping: (tuple of floats) clips target q values within
(low, high) values when computing the critic loss.
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that
interpolates between Q-learning and residual gradient algorithm.
http://www.leemon.com/papers/1995b.pdf
debug_summaries: If True, add summaries to help debug behavior.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._observation_spec = observation_spec[0]
self._action_spec = action_spec[0]
self._state_shape = tf.TensorShape([None]).concatenate(
self._observation_spec.shape)
self._action_shape = tf.TensorShape([None]).concatenate(
self._action_spec.shape)
self._num_action_dims = self._action_spec.shape.num_elements()
self._scope = tf.get_variable_scope().name
self._actor_net = tf.make_template(
self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._critic_net = tf.make_template(
self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._target_actor_net = tf.make_template(
self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._target_critic_net = tf.make_template(
self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._td_errors_loss = td_errors_loss
if dqda_clipping < 0:
raise ValueError('dqda_clipping must be >= 0.')
self._dqda_clipping = dqda_clipping
self._actions_regularizer = actions_regularizer
self._target_q_clipping = target_q_clipping
self._residual_phi = residual_phi
self._debug_summaries = debug_summaries
def _batch_state(self, state):
"""Convert state to a batched state.
Args:
state: Either a list/tuple with an state tensor [num_state_dims].
Returns:
A tensor [1, num_state_dims]
"""
if isinstance(state, (tuple, list)):
state = state[0]
if state.get_shape().ndims == 1:
state = tf.expand_dims(state, 0)
return state
def action(self, state):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
Returns:
A [num_action_dims] tensor representing the action.
"""
return self.actor_net(self._batch_state(state), stop_gradients=True)[0, :]
@gin.configurable('ddpg_sample_action')
def sample_action(self, state, stddev=1.0):
"""Returns the action for the state with additive noise.
Args:
state: A [num_state_dims] tensor representing a state.
stddev: stddev for the Ornstein-Uhlenbeck noise.
Returns:
A [num_action_dims] action tensor.
"""
agent_action = self.action(state)
agent_action += tf.random_normal(tf.shape(agent_action)) * stddev
return utils.clip_to_spec(agent_action, self._action_spec)
def actor_net(self, states, stop_gradients=False):
"""Returns the output of the actor network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
stop_gradients: (boolean) if true, gradients cannot be propogated through
this operation.
Returns:
A [batch_size, num_action_dims] tensor of actions.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self._validate_states(states)
actions = self._actor_net(states, self._action_spec)
if stop_gradients:
actions = tf.stop_gradient(actions)
return actions
def critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the critic network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self._validate_states(states)
self._validate_actions(actions)
return self._critic_net(states, actions,
for_critic_loss=for_critic_loss)
def target_actor_net(self, states):
"""Returns the output of the target actor network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
A [batch_size, num_action_dims] tensor of actions.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self._validate_states(states)
actions = self._target_actor_net(states, self._action_spec)
return tf.stop_gradient(actions)
def target_critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the target critic network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self._validate_states(states)
self._validate_actions(actions)
return tf.stop_gradient(
self._target_critic_net(states, actions,
for_critic_loss=for_critic_loss))
def value_net(self, states, for_critic_loss=False):
"""Returns the output of the critic evaluated with the actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
actions = self.actor_net(states)
return self.critic_net(states, actions,
for_critic_loss=for_critic_loss)
def target_value_net(self, states, for_critic_loss=False):
"""Returns the output of the target critic evaluated with the target actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
target_actions = self.target_actor_net(states)
return self.target_critic_net(states, target_actions,
for_critic_loss=for_critic_loss)
def critic_loss(self, states, actions, rewards, discounts,
next_states):
"""Computes a loss for training the critic network.
The loss is the mean squared error between the Q value predictions of the
critic and Q values estimated using TD-lambda.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
rewards: A [batch_size, ...] tensor representing a batch of rewards,
broadcastable to the critic net output.
discounts: A [batch_size, ...] tensor representing a batch of discounts,
broadcastable to the critic net output.
next_states: A [batch_size, num_state_dims] tensor representing a batch
of next states.
Returns:
A rank-0 tensor representing the critic loss.
Raises:
ValueError: If any of the inputs do not have the expected dimensions, or
if their batch_sizes do not match.
"""
self._validate_states(states)
self._validate_actions(actions)
self._validate_states(next_states)
target_q_values = self.target_value_net(next_states, for_critic_loss=True)
td_targets = target_q_values * discounts + rewards
if self._target_q_clipping is not None:
td_targets = tf.clip_by_value(td_targets, self._target_q_clipping[0],
self._target_q_clipping[1])
q_values = self.critic_net(states, actions, for_critic_loss=True)
td_errors = td_targets - q_values
if self._debug_summaries:
gen_debug_td_error_summaries(
target_q_values, q_values, td_targets, td_errors)
loss = self._td_errors_loss(td_targets, q_values)
if self._residual_phi > 0.0: # compute residual gradient loss
residual_q_values = self.value_net(next_states, for_critic_loss=True)
residual_td_targets = residual_q_values * discounts + rewards
if self._target_q_clipping is not None:
residual_td_targets = tf.clip_by_value(residual_td_targets,
self._target_q_clipping[0],
self._target_q_clipping[1])
residual_td_errors = residual_td_targets - q_values
residual_loss = self._td_errors_loss(
residual_td_targets, residual_q_values)
loss = (loss * (1.0 - self._residual_phi) +
residual_loss * self._residual_phi)
return loss
def actor_loss(self, states):
"""Computes a loss for training the actor network.
Note that output does not represent an actual loss. It is called a loss only
in the sense that its gradient w.r.t. the actor network weights is the
correct gradient for training the actor network,
i.e. dloss/dweights = (dq/da)*(da/dweights)
which is the gradient used in Algorithm 1 of Lilicrap et al.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
A rank-0 tensor representing the actor loss.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self._validate_states(states)
actions = self.actor_net(states, stop_gradients=False)
critic_values = self.critic_net(states, actions)
q_values = self.critic_function(critic_values, states)
dqda = tf.gradients([q_values], [actions])[0]
dqda_unclipped = dqda
if self._dqda_clipping > 0:
dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping)
actions_norm = tf.norm(actions)
if self._debug_summaries:
with tf.name_scope('dqda'):
tf.summary.scalar('actions_norm', actions_norm)
tf.summary.histogram('dqda', dqda)
tf.summary.histogram('dqda_unclipped', dqda_unclipped)
tf.summary.histogram('actions', actions)
for a in range(self._num_action_dims):
tf.summary.histogram('dqda_unclipped_%d' % a, dqda_unclipped[:, a])
tf.summary.histogram('dqda_%d' % a, dqda[:, a])
actions_norm *= self._actions_regularizer
return slim.losses.mean_squared_error(tf.stop_gradient(dqda + actions),
actions,
scope='actor_loss') + actions_norm
@gin.configurable('ddpg_critic_function')
def critic_function(self, critic_values, states, weights=None):
"""Computes q values based on critic_net outputs, states, and weights.
Args:
critic_values: A tf.float32 [batch_size, ...] tensor representing outputs
from the critic net.
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
weights: A list or Numpy array or tensor with a shape broadcastable to
`critic_values`.
Returns:
A tf.float32 [batch_size] tensor representing q values.
"""
del states # unused args
if weights is not None:
weights = tf.convert_to_tensor(weights, dtype=critic_values.dtype)
critic_values *= weights
if critic_values.shape.ndims > 1:
critic_values = tf.reduce_sum(critic_values,
range(1, critic_values.shape.ndims))
critic_values.shape.assert_has_rank(1)
return critic_values
@gin.configurable('ddpg_update_targets')
def update_targets(self, tau=1.0):
"""Performs a soft update of the target network parameters.
For each weight w_s in the actor/critic networks, and its corresponding
weight w_t in the target actor/critic networks, a soft update is:
w_t = (1- tau) x w_t + tau x ws
Args:
tau: A float scalar in [0, 1]
Returns:
An operation that performs a soft update of the target network parameters.
Raises:
ValueError: If `tau` is not in [0, 1].
"""
if tau < 0 or tau > 1:
raise ValueError('Input `tau` should be in [0, 1].')
update_actor = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)),
tau)
update_critic = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)),
tau)
return tf.group(update_actor, update_critic, name='update_targets')
def get_trainable_critic_vars(self):
"""Returns a list of trainable variables in the critic network.
Returns:
A list of trainable variables in the critic network.
"""
return slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))
def get_trainable_actor_vars(self):
"""Returns a list of trainable variables in the actor network.
Returns:
A list of trainable variables in the actor network.
"""
return slim.get_trainable_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE))
def get_critic_vars(self):
"""Returns a list of all variables in the critic network.
Returns:
A list of trainable variables in the critic network.
"""
return slim.get_model_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))
def get_actor_vars(self):
"""Returns a list of all variables in the actor network.
Returns:
A list of trainable variables in the actor network.
"""
return slim.get_model_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE))
def _validate_states(self, states):
"""Raises a value error if `states` does not have the expected shape.
Args:
states: A tensor.
Raises:
ValueError: If states.shape or states.dtype are not compatible with
observation_spec.
"""
states.shape.assert_is_compatible_with(self._state_shape)
if not states.dtype.is_compatible_with(self._observation_spec.dtype):
raise ValueError('states.dtype={} is not compatible with'
' observation_spec.dtype={}'.format(
states.dtype, self._observation_spec.dtype))
def _validate_actions(self, actions):
"""Raises a value error if `actions` does not have the expected shape.
Args:
actions: A tensor.
Raises:
ValueError: If actions.shape or actions.dtype are not compatible with
action_spec.
"""
actions.shape.assert_is_compatible_with(self._action_shape)
if not actions.dtype.is_compatible_with(self._action_spec.dtype):
raise ValueError('actions.dtype={} is not compatible with'
' action_spec.dtype={}'.format(
actions.dtype, self._action_spec.dtype))
@gin.configurable
class TD3Agent(DdpgAgent):
"""An RL agent that learns using the TD3 algorithm."""
ACTOR_NET_SCOPE = 'actor_net'
CRITIC_NET_SCOPE = 'critic_net'
CRITIC_NET2_SCOPE = 'critic_net2'
TARGET_ACTOR_NET_SCOPE = 'target_actor_net'
TARGET_CRITIC_NET_SCOPE = 'target_critic_net'
TARGET_CRITIC_NET2_SCOPE = 'target_critic_net2'
def __init__(self,
observation_spec,
action_spec,
actor_net=networks.actor_net,
critic_net=networks.critic_net,
td_errors_loss=tf.losses.huber_loss,
dqda_clipping=0.,
actions_regularizer=0.,
target_q_clipping=None,
residual_phi=0.0,
debug_summaries=False):
"""Constructs a TD3 agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
actor_net: A callable that creates the actor network. Must take the
following arguments: states, num_actions. Please see networks.actor_net
for an example.
critic_net: A callable that creates the critic network. Must take the
following arguments: states, actions. Please see networks.critic_net
for an example.
td_errors_loss: A callable defining the loss function for the critic
td error.
dqda_clipping: (float) clips the gradient dqda element-wise between
[-dqda_clipping, dqda_clipping]. Does not perform clipping if
dqda_clipping == 0.
actions_regularizer: A scalar, when positive penalizes the norm of the
actions. This can prevent saturation of actions for the actor_loss.
target_q_clipping: (tuple of floats) clips target q values within
(low, high) values when computing the critic loss.
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that
interpolates between Q-learning and residual gradient algorithm.
http://www.leemon.com/papers/1995b.pdf
debug_summaries: If True, add summaries to help debug behavior.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._observation_spec = observation_spec[0]
self._action_spec = action_spec[0]
self._state_shape = tf.TensorShape([None]).concatenate(
self._observation_spec.shape)
self._action_shape = tf.TensorShape([None]).concatenate(
self._action_spec.shape)
self._num_action_dims = self._action_spec.shape.num_elements()
self._scope = tf.get_variable_scope().name
self._actor_net = tf.make_template(
self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._critic_net = tf.make_template(
self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._critic_net2 = tf.make_template(
self.CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True)
self._target_actor_net = tf.make_template(
self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._target_critic_net = tf.make_template(
self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._target_critic_net2 = tf.make_template(
self.TARGET_CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True)
self._td_errors_loss = td_errors_loss
if dqda_clipping < 0:
raise ValueError('dqda_clipping must be >= 0.')
self._dqda_clipping = dqda_clipping
self._actions_regularizer = actions_regularizer
self._target_q_clipping = target_q_clipping
self._residual_phi = residual_phi
self._debug_summaries = debug_summaries
def get_trainable_critic_vars(self):
"""Returns a list of trainable variables in the critic network.
NOTE: This gets the vars of both critic networks.
Returns:
A list of trainable variables in the critic network.
"""
return (
slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)))
def critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the critic network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
values1 = self._critic_net(states, actions,
for_critic_loss=for_critic_loss)
values2 = self._critic_net2(states, actions,
for_critic_loss=for_critic_loss)
if for_critic_loss:
return values1, values2
return values1
def target_critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the target critic network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self._validate_states(states)
self._validate_actions(actions)
values1 = tf.stop_gradient(
self._target_critic_net(states, actions,
for_critic_loss=for_critic_loss))
values2 = tf.stop_gradient(
self._target_critic_net2(states, actions,
for_critic_loss=for_critic_loss))
if for_critic_loss:
return values1, values2
return values1
def value_net(self, states, for_critic_loss=False):
"""Returns the output of the critic evaluated with the actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
actions = self.actor_net(states)
return self.critic_net(states, actions,
for_critic_loss=for_critic_loss)
def target_value_net(self, states, for_critic_loss=False):
"""Returns the output of the target critic evaluated with the target actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
target_actions = self.target_actor_net(states)
noise = tf.clip_by_value(
tf.random_normal(tf.shape(target_actions), stddev=0.2), -0.5, 0.5)
values1, values2 = self.target_critic_net(
states, target_actions + noise,
for_critic_loss=for_critic_loss)
values = tf.minimum(values1, values2)
return values, values
@gin.configurable('td3_update_targets')
def update_targets(self, tau=1.0):
"""Performs a soft update of the target network parameters.
For each weight w_s in the actor/critic networks, and its corresponding
weight w_t in the target actor/critic networks, a soft update is:
w_t = (1- tau) x w_t + tau x ws
Args:
tau: A float scalar in [0, 1]
Returns:
An operation that performs a soft update of the target network parameters.
Raises:
ValueError: If `tau` is not in [0, 1].
"""
if tau < 0 or tau > 1:
raise ValueError('Input `tau` should be in [0, 1].')
update_actor = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)),
tau)
# NOTE: This updates both critic networks.
update_critic = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)),
tau)
return tf.group(update_actor, update_critic, name='update_targets')
def gen_debug_td_error_summaries(
target_q_values, q_values, td_targets, td_errors):
"""Generates debug summaries for critic given a set of batch samples.
Args:
target_q_values: set of predicted next stage values.
q_values: current predicted value for the critic network.
td_targets: discounted target_q_values with added next stage reward.
td_errors: the different between td_targets and q_values.
"""
with tf.name_scope('td_errors'):
tf.summary.histogram('td_targets', td_targets)
tf.summary.histogram('q_values', q_values)
tf.summary.histogram('target_q_values', target_q_values)
tf.summary.histogram('td_errors', td_errors)
with tf.name_scope('td_targets'):
tf.summary.scalar('mean', tf.reduce_mean(td_targets))
tf.summary.scalar('max', tf.reduce_max(td_targets))
tf.summary.scalar('min', tf.reduce_min(td_targets))
with tf.name_scope('q_values'):
tf.summary.scalar('mean', tf.reduce_mean(q_values))
tf.summary.scalar('max', tf.reduce_max(q_values))
tf.summary.scalar('min', tf.reduce_min(q_values))
with tf.name_scope('target_q_values'):
tf.summary.scalar('mean', tf.reduce_mean(target_q_values))
tf.summary.scalar('max', tf.reduce_max(target_q_values))
tf.summary.scalar('min', tf.reduce_min(target_q_values))
with tf.name_scope('td_errors'):
tf.summary.scalar('mean', tf.reduce_mean(td_errors))
tf.summary.scalar('max', tf.reduce_max(td_errors))
tf.summary.scalar('min', tf.reduce_min(td_errors))
tf.summary.scalar('mean_abs', tf.reduce_mean(tf.abs(td_errors)))
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sample actor(policy) and critic(q) networks to use with DDPG/NAF agents.
The DDPG networks are defined in "Section 7: Experiment Details" of
"Continuous control with deep reinforcement learning" - Lilicrap et al.
https://arxiv.org/abs/1509.02971
The NAF critic network is based on "Section 4" of "Continuous deep Q-learning
with model-based acceleration" - Gu et al. https://arxiv.org/pdf/1603.00748.
"""
import tensorflow as tf
slim = tf.contrib.slim
import gin.tf
@gin.configurable('ddpg_critic_net')
def critic_net(states, actions,
for_critic_loss=False,
num_reward_dims=1,
states_hidden_layers=(400,),
actions_hidden_layers=None,
joint_hidden_layers=(300,),
weight_decay=0.0001,
normalizer_fn=None,
activation_fn=tf.nn.relu,
zero_obs=False,
images=False):
"""Creates a critic that returns q values for the given states and actions.
Args:
states: (castable to tf.float32) a [batch_size, num_state_dims] tensor
representing a batch of states.
actions: (castable to tf.float32) a [batch_size, num_action_dims] tensor
representing a batch of actions.
num_reward_dims: Number of reward dimensions.
states_hidden_layers: tuple of hidden layers units for states.
actions_hidden_layers: tuple of hidden layers units for actions.
joint_hidden_layers: tuple of hidden layers units after joining states
and actions using tf.concat().
weight_decay: Weight decay for l2 weights regularizer.
normalizer_fn: Normalizer function, i.e. slim.layer_norm,
activation_fn: Activation function, i.e. tf.nn.relu, slim.leaky_relu, ...
Returns:
A tf.float32 [batch_size] tensor of q values, or a tf.float32
[batch_size, num_reward_dims] tensor of vector q values if
num_reward_dims > 1.
"""
with slim.arg_scope(
[slim.fully_connected],
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
weights_regularizer=slim.l2_regularizer(weight_decay),
weights_initializer=slim.variance_scaling_initializer(
factor=1.0/3.0, mode='FAN_IN', uniform=True)):
orig_states = tf.to_float(states)
#states = tf.to_float(states)
states = tf.concat([tf.to_float(states), tf.to_float(actions)], -1) #TD3
if images or zero_obs:
states *= tf.constant([0.0] * 2 + [1.0] * (states.shape[1] - 2)) #LALA
actions = tf.to_float(actions)
if states_hidden_layers:
states = slim.stack(states, slim.fully_connected, states_hidden_layers,
scope='states')
if actions_hidden_layers:
actions = slim.stack(actions, slim.fully_connected, actions_hidden_layers,
scope='actions')
joint = tf.concat([states, actions], 1)
if joint_hidden_layers:
joint = slim.stack(joint, slim.fully_connected, joint_hidden_layers,
scope='joint')
with slim.arg_scope([slim.fully_connected],
weights_regularizer=None,
weights_initializer=tf.random_uniform_initializer(
minval=-0.003, maxval=0.003)):
value = slim.fully_connected(joint, num_reward_dims,
activation_fn=None,
normalizer_fn=None,
scope='q_value')
if num_reward_dims == 1:
value = tf.reshape(value, [-1])
if not for_critic_loss and num_reward_dims > 1:
value = tf.reduce_sum(
value * tf.abs(orig_states[:, -num_reward_dims:]), -1)
return value
@gin.configurable('ddpg_actor_net')
def actor_net(states, action_spec,
hidden_layers=(400, 300),
normalizer_fn=None,
activation_fn=tf.nn.relu,
zero_obs=False,
images=False):
"""Creates an actor that returns actions for the given states.
Args:
states: (castable to tf.float32) a [batch_size, num_state_dims] tensor
representing a batch of states.
action_spec: (BoundedTensorSpec) A tensor spec indicating the shape
and range of actions.
hidden_layers: tuple of hidden layers units.
normalizer_fn: Normalizer function, i.e. slim.layer_norm,
activation_fn: Activation function, i.e. tf.nn.relu, slim.leaky_relu, ...
Returns:
A tf.float32 [batch_size, num_action_dims] tensor of actions.
"""
with slim.arg_scope(
[slim.fully_connected],
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
weights_initializer=slim.variance_scaling_initializer(
factor=1.0/3.0, mode='FAN_IN', uniform=True)):
states = tf.to_float(states)
orig_states = states
if images or zero_obs: # Zero-out x, y position. Hacky.
states *= tf.constant([0.0] * 2 + [1.0] * (states.shape[1] - 2))
if hidden_layers:
states = slim.stack(states, slim.fully_connected, hidden_layers,
scope='states')
with slim.arg_scope([slim.fully_connected],
weights_initializer=tf.random_uniform_initializer(
minval=-0.003, maxval=0.003)):
actions = slim.fully_connected(states,
action_spec.shape.num_elements(),
scope='actions',
normalizer_fn=None,
activation_fn=tf.nn.tanh)
action_means = (action_spec.maximum + action_spec.minimum) / 2.0
action_magnitudes = (action_spec.maximum - action_spec.minimum) / 2.0
actions = action_means + action_magnitudes * actions
return actions
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines many boolean functions indicating when to step and reset.
"""
import tensorflow as tf
import gin.tf
@gin.configurable
def env_transition(agent, state, action, transition_type, environment_steps,
num_episodes):
"""True if the transition_type is TRANSITION or FINAL_TRANSITION.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: Returns an op that evaluates to true if the transition type is
not RESTARTING
"""
del agent, state, action, num_episodes, environment_steps
cond = tf.logical_not(transition_type)
return cond
@gin.configurable
def env_restart(agent, state, action, transition_type, environment_steps,
num_episodes):
"""True if the transition_type is RESTARTING.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: Returns an op that evaluates to true if the transition type equals
RESTARTING.
"""
del agent, state, action, num_episodes, environment_steps
cond = tf.identity(transition_type)
return cond
@gin.configurable
def every_n_steps(agent,
state,
action,
transition_type,
environment_steps,
num_episodes,
n=150):
"""True once every n steps.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
n: Return true once every n steps.
Returns:
cond: Returns an op that evaluates to true if environment_steps
equals 0 mod n. We increment the step before checking this condition, so
we do not need to add one to environment_steps.
"""
del agent, state, action, transition_type, num_episodes
cond = tf.equal(tf.mod(environment_steps, n), 0)
return cond
@gin.configurable
def every_n_episodes(agent,
state,
action,
transition_type,
environment_steps,
num_episodes,
n=2,
steps_per_episode=None):
"""True once every n episodes.
Specifically, evaluates to True on the 0th step of every nth episode.
Unlike environment_steps, num_episodes starts at 0, so we do want to add
one to ensure it does not reset on the first call.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
n: Return true once every n episodes.
steps_per_episode: How many steps per episode. Needed to determine when a
new episode starts.
Returns:
cond: Returns an op that evaluates to true on the last step of the episode
(i.e. if num_episodes equals 0 mod n).
"""
assert steps_per_episode is not None
del agent, action, transition_type
ant_fell = tf.logical_or(state[2] < 0.2, state[2] > 1.0)
cond = tf.logical_and(
tf.logical_or(
ant_fell,
tf.equal(tf.mod(num_episodes + 1, n), 0)),
tf.equal(tf.mod(environment_steps, steps_per_episode), 0))
return cond
@gin.configurable
def failed_reset_after_n_episodes(agent,
state,
action,
transition_type,
environment_steps,
num_episodes,
steps_per_episode=None,
reset_state=None,
max_dist=1.0,
epsilon=1e-10):
"""Every n episodes, returns True if the reset agent fails to return.
Specifically, evaluates to True if the distance between the state and the
reset state is greater than max_dist at the end of the episode.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
steps_per_episode: How many steps per episode. Needed to determine when a
new episode starts.
reset_state: State to which the reset controller should return.
max_dist: Agent is considered to have successfully reset if its distance
from the reset_state is less than max_dist.
epsilon: small offset to ensure non-negative/zero distance.
Returns:
cond: Returns an op that evaluates to true if num_episodes+1 equals 0
mod n. We add one to the num_episodes so the environment is not reset after
the 0th step.
"""
assert steps_per_episode is not None
assert reset_state is not None
del agent, state, action, transition_type, num_episodes
dist = tf.sqrt(
tf.reduce_sum(tf.squared_difference(state, reset_state)) + epsilon)
cond = tf.logical_and(
tf.greater(dist, tf.constant(max_dist)),
tf.equal(tf.mod(environment_steps, steps_per_episode), 0))
return cond
@gin.configurable
def q_too_small(agent,
state,
action,
transition_type,
environment_steps,
num_episodes,
q_min=0.5):
"""True of q is too small.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
q_min: Returns true if the qval is less than q_min
Returns:
cond: Returns an op that evaluates to true if qval is less than q_min.
"""
del transition_type, environment_steps, num_episodes
state_for_reset_agent = tf.stack(state[:-1], tf.constant([0], dtype=tf.float))
qval = agent.BASE_AGENT_CLASS.critic_net(
tf.expand_dims(state_for_reset_agent, 0), tf.expand_dims(action, 0))[0, :]
cond = tf.greater(tf.constant(q_min), qval)
return cond
@gin.configurable
def true_fn(agent, state, action, transition_type, environment_steps,
num_episodes):
"""Returns an op that evaluates to true.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: op that always evaluates to True.
"""
del agent, state, action, transition_type, environment_steps, num_episodes
cond = tf.constant(True, dtype=tf.bool)
return cond
@gin.configurable
def false_fn(agent, state, action, transition_type, environment_steps,
num_episodes):
"""Returns an op that evaluates to false.
Args:
agent: RL agent.
state: A [num_state_dims] tensor representing a state.
action: Action performed.
transition_type: Type of transition after action
environment_steps: Number of steps performed by environment.
num_episodes: Number of episodes.
Returns:
cond: op that always evaluates to False.
"""
del agent, state, action, transition_type, environment_steps, num_episodes
cond = tf.constant(False, dtype=tf.bool)
return cond
#-*-Python-*-
import gin.tf.external_configurables
create_maze_env.top_down_view = %IMAGES
## Create the agent
AGENT_CLASS = @UvfAgent
UvfAgent.tf_context = %CONTEXT
UvfAgent.actor_net = @agent/ddpg_actor_net
UvfAgent.critic_net = @agent/ddpg_critic_net
UvfAgent.dqda_clipping = 0.0
UvfAgent.td_errors_loss = @tf.losses.huber_loss
UvfAgent.target_q_clipping = %TARGET_Q_CLIPPING
# Create meta agent
META_CLASS = @MetaAgent
MetaAgent.tf_context = %META_CONTEXT
MetaAgent.sub_context = %CONTEXT
MetaAgent.actor_net = @meta/ddpg_actor_net
MetaAgent.critic_net = @meta/ddpg_critic_net
MetaAgent.dqda_clipping = 0.0
MetaAgent.td_errors_loss = @tf.losses.huber_loss
MetaAgent.target_q_clipping = %TARGET_Q_CLIPPING
# Create state preprocess
STATE_PREPROCESS_CLASS = @StatePreprocess
StatePreprocess.ndims = %SUBGOAL_DIM
state_preprocess_net.states_hidden_layers = (100, 100)
state_preprocess_net.num_output_dims = %SUBGOAL_DIM
state_preprocess_net.images = %IMAGES
action_embed_net.num_output_dims = %SUBGOAL_DIM
INVERSE_DYNAMICS_CLASS = @InverseDynamics
# actor_net
ACTOR_HIDDEN_SIZE_1 = 300
ACTOR_HIDDEN_SIZE_2 = 300
agent/ddpg_actor_net.hidden_layers = (%ACTOR_HIDDEN_SIZE_1, %ACTOR_HIDDEN_SIZE_2)
agent/ddpg_actor_net.activation_fn = @tf.nn.relu
agent/ddpg_actor_net.zero_obs = %ZERO_OBS
agent/ddpg_actor_net.images = %IMAGES
meta/ddpg_actor_net.hidden_layers = (%ACTOR_HIDDEN_SIZE_1, %ACTOR_HIDDEN_SIZE_2)
meta/ddpg_actor_net.activation_fn = @tf.nn.relu
meta/ddpg_actor_net.zero_obs = False
meta/ddpg_actor_net.images = %IMAGES
# critic_net
CRITIC_HIDDEN_SIZE_1 = 300
CRITIC_HIDDEN_SIZE_2 = 300
agent/ddpg_critic_net.states_hidden_layers = (%CRITIC_HIDDEN_SIZE_1,)
agent/ddpg_critic_net.actions_hidden_layers = None
agent/ddpg_critic_net.joint_hidden_layers = (%CRITIC_HIDDEN_SIZE_2,)
agent/ddpg_critic_net.weight_decay = 0.0
agent/ddpg_critic_net.activation_fn = @tf.nn.relu
agent/ddpg_critic_net.zero_obs = %ZERO_OBS
agent/ddpg_critic_net.images = %IMAGES
meta/ddpg_critic_net.states_hidden_layers = (%CRITIC_HIDDEN_SIZE_1,)
meta/ddpg_critic_net.actions_hidden_layers = None
meta/ddpg_critic_net.joint_hidden_layers = (%CRITIC_HIDDEN_SIZE_2,)
meta/ddpg_critic_net.weight_decay = 0.0
meta/ddpg_critic_net.activation_fn = @tf.nn.relu
meta/ddpg_critic_net.zero_obs = False
meta/ddpg_critic_net.images = %IMAGES
tf.losses.huber_loss.delta = 1.0
# Sample action
uvf_add_noise_fn.stddev = 1.0
meta_add_noise_fn.stddev = %META_EXPLORE_NOISE
# Update targets
ddpg_update_targets.tau = 0.001
td3_update_targets.tau = 0.005
#-*-Python-*-
# Config eval
evaluate.environment = @create_maze_env()
evaluate.agent_class = %AGENT_CLASS
evaluate.meta_agent_class = %META_CLASS
evaluate.state_preprocess_class = %STATE_PREPROCESS_CLASS
evaluate.num_episodes_eval = 50
evaluate.num_episodes_videos = 1
evaluate.gamma = 1.0
evaluate.eval_interval_secs = 1
evaluate.generate_videos = False
evaluate.generate_summaries = True
evaluate.eval_modes = %EVAL_MODES
evaluate.max_steps_per_episode = %RESET_EPISODE_PERIOD
#-*-Python-*-
# Create replay_buffer
agent/CircularBuffer.buffer_size = 200000
meta/CircularBuffer.buffer_size = 200000
agent/CircularBuffer.scope = "agent"
meta/CircularBuffer.scope = "meta"
# Config train
train_uvf.environment = @create_maze_env()
train_uvf.agent_class = %AGENT_CLASS
train_uvf.meta_agent_class = %META_CLASS
train_uvf.state_preprocess_class = %STATE_PREPROCESS_CLASS
train_uvf.inverse_dynamics_class = %INVERSE_DYNAMICS_CLASS
train_uvf.replay_buffer = @agent/CircularBuffer()
train_uvf.meta_replay_buffer = @meta/CircularBuffer()
train_uvf.critic_optimizer = @critic/AdamOptimizer()
train_uvf.actor_optimizer = @actor/AdamOptimizer()
train_uvf.meta_critic_optimizer = @meta_critic/AdamOptimizer()
train_uvf.meta_actor_optimizer = @meta_actor/AdamOptimizer()
train_uvf.repr_optimizer = @repr/AdamOptimizer()
train_uvf.num_episodes_train = 25000
train_uvf.batch_size = 100
train_uvf.initial_episodes = 5
train_uvf.gamma = 0.99
train_uvf.meta_gamma = 0.99
train_uvf.reward_scale_factor = 1.0
train_uvf.target_update_period = 2
train_uvf.num_updates_per_observation = 1
train_uvf.num_collect_per_update = 1
train_uvf.num_collect_per_meta_update = 10
train_uvf.debug_summaries = False
train_uvf.log_every_n_steps = 1000
train_uvf.save_policy_every_n_steps =100000
# Config Optimizers
critic/AdamOptimizer.learning_rate = 0.001
critic/AdamOptimizer.beta1 = 0.9
critic/AdamOptimizer.beta2 = 0.999
actor/AdamOptimizer.learning_rate = 0.0001
actor/AdamOptimizer.beta1 = 0.9
actor/AdamOptimizer.beta2 = 0.999
meta_critic/AdamOptimizer.learning_rate = 0.001
meta_critic/AdamOptimizer.beta1 = 0.9
meta_critic/AdamOptimizer.beta2 = 0.999
meta_actor/AdamOptimizer.learning_rate = 0.0001
meta_actor/AdamOptimizer.beta1 = 0.9
meta_actor/AdamOptimizer.beta2 = 0.999
repr/AdamOptimizer.learning_rate = 0.0001
repr/AdamOptimizer.beta1 = 0.9
repr/AdamOptimizer.beta2 = 0.999
#-*-Python-*-
create_maze_env.env_name = "AntBlock"
ZERO_OBS = False
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4), (20, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1", "eval2", "eval3"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
"eval2": [@eval2/ConstantSampler],
"eval3": [@eval3/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [3, 4]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [16, 0]
eval2/ConstantSampler.value = [16, 16]
eval3/ConstantSampler.value = [0, 16]
#-*-Python-*-
create_maze_env.env_name = "AntBlockMaze"
ZERO_OBS = False
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4), (12, 20))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1", "eval2", "eval3"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [2]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
"eval2": [@eval2/ConstantSampler],
"eval3": [@eval3/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [3, 4]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [8, 0]
eval2/ConstantSampler.value = [8, 16]
eval3/ConstantSampler.value = [0, 16]
#-*-Python-*-
create_maze_env.env_name = "AntFall"
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4, 0), (12, 28, 5))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [3]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1, 2]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [0, 27, 4.5]
#-*-Python-*-
create_maze_env.env_name = "AntFall"
IMAGES = True
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4, 0), (12, 28, 5))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [3]
meta/Context.samplers = {
"train": [@train/RandomSampler],
"explore": [@train/RandomSampler],
"eval1": [@eval1/ConstantSampler],
}
meta/Context.context_transition_fn = @task/relative_context_transition_fn
meta/Context.context_multi_transition_fn = @task/relative_context_multi_transition_fn
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1, 2]
task/negative_distance.relative_context = True
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
task/relative_context_transition_fn.k = 3
task/relative_context_multi_transition_fn.k = 3
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [0, 27, 0]
#-*-Python-*-
create_maze_env.env_name = "AntFall"
context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
meta_context_range = ((-4, -4, 0), (12, 28, 5))
RESET_EPISODE_PERIOD = 500
RESET_ENV_PERIOD = 1
# End episode every N steps
UvfAgent.reset_episode_cond_fn = @every_n_steps
every_n_steps.n = %RESET_EPISODE_PERIOD
train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
# Do a manual reset every N episodes
UvfAgent.reset_env_cond_fn = @every_n_episodes
every_n_episodes.n = %RESET_ENV_PERIOD
every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
## Config defaults
EVAL_MODES = ["eval1"]
## Config agent
CONTEXT = @agent/Context
META_CONTEXT = @meta/Context
## Config agent context
agent/Context.context_ranges = [%context_range]
agent/Context.context_shapes = [%SUBGOAL_DIM]
agent/Context.meta_action_every_n = 10
agent/Context.samplers = {
"train": [@train/DirectionSampler],
"explore": [@train/DirectionSampler],
}
agent/Context.context_transition_fn = @relative_context_transition_fn
agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
agent/Context.reward_fn = @uvf/negative_distance
## Config meta context
meta/Context.context_ranges = [%meta_context_range]
meta/Context.context_shapes = [3]
meta/Context.samplers = {
"train": [@eval1/ConstantSampler],
"explore": [@eval1/ConstantSampler],
"eval1": [@eval1/ConstantSampler],
}
meta/Context.reward_fn = @task/negative_distance
## Config rewards
task/negative_distance.state_indices = [0, 1, 2]
task/negative_distance.relative_context = False
task/negative_distance.diff = False
task/negative_distance.offset = 0.0
## Config samplers
train/RandomSampler.context_range = %meta_context_range
train/DirectionSampler.context_range = %context_range
train/DirectionSampler.k = %SUBGOAL_DIM
relative_context_transition_fn.k = %SUBGOAL_DIM
relative_context_multi_transition_fn.k = %SUBGOAL_DIM
MetaAgent.k = %SUBGOAL_DIM
eval1/ConstantSampler.value = [0, 27, 4.5]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment