Unverified Commit 4a733c29 authored by chengpohi's avatar chengpohi Committed by GitHub
Browse files

replace `FLAGS.batch_size` by `images.get_shape()[0]`

We can directly get `batch_size` by `images.get_shape()[0]` in `inference` method, since maybe we will not use `cifar.inputs` method to build the input.
parent 059c79ac
......@@ -240,7 +240,7 @@ def inference(images):
# local3
with tf.variable_scope('local3') as scope:
# Move everything into depth so we can perform a single matrix multiply.
reshape = tf.reshape(pool2, [FLAGS.batch_size, -1])
reshape = tf.reshape(pool2, [images.get_shape()[0], -1])
dim = reshape.get_shape()[1].value
weights = _variable_with_weight_decay('weights', shape=[dim, 384],
stddev=0.04, wd=0.004)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment