Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
801ac678
Commit
801ac678
authored
Oct 26, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Oct 26, 2020
Browse files
Internal change
PiperOrigin-RevId: 339095008
parent
b70019f0
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
210 additions
and
118 deletions
+210
-118
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+50
-0
official/nlp/modeling/layers/relative_attention.py
official/nlp/modeling/layers/relative_attention.py
+3
-29
official/nlp/modeling/layers/transformer_xl.py
official/nlp/modeling/layers/transformer_xl.py
+0
-1
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+17
-6
official/nlp/modeling/models/xlnet_test.py
official/nlp/modeling/models/xlnet_test.py
+7
-6
official/nlp/modeling/networks/xlnet_base.py
official/nlp/modeling/networks/xlnet_base.py
+46
-19
official/nlp/modeling/networks/xlnet_base_test.py
official/nlp/modeling/networks/xlnet_base_test.py
+44
-44
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+13
-7
official/nlp/xlnet/xlnet_modeling.py
official/nlp/xlnet/xlnet_modeling.py
+30
-6
No files found.
official/nlp/configs/encoders.py
View file @
801ac678
...
@@ -136,6 +136,31 @@ class BigBirdEncoderConfig(hyperparams.Config):
...
@@ -136,6 +136,31 @@ class BigBirdEncoderConfig(hyperparams.Config):
embedding_size
:
Optional
[
int
]
=
None
embedding_size
:
Optional
[
int
]
=
None
@
dataclasses
.
dataclass
class
XLNetEncoderConfig
(
hyperparams
.
Config
):
"""XLNet encoder configuration."""
vocab_size
:
int
=
32000
num_layers
:
int
=
24
hidden_size
:
int
=
1024
num_attention_heads
:
int
=
16
head_size
:
int
=
64
inner_size
:
int
=
4096
inner_activation
:
str
=
"gelu"
dropout_rate
:
float
=
0.1
attention_dropout_rate
:
float
=
0.1
attention_type
:
str
=
"bi"
bi_data
:
bool
=
False
tie_attention_biases
:
bool
=
False
memory_length
:
int
=
0
same_length
:
bool
=
False
clamp_length
:
int
=
-
1
reuse_length
:
int
=
0
use_cls_mask
:
bool
=
False
embedding_width
:
int
=
1024
initializer_range
:
float
=
0.02
two_stream
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
"""Encoder configuration."""
"""Encoder configuration."""
...
@@ -144,6 +169,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
...
@@ -144,6 +169,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
bert
:
BertEncoderConfig
=
BertEncoderConfig
()
bert
:
BertEncoderConfig
=
BertEncoderConfig
()
bigbird
:
BigBirdEncoderConfig
=
BigBirdEncoderConfig
()
bigbird
:
BigBirdEncoderConfig
=
BigBirdEncoderConfig
()
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
ENCODER_CLS
=
{
ENCODER_CLS
=
{
...
@@ -151,6 +177,7 @@ ENCODER_CLS = {
...
@@ -151,6 +177,7 @@ ENCODER_CLS = {
"mobilebert"
:
networks
.
MobileBERTEncoder
,
"mobilebert"
:
networks
.
MobileBERTEncoder
,
"albert"
:
networks
.
AlbertEncoder
,
"albert"
:
networks
.
AlbertEncoder
,
"bigbird"
:
bigbird_encoder
.
BigBirdEncoder
,
"bigbird"
:
bigbird_encoder
.
BigBirdEncoder
,
"xlnet"
:
networks
.
XLNetBase
,
}
}
...
@@ -266,6 +293,29 @@ def build_encoder(
...
@@ -266,6 +293,29 @@ def build_encoder(
stddev
=
encoder_cfg
.
initializer_range
),
stddev
=
encoder_cfg
.
initializer_range
),
embedding_width
=
encoder_cfg
.
embedding_size
)
embedding_width
=
encoder_cfg
.
embedding_size
)
if
encoder_type
==
"xlnet"
:
return
encoder_cls
(
vocab_size
=
encoder_cfg
.
vocab_size
,
num_layers
=
encoder_cfg
.
num_layers
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
head_size
=
encoder_cfg
.
head_size
,
inner_size
=
encoder_cfg
.
inner_size
,
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
attention_type
=
encoder_cfg
.
attention_type
,
bi_data
=
encoder_cfg
.
bi_data
,
two_stream
=
encoder_cfg
.
two_stream
,
tie_attention_biases
=
encoder_cfg
.
tie_attention_biases
,
memory_length
=
encoder_cfg
.
memory_length
,
clamp_length
=
encoder_cfg
.
clamp_length
,
reuse_length
=
encoder_cfg
.
reuse_length
,
inner_activation
=
encoder_cfg
.
inner_activation
,
use_cls_mask
=
encoder_cfg
.
use_cls_mask
,
embedding_width
=
encoder_cfg
.
embedding_width
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
encoder_cfg
.
initializer_range
))
# Uses the default BERTEncoder configuration schema to create the encoder.
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
# If it does not match, please add a switch branch by the encoder type.
return
encoder_cls
(
return
encoder_cls
(
...
...
official/nlp/modeling/layers/relative_attention.py
View file @
801ac678
...
@@ -54,23 +54,6 @@ def _get_output_shape(output_rank, known_last_dims):
...
@@ -54,23 +54,6 @@ def _get_output_shape(output_rank, known_last_dims):
return
[
None
]
*
(
output_rank
-
len
(
known_last_dims
))
+
list
(
known_last_dims
)
return
[
None
]
*
(
output_rank
-
len
(
known_last_dims
))
+
list
(
known_last_dims
)
def
_large_compatible_negative
(
tensor_type
):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if
tensor_type
==
tf
.
float16
:
return
tf
.
float16
.
min
return
-
1e9
def
_rel_shift
(
x
,
klen
=-
1
):
def
_rel_shift
(
x
,
klen
=-
1
):
"""Performs relative shift to form the relative attention score."""
"""Performs relative shift to form the relative attention score."""
...
@@ -101,13 +84,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -101,13 +84,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
**Note: This layer is currently experimental.
**Note: This layer is currently experimental.
Attributes:
Attributes:
num_heads: The number of attention heads.
kernel_initializer: The kernel initializer. Defaults to variance_scaling.
key_dim: Size of each attention head for query and key.
value_dim: Size of attention head for value.
dropout: Dropout probability for attention.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
Call args:
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
...
@@ -242,12 +220,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -242,12 +220,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
attention_scores
=
tf
.
multiply
(
attention_scores
=
tf
.
multiply
(
attention_sum
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
attention_sum
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
# `attention_scores`: `[B, N, S, S + M]`
attention_scores
=
self
.
_masked_softmax
(
attention_scores
,
attention_mask
)
if
attention_mask
is
not
None
:
attention_scores
+=
(
_large_compatible_negative
(
attention_scores
.
dtype
)
*
attention_mask
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
3
)
attention_output
=
self
.
_dropout_layer
(
attention_scores
)
attention_output
=
self
.
_dropout_layer
(
attention_scores
)
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
...
...
official/nlp/modeling/layers/transformer_xl.py
View file @
801ac678
...
@@ -85,7 +85,6 @@ class TransformerXLBlock(tf.keras.layers.Layer):
...
@@ -85,7 +85,6 @@ class TransformerXLBlock(tf.keras.layers.Layer):
kernel_initializer: Initializer for dense layer kernels.
kernel_initializer: Initializer for dense layer kernels.
inner_dropout: Dropout probability for the inner dropout
inner_dropout: Dropout probability for the inner dropout
layer.
layer.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
...
official/nlp/modeling/models/xlnet.py
View file @
801ac678
...
@@ -31,6 +31,9 @@ class XLNetClassifier(tf.keras.Model):
...
@@ -31,6 +31,9 @@ class XLNetClassifier(tf.keras.Model):
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Note: This model does not use utilize the memory mechanism used in the
original XLNet Classifier.
Arguments:
Arguments:
network: An XLNet/Transformer-XL based network. This network should output a
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
sequence output and list of `state` tensors.
...
@@ -70,7 +73,7 @@ class XLNetClassifier(tf.keras.Model):
...
@@ -70,7 +73,7 @@ class XLNetClassifier(tf.keras.Model):
raise
ValueError
(
'Invalid summary type provided: %s.'
%
summary_type
)
raise
ValueError
(
'Invalid summary type provided: %s.'
%
summary_type
)
self
.
classifier
=
layers
.
ClassificationHead
(
self
.
classifier
=
layers
.
ClassificationHead
(
inner_dim
=
network
.
get_config
()[
'
inner
_size'
],
inner_dim
=
network
.
get_config
()[
'
hidden
_size'
],
num_classes
=
num_classes
,
num_classes
=
num_classes
,
initializer
=
initializer
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
...
@@ -78,12 +81,12 @@ class XLNetClassifier(tf.keras.Model):
...
@@ -78,12 +81,12 @@ class XLNetClassifier(tf.keras.Model):
name
=
'sentence_prediction'
)
name
=
'sentence_prediction'
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_ids
=
inputs
[
'input_ids'
]
input_ids
=
inputs
[
'input_
word_
ids'
]
segment_ids
=
inputs
[
'
segment
_ids'
]
segment_ids
=
inputs
[
'
input_type
_ids'
]
input_mask
=
inputs
[
'input_mask'
]
input_mask
=
tf
.
cast
(
inputs
[
'input_mask'
]
,
tf
.
float32
)
state
=
inputs
.
get
(
'mems'
,
None
)
state
=
inputs
.
get
(
'mems'
,
None
)
attention_output
,
new_states
=
self
.
_network
(
attention_output
,
_
=
self
.
_network
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
...
@@ -91,7 +94,7 @@ class XLNetClassifier(tf.keras.Model):
...
@@ -91,7 +94,7 @@ class XLNetClassifier(tf.keras.Model):
logits
=
self
.
classifier
(
attention_output
)
logits
=
self
.
classifier
(
attention_output
)
return
logits
,
new_states
return
logits
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config
return
self
.
_config
...
@@ -100,6 +103,14 @@ class XLNetClassifier(tf.keras.Model):
...
@@ -100,6 +103,14 @@ class XLNetClassifier(tf.keras.Model):
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
return
cls
(
**
config
)
@
property
def
checkpoint_items
(
self
):
items
=
dict
(
encoder
=
self
.
_network
)
if
hasattr
(
self
.
classifier
,
'checkpoint_items'
):
for
key
,
item
in
self
.
classifier
.
checkpoint_items
.
items
():
items
[
'.'
.
join
([
self
.
classifier
.
name
,
key
])]
=
item
return
items
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
XLNetSpanLabeler
(
tf
.
keras
.
Model
):
class
XLNetSpanLabeler
(
tf
.
keras
.
Model
):
...
...
official/nlp/modeling/models/xlnet_test.py
View file @
801ac678
...
@@ -64,10 +64,10 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
...
@@ -64,10 +64,10 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
summary_type
=
'last'
,
summary_type
=
'last'
,
dropout_rate
=
0.1
)
dropout_rate
=
0.1
)
inputs
=
dict
(
inputs
=
dict
(
input_ids
=
tf
.
keras
.
layers
.
Input
(
input_
word_
ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
),
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
),
segment
_ids
=
tf
.
keras
.
layers
.
Input
(
input_type
_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'
segment
_ids'
),
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'
input_type
_ids'
),
input_mask
=
tf
.
keras
.
layers
.
Input
(
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'input_mask'
),
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'input_mask'
),
permutation_mask
=
tf
.
keras
.
layers
.
Input
(
permutation_mask
=
tf
.
keras
.
layers
.
Input
(
...
@@ -76,7 +76,7 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
...
@@ -76,7 +76,7 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
masked_tokens
=
tf
.
keras
.
layers
.
Input
(
masked_tokens
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'masked_tokens'
))
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'masked_tokens'
))
logits
,
_
=
xlnet_trainer_model
(
inputs
)
logits
=
xlnet_trainer_model
(
inputs
)
expected_classification_shape
=
[
None
,
num_classes
]
expected_classification_shape
=
[
None
,
num_classes
]
self
.
assertAllEqual
(
expected_classification_shape
,
logits
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_classification_shape
,
logits
.
shape
.
as_list
())
...
@@ -99,8 +99,9 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
...
@@ -99,8 +99,9 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
sequence_shape
=
(
batch_size
,
seq_length
)
sequence_shape
=
(
batch_size
,
seq_length
)
inputs
=
dict
(
inputs
=
dict
(
input_ids
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_word_ids
=
np
.
random
.
randint
(
segment_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_type_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'float32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'float32'
),
permutation_mask
=
np
.
random
.
randint
(
permutation_mask
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
seq_length
,
seq_length
)).
astype
(
'float32'
),
2
,
size
=
(
batch_size
,
seq_length
,
seq_length
)).
astype
(
'float32'
),
...
...
official/nlp/modeling/networks/xlnet_base.py
View file @
801ac678
...
@@ -49,6 +49,9 @@ def _create_causal_attention_mask(
...
@@ -49,6 +49,9 @@ def _create_causal_attention_mask(
concatenating 0s (representing memory positions) with a strictly upper
concatenating 0s (representing memory positions) with a strictly upper
triangular matrix of 1s.
triangular matrix of 1s.
We then flip the matrix values in order to match the representation where
real values are 1s.
Arguments:
Arguments:
seq_length: int, The length of each sequence.
seq_length: int, The length of each sequence.
memory_length: int, The length of memory blocks.
memory_length: int, The length of memory blocks.
...
@@ -59,10 +62,10 @@ def _create_causal_attention_mask(
...
@@ -59,10 +62,10 @@ def _create_causal_attention_mask(
A unidirectional attention mask of shape
A unidirectional attention mask of shape
`[seq_length, seq_length + memory_length]`. E.g.:
`[seq_length, seq_length + memory_length]`. E.g.:
[[
0
. 0. 0.
1. 1. 1
.]
[[
1. 1. 1
. 0. 0.
0
.]
[
0
.
0
.
0
.
0
.
1
.
1
.]
[
1
.
1
.
1
.
1
.
0
.
0
.]
[
0
.
0
.
0
.
0
.
0
.
1
.]
[
1
.
1
.
1
.
1
.
1
.
0
.]
[
0
.
0
.
0
.
0
.
0
.
0
.]]
[
1
.
1
.
1
.
1
.
1
.
1
.]]
"""
"""
ones_matrix
=
tf
.
ones
([
seq_length
,
seq_length
],
dtype
=
dtype
)
ones_matrix
=
tf
.
ones
([
seq_length
,
seq_length
],
dtype
=
dtype
)
upper_triangular
=
tf
.
linalg
.
band_part
(
ones_matrix
,
0
,
-
1
)
upper_triangular
=
tf
.
linalg
.
band_part
(
ones_matrix
,
0
,
-
1
)
...
@@ -78,7 +81,32 @@ def _create_causal_attention_mask(
...
@@ -78,7 +81,32 @@ def _create_causal_attention_mask(
[
causal_attention_mask
[:,
:
seq_length
]
+
strictly_lower_triangular
,
[
causal_attention_mask
[:,
:
seq_length
]
+
strictly_lower_triangular
,
causal_attention_mask
[:,
seq_length
:]],
1
)
causal_attention_mask
[:,
seq_length
:]],
1
)
return
causal_attention_mask
return
1
-
causal_attention_mask
def
_combine_masks
(
mask1
,
mask2
,
dtype
,
how
=
"and"
):
"""Combines two masks.
Use "and" if trying to combine two existing masks.
Use "or" if trying to flip a few positions to "real".
Args:
mask1: tf.Tensor, input mask 1
mask2: tf.Tensor, input mask 2
dtype: tf.dtype
how: Which logical operation should run.
Returns:
The combined input masks.
"""
if
how
==
"and"
:
operator
=
tf
.
math
.
logical_and
else
:
operator
=
tf
.
math
.
logical_or
return
tf
.
cast
(
operator
(
tf
.
cast
(
mask1
,
tf
.
bool
),
tf
.
cast
(
mask2
,
tf
.
bool
)),
dtype
=
dtype
)
def
_compute_attention_mask
(
def
_compute_attention_mask
(
...
@@ -140,8 +168,7 @@ def _compute_attention_mask(
...
@@ -140,8 +168,7 @@ def _compute_attention_mask(
# input_mask: [B, S]
# input_mask: [B, S]
# permutation_mask: [B, S, S]
# permutation_mask: [B, S, S]
if
input_mask
is
not
None
and
permutation_mask
is
not
None
:
if
input_mask
is
not
None
and
permutation_mask
is
not
None
:
data_mask
=
input_mask
[:,
None
,
:]
+
permutation_mask
data_mask
=
_combine_masks
(
input_mask
[:,
None
,
:],
permutation_mask
,
dtype
)
elif
input_mask
is
not
None
and
permutation_mask
is
None
:
elif
input_mask
is
not
None
and
permutation_mask
is
None
:
data_mask
=
input_mask
[:,
None
,
:]
data_mask
=
input_mask
[:,
None
,
:]
elif
input_mask
is
None
and
permutation_mask
is
not
None
:
elif
input_mask
is
None
and
permutation_mask
is
not
None
:
...
@@ -153,28 +180,28 @@ def _compute_attention_mask(
...
@@ -153,28 +180,28 @@ def _compute_attention_mask(
if
data_mask
is
not
None
:
if
data_mask
is
not
None
:
# All positions within state can be attended to.
# All positions within state can be attended to.
state_mask
=
tf
.
zero
s
([
batch_size
,
tf
.
shape
(
data_mask
)[
1
],
memory_length
],
state_mask
=
tf
.
one
s
([
batch_size
,
tf
.
shape
(
data_mask
)[
1
],
memory_length
],
dtype
=
dtype
)
dtype
=
dtype
)
# state_mask: [B, 1, M] or [B, S, M]
# state_mask: [B, 1, M] or [B, S, M]
data_mask
=
tf
.
concat
([
state_mask
,
data_mask
],
2
)
data_mask
=
tf
.
concat
([
state_mask
,
data_mask
],
2
)
# data_mask: [B, 1, S + M] or [B, S, S + M]
# data_mask: [B, 1, S + M] or [B, S, S + M]
if
attention_type
==
"uni"
:
if
attention_type
==
"uni"
:
attention_mask
=
causal_attention_mask
+
data_mask
[:,
None
,
:,
:]
attention_mask
=
_combine_masks
(
causal_attention_mask
,
data_mask
[:,
None
,
:,
:],
dtype
=
dtype
)
else
:
else
:
attention_mask
=
data_mask
[:,
None
,
:,
:]
attention_mask
=
data_mask
[:,
None
,
:,
:]
# Construct the content attention mask.
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
tf
.
cast
(
attention
_
mask
>
0
,
dtype
=
dtype
)
# Construct the content
attention
mask
.
# This ensures that the mask allows the model to attend to positions in
non_tgt_mask
=
-
tf
.
eye
(
seq_length
,
dtype
=
dtype
)
# content positions (e.g. the content diagonal).
non_t
g
t_mask
=
tf
.
concat
(
non_t
arge
t_mask
=
tf
.
concat
(
[
tf
.
zeros
([
seq_length
,
memory_length
],
dtype
=
dtype
),
[
tf
.
zeros
([
seq_length
,
memory_length
],
dtype
=
dtype
),
non_tgt_mask
],
axis
=-
1
)
tf
.
eye
(
seq_length
,
dtype
=
dtype
)],
axis
=-
1
)
content_attention_mask
=
tf
.
cast
(
content_attention_mask
=
_combine_masks
(
(
attention_mask
+
non_tgt_mask
[
None
,
None
,
:,
:])
>
0
,
attention_mask
,
non_target_mask
,
how
=
"or"
,
dtype
=
dtype
)
dtype
=
dtype
)
else
:
else
:
content_attention_mask
=
None
content_attention_mask
=
None
...
...
official/nlp/modeling/networks/xlnet_base_test.py
View file @
801ac678
...
@@ -85,9 +85,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
...
@@ -85,9 +85,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
seq_length
=
seq_length
,
seq_length
=
seq_length
,
memory_length
=
memory_length
)
memory_length
=
memory_length
)
expected_output
=
np
.
array
([[
0
,
1
,
1
],
expected_output
=
np
.
array
([[
1
,
0
,
0
],
[
0
,
0
,
1
],
[
1
,
1
,
0
],
[
0
,
0
,
0
]])
[
1
,
1
,
1
]])
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
def
test_casual_attention_mask_with_memory
(
self
):
def
test_casual_attention_mask_with_memory
(
self
):
...
@@ -96,9 +96,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
...
@@ -96,9 +96,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
seq_length
=
seq_length
,
seq_length
=
seq_length
,
memory_length
=
memory_length
)
memory_length
=
memory_length
)
expected_output
=
np
.
array
([[
0
,
0
,
0
,
1
,
1
],
expected_output
=
np
.
array
([[
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
,
1
],
[
1
,
1
,
1
,
1
,
0
],
[
0
,
0
,
0
,
0
,
0
]])
[
1
,
1
,
1
,
1
,
1
]])
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
def
test_causal_attention_mask_with_same_length
(
self
):
def
test_causal_attention_mask_with_same_length
(
self
):
...
@@ -108,9 +108,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
...
@@ -108,9 +108,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
memory_length
=
memory_length
,
memory_length
=
memory_length
,
same_length
=
True
)
same_length
=
True
)
expected_output
=
np
.
array
([[
0
,
0
,
0
,
1
,
1
],
expected_output
=
np
.
array
([[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
0
,
1
],
[
0
,
1
,
1
,
1
,
0
],
[
1
,
1
,
0
,
0
,
0
]])
[
0
,
0
,
1
,
1
,
1
]])
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
...
@@ -179,15 +179,15 @@ class MaskComputationTests(keras_parameterized.TestCase):
...
@@ -179,15 +179,15 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size
=
1
batch_size
=
1
memory_length
=
0
memory_length
=
0
input_mask
=
np
.
array
([[
0
,
0
,
1
,
1
]])
input_mask
=
np
.
array
([[
1
,
1
,
0
,
0
]])
permutation_mask
=
None
permutation_mask
=
None
expected_query_mask
=
input_mask
[
None
,
None
,
:,
:]
expected_query_mask
=
input_mask
[
None
,
None
,
:,
:]
expected_content_mask
=
np
.
array
([[[
expected_content_mask
=
np
.
array
([[[
[
0
,
0
,
1
,
1
],
[
1
,
1
,
0
,
0
],
[
0
,
0
,
1
,
1
],
[
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
],
[
0
,
0
,
1
,
0
]]]])
[
1
,
1
,
0
,
1
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
input_mask
=
input_mask
,
...
@@ -209,14 +209,14 @@ class MaskComputationTests(keras_parameterized.TestCase):
...
@@ -209,14 +209,14 @@ class MaskComputationTests(keras_parameterized.TestCase):
input_mask
=
None
input_mask
=
None
permutation_mask
=
np
.
array
([
permutation_mask
=
np
.
array
([
[[
0
,
1
],
[[
1
,
0
],
[
0
,
1
]],
[
1
,
0
]],
])
])
expected_query_mask
=
permutation_mask
[:,
None
,
:,
:]
expected_query_mask
=
permutation_mask
[:,
None
,
:,
:]
expected_content_mask
=
np
.
array
([[[
expected_content_mask
=
np
.
array
([[[
[
0
,
1
],
[
1
,
0
],
[
0
,
0
]]]])
[
1
,
1
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
input_mask
=
input_mask
,
...
@@ -236,24 +236,24 @@ class MaskComputationTests(keras_parameterized.TestCase):
...
@@ -236,24 +236,24 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size
=
1
batch_size
=
1
memory_length
=
0
memory_length
=
0
input_mask
=
np
.
array
([[
0
,
0
,
1
,
1
]])
input_mask
=
np
.
array
([[
1
,
1
,
0
,
0
]])
permutation_mask
=
np
.
array
([[
permutation_mask
=
np
.
array
([[
[
1
,
0
,
0
,
0
],
[
0
,
1
,
1
,
1
],
[
0
,
1
,
0
,
0
],
[
1
,
0
,
1
,
1
],
[
0
,
0
,
1
,
0
],
[
1
,
1
,
0
,
1
],
[
0
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
],
]])
]])
expected_query_mask
=
np
.
array
([[[
expected_query_mask
=
np
.
array
([[[
[
1
,
0
,
1
,
1
],
[
0
,
1
,
0
,
0
],
[
0
,
1
,
1
,
1
],
[
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
1
],
[
1
,
1
,
0
,
0
],
[
0
,
0
,
1
,
1
]]]])
[
1
,
1
,
0
,
0
]]]])
expected_content_mask
=
np
.
array
([[[
expected_content_mask
=
np
.
array
([[[
[
0
,
0
,
1
,
1
],
[
1
,
1
,
0
,
0
],
[
0
,
0
,
1
,
1
],
[
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
],
[
0
,
0
,
1
,
0
]]]])
[
1
,
1
,
0
,
1
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
permutation_mask
=
permutation_mask
,
...
@@ -272,24 +272,24 @@ class MaskComputationTests(keras_parameterized.TestCase):
...
@@ -272,24 +272,24 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size
=
1
batch_size
=
1
memory_length
=
0
memory_length
=
0
input_mask
=
np
.
array
([[
0
,
0
,
0
,
1
]])
input_mask
=
np
.
array
([[
1
,
1
,
1
,
0
]])
permutation_mask
=
np
.
array
([[
permutation_mask
=
np
.
array
([[
[
1
,
0
,
0
,
0
],
[
0
,
1
,
1
,
1
],
[
0
,
1
,
0
,
0
],
[
1
,
0
,
1
,
1
],
[
0
,
0
,
1
,
0
],
[
1
,
1
,
0
,
1
],
[
0
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
],
]])
]])
expected_query_mask
=
np
.
array
([[[
expected_query_mask
=
np
.
array
([[[
[
1
,
1
,
1
,
1
],
[
0
,
0
,
0
,
0
],
[
0
,
1
,
1
,
1
],
[
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
1
],
[
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
]]]])
[
1
,
1
,
1
,
0
]]]])
expected_content_mask
=
np
.
array
([[[
expected_content_mask
=
np
.
array
([[[
[
0
,
1
,
1
,
1
],
[
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
1
],
[
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
],
[
0
,
0
,
0
,
0
]]]])
[
1
,
1
,
1
,
1
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
permutation_mask
=
permutation_mask
,
...
...
official/nlp/tasks/sentence_prediction.py
View file @
801ac678
...
@@ -81,7 +81,13 @@ class SentencePredictionTask(base_task.Task):
...
@@ -81,7 +81,13 @@ class SentencePredictionTask(base_task.Task):
else
:
else
:
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
()
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
()
# Currently, we only support bert-style sentence prediction finetuning.
if
self
.
task_config
.
model
.
encoder
.
type
==
'xlnet'
:
return
models
.
XLNetClassifier
(
network
=
encoder_network
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
encoder_cfg
.
initializer_range
))
else
:
return
models
.
BertClassifier
(
return
models
.
BertClassifier
(
network
=
encoder_network
,
network
=
encoder_network
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
...
...
official/nlp/xlnet/xlnet_modeling.py
View file @
801ac678
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""Keras layers of XLNet model in TF 2.0."""
"""Keras layers of XLNet model in TF 2.0."""
import
copy
import
copy
import
warnings
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -416,6 +417,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -416,6 +417,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
"""
"""
super
(
TransformerXLModel
,
self
).
__init__
(
**
kwargs
)
super
(
TransformerXLModel
,
self
).
__init__
(
**
kwargs
)
warnings
.
warn
(
"`TransformerXLModel` is deprecated, please use `XLNetBase` instead"
,
DeprecationWarning
,
stacklevel
=
2
)
self
.
n_token
=
n_token
self
.
n_token
=
n_token
self
.
initializer
=
initializer
self
.
initializer
=
initializer
...
@@ -745,11 +749,13 @@ class PretrainingXLNetModel(tf.keras.Model):
...
@@ -745,11 +749,13 @@ class PretrainingXLNetModel(tf.keras.Model):
"""
"""
def
__init__
(
self
,
use_proj
,
xlnet_config
,
run_config
,
**
kwargs
):
def
__init__
(
self
,
use_proj
,
xlnet_config
,
run_config
,
use_legacy_mask
=
True
,
**
kwargs
):
super
(
PretrainingXLNetModel
,
self
).
__init__
(
**
kwargs
)
super
(
PretrainingXLNetModel
,
self
).
__init__
(
**
kwargs
)
self
.
run_config
=
run_config
self
.
run_config
=
run_config
self
.
initializer
=
_get_initializer
(
run_config
)
self
.
initializer
=
_get_initializer
(
run_config
)
self
.
xlnet_config
=
copy
.
deepcopy
(
xlnet_config
)
self
.
xlnet_config
=
copy
.
deepcopy
(
xlnet_config
)
self
.
_use_legacy_mask
=
use_legacy_mask
self
.
xlnet_model
=
networks
.
XLNetBase
(
self
.
xlnet_model
=
networks
.
XLNetBase
(
vocab_size
=
self
.
xlnet_config
.
n_token
,
vocab_size
=
self
.
xlnet_config
.
n_token
,
...
@@ -788,6 +794,9 @@ class PretrainingXLNetModel(tf.keras.Model):
...
@@ -788,6 +794,9 @@ class PretrainingXLNetModel(tf.keras.Model):
input_ids
=
features
[
"input_ids"
]
input_ids
=
features
[
"input_ids"
]
masked_tokens
=
features
[
"input_q"
]
masked_tokens
=
features
[
"input_q"
]
seg_ids
=
features
[
"seg_id"
]
seg_ids
=
features
[
"seg_id"
]
if
self
.
_use_legacy_mask
:
perm_mask
=
1
-
features
[
"perm_mask"
]
else
:
perm_mask
=
features
[
"perm_mask"
]
perm_mask
=
features
[
"perm_mask"
]
target_mapping
=
features
[
"target_mapping"
]
target_mapping
=
features
[
"target_mapping"
]
...
@@ -823,11 +832,16 @@ class ClassificationXLNetModel(tf.keras.Model):
...
@@ -823,11 +832,16 @@ class ClassificationXLNetModel(tf.keras.Model):
"""
"""
def
__init__
(
self
,
xlnet_config
,
run_config
,
n_class
,
summary_type
,
**
kwargs
):
def
__init__
(
self
,
xlnet_config
,
run_config
,
n_class
,
summary_type
,
use_legacy_mask
=
True
,
**
kwargs
):
super
(
ClassificationXLNetModel
,
self
).
__init__
(
**
kwargs
)
super
(
ClassificationXLNetModel
,
self
).
__init__
(
**
kwargs
)
warnings
.
warn
(
"`ClassificationXLNetModel` is deprecated, please use `XLNetClassifier`"
"instead."
,
DeprecationWarning
,
stacklevel
=
2
)
self
.
run_config
=
run_config
self
.
run_config
=
run_config
self
.
initializer
=
_get_initializer
(
run_config
)
self
.
initializer
=
_get_initializer
(
run_config
)
self
.
xlnet_config
=
copy
.
deepcopy
(
xlnet_config
)
self
.
xlnet_config
=
copy
.
deepcopy
(
xlnet_config
)
self
.
_use_legacy_mask
=
use_legacy_mask
self
.
xlnet_model
=
networks
.
XLNetBase
(
self
.
xlnet_model
=
networks
.
XLNetBase
(
vocab_size
=
self
.
xlnet_config
.
n_token
,
vocab_size
=
self
.
xlnet_config
.
n_token
,
...
@@ -870,6 +884,9 @@ class ClassificationXLNetModel(tf.keras.Model):
...
@@ -870,6 +884,9 @@ class ClassificationXLNetModel(tf.keras.Model):
input_ids
=
features
[
"input_ids"
]
input_ids
=
features
[
"input_ids"
]
segment_ids
=
features
[
"segment_ids"
]
segment_ids
=
features
[
"segment_ids"
]
if
self
.
_use_legacy_mask
:
input_mask
=
1
-
features
[
"input_mask"
]
else
:
input_mask
=
features
[
"input_mask"
]
input_mask
=
features
[
"input_mask"
]
label
=
tf
.
reshape
(
features
[
"label_ids"
],
[
batch_size_per_core
])
label
=
tf
.
reshape
(
features
[
"label_ids"
],
[
batch_size_per_core
])
...
@@ -1068,11 +1085,15 @@ class QAXLNetModel(tf.keras.Model):
...
@@ -1068,11 +1085,15 @@ class QAXLNetModel(tf.keras.Model):
"""
"""
def
__init__
(
self
,
xlnet_config
,
run_config
,
start_n_top
,
end_n_top
,
def
__init__
(
self
,
xlnet_config
,
run_config
,
start_n_top
,
end_n_top
,
**
kwargs
):
use_legacy_mask
=
True
,
**
kwargs
):
super
(
QAXLNetModel
,
self
).
__init__
(
**
kwargs
)
super
(
QAXLNetModel
,
self
).
__init__
(
**
kwargs
)
warnings
.
warn
(
"`QAXLNetModel` is deprecated, please use `XLNetSpanLabeler` instead."
,
DeprecationWarning
,
stacklevel
=
2
)
self
.
run_config
=
run_config
self
.
run_config
=
run_config
self
.
initializer
=
_get_initializer
(
run_config
)
self
.
initializer
=
_get_initializer
(
run_config
)
self
.
xlnet_config
=
copy
.
deepcopy
(
xlnet_config
)
self
.
xlnet_config
=
copy
.
deepcopy
(
xlnet_config
)
self
.
_use_legacy_mask
=
use_legacy_mask
self
.
xlnet_model
=
networks
.
XLNetBase
(
self
.
xlnet_model
=
networks
.
XLNetBase
(
vocab_size
=
self
.
xlnet_config
.
n_token
,
vocab_size
=
self
.
xlnet_config
.
n_token
,
...
@@ -1108,6 +1129,9 @@ class QAXLNetModel(tf.keras.Model):
...
@@ -1108,6 +1129,9 @@ class QAXLNetModel(tf.keras.Model):
input_ids
=
features
[
"input_ids"
]
input_ids
=
features
[
"input_ids"
]
segment_ids
=
features
[
"segment_ids"
]
segment_ids
=
features
[
"segment_ids"
]
if
self
.
_use_legacy_mask
:
input_mask
=
1
-
features
[
"input_mask"
]
else
:
input_mask
=
features
[
"input_mask"
]
input_mask
=
features
[
"input_mask"
]
cls_index
=
tf
.
reshape
(
features
[
"cls_index"
],
[
-
1
])
cls_index
=
tf
.
reshape
(
features
[
"cls_index"
],
[
-
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment