"vscode:/vscode.git/clone" did not exist on "d219867c37a6f514fed83a8c8dbb5defa0999521"
Commit 74543c03 authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 264958330
parent 6252e588
...@@ -31,6 +31,7 @@ from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark ...@@ -31,6 +31,7 @@ from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official' TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014' EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
TMP_DIR = os.getenv('TMPDIR')
class TransformerBenchmark(PerfZeroBenchmark): class TransformerBenchmark(PerfZeroBenchmark):
...@@ -57,6 +58,11 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -57,6 +58,11 @@ class TransformerBenchmark(PerfZeroBenchmark):
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de') 'newstest2014.de')
default_flags['train_steps'] = 200
default_flags['log_steps'] = 10
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__( super(TransformerBenchmark, self).__init__(
output_dir=output_dir, output_dir=output_dir,
default_flags=default_flags, default_flags=default_flags,
...@@ -619,19 +625,9 @@ class TransformerKerasBenchmark(TransformerBenchmark): ...@@ -619,19 +625,9 @@ class TransformerKerasBenchmark(TransformerBenchmark):
class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark): class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests.""" """Transformer based version real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=TMP_DIR, root_data_dir=None, **kwargs):
train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME)
vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768')
def_flags = {} def_flags = {}
def_flags['param_set'] = 'base' def_flags['param_set'] = 'base'
def_flags['vocab_file'] = vocab_file
def_flags['data_dir'] = train_data_dir
def_flags['train_steps'] = 200
def_flags['log_steps'] = 10
super(TransformerBaseKerasBenchmarkReal, self).__init__( super(TransformerBaseKerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags, output_dir=output_dir, default_flags=def_flags,
...@@ -641,19 +637,9 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -641,19 +637,9 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests.""" """Transformer based version real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=TMP_DIR, root_data_dir=None, **kwargs):
train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME)
vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768')
def_flags = {} def_flags = {}
def_flags['param_set'] = 'big' def_flags['param_set'] = 'big'
def_flags['vocab_file'] = vocab_file
def_flags['data_dir'] = train_data_dir
def_flags['train_steps'] = 200
def_flags['log_steps'] = 10
super(TransformerBigKerasBenchmarkReal, self).__init__( super(TransformerBigKerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags, output_dir=output_dir, default_flags=def_flags,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment