"src/graph/vscode:/vscode.git/clone" did not exist on "79b057f0cf493995c133c1ff8fe2fa90f802ed42"
Unverified Commit 4a733c29 authored by chengpohi's avatar chengpohi Committed by GitHub
Browse files

replace `FLAGS.batch_size` by `images.get_shape()[0]`

We can directly get `batch_size` by `images.get_shape()[0]` in `inference` method, since maybe we will not use `cifar.inputs` method to build the input.
parent 059c79ac
...@@ -240,7 +240,7 @@ def inference(images): ...@@ -240,7 +240,7 @@ def inference(images):
# local3 # local3
with tf.variable_scope('local3') as scope: with tf.variable_scope('local3') as scope:
# Move everything into depth so we can perform a single matrix multiply. # Move everything into depth so we can perform a single matrix multiply.
reshape = tf.reshape(pool2, [FLAGS.batch_size, -1]) reshape = tf.reshape(pool2, [images.get_shape()[0], -1])
dim = reshape.get_shape()[1].value dim = reshape.get_shape()[1].value
weights = _variable_with_weight_decay('weights', shape=[dim, 384], weights = _variable_with_weight_decay('weights', shape=[dim, 384],
stddev=0.04, wd=0.004) stddev=0.04, wd=0.004)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment