Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dd376f53
Commit
dd376f53
authored
Aug 14, 2019
by
Igor Saprykin
Committed by
A. Unique TensorFlower
Aug 14, 2019
Browse files
Internal change
PiperOrigin-RevId: 263463300
parent
c8660848
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
5 deletions
+15
-5
official/transformer/v2/transformer_main_test.py
official/transformer/v2/transformer_main_test.py
+15
-5
No files found.
official/transformer/v2/transformer_main_test.py
View file @
dd376f53
...
...
@@ -80,10 +80,14 @@ class TransformerTaskTest(tf.test.TestCase):
self
.
assertTrue
(
os
.
path
.
exists
(
filepath
))
def
test_train_no_dist_strat
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
def
test_train_static_batch
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
FLAGS
.
distribution_strategy
=
'one_device'
FLAGS
.
static_batch
=
True
t
=
tm
.
TransformerTask
(
FLAGS
)
...
...
@@ -105,8 +109,8 @@ class TransformerTaskTest(tf.test.TestCase):
def
test_train_2_gpu
(
self
):
if
context
.
num_gpus
()
<
2
:
self
.
skipTest
(
'{} GPUs are not available for this test. {} GPUs are available'
.
format
(
2
,
context
.
num_gpus
()))
'{} GPUs are not available for this test. {} GPUs are available'
.
format
(
2
,
context
.
num_gpus
()))
FLAGS
.
distribution_strategy
=
'mirrored'
FLAGS
.
num_gpus
=
2
FLAGS
.
param_set
=
'base'
...
...
@@ -117,8 +121,8 @@ class TransformerTaskTest(tf.test.TestCase):
def
test_train_2_gpu_fp16
(
self
):
if
context
.
num_gpus
()
<
2
:
self
.
skipTest
(
'{} GPUs are not available for this test. {} GPUs are available'
.
format
(
2
,
context
.
num_gpus
()))
'{} GPUs are not available for this test. {} GPUs are available'
.
format
(
2
,
context
.
num_gpus
()))
FLAGS
.
distribution_strategy
=
'mirrored'
FLAGS
.
num_gpus
=
2
FLAGS
.
param_set
=
'base'
...
...
@@ -153,16 +157,22 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS
(
update_flags
)
def
test_predict
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
self
.
_prepare_files_and_flags
()
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
predict
()
def
test_predict_fp16
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
self
.
_prepare_files_and_flags
(
'--dtype=fp16'
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
predict
()
def
test_eval
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
self
.
_prepare_files_and_flags
()
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment