Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
39638d66
Commit
39638d66
authored
May 28, 2019
by
guptapriya
Browse files
Reduce max_length to 64 in static_batch cases.
parent
3bb5dd6c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
official/transformer/v2/transformer_benchmark.py
official/transformer/v2/transformer_benchmark.py
+5
-5
No files found.
official/transformer/v2/transformer_benchmark.py
View file @
39638d66
...
...
@@ -169,8 +169,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
FLAGS
.
batch_size
=
4096
FLAGS
.
train_steps
=
100000
FLAGS
.
steps_between_evals
=
5000
# TODO(guptapriya): Add max_length
FLAGS
.
static_batch
=
True
FLAGS
.
max_length
=
64
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_static_batch'
)
# These bleu scores are based on test runs after at this limited
# number of steps and batch size after verifying SOTA at 8xV100s.
...
...
@@ -216,8 +216,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
FLAGS
.
param_set
=
'base'
FLAGS
.
batch_size
=
4096
*
8
FLAGS
.
train_steps
=
100000
# TODO(guptapriya): Add max_length
FLAGS
.
static_batch
=
True
FLAGS
.
max_length
=
64
FLAGS
.
steps_between_evals
=
5000
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_static_batch'
)
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
...
...
@@ -280,8 +280,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
# TODO(guptapriya): Add max_length
FLAGS
.
static_batch
=
True
FLAGS
.
max_length
=
64
FLAGS
.
train_steps
=
100000
FLAGS
.
steps_between_evals
=
5000
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_static_batch'
)
...
...
@@ -330,8 +330,8 @@ class TransformerKerasBenchmark(TransformerBenchmark):
FLAGS
.
distribution_strategy
=
'off'
FLAGS
.
batch_size
=
self
.
batch_per_gpu
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_static_batch'
)
# TODO(guptapriya): Add max_length
FLAGS
.
static_batch
=
True
FLAGS
.
max_length
=
64
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
)
...
...
@@ -350,8 +350,8 @@ class TransformerKerasBenchmark(TransformerBenchmark):
FLAGS
.
num_gpus
=
8
FLAGS
.
batch_size
=
self
.
batch_per_gpu
*
8
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_static_batch'
)
# TODO(guptapriya): Add max_length
FLAGS
.
static_batch
=
True
FLAGS
.
max_length
=
64
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment