Commit cf4cae61 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 273966871
parent 8110bb64
...@@ -88,7 +88,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -88,7 +88,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
root_data_dir=None, root_data_dir=None,
default_flags=None, default_flags=None,
**kwargs): **kwargs):
root_data_dir = root_data_dir if root_data_dir else ''
default_flags = {} default_flags = {}
default_flags['dataset'] = 'ml-20m' default_flags['dataset'] = 'ml-20m'
default_flags['num_gpus'] = 1 default_flags['num_gpus'] = 1
......
...@@ -44,6 +44,8 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -44,6 +44,8 @@ class TransformerBenchmark(PerfZeroBenchmark):
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None, def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None): flag_methods=None):
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir, self.train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME) TRANSFORMER_EN2DE_DATA_DIR_NAME)
...@@ -650,7 +652,7 @@ class TransformerKerasBenchmark(TransformerBenchmark): ...@@ -650,7 +652,7 @@ class TransformerKerasBenchmark(TransformerBenchmark):
class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark): class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests.""" """Transformer based version real data benchmark tests."""
def __init__(self, output_dir=TMP_DIR, root_data_dir=None, **kwargs): def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR, **kwargs):
def_flags = {} def_flags = {}
def_flags['param_set'] = 'base' def_flags['param_set'] = 'base'
def_flags['train_steps'] = 50 def_flags['train_steps'] = 50
...@@ -664,7 +666,7 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -664,7 +666,7 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests.""" """Transformer based version real data benchmark tests."""
def __init__(self, output_dir=TMP_DIR, root_data_dir=None, **kwargs): def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR, **kwargs):
def_flags = {} def_flags = {}
def_flags['param_set'] = 'big' def_flags['param_set'] = 'big'
def_flags['train_steps'] = 50 def_flags['train_steps'] = 50
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment