Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ef800b03
Commit
ef800b03
authored
Aug 08, 2020
by
xinliupitt
Browse files
no predict func
parent
0490e860
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
65 deletions
+0
-65
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+0
-65
No files found.
official/nlp/modeling/models/seq2seq_transformer.py
View file @
ef800b03
...
...
@@ -172,7 +172,6 @@ class Seq2SeqTransformer(tf.keras.Model):
attention_mask
=
attention_mask
)
if
targets
is
None
:
# return self.predict(encoder_outputs, attention_bias, training)
encoder_decoder_attention_bias
=
attention_bias
encoder_outputs
=
tf
.
cast
(
encoder_outputs
,
self
.
params
[
"dtype"
])
if
self
.
params
[
"padded_decode"
]:
...
...
@@ -369,70 +368,6 @@ class Seq2SeqTransformer(tf.keras.Model):
return
symbols_to_logits_fn
def
predict
(
self
,
encoder_outputs
,
encoder_decoder_attention_bias
,
training
):
"""Return predicted sequence."""
encoder_outputs
=
tf
.
cast
(
encoder_outputs
,
self
.
params
[
"dtype"
])
if
self
.
params
[
"padded_decode"
]:
batch_size
=
encoder_outputs
.
shape
.
as_list
()[
0
]
input_length
=
encoder_outputs
.
shape
.
as_list
()[
1
]
else
:
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
max_decode_length
=
input_length
+
self
.
params
[
"extra_decode_length"
]
encoder_decoder_attention_bias
=
tf
.
cast
(
encoder_decoder_attention_bias
,
self
.
params
[
"dtype"
])
symbols_to_logits_fn
=
self
.
_get_symbols_to_logits_fn
(
max_decode_length
,
training
)
# Create initial set of IDs that will be passed into symbols_to_logits_fn.
initial_ids
=
tf
.
zeros
([
batch_size
],
dtype
=
tf
.
int32
)
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length
=
(
max_decode_length
if
self
.
params
[
"padded_decode"
]
else
0
)
num_heads
=
self
.
params
[
"num_heads"
]
dim_per_head
=
self
.
params
[
"hidden_size"
]
//
num_heads
cache
=
{
str
(
layer
):
{
"key"
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
num_heads
,
dim_per_head
],
dtype
=
self
.
params
[
"dtype"
]),
"value"
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
num_heads
,
dim_per_head
],
dtype
=
self
.
params
[
"dtype"
])
}
for
layer
in
range
(
self
.
params
[
"num_hidden_layers"
])
}
# pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache.
cache
[
"encoder_outputs"
]
=
encoder_outputs
cache
[
"encoder_decoder_attention_bias"
]
=
encoder_decoder_attention_bias
# Use beam search to find the top beam_size sequences and scores.
decoded_ids
,
scores
=
beam_search
.
sequence_beam_search
(
symbols_to_logits_fn
=
symbols_to_logits_fn
,
initial_ids
=
initial_ids
,
initial_cache
=
cache
,
vocab_size
=
self
.
params
[
"vocab_size"
],
beam_size
=
self
.
params
[
"beam_size"
],
alpha
=
self
.
params
[
"alpha"
],
max_decode_length
=
max_decode_length
,
eos_id
=
EOS_ID
,
padded_decode
=
self
.
params
[
"padded_decode"
],
dtype
=
self
.
params
[
"dtype"
])
# Get the top sequence for each batch element
top_decoded_ids
=
decoded_ids
[:,
0
,
1
:]
top_scores
=
scores
[:,
0
]
return
{
"outputs"
:
top_decoded_ids
,
"scores"
:
top_scores
}
class
TransformerEncoder
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer decoder stack.
Like the encoder stack, the decoder stack is made up of N identical layers.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment