Commit 5c0c749b authored by Toby Boyd's avatar Toby Boyd
Browse files

Add tf.float32 to unittest args

parent 76dbcb5a
......@@ -61,7 +61,7 @@ class BaseTest(tf.test.TestCase):
fake_dataset = tf.data.FixedLengthRecordDataset(
filename, cifar10_main._RECORD_BYTES) # pylint: disable=protected-access
fake_dataset = fake_dataset.map(
lambda val: cifar10_main.parse_record(val, False))
lambda val: cifar10_main.parse_record(val, False, tf.float32))
image, label = fake_dataset.make_one_shot_iterator().get_next()
self.assertAllEqual(label.shape, ())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment