Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
81fa82ef
Commit
81fa82ef
authored
Sep 16, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Sep 16, 2020
Browse files
Implement `TransformerXLBlock.`
PiperOrigin-RevId: 332055561
parent
6e8f1284
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
471 additions
and
4 deletions
+471
-4
official/nlp/modeling/layers/relative_attention.py
official/nlp/modeling/layers/relative_attention.py
+4
-4
official/nlp/modeling/layers/transformer_xl.py
official/nlp/modeling/layers/transformer_xl.py
+292
-0
official/nlp/modeling/layers/transformer_xl_test.py
official/nlp/modeling/layers/transformer_xl_test.py
+175
-0
No files found.
official/nlp/modeling/layers/relative_attention.py
View file @
81fa82ef
...
@@ -113,8 +113,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -113,8 +113,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
value: Value `Tensor` of shape `[B, S, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
content_attention_bias: Bias `Tensor` for content based attention of shape
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
`[num_heads, dim]`.
position_attention_bias: Bias `Tensor` for position based attention of
shape
position
al
_attention_bias: Bias `Tensor` for position based attention of
`[num_heads, dim]`.
shape
`[num_heads, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
`value` for both `key` and `value`, which is the most common case.
relative_position_encoding: Relative positional encoding `Tensor` of shape
relative_position_encoding: Relative positional encoding `Tensor` of shape
...
@@ -368,8 +368,8 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
...
@@ -368,8 +368,8 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_stream: `Tensor` of shape `[B, T, dim]`.
content_stream: `Tensor` of shape `[B, T, dim]`.
content_attention_bias: Bias `Tensor` for content based attention of shape
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
`[num_heads, dim]`.
position_attention_bias: Bias `Tensor` for position based attention of
shape
position
al
_attention_bias: Bias `Tensor` for position based attention of
`[num_heads, dim]`.
shape
`[num_heads, dim]`.
query_stream: `Tensor` of shape `[B, P, dim]`.
query_stream: `Tensor` of shape `[B, P, dim]`.
target_mapping: `Tensor` of shape `[B, P, S]`.
target_mapping: `Tensor` of shape `[B, P, S]`.
relative_position_encoding: Relative positional encoding `Tensor` of shape
relative_position_encoding: Relative positional encoding `Tensor` of shape
...
...
official/nlp/modeling/layers/transformer_xl.py
0 → 100644
View file @
81fa82ef
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based Transformer XL layer."""
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
relative_attention
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
TransformerXLBlock
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer XL block.
This implements a Transformer XL block from "Transformer-XL: Attentive
Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
This block is further extended to allow for the Transformer-XL
re-parameterization in "XLNet: Generalized Autoregressive Pretraining for
Language Understanding" (https://arxiv.org/abs/1906.08237).
Given an input stream, this block computes attention, applies dropouts and
layer norms and feeds into the FFN network.
**Note: This layer is currently experimental.
Attributes:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_attention_heads: The number of attention heads.
head_size: The dimension size of each attention head.
inner_size: The inner size for the transformer layers.
dropout_rate: Dropout rate for the output of this layer.
attention_dropout_rate: Dropout rate on attention probabilities.
two_stream: Whether or not to use `TwoStreamRelativeAttention` used in the
XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
norm_epsilon: Epsilon value to initialize normalization layers.
inner_activation: The activation to use for the inner
FFN layers.
kernel_initializer: Initializer for dense layer kernels.
inner_dropout: Dropout probability for the inner dropout
layer.
"""
def
__init__
(
self
,
vocab_size
,
hidden_size
,
num_attention_heads
,
head_size
,
inner_size
,
dropout_rate
,
attention_dropout_rate
,
two_stream
=
False
,
norm_epsilon
=
1e-12
,
inner_activation
=
"relu"
,
kernel_initializer
=
"variance_scaling"
,
inner_dropout
=
0.0
,
**
kwargs
):
"""Initializes TransformerXLBlock layer."""
super
(
TransformerXLBlock
,
self
).
__init__
(
**
kwargs
)
self
.
_vocab_size
=
vocab_size
self
.
_num_heads
=
num_attention_heads
self
.
_head_size
=
head_size
self
.
_hidden_size
=
hidden_size
self
.
_inner_size
=
inner_size
self
.
_dropout_rate
=
dropout_rate
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_inner_activation
=
inner_activation
self
.
_norm_epsilon
=
norm_epsilon
self
.
_kernel_initializer
=
kernel_initializer
self
.
_inner_dropout
=
inner_dropout
self
.
_two_stream
=
two_stream
if
two_stream
:
self
.
_attention_layer_type
=
relative_attention
.
TwoStreamRelativeAttention
else
:
self
.
_attention_layer_type
=
relative_attention
.
MultiHeadRelativeAttention
def
build
(
self
,
input_shape
):
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
input_tensor_shape
=
tf
.
TensorShape
(
input_tensor
)
if
len
(
input_tensor_shape
.
as_list
())
!=
3
:
raise
ValueError
(
"TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width]."
)
batch_size
,
sequence_length
,
hidden_size
=
input_tensor_shape
if
len
(
input_shape
)
==
2
:
mask_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
1
])
expected_mask_tensor_shape
=
tf
.
TensorShape
(
[
batch_size
,
sequence_length
,
sequence_length
])
if
not
expected_mask_tensor_shape
.
is_compatible_with
(
mask_tensor_shape
):
raise
ValueError
(
"When passing a mask tensor to TransformerXLBlock, "
"the mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s."
%
(
expected_mask_tensor_shape
,
mask_tensor_shape
))
if
hidden_size
%
self
.
_num_heads
!=
0
:
raise
ValueError
(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_layer
=
self
.
_attention_layer_type
(
num_heads
=
self
.
_num_heads
,
key_dim
=
self
.
_head_size
,
value_dim
=
self
.
_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"rel_attn"
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_attention_dropout_rate
)
self
.
_attention_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
)
self
.
_inner_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
_inner_size
),
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"inner"
)
self
.
_inner_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_inner_activation
)
self
.
_inner_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_inner_dropout
)
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
kernel_initializer
=
self
.
_kernel_initializer
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
)
super
(
TransformerXLBlock
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"vocab_size"
:
self
.
_vocab_size
,
"hidden_size"
:
self
.
_hidden_size
,
"num_attention_heads"
:
self
.
_num_heads
,
"head_size"
:
self
.
_head_size
,
"inner_size"
:
self
.
_inner_size
,
"dropout_rate"
:
self
.
_dropout_rate
,
"attention_dropout_rate"
:
self
.
_attention_dropout_rate
,
"two_stream"
:
self
.
_two_stream
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"inner_activation"
:
self
.
_inner_activation
,
"kernel_initializer"
:
self
.
_kernel_initializer
,
"inner_dropout"
:
self
.
_inner_dropout
,
}
base_config
=
super
(
TransformerXLBlock
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
content_stream
,
content_attention_bias
,
positional_attention_bias
,
relative_position_encoding
=
None
,
segment_matrix
=
None
,
segment_encoding
=
None
,
segment_attention_bias
=
None
,
state
=
None
,
content_attention_mask
=
None
,
query_stream
=
None
,
query_attention_mask
=
None
,
target_mapping
=
None
):
"""Implements `call` for the Layer.
Arguments:
content_stream: `Tensor`, the input content stream. This is the standard
input to Transformer XL and is commonly referred to as `h` in XLNet.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
positional_attention_bias: Bias `Tensor` for position based attention of
shape `[num_heads, dim]`.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet,
but not in Transformer XL.
segment_encoding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used
in XLNet, but not in Transformer XL.
segment_attention_bias: Optional bias `Tensor` for segment based attention
of shape `[num_heads, dim]`.
state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
content_attention_mask: Optional `Tensor` representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
query_stream: Optional `Tensor`, the query stream. This is introduced in
`TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if
`two_stream` is `False`.
query_attention_mask: Optional `Tensor` representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
target_mapping: Optional `Tensor` representing the target mapping when
calculating query attention.
Returns:
A `dict` object, containing the key value pairs for `content_attention`
and (if `two_stream` is `True`) `query_attention`.
"""
if
not
self
.
_two_stream
and
query_stream
is
not
None
:
logging
.
warning
(
"`query_stream` was provided but two stream attention is "
"disabled. `query_stream` will be ignored."
)
if
self
.
_two_stream
:
attention_kwargs
=
dict
(
content_stream
=
content_stream
,
query_stream
=
query_stream
,
query_attention_mask
=
query_attention_mask
,
target_mapping
=
target_mapping
,
content_attention_mask
=
content_attention_mask
)
else
:
attention_kwargs
=
dict
(
query
=
content_stream
,
value
=
content_stream
,
key
=
content_stream
,
attention_mask
=
content_attention_mask
)
common_attention_kwargs
=
dict
(
content_attention_bias
=
content_attention_bias
,
relative_position_encoding
=
relative_position_encoding
,
positional_attention_bias
=
positional_attention_bias
,
segment_matrix
=
segment_matrix
,
segment_encoding
=
segment_encoding
,
segment_attention_bias
=
segment_attention_bias
,
state
=
state
)
attention_kwargs
.
update
(
common_attention_kwargs
)
attention_output
=
self
.
_attention_layer
(
**
attention_kwargs
)
if
self
.
_two_stream
:
attention_streams
=
attention_output
input_streams
=
[
content_stream
,
query_stream
]
else
:
attention_streams
=
[
attention_output
]
input_streams
=
[
content_stream
]
attention_keys
=
[
"content_attention"
,
"query_attention"
]
attention_output
=
{}
for
attention_stream
,
input_stream
,
attention_key
in
zip
(
attention_streams
,
input_streams
,
attention_keys
):
attention_stream
=
self
.
_attention_dropout
(
attention_stream
)
attention_stream
=
self
.
_attention_layer_norm
(
attention_stream
+
input_stream
)
inner_output
=
self
.
_inner_dense
(
attention_stream
)
inner_output
=
self
.
_inner_activation_layer
(
inner_output
)
inner_output
=
self
.
_inner_dropout_layer
(
inner_output
)
layer_output
=
self
.
_output_dense
(
inner_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_stream
)
attention_output
[
attention_key
]
=
layer_output
return
attention_output
official/nlp/modeling/layers/transformer_xl_test.py
0 → 100644
View file @
81fa82ef
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Transformer XL."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling.layers
import
transformer_xl
def
create_mock_transformer_xl_data
(
batch_size
,
num_heads
,
head_size
,
hidden_size
,
seq_length
,
memory_length
=
0
,
num_predictions
=
2
,
two_stream
=
False
,
include_state
=
False
,
include_mask
=
False
,
include_segment
=
False
):
"""Creates mock testing data.
Args:
batch_size: `int`, the batch size.
num_heads: `int`, number of attention heads.
head_size: `int`, the size of each attention head.
hidden_size: `int`, the layer's hidden size.
seq_length: `int`, Sequence length of the input.
memory_length: optional `int`, the length of the state. Defaults to 0.
num_predictions: `int`, the number of predictions used in two stream
attention.
two_stream: `bool`, whether or not to generate two stream data.
include_state: optional `bool`, whether or not to include state data.
include_mask: optional `bool`, whether or not to include mask data.
include_segment: optional `bool`, whether or not to include segment data.
Returns:
A dictionary with `str` as keys and `Tensor` as values.
"""
encoding_shape
=
(
batch_size
,
seq_length
*
2
,
hidden_size
)
attention_bias_shape
=
(
num_heads
,
head_size
)
data
=
dict
(
content_stream
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
hidden_size
)),
relative_position_encoding
=
tf
.
random
.
normal
(
shape
=
encoding_shape
),
content_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
),
positional_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
))
if
two_stream
:
two_stream_data
=
dict
(
query_stream
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
num_predictions
,
hidden_size
)),
target_mapping
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
num_predictions
,
seq_length
)))
data
.
update
(
two_stream_data
)
if
include_state
:
total_seq_length
=
seq_length
+
memory_length
data
[
"state"
]
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
memory_length
,
hidden_size
))
else
:
total_seq_length
=
seq_length
if
include_mask
:
mask_shape
=
(
batch_size
,
num_heads
,
seq_length
,
total_seq_length
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"float32"
)
data
[
"content_attention_mask"
]
=
mask_data
if
two_stream
:
data
[
"query_attention_mask"
]
=
mask_data
if
include_segment
:
segment_encoding_shape
=
(
2
,
num_heads
,
head_size
)
segment_matrix
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
seq_length
,
total_seq_length
))
segment_matrix
=
tf
.
math
.
equal
(
segment_matrix
,
1
)
segment_data
=
dict
(
segment_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
),
segment_encoding
=
tf
.
random
.
normal
(
shape
=
segment_encoding_shape
),
segment_matrix
=
segment_matrix
)
data
.
update
(
segment_data
)
return
data
@
keras_parameterized
.
run_all_keras_modes
class
TransformerXLBlockTest
(
keras_parameterized
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
memory_length
=
[
0
,
4
],
two_stream
=
[
True
,
False
],
state
=
[
True
,
False
],
mask
=
[
True
,
False
],
segment
=
[
True
,
False
]))
def
test_transformer_xl
(
self
,
two_stream
,
memory_length
,
state
,
mask
,
segment
):
"""Tests combinations of Transformer XL calculations."""
batch_size
,
num_heads
,
head_size
,
seq_length
=
2
,
12
,
64
,
8
hidden_size
,
num_predictions
,
inner_size
=
24
,
8
,
12
data
=
create_mock_transformer_xl_data
(
num_heads
=
num_heads
,
head_size
=
head_size
,
hidden_size
=
hidden_size
,
seq_length
=
seq_length
,
batch_size
=
batch_size
,
memory_length
=
memory_length
,
num_predictions
=
num_predictions
,
two_stream
=
two_stream
,
include_state
=
state
,
include_mask
=
mask
,
include_segment
=
segment
)
test_layer
=
transformer_xl
.
TransformerXLBlock
(
vocab_size
=
32000
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_heads
,
head_size
=
head_size
,
inner_size
=
inner_size
,
dropout_rate
=
0.
,
attention_dropout_rate
=
0.
,
two_stream
=
two_stream
)
output
=
test_layer
(
**
data
)
content_attention
=
output
[
"content_attention"
]
self
.
assertEqual
(
content_attention
.
shape
,
[
batch_size
,
seq_length
,
hidden_size
])
if
two_stream
:
self
.
assertIn
(
"query_attention"
,
output
)
self
.
assertEqual
(
output
[
"query_attention"
].
shape
,
[
batch_size
,
num_predictions
,
hidden_size
])
else
:
self
.
assertNotIn
(
"query_attention"
,
output
)
def
test_get_config
(
self
):
transformer_xl_block
=
transformer_xl
.
TransformerXLBlock
(
vocab_size
=
32000
,
head_size
=
64
,
num_attention_heads
=
2
,
hidden_size
=
10
,
inner_size
=
50
,
dropout_rate
=
0.
,
attention_dropout_rate
=
0.
,
two_stream
=
False
)
transformer_xl_block_config
=
transformer_xl_block
.
get_config
()
new_block
=
transformer_xl
.
TransformerXLBlock
.
from_config
(
transformer_xl_block_config
)
self
.
assertEqual
(
transformer_xl_block_config
,
new_block
.
get_config
())
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment