Commit cf4cae61 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 273966871
parent 8110bb64
......@@ -88,7 +88,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
root_data_dir=None,
default_flags=None,
**kwargs):
root_data_dir = root_data_dir if root_data_dir else ''
default_flags = {}
default_flags['dataset'] = 'ml-20m'
default_flags['num_gpus'] = 1
......
......@@ -44,6 +44,8 @@ class TransformerBenchmark(PerfZeroBenchmark):
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None):
assert tf.version.VERSION.startswith('2.')
root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME)
......@@ -650,7 +652,7 @@ class TransformerKerasBenchmark(TransformerBenchmark):
class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests."""
def __init__(self, output_dir=TMP_DIR, root_data_dir=None, **kwargs):
def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR, **kwargs):
def_flags = {}
def_flags['param_set'] = 'base'
def_flags['train_steps'] = 50
......@@ -664,7 +666,7 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests."""
def __init__(self, output_dir=TMP_DIR, root_data_dir=None, **kwargs):
def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR, **kwargs):
def_flags = {}
def_flags['param_set'] = 'big'
def_flags['train_steps'] = 50
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment