Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1180f37e
Commit
1180f37e
authored
Mar 16, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
Mar 16, 2021
Browse files
[nlp][translation] Remove seq2seq model _dtype argument and break transformer utils dependency.
PiperOrigin-RevId: 363121226
parent
f06dc1a6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
76 additions
and
85 deletions
+76
-85
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+76
-85
No files found.
official/nlp/modeling/models/seq2seq_transformer.py
View file @
1180f37e
...
...
@@ -23,10 +23,8 @@ from official.modeling import tf_utils
from
official.nlp
import
keras_nlp
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.ops
import
beam_search
from
official.nlp.transformer
import
model_utils
EOS_ID
=
1
# pylint: disable=g-classes-have-attributes
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
...
...
@@ -52,7 +50,6 @@ class Seq2SeqTransformer(tf.keras.Model):
alpha
=
0.6
,
encoder_layer
=
None
,
decoder_layer
=
None
,
dtype
=
tf
.
float32
,
eos_id
=
EOS_ID
,
**
kwargs
):
"""Initialize layers to build Transformer model.
...
...
@@ -69,7 +66,6 @@ class Seq2SeqTransformer(tf.keras.Model):
alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
dtype: float dtype.
eos_id: Id of end of sentence token.
**kwargs: other keyword arguments.
"""
...
...
@@ -82,7 +78,6 @@ class Seq2SeqTransformer(tf.keras.Model):
self
.
_extra_decode_length
=
extra_decode_length
self
.
_beam_size
=
beam_size
self
.
_alpha
=
alpha
self
.
_dtype
=
dtype
self
.
_eos_id
=
eos_id
self
.
embedding_lookup
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
self
.
_vocab_size
,
...
...
@@ -104,7 +99,6 @@ class Seq2SeqTransformer(tf.keras.Model):
"dropout_rate"
:
self
.
_dropout_rate
,
"padded_decode"
:
self
.
_padded_decode
,
"decode_max_length"
:
self
.
_decode_max_length
,
"dtype"
:
self
.
_dtype
,
"eos_id"
:
self
.
_eos_id
,
"extra_decode_length"
:
self
.
_extra_decode_length
,
"beam_size"
:
self
.
_beam_size
,
...
...
@@ -123,10 +117,7 @@ class Seq2SeqTransformer(tf.keras.Model):
vocab_size
=
tf
.
shape
(
embedding_matrix
)[
0
]
x
=
tf
.
reshape
(
x
,
[
-
1
,
hidden_size
])
logits
=
tf
.
matmul
(
tf
.
cast
(
x
,
dtype
=
self
.
_dtype
),
tf
.
cast
(
embedding_matrix
,
self
.
_dtype
),
transpose_b
=
True
)
logits
=
tf
.
matmul
(
x
,
tf
.
cast
(
embedding_matrix
,
x
.
dtype
),
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
vocab_size
])
...
...
@@ -154,14 +145,10 @@ class Seq2SeqTransformer(tf.keras.Model):
"""
sources
=
inputs
[
"inputs"
]
targets
=
inputs
.
get
(
"targets"
,
None
)
attention_bias
=
model_utils
.
get_padding_bias
(
sources
)
attention_bias
=
tf
.
cast
(
attention_bias
,
self
.
_dtype
)
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs
=
self
.
embedding_lookup
(
sources
)
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
sources
,
0
),
self
.
embedding_lookup
.
embeddings
.
dtype
)
embedded_inputs
=
tf
.
cast
(
embedded_inputs
,
self
.
_dtype
)
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
sources
,
0
),
embedded_inputs
.
dtype
)
embedded_inputs
*=
tf
.
expand_dims
(
embedding_mask
,
-
1
)
# Attention_mask generation.
input_shape
=
tf_utils
.
get_shape_list
(
sources
,
expected_rank
=
2
)
...
...
@@ -173,8 +160,8 @@ class Seq2SeqTransformer(tf.keras.Model):
shape
=
[
input_shape
[
0
],
input_shape
[
1
],
1
],
dtype
=
sources
.
dtype
)
attention_mask
=
broadcast_ones
*
attention_mask
pos_encoding
=
self
.
position_embedding
(
inputs
=
embedded_inputs
)
pos_encoding
=
tf
.
cast
(
pos_encoding
,
self
.
_
dtype
)
pos_encoding
=
self
.
position_embedding
(
embedded_inputs
)
pos_encoding
=
tf
.
cast
(
pos_encoding
,
embedded_inputs
.
dtype
)
encoder_inputs
=
embedded_inputs
+
pos_encoding
encoder_inputs
=
self
.
encoder_dropout
(
encoder_inputs
)
...
...
@@ -183,15 +170,11 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_inputs
,
attention_mask
=
attention_mask
)
if
targets
is
None
:
encoder_decoder_attention_bias
=
attention_bias
encoder_outputs
=
tf
.
cast
(
encoder_outputs
,
self
.
_dtype
)
if
self
.
_padded_decode
:
max_decode_length
=
self
.
_decode_max_length
else
:
max_decode_length
=
self
.
_decode_max_length
or
(
tf
.
shape
(
encoder_outputs
)[
1
]
+
self
.
_extra_decode_length
)
encoder_decoder_attention_bias
=
tf
.
cast
(
encoder_decoder_attention_bias
,
self
.
_dtype
)
symbols_to_logits_fn
=
self
.
_get_symbols_to_logits_fn
(
max_decode_length
)
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
...
...
@@ -199,28 +182,35 @@ class Seq2SeqTransformer(tf.keras.Model):
initial_ids
=
tf
.
zeros
([
batch_size
],
dtype
=
tf
.
int32
)
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length
=
(
max_decode_length
if
self
.
_padded_decode
else
0
)
num_heads
=
self
.
decoder_layer
.
num_attention_heads
dim_per_head
=
self
.
_embedding_width
//
num_heads
# Cache dtype needs to match beam_search dtype.
# pylint: disable=g-complex-comprehension
cache
=
{
str
(
layer
):
{
"key"
:
tf
.
zeros
(
[
batch_size
,
init_decode_length
,
num_heads
,
dim_per_head
],
dtype
=
self
.
_dtype
),
dtype
=
self
.
compute
_dtype
),
"value"
:
tf
.
zeros
(
[
batch_size
,
init_decode_length
,
num_heads
,
dim_per_head
],
dtype
=
self
.
_dtype
)
dtype
=
self
.
compute
_dtype
)
}
for
layer
in
range
(
self
.
decoder_layer
.
num_layers
)
}
# pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache.
encoder_outputs
=
tf
.
cast
(
encoder_outputs
,
dtype
=
self
.
compute_dtype
)
attention_mask
=
tf
.
cast
(
tf
.
reshape
(
tf
.
not_equal
(
sources
,
0
),
[
input_shape
[
0
],
1
,
input_shape
[
1
]]),
dtype
=
self
.
compute_dtype
)
cache
[
"encoder_outputs"
]
=
encoder_outputs
cache
[
"encoder_decoder_attention_
bi
as"
]
=
encoder_decoder_
attention_
bi
as
cache
[
"encoder_decoder_attention_
m
as
k
"
]
=
attention_
m
as
k
# Use beam search to find the top beam_size sequences and scores.
decoded_ids
,
scores
=
beam_search
.
sequence_beam_search
(
...
...
@@ -233,7 +223,7 @@ class Seq2SeqTransformer(tf.keras.Model):
max_decode_length
=
max_decode_length
,
eos_id
=
self
.
_eos_id
,
padded_decode
=
self
.
_padded_decode
,
dtype
=
self
.
_dtype
)
dtype
=
self
.
compute
_dtype
)
# Get the top sequence for each batch element
top_decoded_ids
=
decoded_ids
[:,
0
,
1
:]
...
...
@@ -242,15 +232,13 @@ class Seq2SeqTransformer(tf.keras.Model):
return
{
"outputs"
:
top_decoded_ids
,
"scores"
:
top_scores
}
decoder_inputs
=
self
.
embedding_lookup
(
targets
)
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
targets
,
0
),
self
.
embedding_lookup
.
embeddings
.
dtype
)
decoder_inputs
=
tf
.
cast
(
decoder_inputs
,
self
.
_dtype
)
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
targets
,
0
),
decoder_inputs
.
dtype
)
decoder_inputs
*=
tf
.
expand_dims
(
embedding_mask
,
-
1
)
# Shift targets to the right, and remove the last element
decoder_inputs
=
tf
.
pad
(
decoder_inputs
,
[[
0
,
0
],
[
1
,
0
],
[
0
,
0
]])[:,
:
-
1
,
:]
length
=
tf
.
shape
(
decoder_inputs
)[
1
]
pos_encoding
=
self
.
position_embedding
(
decoder_inputs
)
pos_encoding
=
tf
.
cast
(
pos_encoding
,
self
.
_
dtype
)
pos_encoding
=
tf
.
cast
(
pos_encoding
,
embedded_inputs
.
dtype
)
decoder_inputs
+=
pos_encoding
decoder_inputs
=
self
.
decoder_dropout
(
decoder_inputs
)
...
...
@@ -259,8 +247,7 @@ class Seq2SeqTransformer(tf.keras.Model):
batch_size
=
decoder_shape
[
0
]
decoder_length
=
decoder_shape
[
1
]
self_attention_mask
=
tf
.
linalg
.
band_part
(
tf
.
ones
([
length
,
length
],
dtype
=
tf
.
float32
),
-
1
,
0
)
self_attention_mask
=
tf
.
linalg
.
band_part
(
tf
.
ones
([
length
,
length
]),
-
1
,
0
)
self_attention_mask
=
tf
.
reshape
(
self_attention_mask
,
[
1
,
length
,
length
])
self_attention_mask
=
tf
.
tile
(
self_attention_mask
,
[
batch_size
,
1
,
1
])
...
...
@@ -274,6 +261,8 @@ class Seq2SeqTransformer(tf.keras.Model):
memory_mask
=
self_attention_mask
,
target_mask
=
attention_mask
)
logits
=
self
.
_embedding_linear
(
self
.
embedding_lookup
.
embeddings
,
outputs
)
# Model outputs should be float32 to avoid numeric issues.
# https://www.tensorflow.org/guide/mixed_precision#building_the_model
logits
=
tf
.
cast
(
logits
,
tf
.
float32
)
return
logits
...
...
@@ -281,9 +270,12 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal
=
self
.
position_embedding
(
inputs
=
None
,
length
=
max_decode_length
+
1
)
timing_signal
=
tf
.
cast
(
timing_signal
,
self
.
_dtype
)
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
max_decode_length
,
dtype
=
self
.
_dtype
)
timing_signal
=
tf
.
cast
(
timing_signal
,
dtype
=
self
.
compute_dtype
)
decoder_self_attention_mask
=
tf
.
linalg
.
band_part
(
tf
.
ones
([
max_decode_length
,
max_decode_length
],
dtype
=
self
.
compute_dtype
),
-
1
,
0
)
decoder_self_attention_mask
=
tf
.
reshape
(
decoder_self_attention_mask
,
[
1
,
max_decode_length
,
max_decode_length
])
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
"""Generate logits for next potential IDs.
...
...
@@ -308,33 +300,24 @@ class Seq2SeqTransformer(tf.keras.Model):
source_decoder_input
=
decoder_input
decoder_input
=
self
.
embedding_lookup
(
decoder_input
)
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
source_decoder_input
,
0
),
self
.
embedding_lookup
.
embeddings
.
dtype
)
tf
.
not_equal
(
source_decoder_input
,
0
),
decoder_input
.
dtype
)
decoder_input
*=
tf
.
expand_dims
(
embedding_mask
,
-
1
)
decoder_input
+=
timing_signal
[
i
]
if
self
.
_padded_decode
:
bias_shape
=
decoder_self_attention_bias
.
shape
.
as_list
()
self_attention_bias
=
tf
.
slice
(
decoder_self_attention_bias
,
[
0
,
0
,
i
,
0
],
[
bias_shape
[
0
],
bias_shape
[
1
],
1
,
bias_shape
[
3
]])
# indexing does not work on TPU.
bias_shape
=
decoder_self_attention_mask
.
shape
.
as_list
()
self_attention_mask
=
tf
.
slice
(
decoder_self_attention_mask
,
[
0
,
i
,
0
],
[
bias_shape
[
0
],
1
,
bias_shape
[
2
]])
else
:
self_attention_
bi
as
=
decoder_self_attention_
bi
as
[:,
:,
i
:
i
+
1
,
:
i
+
1
]
self_attention_
m
as
k
=
decoder_self_attention_
m
as
k
[:,
i
:
i
+
1
,
:
i
+
1
]
decoder_shape
=
tf_utils
.
get_shape_list
(
decoder_input
,
expected_rank
=
3
)
batch_size
=
decoder_shape
[
0
]
decoder_length
=
decoder_shape
[
1
]
attention_bias
=
cache
.
get
(
"encoder_decoder_attention_bias"
)
attention_bias
=
tf
.
where
(
attention_bias
<
0
,
tf
.
zeros_like
(
attention_bias
),
tf
.
ones_like
(
attention_bias
))
attention_bias
=
tf
.
squeeze
(
attention_bias
,
axis
=
[
1
])
attention_mask
=
tf
.
tile
(
attention_bias
,
[
1
,
decoder_length
,
1
])
self_attention_bias
=
tf
.
where
(
self_attention_bias
<
0
,
tf
.
zeros_like
(
self_attention_bias
),
tf
.
ones_like
(
self_attention_bias
))
self_attention_bias
=
tf
.
squeeze
(
self_attention_bias
,
axis
=
[
1
])
self_attention_mask
=
tf
.
tile
(
self_attention_bias
,
[
batch_size
,
1
,
1
])
self_attention_mask
=
tf
.
tile
(
self_attention_mask
,
[
batch_size
,
1
,
1
])
attention_mask
=
cache
.
get
(
"encoder_decoder_attention_mask"
)
attention_mask
=
tf
.
tile
(
attention_mask
,
[
1
,
decoder_length
,
1
])
decoder_outputs
=
self
.
decoder_layer
(
decoder_input
,
...
...
@@ -344,6 +327,7 @@ class Seq2SeqTransformer(tf.keras.Model):
cache
=
cache
,
decode_loop_step
=
i
if
self
.
_padded_decode
else
None
)
decoder_outputs
=
tf
.
cast
(
decoder_outputs
,
dtype
=
self
.
compute_dtype
)
logits
=
self
.
_embedding_linear
(
self
.
embedding_lookup
.
embeddings
,
decoder_outputs
)
logits
=
tf
.
squeeze
(
logits
,
axis
=
[
1
])
...
...
@@ -359,21 +343,6 @@ class TransformerEncoder(tf.keras.layers.Layer):
of the sublayers:
1. Self-attention layer
2. Feedforward network (which is 2 fully-connected layers)
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
"""
def
__init__
(
self
,
...
...
@@ -388,6 +357,25 @@ class TransformerEncoder(tf.keras.layers.Layer):
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.0
,
**
kwargs
):
"""Initialize a Transformer encoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super
(
TransformerEncoder
,
self
).
__init__
(
**
kwargs
)
self
.
num_layers
=
num_layers
self
.
num_attention_heads
=
num_attention_heads
...
...
@@ -469,21 +457,6 @@ class TransformerDecoder(tf.keras.layers.Layer):
2. Multi-headed attention layer combining encoder outputs with results from
the previous self-attention layer.
3. Feedforward network (2 fully-connected layers)
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set `False`, output of attention and intermediate dense layers
is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
"""
def
__init__
(
self
,
...
...
@@ -498,6 +471,24 @@ class TransformerDecoder(tf.keras.layers.Layer):
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.0
,
**
kwargs
):
"""Initialize a Transformer decoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set `False`, output of attention and intermediate
dense layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super
(
TransformerDecoder
,
self
).
__init__
(
**
kwargs
)
self
.
num_layers
=
num_layers
self
.
num_attention_heads
=
num_attention_heads
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment