Commit 154e8c46 authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 267063328
parent b903049a
......@@ -21,6 +21,7 @@ import os
import time
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.staging.shakespeare import shakespeare_main
from official.utils.flags import core as flags_core
......@@ -28,6 +29,7 @@ from official.utils.misc import keras_utils
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt'
TMP_DIR = os.getenv('TMPDIR')
FLAGS = flags.FLAGS
......@@ -212,7 +214,7 @@ class ShakespeareAccuracy(ShakespeareBenchmarkBase):
class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
"""Benchmark accuracy tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def __init__(self, output_dir=None, root_data_dir=TMP_DIR, **kwargs):
"""Benchmark tests w/Keras.
Args:
......@@ -369,3 +371,7 @@ class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
"""Run and report benchmark."""
super(ShakespeareKerasBenchmarkReal, self)._run_and_report_benchmark(
top_1_train_min=None, log_steps=FLAGS.log_steps)
if __name__ == '__main__':
tf.test.main()
......@@ -94,7 +94,7 @@ def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH):
A tuple, consisting of the Dataset and the class to character mapping
and character to class mapping.
"""
with open(path_to_file, 'rb') as train_data:
with tf.io.gfile.GFile(path_to_file, 'rb') as train_data:
text = train_data.read().decode(encoding='utf-8')
# Create vocab
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment