Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3aee5697
Commit
3aee5697
authored
Jun 05, 2019
by
guptapriya
Committed by
guptapriya
Jun 05, 2019
Browse files
Fix existing tests
parent
6cfa81a1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
30 deletions
+13
-30
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+2
-1
official/transformer/v2/transformer_main_test.py
official/transformer/v2/transformer_main_test.py
+11
-29
No files found.
official/transformer/v2/transformer_main.py
View file @
3aee5697
...
@@ -190,7 +190,8 @@ class TransformerTask(object):
...
@@ -190,7 +190,8 @@ class TransformerTask(object):
with
tf
.
name_scope
(
"model"
):
with
tf
.
name_scope
(
"model"
):
model
=
transformer
.
create_model
(
params
,
is_train
)
model
=
transformer
.
create_model
(
params
,
is_train
)
self
.
_load_weights_if_possible
(
model
,
flags_obj
.
init_weight_path
)
self
.
_load_weights_if_possible
(
model
,
tf
.
train
.
latest_checkpoint
(
self
.
flags_obj
.
model_dir
))
model
.
summary
()
model
.
summary
()
subtokenizer
=
tokenizer
.
Subtokenizer
(
flags_obj
.
vocab_file
)
subtokenizer
=
tokenizer
.
Subtokenizer
(
flags_obj
.
vocab_file
)
...
...
official/transformer/v2/transformer_main_test.py
View file @
3aee5697
...
@@ -42,21 +42,19 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -42,21 +42,19 @@ class TransformerTaskTest(tf.test.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
temp_dir
=
self
.
get_temp_dir
()
temp_dir
=
self
.
get_temp_dir
()
FLAGS
.
model_dir
=
temp_dir
FLAGS
.
model_dir
=
os
.
path
.
join
(
temp_dir
,
FIXED_TIMESTAMP
)
FLAGS
.
init_logdir_timestamp
=
FIXED_TIMESTAMP
FLAGS
.
param_set
=
param_set
=
"tiny"
FLAGS
.
param_set
=
param_set
=
"tiny"
FLAGS
.
use_synthetic_data
=
True
FLAGS
.
use_synthetic_data
=
True
FLAGS
.
steps_per_epoch
=
1
FLAGS
.
steps_between_evals
=
1
FLAGS
.
train_steps
=
2
FLAGS
.
validation_steps
=
1
FLAGS
.
validation_steps
=
1
FLAGS
.
train_epochs
=
1
FLAGS
.
batch_size
=
8
FLAGS
.
batch_size
=
8
FLAGS
.
init_weight_path
=
None
self
.
model_dir
=
FLAGS
.
model_dir
self
.
cur_log_dir
=
os
.
path
.
join
(
temp_dir
,
FIXED_TIMESTAMP
)
self
.
temp_dir
=
temp_dir
self
.
vocab_file
=
os
.
path
.
join
(
self
.
cur_log
_dir
,
"vocab"
)
self
.
vocab_file
=
os
.
path
.
join
(
temp
_dir
,
"vocab"
)
self
.
vocab_size
=
misc
.
get_model_params
(
param_set
,
0
)[
"vocab_size"
]
self
.
vocab_size
=
misc
.
get_model_params
(
param_set
,
0
)[
"vocab_size"
]
self
.
bleu_source
=
os
.
path
.
join
(
self
.
cur_log_dir
,
"bleu_source"
)
self
.
bleu_source
=
os
.
path
.
join
(
temp_dir
,
"bleu_source"
)
self
.
bleu_ref
=
os
.
path
.
join
(
self
.
cur_log_dir
,
"bleu_ref"
)
self
.
bleu_ref
=
os
.
path
.
join
(
temp_dir
,
"bleu_ref"
)
self
.
flags_file
=
os
.
path
.
join
(
self
.
cur_log_dir
,
"flags"
)
def
_assert_exists
(
self
,
filepath
):
def
_assert_exists
(
self
,
filepath
):
self
.
assertTrue
(
os
.
path
.
exists
(
filepath
))
self
.
assertTrue
(
os
.
path
.
exists
(
filepath
))
...
@@ -64,27 +62,11 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -64,27 +62,11 @@ class TransformerTaskTest(tf.test.TestCase):
def
test_train
(
self
):
def
test_train
(
self
):
t
=
tm
.
TransformerTask
(
FLAGS
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
t
.
train
()
# Test model dir.
self
.
_assert_exists
(
self
.
cur_log_dir
)
# Test saving models.
self
.
_assert_exists
(
os
.
path
.
join
(
self
.
cur_log_dir
,
"saves-model-weights.hdf5"
))
self
.
_assert_exists
(
os
.
path
.
join
(
self
.
cur_log_dir
,
"saves-model.hdf5"
))
# Test callbacks:
# TensorBoard file.
self
.
_assert_exists
(
os
.
path
.
join
(
self
.
cur_log_dir
,
"logs"
))
# CSVLogger file.
self
.
_assert_exists
(
os
.
path
.
join
(
self
.
cur_log_dir
,
"result.csv"
))
# Checkpoint file.
filenames
=
os
.
listdir
(
self
.
cur_log_dir
)
matched_weight_file
=
any
([
WEIGHT_PATTERN
.
match
(
f
)
for
f
in
filenames
])
self
.
assertTrue
(
matched_weight_file
)
def
_prepare_files_and_flags
(
self
,
*
extra_flags
):
def
_prepare_files_and_flags
(
self
,
*
extra_flags
):
# Make log dir.
# Make log dir.
if
not
os
.
path
.
exists
(
self
.
cur_log
_dir
):
if
not
os
.
path
.
exists
(
self
.
temp
_dir
):
os
.
makedirs
(
self
.
cur_log
_dir
)
os
.
makedirs
(
self
.
temp
_dir
)
# Fake vocab, bleu_source and bleu_ref.
# Fake vocab, bleu_source and bleu_ref.
tokens
=
[
tokens
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment