Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1453d070
Commit
1453d070
authored
Mar 06, 2018
by
MTDzi
Browse files
Fix to Memory.query + other small changes in Learning to Remember Rare Events
parent
e029542a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
46 deletions
+44
-46
research/learning_to_remember_rare_events/memory.py
research/learning_to_remember_rare_events/memory.py
+15
-10
research/learning_to_remember_rare_events/model.py
research/learning_to_remember_rare_events/model.py
+22
-28
research/learning_to_remember_rare_events/train.py
research/learning_to_remember_rare_events/train.py
+7
-8
No files found.
research/learning_to_remember_rare_events/memory.py
View file @
1453d070
...
@@ -173,6 +173,21 @@ class Memory(object):
...
@@ -173,6 +173,21 @@ class Memory(object):
softmax_temp
=
max
(
1.0
,
np
.
log
(
0.2
*
self
.
choose_k
)
/
self
.
alpha
)
softmax_temp
=
max
(
1.0
,
np
.
log
(
0.2
*
self
.
choose_k
)
/
self
.
alpha
)
mask
=
tf
.
nn
.
softmax
(
hint_pool_sims
[:,
:
choose_k
-
1
]
*
softmax_temp
)
mask
=
tf
.
nn
.
softmax
(
hint_pool_sims
[:,
:
choose_k
-
1
]
*
softmax_temp
)
# prepare returned values
nearest_neighbor
=
tf
.
to_int32
(
tf
.
argmax
(
hint_pool_sims
[:,
:
choose_k
-
1
],
1
))
no_teacher_idxs
=
tf
.
gather
(
tf
.
reshape
(
hint_pool_idxs
,
[
-
1
]),
nearest_neighbor
+
choose_k
*
tf
.
range
(
batch_size
))
with
tf
.
device
(
self
.
var_cache_device
):
result
=
tf
.
gather
(
self
.
mem_vals
,
tf
.
reshape
(
no_teacher_idxs
,
[
-
1
]))
if
not
output_given
:
teacher_loss
=
None
return
result
,
mask
,
teacher_loss
# prepare hints from the teacher on hint pool
# prepare hints from the teacher on hint pool
teacher_hints
=
tf
.
to_float
(
teacher_hints
=
tf
.
to_float
(
tf
.
abs
(
tf
.
expand_dims
(
intended_output
,
1
)
-
hint_pool_mem_vals
))
tf
.
abs
(
tf
.
expand_dims
(
intended_output
,
1
)
-
hint_pool_mem_vals
))
...
@@ -192,13 +207,6 @@ class Memory(object):
...
@@ -192,13 +207,6 @@ class Memory(object):
teacher_vals
*=
(
teacher_vals
*=
(
1
-
tf
.
to_float
(
tf
.
equal
(
0.0
,
tf
.
reduce_sum
(
teacher_hints
,
1
))))
1
-
tf
.
to_float
(
tf
.
equal
(
0.0
,
tf
.
reduce_sum
(
teacher_hints
,
1
))))
# prepare returned values
nearest_neighbor
=
tf
.
to_int32
(
tf
.
argmax
(
hint_pool_sims
[:,
:
choose_k
-
1
],
1
))
no_teacher_idxs
=
tf
.
gather
(
tf
.
reshape
(
hint_pool_idxs
,
[
-
1
]),
nearest_neighbor
+
choose_k
*
tf
.
range
(
batch_size
))
# we'll determine whether to do an update to memory based on whether
# we'll determine whether to do an update to memory based on whether
# memory was queried correctly
# memory was queried correctly
sliced_hints
=
tf
.
slice
(
teacher_hints
,
[
0
,
0
],
[
-
1
,
self
.
correct_in_top
])
sliced_hints
=
tf
.
slice
(
teacher_hints
,
[
0
,
0
],
[
-
1
,
self
.
correct_in_top
])
...
@@ -208,9 +216,6 @@ class Memory(object):
...
@@ -208,9 +216,6 @@ class Memory(object):
teacher_loss
=
(
tf
.
nn
.
relu
(
neg_teacher_vals
-
teacher_vals
+
self
.
alpha
)
teacher_loss
=
(
tf
.
nn
.
relu
(
neg_teacher_vals
-
teacher_vals
+
self
.
alpha
)
-
self
.
alpha
)
-
self
.
alpha
)
with
tf
.
device
(
self
.
var_cache_device
):
result
=
tf
.
gather
(
self
.
mem_vals
,
tf
.
reshape
(
no_teacher_idxs
,
[
-
1
]))
# prepare memory updates
# prepare memory updates
update_keys
=
normalized_query
update_keys
=
normalized_query
update_vals
=
intended_output
update_vals
=
intended_output
...
...
research/learning_to_remember_rare_events/model.py
View file @
1453d070
...
@@ -178,27 +178,13 @@ class Model(object):
...
@@ -178,27 +178,13 @@ class Model(object):
self
.
x
,
self
.
y
=
self
.
get_xy_placeholders
()
self
.
x
,
self
.
y
=
self
.
get_xy_placeholders
()
# This context creates variables
with
tf
.
variable_scope
(
'core'
,
reuse
=
None
):
with
tf
.
variable_scope
(
'core'
,
reuse
=
None
):
self
.
loss
,
self
.
gradient_ops
=
self
.
train
(
self
.
x
,
self
.
y
)
self
.
loss
,
self
.
gradient_ops
=
self
.
train
(
self
.
x
,
self
.
y
)
# And this one re-uses them (thus the `reuse=True`)
with
tf
.
variable_scope
(
'core'
,
reuse
=
True
):
with
tf
.
variable_scope
(
'core'
,
reuse
=
True
):
self
.
y_preds
=
self
.
eval
(
self
.
x
,
self
.
y
)
self
.
y_preds
=
self
.
eval
(
self
.
x
,
self
.
y
)
# setup memory "reset" ops
(
self
.
mem_keys
,
self
.
mem_vals
,
self
.
mem_age
,
self
.
recent_idx
)
=
self
.
memory
.
get
()
self
.
mem_keys_reset
=
tf
.
placeholder
(
self
.
mem_keys
.
dtype
,
tf
.
identity
(
self
.
mem_keys
).
shape
)
self
.
mem_vals_reset
=
tf
.
placeholder
(
self
.
mem_vals
.
dtype
,
tf
.
identity
(
self
.
mem_vals
).
shape
)
self
.
mem_age_reset
=
tf
.
placeholder
(
self
.
mem_age
.
dtype
,
tf
.
identity
(
self
.
mem_age
).
shape
)
self
.
recent_idx_reset
=
tf
.
placeholder
(
self
.
recent_idx
.
dtype
,
tf
.
identity
(
self
.
recent_idx
).
shape
)
self
.
mem_reset_op
=
self
.
memory
.
set
(
self
.
mem_keys_reset
,
self
.
mem_vals_reset
,
self
.
mem_age_reset
,
None
)
def
training_ops
(
self
,
loss
):
def
training_ops
(
self
,
loss
):
opt
=
self
.
get_optimizer
()
opt
=
self
.
get_optimizer
()
params
=
tf
.
trainable_variables
()
params
=
tf
.
trainable_variables
()
...
@@ -254,8 +240,14 @@ class Model(object):
...
@@ -254,8 +240,14 @@ class Model(object):
Predicted y.
Predicted y.
"""
"""
cur_memory
=
sess
.
run
([
self
.
mem_keys
,
self
.
mem_vals
,
# Storing current memory state to restore it after prediction
self
.
mem_age
])
mem_keys
,
mem_vals
,
mem_age
,
_
=
self
.
memory
.
get
()
cur_memory
=
(
tf
.
identity
(
mem_keys
),
tf
.
identity
(
mem_vals
),
tf
.
identity
(
mem_age
),
None
,
)
outputs
=
[
self
.
y_preds
]
outputs
=
[
self
.
y_preds
]
if
y
is
None
:
if
y
is
None
:
...
@@ -263,10 +255,8 @@ class Model(object):
...
@@ -263,10 +255,8 @@ class Model(object):
else
:
else
:
ret
=
sess
.
run
(
outputs
,
feed_dict
=
{
self
.
x
:
x
,
self
.
y
:
y
})
ret
=
sess
.
run
(
outputs
,
feed_dict
=
{
self
.
x
:
x
,
self
.
y
:
y
})
sess
.
run
([
self
.
mem_reset_op
],
# Restoring memory state
feed_dict
=
{
self
.
mem_keys_reset
:
cur_memory
[
0
],
self
.
memory
.
set
(
*
cur_memory
)
self
.
mem_vals_reset
:
cur_memory
[
1
],
self
.
mem_age_reset
:
cur_memory
[
2
]})
return
ret
return
ret
...
@@ -284,8 +274,14 @@ class Model(object):
...
@@ -284,8 +274,14 @@ class Model(object):
List of predicted y.
List of predicted y.
"""
"""
cur_memory
=
sess
.
run
([
self
.
mem_keys
,
self
.
mem_vals
,
# Storing current memory state to restore it after prediction
self
.
mem_age
])
mem_keys
,
mem_vals
,
mem_age
,
_
=
self
.
memory
.
get
()
cur_memory
=
(
tf
.
identity
(
mem_keys
),
tf
.
identity
(
mem_vals
),
tf
.
identity
(
mem_age
),
None
,
)
if
clear_memory
:
if
clear_memory
:
self
.
clear_memory
(
sess
)
self
.
clear_memory
(
sess
)
...
@@ -297,10 +293,8 @@ class Model(object):
...
@@ -297,10 +293,8 @@ class Model(object):
y_pred
=
out
[
0
]
y_pred
=
out
[
0
]
y_preds
.
append
(
y_pred
)
y_preds
.
append
(
y_pred
)
sess
.
run
([
self
.
mem_reset_op
],
# Restoring memory state
feed_dict
=
{
self
.
mem_keys_reset
:
cur_memory
[
0
],
self
.
memory
.
set
(
*
cur_memory
)
self
.
mem_vals_reset
:
cur_memory
[
1
],
self
.
mem_age_reset
:
cur_memory
[
2
]})
return
y_preds
return
y_preds
...
...
research/learning_to_remember_rare_events/train.py
View file @
1453d070
...
@@ -208,17 +208,16 @@ class Trainer(object):
...
@@ -208,17 +208,16 @@ class Trainer(object):
correct
.
append
(
self
.
compute_correct
(
np
.
array
(
y
),
y_preds
))
correct
.
append
(
self
.
compute_correct
(
np
.
array
(
y
),
y_preds
))
# compute per-shot accuracies
# compute per-shot accuracies
seen_counts
=
[
[
0
]
*
episode_width
for
_
in
xrange
(
batch_size
)]
seen_counts
=
[
0
]
*
episode_width
# loop over episode steps
# loop over episode steps
for
yy
,
yy_preds
in
zip
(
y
,
y_preds
):
for
yy
,
yy_preds
in
zip
(
y
,
y_preds
):
# loop over batch examples
# loop over batch examples
for
k
,
(
yyy
,
yyy_preds
)
in
enumerate
(
zip
(
yy
,
yy_preds
)):
yyy
,
yyy_preds
=
int
(
yy
[
0
]),
int
(
yy_preds
[
0
])
yyy
,
yyy_preds
=
int
(
yyy
),
int
(
yyy_preds
)
count
=
seen_counts
[
yyy
%
episode_width
]
count
=
seen_counts
[
k
][
yyy
%
episode_width
]
if
count
in
correct_by_shot
:
if
count
in
correct_by_shot
:
correct_by_shot
[
count
].
append
(
correct_by_shot
[
count
].
append
(
self
.
individual_compute_correct
(
yyy
,
yyy_preds
))
self
.
individual_compute_correct
(
yyy
,
yyy_preds
))
seen_counts
[
yyy
%
episode_width
]
=
count
+
1
seen_counts
[
k
][
yyy
%
episode_width
]
=
count
+
1
logging
.
info
(
'validation overall accuracy %f'
,
np
.
mean
(
correct
))
logging
.
info
(
'validation overall accuracy %f'
,
np
.
mean
(
correct
))
logging
.
info
(
'%d-shot: %.3f, '
*
num_shots
,
logging
.
info
(
'%d-shot: %.3f, '
*
num_shots
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment