Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d56d51d0
"docs_zh-CN/vscode:/vscode.git/clone" did not exist on "26ab7ff294c7a66c7533442ee8d52182c02faf62"
Commit
d56d51d0
authored
Jul 14, 2020
by
Kaushik Shivakumar
Browse files
Merge remote-tracking branch 'upstream/master' into context_tf2
parents
ea550ca9
73a911c0
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
493 additions
and
1381 deletions
+493
-1381
official/benchmark/resnet_ctl_imagenet_benchmark.py
official/benchmark/resnet_ctl_imagenet_benchmark.py
+11
-2
official/nlp/data/create_pretraining_data.py
official/nlp/data/create_pretraining_data.py
+233
-58
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+6
-1
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+85
-8
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+95
-0
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+1
-1
official/staging/training/controller.py
official/staging/training/controller.py
+0
-337
official/staging/training/controller_test.py
official/staging/training/controller_test.py
+0
-308
official/staging/training/runnable.py
official/staging/training/runnable.py
+0
-79
official/staging/training/standard_runnable.py
official/staging/training/standard_runnable.py
+0
-181
official/staging/training/utils.py
official/staging/training/utils.py
+0
-342
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
...n/image_classification/resnet/resnet_ctl_imagenet_main.py
+24
-25
official/vision/image_classification/resnet/resnet_runnable.py
...ial/vision/image_classification/resnet/resnet_runnable.py
+18
-30
orbit/controller.py
orbit/controller.py
+4
-2
research/object_detection/README.md
research/object_detection/README.md
+12
-7
research/object_detection/g3doc/tf1.md
research/object_detection/g3doc/tf1.md
+2
-0
research/object_detection/g3doc/tf2.md
research/object_detection/g3doc/tf2.md
+2
-0
No files found.
official/benchmark/resnet_ctl_imagenet_benchmark.py
View file @
d56d51d0
...
@@ -389,6 +389,15 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -389,6 +389,15 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
dtype
=
'bf16'
FLAGS
.
dtype
=
'bf16'
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
@
owner_utils
.
Owner
(
'tf-graph-compiler'
)
def
benchmark_2x2_tpu_bf16_mlir
(
self
):
self
.
_setup
()
self
.
_set_df_common
()
FLAGS
.
batch_size
=
1024
FLAGS
.
dtype
=
'bf16'
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
()
def
benchmark_4x4_tpu_bf16
(
self
):
def
benchmark_4x4_tpu_bf16
(
self
):
self
.
_setup
()
self
.
_setup
()
self
.
_set_df_common
()
self
.
_set_df_common
()
...
@@ -426,7 +435,7 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
...
@@ -426,7 +435,7 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
def_flags
[
'skip_eval'
]
=
True
def_flags
[
'skip_eval'
]
=
True
def_flags
[
'use_synthetic_data'
]
=
True
def_flags
[
'use_synthetic_data'
]
=
True
def_flags
[
'train_steps'
]
=
110
def_flags
[
'train_steps'
]
=
110
def_flags
[
'steps_per_loop'
]
=
2
0
def_flags
[
'steps_per_loop'
]
=
1
0
def_flags
[
'log_steps'
]
=
10
def_flags
[
'log_steps'
]
=
10
super
(
Resnet50CtlBenchmarkSynth
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkSynth
,
self
).
__init__
(
...
@@ -441,7 +450,7 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
...
@@ -441,7 +450,7 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
def_flags
[
'skip_eval'
]
=
True
def_flags
[
'skip_eval'
]
=
True
def_flags
[
'data_dir'
]
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
def_flags
[
'data_dir'
]
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
def_flags
[
'train_steps'
]
=
110
def_flags
[
'train_steps'
]
=
110
def_flags
[
'steps_per_loop'
]
=
2
0
def_flags
[
'steps_per_loop'
]
=
1
0
def_flags
[
'log_steps'
]
=
10
def_flags
[
'log_steps'
]
=
10
super
(
Resnet50CtlBenchmarkReal
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkReal
,
self
).
__init__
(
...
...
official/nlp/data/create_pretraining_data.py
View file @
d56d51d0
...
@@ -18,6 +18,7 @@ from __future__ import division
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
collections
import
collections
import
itertools
import
random
import
random
from
absl
import
app
from
absl
import
app
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask"
,
False
,
"do_whole_word_mask"
,
False
,
"Whether to use whole word masking rather than per-WordPiece masking."
)
"Whether to use whole word masking rather than per-WordPiece masking."
)
flags
.
DEFINE_integer
(
"max_ngram_size"
,
None
,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking."
)
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
"gzip_compress"
,
False
,
"gzip_compress"
,
False
,
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
max_predictions_per_seq
,
rng
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Create `TrainingInstance`s from raw text."""
"""Create `TrainingInstance`s from raw text."""
all_documents
=
[[]]
all_documents
=
[[]]
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document
(
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
))
do_whole_word_mask
,
max_ngram_size
))
rng
.
shuffle
(
instances
)
rng
.
shuffle
(
instances
)
return
instances
return
instances
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def
create_instances_from_document
(
def
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Creates `TrainingInstance`s for a single document."""
"""Creates `TrainingInstance`s for a single document."""
document
=
all_documents
[
document_index
]
document
=
all_documents
[
document_index
]
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
(
tokens
,
masked_lm_positions
,
(
tokens
,
masked_lm_positions
,
masked_lm_labels
)
=
create_masked_lm_predictions
(
masked_lm_labels
)
=
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
)
do_whole_word_mask
,
max_ngram_size
)
instance
=
TrainingInstance
(
instance
=
TrainingInstance
(
tokens
=
tokens
,
tokens
=
tokens
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
[
"index"
,
"label"
])
[
"index"
,
"label"
])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram
=
collections
.
namedtuple
(
"_Gram"
,
[
"begin"
,
"end"
])
def
_window
(
iterable
,
size
):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i
=
iter
(
iterable
)
window
=
[]
try
:
for
e
in
range
(
0
,
size
):
window
.
append
(
next
(
i
))
yield
window
except
StopIteration
:
# handle the case where iterable's length is less than the window size.
return
for
e
in
i
:
window
=
window
[
1
:]
+
[
e
]
yield
window
def
_contiguous
(
sorted_grams
):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for
a
,
b
in
_window
(
sorted_grams
,
2
):
if
a
.
end
!=
b
.
begin
:
return
False
return
True
def
_masking_ngrams
(
grams
,
max_ngram_size
,
max_masked_tokens
,
rng
):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if
not
grams
:
return
None
grams
=
sorted
(
grams
)
num_tokens
=
grams
[
-
1
].
end
# Ensure our grams are valid (i.e., they don't overlap).
for
a
,
b
in
_window
(
grams
,
2
):
if
a
.
end
>
b
.
begin
:
raise
ValueError
(
"overlapping grams: {}"
.
format
(
grams
))
# Build map from n-gram length to list of n-grams.
ngrams
=
{
i
:
[]
for
i
in
range
(
1
,
max_ngram_size
+
1
)}
for
gram_size
in
range
(
1
,
max_ngram_size
+
1
):
for
g
in
_window
(
grams
,
gram_size
):
if
_contiguous
(
g
):
# Add an n-gram which spans these one-grams.
ngrams
[
gram_size
].
append
(
_Gram
(
g
[
0
].
begin
,
g
[
-
1
].
end
))
# Shuffle each list of n-grams.
for
v
in
ngrams
.
values
():
rng
.
shuffle
(
v
)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights
=
list
(
itertools
.
accumulate
([
1.
/
n
for
n
in
range
(
1
,
max_ngram_size
+
1
)]))
output_ngrams
=
[]
# Keep a bitmask of which tokens have been masked.
masked_tokens
=
[
False
]
*
num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while
(
sum
(
masked_tokens
)
<
max_masked_tokens
and
sum
(
len
(
s
)
for
s
in
ngrams
.
values
())):
# Pick an n-gram size based on our weights.
sz
=
random
.
choices
(
range
(
1
,
max_ngram_size
+
1
),
cum_weights
=
cummulative_weights
)[
0
]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if
sum
(
masked_tokens
)
+
sz
>
max_masked_tokens
:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams
[
sz
].
clear
()
continue
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
# All of the n-grams of this size have been used.
max_predictions_per_seq
,
vocab_words
,
rng
,
if
not
ngrams
[
sz
]:
do_whole_word_mask
):
continue
"""Creates the predictions for the masked LM objective."""
# Choose a random n-gram of the given size.
gram
=
ngrams
[
sz
].
pop
()
num_gram_tokens
=
gram
.
end
-
gram
.
begin
# Check if this would add too many tokens.
if
num_gram_tokens
+
sum
(
masked_tokens
)
>
max_masked_tokens
:
continue
# Check if any of the tokens in this gram have already been masked.
if
sum
(
masked_tokens
[
gram
.
begin
:
gram
.
end
]):
continue
cand_indexes
=
[]
# Found a usable n-gram! Mark its tokens as masked and add it to return.
for
(
i
,
token
)
in
enumerate
(
tokens
):
masked_tokens
[
gram
.
begin
:
gram
.
end
]
=
[
True
]
*
(
gram
.
end
-
gram
.
begin
)
if
token
==
"[CLS]"
or
token
==
"[SEP]"
:
output_ngrams
.
append
(
gram
)
return
output_ngrams
def
_wordpieces_to_grams
(
tokens
):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams
=
[]
gram_start_pos
=
None
for
i
,
token
in
enumerate
(
tokens
):
if
gram_start_pos
is
not
None
and
token
.
startswith
(
"##"
):
continue
continue
# Whole Word Masking means that if we mask all of the wordpieces
if
gram_start_pos
is
not
None
:
# corresponding to an original word. When a word has been split into
grams
.
append
(
_Gram
(
gram_start_pos
,
i
))
# WordPieces, the first token does not have any marker and any subsequence
if
token
not
in
[
"[CLS]"
,
"[SEP]"
]:
# tokens are prefixed with ##. So whenever we see the ## token, we
gram_start_pos
=
i
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
token
.
startswith
(
"##"
)):
cand_indexes
[
-
1
].
append
(
i
)
else
:
else
:
cand_indexes
.
append
([
i
])
gram_start_pos
=
None
if
gram_start_pos
is
not
None
:
grams
.
append
(
_Gram
(
gram_start_pos
,
len
(
tokens
)))
return
grams
rng
.
shuffle
(
cand_indexes
)
output_tokens
=
list
(
tokens
)
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
,
max_ngram_size
=
None
):
"""Creates the predictions for the masked LM objective."""
if
do_whole_word_mask
:
grams
=
_wordpieces_to_grams
(
tokens
)
else
:
# Here we consider each token to be a word to allow for sub-word masking.
if
max_ngram_size
:
raise
ValueError
(
"cannot use ngram masking without whole word masking"
)
grams
=
[
_Gram
(
i
,
i
+
1
)
for
i
in
range
(
0
,
len
(
tokens
))
if
tokens
[
i
]
not
in
[
"[CLS]"
,
"[SEP]"
]]
num_to_predict
=
min
(
max_predictions_per_seq
,
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams
=
_masking_ngrams
(
grams
,
max_ngram_size
or
1
,
num_to_predict
,
rng
)
masked_lms
=
[]
masked_lms
=
[]
covered_indexes
=
set
()
output_tokens
=
list
(
tokens
)
for
index_set
in
cand_indexes
:
for
gram
in
masked_grams
:
if
len
(
masked_lms
)
>=
num_to_predict
:
# 80% of the time, replace all n-gram tokens with [MASK]
break
if
rng
.
random
()
<
0.8
:
# If adding a whole-word mask would exceed the maximum number of
replacement_action
=
lambda
idx
:
"[MASK]"
# predictions, then just skip this candidate.
else
:
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
# 10% of the time, keep all the original n-gram tokens.
continue
if
rng
.
random
()
<
0.5
:
is_any_index_covered
=
False
replacement_action
=
lambda
idx
:
tokens
[
idx
]
for
index
in
index_set
:
# 10% of the time, replace each n-gram token with a random word.
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
rng
.
random
()
<
0.8
:
masked_token
=
"[MASK]"
else
:
else
:
# 10% of the time, keep original
replacement_action
=
lambda
idx
:
rng
.
choice
(
vocab_words
)
if
rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
else
:
masked_token
=
vocab_words
[
rng
.
randint
(
0
,
len
(
vocab_words
)
-
1
)]
output_tokens
[
index
]
=
masked_token
for
idx
in
range
(
gram
.
begin
,
gram
.
end
):
output_tokens
[
idx
]
=
replacement_action
(
idx
)
masked_lms
.
append
(
MaskedLmInstance
(
index
=
idx
,
label
=
tokens
[
idx
]))
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
assert
len
(
masked_lms
)
<=
num_to_predict
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
...
@@ -467,7 +642,7 @@ def main(_):
...
@@ -467,7 +642,7 @@ def main(_):
instances
=
create_training_instances
(
instances
=
create_training_instances
(
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
rng
,
FLAGS
.
do_whole_word_mask
)
rng
,
FLAGS
.
do_whole_word_mask
,
FLAGS
.
max_ngram_size
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
logging
.
info
(
"*** Writing to output files ***"
)
logging
.
info
(
"*** Writing to output files ***"
)
...
...
official/nlp/data/sentence_prediction_dataloader.py
View file @
d56d51d0
...
@@ -23,6 +23,9 @@ from official.modeling.hyperparams import config_definitions as cfg
...
@@ -23,6 +23,9 @@ from official.modeling.hyperparams import config_definitions as cfg
from
official.nlp.data
import
data_loader_factory
from
official.nlp.data
import
data_loader_factory
LABEL_TYPES_MAP
=
{
'int'
:
tf
.
int64
,
'float'
:
tf
.
float32
}
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
...
@@ -30,6 +33,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
...
@@ -30,6 +33,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
global_batch_size
:
int
=
32
global_batch_size
:
int
=
32
is_training
:
bool
=
True
is_training
:
bool
=
True
seq_length
:
int
=
128
seq_length
:
int
=
128
label_type
:
str
=
'int'
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
...
@@ -42,11 +46,12 @@ class SentencePredictionDataLoader:
...
@@ -42,11 +46,12 @@ class SentencePredictionDataLoader:
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
label_type
=
LABEL_TYPES_MAP
[
self
.
_params
.
label_type
]
name_to_features
=
{
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
}
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
...
...
official/nlp/tasks/sentence_prediction.py
View file @
d56d51d0
...
@@ -14,9 +14,12 @@
...
@@ -14,9 +14,12 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Sentence prediction (classification) task."""
"""Sentence prediction (classification) task."""
from
typing
import
List
,
Union
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
dataclasses
import
numpy
as
np
import
numpy
as
np
import
orbit
from
scipy
import
stats
from
scipy
import
stats
from
sklearn
import
metrics
as
sklearn_metrics
from
sklearn
import
metrics
as
sklearn_metrics
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -31,6 +34,10 @@ from official.nlp.modeling import models
...
@@ -31,6 +34,10 @@ from official.nlp.modeling import models
from
official.nlp.tasks
import
utils
from
official.nlp.tasks
import
utils
METRIC_TYPES
=
frozenset
(
[
'accuracy'
,
'matthews_corrcoef'
,
'pearson_spearman_corr'
])
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ModelConfig
(
base_config
.
Config
):
class
ModelConfig
(
base_config
.
Config
):
"""A classifier/regressor configuration."""
"""A classifier/regressor configuration."""
...
@@ -68,6 +75,9 @@ class SentencePredictionTask(base_task.Task):
...
@@ -68,6 +75,9 @@ class SentencePredictionTask(base_task.Task):
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
else
:
else
:
self
.
_hub_module
=
None
self
.
_hub_module
=
None
if
params
.
metric_type
not
in
METRIC_TYPES
:
raise
ValueError
(
'Invalid metric_type: {}'
.
format
(
params
.
metric_type
))
self
.
metric_type
=
params
.
metric_type
self
.
metric_type
=
params
.
metric_type
def
build_model
(
self
):
def
build_model
(
self
):
...
@@ -77,7 +87,7 @@ class SentencePredictionTask(base_task.Task):
...
@@ -77,7 +87,7 @@ class SentencePredictionTask(base_task.Task):
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
self
.
task_config
.
model
.
encoder
)
self
.
task_config
.
model
.
encoder
)
# Currently, we only support
s
bert-style sentence prediction finetuning.
# Currently, we only support bert-style sentence prediction finetuning.
return
models
.
BertClassifier
(
return
models
.
BertClassifier
(
network
=
encoder_network
,
network
=
encoder_network
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
...
@@ -86,8 +96,11 @@ class SentencePredictionTask(base_task.Task):
...
@@ -86,8 +96,11 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
if
self
.
task_config
.
model
.
num_classes
==
1
:
labels
,
tf
.
cast
(
model_outputs
,
tf
.
float32
),
from_logits
=
True
)
loss
=
tf
.
keras
.
losses
.
mean_squared_error
(
labels
,
model_outputs
)
else
:
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
tf
.
cast
(
model_outputs
,
tf
.
float32
),
from_logits
=
True
)
if
aux_losses
:
if
aux_losses
:
loss
+=
tf
.
add_n
(
aux_losses
)
loss
+=
tf
.
add_n
(
aux_losses
)
...
@@ -103,8 +116,12 @@ class SentencePredictionTask(base_task.Task):
...
@@ -103,8 +116,12 @@ class SentencePredictionTask(base_task.Task):
input_word_ids
=
dummy_ids
,
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
input_type_ids
=
dummy_ids
)
y
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
)
return
(
x
,
y
)
if
self
.
task_config
.
model
.
num_classes
==
1
:
y
=
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
)
else
:
y
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
)
return
x
,
y
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
...
@@ -116,7 +133,11 @@ class SentencePredictionTask(base_task.Task):
...
@@ -116,7 +133,11 @@ class SentencePredictionTask(base_task.Task):
def
build_metrics
(
self
,
training
=
None
):
def
build_metrics
(
self
,
training
=
None
):
del
training
del
training
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)]
if
self
.
task_config
.
model
.
num_classes
==
1
:
metrics
=
[
tf
.
keras
.
metrics
.
MeanSquaredError
()]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)]
return
metrics
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
...
@@ -154,6 +175,7 @@ class SentencePredictionTask(base_task.Task):
...
@@ -154,6 +175,7 @@ class SentencePredictionTask(base_task.Task):
return
None
return
None
if
state
is
None
:
if
state
is
None
:
state
=
{
'sentence_prediction'
:
[],
'labels'
:
[]}
state
=
{
'sentence_prediction'
:
[],
'labels'
:
[]}
# TODO(b/160712818): Add support for concatenating partial batches.
state
[
'sentence_prediction'
].
append
(
state
[
'sentence_prediction'
].
append
(
np
.
concatenate
([
v
.
numpy
()
for
v
in
step_outputs
[
'sentence_prediction'
]],
np
.
concatenate
([
v
.
numpy
()
for
v
in
step_outputs
[
'sentence_prediction'
]],
axis
=
0
))
axis
=
0
))
...
@@ -162,15 +184,21 @@ class SentencePredictionTask(base_task.Task):
...
@@ -162,15 +184,21 @@ class SentencePredictionTask(base_task.Task):
return
state
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
if
self
.
metric_type
==
'matthews_corrcoef'
:
if
self
.
metric_type
==
'accuracy'
:
return
None
elif
self
.
metric_type
==
'matthews_corrcoef'
:
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
preds
=
np
.
reshape
(
preds
,
-
1
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
labels
=
np
.
reshape
(
labels
,
-
1
)
return
{
return
{
self
.
metric_type
:
sklearn_metrics
.
matthews_corrcoef
(
preds
,
labels
)
self
.
metric_type
:
sklearn_metrics
.
matthews_corrcoef
(
preds
,
labels
)
}
}
if
self
.
metric_type
==
'pearson_spearman_corr'
:
el
if
self
.
metric_type
==
'pearson_spearman_corr'
:
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
preds
=
np
.
reshape
(
preds
,
-
1
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
labels
=
np
.
reshape
(
labels
,
-
1
)
pearson_corr
=
stats
.
pearsonr
(
preds
,
labels
)[
0
]
pearson_corr
=
stats
.
pearsonr
(
preds
,
labels
)[
0
]
spearman_corr
=
stats
.
spearmanr
(
preds
,
labels
)[
0
]
spearman_corr
=
stats
.
spearmanr
(
preds
,
labels
)[
0
]
corr_metric
=
(
pearson_corr
+
spearman_corr
)
/
2
corr_metric
=
(
pearson_corr
+
spearman_corr
)
/
2
...
@@ -198,3 +226,52 @@ class SentencePredictionTask(base_task.Task):
...
@@ -198,3 +226,52 @@ class SentencePredictionTask(base_task.Task):
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
ckpt_dir_or_file
)
def
predict
(
task
:
SentencePredictionTask
,
params
:
cfg
.
DataConfig
,
model
:
tf
.
keras
.
Model
)
->
List
[
Union
[
int
,
float
]]:
"""Predicts on the input data.
Args:
task: A `SentencePredictionTask` object.
params: A `cfg.DataConfig` object.
model: A keras.Model.
Returns:
A list of predictions with length of `num_examples`. For regression task,
each element in the list is the predicted score; for classification task,
each element is the predicted class id.
"""
is_regression
=
task
.
task_config
.
model
.
num_classes
==
1
@
tf
.
function
def
predict_step
(
iterator
):
"""Predicts on distributed devices."""
def
_replicated_step
(
inputs
):
"""Replicated prediction calculation."""
x
,
_
=
inputs
outputs
=
task
.
inference_step
(
x
,
model
)
if
is_regression
:
return
outputs
else
:
return
tf
.
argmax
(
outputs
,
axis
=-
1
)
outputs
=
tf
.
distribute
.
get_strategy
().
run
(
_replicated_step
,
args
=
(
next
(
iterator
),))
return
tf
.
nest
.
map_structure
(
tf
.
distribute
.
get_strategy
().
experimental_local_results
,
outputs
)
def
reduce_fn
(
state
,
outputs
):
"""Concatenates model's outputs."""
for
per_replica_batch_predictions
in
outputs
:
state
.
extend
(
per_replica_batch_predictions
)
return
state
loop_fn
=
orbit
.
utils
.
create_loop_fn
(
predict_step
)
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
tf
.
distribute
.
get_strategy
(),
task
.
build_inputs
,
params
)
# Set `num_steps` to -1 to exhaust the dataset.
predictions
=
loop_fn
(
iter
(
dataset
),
num_steps
=-
1
,
state
=
[],
reduce_fn
=
reduce_fn
)
return
predictions
official/nlp/tasks/sentence_prediction_test.py
View file @
d56d51d0
...
@@ -18,6 +18,7 @@ import functools
...
@@ -18,6 +18,7 @@ import functools
import
os
import
os
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
configs
...
@@ -28,6 +29,35 @@ from official.nlp.data import sentence_prediction_dataloader
...
@@ -28,6 +29,35 @@ from official.nlp.data import sentence_prediction_dataloader
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
sentence_prediction
def
_create_fake_dataset
(
output_path
,
seq_length
,
num_classes
,
num_examples
):
"""Creates a fake dataset."""
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
def
create_int_feature
(
values
):
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
def
create_float_feature
(
values
):
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
list
(
values
)))
for
_
in
range
(
num_examples
):
features
=
{}
input_ids
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
"segment_ids"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
"segment_ids"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
if
num_classes
==
1
:
features
[
"label_ids"
]
=
create_float_feature
([
np
.
random
.
random
()])
else
:
features
[
"label_ids"
]
=
create_int_feature
(
[
np
.
random
.
random_integers
(
0
,
num_classes
-
1
,
size
=
())])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
class
SentencePredictionTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
SentencePredictionTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -85,6 +115,42 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -85,6 +115,42 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
ckpt
.
save
(
config
.
init_checkpoint
)
ckpt
.
save
(
config
.
init_checkpoint
)
task
.
initialize
(
model
)
task
.
initialize
(
model
)
@
parameterized
.
named_parameters
(
{
"testcase_name"
:
"regression"
,
"num_classes"
:
1
,
},
{
"testcase_name"
:
"classification"
,
"num_classes"
:
2
,
},
)
def
test_metrics_and_losses
(
self
,
num_classes
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
model
=
self
.
get_model_config
(
num_classes
),
train_data
=
self
.
_train_data_config
)
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
if
num_classes
==
1
:
self
.
assertIsInstance
(
metrics
[
0
],
tf
.
keras
.
metrics
.
MeanSquaredError
)
else
:
self
.
assertIsInstance
(
metrics
[
0
],
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
)
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
loss
=
logs
[
"loss"
].
numpy
()
if
num_classes
==
1
:
self
.
assertAlmostEqual
(
loss
,
42.77483
,
places
=
3
)
else
:
self
.
assertAlmostEqual
(
loss
,
3.57627e-6
,
places
=
3
)
@
parameterized
.
parameters
((
"matthews_corrcoef"
,
2
),
@
parameterized
.
parameters
((
"matthews_corrcoef"
,
2
),
(
"pearson_spearman_corr"
,
1
))
(
"pearson_spearman_corr"
,
1
))
def
test_np_metrics
(
self
,
metric_type
,
num_classes
):
def
test_np_metrics
(
self
,
metric_type
,
num_classes
):
...
@@ -153,6 +219,35 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -153,6 +219,35 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
train_data
=
self
.
_train_data_config
)
train_data
=
self
.
_train_data_config
)
self
.
_run_task
(
config
)
self
.
_run_task
(
config
)
@
parameterized
.
named_parameters
((
"classification"
,
5
),
(
"regression"
,
1
))
def
test_prediction
(
self
,
num_classes
):
task_config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
self
.
get_model_config
(
num_classes
=
num_classes
),
train_data
=
self
.
_train_data_config
)
task
=
sentence_prediction
.
SentencePredictionTask
(
task_config
)
model
=
task
.
build_model
()
test_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"test.tf_record"
)
seq_length
=
16
num_examples
=
100
_create_fake_dataset
(
test_data_path
,
seq_length
=
seq_length
,
num_classes
=
num_classes
,
num_examples
=
num_examples
)
test_data_config
=
(
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(
input_path
=
test_data_path
,
seq_length
=
seq_length
,
is_training
=
False
,
label_type
=
"int"
if
num_classes
>
1
else
"float"
,
global_batch_size
=
16
,
drop_remainder
=
False
))
predictions
=
sentence_prediction
.
predict
(
task
,
test_data_config
,
model
)
self
.
assertLen
(
predictions
,
num_examples
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/nlp/tasks/tagging.py
View file @
d56d51d0
...
@@ -262,7 +262,7 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
...
@@ -262,7 +262,7 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
label_mask
=
label_mask
,
label_mask
=
label_mask
,
sentence_ids
=
sentence_ids
)
sentence_ids
=
sentence_ids
)
outputs
=
tf
.
distribute
.
get_strategy
().
experimental_run_v2
(
outputs
=
tf
.
distribute
.
get_strategy
().
run
(
_replicated_step
,
args
=
(
next
(
iterator
),))
_replicated_step
,
args
=
(
next
(
iterator
),))
return
tf
.
nest
.
map_structure
(
return
tf
.
nest
.
map_structure
(
tf
.
distribute
.
get_strategy
().
experimental_local_results
,
outputs
)
tf
.
distribute
.
get_strategy
().
experimental_local_results
,
outputs
)
...
...
official/staging/training/controller.py
deleted
100644 → 0
View file @
ea550ca9
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A light weight utilities to train TF2 models."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
time
from
absl
import
logging
import
tensorflow.compat.v2
as
tf
from
typing
import
Callable
,
Dict
,
Optional
,
Text
from
official.staging.training
import
utils
class
Controller
(
object
):
"""Class that facilitates training and evaluation of models."""
def
__init__
(
self
,
strategy
:
Optional
[
tf
.
distribute
.
Strategy
]
=
None
,
train_fn
:
Optional
[
Callable
[[
tf
.
Tensor
],
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]]]
=
None
,
eval_fn
:
Optional
[
Callable
[[
tf
.
Tensor
],
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]]]
=
None
,
global_step
:
Optional
[
tf
.
Variable
]
=
None
,
# Train related
train_steps
:
Optional
[
int
]
=
None
,
steps_per_loop
:
Optional
[
int
]
=
None
,
summary_dir
:
Optional
[
Text
]
=
None
,
checkpoint_manager
:
Optional
[
tf
.
train
.
CheckpointManager
]
=
None
,
# summary related
summary_interval
:
Optional
[
int
]
=
None
,
# Evaluation related
eval_summary_dir
:
Optional
[
Text
]
=
None
,
eval_steps
:
Optional
[
int
]
=
None
,
eval_interval
:
Optional
[
int
]
=
None
):
"""Constructs a `Controller` instance.
Args:
strategy: An instance of `tf.distribute.Strategy`.
train_fn: A callable defined as `def train_fn(num_steps)`, which
`num_steps` indicates the number of steps to run for each loop.
eval_fn: A callable defined as `def eval_fn(num_steps)`, which `num_steps`
indicates the number of steps for one evaluation.
global_step: An integer `tf.Variable` indicating the global training step
number. Usually this can be obtained from `iterations` property of the
model's optimizer (e.g. `self.optimizer.iterations`), or users can
create their own global step variable as well. If the users create their
own global step variable, it is recommended to create the `tf.Variable`
inside strategy scope, and with
`aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA`.
train_steps: The total (maximum) number of training steps to perform.
steps_per_loop: The number of steps to run in each "inner loop" of
training (passed to the `num_steps` parameter of `train_fn`).
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
checkpoint_manager: An instance of `tf.train.CheckpointManager`.
summary_interval: Step interval for training summaries. Note that this
argument only applies to the summaries outside the training loop. If the
value is None, then training summaries are not enabled.
eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`.
eval_steps: Number of steps to run evaluation.
eval_interval: Step interval for evaluation. If None, will skip evaluation
in the middle of training. Note that evaluation only happens outside the
training loop, which the loop iteration is specify by `steps_per_loop`
parameter.
Raises:
ValueError: If both `train_fn` and `eval_fn` are None.
ValueError: If `train_fn` is not None and `train_steps` is None.
ValueError: If `steps_per_loop` is None when `train_fn` is provided.
ValueError: If `steps_per_loop` is not a positive integer.
"""
if
train_fn
is
None
and
eval_fn
is
None
:
raise
ValueError
(
"`train_fn` and `eval_fn` should not both be None"
)
# TODO(rxsang): Support training until exhaustion by passing
# `train_steps=-1`. Currently it cannot be supported with a host training
# loop because break statements are not supported with distributed dataset.
if
train_fn
is
not
None
:
if
train_steps
is
None
:
raise
ValueError
(
"`train_steps` is required when `train_fn` is "
"provided."
)
if
steps_per_loop
is
None
:
raise
ValueError
(
"`steps_per_loop` is required when `train_fn is "
"provided."
)
if
not
isinstance
(
steps_per_loop
,
int
)
or
steps_per_loop
<
1
:
raise
ValueError
(
"`steps_per_loop` should be a positive integer"
)
if
summary_interval
is
not
None
and
summary_interval
<=
0
:
raise
ValueError
(
"`summary_interval` should be larger than 0"
)
self
.
strategy
=
strategy
or
tf
.
distribute
.
get_strategy
()
self
.
train_fn
=
train_fn
self
.
eval_fn
=
eval_fn
self
.
global_step
=
global_step
self
.
checkpoint_manager
=
checkpoint_manager
if
self
.
train_fn
is
not
None
:
self
.
train_steps
=
train_steps
self
.
steps_per_loop
=
steps_per_loop
if
summary_dir
:
self
.
summary_dir
=
summary_dir
elif
checkpoint_manager
:
self
.
summary_dir
=
checkpoint_manager
.
directory
else
:
self
.
summary_dir
=
None
self
.
summary_interval
=
summary_interval
if
self
.
summary_dir
and
self
.
summary_interval
:
summary_writer
=
tf
.
summary
.
create_file_writer
(
self
.
summary_dir
)
else
:
summary_writer
=
None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self
.
summary_manager
=
utils
.
SummaryManager
(
summary_writer
,
tf
.
summary
.
scalar
,
global_step
=
self
.
global_step
,
summary_interval
=
self
.
summary_interval
)
if
self
.
eval_fn
is
not
None
:
eval_summary_dir
=
eval_summary_dir
or
self
.
summary_dir
eval_summary_writer
=
tf
.
summary
.
create_file_writer
(
eval_summary_dir
)
if
eval_summary_dir
else
None
self
.
eval_summary_manager
=
utils
.
SummaryManager
(
eval_summary_writer
,
tf
.
summary
.
scalar
,
global_step
=
self
.
global_step
)
self
.
eval_steps
=
eval_steps
self
.
eval_interval
=
eval_interval
# Creates and initializes the interval triggers.
self
.
eval_trigger
=
utils
.
IntervalTrigger
(
self
.
eval_interval
,
self
.
global_step
.
numpy
())
# pytype: disable=attribute-error
if
self
.
global_step
:
tf
.
summary
.
experimental
.
set_step
(
self
.
global_step
)
# Restores the model if needed.
if
self
.
checkpoint_manager
is
not
None
:
model_restored
=
self
.
_restore_model
()
if
not
model_restored
and
self
.
checkpoint_manager
.
checkpoint_interval
:
# If the model is not restored from a checkpoint, save an initial
# checkpoint.
ckpt_path
=
self
.
checkpoint_manager
.
save
(
checkpoint_number
=
self
.
global_step
)
logging
.
info
(
"Saved checkpoins in %s"
,
ckpt_path
)
def
_restore_model
(
self
,
checkpoint_path
=
None
):
"""Restore or initialize the model.
Args:
checkpoint_path: An optional string indicates the checkpoint path to
restore. If None, will restore from `self.checkpoint_manager`.
Returns:
True if the latest checkpoint is found or restored. Otherwise False.
"""
with
self
.
strategy
.
scope
():
# Checkpoint restoring should be inside scope. b/139450638
if
checkpoint_path
is
not
None
:
self
.
checkpoint_manager
.
checkpoint
.
restore
(
checkpoint_path
)
return
True
return
self
.
checkpoint_manager
.
restore_or_initialize
()
def
_evaluate_once
(
self
,
current_step
):
"""Runs the evaluation once."""
logging
.
info
(
"Start evaluation at step: %s"
,
current_step
)
with
self
.
eval_summary_manager
.
summary_writer
.
as_default
():
eval_outputs
=
self
.
eval_fn
(
self
.
eval_steps
)
if
eval_outputs
:
eval_outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
(),
eval_outputs
)
info
=
"step: {} evaluation metric: {}"
.
format
(
current_step
,
eval_outputs
)
self
.
_log_info
(
info
)
self
.
eval_summary_manager
.
write_summaries
(
eval_outputs
)
self
.
eval_summary_manager
.
flush
()
def
_maybe_save_checkpoints
(
self
,
current_step
,
force_trigger
=
False
):
if
self
.
checkpoint_manager
and
self
.
checkpoint_manager
.
checkpoint_interval
:
ckpt_path
=
self
.
checkpoint_manager
.
save
(
checkpoint_number
=
current_step
,
check_interval
=
not
force_trigger
)
if
ckpt_path
is
not
None
:
logging
.
info
(
"Saved checkpoins in %s"
,
ckpt_path
)
def
_maybe_evaluate
(
self
,
current_step
,
force_trigger
=
False
):
if
self
.
eval_trigger
(
current_step
,
force_trigger
):
self
.
_evaluate_once
(
current_step
)
def
_log_info
(
self
,
message
):
"""Logs `message` to the `info` log, and also prints to stdout."""
logging
.
info
(
message
)
print
(
message
)
def
train
(
self
,
evaluate
=
True
):
"""Runs the training, with optional evaluation.
This handles evaluation, gathering summaries, and saving checkpoints.
Args:
evaluate: A boolean indicates whether to perform evaluation during
training.
Raises:
RuntimeError: If `global_step` is not updated correctly in `train_fn`.
"""
if
self
.
train_fn
is
None
:
raise
ValueError
(
"`self.train_fn` is required when calling `train` "
"method."
)
if
self
.
global_step
is
None
:
raise
ValueError
(
"`self.global_step` is required when calling `train` "
"method."
)
if
evaluate
and
self
.
eval_fn
is
None
:
raise
ValueError
(
"`self.eval_fn` is required when calling `train` method "
"with `evaluate=True`"
)
step_timer
=
_StepTimer
(
self
.
global_step
)
current_step
=
self
.
global_step
.
numpy
()
logging
.
info
(
"Train at step %s of %s"
,
current_step
,
self
.
train_steps
)
while
current_step
<
self
.
train_steps
:
# Calculates steps to run for the next train loop.
steps_per_loop
=
min
(
self
.
train_steps
-
current_step
,
self
.
steps_per_loop
)
logging
.
info
(
"Entering training loop with %s steps, at step %s of %s"
,
steps_per_loop
,
current_step
,
self
.
train_steps
)
current_step
+=
steps_per_loop
steps_per_loop
=
tf
.
convert_to_tensor
(
steps_per_loop
,
dtype
=
tf
.
int32
)
with
self
.
summary_manager
.
summary_writer
.
as_default
():
train_outputs
=
self
.
train_fn
(
steps_per_loop
)
# Updates and verifies the current step after a training loop finishes.
if
current_step
!=
self
.
global_step
.
numpy
():
raise
RuntimeError
(
"`self.train_fn` is not updating `global_step` "
"correctly, expected: %s, actual: %s"
%
(
current_step
,
self
.
global_step
.
numpy
()))
# Print information like metrics and steps_per_second after a training
# loop.
if
train_outputs
:
train_outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
(),
train_outputs
)
steps_per_second
=
step_timer
.
steps_per_second
()
info
=
"step: {} steps_per_second: {:.2f} {}"
.
format
(
current_step
,
steps_per_second
,
train_outputs
)
self
.
_log_info
(
info
)
train_outputs
=
train_outputs
or
{}
train_outputs
[
"steps_per_second"
]
=
steps_per_second
self
.
summary_manager
.
write_summaries
(
train_outputs
)
self
.
_maybe_save_checkpoints
(
current_step
)
if
evaluate
:
self
.
_maybe_evaluate
(
current_step
)
self
.
summary_manager
.
write_summaries
(
train_outputs
,
always_write
=
True
)
self
.
summary_manager
.
flush
()
self
.
_maybe_save_checkpoints
(
current_step
,
force_trigger
=
True
)
if
evaluate
:
self
.
_maybe_evaluate
(
current_step
,
force_trigger
=
True
)
def
evaluate
(
self
,
continuous
=
False
,
timeout_fn
=
None
):
"""Runs the evaluation.
Args:
continuous: If `True`, will continously monitor the checkpoint directory
to evaluate on the latest checkpoint. If `False`, will do the evaluation
once.
timeout_fn: Optional callable to call after a timeout. If the function
returns True, then it means that no new checkpoints will be generated
and the iterator will exit.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
"""
if
self
.
eval_fn
is
None
:
raise
ValueError
(
"`self.eval_fn` should not be None to call "
"`evaluate()` method."
)
if
not
continuous
and
timeout_fn
is
not
None
:
raise
ValueError
(
"`timeout_fn` can be only passed when `continuous` is "
"True"
)
if
continuous
:
for
checkpoint_path
in
tf
.
train
.
checkpoints_iterator
(
self
.
checkpoint_manager
.
directory
,
timeout_fn
=
timeout_fn
):
self
.
_restore_model
(
checkpoint_path
)
self
.
_evaluate_once
(
self
.
global_step
.
numpy
())
return
latest_checkpoint
=
self
.
checkpoint_manager
.
latest_checkpoint
if
not
latest_checkpoint
:
raise
ValueError
(
"no checkpoint found in dir %s"
%
self
.
checkpoint_manager
.
directory
)
self
.
_restore_model
()
self
.
_evaluate_once
(
self
.
global_step
.
numpy
())
class
_StepTimer
(
object
):
"""Utility class for measuring steps/second."""
def
__init__
(
self
,
step
):
self
.
step
=
step
self
.
start
()
def
start
(
self
):
self
.
last_iteration
=
self
.
step
.
numpy
()
self
.
last_time
=
time
.
time
()
def
steps_per_second
(
self
,
restart
=
True
):
value
=
((
self
.
step
.
numpy
()
-
self
.
last_iteration
)
/
(
time
.
time
()
-
self
.
last_time
))
if
restart
:
self
.
start
()
return
value
official/staging/training/controller_test.py
deleted
100644 → 0
View file @
ea550ca9
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.staging.training.controller."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.staging.training
import
controller
from
official.staging.training
import
standard_runnable
def
all_strategy_combinations
():
"""Gets combinations of distribution strategies."""
return
combinations
.
combine
(
strategy
=
[
strategy_combinations
.
one_device_strategy
,
strategy_combinations
.
tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
strategy_combinations
.
mirrored_strategy_with_gpu_and_cpu
,
],
mode
=
"eager"
,
)
def
create_model
():
x
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
3
,),
name
=
"input"
)
y
=
tf
.
keras
.
layers
.
Dense
(
4
,
name
=
"dense"
)(
x
)
model
=
tf
.
keras
.
Model
(
x
,
y
)
return
model
def
summaries_with_matching_keyword
(
keyword
,
summary_dir
):
"""Yields summary protos matching given keyword from event file."""
event_paths
=
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
summary_dir
,
"events*"
))
for
event
in
tf
.
compat
.
v1
.
train
.
summary_iterator
(
event_paths
[
-
1
]):
if
event
.
summary
is
not
None
:
for
value
in
event
.
summary
.
value
:
if
keyword
in
value
.
tag
:
tf
.
compat
.
v1
.
logging
.
error
(
event
)
yield
event
.
summary
def
check_eventfile_for_keyword
(
keyword
,
summary_dir
):
"""Checks event files for the keyword."""
return
any
(
summaries_with_matching_keyword
(
keyword
,
summary_dir
))
def
dataset_fn
(
ctx
):
del
ctx
inputs
=
np
.
zeros
((
10
,
3
),
dtype
=
np
.
float32
)
targets
=
np
.
zeros
((
10
,
4
),
dtype
=
np
.
float32
)
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
inputs
,
targets
))
dataset
=
dataset
.
repeat
(
100
)
dataset
=
dataset
.
batch
(
10
,
drop_remainder
=
True
)
return
dataset
class
TestRunnable
(
standard_runnable
.
StandardTrainable
,
standard_runnable
.
StandardEvaluable
):
"""Implements the training and evaluation APIs for the test model."""
def
__init__
(
self
):
standard_runnable
.
StandardTrainable
.
__init__
(
self
)
standard_runnable
.
StandardEvaluable
.
__init__
(
self
)
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
model
=
create_model
()
self
.
optimizer
=
tf
.
keras
.
optimizers
.
RMSprop
()
self
.
global_step
=
self
.
optimizer
.
iterations
self
.
train_loss
=
tf
.
keras
.
metrics
.
Mean
(
"train_loss"
,
dtype
=
tf
.
float32
)
self
.
eval_loss
=
tf
.
keras
.
metrics
.
Mean
(
"eval_loss"
,
dtype
=
tf
.
float32
)
def
build_train_dataset
(
self
):
return
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
def
train_step
(
self
,
iterator
):
def
_replicated_step
(
inputs
):
"""Replicated training step."""
inputs
,
targets
=
inputs
with
tf
.
GradientTape
()
as
tape
:
outputs
=
self
.
model
(
inputs
)
loss
=
tf
.
math
.
reduce_sum
(
outputs
-
targets
)
grads
=
tape
.
gradient
(
loss
,
self
.
model
.
variables
)
self
.
optimizer
.
apply_gradients
(
zip
(
grads
,
self
.
model
.
variables
))
self
.
train_loss
.
update_state
(
loss
)
self
.
strategy
.
run
(
_replicated_step
,
args
=
(
next
(
iterator
),))
def
train_loop_end
(
self
):
return
{
"loss"
:
self
.
train_loss
.
result
(),
}
def
build_eval_dataset
(
self
):
return
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
def
eval_begin
(
self
):
self
.
eval_loss
.
reset_states
()
def
eval_step
(
self
,
iterator
):
def
_replicated_step
(
inputs
):
"""Replicated evaluation step."""
inputs
,
targets
=
inputs
outputs
=
self
.
model
(
inputs
)
loss
=
tf
.
math
.
reduce_sum
(
outputs
-
targets
)
self
.
eval_loss
.
update_state
(
loss
)
self
.
strategy
.
run
(
_replicated_step
,
args
=
(
next
(
iterator
),))
def
eval_end
(
self
):
return
{
"eval_loss"
:
self
.
eval_loss
.
result
(),
}
class
ControllerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
ControllerTest
,
self
).
setUp
()
self
.
model_dir
=
self
.
get_temp_dir
()
def
test_no_checkpoint
(
self
):
test_runnable
=
TestRunnable
()
# No checkpoint manager and no strategy.
test_controller
=
controller
.
Controller
(
train_fn
=
test_runnable
.
train
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
steps_per_loop
=
2
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
summary_interval
=
2
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
# Loss and accuracy values should be written into summaries.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
# No checkpoint, so global step starts from 0.
test_runnable
.
global_step
.
assign
(
0
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
def
test_no_checkpoint_and_summaries
(
self
):
test_runnable
=
TestRunnable
()
# No checkpoint + summary directories.
test_controller
=
controller
.
Controller
(
train_fn
=
test_runnable
.
train
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
steps_per_loop
=
2
,
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_train_and_evaluate
(
self
,
strategy
):
with
strategy
.
scope
():
test_runnable
=
TestRunnable
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runnable
.
model
,
optimizer
=
test_runnable
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runnable
.
global_step
,
checkpoint_interval
=
10
)
test_controller
=
controller
.
Controller
(
strategy
=
strategy
,
train_fn
=
test_runnable
.
train
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
steps_per_loop
=
2
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
summary_interval
=
2
,
checkpoint_manager
=
checkpoint_manager
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
train
(
evaluate
=
True
)
# Checkpoints are saved.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
self
.
model_dir
,
"ckpt*"
)))
# Loss and accuracy values should be written into summaries.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_train_only
(
self
,
strategy
):
with
strategy
.
scope
():
test_runnable
=
TestRunnable
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runnable
.
model
,
optimizer
=
test_runnable
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runnable
.
global_step
,
checkpoint_interval
=
10
)
test_controller
=
controller
.
Controller
(
strategy
=
strategy
,
train_fn
=
test_runnable
.
train
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
steps_per_loop
=
2
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
summary_interval
=
2
,
checkpoint_manager
=
checkpoint_manager
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
)
test_controller
.
train
(
evaluate
=
False
)
# Checkpoints are saved.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
self
.
model_dir
,
"ckpt*"
)))
# Only train summaries are written.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertFalse
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_evaluate_only
(
self
,
strategy
):
with
strategy
.
scope
():
test_runnable
=
TestRunnable
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runnable
.
model
)
checkpoint
.
save
(
os
.
path
.
join
(
self
.
model_dir
,
"ckpt"
))
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runnable
.
global_step
)
test_controller
=
controller
.
Controller
(
strategy
=
strategy
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
evaluate
()
# Only eval summaries are written
self
.
assertFalse
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/staging/training/runnable.py
deleted
100644 → 0
View file @
ea550ca9
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
six
import
tensorflow.compat.v2
as
tf
from
typing
import
Dict
,
Optional
,
Text
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
AbstractTrainable
(
tf
.
Module
):
"""An abstract class defining the APIs required for training."""
@
abc
.
abstractmethod
def
train
(
self
,
num_steps
:
Optional
[
tf
.
Tensor
])
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
"""Implements model training with multiple steps.
In training, it is common to break the total training steps into several
training loops, so users can do checkpointing, write summaries and run some
python callbacks. This is necessary for getting good performance in TPU
training, as the overhead for launching a multi worker tf.function may be
large in Eager mode. It is usually encouraged to create a host training loop
(e.g. using a `tf.range` wrapping `strategy.run` inside a
`tf.function`) in the TPU case. For the cases that don't require host
training loop to acheive peak performance, users can just implement a simple
python loop to drive each step.
Args:
num_steps: A guideline for how many training steps to run. Note that it is
up to the model what constitutes a "step" (this may involve more than
one update to model parameters, e.g. if training a GAN).
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
AbstractEvaluable
(
tf
.
Module
):
"""An abstract class defining the APIs required for evaluation."""
@
abc
.
abstractmethod
def
evaluate
(
self
,
num_steps
:
Optional
[
tf
.
Tensor
])
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
"""Implements model evaluation.
Args:
num_steps: A guideline for how many evaluation steps to run. Note that it
is up to the model what constitutes a "step". Generally, it may be
desirable to support both a limited number of eval steps and iterating
over a full dataset (however many steps are required) when `num_steps`
is `None`.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
official/staging/training/standard_runnable.py
deleted
100644 → 0
View file @
ea550ca9
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
six
import
tensorflow.compat.v2
as
tf
from
typing
import
Dict
,
Optional
,
Text
from
official.staging.training
import
runnable
from
official.staging.training
import
utils
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
StandardTrainable
(
runnable
.
AbstractTrainable
):
"""Implements the standard functionality of AbstractTrainable APIs."""
def
__init__
(
self
,
use_tf_while_loop
=
True
,
use_tf_function
=
True
):
if
use_tf_while_loop
and
not
use_tf_function
:
raise
ValueError
(
"`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported"
)
self
.
use_tf_while_loop
=
use_tf_while_loop
self
.
use_tf_function
=
use_tf_function
self
.
train_dataset
=
None
self
.
train_iter
=
None
self
.
train_loop_fn
=
None
@
abc
.
abstractmethod
def
build_train_dataset
(
self
):
"""Builds the training datasets.
Returns:
A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
"""
pass
def
train
(
self
,
num_steps
:
Optional
[
tf
.
Tensor
])
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
"""See base class."""
if
self
.
train_dataset
is
None
:
# Build train input dataset
self
.
train_dataset
=
self
.
build_train_dataset
()
self
.
train_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
train_dataset
)
if
self
.
train_loop_fn
is
None
:
train_fn
=
self
.
train_step
if
self
.
use_tf_while_loop
:
self
.
train_loop_fn
=
utils
.
create_tf_while_loop_fn
(
train_fn
)
else
:
if
self
.
use_tf_function
:
train_fn
=
tf
.
function
(
train_fn
)
self
.
train_loop_fn
=
utils
.
create_loop_fn
(
train_fn
)
self
.
train_loop_begin
()
self
.
train_loop_fn
(
self
.
train_iter
,
num_steps
)
return
self
.
train_loop_end
()
def
train_loop_begin
(
self
):
"""Called once at the beginning of the training loop.
This is a good place to reset metrics that accumulate values over multiple
steps of training.
"""
pass
@
abc
.
abstractmethod
def
train_step
(
self
,
iterator
):
"""Implements one step of training.
What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.run`.
Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator.
"""
pass
def
train_loop_end
(
self
)
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
"""Called at the end of the training loop.
This is a good place to get metric results. The value returned from this
function will be returned as-is from the train() method.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
StandardEvaluable
(
runnable
.
AbstractEvaluable
):
"""Implements the standard functionality of AbstractEvaluable APIs."""
def
__init__
(
self
,
use_tf_function
=
True
):
self
.
eval_use_tf_function
=
use_tf_function
self
.
eval_dataset
=
None
self
.
eval_loop_fn
=
None
@
abc
.
abstractmethod
def
build_eval_dataset
(
self
):
"""Builds the evaluation datasets.
Returns:
A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
"""
pass
def
evaluate
(
self
,
num_steps
:
Optional
[
tf
.
Tensor
])
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
"""See base class."""
if
self
.
eval_dataset
is
None
:
# Build train input dataset
self
.
eval_dataset
=
self
.
build_eval_dataset
()
if
self
.
eval_loop_fn
is
None
:
eval_fn
=
self
.
eval_step
if
self
.
eval_use_tf_function
:
eval_fn
=
tf
.
function
(
eval_fn
)
self
.
eval_loop_fn
=
utils
.
create_loop_fn
(
eval_fn
)
eval_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_dataset
)
self
.
eval_begin
()
self
.
eval_loop_fn
(
eval_iter
,
num_steps
)
return
self
.
eval_end
()
def
eval_begin
(
self
):
"""Called once at the beginning of the evaluation.
This is a good place to reset metrics that accumulate values over the entire
evaluation.
"""
pass
@
abc
.
abstractmethod
def
eval_step
(
self
,
iterator
):
"""Implements one step of evaluation.
What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.run`.
Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator.
"""
pass
def
eval_end
(
self
)
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
"""Called at the end of the evaluation.
This is a good place to get metric results. The value returned from this
function will be returned as-is from the evaluate() method.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
official/staging/training/utils.py
deleted
100644 → 0
View file @
ea550ca9
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
inspect
import
six
import
tensorflow.compat.v2
as
tf
def
create_loop_fn
(
step_fn
):
"""Creates a multiple steps function driven by the python while loop.
Args:
step_fn: A function which takes `iterator` as input.
Returns:
A callable defined as the `loop_fn` defination below.
"""
def
loop_fn
(
iterator
,
num_steps
,
state
=
None
,
reduce_fn
=
None
):
"""A loop function with multiple steps.
Args:
iterator: A nested structure of tf.data `Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. If `num_steps==-1`, will
iterate until exausting the iterator.
state: An optional initial state before running the loop.
reduce_fn: a callable defined as `def reduce_fn(state, value)`, where
`value` is the outputs from `step_fn`.
Returns:
The updated state.
"""
try
:
step
=
0
# To make sure the OutOfRangeError exception can be handled well with
# async remote eager, we need to wrap the loop body in a `async_scope`.
with
tf
.
experimental
.
async_scope
():
while
(
num_steps
==
-
1
or
step
<
num_steps
):
outputs
=
step_fn
(
iterator
)
if
reduce_fn
is
not
None
:
state
=
reduce_fn
(
state
,
outputs
)
step
+=
1
return
state
except
(
StopIteration
,
tf
.
errors
.
OutOfRangeError
):
tf
.
experimental
.
async_clear_error
()
return
state
return
loop_fn
def
create_tf_while_loop_fn
(
step_fn
):
"""Create a multiple steps function driven by tf.while_loop on the host.
Args:
step_fn: A function which takes `iterator` as input.
Returns:
A callable defined as the `loop_fn` defination below.
"""
@
tf
.
function
def
loop_fn
(
iterator
,
num_steps
):
"""A loop function with multiple steps.
Args:
iterator: A nested structure of tf.data `Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Must be a tf.Tensor.
"""
if
not
isinstance
(
num_steps
,
tf
.
Tensor
):
raise
ValueError
(
"`num_steps` should be an `tf.Tensor`. Python object "
"may cause retracing."
)
for
_
in
tf
.
range
(
num_steps
):
step_fn
(
iterator
)
return
loop_fn
def
make_distributed_dataset
(
strategy
,
dataset_or_fn
,
*
args
,
**
kwargs
):
"""A helper function to create distributed dataset.
Args:
strategy: An instance of `tf.distribute.Strategy`.
dataset_or_fn: A instance of `tf.data.Dataset` or a function which takes an
`tf.distribute.InputContext` as input and returns a `tf.data.Dataset`. If
it is a function, it could optionally have an argument named
`input_context` which is `tf.distribute.InputContext` argument type.
*args: The list of arguments to be passed to dataset_or_fn.
**kwargs: Any keyword arguments to be passed.
Returns:
A distributed Dataset.
"""
if
strategy
is
None
:
strategy
=
tf
.
distribute
.
get_strategy
()
if
isinstance
(
dataset_or_fn
,
tf
.
data
.
Dataset
):
return
strategy
.
experimental_distribute_dataset
(
dataset_or_fn
)
if
not
callable
(
dataset_or_fn
):
raise
ValueError
(
"`dataset_or_fn` should be either callable or an instance "
"of `tf.data.Dataset`"
)
def
dataset_fn
(
ctx
):
"""Wrapped dataset function for creating distributed dataset.."""
# If `dataset_or_fn` is a function and has `input_context` as argument
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
if
six
.
PY3
:
argspec
=
inspect
.
getfullargspec
(
dataset_or_fn
)
else
:
argspec
=
inspect
.
getargspec
(
dataset_or_fn
)
args_names
=
argspec
.
args
if
"input_context"
in
args_names
:
kwargs
[
"input_context"
]
=
ctx
ds
=
dataset_or_fn
(
*
args
,
**
kwargs
)
return
ds
return
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
class
SummaryManager
(
object
):
"""A class manages writing summaries."""
def
__init__
(
self
,
summary_writer
,
summary_fn
,
global_step
=
None
,
summary_interval
=
None
):
"""Construct a summary manager object.
Args:
summary_writer: A `tf.summary.SummaryWriter` instance for writing
summaries.
summary_fn: A callable defined as `def summary_fn(name, tensor,
step=None)`, which describes the summary operation.
global_step: A `tf.Variable` instance for checking the current global step
value, in case users want to save summaries every N steps.
summary_interval: An integer, indicates the minimum step interval between
two summaries.
"""
if
summary_writer
is
not
None
:
self
.
_summary_writer
=
summary_writer
self
.
_enabled
=
True
else
:
self
.
_summary_writer
=
tf
.
summary
.
create_noop_writer
()
self
.
_enabled
=
False
self
.
_summary_fn
=
summary_fn
if
global_step
is
None
:
self
.
_global_step
=
tf
.
summary
.
experimental
.
get_step
()
else
:
self
.
_global_step
=
global_step
if
summary_interval
is
not
None
:
if
self
.
_global_step
is
None
:
raise
ValueError
(
"`summary_interval` is not None, but no `global_step` "
"can be obtained "
)
self
.
_last_summary_step
=
self
.
_global_step
.
numpy
()
self
.
_summary_interval
=
summary_interval
@
property
def
summary_interval
(
self
):
return
self
.
_summary_interval
@
property
def
summary_writer
(
self
):
"""Returns the underlying summary writer."""
return
self
.
_summary_writer
def
flush
(
self
):
"""Flush the underlying summary writer."""
if
self
.
_enabled
:
tf
.
summary
.
flush
(
self
.
_summary_writer
)
def
write_summaries
(
self
,
items
,
always_write
=
True
):
"""Write a bulk of summaries.
Args:
items: a dictionary of `Tensors` for writing summaries.
always_write: An optional boolean. If `True`, the manager will always
write summaries unless the summaries have been written for the same
step. Otherwise the manager will only write the summaries if the
interval between summaries are larger than `summary_interval`.
Returns:
A boolean indicates whether the summaries are written or not.
"""
# TODO(rxsang): Support writing summaries with nested structure, so users
# can split the summaries into different directories for nicer visualization
# in Tensorboard, like train and eval metrics.
if
not
self
.
_enabled
:
return
False
if
self
.
_summary_interval
is
not
None
:
current_step
=
self
.
_global_step
.
numpy
()
if
current_step
==
self
.
_last_summary_step
:
return
False
if
not
always_write
and
current_step
<
(
self
.
_last_summary_step
+
self
.
_summary_interval
):
return
False
self
.
_last_summary_step
=
current_step
with
self
.
_summary_writer
.
as_default
():
for
name
,
tensor
in
items
.
items
():
self
.
_summary_fn
(
name
,
tensor
,
step
=
self
.
_global_step
)
return
True
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
Trigger
(
object
):
"""An abstract class representing a "trigger" for some event."""
@
abc
.
abstractmethod
def
__call__
(
self
,
value
:
float
,
force_trigger
=
False
):
"""Maybe trigger the event based on the given value.
Args:
value: the value for triggering.
force_trigger: Whether the trigger is forced triggered.
Returns:
`True` if the trigger is triggered on the given `value`, and
`False` otherwise.
"""
@
abc
.
abstractmethod
def
reset
(
self
):
"""Reset states in the trigger."""
class
IntervalTrigger
(
Trigger
):
"""Triggers on every fixed interval."""
def
__init__
(
self
,
interval
,
start
=
0
):
"""Constructs the IntervalTrigger.
Args:
interval: The triggering interval.
start: An initial value for the trigger.
"""
self
.
_interval
=
interval
self
.
_last_trigger_value
=
start
def
__call__
(
self
,
value
,
force_trigger
=
False
):
"""Maybe trigger the event based on the given value.
Args:
value: the value for triggering.
force_trigger: If True, the trigger will be forced triggered unless the
last trigger value is equal to `value`.
Returns:
`True` if the trigger is triggered on the given `value`, and
`False` otherwise.
"""
if
force_trigger
and
value
!=
self
.
_last_trigger_value
:
self
.
_last_trigger_value
=
value
return
True
if
self
.
_interval
and
self
.
_interval
>
0
:
if
value
>=
self
.
_last_trigger_value
+
self
.
_interval
:
self
.
_last_trigger_value
=
value
return
True
return
False
def
reset
(
self
):
"""See base class."""
self
.
_last_trigger_value
=
0
class
EpochHelper
(
object
):
"""A Helper class to handle epochs in Customized Training Loop."""
def
__init__
(
self
,
epoch_steps
,
global_step
):
"""Constructs the EpochHelper.
Args:
epoch_steps: An integer indicates how many steps in an epoch.
global_step: A `tf.Variable` instance indicates the current global step.
"""
self
.
_epoch_steps
=
epoch_steps
self
.
_global_step
=
global_step
self
.
_current_epoch
=
None
self
.
_epoch_start_step
=
None
self
.
_in_epoch
=
False
def
epoch_begin
(
self
):
"""Returns whether a new epoch should begin."""
if
self
.
_in_epoch
:
return
False
current_step
=
self
.
_global_step
.
numpy
()
self
.
_epoch_start_step
=
current_step
self
.
_current_epoch
=
current_step
//
self
.
_epoch_steps
self
.
_in_epoch
=
True
return
True
def
epoch_end
(
self
):
"""Returns whether the current epoch should end."""
if
not
self
.
_in_epoch
:
raise
ValueError
(
"`epoch_end` can only be called inside an epoch"
)
current_step
=
self
.
_global_step
.
numpy
()
epoch
=
current_step
//
self
.
_epoch_steps
if
epoch
>
self
.
_current_epoch
:
self
.
_in_epoch
=
False
return
True
return
False
@
property
def
batch_index
(
self
):
"""Index of the next batch within the current epoch."""
return
self
.
_global_step
.
numpy
()
-
self
.
_epoch_start_step
@
property
def
current_epoch
(
self
):
return
self
.
_current_epoch
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
View file @
d56d51d0
...
@@ -14,18 +14,16 @@
...
@@ -14,18 +14,16 @@
# ==============================================================================
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
math
import
os
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.staging.training
import
controller
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
...
@@ -87,15 +85,6 @@ def get_num_train_iterations(flags_obj):
...
@@ -87,15 +85,6 @@ def get_num_train_iterations(flags_obj):
return
train_steps
,
train_epochs
,
eval_steps
return
train_steps
,
train_epochs
,
eval_steps
def
_steps_to_run
(
steps_in_current_epoch
,
steps_per_epoch
,
steps_per_loop
):
"""Calculates steps to run on device."""
if
steps_per_loop
<=
0
:
raise
ValueError
(
'steps_per_loop should be positive integer.'
)
if
steps_per_loop
==
1
:
return
steps_per_loop
return
min
(
steps_per_loop
,
steps_per_epoch
-
steps_in_current_epoch
)
def
run
(
flags_obj
):
def
run
(
flags_obj
):
"""Run ResNet ImageNet training and eval loop using custom training loops.
"""Run ResNet ImageNet training and eval loop using custom training loops.
...
@@ -121,7 +110,6 @@ def run(flags_obj):
...
@@ -121,7 +110,6 @@ def run(flags_obj):
datasets_num_private_threads
=
flags_obj
.
datasets_num_private_threads
)
datasets_num_private_threads
=
flags_obj
.
datasets_num_private_threads
)
common
.
set_cudnn_batchnorm_mode
()
common
.
set_cudnn_batchnorm_mode
()
# TODO(anj-s): Set data_format without using Keras.
data_format
=
flags_obj
.
data_format
data_format
=
flags_obj
.
data_format
if
data_format
is
None
:
if
data_format
is
None
:
data_format
=
(
'channels_first'
if
tf
.
config
.
list_physical_devices
(
'GPU'
)
data_format
=
(
'channels_first'
if
tf
.
config
.
list_physical_devices
(
'GPU'
)
...
@@ -137,7 +125,14 @@ def run(flags_obj):
...
@@ -137,7 +125,14 @@ def run(flags_obj):
per_epoch_steps
,
train_epochs
,
eval_steps
=
get_num_train_iterations
(
per_epoch_steps
,
train_epochs
,
eval_steps
=
get_num_train_iterations
(
flags_obj
)
flags_obj
)
steps_per_loop
=
min
(
flags_obj
.
steps_per_loop
,
per_epoch_steps
)
if
flags_obj
.
steps_per_loop
is
None
:
steps_per_loop
=
per_epoch_steps
elif
flags_obj
.
steps_per_loop
>
per_epoch_steps
:
steps_per_loop
=
per_epoch_steps
logging
.
warn
(
'Setting steps_per_loop to %d to respect epoch boundary.'
,
steps_per_loop
)
else
:
steps_per_loop
=
flags_obj
.
steps_per_loop
logging
.
info
(
logging
.
info
(
'Training %d epochs, each epoch has %d steps, '
'Training %d epochs, each epoch has %d steps, '
...
@@ -154,8 +149,8 @@ def run(flags_obj):
...
@@ -154,8 +149,8 @@ def run(flags_obj):
eval_interval
=
flags_obj
.
epochs_between_evals
*
per_epoch_steps
eval_interval
=
flags_obj
.
epochs_between_evals
*
per_epoch_steps
checkpoint_interval
=
(
checkpoint_interval
=
(
per_epoch_steps
if
flags_obj
.
enable_checkpoint_and_export
else
None
)
steps_per_loop
*
5
if
flags_obj
.
enable_checkpoint_and_export
else
None
)
summary_interval
=
per_epoch_steps
if
flags_obj
.
enable_tensorboard
else
None
summary_interval
=
steps_per_loop
if
flags_obj
.
enable_tensorboard
else
None
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
runnable
.
checkpoint
,
runnable
.
checkpoint
,
...
@@ -164,20 +159,24 @@ def run(flags_obj):
...
@@ -164,20 +159,24 @@ def run(flags_obj):
step_counter
=
runnable
.
global_step
,
step_counter
=
runnable
.
global_step
,
checkpoint_interval
=
checkpoint_interval
)
checkpoint_interval
=
checkpoint_interval
)
resnet_controller
=
controller
.
Controller
(
resnet_controller
=
orbit
.
Controller
(
strategy
,
strategy
,
runnable
.
train
,
runnable
,
runnable
.
evaluate
if
not
flags_obj
.
skip_eval
else
None
,
runnable
if
not
flags_obj
.
skip_eval
else
None
,
global_step
=
runnable
.
global_step
,
global_step
=
runnable
.
global_step
,
steps_per_loop
=
steps_per_loop
,
steps_per_loop
=
steps_per_loop
,
train_steps
=
per_epoch_steps
*
train_epochs
,
checkpoint_manager
=
checkpoint_manager
,
checkpoint_manager
=
checkpoint_manager
,
summary_interval
=
summary_interval
,
summary_interval
=
summary_interval
,
eval_steps
=
eval_steps
,
eval_summary_dir
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'eval'
))
eval_interval
=
eval_interval
)
time_callback
.
on_train_begin
()
time_callback
.
on_train_begin
()
resnet_controller
.
train
(
evaluate
=
not
flags_obj
.
skip_eval
)
if
not
flags_obj
.
skip_eval
:
resnet_controller
.
train_and_evaluate
(
train_steps
=
per_epoch_steps
*
train_epochs
,
eval_steps
=
eval_steps
,
eval_interval
=
eval_interval
)
else
:
resnet_controller
.
train
(
steps
=
per_epoch_steps
*
train_epochs
)
time_callback
.
on_train_end
()
time_callback
.
on_train_end
()
stats
=
build_stats
(
runnable
,
time_callback
)
stats
=
build_stats
(
runnable
,
time_callback
)
...
...
official/vision/image_classification/resnet/resnet_runnable.py
View file @
d56d51d0
...
@@ -14,33 +14,21 @@
...
@@ -14,33 +14,21 @@
# ==============================================================================
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
from
__future__
import
absolute_import
import
orbit
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.staging.training
import
grad_utils
from
official.staging.training
import
grad_utils
from
official.staging.training
import
standard_runnable
from
official.staging.training
import
utils
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
imagenet_preprocessing
from
official.vision.image_classification.resnet
import
imagenet_preprocessing
from
official.vision.image_classification.resnet
import
resnet_model
from
official.vision.image_classification.resnet
import
resnet_model
class
ResnetRunnable
(
standard_runnable
.
StandardTrainable
,
class
ResnetRunnable
(
orbit
.
StandardTrainer
,
orbit
.
StandardEvaluator
):
standard_runnable
.
StandardEvaluable
):
"""Implements the training and evaluation APIs for Resnet model."""
"""Implements the training and evaluation APIs for Resnet model."""
def
__init__
(
self
,
flags_obj
,
time_callback
,
epoch_steps
):
def
__init__
(
self
,
flags_obj
,
time_callback
,
epoch_steps
):
standard_runnable
.
StandardTrainable
.
__init__
(
self
,
flags_obj
.
use_tf_while_loop
,
flags_obj
.
use_tf_function
)
standard_runnable
.
StandardEvaluable
.
__init__
(
self
,
flags_obj
.
use_tf_function
)
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
flags_obj
=
flags_obj
self
.
flags_obj
=
flags_obj
self
.
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
self
.
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
...
@@ -107,11 +95,8 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
...
@@ -107,11 +95,8 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Handling epochs.
# Handling epochs.
self
.
epoch_steps
=
epoch_steps
self
.
epoch_steps
=
epoch_steps
self
.
epoch_helper
=
utils
.
EpochHelper
(
epoch_steps
,
self
.
global_step
)
self
.
epoch_helper
=
orbit
.
utils
.
EpochHelper
(
epoch_steps
,
self
.
global_step
)
train_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
def
build_train_dataset
(
self
):
"""See base class."""
return
utils
.
make_distributed_dataset
(
self
.
strategy
,
self
.
strategy
,
self
.
input_fn
,
self
.
input_fn
,
is_training
=
True
,
is_training
=
True
,
...
@@ -122,17 +107,20 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
...
@@ -122,17 +107,20 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
.
datasets_num_private_threads
,
.
datasets_num_private_threads
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
drop_remainder
=
True
)
drop_remainder
=
True
)
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
def
build_eval_dataset
(
self
):
flags_obj
.
use_tf_while_loop
,
"""See base class."""
flags_obj
.
use_tf_function
)
return
utils
.
make_distributed_dataset
(
if
not
flags_obj
.
skip_eval
:
self
.
strategy
,
eval_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
input_fn
,
self
.
strategy
,
is_training
=
False
,
self
.
input_fn
,
data_dir
=
self
.
flags_obj
.
data_dir
,
is_training
=
False
,
batch_size
=
self
.
batch_size
,
data_dir
=
self
.
flags_obj
.
data_dir
,
parse_record_fn
=
imagenet_preprocessing
.
parse_record
,
batch_size
=
self
.
batch_size
,
dtype
=
self
.
dtype
)
parse_record_fn
=
imagenet_preprocessing
.
parse_record
,
dtype
=
self
.
dtype
)
orbit
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
,
flags_obj
.
use_tf_function
)
def
train_loop_begin
(
self
):
def
train_loop_begin
(
self
):
"""See base class."""
"""See base class."""
...
...
orbit/controller.py
View file @
d56d51d0
...
@@ -151,8 +151,10 @@ class Controller(object):
...
@@ -151,8 +151,10 @@ class Controller(object):
checkpoint_interval
,
steps_per_loop
,
interval_name
=
"checkpoint"
)
checkpoint_interval
,
steps_per_loop
,
interval_name
=
"checkpoint"
)
model_restored
=
self
.
restore_checkpoint
()
model_restored
=
self
.
restore_checkpoint
()
if
not
model_restored
and
checkpoint_interval
:
if
not
model_restored
and
(
checkpoint_interval
and
# If the model is not restored from a checkpoint, save an initial
self
.
trainer
is
not
None
):
# If the model is not restored from a checkpoint, and
# `checkpoint_interval` is enabled for training, save an initial
# checkpoint.
# checkpoint.
self
.
save_checkpoint
()
self
.
save_checkpoint
()
...
...
research/object_detection/README.md
View file @
d56d51d0
...
@@ -54,16 +54,21 @@ Note: The models we provide in [TF2 Zoo](g3doc/tf2_detection_zoo.md) and
...
@@ -54,16 +54,21 @@ Note: The models we provide in [TF2 Zoo](g3doc/tf2_detection_zoo.md) and
[
TF1 Zoo
](
g3doc/tf1_detection_zoo.md
)
are specific to the TensorFlow major
[
TF1 Zoo
](
g3doc/tf1_detection_zoo.md
)
are specific to the TensorFlow major
version and are not interoperable.
version and are not interoperable.
Please select one of the
two
links below for TensorFlow version
specific
Please select one of the links below for TensorFlow version
-
specific
documentation of the Object Detection API:
documentation of the Object Detection API:
<!-- mdlint off(WHITESPACE_LINE_LENGTH) -->
<!-- mdlint off(WHITESPACE_LINE_LENGTH) -->
### Tensorflow 2.x
*
<a
href=
'g3doc/tf2.md'
>
|
[

](g3doc/tf2.md) |
[

](g3doc/tf2_detection_zoo.md) |
Object Detection API TensorFlow 2
</a><br>
|---|---|
*
<a
href=
'g3doc/tf2_detection_zoo.md'
>
|
[

](g3doc/tf1.md) |
[

](g3doc/tf1_detection_zoo.md) |
TensorFlow 2 Model Zoo
</a><br>
### Tensorflow 1.x
*
<a
href=
'g3doc/tf1.md'
>
Object Detection API TensorFlow 1
</a><br>
*
<a
href=
'g3doc/tf1_detection_zoo.md'
>
TensorFlow 1 Model Zoo
</a><br>
<!-- mdlint on -->
<!-- mdlint on -->
## Whats New
## Whats New
...
...
research/object_detection/g3doc/tf1.md
View file @
d56d51d0
...
@@ -73,6 +73,8 @@ the [Model Zoo](tf1_detection_zoo.md).
...
@@ -73,6 +73,8 @@ the [Model Zoo](tf1_detection_zoo.md).
Supported object detection evaluation protocols
</a><br>
Supported object detection evaluation protocols
</a><br>
*
<a
href=
'tpu_compatibility.md'
>
*
<a
href=
'tpu_compatibility.md'
>
TPU compatible detection pipelines
</a><br>
TPU compatible detection pipelines
</a><br>
*
<a
href=
'tf1_training_and_evaluation.md'
>
Training and evaluation guide (CPU, GPU, or TPU)
</a><br>
## Extras:
## Extras:
...
...
research/object_detection/g3doc/tf2.md
View file @
d56d51d0
...
@@ -80,3 +80,5 @@ We provide a large collection of models that are trained on COCO 2017 in the
...
@@ -80,3 +80,5 @@ We provide a large collection of models that are trained on COCO 2017 in the
Supported object detection evaluation protocols
</a><br>
Supported object detection evaluation protocols
</a><br>
*
<a
href=
'tpu_compatibility.md'
>
*
<a
href=
'tpu_compatibility.md'
>
TPU compatible detection pipelines
</a><br>
TPU compatible detection pipelines
</a><br>
*
<a
href=
'tf2_training_and_evaluation.md'
>
Training and evaluation guide (CPU, GPU, or TPU)
</a><br>
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment