"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "aa03da16684a8f6c1aea6bb7c94b2b89de75b896"
Commit fcb43c38 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Implement SpanLabeler for XLNet.

PiperOrigin-RevId: 336709640
parent a26d77c4
...@@ -19,6 +19,7 @@ from official.nlp.modeling.networks.classification import Classification ...@@ -19,6 +19,7 @@ from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.mobile_bert_encoder import MobileBERTEncoder from official.nlp.modeling.networks.mobile_bert_encoder import MobileBERTEncoder
from official.nlp.modeling.networks.span_labeling import SpanLabeling from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.span_labeling import XLNetSpanLabeling
from official.nlp.modeling.networks.xlnet_base import XLNetBase from official.nlp.modeling.networks.xlnet_base import XLNetBase
# Backward compatibility. The modules are deprecated. # Backward compatibility. The modules are deprecated.
TransformerEncoder = BertEncoder TransformerEncoder = BertEncoder
...@@ -22,6 +22,14 @@ from __future__ import print_function ...@@ -22,6 +22,14 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
def _apply_position_mask(logits, position_mask):
"""Applies a position mask to calculated logits."""
if tf.rank(logits) != tf.rank(position_mask):
position_mask = position_mask[:, None, :]
masked_logits = logits * (1 - position_mask) - 1e30 * position_mask
return tf.nn.log_softmax(masked_logits, -1), masked_logits
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class SpanLabeling(tf.keras.Model): class SpanLabeling(tf.keras.Model):
"""Span labeling network head for BERT modeling. """Span labeling network head for BERT modeling.
...@@ -92,3 +100,197 @@ class SpanLabeling(tf.keras.Model): ...@@ -92,3 +100,197 @@ class SpanLabeling(tf.keras.Model):
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
class XLNetSpanLabeling(tf.keras.layers.Layer):
"""Span labeling network head for XLNet on SQuAD2.0.
This networks implements a span-labeler based on dense layers and question
possibility classification. This is the complex version seen in the original
XLNet implementation.
This applies a dense layer to the input sequence data to predict the start
positions, and then uses either the true start positions (if training) or
beam search to predict the end positions.
Arguments:
input_width: The innermost dimension of the input tensor to this network.
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
activation: The activation, if any, for the dense layer in this network.
dropout_rate: The dropout rate used for answer classification.
initializer: The initializer for the dense layer in this network. Defaults
to a Glorot uniform initializer.
"""
def __init__(self,
input_width,
start_n_top,
end_n_top,
activation='tanh',
dropout_rate=0.,
initializer='glorot_uniform',
**kwargs):
super().__init__(**kwargs)
self._config = {
'input_width': input_width,
'activation': activation,
'initializer': initializer,
'start_n_top': start_n_top,
'end_n_top': end_n_top,
'dropout_rate': dropout_rate,
}
self._start_n_top = start_n_top
self._end_n_top = end_n_top
self.start_logits_dense = tf.keras.layers.Dense(
units=1,
kernel_initializer=initializer,
name='predictions/transform/start_logits')
self.end_logits_inner_dense = tf.keras.layers.Dense(
units=input_width,
kernel_initializer=initializer,
name='predictions/transform/end_logits/inner')
self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12,
name='predictions/transform/end_logits/layernorm')
self.end_logits_output_dense = tf.keras.layers.Dense(
units=1,
kernel_initializer=initializer,
name='predictions/transform/end_logits/output')
self.answer_logits_inner = tf.keras.layers.Dense(
units=input_width,
kernel_initializer=initializer,
activation=activation,
name='predictions/transform/answer_logits/inner')
self.answer_logits_dropout = tf.keras.layers.Dropout(rate=dropout_rate)
self.answer_logits_output = tf.keras.layers.Dense(
units=1,
kernel_initializer=initializer,
use_bias=False,
name='predictions/transform/answer_logits/output')
def end_logits(self, inputs):
"""Computes the end logits."""
end_logits = self.end_logits_inner_dense(inputs)
end_logits = self.end_logits_layer_norm(end_logits)
end_logits = self.end_logits_output_dense(end_logits)
end_logits = tf.squeeze(end_logits)
if tf.rank(end_logits) > 2:
# shape = [batch_size, seq_length, start_n_top]
end_logits = tf.transpose(end_logits, [0, 2, 1])
return end_logits
def call(self,
sequence_data,
class_index,
position_mask=None,
start_positions=None,
training=False):
"""Implements call().
Einsum glossary:
- b: the batch size.
- l: the sequence length.
- h: the hidden size, or input width.
- k: the start/end top n.
Args:
sequence_data: The input sequence data of shape
(batch_size, seq_length, input_width).
class_index: The class indices of the inputs of shape (batch_size,).
position_mask: Invalid position mask such as query and special symbols
(e.g. PAD, SEP, CLS) of shape (batch_size,).
start_positions: The start positions of each example of shape
(batch_size,).
training: Whether or not this is the training phase.
Returns:
A dictionary with the keys 'cls_logits' and
- (if training) 'start_log_probs', 'end_log_probs'.
- (if inference/beam search) 'start_top_log_probs', 'start_top_index',
'end_top_log_probs', 'end_top_index'.
"""
seq_length = tf.shape(sequence_data)[1]
start_logits = self.start_logits_dense(sequence_data)
start_logits = tf.squeeze(start_logits, -1)
start_log_probs, masked_start_logits = _apply_position_mask(
start_logits, position_mask)
compute_with_beam_search = not training or start_positions is None
if compute_with_beam_search:
# Compute end logits using beam search.
start_top_log_probs, start_top_index = tf.nn.top_k(
start_log_probs, k=self._start_n_top)
start_index = tf.one_hot(
start_top_index, depth=seq_length, axis=-1, dtype=tf.float32)
# start_index: [batch_size, end_n_top, seq_length]
start_features = tf.einsum('blh,bkl->bkh', sequence_data, start_index)
start_features = tf.tile(start_features[:, None, :, :],
[1, seq_length, 1, 1])
# start_features: [batch_size, seq_length, end_n_top, input_width]
end_input = tf.tile(sequence_data[:, :, None],
[1, 1, self._start_n_top, 1])
end_input = tf.concat([end_input, start_features], axis=-1)
# end_input: [batch_size, seq_length, end_n_top, 2*input_width]
else:
start_positions = tf.reshape(start_positions, -1)
start_index = tf.one_hot(
start_positions, depth=seq_length, axis=-1, dtype=tf.float32)
# start_index: [batch_size, seq_length]
start_features = tf.einsum('blh,bl->bh', sequence_data, start_index)
start_features = tf.tile(start_features[:, None, :], [1, seq_length, 1])
# start_features: [batch_size, seq_length, input_width]
end_input = tf.concat([sequence_data, start_features],
axis=-1)
# end_input: [batch_size, seq_length, 2*input_width]
end_logits = self.end_logits(end_input)
end_log_probs, _ = _apply_position_mask(end_logits, position_mask)
output_dict = {}
if training:
output_dict['start_log_probs'] = start_log_probs
output_dict['end_log_probs'] = end_log_probs
else:
end_top_log_probs, end_top_index = tf.nn.top_k(
end_log_probs, k=self._end_n_top)
end_top_log_probs = tf.reshape(end_top_log_probs,
[-1, self._start_n_top * self._end_n_top])
end_top_index = tf.reshape(end_top_index,
[-1, self._start_n_top * self._end_n_top])
output_dict['start_top_log_probs'] = start_top_log_probs
output_dict['start_top_index'] = start_top_index
output_dict['end_top_log_probs'] = end_top_log_probs
output_dict['end_top_index'] = end_top_index
# get the representation of CLS
class_index = tf.one_hot(class_index, seq_length, axis=-1, dtype=tf.float32)
class_feature = tf.einsum('blh,bl->bh', sequence_data, class_index)
# get the representation of START
start_p = tf.nn.softmax(masked_start_logits, axis=-1)
start_feature = tf.einsum('blh,bl->bh', sequence_data, start_p)
answer_feature = tf.concat([start_feature, class_feature], -1)
answer_feature = self.answer_logits_inner(answer_feature)
answer_feature = self.answer_logits_dropout(answer_feature)
class_logits = self.answer_logits_output(answer_feature)
class_logits = tf.squeeze(class_logits, -1)
output_dict['class_logits'] = class_logits
return output_dict
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
...@@ -18,6 +18,8 @@ from __future__ import absolute_import ...@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -170,5 +172,141 @@ class SpanLabelingTest(keras_parameterized.TestCase): ...@@ -170,5 +172,141 @@ class SpanLabelingTest(keras_parameterized.TestCase):
_ = span_labeling.SpanLabeling(input_width=10, output='bad') _ = span_labeling.SpanLabeling(input_width=10, output='bad')
@keras_parameterized.run_all_keras_modes
class XLNetSpanLabelingTest(keras_parameterized.TestCase):
def test_basic_invocation_train(self):
batch_size = 2
seq_length = 8
hidden_size = 4
sequence_data = np.random.uniform(
size=(batch_size, seq_length, hidden_size)).astype('float32')
position_mask = np.random.uniform(
size=(batch_size, seq_length)).astype('float32')
class_index = np.random.uniform(size=(batch_size)).astype('uint8')
start_positions = np.zeros(shape=(batch_size)).astype('uint8')
layer = span_labeling.XLNetSpanLabeling(
input_width=hidden_size,
start_n_top=1,
end_n_top=1,
activation='tanh',
dropout_rate=0.,
initializer='glorot_uniform')
output = layer(sequence_data=sequence_data,
class_index=class_index,
position_mask=position_mask,
start_positions=start_positions,
training=True)
expected_keys = {
'start_log_probs', 'end_log_probs', 'class_logits',
}
self.assertSetEqual(expected_keys, set(output.keys()))
@parameterized.named_parameters(
('top_1', 1),
('top_n', 5))
def test_basic_invocation_beam_search(self, top_n):
batch_size = 2
seq_length = 8
hidden_size = 4
sequence_data = np.random.uniform(
size=(batch_size, seq_length, hidden_size)).astype('float32')
position_mask = np.random.uniform(
size=(batch_size, seq_length)).astype('float32')
class_index = np.random.uniform(size=(batch_size)).astype('uint8')
layer = span_labeling.XLNetSpanLabeling(
input_width=hidden_size,
start_n_top=top_n,
end_n_top=top_n,
activation='tanh',
dropout_rate=0.,
initializer='glorot_uniform')
output = layer(sequence_data=sequence_data,
class_index=class_index,
position_mask=position_mask,
training=False)
expected_keys = {
'start_top_log_probs', 'end_top_log_probs', 'class_logits',
'start_top_index', 'end_top_index',
}
self.assertSetEqual(expected_keys, set(output.keys()))
def test_functional_model_invocation(self):
"""Tests basic invocation of this layer wrapped by a Functional model."""
seq_length = 8
hidden_size = 4
batch_size = 2
sequence_data = tf.keras.Input(shape=(seq_length, hidden_size),
dtype=tf.float32)
class_index = tf.keras.Input(shape=(), dtype=tf.uint8)
position_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32)
start_positions = tf.keras.Input(shape=(), dtype=tf.float32)
layer = span_labeling.XLNetSpanLabeling(
input_width=hidden_size,
start_n_top=5,
end_n_top=5,
activation='tanh',
dropout_rate=0.,
initializer='glorot_uniform')
output = layer(sequence_data=sequence_data,
class_index=class_index,
position_mask=position_mask,
start_positions=start_positions)
model = tf.keras.Model(
inputs={
'sequence_data': sequence_data,
'class_index': class_index,
'position_mask': position_mask,
'start_positions': start_positions,
},
outputs=output)
sequence_data = tf.random.uniform(
shape=(batch_size, seq_length, hidden_size), dtype=tf.float32)
position_mask = tf.random.uniform(
shape=(batch_size, seq_length), dtype=tf.float32)
class_index = tf.ones(shape=(batch_size,), dtype=tf.uint8)
start_positions = tf.random.uniform(shape=(batch_size,), dtype=tf.float32)
inputs = dict(sequence_data=sequence_data,
position_mask=position_mask,
class_index=class_index,
start_positions=start_positions)
output = model(inputs)
self.assertIsInstance(output, dict)
# Test `call` with training flag.
output = model.call(inputs, training=True)
self.assertIsInstance(output, dict)
# Test `call` without training flag.
output = model.call(inputs, training=False)
self.assertIsInstance(output, dict)
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
network = span_labeling.XLNetSpanLabeling(
input_width=128,
start_n_top=5,
end_n_top=1,
activation='tanh',
dropout_rate=0.34,
initializer='zeros')
# Create another network object from the first object's config.
new_network = span_labeling.XLNetSpanLabeling.from_config(
network.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment