Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d967bfae
Commit
d967bfae
authored
Jun 05, 2019
by
guptapriya
Committed by
guptapriya
Jun 05, 2019
Browse files
add strategy specific tests
parent
3aee5697
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
1 deletion
+25
-1
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+6
-0
official/transformer/v2/transformer_main_test.py
official/transformer/v2/transformer_main_test.py
+19
-1
No files found.
official/transformer/v2/transformer_main.py
View file @
d967bfae
...
@@ -97,6 +97,12 @@ class TransformerTask(object):
...
@@ -97,6 +97,12 @@ class TransformerTask(object):
distribution_strategy
=
flags_obj
.
distribution_strategy
,
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_core
.
get_num_gpus
(
flags_obj
))
num_gpus
=
flags_core
.
get_num_gpus
(
flags_obj
))
print
(
"Running transformer with num_gpus ="
,
num_gpus
)
if
self
.
distribution_strategy
:
print
(
"For training, using distribution strategy: "
,
self
.
distribution_strategy
)
else
:
print
(
"Not using any distribution strategy."
)
self
.
params
=
params
=
misc
.
get_model_params
(
flags_obj
.
param_set
,
num_gpus
)
self
.
params
=
params
=
misc
.
get_model_params
(
flags_obj
.
param_set
,
num_gpus
)
params
[
"num_gpus"
]
=
num_gpus
params
[
"num_gpus"
]
=
num_gpus
...
...
official/transformer/v2/transformer_main_test.py
View file @
d967bfae
...
@@ -49,6 +49,8 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -49,6 +49,8 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS
.
train_steps
=
2
FLAGS
.
train_steps
=
2
FLAGS
.
validation_steps
=
1
FLAGS
.
validation_steps
=
1
FLAGS
.
batch_size
=
8
FLAGS
.
batch_size
=
8
FLAGS
.
num_gpus
=
1
FLAGS
.
distribution_strategy
=
"off"
self
.
model_dir
=
FLAGS
.
model_dir
self
.
model_dir
=
FLAGS
.
model_dir
self
.
temp_dir
=
temp_dir
self
.
temp_dir
=
temp_dir
self
.
vocab_file
=
os
.
path
.
join
(
temp_dir
,
"vocab"
)
self
.
vocab_file
=
os
.
path
.
join
(
temp_dir
,
"vocab"
)
...
@@ -63,6 +65,22 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -63,6 +65,22 @@ class TransformerTaskTest(tf.test.TestCase):
t
=
tm
.
TransformerTask
(
FLAGS
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
t
.
train
()
def
test_train_static_batch
(
self
):
FLAGS
.
static_batch
=
True
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
def
test_train_1_gpu_with_dist_strat
(
self
):
FLAGS
.
distribution_strategy
=
"one_device"
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
def
test_train_2_gpu
(
self
):
FLAGS
.
distribution_strategy
=
"mirrored"
FLAGS
.
num_gpus
=
2
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
def
_prepare_files_and_flags
(
self
,
*
extra_flags
):
def
_prepare_files_and_flags
(
self
,
*
extra_flags
):
# Make log dir.
# Make log dir.
if
not
os
.
path
.
exists
(
self
.
temp_dir
):
if
not
os
.
path
.
exists
(
self
.
temp_dir
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment