Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c173bd9b
Commit
c173bd9b
authored
Mar 07, 2018
by
MTDzi
Browse files
Changes to "Learning to Remember..." to make it runnable in Python3.5
parent
1453d070
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
8 deletions
+8
-8
research/learning_to_remember_rare_events/data_utils.py
research/learning_to_remember_rare_events/data_utils.py
+5
-5
research/learning_to_remember_rare_events/train.py
research/learning_to_remember_rare_events/train.py
+3
-3
No files found.
research/learning_to_remember_rare_events/data_utils.py
View file @
c173bd9b
...
@@ -20,10 +20,10 @@ Simply call
...
@@ -20,10 +20,10 @@ Simply call
python data_utils.py
python data_utils.py
"""
"""
import
cPickle
as
pickle
import
logging
import
logging
import
os
import
os
import
subprocess
import
subprocess
from
six.moves
import
cPickle
as
pickle
import
numpy
as
np
import
numpy
as
np
from
scipy.misc
import
imresize
from
scipy.misc
import
imresize
...
@@ -54,9 +54,9 @@ def get_data():
...
@@ -54,9 +54,9 @@ def get_data():
Train and test data as dictionaries mapping
Train and test data as dictionaries mapping
label to list of examples.
label to list of examples.
"""
"""
with
tf
.
gfile
.
GFile
(
DATA_FILE_FORMAT
%
'train'
)
as
f
:
with
tf
.
gfile
.
GFile
(
DATA_FILE_FORMAT
%
'train'
,
'rb'
)
as
f
:
processed_train_data
=
pickle
.
load
(
f
)
processed_train_data
=
pickle
.
load
(
f
)
with
tf
.
gfile
.
GFile
(
DATA_FILE_FORMAT
%
'test'
)
as
f
:
with
tf
.
gfile
.
GFile
(
DATA_FILE_FORMAT
%
'test'
,
'rb'
)
as
f
:
processed_test_data
=
pickle
.
load
(
f
)
processed_test_data
=
pickle
.
load
(
f
)
train_data
=
{}
train_data
=
{}
...
@@ -72,9 +72,9 @@ def get_data():
...
@@ -72,9 +72,9 @@ def get_data():
intersection
=
set
(
train_data
.
keys
())
&
set
(
test_data
.
keys
())
intersection
=
set
(
train_data
.
keys
())
&
set
(
test_data
.
keys
())
assert
not
intersection
,
'Train and test data intersect.'
assert
not
intersection
,
'Train and test data intersect.'
ok_num_examples
=
[
len
(
ll
)
==
20
for
_
,
ll
in
train_data
.
iter
items
()]
ok_num_examples
=
[
len
(
ll
)
==
20
for
_
,
ll
in
train_data
.
items
()]
assert
all
(
ok_num_examples
),
'Bad number of examples in train data.'
assert
all
(
ok_num_examples
),
'Bad number of examples in train data.'
ok_num_examples
=
[
len
(
ll
)
==
20
for
_
,
ll
in
test_data
.
iter
items
()]
ok_num_examples
=
[
len
(
ll
)
==
20
for
_
,
ll
in
test_data
.
items
()]
assert
all
(
ok_num_examples
),
'Bad number of examples in test data.'
assert
all
(
ok_num_examples
),
'Bad number of examples in test data.'
logging
.
info
(
'Number of labels in train data: %d.'
,
len
(
train_data
))
logging
.
info
(
'Number of labels in train data: %d.'
,
len
(
train_data
))
...
...
research/learning_to_remember_rare_events/train.py
View file @
c173bd9b
...
@@ -112,7 +112,7 @@ class Trainer(object):
...
@@ -112,7 +112,7 @@ class Trainer(object):
remainders
=
[
0
]
*
(
episode_width
-
remainder
)
+
[
1
]
*
remainder
remainders
=
[
0
]
*
(
episode_width
-
remainder
)
+
[
1
]
*
remainder
episode_x
=
[
episode_x
=
[
random
.
sample
(
data
[
lab
],
random
.
sample
(
data
[
lab
],
r
+
(
episode_length
-
remainder
)
/
episode_width
)
r
+
(
episode_length
-
remainder
)
/
/
episode_width
)
for
lab
,
r
in
zip
(
episode_labels
,
remainders
)]
for
lab
,
r
in
zip
(
episode_labels
,
remainders
)]
episode
=
sum
([[(
x
,
i
,
ii
)
for
ii
,
x
in
enumerate
(
xx
)]
episode
=
sum
([[(
x
,
i
,
ii
)
for
ii
,
x
in
enumerate
(
xx
)]
for
i
,
xx
in
enumerate
(
episode_x
)],
[])
for
i
,
xx
in
enumerate
(
episode_x
)],
[])
...
@@ -160,9 +160,9 @@ class Trainer(object):
...
@@ -160,9 +160,9 @@ class Trainer(object):
logging
.
info
(
'batch_size %d'
,
batch_size
)
logging
.
info
(
'batch_size %d'
,
batch_size
)
assert
all
(
len
(
v
)
>=
float
(
episode_length
)
/
episode_width
assert
all
(
len
(
v
)
>=
float
(
episode_length
)
/
episode_width
for
v
in
train_data
.
iter
values
())
for
v
in
train_data
.
values
())
assert
all
(
len
(
v
)
>=
float
(
episode_length
)
/
episode_width
assert
all
(
len
(
v
)
>=
float
(
episode_length
)
/
episode_width
for
v
in
valid_data
.
iter
values
())
for
v
in
valid_data
.
values
())
output_dim
=
episode_width
output_dim
=
episode_width
self
.
model
=
self
.
get_model
()
self
.
model
=
self
.
get_model
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment