Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2c98b4b0
Commit
2c98b4b0
authored
Sep 18, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Sep 18, 2020
Browse files
Implement Transformer XL
PiperOrigin-RevId: 332486600
parent
c7e31961
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
419 additions
and
24 deletions
+419
-24
official/nlp/modeling/layers/README.md
official/nlp/modeling/layers/README.md
+21
-0
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+4
-0
official/nlp/modeling/layers/transformer_xl.py
official/nlp/modeling/layers/transformer_xl.py
+269
-0
official/nlp/modeling/layers/transformer_xl_test.py
official/nlp/modeling/layers/transformer_xl_test.py
+125
-24
No files found.
official/nlp/modeling/layers/README.md
View file @
2c98b4b0
...
@@ -63,3 +63,24 @@ assemble new layers, networks, or models.
...
@@ -63,3 +63,24 @@ assemble new layers, networks, or models.
*
[
GatedFeedforward
](
gated_feedforward.py
)
implements the gated linear layer
*
[
GatedFeedforward
](
gated_feedforward.py
)
implements the gated linear layer
feedforward as described in
feedforward as described in
[
"GLU Variants Improve Transformer"
](
https://arxiv.org/abs/2002.05202
)
.
[
"GLU Variants Improve Transformer"
](
https://arxiv.org/abs/2002.05202
)
.
*
[
MultiHeadRelativeAttention
](
relative_attention.py
)
implements a variant
of multi-head attention with support for relative position encodings as
described in "Transformer-XL: Attentive Language Models Beyond a
Fixed-Length Context"(https://arxiv.org/abs/1901.02860). This also has
extended support for segment-based attention, a re-parameterization
introduced in "XLNet: Generalized Autoregressive Pretraining for Language
Understanding" (https://arxiv.org/abs/1906.08237).
*
[
TwoStreamRelativeAttention
](
relative_attention.py
)
implements a variant
of multi-head relative attention as described in "XLNet: Generalized
Autoregressive Pretraining for Language Understanding"
(https://arxiv.org/abs/1906.08237). This takes in a query and content
stream and applies self attention.
*
[
TransformerXL
](
transformer_xl.py
)
implements Transformer XL introduced in
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860). This contains
`TransformerXLBlock`
, a
block containing either one or two stream relative self-attention as well as
subsequent feedforward networks. It also contains
`TransformerXL`
, which
contains attention biases as well as multiple
`TransformerXLBlocks`
.
official/nlp/modeling/layers/__init__.py
View file @
2c98b4b0
...
@@ -24,9 +24,13 @@ from official.nlp.modeling.layers.mat_mul_with_margin import MatMulWithMargin
...
@@ -24,9 +24,13 @@ from official.nlp.modeling.layers.mat_mul_with_margin import MatMulWithMargin
from
official.nlp.modeling.layers.multi_channel_attention
import
*
from
official.nlp.modeling.layers.multi_channel_attention
import
*
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.modeling.layers.position_embedding
import
RelativePositionEmbedding
from
official.nlp.modeling.layers.position_embedding
import
RelativePositionEmbedding
from
official.nlp.modeling.layers.relative_attention
import
MultiHeadRelativeAttention
from
official.nlp.modeling.layers.relative_attention
import
TwoStreamRelativeAttention
from
official.nlp.modeling.layers.rezero_transformer
import
ReZeroTransformer
from
official.nlp.modeling.layers.rezero_transformer
import
ReZeroTransformer
from
official.nlp.modeling.layers.self_attention_mask
import
SelfAttentionMask
from
official.nlp.modeling.layers.self_attention_mask
import
SelfAttentionMask
from
official.nlp.modeling.layers.talking_heads_attention
import
TalkingHeadsAttention
from
official.nlp.modeling.layers.talking_heads_attention
import
TalkingHeadsAttention
from
official.nlp.modeling.layers.tn_transformer_expand_condense
import
TNTransformerExpandCondense
from
official.nlp.modeling.layers.tn_transformer_expand_condense
import
TNTransformerExpandCondense
from
official.nlp.modeling.layers.transformer
import
*
from
official.nlp.modeling.layers.transformer
import
*
from
official.nlp.modeling.layers.transformer_scaffold
import
TransformerScaffold
from
official.nlp.modeling.layers.transformer_scaffold
import
TransformerScaffold
from
official.nlp.modeling.layers.transformer_xl
import
TransformerXL
from
official.nlp.modeling.layers.transformer_xl
import
TransformerXLBlock
official/nlp/modeling/layers/transformer_xl.py
View file @
2c98b4b0
...
@@ -22,6 +22,35 @@ import tensorflow as tf
...
@@ -22,6 +22,35 @@ import tensorflow as tf
from
official.nlp.modeling.layers
import
relative_attention
from
official.nlp.modeling.layers
import
relative_attention
def
_cache_memory
(
current_state
,
previous_state
,
memory_length
,
reuse_length
=
0
):
"""Caches hidden states into memory.
Arguments:
current_state: `Tensor`, the current state.
previous_state: `Tensor`, the previous state.
memory_length: `int`, the number of tokens to cache.
reuse_length: `int`, the number of tokens in the current batch to be cached
and reused in the future.
Returns:
A `Tensor`, representing the cached state with stopped gradients.
"""
if
memory_length
is
None
or
memory_length
==
0
:
return
None
else
:
if
reuse_length
>
0
:
current_state
=
current_state
[:,
:
reuse_length
,
:]
if
previous_state
is
None
:
new_mem
=
current_state
[:,
-
memory_length
:,
:]
else
:
new_mem
=
tf
.
concat
(
[
previous_state
,
current_state
],
1
)[:,
-
memory_length
:,
:]
return
tf
.
stop_gradient
(
new_mem
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
TransformerXLBlock
(
tf
.
keras
.
layers
.
Layer
):
class
TransformerXLBlock
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer XL block.
"""Transformer XL block.
...
@@ -290,3 +319,243 @@ class TransformerXLBlock(tf.keras.layers.Layer):
...
@@ -290,3 +319,243 @@ class TransformerXLBlock(tf.keras.layers.Layer):
return
attention_output
return
attention_output
class
TransformerXL
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer XL.
This layer combines multiple Transformer XL blocks from "Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
This layer handles the attention biases as well as memory caching and reuse
as in Transformer XL and XLNet.
Attributes:
vocab_size: The number of tokens in vocabulary.
num_layers: The number of layers.
hidden_size: The hidden size.
num_attention_heads: The number of attention heads.
head_size: The dimension size of each attention head.
inner_size: The hidden size in feed-forward layers.
dropout_rate: Dropout rate used in each Transformer XL block.
attention_dropout_rate: Dropout rate on attention probabilities.
two_stream: Whether or not to use `TwoStreamRelativeAttention` used
in the XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
initializer: The initializer to use for attention biases.
tie_attention_biases: Whether or not to tie biases together. If `True`, then
each Transformer XL block shares the same trainable attention bias. If
`False`, then each block has its own attention bias. This is usually set
to `True`.
memory_length: The number of tokens to cache.
reuse_length: The number of tokens in the current batch to be cached
and reused in the future.
inner_activation: The activation to use in the inner layers
for Transformer XL blocks. Typically "relu" or "gelu".
"""
def
__init__
(
self
,
vocab_size
,
num_layers
,
hidden_size
,
num_attention_heads
,
head_size
,
inner_size
,
dropout_rate
,
attention_dropout_rate
,
initializer
,
two_stream
=
False
,
tie_attention_biases
=
True
,
memory_length
=
None
,
reuse_length
=
None
,
inner_activation
=
"relu"
,
**
kwargs
):
"""Initializes TransformerXL."""
super
(
TransformerXL
,
self
).
__init__
(
**
kwargs
)
self
.
_vocab_size
=
vocab_size
self
.
_initializer
=
initializer
self
.
_num_layers
=
num_layers
self
.
_hidden_size
=
hidden_size
self
.
_num_attention_heads
=
num_attention_heads
self
.
_head_size
=
head_size
self
.
_inner_size
=
inner_size
self
.
_inner_activation
=
inner_activation
self
.
_dropout_rate
=
dropout_rate
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_tie_attention_biases
=
tie_attention_biases
self
.
_two_stream
=
two_stream
self
.
_memory_length
=
memory_length
self
.
_reuse_length
=
reuse_length
if
self
.
_tie_attention_biases
:
attention_bias_shape
=
[
self
.
_num_attention_heads
,
self
.
_head_size
]
else
:
attention_bias_shape
=
[
self
.
_num_layers
,
self
.
_num_attention_heads
,
self
.
_head_size
]
self
.
content_attention_bias
=
self
.
add_weight
(
"content_attention_bias"
,
shape
=
attention_bias_shape
,
dtype
=
tf
.
float32
,
initializer
=
self
.
_initializer
)
self
.
positional_attention_bias
=
self
.
add_weight
(
"positional_attention_bias"
,
shape
=
attention_bias_shape
,
dtype
=
tf
.
float32
,
initializer
=
self
.
_initializer
)
self
.
segment_attention_bias
=
self
.
add_weight
(
"segment_attention_bias"
,
shape
=
attention_bias_shape
,
dtype
=
tf
.
float32
,
initializer
=
self
.
_initializer
)
self
.
transformer_xl_layers
=
[]
for
i
in
range
(
self
.
_num_layers
):
self
.
transformer_xl_layers
.
append
(
TransformerXLBlock
(
vocab_size
=
self
.
_vocab_size
,
hidden_size
=
self
.
_head_size
*
self
.
_num_attention_heads
,
num_attention_heads
=
self
.
_num_attention_heads
,
head_size
=
self
.
_head_size
,
inner_size
=
self
.
_inner_size
,
dropout_rate
=
self
.
_dropout_rate
,
attention_dropout_rate
=
self
.
_attention_dropout_rate
,
norm_epsilon
=
1e-12
,
inner_activation
=
self
.
_inner_activation
,
two_stream
=
self
.
_two_stream
,
kernel_initializer
=
"variance_scaling"
,
name
=
"layer_%d"
%
i
))
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
def
get_config
(
self
):
config
=
{
"vocab_size"
:
self
.
_vocab_size
,
"num_layers"
:
self
.
_num_layers
,
"hidden_size"
:
self
.
_hidden_size
,
"num_attention_heads"
:
self
.
_num_attention_heads
,
"head_size"
:
self
.
_head_size
,
"inner_size"
:
self
.
_inner_size
,
"dropout_rate"
:
self
.
_dropout_rate
,
"attention_dropout_rate"
:
self
.
_attention_dropout_rate
,
"initializer"
:
self
.
_initializer
,
"two_stream"
:
self
.
_two_stream
,
"tie_attention_biases"
:
self
.
_tie_attention_biases
,
"memory_length"
:
self
.
_memory_length
,
"reuse_length"
:
self
.
_reuse_length
,
"inner_activation"
:
self
.
_inner_activation
,
}
base_config
=
super
(
TransformerXL
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
content_stream
,
relative_position_encoding
,
segment_matrix
=
None
,
segment_embedding
=
None
,
state
=
None
,
content_attention_mask
=
None
,
query_stream
=
None
,
query_attention_mask
=
None
,
target_mapping
=
None
):
"""Implements call() for the layer.
Arguments:
content_stream: `Tensor`, the input content stream. This is the standard
input to Transformer XL and is commonly referred to as `h` in XLNet.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet,
but not in Transformer XL.
segment_embedding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used
in XLNet, but not in Transformer XL.
state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
content_attention_mask: Optional `Tensor` representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
query_stream: Optional `Tensor`, the query stream. This is introduced in
`TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if
`two_stream` is `False`.
query_attention_mask: Optional `Tensor` representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
target_mapping: Optional `Tensor` representing the target mapping when
calculating query attention.
Returns:
A tuple consisting of the attention output and the list of cached memory
states.
The attention output is `content_attention` if `two_stream` is `False`,
otherwise it is `query_attention`.
"""
new_mems
=
[]
if
state
is
None
:
state
=
[
None
]
*
self
.
_num_layers
for
i
in
range
(
self
.
_num_layers
):
# cache new mems
new_mems
.
append
(
_cache_memory
(
content_stream
,
state
[
i
],
self
.
_memory_length
,
self
.
_reuse_length
))
# segment bias
if
segment_matrix
is
None
:
segment_attention_bias
=
None
segment_encoding
=
None
else
:
segment_attention_bias
=
(
self
.
segment_attention_bias
if
self
.
_tie_attention_biases
else
self
.
segment_attention_bias
[
i
])
segment_encoding
=
segment_embedding
[
i
]
content_attention_bias
=
(
self
.
content_attention_bias
if
self
.
_tie_attention_biases
else
self
.
content_attention_bias
[
i
])
positional_attention_bias
=
(
self
.
positional_attention_bias
if
self
.
_tie_attention_biases
else
self
.
positional_attention_bias
[
i
])
transformer_xl_layer
=
self
.
transformer_xl_layers
[
i
]
transformer_xl_output
=
transformer_xl_layer
(
content_stream
=
content_stream
,
content_attention_bias
=
content_attention_bias
,
positional_attention_bias
=
positional_attention_bias
,
relative_position_encoding
=
relative_position_encoding
,
segment_matrix
=
segment_matrix
,
segment_encoding
=
segment_encoding
,
segment_attention_bias
=
segment_attention_bias
,
state
=
state
[
i
],
content_attention_mask
=
content_attention_mask
,
query_attention_mask
=
query_attention_mask
,
query_stream
=
query_stream
,
target_mapping
=
target_mapping
)
content_stream
=
transformer_xl_output
[
"content_attention"
]
if
self
.
_two_stream
:
query_stream
=
transformer_xl_output
[
"query_attention"
]
else
:
query_stream
=
None
if
self
.
_two_stream
:
output_stream
=
query_stream
else
:
output_stream
=
content_stream
return
output_stream
,
new_mems
official/nlp/modeling/layers/transformer_xl_test.py
View file @
2c98b4b0
...
@@ -32,6 +32,8 @@ def create_mock_transformer_xl_data(
...
@@ -32,6 +32,8 @@ def create_mock_transformer_xl_data(
memory_length
=
0
,
memory_length
=
0
,
num_predictions
=
2
,
num_predictions
=
2
,
two_stream
=
False
,
two_stream
=
False
,
num_layers
=
1
,
include_biases
=
True
,
include_state
=
False
,
include_state
=
False
,
include_mask
=
False
,
include_mask
=
False
,
include_segment
=
False
):
include_segment
=
False
):
...
@@ -47,6 +49,8 @@ def create_mock_transformer_xl_data(
...
@@ -47,6 +49,8 @@ def create_mock_transformer_xl_data(
num_predictions: `int`, the number of predictions used in two stream
num_predictions: `int`, the number of predictions used in two stream
attention.
attention.
two_stream: `bool`, whether or not to generate two stream data.
two_stream: `bool`, whether or not to generate two stream data.
num_layers: `int`, the number of Transformer XL blocks.
include_biases: optional `bool`, whether or not to include attention biases.
include_state: optional `bool`, whether or not to include state data.
include_state: optional `bool`, whether or not to include state data.
include_mask: optional `bool`, whether or not to include mask data.
include_mask: optional `bool`, whether or not to include mask data.
include_segment: optional `bool`, whether or not to include segment data.
include_segment: optional `bool`, whether or not to include segment data.
...
@@ -55,27 +59,34 @@ def create_mock_transformer_xl_data(
...
@@ -55,27 +59,34 @@ def create_mock_transformer_xl_data(
A dictionary with `str` as keys and `Tensor` as values.
A dictionary with `str` as keys and `Tensor` as values.
"""
"""
encoding_shape
=
(
batch_size
,
seq_length
*
2
,
hidden_size
)
encoding_shape
=
(
batch_size
,
seq_length
*
2
,
hidden_size
)
attention_bias_shape
=
(
num_heads
,
head_size
)
data
=
dict
(
data
=
dict
(
content_stream
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
hidden_size
)),
relative_position_encoding
=
tf
.
random
.
normal
(
shape
=
encoding_shape
),
relative_position_encoding
=
tf
.
random
.
normal
(
shape
=
encoding_shape
),
content_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
),
content_stream
=
tf
.
random
.
normal
(
positional_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
))
shape
=
(
batch_size
,
seq_length
,
hidden_size
)))
if
include_biases
:
attention_bias_shape
=
(
num_heads
,
head_size
)
data
.
update
(
dict
(
content_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
),
segment_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
),
positional_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
)))
if
two_stream
:
if
two_stream
:
two_stream_data
=
dict
(
data
.
update
(
dict
(
query_stream
=
tf
.
random
.
normal
(
query_stream
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
num_predictions
,
hidden_size
)),
shape
=
(
batch_size
,
num_predictions
,
hidden_size
)),
target_mapping
=
tf
.
random
.
normal
(
target_mapping
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
num_predictions
,
seq_length
)))
shape
=
(
batch_size
,
num_predictions
,
seq_length
))))
data
.
update
(
two_stream_data
)
if
include_state
:
if
include_state
:
total_seq_length
=
seq_length
+
memory_length
total_seq_length
=
seq_length
+
memory_length
data
[
"state"
]
=
tf
.
random
.
normal
(
if
num_layers
>
1
:
shape
=
(
batch_size
,
memory_length
,
hidden_size
))
state_shape
=
(
num_layers
,
batch_size
,
memory_length
,
hidden_size
)
else
:
state_shape
=
(
batch_size
,
memory_length
,
hidden_size
)
data
.
update
(
dict
(
state
=
tf
.
random
.
normal
(
shape
=
state_shape
)))
else
:
else
:
total_seq_length
=
seq_length
total_seq_length
=
seq_length
...
@@ -87,15 +98,19 @@ def create_mock_transformer_xl_data(
...
@@ -87,15 +98,19 @@ def create_mock_transformer_xl_data(
data
[
"query_attention_mask"
]
=
mask_data
data
[
"query_attention_mask"
]
=
mask_data
if
include_segment
:
if
include_segment
:
segment_encoding_shape
=
(
2
,
num_heads
,
head_size
)
# A transformer XL block takes an individual segment "encoding" from the
# entirety of the Transformer XL segment "embedding".
if
num_layers
>
1
:
segment_encoding_shape
=
(
num_layers
,
2
,
num_heads
,
head_size
)
segment_encoding_name
=
"segment_embedding"
else
:
segment_encoding_shape
=
(
2
,
num_heads
,
head_size
)
segment_encoding_name
=
"segment_encoding"
segment_matrix
=
np
.
random
.
randint
(
segment_matrix
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
seq_length
,
total_seq_length
))
2
,
size
=
(
batch_size
,
seq_length
,
total_seq_length
))
segment_matrix
=
tf
.
math
.
equal
(
segment_matrix
,
1
)
data
[
"segment_matrix"
]
=
tf
.
math
.
equal
(
segment_matrix
,
1
)
segment_data
=
dict
(
data
[
segment_encoding_name
]
=
tf
.
random
.
normal
(
shape
=
segment_encoding_shape
)
segment_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
),
segment_encoding
=
tf
.
random
.
normal
(
shape
=
segment_encoding_shape
),
segment_matrix
=
segment_matrix
)
data
.
update
(
segment_data
)
return
data
return
data
...
@@ -109,17 +124,19 @@ class TransformerXLBlockTest(keras_parameterized.TestCase):
...
@@ -109,17 +124,19 @@ class TransformerXLBlockTest(keras_parameterized.TestCase):
state
=
[
True
,
False
],
state
=
[
True
,
False
],
mask
=
[
True
,
False
],
mask
=
[
True
,
False
],
segment
=
[
True
,
False
]))
segment
=
[
True
,
False
]))
def
test_transformer_xl
(
self
,
def
test_transformer_xl_block
(
two_stream
,
self
,
memory_length
,
two_stream
,
state
,
memory_length
,
mask
,
state
,
segment
):
mask
,
"""Tests combinations of Transformer XL calculations."""
segment
):
"""Tests combinations of Transformer XL block calculations."""
batch_size
,
num_heads
,
head_size
,
seq_length
=
2
,
12
,
64
,
8
batch_size
,
num_heads
,
head_size
,
seq_length
=
2
,
12
,
64
,
8
hidden_size
,
num_predictions
,
inner_size
=
24
,
8
,
12
hidden_size
,
num_predictions
,
inner_size
=
24
,
8
,
12
data
=
create_mock_transformer_xl_data
(
data
=
create_mock_transformer_xl_data
(
include_biases
=
True
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
head_size
=
head_size
,
head_size
=
head_size
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
...
@@ -169,6 +186,90 @@ class TransformerXLBlockTest(keras_parameterized.TestCase):
...
@@ -169,6 +186,90 @@ class TransformerXLBlockTest(keras_parameterized.TestCase):
self
.
assertEqual
(
transformer_xl_block_config
,
new_block
.
get_config
())
self
.
assertEqual
(
transformer_xl_block_config
,
new_block
.
get_config
())
@
keras_parameterized
.
run_all_keras_modes
class
TransformerXLTest
(
keras_parameterized
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
two_stream
=
[
True
,
False
],
memory_length
=
[
0
,
4
],
reuse_length
=
[
0
,
4
],
tie_attention_biases
=
[
True
,
False
],
state
=
[
True
,
False
],
mask
=
[
True
,
False
],
segment
=
[
True
,
False
]))
def
test_transformer_xl
(
self
,
two_stream
,
memory_length
,
reuse_length
,
tie_attention_biases
,
state
,
mask
,
segment
):
batch_size
,
num_heads
,
head_size
,
seq_length
=
2
,
12
,
64
,
8
hidden_size
,
num_predictions
,
inner_size
=
24
,
8
,
12
num_layers
=
3
data
=
create_mock_transformer_xl_data
(
include_biases
=
False
,
num_heads
=
num_heads
,
head_size
=
head_size
,
hidden_size
=
hidden_size
,
seq_length
=
seq_length
,
batch_size
=
batch_size
,
memory_length
=
memory_length
,
num_predictions
=
num_predictions
,
two_stream
=
two_stream
,
num_layers
=
num_layers
,
include_state
=
state
,
include_mask
=
mask
,
include_segment
=
segment
)
transformer_xl_layer
=
transformer_xl
.
TransformerXL
(
vocab_size
=
32000
,
num_layers
=
num_layers
,
head_size
=
head_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_heads
,
inner_size
=
inner_size
,
dropout_rate
=
0.
,
attention_dropout_rate
=
0.
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.1
),
two_stream
=
two_stream
,
tie_attention_biases
=
tie_attention_biases
,
memory_length
=
memory_length
,
reuse_length
=
reuse_length
,
inner_activation
=
"relu"
)
attention_output
,
cached_memory_states
=
transformer_xl_layer
(
**
data
)
if
two_stream
:
self
.
assertEqual
(
attention_output
.
shape
,
[
batch_size
,
num_predictions
,
hidden_size
])
else
:
self
.
assertEqual
(
attention_output
.
shape
,
[
batch_size
,
seq_length
,
hidden_size
])
self
.
assertEqual
(
len
(
cached_memory_states
),
num_layers
)
def
test_get_config
(
self
):
transformer_xl_layer
=
transformer_xl
.
TransformerXL
(
vocab_size
=
32000
,
num_layers
=
12
,
hidden_size
=
36
,
head_size
=
12
,
num_attention_heads
=
12
,
inner_size
=
12
,
dropout_rate
=
0.
,
attention_dropout_rate
=
0.
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.1
),
two_stream
=
False
,
tie_attention_biases
=
True
,
memory_length
=
0
,
reuse_length
=
0
,
inner_activation
=
"relu"
)
transformer_xl_config
=
transformer_xl_layer
.
get_config
()
new_transformer_xl
=
transformer_xl
.
TransformerXL
.
from_config
(
transformer_xl_config
)
self
.
assertEqual
(
transformer_xl_config
,
new_transformer_xl
.
get_config
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
tf
.
random
.
set_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment