Commit d3628a74 authored by Alex Lee's avatar Alex Lee
Browse files

Fixes for compatibility with TF 1.0 and python 3.

parent 405bb623
...@@ -38,17 +38,11 @@ def init_state(inputs, ...@@ -38,17 +38,11 @@ def init_state(inputs,
if inputs is not None: if inputs is not None:
# Handle both the dynamic shape as well as the inferred shape. # Handle both the dynamic shape as well as the inferred shape.
inferred_batch_size = inputs.get_shape().with_rank_at_least(1)[0] inferred_batch_size = inputs.get_shape().with_rank_at_least(1)[0]
batch_size = tf.shape(inputs)[0]
dtype = inputs.dtype dtype = inputs.dtype
else: else:
inferred_batch_size = 0 inferred_batch_size = 0
batch_size = 0
initial_state = state_initializer( initial_state = state_initializer(
tf.stack([batch_size] + state_shape), [inferred_batch_size] + state_shape, dtype=dtype)
dtype=dtype)
initial_state.set_shape([inferred_batch_size] + state_shape)
return initial_state return initial_state
......
...@@ -103,21 +103,24 @@ class Model(object): ...@@ -103,21 +103,24 @@ class Model(object):
actions=None, actions=None,
states=None, states=None,
sequence_length=None, sequence_length=None,
reuse_scope=None): reuse_scope=None,
prefix=None):
if sequence_length is None: if sequence_length is None:
sequence_length = FLAGS.sequence_length sequence_length = FLAGS.sequence_length
self.prefix = prefix = tf.placeholder(tf.string, []) if prefix is None:
prefix = tf.placeholder(tf.string, [])
self.prefix = prefix
self.iter_num = tf.placeholder(tf.float32, []) self.iter_num = tf.placeholder(tf.float32, [])
summaries = [] summaries = []
# Split into timesteps. # Split into timesteps.
actions = tf.split(axis=1, num_or_size_splits=actions.get_shape()[1], value=actions) actions = tf.split(axis=1, num_or_size_splits=int(actions.get_shape()[1]), value=actions)
actions = [tf.squeeze(act) for act in actions] actions = [tf.squeeze(act) for act in actions]
states = tf.split(axis=1, num_or_size_splits=states.get_shape()[1], value=states) states = tf.split(axis=1, num_or_size_splits=int(states.get_shape()[1]), value=states)
states = [tf.squeeze(st) for st in states] states = [tf.squeeze(st) for st in states]
images = tf.split(axis=1, num_or_size_splits=images.get_shape()[1], value=images) images = tf.split(axis=1, num_or_size_splits=int(images.get_shape()[1]), value=images)
images = [tf.squeeze(img) for img in images] images = [tf.squeeze(img) for img in images]
if reuse_scope is None: if reuse_scope is None:
...@@ -183,17 +186,18 @@ class Model(object): ...@@ -183,17 +186,18 @@ class Model(object):
def main(unused_argv): def main(unused_argv):
print 'Constructing models and inputs.' print('Constructing models and inputs.')
with tf.variable_scope('model', reuse=None) as training_scope: with tf.variable_scope('model', reuse=None) as training_scope:
images, actions, states = build_tfrecord_input(training=True) images, actions, states = build_tfrecord_input(training=True)
model = Model(images, actions, states, FLAGS.sequence_length) model = Model(images, actions, states, FLAGS.sequence_length,
prefix='train')
with tf.variable_scope('val_model', reuse=None): with tf.variable_scope('val_model', reuse=None):
val_images, val_actions, val_states = build_tfrecord_input(training=False) val_images, val_actions, val_states = build_tfrecord_input(training=False)
val_model = Model(val_images, val_actions, val_states, val_model = Model(val_images, val_actions, val_states,
FLAGS.sequence_length, training_scope) FLAGS.sequence_length, training_scope, prefix='val')
print 'Constructing saver.' print('Constructing saver.')
# Make saver. # Make saver.
saver = tf.train.Saver( saver = tf.train.Saver(
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0) tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)
...@@ -214,8 +218,7 @@ def main(unused_argv): ...@@ -214,8 +218,7 @@ def main(unused_argv):
# Run training. # Run training.
for itr in range(FLAGS.num_iterations): for itr in range(FLAGS.num_iterations):
# Generate new batch of data. # Generate new batch of data.
feed_dict = {model.prefix: 'train', feed_dict = {model.iter_num: np.float32(itr),
model.iter_num: np.float32(itr),
model.lr: FLAGS.learning_rate} model.lr: FLAGS.learning_rate}
cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op], cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op],
feed_dict) feed_dict)
...@@ -226,7 +229,6 @@ def main(unused_argv): ...@@ -226,7 +229,6 @@ def main(unused_argv):
if (itr) % VAL_INTERVAL == 2: if (itr) % VAL_INTERVAL == 2:
# Run through validation set. # Run through validation set.
feed_dict = {val_model.lr: 0.0, feed_dict = {val_model.lr: 0.0,
val_model.prefix: 'val',
val_model.iter_num: np.float32(itr)} val_model.iter_num: np.float32(itr)}
_, val_summary_str = sess.run([val_model.train_op, val_model.summ_op], _, val_summary_str = sess.run([val_model.train_op, val_model.summ_op],
feed_dict) feed_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment