Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
aae0a947
Commit
aae0a947
authored
Aug 21, 2017
by
Lukasz Kaiser
Committed by
GitHub
Aug 21, 2017
Browse files
Merge pull request #2116 from cclauss/patch-3
rebar: center() is defined in utils
parents
55b440f3
1d5dba69
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
4 deletions
+12
-4
rebar/rebar.py
rebar/rebar.py
+6
-1
rebar/rebar_train.py
rebar/rebar_train.py
+6
-0
rebar/utils.py
rebar/utils.py
+0
-3
No files found.
rebar/rebar.py
View file @
aae0a947
...
@@ -26,6 +26,11 @@ import tensorflow.contrib.slim as slim
...
@@ -26,6 +26,11 @@ import tensorflow.contrib.slim as slim
from
tensorflow.python.ops
import
init_ops
from
tensorflow.python.ops
import
init_ops
import
utils
as
U
import
utils
as
U
try
:
xrange
# Python 2
except
NameError
:
xrange
=
range
# Python 3
FLAGS
=
tf
.
flags
.
FLAGS
FLAGS
=
tf
.
flags
.
FLAGS
Q_COLLECTION
=
"q_collection"
Q_COLLECTION
=
"q_collection"
...
@@ -293,7 +298,7 @@ class SBN(object): # REINFORCE
...
@@ -293,7 +298,7 @@ class SBN(object): # REINFORCE
logQHard
=
tf
.
add_n
(
logQHard
)
logQHard
=
tf
.
add_n
(
logQHard
)
# REINFORCE
# REINFORCE
learning_signal
=
tf
.
stop_gradient
(
center
(
reinforce_learning_signal
))
learning_signal
=
tf
.
stop_gradient
(
U
.
center
(
reinforce_learning_signal
))
self
.
optimizerLoss
=
-
(
learning_signal
*
logQHard
+
self
.
optimizerLoss
=
-
(
learning_signal
*
logQHard
+
reinforce_model_grad
)
reinforce_model_grad
)
self
.
lHat
=
map
(
tf
.
reduce_mean
,
[
self
.
lHat
=
map
(
tf
.
reduce_mean
,
[
...
...
rebar/rebar_train.py
View file @
aae0a947
...
@@ -28,6 +28,12 @@ import tensorflow as tf
...
@@ -28,6 +28,12 @@ import tensorflow as tf
import
rebar
import
rebar
import
datasets
import
datasets
import
logger
as
L
import
logger
as
L
try
:
xrange
# Python 2
except
NameError
:
xrange
=
range
# Python 3
gfile
=
tf
.
gfile
gfile
=
tf
.
gfile
tf
.
app
.
flags
.
DEFINE_string
(
"working_dir"
,
"/tmp/rebar"
,
tf
.
app
.
flags
.
DEFINE_string
(
"working_dir"
,
"/tmp/rebar"
,
...
...
rebar/utils.py
View file @
aae0a947
...
@@ -134,6 +134,3 @@ def logSumExp(t, axis=0, keep_dims = False):
...
@@ -134,6 +134,3 @@ def logSumExp(t, axis=0, keep_dims = False):
return
tf
.
expand_dims
(
res
,
axis
)
return
tf
.
expand_dims
(
res
,
axis
)
else
:
else
:
return
res
return
res
if
__name__
==
'__main__'
:
app
.
run
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment