Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
74543c03
Commit
74543c03
authored
Aug 22, 2019
by
David Chen
Committed by
A. Unique TensorFlower
Aug 22, 2019
Browse files
Internal change
PiperOrigin-RevId: 264958330
parent
6252e588
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
22 deletions
+8
-22
official/transformer/v2/transformer_benchmark.py
official/transformer/v2/transformer_benchmark.py
+8
-22
No files found.
official/transformer/v2/transformer_benchmark.py
View file @
74543c03
...
@@ -31,6 +31,7 @@ from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
...
@@ -31,6 +31,7 @@ from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
TRANSFORMER_EN2DE_DATA_DIR_NAME
=
'wmt32k-en2de-official'
TRANSFORMER_EN2DE_DATA_DIR_NAME
=
'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME
=
'newstest2014'
EN2DE_2014_BLEU_DATA_DIR_NAME
=
'newstest2014'
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
TMP_DIR
=
os
.
getenv
(
'TMPDIR'
)
class
TransformerBenchmark
(
PerfZeroBenchmark
):
class
TransformerBenchmark
(
PerfZeroBenchmark
):
...
@@ -57,6 +58,11 @@ class TransformerBenchmark(PerfZeroBenchmark):
...
@@ -57,6 +58,11 @@ class TransformerBenchmark(PerfZeroBenchmark):
EN2DE_2014_BLEU_DATA_DIR_NAME
,
EN2DE_2014_BLEU_DATA_DIR_NAME
,
'newstest2014.de'
)
'newstest2014.de'
)
default_flags
[
'train_steps'
]
=
200
default_flags
[
'log_steps'
]
=
10
default_flags
[
'data_dir'
]
=
self
.
train_data_dir
default_flags
[
'vocab_file'
]
=
self
.
vocab_file
super
(
TransformerBenchmark
,
self
).
__init__
(
super
(
TransformerBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
,
output_dir
=
output_dir
,
default_flags
=
default_flags
,
default_flags
=
default_flags
,
...
@@ -619,19 +625,9 @@ class TransformerKerasBenchmark(TransformerBenchmark):
...
@@ -619,19 +625,9 @@ class TransformerKerasBenchmark(TransformerBenchmark):
class
TransformerBaseKerasBenchmarkReal
(
TransformerKerasBenchmark
):
class
TransformerBaseKerasBenchmarkReal
(
TransformerKerasBenchmark
):
"""Transformer based version real data benchmark tests."""
"""Transformer based version real data benchmark tests."""
def
__init__
(
self
,
output_dir
=
None
,
root_data_dir
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
root_data_dir
=
None
,
**
kwargs
):
train_data_dir
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
)
vocab_file
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
,
'vocab.ende.32768'
)
def_flags
=
{}
def_flags
=
{}
def_flags
[
'param_set'
]
=
'base'
def_flags
[
'param_set'
]
=
'base'
def_flags
[
'vocab_file'
]
=
vocab_file
def_flags
[
'data_dir'
]
=
train_data_dir
def_flags
[
'train_steps'
]
=
200
def_flags
[
'log_steps'
]
=
10
super
(
TransformerBaseKerasBenchmarkReal
,
self
).
__init__
(
super
(
TransformerBaseKerasBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
def_flags
,
output_dir
=
output_dir
,
default_flags
=
def_flags
,
...
@@ -641,19 +637,9 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
...
@@ -641,19 +637,9 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
class
TransformerBigKerasBenchmarkReal
(
TransformerKerasBenchmark
):
class
TransformerBigKerasBenchmarkReal
(
TransformerKerasBenchmark
):
"""Transformer based version real data benchmark tests."""
"""Transformer based version real data benchmark tests."""
def
__init__
(
self
,
output_dir
=
None
,
root_data_dir
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
root_data_dir
=
None
,
**
kwargs
):
train_data_dir
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
)
vocab_file
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
,
'vocab.ende.32768'
)
def_flags
=
{}
def_flags
=
{}
def_flags
[
'param_set'
]
=
'big'
def_flags
[
'param_set'
]
=
'big'
def_flags
[
'vocab_file'
]
=
vocab_file
def_flags
[
'data_dir'
]
=
train_data_dir
def_flags
[
'train_steps'
]
=
200
def_flags
[
'log_steps'
]
=
10
super
(
TransformerBigKerasBenchmarkReal
,
self
).
__init__
(
super
(
TransformerBigKerasBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
def_flags
,
output_dir
=
output_dir
,
default_flags
=
def_flags
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment