Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d3628a74
Commit
d3628a74
authored
Apr 10, 2017
by
Alex Lee
Browse files
Fixes for compatibility with TF 1.0 and python 3.
parent
405bb623
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
19 deletions
+15
-19
video_prediction/lstm_ops.py
video_prediction/lstm_ops.py
+1
-7
video_prediction/prediction_train.py
video_prediction/prediction_train.py
+14
-12
No files found.
video_prediction/lstm_ops.py
View file @
d3628a74
...
...
@@ -38,17 +38,11 @@ def init_state(inputs,
if
inputs
is
not
None
:
# Handle both the dynamic shape as well as the inferred shape.
inferred_batch_size
=
inputs
.
get_shape
().
with_rank_at_least
(
1
)[
0
]
batch_size
=
tf
.
shape
(
inputs
)[
0
]
dtype
=
inputs
.
dtype
else
:
inferred_batch_size
=
0
batch_size
=
0
initial_state
=
state_initializer
(
tf
.
stack
([
batch_size
]
+
state_shape
),
dtype
=
dtype
)
initial_state
.
set_shape
([
inferred_batch_size
]
+
state_shape
)
[
inferred_batch_size
]
+
state_shape
,
dtype
=
dtype
)
return
initial_state
...
...
video_prediction/prediction_train.py
View file @
d3628a74
...
...
@@ -103,21 +103,24 @@ class Model(object):
actions
=
None
,
states
=
None
,
sequence_length
=
None
,
reuse_scope
=
None
):
reuse_scope
=
None
,
prefix
=
None
):
if
sequence_length
is
None
:
sequence_length
=
FLAGS
.
sequence_length
self
.
prefix
=
prefix
=
tf
.
placeholder
(
tf
.
string
,
[])
if
prefix
is
None
:
prefix
=
tf
.
placeholder
(
tf
.
string
,
[])
self
.
prefix
=
prefix
self
.
iter_num
=
tf
.
placeholder
(
tf
.
float32
,
[])
summaries
=
[]
# Split into timesteps.
actions
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
actions
.
get_shape
()[
1
],
value
=
actions
)
actions
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
int
(
actions
.
get_shape
()[
1
]
)
,
value
=
actions
)
actions
=
[
tf
.
squeeze
(
act
)
for
act
in
actions
]
states
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
states
.
get_shape
()[
1
],
value
=
states
)
states
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
int
(
states
.
get_shape
()[
1
]
)
,
value
=
states
)
states
=
[
tf
.
squeeze
(
st
)
for
st
in
states
]
images
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
images
.
get_shape
()[
1
],
value
=
images
)
images
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
int
(
images
.
get_shape
()[
1
]
)
,
value
=
images
)
images
=
[
tf
.
squeeze
(
img
)
for
img
in
images
]
if
reuse_scope
is
None
:
...
...
@@ -183,17 +186,18 @@ class Model(object):
def
main
(
unused_argv
):
print
'Constructing models and inputs.'
print
(
'Constructing models and inputs.'
)
with
tf
.
variable_scope
(
'model'
,
reuse
=
None
)
as
training_scope
:
images
,
actions
,
states
=
build_tfrecord_input
(
training
=
True
)
model
=
Model
(
images
,
actions
,
states
,
FLAGS
.
sequence_length
)
model
=
Model
(
images
,
actions
,
states
,
FLAGS
.
sequence_length
,
prefix
=
'train'
)
with
tf
.
variable_scope
(
'val_model'
,
reuse
=
None
):
val_images
,
val_actions
,
val_states
=
build_tfrecord_input
(
training
=
False
)
val_model
=
Model
(
val_images
,
val_actions
,
val_states
,
FLAGS
.
sequence_length
,
training_scope
)
FLAGS
.
sequence_length
,
training_scope
,
prefix
=
'val'
)
print
'Constructing saver.'
print
(
'Constructing saver.'
)
# Make saver.
saver
=
tf
.
train
.
Saver
(
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
),
max_to_keep
=
0
)
...
...
@@ -214,8 +218,7 @@ def main(unused_argv):
# Run training.
for
itr
in
range
(
FLAGS
.
num_iterations
):
# Generate new batch of data.
feed_dict
=
{
model
.
prefix
:
'train'
,
model
.
iter_num
:
np
.
float32
(
itr
),
feed_dict
=
{
model
.
iter_num
:
np
.
float32
(
itr
),
model
.
lr
:
FLAGS
.
learning_rate
}
cost
,
_
,
summary_str
=
sess
.
run
([
model
.
loss
,
model
.
train_op
,
model
.
summ_op
],
feed_dict
)
...
...
@@ -226,7 +229,6 @@ def main(unused_argv):
if
(
itr
)
%
VAL_INTERVAL
==
2
:
# Run through validation set.
feed_dict
=
{
val_model
.
lr
:
0.0
,
val_model
.
prefix
:
'val'
,
val_model
.
iter_num
:
np
.
float32
(
itr
)}
_
,
val_summary_str
=
sess
.
run
([
val_model
.
train_op
,
val_model
.
summ_op
],
feed_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment