Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
69fb2164
Commit
69fb2164
authored
Nov 19, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Nov 19, 2020
Browse files
Implement XLNet Pretrainer model.
PiperOrigin-RevId: 343345144
parent
41f4927a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
247 additions
and
13 deletions
+247
-13
official/nlp/modeling/models/__init__.py
official/nlp/modeling/models/__init__.py
+1
-0
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+129
-0
official/nlp/modeling/models/xlnet_test.py
official/nlp/modeling/models/xlnet_test.py
+109
-11
official/nlp/modeling/networks/xlnet_base.py
official/nlp/modeling/networks/xlnet_base.py
+8
-2
No files found.
official/nlp/modeling/models/__init__.py
View file @
69fb2164
...
@@ -21,4 +21,5 @@ from official.nlp.modeling.models.dual_encoder import DualEncoder
...
@@ -21,4 +21,5 @@ from official.nlp.modeling.models.dual_encoder import DualEncoder
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
from
official.nlp.modeling.models.seq2seq_transformer
import
*
from
official.nlp.modeling.models.seq2seq_transformer
import
*
from
official.nlp.modeling.models.xlnet
import
XLNetClassifier
from
official.nlp.modeling.models.xlnet
import
XLNetClassifier
from
official.nlp.modeling.models.xlnet
import
XLNetPretrainer
from
official.nlp.modeling.models.xlnet
import
XLNetSpanLabeler
from
official.nlp.modeling.models.xlnet
import
XLNetSpanLabeler
official/nlp/modeling/models/xlnet.py
View file @
69fb2164
...
@@ -23,6 +23,135 @@ from official.nlp.modeling import layers
...
@@ -23,6 +23,135 @@ from official.nlp.modeling import layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
class
XLNetMaskedLM
(
tf
.
keras
.
layers
.
Layer
):
"""XLNet pretraining head."""
def
__init__
(
self
,
vocab_size
:
int
,
hidden_size
:
int
,
initializer
:
str
=
'glorot_uniform'
,
activation
:
str
=
'gelu'
,
name
=
None
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
_vocab_size
=
vocab_size
self
.
_hidden_size
=
hidden_size
self
.
_initializer
=
initializer
self
.
_activation
=
activation
def
build
(
self
,
input_shape
):
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
units
=
self
.
_hidden_size
,
activation
=
self
.
_activation
,
kernel_initializer
=
self
.
_initializer
,
name
=
'transform/dense'
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
axis
=-
1
,
epsilon
=
1e-12
,
name
=
'transform/LayerNorm'
)
self
.
bias
=
self
.
add_weight
(
'output_bias/bias'
,
shape
=
(
self
.
_vocab_size
,),
initializer
=
'zeros'
,
trainable
=
True
)
super
().
build
(
input_shape
)
def
call
(
self
,
sequence_data
:
tf
.
Tensor
,
embedding_table
:
tf
.
Tensor
):
lm_data
=
self
.
dense
(
sequence_data
)
lm_data
=
self
.
layer_norm
(
lm_data
)
lm_data
=
tf
.
matmul
(
lm_data
,
embedding_table
,
transpose_b
=
True
)
logits
=
tf
.
nn
.
bias_add
(
lm_data
,
self
.
bias
)
return
logits
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
config
=
{
'vocab_size'
:
self
.
_vocab_size
,
'hidden_size'
:
self
.
_hidden_size
,
'initializer'
:
self
.
_initializer
}
base_config
=
super
(
XLNetMaskedLM
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
XLNetPretrainer
(
tf
.
keras
.
Model
):
"""XLNet-based pretrainer.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Arguments:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
mlm_activation: The activation (if any) to use in the Masked LM network. If
None, then no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Defaults
to a Glorot uniform initializer.
"""
def
__init__
(
self
,
network
:
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
],
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
name
:
str
=
None
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
_config
=
{
'network'
:
network
,
'mlm_activation'
:
mlm_activation
,
'mlm_initializer'
:
mlm_initializer
,
}
self
.
_network
=
network
self
.
_hidden_size
=
network
.
get_config
()[
'hidden_size'
]
self
.
_vocab_size
=
network
.
get_config
()[
'vocab_size'
]
self
.
_activation
=
mlm_activation
self
.
_initializer
=
mlm_initializer
self
.
_masked_lm
=
XLNetMaskedLM
(
vocab_size
=
self
.
_vocab_size
,
hidden_size
=
self
.
_hidden_size
,
initializer
=
self
.
_initializer
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_word_ids
=
inputs
[
'input_word_ids'
]
input_type_ids
=
inputs
[
'input_type_ids'
]
masked_tokens
=
inputs
[
'masked_tokens'
]
permutation_mask
=
inputs
[
'permutation_mask'
]
target_mapping
=
inputs
[
'target_mapping'
]
state
=
inputs
.
get
(
'state'
,
None
)
attention_output
,
state
=
self
.
_network
(
input_ids
=
input_word_ids
,
segment_ids
=
input_type_ids
,
input_mask
=
None
,
state
=
state
,
permutation_mask
=
permutation_mask
,
target_mapping
=
target_mapping
,
masked_tokens
=
masked_tokens
)
embedding_table
=
self
.
_network
.
get_embedding_lookup_table
()
mlm_outputs
=
self
.
_masked_lm
(
sequence_data
=
attention_output
,
embedding_table
=
embedding_table
)
return
mlm_outputs
,
state
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
XLNetClassifier
(
tf
.
keras
.
Model
):
class
XLNetClassifier
(
tf
.
keras
.
Model
):
"""Classifier model based on XLNet.
"""Classifier model based on XLNet.
...
...
official/nlp/modeling/models/xlnet_test.py
View file @
69fb2164
...
@@ -46,6 +46,104 @@ def _get_xlnet_base() -> tf.keras.layers.Layer:
...
@@ -46,6 +46,104 @@ def _get_xlnet_base() -> tf.keras.layers.Layer:
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
XLNetMaskedLMTest
(
keras_parameterized
.
TestCase
):
def
test_xlnet_masked_lm_head
(
self
):
hidden_size
=
10
seq_length
=
8
batch_size
=
2
masked_lm
=
xlnet
.
XLNetMaskedLM
(
vocab_size
=
10
,
hidden_size
=
hidden_size
,
initializer
=
'glorot_uniform'
)
sequence_data
=
np
.
random
.
uniform
(
size
=
(
batch_size
,
seq_length
))
embedding_table
=
np
.
random
.
uniform
(
size
=
(
hidden_size
,
hidden_size
))
mlm_output
=
masked_lm
(
sequence_data
,
embedding_table
)
self
.
assertAllClose
(
mlm_output
.
shape
,
(
batch_size
,
hidden_size
))
@
keras_parameterized
.
run_all_keras_modes
class
XLNetPretrainerTest
(
keras_parameterized
.
TestCase
):
def
test_xlnet_trainer
(
self
):
"""Validates that the Keras object can be created."""
seq_length
=
4
num_predictions
=
2
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base
=
_get_xlnet_base
()
# Create an XLNet trainer with the created network.
xlnet_trainer_model
=
xlnet
.
XLNetPretrainer
(
network
=
xlnet_base
)
inputs
=
dict
(
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
),
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
),
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
),
permutation_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'permutation_mask'
),
target_mapping
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
num_predictions
,
seq_length
),
dtype
=
tf
.
int32
,
name
=
'target_mapping'
),
masked_tokens
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'masked_tokens'
))
logits
,
_
=
xlnet_trainer_model
(
inputs
)
# [None, hidden_size, vocab_size]
expected_output_shape
=
[
None
,
4
,
100
]
self
.
assertAllEqual
(
expected_output_shape
,
logits
.
shape
.
as_list
())
def
test_xlnet_tensor_call
(
self
):
"""Validates that the Keras object can be invoked."""
seq_length
=
4
batch_size
=
2
num_predictions
=
2
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base
=
_get_xlnet_base
()
# Create an XLNet trainer with the created network.
xlnet_trainer_model
=
xlnet
.
XLNetPretrainer
(
network
=
xlnet_base
)
sequence_shape
=
(
batch_size
,
seq_length
)
inputs
=
dict
(
input_word_ids
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_type_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'int32'
),
permutation_mask
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
seq_length
,
seq_length
)).
astype
(
'int32'
),
target_mapping
=
np
.
random
.
randint
(
10
,
size
=
(
num_predictions
,
seq_length
),
dtype
=
'int32'
),
masked_tokens
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
))
xlnet_trainer_model
(
inputs
)
def
test_serialize_deserialize
(
self
):
"""Validates that the XLNet trainer can be serialized and deserialized."""
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base
=
_get_xlnet_base
()
# Create an XLNet trainer with the created network.
xlnet_trainer_model
=
xlnet
.
XLNetPretrainer
(
network
=
xlnet_base
,
mlm_activation
=
'gelu'
,
mlm_initializer
=
'random_normal'
)
# Create another XLNet trainer via serialization and deserialization.
config
=
xlnet_trainer_model
.
get_config
()
new_xlnet_trainer_model
=
xlnet
.
XLNetPretrainer
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_xlnet_trainer_model
.
to_json
()
# If serialization was successful, then the new config should match the old.
self
.
assertAllEqual
(
xlnet_trainer_model
.
get_config
(),
new_xlnet_trainer_model
.
get_config
())
@
keras_parameterized
.
run_all_keras_modes
@
keras_parameterized
.
run_all_keras_modes
class
XLNetClassifierTest
(
keras_parameterized
.
TestCase
):
class
XLNetClassifierTest
(
keras_parameterized
.
TestCase
):
...
@@ -69,13 +167,12 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
...
@@ -69,13 +167,12 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
),
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
),
input_mask
=
tf
.
keras
.
layers
.
Input
(
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
floa
t32
,
name
=
'input_mask'
),
shape
=
(
seq_length
,),
dtype
=
tf
.
in
t32
,
name
=
'input_mask'
),
permutation_mask
=
tf
.
keras
.
layers
.
Input
(
permutation_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,
seq_length
,),
dtype
=
tf
.
floa
t32
,
shape
=
(
seq_length
,
seq_length
,),
dtype
=
tf
.
in
t32
,
name
=
'permutation_mask'
),
name
=
'permutation_mask'
),
masked_tokens
=
tf
.
keras
.
layers
.
Input
(
masked_tokens
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'masked_tokens'
))
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'masked_tokens'
))
logits
=
xlnet_trainer_model
(
inputs
)
logits
=
xlnet_trainer_model
(
inputs
)
expected_classification_shape
=
[
None
,
num_classes
]
expected_classification_shape
=
[
None
,
num_classes
]
...
@@ -102,10 +199,11 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
...
@@ -102,10 +199,11 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
input_word_ids
=
np
.
random
.
randint
(
input_word_ids
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_type_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_type_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'
floa
t32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'
in
t32'
),
permutation_mask
=
np
.
random
.
randint
(
permutation_mask
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
seq_length
,
seq_length
)).
astype
(
'float32'
),
2
,
size
=
(
batch_size
,
seq_length
,
seq_length
)).
astype
(
'int32'
),
masked_tokens
=
tf
.
random
.
uniform
(
shape
=
sequence_shape
))
masked_tokens
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
))
xlnet_trainer_model
(
inputs
)
xlnet_trainer_model
(
inputs
)
def
test_serialize_deserialize
(
self
):
def
test_serialize_deserialize
(
self
):
...
@@ -158,9 +256,9 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase):
...
@@ -158,9 +256,9 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase):
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
),
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
),
input_mask
=
tf
.
keras
.
layers
.
Input
(
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
floa
t32
,
name
=
'input_mask'
),
shape
=
(
seq_length
,),
dtype
=
tf
.
in
t32
,
name
=
'input_mask'
),
paragraph_mask
=
tf
.
keras
.
layers
.
Input
(
paragraph_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
floa
t32
,
name
=
'paragraph_mask'
),
shape
=
(
seq_length
,),
dtype
=
tf
.
in
t32
,
name
=
'paragraph_mask'
),
class_index
=
tf
.
keras
.
layers
.
Input
(
class_index
=
tf
.
keras
.
layers
.
Input
(
shape
=
(),
dtype
=
tf
.
int32
,
name
=
'class_index'
),
shape
=
(),
dtype
=
tf
.
int32
,
name
=
'class_index'
),
start_positions
=
tf
.
keras
.
layers
.
Input
(
start_positions
=
tf
.
keras
.
layers
.
Input
(
...
@@ -175,9 +273,9 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase):
...
@@ -175,9 +273,9 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase):
input_word_ids
=
np
.
random
.
randint
(
input_word_ids
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_type_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_type_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'
floa
t32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'
in
t32'
),
paragraph_mask
=
np
.
random
.
randint
(
paragraph_mask
=
np
.
random
.
randint
(
1
,
size
=
(
sequence_shape
)).
astype
(
'
floa
t32'
),
1
,
size
=
(
sequence_shape
)).
astype
(
'
in
t32'
),
class_index
=
np
.
random
.
randint
(
1
,
size
=
(
batch_size
)).
astype
(
'uint8'
),
class_index
=
np
.
random
.
randint
(
1
,
size
=
(
batch_size
)).
astype
(
'uint8'
),
start_positions
=
tf
.
random
.
uniform
(
start_positions
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,),
maxval
=
5
,
dtype
=
tf
.
int32
))
shape
=
(
batch_size
,),
maxval
=
5
,
dtype
=
tf
.
int32
))
...
...
official/nlp/modeling/networks/xlnet_base.py
View file @
69fb2164
...
@@ -242,7 +242,8 @@ def _compute_segment_matrix(
...
@@ -242,7 +242,8 @@ def _compute_segment_matrix(
if
segment_ids
is
None
:
if
segment_ids
is
None
:
return
None
return
None
memory_padding
=
tf
.
zeros
([
batch_size
,
memory_length
],
dtype
=
tf
.
int32
)
memory_padding
=
tf
.
zeros
([
batch_size
,
memory_length
],
dtype
=
segment_ids
.
dtype
)
padded_segment_ids
=
tf
.
concat
([
memory_padding
,
segment_ids
],
1
)
padded_segment_ids
=
tf
.
concat
([
memory_padding
,
segment_ids
],
1
)
# segment_ids: [B, S]
# segment_ids: [B, S]
# padded_segment_ids: [B, S + M]
# padded_segment_ids: [B, S + M]
...
@@ -629,7 +630,12 @@ class XLNetBase(tf.keras.layers.Layer):
...
@@ -629,7 +630,12 @@ class XLNetBase(tf.keras.layers.Layer):
"enabled. Please enable `two_stream` to enable two "
"enabled. Please enable `two_stream` to enable two "
"stream attention."
)
"stream attention."
)
dtype
=
input_mask
.
dtype
if
input_mask
is
not
None
else
tf
.
float32
if
input_mask
is
not
None
:
dtype
=
input_mask
.
dtype
elif
permutation_mask
is
not
None
:
dtype
=
permutation_mask
.
dtype
else
:
dtype
=
tf
.
int32
query_attention_mask
,
content_attention_mask
=
_compute_attention_mask
(
query_attention_mask
,
content_attention_mask
=
_compute_attention_mask
(
input_mask
=
input_mask
,
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
permutation_mask
=
permutation_mask
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment