Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cd00b9a7
Commit
cd00b9a7
authored
Oct 24, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Oct 24, 2019
Browse files
Remove unnecessary flags
PiperOrigin-RevId: 276518206
parent
87d6459a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
31 deletions
+17
-31
official/transformer/v2/misc.py
official/transformer/v2/misc.py
+0
-6
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+7
-15
official/transformer/v2/transformer_main_test.py
official/transformer/v2/transformer_main_test.py
+10
-10
No files found.
official/transformer/v2/misc.py
View file @
cd00b9a7
...
...
@@ -180,12 +180,6 @@ def define_transformer_flags():
default
=
False
,
help
=
flags_core
.
help_wrap
(
'Whether the model runs with custom training loop.'
))
flags
.
DEFINE_bool
(
name
=
'use_tpu_2vm_config'
,
default
=
False
,
help
=
flags_core
.
help_wrap
(
'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.'
))
flags
.
DEFINE_integer
(
name
=
'decode_batch_size'
,
default
=
32
,
...
...
official/transformer/v2/transformer_main.py
View file @
cd00b9a7
...
...
@@ -449,7 +449,6 @@ def main(_):
with
logger
.
benchmark_context
(
flags_obj
):
task
=
TransformerTask
(
flags_obj
)
def
_run_task
(
task
):
if
flags_obj
.
mode
==
"train"
:
task
.
train
()
elif
flags_obj
.
mode
==
"predict"
:
...
...
@@ -459,13 +458,6 @@ def main(_):
else
:
raise
ValueError
(
"Invalid mode {}"
.
format
(
flags_obj
.
mode
))
if
flags_obj
.
distribution_strategy
!=
"tpu"
:
_run_task
(
task
)
else
:
primary_cpu_task
=
"/job:worker"
if
flags_obj
.
use_tpu_2vm_config
else
""
with
tf
.
device
(
primary_cpu_task
):
_run_task
(
task
)
if
__name__
==
"__main__"
:
tf
.
compat
.
v1
.
enable_v2_behavior
()
...
...
official/transformer/v2/transformer_main_test.py
View file @
cd00b9a7
...
...
@@ -28,7 +28,7 @@ from absl.testing import flagsaver
import
tensorflow
as
tf
from
official.transformer.v2
import
misc
from
official.transformer.v2
import
transformer_main
as
tm
from
official.transformer.v2
import
transformer_main
from
official.utils.misc
import
keras_utils
from
tensorflow.python.eager
import
context
# pylint: disable=ungrouped-imports
...
...
@@ -84,7 +84,7 @@ class TransformerTaskTest(tf.test.TestCase):
def
test_train_no_dist_strat
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
t
=
t
m
.
TransformerTask
(
FLAGS
)
t
=
t
ransformer_main
.
TransformerTask
(
FLAGS
)
t
.
train
()
def
test_train_static_batch
(
self
):
...
...
@@ -96,20 +96,20 @@ class TransformerTaskTest(tf.test.TestCase):
else
:
FLAGS
.
num_gpus
=
0
FLAGS
.
static_batch
=
True
t
=
t
m
.
TransformerTask
(
FLAGS
)
t
=
t
ransformer_main
.
TransformerTask
(
FLAGS
)
t
.
train
()
@
unittest
.
skipUnless
(
tf
.
test
.
is_built_with_cuda
(),
'requires GPU'
)
def
test_train_1_gpu_with_dist_strat
(
self
):
FLAGS
.
distribution_strategy
=
'one_device'
t
=
t
m
.
TransformerTask
(
FLAGS
)
t
=
t
ransformer_main
.
TransformerTask
(
FLAGS
)
t
.
train
()
@
unittest
.
skipUnless
(
tf
.
test
.
is_built_with_cuda
(),
'requires GPU'
)
def
test_train_fp16
(
self
):
FLAGS
.
distribution_strategy
=
'one_device'
FLAGS
.
dtype
=
'fp16'
t
=
t
m
.
TransformerTask
(
FLAGS
)
t
=
t
ransformer_main
.
TransformerTask
(
FLAGS
)
t
.
train
()
@
unittest
.
skipUnless
(
tf
.
test
.
is_built_with_cuda
(),
'requires GPU'
)
...
...
@@ -121,7 +121,7 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS
.
distribution_strategy
=
'mirrored'
FLAGS
.
num_gpus
=
2
FLAGS
.
param_set
=
'base'
t
=
t
m
.
TransformerTask
(
FLAGS
)
t
=
t
ransformer_main
.
TransformerTask
(
FLAGS
)
t
.
train
()
@
unittest
.
skipUnless
(
tf
.
test
.
is_built_with_cuda
(),
'requires GPU'
)
...
...
@@ -134,7 +134,7 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS
.
num_gpus
=
2
FLAGS
.
param_set
=
'base'
FLAGS
.
dtype
=
'fp16'
t
=
t
m
.
TransformerTask
(
FLAGS
)
t
=
t
ransformer_main
.
TransformerTask
(
FLAGS
)
t
.
train
()
def
_prepare_files_and_flags
(
self
,
*
extra_flags
):
...
...
@@ -167,14 +167,14 @@ class TransformerTaskTest(tf.test.TestCase):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
self
.
_prepare_files_and_flags
()
t
=
t
m
.
TransformerTask
(
FLAGS
)
t
=
t
ransformer_main
.
TransformerTask
(
FLAGS
)
t
.
predict
()
def
test_predict_fp16
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
self
.
_prepare_files_and_flags
(
'--dtype=fp16'
)
t
=
t
m
.
TransformerTask
(
FLAGS
)
t
=
t
ransformer_main
.
TransformerTask
(
FLAGS
)
t
.
predict
()
def
test_eval
(
self
):
...
...
@@ -183,7 +183,7 @@ class TransformerTaskTest(tf.test.TestCase):
if
'test_xla'
in
sys
.
argv
[
0
]:
self
.
skipTest
(
'TODO(xla): Make this test faster under XLA.'
)
self
.
_prepare_files_and_flags
()
t
=
t
m
.
TransformerTask
(
FLAGS
)
t
=
t
ransformer_main
.
TransformerTask
(
FLAGS
)
t
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment