Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
801ac678
Commit
801ac678
authored
Oct 26, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Oct 26, 2020
Browse files
Internal change
PiperOrigin-RevId: 339095008
parent
b70019f0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
210 additions
and
118 deletions
+210
-118
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+50
-0
official/nlp/modeling/layers/relative_attention.py
official/nlp/modeling/layers/relative_attention.py
+3
-29
official/nlp/modeling/layers/transformer_xl.py
official/nlp/modeling/layers/transformer_xl.py
+0
-1
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+17
-6
official/nlp/modeling/models/xlnet_test.py
official/nlp/modeling/models/xlnet_test.py
+7
-6
official/nlp/modeling/networks/xlnet_base.py
official/nlp/modeling/networks/xlnet_base.py
+46
-19
official/nlp/modeling/networks/xlnet_base_test.py
official/nlp/modeling/networks/xlnet_base_test.py
+44
-44
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+13
-7
official/nlp/xlnet/xlnet_modeling.py
official/nlp/xlnet/xlnet_modeling.py
+30
-6
No files found.
official/nlp/configs/encoders.py
View file @
801ac678
...
...
@@ -136,6 +136,31 @@ class BigBirdEncoderConfig(hyperparams.Config):
embedding_size
:
Optional
[
int
]
=
None
@
dataclasses
.
dataclass
class
XLNetEncoderConfig
(
hyperparams
.
Config
):
"""XLNet encoder configuration."""
vocab_size
:
int
=
32000
num_layers
:
int
=
24
hidden_size
:
int
=
1024
num_attention_heads
:
int
=
16
head_size
:
int
=
64
inner_size
:
int
=
4096
inner_activation
:
str
=
"gelu"
dropout_rate
:
float
=
0.1
attention_dropout_rate
:
float
=
0.1
attention_type
:
str
=
"bi"
bi_data
:
bool
=
False
tie_attention_biases
:
bool
=
False
memory_length
:
int
=
0
same_length
:
bool
=
False
clamp_length
:
int
=
-
1
reuse_length
:
int
=
0
use_cls_mask
:
bool
=
False
embedding_width
:
int
=
1024
initializer_range
:
float
=
0.02
two_stream
:
bool
=
False
@
dataclasses
.
dataclass
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
"""Encoder configuration."""
...
...
@@ -144,6 +169,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
bert
:
BertEncoderConfig
=
BertEncoderConfig
()
bigbird
:
BigBirdEncoderConfig
=
BigBirdEncoderConfig
()
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
ENCODER_CLS
=
{
...
...
@@ -151,6 +177,7 @@ ENCODER_CLS = {
"mobilebert"
:
networks
.
MobileBERTEncoder
,
"albert"
:
networks
.
AlbertEncoder
,
"bigbird"
:
bigbird_encoder
.
BigBirdEncoder
,
"xlnet"
:
networks
.
XLNetBase
,
}
...
...
@@ -266,6 +293,29 @@ def build_encoder(
stddev
=
encoder_cfg
.
initializer_range
),
embedding_width
=
encoder_cfg
.
embedding_size
)
if
encoder_type
==
"xlnet"
:
return
encoder_cls
(
vocab_size
=
encoder_cfg
.
vocab_size
,
num_layers
=
encoder_cfg
.
num_layers
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
head_size
=
encoder_cfg
.
head_size
,
inner_size
=
encoder_cfg
.
inner_size
,
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
attention_type
=
encoder_cfg
.
attention_type
,
bi_data
=
encoder_cfg
.
bi_data
,
two_stream
=
encoder_cfg
.
two_stream
,
tie_attention_biases
=
encoder_cfg
.
tie_attention_biases
,
memory_length
=
encoder_cfg
.
memory_length
,
clamp_length
=
encoder_cfg
.
clamp_length
,
reuse_length
=
encoder_cfg
.
reuse_length
,
inner_activation
=
encoder_cfg
.
inner_activation
,
use_cls_mask
=
encoder_cfg
.
use_cls_mask
,
embedding_width
=
encoder_cfg
.
embedding_width
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
encoder_cfg
.
initializer_range
))
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return
encoder_cls
(
...
...
official/nlp/modeling/layers/relative_attention.py
View file @
801ac678
...
...
@@ -54,23 +54,6 @@ def _get_output_shape(output_rank, known_last_dims):
return
[
None
]
*
(
output_rank
-
len
(
known_last_dims
))
+
list
(
known_last_dims
)
def
_large_compatible_negative
(
tensor_type
):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if
tensor_type
==
tf
.
float16
:
return
tf
.
float16
.
min
return
-
1e9
def
_rel_shift
(
x
,
klen
=-
1
):
"""Performs relative shift to form the relative attention score."""
...
...
@@ -101,13 +84,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
**Note: This layer is currently experimental.
Attributes:
num_heads: The number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of attention head for value.
dropout: Dropout probability for attention.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_initializer: The kernel initializer. Defaults to variance_scaling.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
...
...
@@ -242,12 +220,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
attention_scores
=
tf
.
multiply
(
attention_sum
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
# `attention_scores`: `[B, N, S, S + M]`
if
attention_mask
is
not
None
:
attention_scores
+=
(
_large_compatible_negative
(
attention_scores
.
dtype
)
*
attention_mask
)
attention_scores
=
self
.
_masked_softmax
(
attention_scores
,
attention_mask
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
3
)
attention_output
=
self
.
_dropout_layer
(
attention_scores
)
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
...
...
official/nlp/modeling/layers/transformer_xl.py
View file @
801ac678
...
...
@@ -85,7 +85,6 @@ class TransformerXLBlock(tf.keras.layers.Layer):
kernel_initializer: Initializer for dense layer kernels.
inner_dropout: Dropout probability for the inner dropout
layer.
"""
def
__init__
(
self
,
...
...
official/nlp/modeling/models/xlnet.py
View file @
801ac678
...
...
@@ -31,6 +31,9 @@ class XLNetClassifier(tf.keras.Model):
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Note: This model does not use utilize the memory mechanism used in the
original XLNet Classifier.
Arguments:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
...
...
@@ -70,7 +73,7 @@ class XLNetClassifier(tf.keras.Model):
raise
ValueError
(
'Invalid summary type provided: %s.'
%
summary_type
)
self
.
classifier
=
layers
.
ClassificationHead
(
inner_dim
=
network
.
get_config
()[
'
inner
_size'
],
inner_dim
=
network
.
get_config
()[
'
hidden
_size'
],
num_classes
=
num_classes
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
...
...
@@ -78,12 +81,12 @@ class XLNetClassifier(tf.keras.Model):
name
=
'sentence_prediction'
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_ids
=
inputs
[
'input_ids'
]
segment_ids
=
inputs
[
'
segment
_ids'
]
input_mask
=
inputs
[
'input_mask'
]
input_ids
=
inputs
[
'input_
word_
ids'
]
segment_ids
=
inputs
[
'
input_type
_ids'
]
input_mask
=
tf
.
cast
(
inputs
[
'input_mask'
]
,
tf
.
float32
)
state
=
inputs
.
get
(
'mems'
,
None
)
attention_output
,
new_states
=
self
.
_network
(
attention_output
,
_
=
self
.
_network
(
input_ids
=
input_ids
,
segment_ids
=
segment_ids
,
input_mask
=
input_mask
,
...
...
@@ -91,7 +94,7 @@ class XLNetClassifier(tf.keras.Model):
logits
=
self
.
classifier
(
attention_output
)
return
logits
,
new_states
return
logits
def
get_config
(
self
):
return
self
.
_config
...
...
@@ -100,6 +103,14 @@ class XLNetClassifier(tf.keras.Model):
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
checkpoint_items
(
self
):
items
=
dict
(
encoder
=
self
.
_network
)
if
hasattr
(
self
.
classifier
,
'checkpoint_items'
):
for
key
,
item
in
self
.
classifier
.
checkpoint_items
.
items
():
items
[
'.'
.
join
([
self
.
classifier
.
name
,
key
])]
=
item
return
items
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
XLNetSpanLabeler
(
tf
.
keras
.
Model
):
...
...
official/nlp/modeling/models/xlnet_test.py
View file @
801ac678
...
...
@@ -64,10 +64,10 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
summary_type
=
'last'
,
dropout_rate
=
0.1
)
inputs
=
dict
(
input_ids
=
tf
.
keras
.
layers
.
Input
(
input_
word_
ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
),
segment
_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'
segment
_ids'
),
input_type
_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'
input_type
_ids'
),
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'input_mask'
),
permutation_mask
=
tf
.
keras
.
layers
.
Input
(
...
...
@@ -76,7 +76,7 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
masked_tokens
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'masked_tokens'
))
logits
,
_
=
xlnet_trainer_model
(
inputs
)
logits
=
xlnet_trainer_model
(
inputs
)
expected_classification_shape
=
[
None
,
num_classes
]
self
.
assertAllEqual
(
expected_classification_shape
,
logits
.
shape
.
as_list
())
...
...
@@ -99,8 +99,9 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
sequence_shape
=
(
batch_size
,
seq_length
)
inputs
=
dict
(
input_ids
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
segment_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_word_ids
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_type_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'float32'
),
permutation_mask
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
seq_length
,
seq_length
)).
astype
(
'float32'
),
...
...
official/nlp/modeling/networks/xlnet_base.py
View file @
801ac678
...
...
@@ -49,6 +49,9 @@ def _create_causal_attention_mask(
concatenating 0s (representing memory positions) with a strictly upper
triangular matrix of 1s.
We then flip the matrix values in order to match the representation where
real values are 1s.
Arguments:
seq_length: int, The length of each sequence.
memory_length: int, The length of memory blocks.
...
...
@@ -59,10 +62,10 @@ def _create_causal_attention_mask(
A unidirectional attention mask of shape
`[seq_length, seq_length + memory_length]`. E.g.:
[[
0
. 0. 0.
1. 1. 1
.]
[
0
.
0
.
0
.
0
.
1
.
1
.]
[
0
.
0
.
0
.
0
.
0
.
1
.]
[
0
.
0
.
0
.
0
.
0
.
0
.]]
[[
1. 1. 1
. 0. 0.
0
.]
[
1
.
1
.
1
.
1
.
0
.
0
.]
[
1
.
1
.
1
.
1
.
1
.
0
.]
[
1
.
1
.
1
.
1
.
1
.
1
.]]
"""
ones_matrix
=
tf
.
ones
([
seq_length
,
seq_length
],
dtype
=
dtype
)
upper_triangular
=
tf
.
linalg
.
band_part
(
ones_matrix
,
0
,
-
1
)
...
...
@@ -78,7 +81,32 @@ def _create_causal_attention_mask(
[
causal_attention_mask
[:,
:
seq_length
]
+
strictly_lower_triangular
,
causal_attention_mask
[:,
seq_length
:]],
1
)
return
causal_attention_mask
return
1
-
causal_attention_mask
def
_combine_masks
(
mask1
,
mask2
,
dtype
,
how
=
"and"
):
"""Combines two masks.
Use "and" if trying to combine two existing masks.
Use "or" if trying to flip a few positions to "real".
Args:
mask1: tf.Tensor, input mask 1
mask2: tf.Tensor, input mask 2
dtype: tf.dtype
how: Which logical operation should run.
Returns:
The combined input masks.
"""
if
how
==
"and"
:
operator
=
tf
.
math
.
logical_and
else
:
operator
=
tf
.
math
.
logical_or
return
tf
.
cast
(
operator
(
tf
.
cast
(
mask1
,
tf
.
bool
),
tf
.
cast
(
mask2
,
tf
.
bool
)),
dtype
=
dtype
)
def
_compute_attention_mask
(
...
...
@@ -140,8 +168,7 @@ def _compute_attention_mask(
# input_mask: [B, S]
# permutation_mask: [B, S, S]
if
input_mask
is
not
None
and
permutation_mask
is
not
None
:
data_mask
=
input_mask
[:,
None
,
:]
+
permutation_mask
data_mask
=
_combine_masks
(
input_mask
[:,
None
,
:],
permutation_mask
,
dtype
)
elif
input_mask
is
not
None
and
permutation_mask
is
None
:
data_mask
=
input_mask
[:,
None
,
:]
elif
input_mask
is
None
and
permutation_mask
is
not
None
:
...
...
@@ -153,28 +180,28 @@ def _compute_attention_mask(
if
data_mask
is
not
None
:
# All positions within state can be attended to.
state_mask
=
tf
.
zero
s
([
batch_size
,
tf
.
shape
(
data_mask
)[
1
],
memory_length
],
dtype
=
dtype
)
state_mask
=
tf
.
one
s
([
batch_size
,
tf
.
shape
(
data_mask
)[
1
],
memory_length
],
dtype
=
dtype
)
# state_mask: [B, 1, M] or [B, S, M]
data_mask
=
tf
.
concat
([
state_mask
,
data_mask
],
2
)
# data_mask: [B, 1, S + M] or [B, S, S + M]
if
attention_type
==
"uni"
:
attention_mask
=
causal_attention_mask
+
data_mask
[:,
None
,
:,
:]
attention_mask
=
_combine_masks
(
causal_attention_mask
,
data_mask
[:,
None
,
:,
:],
dtype
=
dtype
)
else
:
attention_mask
=
data_mask
[:,
None
,
:,
:]
# Construct the content attention mask.
if
attention_mask
is
not
None
:
attention_mask
=
tf
.
cast
(
attention
_
mask
>
0
,
dtype
=
dtype
)
non_tgt_mask
=
-
tf
.
eye
(
seq_length
,
dtype
=
dtype
)
non_t
g
t_mask
=
tf
.
concat
(
# Construct the content
attention
mask
.
# This ensures that the mask allows the model to attend to positions in
# content positions (e.g. the content diagonal).
non_t
arge
t_mask
=
tf
.
concat
(
[
tf
.
zeros
([
seq_length
,
memory_length
],
dtype
=
dtype
),
non_tgt_mask
],
axis
=-
1
)
content_attention_mask
=
tf
.
cast
(
(
attention_mask
+
non_tgt_mask
[
None
,
None
,
:,
:])
>
0
,
dtype
=
dtype
)
tf
.
eye
(
seq_length
,
dtype
=
dtype
)],
axis
=-
1
)
content_attention_mask
=
_combine_masks
(
attention_mask
,
non_target_mask
,
how
=
"or"
,
dtype
=
dtype
)
else
:
content_attention_mask
=
None
...
...
official/nlp/modeling/networks/xlnet_base_test.py
View file @
801ac678
...
...
@@ -85,9 +85,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
seq_length
=
seq_length
,
memory_length
=
memory_length
)
expected_output
=
np
.
array
([[
0
,
1
,
1
],
[
0
,
0
,
1
],
[
0
,
0
,
0
]])
expected_output
=
np
.
array
([[
1
,
0
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
1
]])
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
def
test_casual_attention_mask_with_memory
(
self
):
...
...
@@ -96,9 +96,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
seq_length
=
seq_length
,
memory_length
=
memory_length
)
expected_output
=
np
.
array
([[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
0
,
0
]])
expected_output
=
np
.
array
([[
1
,
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
1
,
0
],
[
1
,
1
,
1
,
1
,
1
]])
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
def
test_causal_attention_mask_with_same_length
(
self
):
...
...
@@ -108,9 +108,9 @@ class CausalAttentionMaskTests(tf.test.TestCase):
memory_length
=
memory_length
,
same_length
=
True
)
expected_output
=
np
.
array
([[
0
,
0
,
0
,
1
,
1
],
[
1
,
0
,
0
,
0
,
1
],
[
1
,
1
,
0
,
0
,
0
]])
expected_output
=
np
.
array
([[
1
,
1
,
1
,
0
,
0
],
[
0
,
1
,
1
,
1
,
0
],
[
0
,
0
,
1
,
1
,
1
]])
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
...
...
@@ -179,15 +179,15 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size
=
1
memory_length
=
0
input_mask
=
np
.
array
([[
0
,
0
,
1
,
1
]])
input_mask
=
np
.
array
([[
1
,
1
,
0
,
0
]])
permutation_mask
=
None
expected_query_mask
=
input_mask
[
None
,
None
,
:,
:]
expected_content_mask
=
np
.
array
([[[
[
0
,
0
,
1
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
]]]])
[
1
,
1
,
0
,
0
],
[
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
0
],
[
1
,
1
,
0
,
1
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
...
...
@@ -209,14 +209,14 @@ class MaskComputationTests(keras_parameterized.TestCase):
input_mask
=
None
permutation_mask
=
np
.
array
([
[[
0
,
1
],
[
0
,
1
]],
[[
1
,
0
],
[
1
,
0
]],
])
expected_query_mask
=
permutation_mask
[:,
None
,
:,
:]
expected_content_mask
=
np
.
array
([[[
[
0
,
1
],
[
0
,
0
]]]])
[
1
,
0
],
[
1
,
1
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
...
...
@@ -236,24 +236,24 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size
=
1
memory_length
=
0
input_mask
=
np
.
array
([[
0
,
0
,
1
,
1
]])
input_mask
=
np
.
array
([[
1
,
1
,
0
,
0
]])
permutation_mask
=
np
.
array
([[
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
],
[
0
,
1
,
1
,
1
],
[
1
,
0
,
1
,
1
],
[
1
,
1
,
0
,
1
],
[
1
,
1
,
1
,
0
],
]])
expected_query_mask
=
np
.
array
([[[
[
1
,
0
,
1
,
1
],
[
0
,
1
,
1
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
1
,
1
]]]])
[
0
,
1
,
0
,
0
],
[
1
,
0
,
0
,
0
],
[
1
,
1
,
0
,
0
],
[
1
,
1
,
0
,
0
]]]])
expected_content_mask
=
np
.
array
([[[
[
0
,
0
,
1
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
]]]])
[
1
,
1
,
0
,
0
],
[
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
0
],
[
1
,
1
,
0
,
1
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
...
...
@@ -272,24 +272,24 @@ class MaskComputationTests(keras_parameterized.TestCase):
batch_size
=
1
memory_length
=
0
input_mask
=
np
.
array
([[
0
,
0
,
0
,
1
]])
input_mask
=
np
.
array
([[
1
,
1
,
1
,
0
]])
permutation_mask
=
np
.
array
([[
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
],
[
0
,
1
,
1
,
1
],
[
1
,
0
,
1
,
1
],
[
1
,
1
,
0
,
1
],
[
1
,
1
,
1
,
0
],
]])
expected_query_mask
=
np
.
array
([[[
[
1
,
1
,
1
,
1
],
[
0
,
1
,
1
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
1
]]]])
[
0
,
0
,
0
,
0
],
[
1
,
0
,
0
,
0
],
[
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
0
]]]])
expected_content_mask
=
np
.
array
([[[
[
0
,
1
,
1
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
0
]]]])
[
1
,
0
,
0
,
0
],
[
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
0
],
[
1
,
1
,
1
,
1
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
...
...
official/nlp/tasks/sentence_prediction.py
View file @
801ac678
...
...
@@ -81,13 +81,19 @@ class SentencePredictionTask(base_task.Task):
else
:
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
()
# Currently, we only support bert-style sentence prediction finetuning.
return
models
.
BertClassifier
(
network
=
encoder_network
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
if
self
.
task_config
.
model
.
encoder
.
type
==
'xlnet'
:
return
models
.
XLNetClassifier
(
network
=
encoder_network
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
encoder_cfg
.
initializer_range
))
else
:
return
models
.
BertClassifier
(
network
=
encoder_network
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
if
self
.
task_config
.
model
.
num_classes
==
1
:
...
...
official/nlp/xlnet/xlnet_modeling.py
View file @
801ac678
...
...
@@ -15,6 +15,7 @@
"""Keras layers of XLNet model in TF 2.0."""
import
copy
import
warnings
import
tensorflow
as
tf
...
...
@@ -416,6 +417,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
"""
super
(
TransformerXLModel
,
self
).
__init__
(
**
kwargs
)
warnings
.
warn
(
"`TransformerXLModel` is deprecated, please use `XLNetBase` instead"
,
DeprecationWarning
,
stacklevel
=
2
)
self
.
n_token
=
n_token
self
.
initializer
=
initializer
...
...
@@ -745,11 +749,13 @@ class PretrainingXLNetModel(tf.keras.Model):
"""
def
__init__
(
self
,
use_proj
,
xlnet_config
,
run_config
,
**
kwargs
):
def
__init__
(
self
,
use_proj
,
xlnet_config
,
run_config
,
use_legacy_mask
=
True
,
**
kwargs
):
super
(
PretrainingXLNetModel
,
self
).
__init__
(
**
kwargs
)
self
.
run_config
=
run_config
self
.
initializer
=
_get_initializer
(
run_config
)
self
.
xlnet_config
=
copy
.
deepcopy
(
xlnet_config
)
self
.
_use_legacy_mask
=
use_legacy_mask
self
.
xlnet_model
=
networks
.
XLNetBase
(
vocab_size
=
self
.
xlnet_config
.
n_token
,
...
...
@@ -788,7 +794,10 @@ class PretrainingXLNetModel(tf.keras.Model):
input_ids
=
features
[
"input_ids"
]
masked_tokens
=
features
[
"input_q"
]
seg_ids
=
features
[
"seg_id"
]
perm_mask
=
features
[
"perm_mask"
]
if
self
.
_use_legacy_mask
:
perm_mask
=
1
-
features
[
"perm_mask"
]
else
:
perm_mask
=
features
[
"perm_mask"
]
target_mapping
=
features
[
"target_mapping"
]
# target for LM loss
...
...
@@ -823,11 +832,16 @@ class ClassificationXLNetModel(tf.keras.Model):
"""
def
__init__
(
self
,
xlnet_config
,
run_config
,
n_class
,
summary_type
,
**
kwargs
):
def
__init__
(
self
,
xlnet_config
,
run_config
,
n_class
,
summary_type
,
use_legacy_mask
=
True
,
**
kwargs
):
super
(
ClassificationXLNetModel
,
self
).
__init__
(
**
kwargs
)
warnings
.
warn
(
"`ClassificationXLNetModel` is deprecated, please use `XLNetClassifier`"
"instead."
,
DeprecationWarning
,
stacklevel
=
2
)
self
.
run_config
=
run_config
self
.
initializer
=
_get_initializer
(
run_config
)
self
.
xlnet_config
=
copy
.
deepcopy
(
xlnet_config
)
self
.
_use_legacy_mask
=
use_legacy_mask
self
.
xlnet_model
=
networks
.
XLNetBase
(
vocab_size
=
self
.
xlnet_config
.
n_token
,
...
...
@@ -870,7 +884,10 @@ class ClassificationXLNetModel(tf.keras.Model):
input_ids
=
features
[
"input_ids"
]
segment_ids
=
features
[
"segment_ids"
]
input_mask
=
features
[
"input_mask"
]
if
self
.
_use_legacy_mask
:
input_mask
=
1
-
features
[
"input_mask"
]
else
:
input_mask
=
features
[
"input_mask"
]
label
=
tf
.
reshape
(
features
[
"label_ids"
],
[
batch_size_per_core
])
...
...
@@ -1068,11 +1085,15 @@ class QAXLNetModel(tf.keras.Model):
"""
def
__init__
(
self
,
xlnet_config
,
run_config
,
start_n_top
,
end_n_top
,
**
kwargs
):
use_legacy_mask
=
True
,
**
kwargs
):
super
(
QAXLNetModel
,
self
).
__init__
(
**
kwargs
)
warnings
.
warn
(
"`QAXLNetModel` is deprecated, please use `XLNetSpanLabeler` instead."
,
DeprecationWarning
,
stacklevel
=
2
)
self
.
run_config
=
run_config
self
.
initializer
=
_get_initializer
(
run_config
)
self
.
xlnet_config
=
copy
.
deepcopy
(
xlnet_config
)
self
.
_use_legacy_mask
=
use_legacy_mask
self
.
xlnet_model
=
networks
.
XLNetBase
(
vocab_size
=
self
.
xlnet_config
.
n_token
,
...
...
@@ -1108,7 +1129,10 @@ class QAXLNetModel(tf.keras.Model):
input_ids
=
features
[
"input_ids"
]
segment_ids
=
features
[
"segment_ids"
]
input_mask
=
features
[
"input_mask"
]
if
self
.
_use_legacy_mask
:
input_mask
=
1
-
features
[
"input_mask"
]
else
:
input_mask
=
features
[
"input_mask"
]
cls_index
=
tf
.
reshape
(
features
[
"cls_index"
],
[
-
1
])
p_mask
=
features
[
"p_mask"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment