Commit 154e8c46 authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 267063328
parent b903049a
...@@ -21,6 +21,7 @@ import os ...@@ -21,6 +21,7 @@ import os
import time import time
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.staging.shakespeare import shakespeare_main from official.staging.shakespeare import shakespeare_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -28,6 +29,7 @@ from official.utils.misc import keras_utils ...@@ -28,6 +29,7 @@ from official.utils.misc import keras_utils
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt' SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt'
TMP_DIR = os.getenv('TMPDIR')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -212,7 +214,7 @@ class ShakespeareAccuracy(ShakespeareBenchmarkBase): ...@@ -212,7 +214,7 @@ class ShakespeareAccuracy(ShakespeareBenchmarkBase):
class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase): class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
"""Benchmark accuracy tests.""" """Benchmark accuracy tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=TMP_DIR, **kwargs):
"""Benchmark tests w/Keras. """Benchmark tests w/Keras.
Args: Args:
...@@ -369,3 +371,7 @@ class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase): ...@@ -369,3 +371,7 @@ class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
"""Run and report benchmark.""" """Run and report benchmark."""
super(ShakespeareKerasBenchmarkReal, self)._run_and_report_benchmark( super(ShakespeareKerasBenchmarkReal, self)._run_and_report_benchmark(
top_1_train_min=None, log_steps=FLAGS.log_steps) top_1_train_min=None, log_steps=FLAGS.log_steps)
if __name__ == '__main__':
tf.test.main()
...@@ -94,7 +94,7 @@ def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH): ...@@ -94,7 +94,7 @@ def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH):
A tuple, consisting of the Dataset and the class to character mapping A tuple, consisting of the Dataset and the class to character mapping
and character to class mapping. and character to class mapping.
""" """
with open(path_to_file, 'rb') as train_data: with tf.io.gfile.GFile(path_to_file, 'rb') as train_data:
text = train_data.read().decode(encoding='utf-8') text = train_data.read().decode(encoding='utf-8')
# Create vocab # Create vocab
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment