Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5a2cf36f
Commit
5a2cf36f
authored
Jul 23, 2020
by
Kaushik Shivakumar
Browse files
Merge remote-tracking branch 'upstream/master' into newavarecords
parents
258ddfc3
a829e648
Changes
330
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
141 additions
and
839 deletions
+141
-839
official/nlp/transformer/beam_search.py
official/nlp/transformer/beam_search.py
+0
-132
official/nlp/transformer/beam_search_v1.py
official/nlp/transformer/beam_search_v1.py
+6
-602
official/nlp/transformer/embedding_layer.py
official/nlp/transformer/embedding_layer.py
+1
-0
official/nlp/transformer/transformer.py
official/nlp/transformer/transformer.py
+1
-3
official/nlp/transformer/transformer_main.py
official/nlp/transformer/transformer_main.py
+0
-3
official/nlp/transformer/utils/metrics.py
official/nlp/transformer/utils/metrics.py
+20
-19
official/recommendation/ncf_common.py
official/recommendation/ncf_common.py
+1
-1
official/recommendation/ncf_input_pipeline.py
official/recommendation/ncf_input_pipeline.py
+14
-25
official/recommendation/ncf_keras_main.py
official/recommendation/ncf_keras_main.py
+6
-6
official/requirements.txt
official/requirements.txt
+0
-1
official/vision/detection/modeling/architecture/fpn.py
official/vision/detection/modeling/architecture/fpn.py
+2
-2
official/vision/detection/modeling/architecture/heads.py
official/vision/detection/modeling/architecture/heads.py
+11
-8
official/vision/detection/modeling/architecture/keras_utils.py
...ial/vision/detection/modeling/architecture/keras_utils.py
+43
-0
official/vision/detection/modeling/architecture/resnet.py
official/vision/detection/modeling/architecture/resnet.py
+2
-2
official/vision/detection/modeling/architecture/spinenet.py
official/vision/detection/modeling/architecture/spinenet.py
+2
-2
official/vision/detection/modeling/maskrcnn_model.py
official/vision/detection/modeling/maskrcnn_model.py
+2
-2
official/vision/detection/modeling/retinanet_model.py
official/vision/detection/modeling/retinanet_model.py
+3
-3
official/vision/detection/modeling/shapemask_model.py
official/vision/detection/modeling/shapemask_model.py
+2
-2
official/vision/image_classification/resnet/common.py
official/vision/image_classification/resnet/common.py
+1
-1
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
...n/image_classification/resnet/resnet_ctl_imagenet_main.py
+24
-25
No files found.
official/nlp/transformer/beam_search.py
deleted
100644 → 0
View file @
258ddfc3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search in TF v2."""
import
tensorflow
as
tf
from
official.nlp.transformer
import
beam_search_v1
as
v1
_StateKeys
=
v1
.
_StateKeys
# pylint: disable=protected-access
class
SequenceBeamSearchV2
(
v1
.
SequenceBeamSearch
):
"""Implementation of beam search loop in v2."""
def
search
(
self
,
initial_ids
,
initial_cache
):
"""Beam search for sequences with highest scores."""
state
,
state_shapes
=
self
.
_create_initial_state
(
initial_ids
,
initial_cache
)
finished_state
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
while_loop
(
self
.
_continue_search
,
self
.
_search_step
,
loop_vars
=
[
state
],
shape_invariants
=
[
state_shapes
],
parallel_iterations
=
1
))
finished_state
=
finished_state
[
0
]
alive_seq
=
finished_state
[
_StateKeys
.
ALIVE_SEQ
]
alive_log_probs
=
finished_state
[
_StateKeys
.
ALIVE_LOG_PROBS
]
finished_seq
=
finished_state
[
_StateKeys
.
FINISHED_SEQ
]
finished_scores
=
finished_state
[
_StateKeys
.
FINISHED_SCORES
]
finished_flags
=
finished_state
[
_StateKeys
.
FINISHED_FLAGS
]
# 2.0 changes tf.where behavior. Should make parameters broadcastable.
finished_cond
=
tf
.
reduce_any
(
finished_flags
,
1
,
name
=
"finished_cond"
)
seq_cond
=
_expand_to_same_rank
(
finished_cond
,
finished_seq
)
score_cond
=
_expand_to_same_rank
(
finished_cond
,
finished_scores
)
# Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch
# item.
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
return
finished_seq
,
finished_scores
def
sequence_beam_search
(
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
=
False
,
dtype
=
"float32"
):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> A nested dictionary with the same shape/structure as the
inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of tokens.
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
def
_expand_to_same_rank
(
tensor
,
target
):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target.
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if
tensor
.
shape
.
rank
is
None
:
raise
ValueError
(
"Expect rank for tensor shape, but got None."
)
if
target
.
shape
.
rank
is
None
:
raise
ValueError
(
"Expect rank for target shape, but got None."
)
with
tf
.
name_scope
(
"expand_rank"
):
diff_rank
=
target
.
shape
.
rank
-
tensor
.
shape
.
rank
for
_
in
range
(
diff_rank
):
tensor
=
tf
.
expand_dims
(
tensor
,
-
1
)
return
tensor
official/nlp/transformer/beam_search_v1.py
View file @
5a2cf36f
...
@@ -13,126 +13,18 @@
...
@@ -13,126 +13,18 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Beam search to find the translated sequence with the highest probability.
"""Beam search to find the translated sequence with the highest probability.
Source implementation from Tensor2Tensor:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam_search.py
"""
"""
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
from
tensorflow.python.util
import
nest
from
official.nlp.modeling.ops
import
beam_search
def
inf
(
dtype
):
"""Returns a value close to infinity, but is still finite in `dtype`.
This is useful to get a very large value that is still zero when multiplied by
zero. The floating-point "Inf" value is NaN when multiplied by zero.
Args:
dtype: A dtype. The returned value will be finite when casted to this dtype.
Returns:
A very large value.
"""
if
dtype
==
"float32"
or
dtype
==
"bfloat16"
:
return
1e7
elif
dtype
==
"float16"
:
# Disable no-member lint error, as the linter thinks np.float16 does not
# exist for some reason.
return
np
.
finfo
(
np
.
float16
).
max
# pylint: disable=no-member
else
:
raise
AssertionError
(
'Invalid dtype: %s'
%
dtype
)
class
_StateKeys
(
object
):
"""Keys to dictionary storing the state of the beam search loop."""
# Variable storing the loop index.
CUR_INDEX
=
"CUR_INDEX"
# Top sequences that are alive for each batch item. Alive sequences are ones
_StateKeys
=
beam_search
.
_StateKeys
# pylint: disable=protected-access
# that have not generated an EOS token. Sequences that reach EOS are marked as
# finished and moved to the FINISHED_SEQ tensor.
# Has shape [batch_size, beam_size, CUR_INDEX + 1]
ALIVE_SEQ
=
"ALIVE_SEQ"
# Log probabilities of each alive sequence. Shape [batch_size, beam_size]
ALIVE_LOG_PROBS
=
"ALIVE_LOG_PROBS"
# Dictionary of cached values for each alive sequence. The cache stores
# the encoder output, attention bias, and the decoder attention output from
# the previous iteration.
ALIVE_CACHE
=
"ALIVE_CACHE"
# Top finished sequences for each batch item.
# Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
# shorter than CUR_INDEX + 1 are padded with 0s.
FINISHED_SEQ
=
"FINISHED_SEQ"
# Scores for each finished sequence. Score = log probability / length norm
# Shape [batch_size, beam_size]
FINISHED_SCORES
=
"FINISHED_SCORES"
# Flags indicating which sequences in the finished sequences are finished.
# At the beginning, all of the sequences in FINISHED_SEQ are filler values.
# True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
FINISHED_FLAGS
=
"FINISHED_FLAGS"
class
SequenceBeamSearch
(
beam_search
.
SequenceBeamSearch
):
class
SequenceBeamSearch
(
object
):
"""Implementation of beam search loop."""
"""Implementation of beam search loop."""
def
__init__
(
self
,
def
_process_finished_state
(
self
,
finished_state
):
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
=
tf
.
float32
):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and the updated cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
updated cache -> A nested dictionary with the same structure as the
input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
batch_size: An integer, the decode batch size.
beam_size: An integer, number of beams for beam search.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum number of steps to decode
a sequence.
eos_id: An integer. ID of end of sentence token.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
vocab_size
=
vocab_size
self
.
batch_size
=
batch_size
self
.
beam_size
=
beam_size
self
.
alpha
=
alpha
self
.
max_decode_length
=
max_decode_length
self
.
eos_id
=
eos_id
self
.
padded_decode
=
padded_decode
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
def
search
(
self
,
initial_ids
,
initial_cache
):
"""Beam search for sequences with highest scores."""
state
,
state_shapes
=
self
.
_create_initial_state
(
initial_ids
,
initial_cache
)
finished_state
=
tf
.
while_loop
(
self
.
_continue_search
,
self
.
_search_step
,
loop_vars
=
[
state
],
shape_invariants
=
[
state_shapes
],
parallel_iterations
=
1
,
back_prop
=
False
)
finished_state
=
finished_state
[
0
]
alive_seq
=
finished_state
[
_StateKeys
.
ALIVE_SEQ
]
alive_seq
=
finished_state
[
_StateKeys
.
ALIVE_SEQ
]
alive_log_probs
=
finished_state
[
_StateKeys
.
ALIVE_LOG_PROBS
]
alive_log_probs
=
finished_state
[
_StateKeys
.
ALIVE_LOG_PROBS
]
finished_seq
=
finished_state
[
_StateKeys
.
FINISHED_SEQ
]
finished_seq
=
finished_state
[
_StateKeys
.
FINISHED_SEQ
]
...
@@ -148,360 +40,6 @@ class SequenceBeamSearch(object):
...
@@ -148,360 +40,6 @@ class SequenceBeamSearch(object):
tf
.
reduce_any
(
finished_flags
,
1
),
finished_scores
,
alive_log_probs
)
tf
.
reduce_any
(
finished_flags
,
1
),
finished_scores
,
alive_log_probs
)
return
finished_seq
,
finished_scores
return
finished_seq
,
finished_scores
def
_create_initial_state
(
self
,
initial_ids
,
initial_cache
):
"""Return initial state dictionary and its shape invariants.
Args:
initial_ids: initial ids to pass into the symbols_to_logits_fn.
int tensor with shape [batch_size, 1]
initial_cache: dictionary storing values to be passed into the
symbols_to_logits_fn.
Returns:
state and shape invariant dictionaries with keys from _StateKeys
"""
for
key
,
value
in
initial_cache
.
items
():
for
inner_value
in
nest
.
flatten
(
value
):
if
inner_value
.
dtype
!=
self
.
dtype
:
raise
TypeError
(
"initial_cache element for key '%s' has dtype %s that does not "
"match SequenceBeamSearch's dtype of %s. Value: %s"
%
(
key
,
value
.
dtype
.
name
,
self
.
dtype
.
name
,
inner_value
))
# Current loop index (starts at 0)
cur_index
=
tf
.
constant
(
0
)
# Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq
=
_expand_to_beam_size
(
initial_ids
,
self
.
beam_size
)
alive_seq
=
tf
.
expand_dims
(
alive_seq
,
axis
=
2
)
if
self
.
padded_decode
:
alive_seq
=
tf
.
tile
(
alive_seq
,
[
1
,
1
,
self
.
max_decode_length
+
1
])
# Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0
initial_log_probs
=
tf
.
constant
(
[[
0.
]
+
[
-
float
(
"inf"
)]
*
(
self
.
beam_size
-
1
)],
dtype
=
self
.
dtype
)
alive_log_probs
=
tf
.
tile
(
initial_log_probs
,
[
self
.
batch_size
,
1
])
# Expand all values stored in the dictionary to the beam size, so that each
# beam has a separate cache.
alive_cache
=
nest
.
map_structure
(
lambda
t
:
_expand_to_beam_size
(
t
,
self
.
beam_size
),
initial_cache
)
# Initialize tensor storing finished sequences with filler values.
finished_seq
=
tf
.
zeros
(
tf
.
shape
(
alive_seq
),
tf
.
int32
)
# Set scores of the initial finished seqs to negative infinity.
finished_scores
=
tf
.
ones
([
self
.
batch_size
,
self
.
beam_size
],
dtype
=
self
.
dtype
)
*
-
inf
(
self
.
dtype
)
# Initialize finished flags with all False values.
finished_flags
=
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
],
tf
.
bool
)
# Create state dictionary
state
=
{
_StateKeys
.
CUR_INDEX
:
cur_index
,
_StateKeys
.
ALIVE_SEQ
:
alive_seq
,
_StateKeys
.
ALIVE_LOG_PROBS
:
alive_log_probs
,
_StateKeys
.
ALIVE_CACHE
:
alive_cache
,
_StateKeys
.
FINISHED_SEQ
:
finished_seq
,
_StateKeys
.
FINISHED_SCORES
:
finished_scores
,
_StateKeys
.
FINISHED_FLAGS
:
finished_flags
}
# Create state invariants for each value in the state dictionary. Each
# dimension must be a constant or None. A None dimension means either:
# 1) the dimension's value is a tensor that remains the same but may
# depend on the input sequence to the model (e.g. batch size).
# 2) the dimension may have different values on different iterations.
if
self
.
padded_decode
:
state_shape_invariants
=
{
_StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
_StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
(
[
self
.
batch_size
,
self
.
beam_size
,
self
.
max_decode_length
+
1
]),
_StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_get_shape
,
alive_cache
),
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
(
[
self
.
batch_size
,
self
.
beam_size
,
self
.
max_decode_length
+
1
]),
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
])
}
else
:
state_shape_invariants
=
{
_StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
_StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_get_shape_keep_last_dim
,
alive_cache
),
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
])
}
return
state
,
state_shape_invariants
def
_continue_search
(
self
,
state
):
"""Return whether to continue the search loop.
The loops should terminate when
1) when decode length has been reached, or
2) when the worst score in the finished sequences is better than the best
score in the alive sequences (i.e. the finished sequences are provably
unchanging)
Args:
state: A dictionary with the current loop state.
Returns:
Bool tensor with value True if loop should continue, False if loop should
terminate.
"""
i
=
state
[
_StateKeys
.
CUR_INDEX
]
alive_log_probs
=
state
[
_StateKeys
.
ALIVE_LOG_PROBS
]
finished_scores
=
state
[
_StateKeys
.
FINISHED_SCORES
]
finished_flags
=
state
[
_StateKeys
.
FINISHED_FLAGS
]
not_at_max_decode_length
=
tf
.
less
(
i
,
self
.
max_decode_length
)
# Calculate largest length penalty (the larger penalty, the better score).
max_length_norm
=
_length_normalization
(
self
.
alpha
,
self
.
max_decode_length
,
dtype
=
self
.
dtype
)
# Get the best possible scores from alive sequences.
best_alive_scores
=
alive_log_probs
[:,
0
]
/
max_length_norm
# Compute worst score in finished sequences for each batch element
finished_scores
*=
tf
.
cast
(
finished_flags
,
self
.
dtype
)
# set filler scores to zero
lowest_finished_scores
=
tf
.
reduce_min
(
finished_scores
,
axis
=
1
)
# If there are no finished sequences in a batch element, then set the lowest
# finished score to -INF for that element.
finished_batches
=
tf
.
reduce_any
(
finished_flags
,
1
)
lowest_finished_scores
+=
((
1.0
-
tf
.
cast
(
finished_batches
,
self
.
dtype
))
*
-
inf
(
self
.
dtype
))
worst_finished_score_better_than_best_alive_score
=
tf
.
reduce_all
(
tf
.
greater
(
lowest_finished_scores
,
best_alive_scores
)
)
return
tf
.
logical_and
(
not_at_max_decode_length
,
tf
.
logical_not
(
worst_finished_score_better_than_best_alive_score
)
)
def
_search_step
(
self
,
state
):
"""Beam search loop body.
Grow alive sequences by a single ID. Sequences that have reached the EOS
token are marked as finished. The alive and finished sequences with the
highest log probabilities and scores are returned.
A sequence's finished score is calculating by dividing the log probability
by the length normalization factor. Without length normalization, the
search is more likely to return shorter sequences.
Args:
state: A dictionary with the current loop state.
Returns:
new state dictionary.
"""
# Grow alive sequences by one token.
new_seq
,
new_log_probs
,
topk_ids
,
new_cache
=
self
.
_grow_alive_seq
(
state
)
new_finished_flags
=
tf
.
equal
(
topk_ids
,
self
.
eos_id
)
# Collect top beam_size alive sequences
alive_state
=
self
.
_get_new_alive_state
(
new_seq
,
new_log_probs
,
new_finished_flags
,
new_cache
)
# Combine newly finished sequences with existing finished sequences, and
# collect the top k scoring sequences.
finished_state
=
self
.
_get_new_finished_state
(
state
,
new_seq
,
new_log_probs
,
new_finished_flags
)
# Increment loop index and create new state dictionary
new_state
=
{
_StateKeys
.
CUR_INDEX
:
state
[
_StateKeys
.
CUR_INDEX
]
+
1
}
new_state
.
update
(
alive_state
)
new_state
.
update
(
finished_state
)
return
[
new_state
]
def
_grow_alive_seq
(
self
,
state
):
"""Grow alive sequences by one token, and collect top 2*beam_size sequences.
2*beam_size sequences are collected because some sequences may have reached
the EOS token. 2*beam_size ensures that at least beam_size sequences are
still alive.
Args:
state: A dictionary with the current loop state.
Returns:
Tuple of
(Top 2*beam_size sequences [batch_size, 2 * beam_size, cur_index + 1],
Scores of returned sequences [batch_size, 2 * beam_size],
New alive cache, for each of the 2 * beam_size sequences)
"""
i
=
state
[
_StateKeys
.
CUR_INDEX
]
alive_seq
=
state
[
_StateKeys
.
ALIVE_SEQ
]
alive_log_probs
=
state
[
_StateKeys
.
ALIVE_LOG_PROBS
]
alive_cache
=
state
[
_StateKeys
.
ALIVE_CACHE
]
beams_to_keep
=
2
*
self
.
beam_size
# Get logits for the next candidate IDs for the alive sequences. Get the new
# cache values at the same time.
if
self
.
padded_decode
:
flat_ids
=
tf
.
reshape
(
tf
.
slice
(
alive_seq
,
[
0
,
0
,
i
],
[
self
.
batch_size
,
self
.
beam_size
,
1
]),
[
self
.
batch_size
*
self
.
beam_size
,
-
1
])
else
:
flat_ids
=
_flatten_beam_dim
(
alive_seq
)
# [batch_size * beam_size]
flat_cache
=
nest
.
map_structure
(
_flatten_beam_dim
,
alive_cache
)
flat_logits
,
flat_cache
=
self
.
symbols_to_logits_fn
(
flat_ids
,
i
,
flat_cache
)
# Unflatten logits to shape [batch_size, beam_size, vocab_size]
logits
=
_unflatten_beam_dim
(
flat_logits
,
self
.
batch_size
,
self
.
beam_size
)
new_cache
=
nest
.
map_structure
(
lambda
t
:
_unflatten_beam_dim
(
t
,
self
.
batch_size
,
self
.
beam_size
),
flat_cache
)
# Convert logits to normalized log probs
candidate_log_probs
=
_log_prob_from_logits
(
logits
)
# Calculate new log probabilities if each of the alive sequences were
# extended # by the the candidate IDs.
# Shape [batch_size, beam_size, vocab_size]
log_probs
=
candidate_log_probs
+
tf
.
expand_dims
(
alive_log_probs
,
axis
=
2
)
# Each batch item has beam_size * vocab_size candidate sequences. For each
# batch item, get the k candidates with the highest log probabilities.
flat_log_probs
=
tf
.
reshape
(
log_probs
,
[
-
1
,
self
.
beam_size
*
self
.
vocab_size
])
topk_log_probs
,
topk_indices
=
tf
.
nn
.
top_k
(
flat_log_probs
,
k
=
beams_to_keep
)
# Extract the alive sequences that generate the highest log probabilities
# after being extended.
topk_beam_indices
=
topk_indices
//
self
.
vocab_size
topk_seq
,
new_cache
=
_gather_beams
(
[
alive_seq
,
new_cache
],
topk_beam_indices
,
self
.
batch_size
,
beams_to_keep
)
# Append the most probable IDs to the topk sequences
topk_ids
=
topk_indices
%
self
.
vocab_size
if
self
.
padded_decode
:
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
2
,
0
,
1
])
# TODO(b/145533236, hongkuny): Reverts once TF fix the validation.
topk_seq
=
tf
.
tensor_scatter_nd_update
(
topk_seq
,
[[
i
+
1
]],
tf
.
expand_dims
(
topk_ids
,
axis
=
0
))
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
1
,
2
,
0
])
else
:
topk_seq
=
tf
.
concat
([
topk_seq
,
tf
.
expand_dims
(
topk_ids
,
axis
=
2
)],
axis
=
2
)
return
topk_seq
,
topk_log_probs
,
topk_ids
,
new_cache
def
_get_new_alive_state
(
self
,
new_seq
,
new_log_probs
,
new_finished_flags
,
new_cache
):
"""Gather the top k sequences that are still alive.
Args:
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1]
new_log_probs: Log probabilities of new sequences float32 tensor with
shape [batch_size, beam_size]
new_finished_flags: A boolean Tensor indicates which sequences are live
inside the beam.
new_cache: Dict of cached values for each sequence.
Returns:
Dictionary with alive keys from _StateKeys:
{Top beam_size sequences that are still alive (don't end with eos_id)
Log probabilities of top alive sequences
Dict cache storing decoder states for top alive sequences}
"""
# To prevent finished sequences from being considered, set log probs to -inf
new_log_probs
+=
tf
.
cast
(
new_finished_flags
,
self
.
dtype
)
*
-
inf
(
self
.
dtype
)
top_alive_seq
,
top_alive_log_probs
,
top_alive_cache
=
_gather_topk_beams
(
[
new_seq
,
new_log_probs
,
new_cache
],
new_log_probs
,
self
.
batch_size
,
self
.
beam_size
)
return
{
_StateKeys
.
ALIVE_SEQ
:
top_alive_seq
,
_StateKeys
.
ALIVE_LOG_PROBS
:
top_alive_log_probs
,
_StateKeys
.
ALIVE_CACHE
:
top_alive_cache
}
def
_get_new_finished_state
(
self
,
state
,
new_seq
,
new_log_probs
,
new_finished_flags
):
"""Combine new and old finished sequences, and gather the top k sequences.
Args:
state: A dictionary with the current loop state.
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, beam_size, i + 1]
new_log_probs: Log probabilities of new sequences float32 tensor with
shape [batch_size, beam_size]
new_finished_flags: A boolean Tensor indicates which sequences are live
inside the beam.
Returns:
Dictionary with finished keys from _StateKeys:
{Top beam_size finished sequences based on score,
Scores of finished sequences,
Finished flags of finished sequences}
"""
i
=
state
[
_StateKeys
.
CUR_INDEX
]
finished_seq
=
state
[
_StateKeys
.
FINISHED_SEQ
]
finished_scores
=
state
[
_StateKeys
.
FINISHED_SCORES
]
finished_flags
=
state
[
_StateKeys
.
FINISHED_FLAGS
]
# First append a column of 0-ids to finished_seq to increment the length.
# New shape of finished_seq: [batch_size, beam_size, i + 1]
if
not
self
.
padded_decode
:
finished_seq
=
tf
.
concat
([
finished_seq
,
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
,
1
],
tf
.
int32
)
],
axis
=
2
)
# Calculate new seq scores from log probabilities.
length_norm
=
_length_normalization
(
self
.
alpha
,
i
+
1
,
dtype
=
self
.
dtype
)
new_scores
=
new_log_probs
/
length_norm
# Set the scores of the still-alive seq in new_seq to large negative values.
new_scores
+=
((
1.
-
tf
.
cast
(
new_finished_flags
,
self
.
dtype
))
*
-
inf
(
self
.
dtype
))
# Combine sequences, scores, and flags.
finished_seq
=
tf
.
concat
([
finished_seq
,
new_seq
],
axis
=
1
)
finished_scores
=
tf
.
concat
([
finished_scores
,
new_scores
],
axis
=
1
)
finished_flags
=
tf
.
concat
([
finished_flags
,
new_finished_flags
],
axis
=
1
)
# Return the finished sequences with the best scores.
top_finished_seq
,
top_finished_scores
,
top_finished_flags
=
(
_gather_topk_beams
([
finished_seq
,
finished_scores
,
finished_flags
],
finished_scores
,
self
.
batch_size
,
self
.
beam_size
))
return
{
_StateKeys
.
FINISHED_SEQ
:
top_finished_seq
,
_StateKeys
.
FINISHED_SCORES
:
top_finished_scores
,
_StateKeys
.
FINISHED_FLAGS
:
top_finished_flags
}
def
sequence_beam_search
(
def
sequence_beam_search
(
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
...
@@ -536,140 +74,6 @@ def sequence_beam_search(
...
@@ -536,140 +74,6 @@ def sequence_beam_search(
Top decoded sequences [batch_size, beam_size, max_decode_length]
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
sequence scores [batch_size, beam_size]
"""
"""
batch_size
=
(
sbs
=
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
beam_size
,
alpha
,
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
max_decode_length
,
eos_id
,
padded_decode
)
tf
.
shape
(
initial_ids
)[
0
])
sbs
=
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
def
_log_prob_from_logits
(
logits
):
return
logits
-
tf
.
reduce_logsumexp
(
logits
,
axis
=
2
,
keepdims
=
True
)
def
_length_normalization
(
alpha
,
length
,
dtype
=
tf
.
float32
):
"""Return length normalization factor."""
return
tf
.
pow
(((
5.
+
tf
.
cast
(
length
,
dtype
))
/
6.
),
alpha
)
def
_expand_to_beam_size
(
tensor
,
beam_size
):
"""Tiles a given tensor by beam_size.
Args:
tensor: tensor to tile [batch_size, ...]
beam_size: How much to tile the tensor by.
Returns:
Tiled tensor [batch_size, beam_size, ...]
"""
tensor
=
tf
.
expand_dims
(
tensor
,
axis
=
1
)
tile_dims
=
[
1
]
*
tensor
.
shape
.
ndims
tile_dims
[
1
]
=
beam_size
return
tf
.
tile
(
tensor
,
tile_dims
)
def
_shape_list
(
tensor
):
"""Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions)
shape
=
tensor
.
get_shape
().
as_list
()
# Ensure that the shape values are not None
dynamic_shape
=
tf
.
shape
(
tensor
)
for
i
in
range
(
len
(
shape
)):
# pylint: disable=consider-using-enumerate
if
shape
[
i
]
is
None
:
shape
[
i
]
=
dynamic_shape
[
i
]
return
shape
def
_get_shape_keep_last_dim
(
tensor
):
shape_list
=
_shape_list
(
tensor
)
# Only the last
for
i
in
range
(
len
(
shape_list
)
-
1
):
shape_list
[
i
]
=
None
if
isinstance
(
shape_list
[
-
1
],
tf
.
Tensor
):
shape_list
[
-
1
]
=
None
return
tf
.
TensorShape
(
shape_list
)
def
_get_shape
(
tensor
):
"""Return the shape of the input tensor."""
return
tf
.
TensorShape
(
_shape_list
(
tensor
))
def
_flatten_beam_dim
(
tensor
):
"""Reshapes first two dimensions in to single dimension.
Args:
tensor: Tensor to reshape of shape [A, B, ...]
Returns:
Reshaped tensor of shape [A*B, ...]
"""
shape
=
_shape_list
(
tensor
)
shape
[
0
]
*=
shape
[
1
]
shape
.
pop
(
1
)
# Remove beam dim
return
tf
.
reshape
(
tensor
,
shape
)
def
_unflatten_beam_dim
(
tensor
,
batch_size
,
beam_size
):
"""Reshapes first dimension back to [batch_size, beam_size].
Args:
tensor: Tensor to reshape of shape [batch_size*beam_size, ...]
batch_size: Tensor, original batch size.
beam_size: int, original beam size.
Returns:
Reshaped tensor of shape [batch_size, beam_size, ...]
"""
shape
=
_shape_list
(
tensor
)
new_shape
=
[
batch_size
,
beam_size
]
+
shape
[
1
:]
return
tf
.
reshape
(
tensor
,
new_shape
)
def
_gather_beams
(
nested
,
beam_indices
,
batch_size
,
new_beam_size
):
"""Gather beams from nested structure of tensors.
Each tensor in nested represents a batch of beams, where beam refers to a
single search state (beam search involves searching through multiple states
in parallel).
This function is used to gather the top beams, specified by
beam_indices, from the nested tensors.
Args:
nested: Nested structure (tensor, list, tuple or dict) containing tensors
with shape [batch_size, beam_size, ...].
beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each
value in beam_indices must be between [0, beam_size), and are not
necessarily unique.
batch_size: int size of batch
new_beam_size: int number of beams to be pulled from the nested tensors.
Returns:
Nested structure containing tensors with shape
[batch_size, new_beam_size, ...]
"""
# Computes the i'th coodinate that contains the batch index for gather_nd.
# Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..].
batch_pos
=
tf
.
range
(
batch_size
*
new_beam_size
)
//
new_beam_size
batch_pos
=
tf
.
reshape
(
batch_pos
,
[
batch_size
,
new_beam_size
])
# Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor
# with shape [batch_size, beam_size, 2], where the last dimension contains
# the (i, j) gathering coordinates.
coordinates
=
tf
.
stack
([
batch_pos
,
beam_indices
],
axis
=
2
)
return
nest
.
map_structure
(
lambda
state
:
tf
.
gather_nd
(
state
,
coordinates
),
nested
)
def
_gather_topk_beams
(
nested
,
score_or_log_prob
,
batch_size
,
beam_size
):
"""Gather top beams from nested structure."""
_
,
topk_indexes
=
tf
.
nn
.
top_k
(
score_or_log_prob
,
k
=
beam_size
)
return
_gather_beams
(
nested
,
topk_indexes
,
batch_size
,
beam_size
)
official/nlp/transformer/embedding_layer.py
View file @
5a2cf36f
...
@@ -43,6 +43,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
...
@@ -43,6 +43,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self
.
shared_weights
=
self
.
add_weight
(
self
.
shared_weights
=
self
.
add_weight
(
"weights"
,
"weights"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
dtype
=
tf
.
float32
,
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
super
(
EmbeddingSharedWeights
,
self
).
build
(
input_shape
)
super
(
EmbeddingSharedWeights
,
self
).
build
(
input_shape
)
...
...
official/nlp/transformer/transformer.py
View file @
5a2cf36f
...
@@ -23,8 +23,8 @@ from __future__ import print_function
...
@@ -23,8 +23,8 @@ from __future__ import print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
position_embedding
from
official.nlp.modeling.layers
import
position_embedding
from
official.nlp.modeling.ops
import
beam_search
from
official.nlp.transformer
import
attention_layer
from
official.nlp.transformer
import
attention_layer
from
official.nlp.transformer
import
beam_search
from
official.nlp.transformer
import
embedding_layer
from
official.nlp.transformer
import
embedding_layer
from
official.nlp.transformer
import
ffn_layer
from
official.nlp.transformer
import
ffn_layer
from
official.nlp.transformer
import
metrics
from
official.nlp.transformer
import
metrics
...
@@ -52,7 +52,6 @@ def create_model(params, is_train):
...
@@ -52,7 +52,6 @@ def create_model(params, is_train):
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
,
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
,
dtype
=
tf
.
float32
)(
logits
)
dtype
=
tf
.
float32
)(
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss
=
metrics
.
transformer_loss
(
loss
=
metrics
.
transformer_loss
(
logits
,
targets
,
label_smoothing
,
vocab_size
)
logits
,
targets
,
label_smoothing
,
vocab_size
)
model
.
add_loss
(
loss
)
model
.
add_loss
(
loss
)
...
@@ -238,7 +237,6 @@ class Transformer(tf.keras.Model):
...
@@ -238,7 +237,6 @@ class Transformer(tf.keras.Model):
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
max_decode_length
,
dtype
=
self
.
params
[
"dtype"
])
max_decode_length
,
dtype
=
self
.
params
[
"dtype"
])
# TODO(b/139770046): Refactor code with better naming of i.
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
"""Generate logits for next potential IDs.
"""Generate logits for next potential IDs.
...
...
official/nlp/transformer/transformer_main.py
View file @
5a2cf36f
...
@@ -248,7 +248,6 @@ class TransformerTask(object):
...
@@ -248,7 +248,6 @@ class TransformerTask(object):
callbacks
=
[
cb
for
cb
in
callbacks
callbacks
=
[
cb
for
cb
in
callbacks
if
isinstance
(
cb
,
keras_utils
.
TimeHistory
)]
if
isinstance
(
cb
,
keras_utils
.
TimeHistory
)]
# TODO(b/139418525): Refactor the custom training loop logic.
@
tf
.
function
@
tf
.
function
def
train_steps
(
iterator
,
steps
):
def
train_steps
(
iterator
,
steps
):
"""Training steps function for TPU runs.
"""Training steps function for TPU runs.
...
@@ -422,8 +421,6 @@ class TransformerTask(object):
...
@@ -422,8 +421,6 @@ class TransformerTask(object):
"""Loads model weights when it is provided."""
"""Loads model weights when it is provided."""
if
init_weight_path
:
if
init_weight_path
:
logging
.
info
(
"Load weights: {}"
.
format
(
init_weight_path
))
logging
.
info
(
"Load weights: {}"
.
format
(
init_weight_path
))
# TODO(b/139414977): Having the same variable restoring method for both
# TPU and GPU.
if
self
.
use_tpu
:
if
self
.
use_tpu
:
checkpoint
=
tf
.
train
.
Checkpoint
(
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
self
.
_create_optimizer
())
model
=
model
,
optimizer
=
self
.
_create_optimizer
())
...
...
official/nlp/transformer/utils/metrics.py
View file @
5a2cf36f
...
@@ -67,7 +67,7 @@ def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
...
@@ -67,7 +67,7 @@ def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
# Calculate smoothing cross entropy
# Calculate smoothing cross entropy
with
tf
.
name_scope
(
"smoothing_cross_entropy"
,
values
=
[
logits
,
labels
]):
with
tf
.
name_scope
(
"smoothing_cross_entropy"
,
values
=
[
logits
,
labels
]):
confidence
=
1.0
-
smoothing
confidence
=
1.0
-
smoothing
low_confidence
=
(
1.0
-
confidence
)
/
tf
.
to_floa
t
(
vocab_size
-
1
)
low_confidence
=
(
1.0
-
confidence
)
/
tf
.
cas
t
(
vocab_size
-
1
,
tf
.
float32
)
soft_targets
=
tf
.
one_hot
(
soft_targets
=
tf
.
one_hot
(
tf
.
cast
(
labels
,
tf
.
int32
),
tf
.
cast
(
labels
,
tf
.
int32
),
depth
=
vocab_size
,
depth
=
vocab_size
,
...
@@ -79,11 +79,11 @@ def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
...
@@ -79,11 +79,11 @@ def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
# Calculate the best (lowest) possible value of cross entropy, and
# Calculate the best (lowest) possible value of cross entropy, and
# subtract from the cross entropy loss.
# subtract from the cross entropy loss.
normalizing_constant
=
-
(
normalizing_constant
=
-
(
confidence
*
tf
.
log
(
confidence
)
+
tf
.
to_floa
t
(
vocab_size
-
1
)
*
confidence
*
tf
.
log
(
confidence
)
+
tf
.
cas
t
(
vocab_size
-
1
,
tf
.
float32
)
low_confidence
*
tf
.
log
(
low_confidence
+
1e-20
))
*
low_confidence
*
tf
.
log
(
low_confidence
+
1e-20
))
xentropy
-=
normalizing_constant
xentropy
-=
normalizing_constant
weights
=
tf
.
to_floa
t
(
tf
.
not_equal
(
labels
,
0
))
weights
=
tf
.
cas
t
(
tf
.
not_equal
(
labels
,
0
)
,
tf
.
float32
)
return
xentropy
*
weights
,
weights
return
xentropy
*
weights
,
weights
...
@@ -142,24 +142,24 @@ def padded_accuracy(logits, labels):
...
@@ -142,24 +142,24 @@ def padded_accuracy(logits, labels):
"""Percentage of times that predictions matches labels on non-0s."""
"""Percentage of times that predictions matches labels on non-0s."""
with
tf
.
variable_scope
(
"padded_accuracy"
,
values
=
[
logits
,
labels
]):
with
tf
.
variable_scope
(
"padded_accuracy"
,
values
=
[
logits
,
labels
]):
logits
,
labels
=
_pad_tensors_to_same_length
(
logits
,
labels
)
logits
,
labels
=
_pad_tensors_to_same_length
(
logits
,
labels
)
weights
=
tf
.
to_floa
t
(
tf
.
not_equal
(
labels
,
0
))
weights
=
tf
.
cas
t
(
tf
.
not_equal
(
labels
,
0
)
,
tf
.
float32
)
outputs
=
tf
.
to_int32
(
tf
.
argmax
(
logits
,
axis
=-
1
))
outputs
=
tf
.
cast
(
tf
.
argmax
(
logits
,
axis
=-
1
)
,
tf
.
int32
)
padded_labels
=
tf
.
to_int32
(
labels
)
padded_labels
=
tf
.
cast
(
labels
,
tf
.
int32
)
return
tf
.
to_floa
t
(
tf
.
equal
(
outputs
,
padded_labels
)),
weights
return
tf
.
cas
t
(
tf
.
equal
(
outputs
,
padded_labels
)
,
tf
.
float32
),
weights
def
padded_accuracy_topk
(
logits
,
labels
,
k
):
def
padded_accuracy_topk
(
logits
,
labels
,
k
):
"""Percentage of times that top-k predictions matches labels on non-0s."""
"""Percentage of times that top-k predictions matches labels on non-0s."""
with
tf
.
variable_scope
(
"padded_accuracy_topk"
,
values
=
[
logits
,
labels
]):
with
tf
.
variable_scope
(
"padded_accuracy_topk"
,
values
=
[
logits
,
labels
]):
logits
,
labels
=
_pad_tensors_to_same_length
(
logits
,
labels
)
logits
,
labels
=
_pad_tensors_to_same_length
(
logits
,
labels
)
weights
=
tf
.
to_floa
t
(
tf
.
not_equal
(
labels
,
0
))
weights
=
tf
.
cas
t
(
tf
.
not_equal
(
labels
,
0
)
,
tf
.
float32
)
effective_k
=
tf
.
minimum
(
k
,
tf
.
shape
(
logits
)[
-
1
])
effective_k
=
tf
.
minimum
(
k
,
tf
.
shape
(
logits
)[
-
1
])
_
,
outputs
=
tf
.
nn
.
top_k
(
logits
,
k
=
effective_k
)
_
,
outputs
=
tf
.
nn
.
top_k
(
logits
,
k
=
effective_k
)
outputs
=
tf
.
to_int32
(
outputs
)
outputs
=
tf
.
cast
(
outputs
,
tf
.
int32
)
padded_labels
=
tf
.
to_int32
(
labels
)
padded_labels
=
tf
.
cast
(
labels
,
tf
.
int32
)
padded_labels
=
tf
.
expand_dims
(
padded_labels
,
axis
=-
1
)
padded_labels
=
tf
.
expand_dims
(
padded_labels
,
axis
=-
1
)
padded_labels
+=
tf
.
zeros_like
(
outputs
)
# Pad to same shape.
padded_labels
+=
tf
.
zeros_like
(
outputs
)
# Pad to same shape.
same
=
tf
.
to_floa
t
(
tf
.
equal
(
outputs
,
padded_labels
))
same
=
tf
.
cas
t
(
tf
.
equal
(
outputs
,
padded_labels
)
,
tf
.
float32
)
same_topk
=
tf
.
reduce_sum
(
same
,
axis
=-
1
)
same_topk
=
tf
.
reduce_sum
(
same
,
axis
=-
1
)
return
same_topk
,
weights
return
same_topk
,
weights
...
@@ -172,10 +172,11 @@ def padded_sequence_accuracy(logits, labels):
...
@@ -172,10 +172,11 @@ def padded_sequence_accuracy(logits, labels):
"""Percentage of times that predictions matches labels everywhere (non-0)."""
"""Percentage of times that predictions matches labels everywhere (non-0)."""
with
tf
.
variable_scope
(
"padded_sequence_accuracy"
,
values
=
[
logits
,
labels
]):
with
tf
.
variable_scope
(
"padded_sequence_accuracy"
,
values
=
[
logits
,
labels
]):
logits
,
labels
=
_pad_tensors_to_same_length
(
logits
,
labels
)
logits
,
labels
=
_pad_tensors_to_same_length
(
logits
,
labels
)
weights
=
tf
.
to_float
(
tf
.
not_equal
(
labels
,
0
))
weights
=
tf
.
cast
(
tf
.
not_equal
(
labels
,
0
),
tf
.
float32
)
outputs
=
tf
.
to_int32
(
tf
.
argmax
(
logits
,
axis
=-
1
))
outputs
=
tf
.
cast
(
tf
.
argmax
(
logits
,
axis
=-
1
),
tf
.
int32
)
padded_labels
=
tf
.
to_int32
(
labels
)
padded_labels
=
tf
.
cast
(
labels
,
tf
.
int32
)
not_correct
=
tf
.
to_float
(
tf
.
not_equal
(
outputs
,
padded_labels
))
*
weights
not_correct
=
(
tf
.
cast
(
tf
.
not_equal
(
outputs
,
padded_labels
),
tf
.
float32
)
*
weights
)
axis
=
list
(
range
(
1
,
len
(
outputs
.
get_shape
())))
axis
=
list
(
range
(
1
,
len
(
outputs
.
get_shape
())))
correct_seq
=
1.0
-
tf
.
minimum
(
1.0
,
tf
.
reduce_sum
(
not_correct
,
axis
=
axis
))
correct_seq
=
1.0
-
tf
.
minimum
(
1.0
,
tf
.
reduce_sum
(
not_correct
,
axis
=
axis
))
return
correct_seq
,
tf
.
constant
(
1.0
)
return
correct_seq
,
tf
.
constant
(
1.0
)
...
@@ -201,7 +202,7 @@ def bleu_score(logits, labels):
...
@@ -201,7 +202,7 @@ def bleu_score(logits, labels):
Returns:
Returns:
bleu: int, approx bleu score
bleu: int, approx bleu score
"""
"""
predictions
=
tf
.
to_int32
(
tf
.
argmax
(
logits
,
axis
=-
1
))
predictions
=
tf
.
cast
(
tf
.
argmax
(
logits
,
axis
=-
1
)
,
tf
.
int32
)
# TODO: Look into removing use of py_func
# TODO: Look into removing use of py_func
bleu
=
tf
.
py_func
(
compute_bleu
,
(
labels
,
predictions
),
tf
.
float32
)
bleu
=
tf
.
py_func
(
compute_bleu
,
(
labels
,
predictions
),
tf
.
float32
)
return
bleu
,
tf
.
constant
(
1.0
)
return
bleu
,
tf
.
constant
(
1.0
)
...
@@ -306,7 +307,7 @@ def rouge_2_fscore(logits, labels):
...
@@ -306,7 +307,7 @@ def rouge_2_fscore(logits, labels):
Returns:
Returns:
rouge2_fscore: approx rouge-2 f1 score.
rouge2_fscore: approx rouge-2 f1 score.
"""
"""
predictions
=
tf
.
to_int32
(
tf
.
argmax
(
logits
,
axis
=-
1
))
predictions
=
tf
.
cast
(
tf
.
argmax
(
logits
,
axis
=-
1
)
,
tf
.
int32
)
# TODO: Look into removing use of py_func
# TODO: Look into removing use of py_func
rouge_2_f_score
=
tf
.
py_func
(
rouge_n
,
(
predictions
,
labels
),
tf
.
float32
)
rouge_2_f_score
=
tf
.
py_func
(
rouge_n
,
(
predictions
,
labels
),
tf
.
float32
)
return
rouge_2_f_score
,
tf
.
constant
(
1.0
)
return
rouge_2_f_score
,
tf
.
constant
(
1.0
)
...
@@ -383,7 +384,7 @@ def rouge_l_fscore(predictions, labels):
...
@@ -383,7 +384,7 @@ def rouge_l_fscore(predictions, labels):
Returns:
Returns:
rouge_l_fscore: approx rouge-l f1 score.
rouge_l_fscore: approx rouge-l f1 score.
"""
"""
outputs
=
tf
.
to_int32
(
tf
.
argmax
(
predictions
,
axis
=-
1
))
outputs
=
tf
.
cast
(
tf
.
argmax
(
predictions
,
axis
=-
1
)
,
tf
.
int32
)
rouge_l_f_score
=
tf
.
py_func
(
rouge_l_sentence_level
,
(
outputs
,
labels
),
rouge_l_f_score
=
tf
.
py_func
(
rouge_l_sentence_level
,
(
outputs
,
labels
),
tf
.
float32
)
tf
.
float32
)
return
rouge_l_f_score
,
tf
.
constant
(
1.0
)
return
rouge_l_f_score
,
tf
.
constant
(
1.0
)
...
...
official/recommendation/ncf_common.py
View file @
5a2cf36f
...
@@ -94,7 +94,7 @@ def parse_flags(flags_obj):
...
@@ -94,7 +94,7 @@ def parse_flags(flags_obj):
"beta2"
:
flags_obj
.
beta2
,
"beta2"
:
flags_obj
.
beta2
,
"epsilon"
:
flags_obj
.
epsilon
,
"epsilon"
:
flags_obj
.
epsilon
,
"match_mlperf"
:
flags_obj
.
ml_perf
,
"match_mlperf"
:
flags_obj
.
ml_perf
,
"epochs_between_evals"
:
FLAGS
.
epochs_between_evals
,
"epochs_between_evals"
:
flags_obj
.
epochs_between_evals
,
"keras_use_ctl"
:
flags_obj
.
keras_use_ctl
,
"keras_use_ctl"
:
flags_obj
.
keras_use_ctl
,
"hr_threshold"
:
flags_obj
.
hr_threshold
,
"hr_threshold"
:
flags_obj
.
hr_threshold
,
"stream_files"
:
flags_obj
.
tpu
is
not
None
,
"stream_files"
:
flags_obj
.
tpu
is
not
None
,
...
...
official/recommendation/ncf_input_pipeline.py
View file @
5a2cf36f
...
@@ -25,10 +25,8 @@ import tensorflow.compat.v2 as tf
...
@@ -25,10 +25,8 @@ import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
# pylint: enable=g-bad-import-order
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
movielens
from
official.recommendation
import
data_pipeline
from
official.recommendation
import
data_pipeline
from
official.recommendation
import
movielens
NUM_SHARDS
=
16
def
create_dataset_from_tf_record_files
(
input_file_pattern
,
def
create_dataset_from_tf_record_files
(
input_file_pattern
,
...
@@ -36,32 +34,23 @@ def create_dataset_from_tf_record_files(input_file_pattern,
...
@@ -36,32 +34,23 @@ def create_dataset_from_tf_record_files(input_file_pattern,
batch_size
,
batch_size
,
is_training
=
True
):
is_training
=
True
):
"""Creates dataset from (tf)records files for training/evaluation."""
"""Creates dataset from (tf)records files for training/evaluation."""
if
pre_batch_size
!=
batch_size
:
raise
ValueError
(
"Pre-batch ({}) size is not equal to batch "
"size ({})"
.
format
(
pre_batch_size
,
batch_size
))
files
=
tf
.
data
.
Dataset
.
list_files
(
input_file_pattern
,
shuffle
=
is_training
)
files
=
tf
.
data
.
Dataset
.
list_files
(
input_file_pattern
,
shuffle
=
is_training
)
def
make_dataset
(
files_dataset
,
shard_index
):
dataset
=
files
.
interleave
(
"""Returns dataset for sharded tf record files."""
tf
.
data
.
TFRecordDataset
,
if
pre_batch_size
!=
batch_size
:
cycle_length
=
16
,
raise
ValueError
(
"Pre-batch ({}) size is not equal to batch "
"size ({})"
.
format
(
pre_batch_size
,
batch_size
))
files_dataset
=
files_dataset
.
shard
(
NUM_SHARDS
,
shard_index
)
dataset
=
files_dataset
.
interleave
(
tf
.
data
.
TFRecordDataset
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
decode_fn
=
functools
.
partial
(
data_pipeline
.
DatasetManager
.
deserialize
,
batch_size
=
pre_batch_size
,
is_training
=
is_training
)
dataset
=
dataset
.
map
(
decode_fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
dataset
=
tf
.
data
.
Dataset
.
range
(
NUM_SHARDS
)
map_fn
=
functools
.
partial
(
make_dataset
,
files
)
dataset
=
dataset
.
interleave
(
map_fn
,
cycle_length
=
NUM_SHARDS
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
decode_fn
=
functools
.
partial
(
data_pipeline
.
DatasetManager
.
deserialize
,
batch_size
=
pre_batch_size
,
is_training
=
is_training
)
dataset
=
dataset
.
map
(
decode_fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
...
...
official/recommendation/ncf_keras_main.py
View file @
5a2cf36f
...
@@ -488,19 +488,19 @@ def run_ncf_custom_training(params,
...
@@ -488,19 +488,19 @@ def run_ncf_custom_training(params,
c
.
on_batch_end
(
current_step
)
c
.
on_batch_end
(
current_step
)
train_loss
/=
num_train_steps
train_loss
/=
num_train_steps
logging
.
info
(
"Done training epoch %s, epoch loss=%
s
."
,
epoch
+
1
,
logging
.
info
(
"Done training epoch %s, epoch loss=%.
3f
"
,
epoch
+
1
,
train_loss
)
train_loss
)
eval_input_iterator
=
iter
(
eval_input_iterator
=
iter
(
eval_input_dataset
)
strategy
.
experimental_distribute_dataset
(
eval_input_dataset
))
hr_sum
=
0
hr_sum
=
0
.0
hr_count
=
0
hr_count
=
0
.0
for
_
in
range
(
num_eval_steps
):
for
_
in
range
(
num_eval_steps
):
step_hr_sum
,
step_hr_count
=
eval_step
(
eval_input_iterator
)
step_hr_sum
,
step_hr_count
=
eval_step
(
eval_input_iterator
)
hr_sum
+=
step_hr_sum
hr_sum
+=
step_hr_sum
hr_count
+=
step_hr_count
hr_count
+=
step_hr_count
logging
.
info
(
"Done eval epoch %s, hit_rate=%
s
."
,
epoch
+
1
,
logging
.
info
(
"Done eval epoch %s, hit_rate=%.
3f
"
,
epoch
+
1
,
hr_sum
/
hr_count
)
hr_sum
/
hr_count
)
if
eval_summary_writer
:
if
eval_summary_writer
:
with
eval_summary_writer
.
as_default
():
with
eval_summary_writer
.
as_default
():
...
...
official/requirements.txt
View file @
5a2cf36f
...
@@ -15,7 +15,6 @@ tensorflow-addons
...
@@ -15,7 +15,6 @@ tensorflow-addons
dataclasses
dataclasses
gin-config
gin-config
tf_slim>=1.1.0
tf_slim>=1.1.0
typing
Cython
Cython
matplotlib
matplotlib
pyyaml
pyyaml
...
...
official/vision/detection/modeling/architecture/fpn.py
View file @
5a2cf36f
...
@@ -28,7 +28,7 @@ import functools
...
@@ -28,7 +28,7 @@ import functools
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
official.vision.detection.modeling.architecture
import
keras_utils
from
official.vision.detection.modeling.architecture
import
nn_ops
from
official.vision.detection.modeling.architecture
import
nn_ops
from
official.vision.detection.ops
import
spatial_transform_ops
from
official.vision.detection.ops
import
spatial_transform_ops
...
@@ -120,7 +120,7 @@ class Fpn(object):
...
@@ -120,7 +120,7 @@ class Fpn(object):
'The minimum backbone level %d should be '
%
(
min
(
input_levels
))
+
'The minimum backbone level %d should be '
%
(
min
(
input_levels
))
+
'less or equal to FPN minimum level %d.:'
%
(
self
.
_min_level
))
'less or equal to FPN minimum level %d.:'
%
(
self
.
_min_level
))
backbone_max_level
=
min
(
max
(
input_levels
),
self
.
_max_level
)
backbone_max_level
=
min
(
max
(
input_levels
),
self
.
_max_level
)
with
backend
.
get
_graph
()
.
as_default
()
,
tf
.
name_scope
(
'fpn'
):
with
keras_utils
.
maybe_enter_
backend_graph
(),
tf
.
name_scope
(
'fpn'
):
# Adds lateral connections.
# Adds lateral connections.
feats_lateral
=
{}
feats_lateral
=
{}
for
level
in
range
(
self
.
_min_level
,
backbone_max_level
+
1
):
for
level
in
range
(
self
.
_min_level
,
backbone_max_level
+
1
):
...
...
official/vision/detection/modeling/architecture/heads.py
View file @
5a2cf36f
...
@@ -22,7 +22,8 @@ import functools
...
@@ -22,7 +22,8 @@ import functools
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
official.vision.detection.modeling.architecture
import
keras_utils
from
official.vision.detection.modeling.architecture
import
nn_ops
from
official.vision.detection.modeling.architecture
import
nn_ops
from
official.vision.detection.ops
import
spatial_transform_ops
from
official.vision.detection.ops
import
spatial_transform_ops
...
@@ -127,7 +128,7 @@ class RpnHead(tf.keras.layers.Layer):
...
@@ -127,7 +128,7 @@ class RpnHead(tf.keras.layers.Layer):
scores_outputs
=
{}
scores_outputs
=
{}
box_outputs
=
{}
box_outputs
=
{}
with
backend
.
get
_graph
()
.
as_default
()
,
tf
.
name_scope
(
'rpn_head'
):
with
keras_utils
.
maybe_enter_
backend_graph
(),
tf
.
name_scope
(
'rpn_head'
):
for
level
in
range
(
self
.
_min_level
,
self
.
_max_level
+
1
):
for
level
in
range
(
self
.
_min_level
,
self
.
_max_level
+
1
):
scores_output
,
box_output
=
self
.
_shared_rpn_heads
(
scores_output
,
box_output
=
self
.
_shared_rpn_heads
(
features
[
level
],
self
.
_anchors_per_location
,
level
,
is_training
)
features
[
level
],
self
.
_anchors_per_location
,
level
,
is_training
)
...
@@ -249,7 +250,8 @@ class FastrcnnHead(tf.keras.layers.Layer):
...
@@ -249,7 +250,8 @@ class FastrcnnHead(tf.keras.layers.Layer):
predictions.
predictions.
"""
"""
with
backend
.
get_graph
().
as_default
(),
tf
.
name_scope
(
'fast_rcnn_head'
):
with
keras_utils
.
maybe_enter_backend_graph
(),
tf
.
name_scope
(
'fast_rcnn_head'
):
# reshape inputs beofre FC.
# reshape inputs beofre FC.
_
,
num_rois
,
height
,
width
,
filters
=
roi_features
.
get_shape
().
as_list
()
_
,
num_rois
,
height
,
width
,
filters
=
roi_features
.
get_shape
().
as_list
()
...
@@ -368,7 +370,7 @@ class MaskrcnnHead(tf.keras.layers.Layer):
...
@@ -368,7 +370,7 @@ class MaskrcnnHead(tf.keras.layers.Layer):
boxes is not 4.
boxes is not 4.
"""
"""
with
backend
.
get
_graph
()
.
as_default
()
:
with
keras_utils
.
maybe_enter_
backend_graph
():
with
tf
.
name_scope
(
'mask_head'
):
with
tf
.
name_scope
(
'mask_head'
):
_
,
num_rois
,
height
,
width
,
filters
=
roi_features
.
get_shape
().
as_list
()
_
,
num_rois
,
height
,
width
,
filters
=
roi_features
.
get_shape
().
as_list
()
net
=
tf
.
reshape
(
roi_features
,
[
-
1
,
height
,
width
,
filters
])
net
=
tf
.
reshape
(
roi_features
,
[
-
1
,
height
,
width
,
filters
])
...
@@ -552,7 +554,8 @@ class RetinanetHead(object):
...
@@ -552,7 +554,8 @@ class RetinanetHead(object):
"""Returns outputs of RetinaNet head."""
"""Returns outputs of RetinaNet head."""
class_outputs
=
{}
class_outputs
=
{}
box_outputs
=
{}
box_outputs
=
{}
with
backend
.
get_graph
().
as_default
(),
tf
.
name_scope
(
'retinanet_head'
):
with
keras_utils
.
maybe_enter_backend_graph
(),
tf
.
name_scope
(
'retinanet_head'
):
for
level
in
range
(
self
.
_min_level
,
self
.
_max_level
+
1
):
for
level
in
range
(
self
.
_min_level
,
self
.
_max_level
+
1
):
features
=
fpn_features
[
level
]
features
=
fpn_features
[
level
]
...
@@ -644,7 +647,7 @@ class ShapemaskPriorHead(object):
...
@@ -644,7 +647,7 @@ class ShapemaskPriorHead(object):
detection_priors: A float Tensor of shape [batch_size * num_instances,
detection_priors: A float Tensor of shape [batch_size * num_instances,
mask_size, mask_size, 1].
mask_size, mask_size, 1].
"""
"""
with
backend
.
get
_graph
()
.
as_default
()
,
tf
.
name_scope
(
'prior_mask'
):
with
keras_utils
.
maybe_enter_
backend_graph
(),
tf
.
name_scope
(
'prior_mask'
):
batch_size
,
num_instances
,
_
=
boxes
.
get_shape
().
as_list
()
batch_size
,
num_instances
,
_
=
boxes
.
get_shape
().
as_list
()
outer_boxes
=
tf
.
cast
(
outer_boxes
,
tf
.
float32
)
outer_boxes
=
tf
.
cast
(
outer_boxes
,
tf
.
float32
)
boxes
=
tf
.
cast
(
boxes
,
tf
.
float32
)
boxes
=
tf
.
cast
(
boxes
,
tf
.
float32
)
...
@@ -807,7 +810,7 @@ class ShapemaskCoarsemaskHead(object):
...
@@ -807,7 +810,7 @@ class ShapemaskCoarsemaskHead(object):
mask_outputs: instance mask prediction as a float Tensor of shape
mask_outputs: instance mask prediction as a float Tensor of shape
[batch_size, num_instances, mask_size, mask_size].
[batch_size, num_instances, mask_size, mask_size].
"""
"""
with
backend
.
get
_graph
()
.
as_default
()
,
tf
.
name_scope
(
'coarse_mask'
):
with
keras_utils
.
maybe_enter_
backend_graph
(),
tf
.
name_scope
(
'coarse_mask'
):
# Transform detection priors to have the same dimension as features.
# Transform detection priors to have the same dimension as features.
detection_priors
=
tf
.
expand_dims
(
detection_priors
,
axis
=-
1
)
detection_priors
=
tf
.
expand_dims
(
detection_priors
,
axis
=-
1
)
detection_priors
=
self
.
_coarse_mask_fc
(
detection_priors
)
detection_priors
=
self
.
_coarse_mask_fc
(
detection_priors
)
...
@@ -939,7 +942,7 @@ class ShapemaskFinemaskHead(object):
...
@@ -939,7 +942,7 @@ class ShapemaskFinemaskHead(object):
"""
"""
# Extract the foreground mean features
# Extract the foreground mean features
# with tf.variable_scope('fine_mask', reuse=tf.AUTO_REUSE):
# with tf.variable_scope('fine_mask', reuse=tf.AUTO_REUSE):
with
backend
.
get
_graph
()
.
as_default
()
,
tf
.
name_scope
(
'fine_mask'
):
with
keras_utils
.
maybe_enter_
backend_graph
(),
tf
.
name_scope
(
'fine_mask'
):
mask_probs
=
tf
.
nn
.
sigmoid
(
mask_logits
)
mask_probs
=
tf
.
nn
.
sigmoid
(
mask_logits
)
# Compute instance embedding for hard average.
# Compute instance embedding for hard average.
binary_mask
=
tf
.
cast
(
tf
.
greater
(
mask_probs
,
0.5
),
features
.
dtype
)
binary_mask
=
tf
.
cast
(
tf
.
greater
(
mask_probs
,
0.5
),
features
.
dtype
)
...
...
official/vision/detection/modeling/architecture/keras_utils.py
0 → 100644
View file @
5a2cf36f
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Util functions to integrate with Keras internals."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
tensorflow.python.keras
import
backend
try
:
from
tensorflow.python.keras.engine
import
keras_tensor
# pylint: disable=g-import-not-at-top,unused-import
except
ImportError
:
keras_tensor
=
None
class
NoOpContextManager
(
object
):
def
__enter__
(
self
):
pass
def
__exit__
(
self
,
*
args
):
pass
def
maybe_enter_backend_graph
():
if
(
keras_tensor
is
not
None
)
and
keras_tensor
.
keras_tensors_enabled
():
return
NoOpContextManager
()
else
:
return
backend
.
get_graph
().
as_default
()
official/vision/detection/modeling/architecture/resnet.py
View file @
5a2cf36f
...
@@ -25,7 +25,7 @@ from __future__ import print_function
...
@@ -25,7 +25,7 @@ from __future__ import print_function
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
official.vision.detection.modeling.architecture
import
keras_utils
from
official.vision.detection.modeling.architecture
import
nn_ops
from
official.vision.detection.modeling.architecture
import
nn_ops
# TODO(b/140112644): Refactor the code with Keras style, i.e. build and call.
# TODO(b/140112644): Refactor the code with Keras style, i.e. build and call.
...
@@ -90,7 +90,7 @@ class Resnet(object):
...
@@ -90,7 +90,7 @@ class Resnet(object):
The values are corresponding feature hierarchy in ResNet with shape
The values are corresponding feature hierarchy in ResNet with shape
[batch_size, height_l, width_l, num_filters].
[batch_size, height_l, width_l, num_filters].
"""
"""
with
backend
.
get
_graph
()
.
as_default
()
:
with
keras_utils
.
maybe_enter_
backend_graph
():
with
tf
.
name_scope
(
'resnet%s'
%
self
.
_resnet_depth
):
with
tf
.
name_scope
(
'resnet%s'
%
self
.
_resnet_depth
):
return
self
.
_resnet_fn
(
inputs
,
is_training
)
return
self
.
_resnet_fn
(
inputs
,
is_training
)
...
...
official/vision/detection/modeling/architecture/spinenet.py
View file @
5a2cf36f
...
@@ -24,8 +24,8 @@ import math
...
@@ -24,8 +24,8 @@ import math
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.vision.detection.modeling.architecture
import
keras_utils
from
official.vision.detection.modeling.architecture
import
nn_blocks
from
official.vision.detection.modeling.architecture
import
nn_blocks
layers
=
tf
.
keras
.
layers
layers
=
tf
.
keras
.
layers
...
@@ -486,7 +486,7 @@ class SpineNetBuilder(object):
...
@@ -486,7 +486,7 @@ class SpineNetBuilder(object):
self
.
_norm_epsilon
=
norm_epsilon
self
.
_norm_epsilon
=
norm_epsilon
def
__call__
(
self
,
inputs
,
is_training
=
None
):
def
__call__
(
self
,
inputs
,
is_training
=
None
):
with
backend
.
get
_graph
()
.
as_default
()
:
with
keras_utils
.
maybe_enter_
backend_graph
():
model
=
SpineNet
(
model
=
SpineNet
(
input_specs
=
self
.
_input_specs
,
input_specs
=
self
.
_input_specs
,
min_level
=
self
.
_min_level
,
min_level
=
self
.
_min_level
,
...
...
official/vision/detection/modeling/maskrcnn_model.py
View file @
5a2cf36f
...
@@ -20,13 +20,13 @@ from __future__ import print_function
...
@@ -20,13 +20,13 @@ from __future__ import print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
official.vision.detection.dataloader
import
anchor
from
official.vision.detection.dataloader
import
anchor
from
official.vision.detection.dataloader
import
mode_keys
from
official.vision.detection.dataloader
import
mode_keys
from
official.vision.detection.evaluation
import
factory
as
eval_factory
from
official.vision.detection.evaluation
import
factory
as
eval_factory
from
official.vision.detection.modeling
import
base_model
from
official.vision.detection.modeling
import
base_model
from
official.vision.detection.modeling
import
losses
from
official.vision.detection.modeling
import
losses
from
official.vision.detection.modeling.architecture
import
factory
from
official.vision.detection.modeling.architecture
import
factory
from
official.vision.detection.modeling.architecture
import
keras_utils
from
official.vision.detection.ops
import
postprocess_ops
from
official.vision.detection.ops
import
postprocess_ops
from
official.vision.detection.ops
import
roi_ops
from
official.vision.detection.ops
import
roi_ops
from
official.vision.detection.ops
import
spatial_transform_ops
from
official.vision.detection.ops
import
spatial_transform_ops
...
@@ -297,7 +297,7 @@ class MaskrcnnModel(base_model.Model):
...
@@ -297,7 +297,7 @@ class MaskrcnnModel(base_model.Model):
def
build_model
(
self
,
params
,
mode
):
def
build_model
(
self
,
params
,
mode
):
if
self
.
_keras_model
is
None
:
if
self
.
_keras_model
is
None
:
input_layers
=
self
.
build_input_layers
(
self
.
_params
,
mode
)
input_layers
=
self
.
build_input_layers
(
self
.
_params
,
mode
)
with
backend
.
get
_graph
()
.
as_default
()
:
with
keras_utils
.
maybe_enter_
backend_graph
():
outputs
=
self
.
model_outputs
(
input_layers
,
mode
)
outputs
=
self
.
model_outputs
(
input_layers
,
mode
)
model
=
tf
.
keras
.
models
.
Model
(
model
=
tf
.
keras
.
models
.
Model
(
...
...
official/vision/detection/modeling/retinanet_model.py
View file @
5a2cf36f
...
@@ -20,12 +20,12 @@ from __future__ import print_function
...
@@ -20,12 +20,12 @@ from __future__ import print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
official.vision.detection.dataloader
import
mode_keys
from
official.vision.detection.dataloader
import
mode_keys
from
official.vision.detection.evaluation
import
factory
as
eval_factory
from
official.vision.detection.evaluation
import
factory
as
eval_factory
from
official.vision.detection.modeling
import
base_model
from
official.vision.detection.modeling
import
base_model
from
official.vision.detection.modeling
import
losses
from
official.vision.detection.modeling
import
losses
from
official.vision.detection.modeling.architecture
import
factory
from
official.vision.detection.modeling.architecture
import
factory
from
official.vision.detection.modeling.architecture
import
keras_utils
from
official.vision.detection.ops
import
postprocess_ops
from
official.vision.detection.ops
import
postprocess_ops
...
@@ -57,7 +57,7 @@ class RetinanetModel(base_model.Model):
...
@@ -57,7 +57,7 @@ class RetinanetModel(base_model.Model):
params
.
postprocess
)
params
.
postprocess
)
self
.
_transpose_input
=
params
.
train
.
transpose_input
self
.
_transpose_input
=
params
.
train
.
transpose_input
assert
not
self
.
_transpose_input
,
'Transpose input is not support
t
ed.'
assert
not
self
.
_transpose_input
,
'Transpose input is not supported.'
# Input layer.
# Input layer.
input_shape
=
(
input_shape
=
(
params
.
retinanet_parser
.
output_size
+
params
.
retinanet_parser
.
output_size
+
...
@@ -120,7 +120,7 @@ class RetinanetModel(base_model.Model):
...
@@ -120,7 +120,7 @@ class RetinanetModel(base_model.Model):
def
build_model
(
self
,
params
,
mode
=
None
):
def
build_model
(
self
,
params
,
mode
=
None
):
if
self
.
_keras_model
is
None
:
if
self
.
_keras_model
is
None
:
with
backend
.
get
_graph
()
.
as_default
()
:
with
keras_utils
.
maybe_enter_
backend_graph
():
outputs
=
self
.
model_outputs
(
self
.
_input_layer
,
mode
)
outputs
=
self
.
model_outputs
(
self
.
_input_layer
,
mode
)
model
=
tf
.
keras
.
models
.
Model
(
model
=
tf
.
keras
.
models
.
Model
(
...
...
official/vision/detection/modeling/shapemask_model.py
View file @
5a2cf36f
...
@@ -20,13 +20,13 @@ from __future__ import print_function
...
@@ -20,13 +20,13 @@ from __future__ import print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
official.vision.detection.dataloader
import
anchor
from
official.vision.detection.dataloader
import
anchor
from
official.vision.detection.dataloader
import
mode_keys
from
official.vision.detection.dataloader
import
mode_keys
from
official.vision.detection.evaluation
import
factory
as
eval_factory
from
official.vision.detection.evaluation
import
factory
as
eval_factory
from
official.vision.detection.modeling
import
base_model
from
official.vision.detection.modeling
import
base_model
from
official.vision.detection.modeling
import
losses
from
official.vision.detection.modeling
import
losses
from
official.vision.detection.modeling.architecture
import
factory
from
official.vision.detection.modeling.architecture
import
factory
from
official.vision.detection.modeling.architecture
import
keras_utils
from
official.vision.detection.ops
import
postprocess_ops
from
official.vision.detection.ops
import
postprocess_ops
from
official.vision.detection.utils
import
box_utils
from
official.vision.detection.utils
import
box_utils
...
@@ -265,7 +265,7 @@ class ShapeMaskModel(base_model.Model):
...
@@ -265,7 +265,7 @@ class ShapeMaskModel(base_model.Model):
def
build_model
(
self
,
params
,
mode
):
def
build_model
(
self
,
params
,
mode
):
if
self
.
_keras_model
is
None
:
if
self
.
_keras_model
is
None
:
input_layers
=
self
.
build_input_layers
(
self
.
_params
,
mode
)
input_layers
=
self
.
build_input_layers
(
self
.
_params
,
mode
)
with
backend
.
get
_graph
()
.
as_default
()
:
with
keras_utils
.
maybe_enter_
backend_graph
():
outputs
=
self
.
model_outputs
(
input_layers
,
mode
)
outputs
=
self
.
model_outputs
(
input_layers
,
mode
)
model
=
tf
.
keras
.
models
.
Model
(
model
=
tf
.
keras
.
models
.
Model
(
...
...
official/vision/image_classification/resnet/common.py
View file @
5a2cf36f
...
@@ -255,7 +255,7 @@ def define_keras_flags(
...
@@ -255,7 +255,7 @@ def define_keras_flags(
name
=
'tpu'
,
default
=
''
,
help
=
'TPU address to connect to.'
)
name
=
'tpu'
,
default
=
''
,
help
=
'TPU address to connect to.'
)
flags
.
DEFINE_integer
(
flags
.
DEFINE_integer
(
name
=
'steps_per_loop'
,
name
=
'steps_per_loop'
,
default
=
500
,
default
=
None
,
help
=
'Number of steps per training loop. Only training step happens '
help
=
'Number of steps per training loop. Only training step happens '
'inside the loop. Callbacks will not be called inside. Will be capped at '
'inside the loop. Callbacks will not be called inside. Will be capped at '
'steps per epoch.'
)
'steps per epoch.'
)
...
...
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
View file @
5a2cf36f
...
@@ -14,18 +14,16 @@
...
@@ -14,18 +14,16 @@
# ==============================================================================
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
math
import
os
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.staging.training
import
controller
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
...
@@ -87,15 +85,6 @@ def get_num_train_iterations(flags_obj):
...
@@ -87,15 +85,6 @@ def get_num_train_iterations(flags_obj):
return
train_steps
,
train_epochs
,
eval_steps
return
train_steps
,
train_epochs
,
eval_steps
def
_steps_to_run
(
steps_in_current_epoch
,
steps_per_epoch
,
steps_per_loop
):
"""Calculates steps to run on device."""
if
steps_per_loop
<=
0
:
raise
ValueError
(
'steps_per_loop should be positive integer.'
)
if
steps_per_loop
==
1
:
return
steps_per_loop
return
min
(
steps_per_loop
,
steps_per_epoch
-
steps_in_current_epoch
)
def
run
(
flags_obj
):
def
run
(
flags_obj
):
"""Run ResNet ImageNet training and eval loop using custom training loops.
"""Run ResNet ImageNet training and eval loop using custom training loops.
...
@@ -121,7 +110,6 @@ def run(flags_obj):
...
@@ -121,7 +110,6 @@ def run(flags_obj):
datasets_num_private_threads
=
flags_obj
.
datasets_num_private_threads
)
datasets_num_private_threads
=
flags_obj
.
datasets_num_private_threads
)
common
.
set_cudnn_batchnorm_mode
()
common
.
set_cudnn_batchnorm_mode
()
# TODO(anj-s): Set data_format without using Keras.
data_format
=
flags_obj
.
data_format
data_format
=
flags_obj
.
data_format
if
data_format
is
None
:
if
data_format
is
None
:
data_format
=
(
'channels_first'
if
tf
.
config
.
list_physical_devices
(
'GPU'
)
data_format
=
(
'channels_first'
if
tf
.
config
.
list_physical_devices
(
'GPU'
)
...
@@ -137,7 +125,14 @@ def run(flags_obj):
...
@@ -137,7 +125,14 @@ def run(flags_obj):
per_epoch_steps
,
train_epochs
,
eval_steps
=
get_num_train_iterations
(
per_epoch_steps
,
train_epochs
,
eval_steps
=
get_num_train_iterations
(
flags_obj
)
flags_obj
)
steps_per_loop
=
min
(
flags_obj
.
steps_per_loop
,
per_epoch_steps
)
if
flags_obj
.
steps_per_loop
is
None
:
steps_per_loop
=
per_epoch_steps
elif
flags_obj
.
steps_per_loop
>
per_epoch_steps
:
steps_per_loop
=
per_epoch_steps
logging
.
warn
(
'Setting steps_per_loop to %d to respect epoch boundary.'
,
steps_per_loop
)
else
:
steps_per_loop
=
flags_obj
.
steps_per_loop
logging
.
info
(
logging
.
info
(
'Training %d epochs, each epoch has %d steps, '
'Training %d epochs, each epoch has %d steps, '
...
@@ -154,8 +149,8 @@ def run(flags_obj):
...
@@ -154,8 +149,8 @@ def run(flags_obj):
eval_interval
=
flags_obj
.
epochs_between_evals
*
per_epoch_steps
eval_interval
=
flags_obj
.
epochs_between_evals
*
per_epoch_steps
checkpoint_interval
=
(
checkpoint_interval
=
(
per_epoch_steps
if
flags_obj
.
enable_checkpoint_and_export
else
None
)
steps_per_loop
*
5
if
flags_obj
.
enable_checkpoint_and_export
else
None
)
summary_interval
=
per_epoch_steps
if
flags_obj
.
enable_tensorboard
else
None
summary_interval
=
steps_per_loop
if
flags_obj
.
enable_tensorboard
else
None
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
runnable
.
checkpoint
,
runnable
.
checkpoint
,
...
@@ -164,20 +159,24 @@ def run(flags_obj):
...
@@ -164,20 +159,24 @@ def run(flags_obj):
step_counter
=
runnable
.
global_step
,
step_counter
=
runnable
.
global_step
,
checkpoint_interval
=
checkpoint_interval
)
checkpoint_interval
=
checkpoint_interval
)
resnet_controller
=
controller
.
Controller
(
resnet_controller
=
orbit
.
Controller
(
strategy
,
strategy
,
runnable
.
train
,
runnable
,
runnable
.
evaluate
if
not
flags_obj
.
skip_eval
else
None
,
runnable
if
not
flags_obj
.
skip_eval
else
None
,
global_step
=
runnable
.
global_step
,
global_step
=
runnable
.
global_step
,
steps_per_loop
=
steps_per_loop
,
steps_per_loop
=
steps_per_loop
,
train_steps
=
per_epoch_steps
*
train_epochs
,
checkpoint_manager
=
checkpoint_manager
,
checkpoint_manager
=
checkpoint_manager
,
summary_interval
=
summary_interval
,
summary_interval
=
summary_interval
,
eval_steps
=
eval_steps
,
eval_summary_dir
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'eval'
))
eval_interval
=
eval_interval
)
time_callback
.
on_train_begin
()
time_callback
.
on_train_begin
()
resnet_controller
.
train
(
evaluate
=
not
flags_obj
.
skip_eval
)
if
not
flags_obj
.
skip_eval
:
resnet_controller
.
train_and_evaluate
(
train_steps
=
per_epoch_steps
*
train_epochs
,
eval_steps
=
eval_steps
,
eval_interval
=
eval_interval
)
else
:
resnet_controller
.
train
(
steps
=
per_epoch_steps
*
train_epochs
)
time_callback
.
on_train_end
()
time_callback
.
on_train_end
()
stats
=
build_stats
(
runnable
,
time_callback
)
stats
=
build_stats
(
runnable
,
time_callback
)
...
...
Prev
1
2
3
4
5
6
7
8
9
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment