Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fcb43c38
Commit
fcb43c38
authored
Oct 12, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Oct 12, 2020
Browse files
Implement SpanLabeler for XLNet.
PiperOrigin-RevId: 336709640
parent
a26d77c4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
341 additions
and
0 deletions
+341
-0
official/nlp/modeling/networks/__init__.py
official/nlp/modeling/networks/__init__.py
+1
-0
official/nlp/modeling/networks/span_labeling.py
official/nlp/modeling/networks/span_labeling.py
+202
-0
official/nlp/modeling/networks/span_labeling_test.py
official/nlp/modeling/networks/span_labeling_test.py
+138
-0
No files found.
official/nlp/modeling/networks/__init__.py
View file @
fcb43c38
...
...
@@ -19,6 +19,7 @@ from official.nlp.modeling.networks.classification import Classification
from
official.nlp.modeling.networks.encoder_scaffold
import
EncoderScaffold
from
official.nlp.modeling.networks.mobile_bert_encoder
import
MobileBERTEncoder
from
official.nlp.modeling.networks.span_labeling
import
SpanLabeling
from
official.nlp.modeling.networks.span_labeling
import
XLNetSpanLabeling
from
official.nlp.modeling.networks.xlnet_base
import
XLNetBase
# Backward compatibility. The modules are deprecated.
TransformerEncoder
=
BertEncoder
official/nlp/modeling/networks/span_labeling.py
View file @
fcb43c38
...
...
@@ -22,6 +22,14 @@ from __future__ import print_function
import
tensorflow
as
tf
def
_apply_position_mask
(
logits
,
position_mask
):
"""Applies a position mask to calculated logits."""
if
tf
.
rank
(
logits
)
!=
tf
.
rank
(
position_mask
):
position_mask
=
position_mask
[:,
None
,
:]
masked_logits
=
logits
*
(
1
-
position_mask
)
-
1e30
*
position_mask
return
tf
.
nn
.
log_softmax
(
masked_logits
,
-
1
),
masked_logits
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
SpanLabeling
(
tf
.
keras
.
Model
):
"""Span labeling network head for BERT modeling.
...
...
@@ -92,3 +100,197 @@ class SpanLabeling(tf.keras.Model):
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
class
XLNetSpanLabeling
(
tf
.
keras
.
layers
.
Layer
):
"""Span labeling network head for XLNet on SQuAD2.0.
This networks implements a span-labeler based on dense layers and question
possibility classification. This is the complex version seen in the original
XLNet implementation.
This applies a dense layer to the input sequence data to predict the start
positions, and then uses either the true start positions (if training) or
beam search to predict the end positions.
Arguments:
input_width: The innermost dimension of the input tensor to this network.
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
activation: The activation, if any, for the dense layer in this network.
dropout_rate: The dropout rate used for answer classification.
initializer: The initializer for the dense layer in this network. Defaults
to a Glorot uniform initializer.
"""
def
__init__
(
self
,
input_width
,
start_n_top
,
end_n_top
,
activation
=
'tanh'
,
dropout_rate
=
0.
,
initializer
=
'glorot_uniform'
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_config
=
{
'input_width'
:
input_width
,
'activation'
:
activation
,
'initializer'
:
initializer
,
'start_n_top'
:
start_n_top
,
'end_n_top'
:
end_n_top
,
'dropout_rate'
:
dropout_rate
,
}
self
.
_start_n_top
=
start_n_top
self
.
_end_n_top
=
end_n_top
self
.
start_logits_dense
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
initializer
,
name
=
'predictions/transform/start_logits'
)
self
.
end_logits_inner_dense
=
tf
.
keras
.
layers
.
Dense
(
units
=
input_width
,
kernel_initializer
=
initializer
,
name
=
'predictions/transform/end_logits/inner'
)
self
.
end_logits_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
axis
=-
1
,
epsilon
=
1e-12
,
name
=
'predictions/transform/end_logits/layernorm'
)
self
.
end_logits_output_dense
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
initializer
,
name
=
'predictions/transform/end_logits/output'
)
self
.
answer_logits_inner
=
tf
.
keras
.
layers
.
Dense
(
units
=
input_width
,
kernel_initializer
=
initializer
,
activation
=
activation
,
name
=
'predictions/transform/answer_logits/inner'
)
self
.
answer_logits_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)
self
.
answer_logits_output
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
initializer
,
use_bias
=
False
,
name
=
'predictions/transform/answer_logits/output'
)
def
end_logits
(
self
,
inputs
):
"""Computes the end logits."""
end_logits
=
self
.
end_logits_inner_dense
(
inputs
)
end_logits
=
self
.
end_logits_layer_norm
(
end_logits
)
end_logits
=
self
.
end_logits_output_dense
(
end_logits
)
end_logits
=
tf
.
squeeze
(
end_logits
)
if
tf
.
rank
(
end_logits
)
>
2
:
# shape = [batch_size, seq_length, start_n_top]
end_logits
=
tf
.
transpose
(
end_logits
,
[
0
,
2
,
1
])
return
end_logits
def
call
(
self
,
sequence_data
,
class_index
,
position_mask
=
None
,
start_positions
=
None
,
training
=
False
):
"""Implements call().
Einsum glossary:
- b: the batch size.
- l: the sequence length.
- h: the hidden size, or input width.
- k: the start/end top n.
Args:
sequence_data: The input sequence data of shape
(batch_size, seq_length, input_width).
class_index: The class indices of the inputs of shape (batch_size,).
position_mask: Invalid position mask such as query and special symbols
(e.g. PAD, SEP, CLS) of shape (batch_size,).
start_positions: The start positions of each example of shape
(batch_size,).
training: Whether or not this is the training phase.
Returns:
A dictionary with the keys 'cls_logits' and
- (if training) 'start_log_probs', 'end_log_probs'.
- (if inference/beam search) 'start_top_log_probs', 'start_top_index',
'end_top_log_probs', 'end_top_index'.
"""
seq_length
=
tf
.
shape
(
sequence_data
)[
1
]
start_logits
=
self
.
start_logits_dense
(
sequence_data
)
start_logits
=
tf
.
squeeze
(
start_logits
,
-
1
)
start_log_probs
,
masked_start_logits
=
_apply_position_mask
(
start_logits
,
position_mask
)
compute_with_beam_search
=
not
training
or
start_positions
is
None
if
compute_with_beam_search
:
# Compute end logits using beam search.
start_top_log_probs
,
start_top_index
=
tf
.
nn
.
top_k
(
start_log_probs
,
k
=
self
.
_start_n_top
)
start_index
=
tf
.
one_hot
(
start_top_index
,
depth
=
seq_length
,
axis
=-
1
,
dtype
=
tf
.
float32
)
# start_index: [batch_size, end_n_top, seq_length]
start_features
=
tf
.
einsum
(
'blh,bkl->bkh'
,
sequence_data
,
start_index
)
start_features
=
tf
.
tile
(
start_features
[:,
None
,
:,
:],
[
1
,
seq_length
,
1
,
1
])
# start_features: [batch_size, seq_length, end_n_top, input_width]
end_input
=
tf
.
tile
(
sequence_data
[:,
:,
None
],
[
1
,
1
,
self
.
_start_n_top
,
1
])
end_input
=
tf
.
concat
([
end_input
,
start_features
],
axis
=-
1
)
# end_input: [batch_size, seq_length, end_n_top, 2*input_width]
else
:
start_positions
=
tf
.
reshape
(
start_positions
,
-
1
)
start_index
=
tf
.
one_hot
(
start_positions
,
depth
=
seq_length
,
axis
=-
1
,
dtype
=
tf
.
float32
)
# start_index: [batch_size, seq_length]
start_features
=
tf
.
einsum
(
'blh,bl->bh'
,
sequence_data
,
start_index
)
start_features
=
tf
.
tile
(
start_features
[:,
None
,
:],
[
1
,
seq_length
,
1
])
# start_features: [batch_size, seq_length, input_width]
end_input
=
tf
.
concat
([
sequence_data
,
start_features
],
axis
=-
1
)
# end_input: [batch_size, seq_length, 2*input_width]
end_logits
=
self
.
end_logits
(
end_input
)
end_log_probs
,
_
=
_apply_position_mask
(
end_logits
,
position_mask
)
output_dict
=
{}
if
training
:
output_dict
[
'start_log_probs'
]
=
start_log_probs
output_dict
[
'end_log_probs'
]
=
end_log_probs
else
:
end_top_log_probs
,
end_top_index
=
tf
.
nn
.
top_k
(
end_log_probs
,
k
=
self
.
_end_n_top
)
end_top_log_probs
=
tf
.
reshape
(
end_top_log_probs
,
[
-
1
,
self
.
_start_n_top
*
self
.
_end_n_top
])
end_top_index
=
tf
.
reshape
(
end_top_index
,
[
-
1
,
self
.
_start_n_top
*
self
.
_end_n_top
])
output_dict
[
'start_top_log_probs'
]
=
start_top_log_probs
output_dict
[
'start_top_index'
]
=
start_top_index
output_dict
[
'end_top_log_probs'
]
=
end_top_log_probs
output_dict
[
'end_top_index'
]
=
end_top_index
# get the representation of CLS
class_index
=
tf
.
one_hot
(
class_index
,
seq_length
,
axis
=-
1
,
dtype
=
tf
.
float32
)
class_feature
=
tf
.
einsum
(
'blh,bl->bh'
,
sequence_data
,
class_index
)
# get the representation of START
start_p
=
tf
.
nn
.
softmax
(
masked_start_logits
,
axis
=-
1
)
start_feature
=
tf
.
einsum
(
'blh,bl->bh'
,
sequence_data
,
start_p
)
answer_feature
=
tf
.
concat
([
start_feature
,
class_feature
],
-
1
)
answer_feature
=
self
.
answer_logits_inner
(
answer_feature
)
answer_feature
=
self
.
answer_logits_dropout
(
answer_feature
)
class_logits
=
self
.
answer_logits_output
(
answer_feature
)
class_logits
=
tf
.
squeeze
(
class_logits
,
-
1
)
output_dict
[
'class_logits'
]
=
class_logits
return
output_dict
def
get_config
(
self
):
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/nlp/modeling/networks/span_labeling_test.py
View file @
fcb43c38
...
...
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -170,5 +172,141 @@ class SpanLabelingTest(keras_parameterized.TestCase):
_
=
span_labeling
.
SpanLabeling
(
input_width
=
10
,
output
=
'bad'
)
@
keras_parameterized
.
run_all_keras_modes
class
XLNetSpanLabelingTest
(
keras_parameterized
.
TestCase
):
def
test_basic_invocation_train
(
self
):
batch_size
=
2
seq_length
=
8
hidden_size
=
4
sequence_data
=
np
.
random
.
uniform
(
size
=
(
batch_size
,
seq_length
,
hidden_size
)).
astype
(
'float32'
)
position_mask
=
np
.
random
.
uniform
(
size
=
(
batch_size
,
seq_length
)).
astype
(
'float32'
)
class_index
=
np
.
random
.
uniform
(
size
=
(
batch_size
)).
astype
(
'uint8'
)
start_positions
=
np
.
zeros
(
shape
=
(
batch_size
)).
astype
(
'uint8'
)
layer
=
span_labeling
.
XLNetSpanLabeling
(
input_width
=
hidden_size
,
start_n_top
=
1
,
end_n_top
=
1
,
activation
=
'tanh'
,
dropout_rate
=
0.
,
initializer
=
'glorot_uniform'
)
output
=
layer
(
sequence_data
=
sequence_data
,
class_index
=
class_index
,
position_mask
=
position_mask
,
start_positions
=
start_positions
,
training
=
True
)
expected_keys
=
{
'start_log_probs'
,
'end_log_probs'
,
'class_logits'
,
}
self
.
assertSetEqual
(
expected_keys
,
set
(
output
.
keys
()))
@
parameterized
.
named_parameters
(
(
'top_1'
,
1
),
(
'top_n'
,
5
))
def
test_basic_invocation_beam_search
(
self
,
top_n
):
batch_size
=
2
seq_length
=
8
hidden_size
=
4
sequence_data
=
np
.
random
.
uniform
(
size
=
(
batch_size
,
seq_length
,
hidden_size
)).
astype
(
'float32'
)
position_mask
=
np
.
random
.
uniform
(
size
=
(
batch_size
,
seq_length
)).
astype
(
'float32'
)
class_index
=
np
.
random
.
uniform
(
size
=
(
batch_size
)).
astype
(
'uint8'
)
layer
=
span_labeling
.
XLNetSpanLabeling
(
input_width
=
hidden_size
,
start_n_top
=
top_n
,
end_n_top
=
top_n
,
activation
=
'tanh'
,
dropout_rate
=
0.
,
initializer
=
'glorot_uniform'
)
output
=
layer
(
sequence_data
=
sequence_data
,
class_index
=
class_index
,
position_mask
=
position_mask
,
training
=
False
)
expected_keys
=
{
'start_top_log_probs'
,
'end_top_log_probs'
,
'class_logits'
,
'start_top_index'
,
'end_top_index'
,
}
self
.
assertSetEqual
(
expected_keys
,
set
(
output
.
keys
()))
def
test_functional_model_invocation
(
self
):
"""Tests basic invocation of this layer wrapped by a Functional model."""
seq_length
=
8
hidden_size
=
4
batch_size
=
2
sequence_data
=
tf
.
keras
.
Input
(
shape
=
(
seq_length
,
hidden_size
),
dtype
=
tf
.
float32
)
class_index
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
uint8
)
position_mask
=
tf
.
keras
.
Input
(
shape
=
(
seq_length
),
dtype
=
tf
.
float32
)
start_positions
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
float32
)
layer
=
span_labeling
.
XLNetSpanLabeling
(
input_width
=
hidden_size
,
start_n_top
=
5
,
end_n_top
=
5
,
activation
=
'tanh'
,
dropout_rate
=
0.
,
initializer
=
'glorot_uniform'
)
output
=
layer
(
sequence_data
=
sequence_data
,
class_index
=
class_index
,
position_mask
=
position_mask
,
start_positions
=
start_positions
)
model
=
tf
.
keras
.
Model
(
inputs
=
{
'sequence_data'
:
sequence_data
,
'class_index'
:
class_index
,
'position_mask'
:
position_mask
,
'start_positions'
:
start_positions
,
},
outputs
=
output
)
sequence_data
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,
seq_length
,
hidden_size
),
dtype
=
tf
.
float32
)
position_mask
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,
seq_length
),
dtype
=
tf
.
float32
)
class_index
=
tf
.
ones
(
shape
=
(
batch_size
,),
dtype
=
tf
.
uint8
)
start_positions
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,),
dtype
=
tf
.
float32
)
inputs
=
dict
(
sequence_data
=
sequence_data
,
position_mask
=
position_mask
,
class_index
=
class_index
,
start_positions
=
start_positions
)
output
=
model
(
inputs
)
self
.
assertIsInstance
(
output
,
dict
)
# Test `call` with training flag.
output
=
model
.
call
(
inputs
,
training
=
True
)
self
.
assertIsInstance
(
output
,
dict
)
# Test `call` without training flag.
output
=
model
.
call
(
inputs
,
training
=
False
)
self
.
assertIsInstance
(
output
,
dict
)
def
test_serialize_deserialize
(
self
):
# Create a network object that sets all of its config options.
network
=
span_labeling
.
XLNetSpanLabeling
(
input_width
=
128
,
start_n_top
=
5
,
end_n_top
=
1
,
activation
=
'tanh'
,
dropout_rate
=
0.34
,
initializer
=
'zeros'
)
# Create another network object from the first object's config.
new_network
=
span_labeling
.
XLNetSpanLabeling
.
from_config
(
network
.
get_config
())
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment