Unverified Commit 35495758 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Change to gfile.Exists and deeper link to CIFAR-10 (#6066)

* Change to gfile.Exists and deeper link to CIFAR-10

* fix lint and use tf.gfile
parent 373776f2
......@@ -24,10 +24,10 @@ from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.resnet import resnet_model
from official.resnet import resnet_run_loop
from official.utils.flags import core as flags_core
from official.utils.logs import logger
HEIGHT = 32
WIDTH = 32
......@@ -52,7 +52,7 @@ DATASET_NAME = 'CIFAR-10'
###############################################################################
def get_filenames(is_training, data_dir):
"""Returns a list of filenames."""
assert os.path.exists(data_dir), (
assert tf.gfile.Exists(data_dir), (
'Run cifar10_download_and_extract.py first to download and extract the '
'CIFAR-10 data.')
......
......@@ -26,7 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
DATA_DIR = '/data/cifar10_data/'
DATA_DIR = '/data/cifar10_data/cifar-10-batches-bin'
class EstimatorCifar10BenchmarkTests(object):
......
......@@ -85,6 +85,7 @@ class KerasCifar10BenchmarkTests(object):
"""Test keras based model with Keras fit and distribution strategies."""
self._setup()
flags.FLAGS.num_gpus = 2
flags.FLAGS.data_dir = DATA_DIR
flags.FLAGS.data_dir = self._get_model_dir('keras_resnet56_2_gpu')
flags.FLAGS.batch_size = 128
flags.FLAGS.train_epochs = 182
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment