Commit e3a000ad authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 341642152
parent d0b78926
...@@ -695,19 +695,13 @@ def postprocess_output(all_examples, ...@@ -695,19 +695,13 @@ def postprocess_output(all_examples,
null_start_logit = result.start_logits[0] null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0] null_end_logit = result.end_logits[0]
start_indexes_and_logits = _get_best_indexes_and_logits( doc_offset = 0 if xlnet_format else feature.tokens.index("[SEP]") + 1
result=result,
n_best_size=n_best_size, for (start_index, start_logit,
start=True, end_index, end_logit) in _get_best_indexes_and_logits(
xlnet_format=xlnet_format)
end_indexes_and_logits = _get_best_indexes_and_logits(
result=result, result=result,
n_best_size=n_best_size, n_best_size=n_best_size,
start=False, xlnet_format=xlnet_format):
xlnet_format=xlnet_format)
doc_offset = 0 if xlnet_format else feature.tokens.index("[SEP]") + 1
for start_index, start_logit in start_indexes_and_logits:
for end_index, end_logit in end_indexes_and_logits:
# We could hypothetically create invalid predictions, e.g., predict # We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all # that the start of the span is in the question. We throw out all
# invalid predictions. # invalid predictions.
...@@ -752,7 +746,7 @@ def postprocess_output(all_examples, ...@@ -752,7 +746,7 @@ def postprocess_output(all_examples,
if len(nbest) >= n_best_size: if len(nbest) >= n_best_size:
break break
feature = features[pred.feature_index] feature = features[pred.feature_index]
if pred.start_index >= 0: # this is a non-null prediction if pred.start_index >= 0 or xlnet_format: # this is a non-null prediction
tok_start_to_orig_index = feature.tok_start_to_orig_index tok_start_to_orig_index = feature.tok_start_to_orig_index
tok_end_to_orig_index = feature.tok_end_to_orig_index tok_end_to_orig_index = feature.tok_end_to_orig_index
start_orig_pos = tok_start_to_orig_index[pred.start_index] start_orig_pos = tok_start_to_orig_index[pred.start_index]
...@@ -774,7 +768,7 @@ def postprocess_output(all_examples, ...@@ -774,7 +768,7 @@ def postprocess_output(all_examples,
start_logit=pred.start_logit, start_logit=pred.start_logit,
end_logit=pred.end_logit)) end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it # if we didn't inlude the empty option in the n-best, include it
if version_2_with_negative and not xlnet_format: if version_2_with_negative and not xlnet_format:
if "" not in seen_predictions: if "" not in seen_predictions:
nbest.append( nbest.append(
...@@ -814,6 +808,11 @@ def postprocess_output(all_examples, ...@@ -814,6 +808,11 @@ def postprocess_output(all_examples,
all_predictions[example.qas_id] = nbest_json[0]["text"] all_predictions[example.qas_id] = nbest_json[0]["text"]
else: else:
assert best_non_null_entry is not None assert best_non_null_entry is not None
if xlnet_format:
score_diff = score_null
scores_diff_json[example.qas_id] = score_diff
all_predictions[example.qas_id] = best_non_null_entry.text
else:
# predict "" iff the null score - the score of best non-null > threshold # predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - ( score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit) best_non_null_entry.end_logit)
...@@ -835,28 +834,27 @@ def write_to_json_files(json_records, json_file): ...@@ -835,28 +834,27 @@ def write_to_json_files(json_records, json_file):
def _get_best_indexes_and_logits(result, def _get_best_indexes_and_logits(result,
n_best_size, n_best_size,
start=False,
xlnet_format=False): xlnet_format=False):
"""Generates the n-best indexes and logits from a list.""" """Generates the n-best indexes and logits from a list."""
if xlnet_format: if xlnet_format:
for i in range(n_best_size): for i in range(n_best_size):
for j in range(n_best_size): for j in range(n_best_size):
j_index = i * n_best_size + j j_index = i * n_best_size + j
if start: yield (result.start_indexes[i], result.start_logits[i],
yield result.start_indexes[i], result.start_logits[i] result.end_indexes[j_index], result.end_logits[j_index])
else:
yield result.end_indexes[j_index], result.end_logits[j_index]
else: else:
if start: start_index_and_score = sorted(enumerate(result.start_logits),
logits = result.start_logits key=lambda x: x[1], reverse=True)
else: end_index_and_score = sorted(enumerate(result.end_logits),
logits = result.end_logits
index_and_score = sorted(enumerate(logits),
key=lambda x: x[1], reverse=True) key=lambda x: x[1], reverse=True)
for i in range(len(index_and_score)): for i in range(len(start_index_and_score)):
if i >= n_best_size: if i >= n_best_size:
break break
yield index_and_score[i] for j in range(len(end_index_and_score)):
if j >= n_best_size:
break
yield (start_index_and_score[i][0], start_index_and_score[i][1],
end_index_and_score[j][0], end_index_and_score[j][1])
def _compute_softmax(scores): def _compute_softmax(scores):
...@@ -885,13 +883,12 @@ def _compute_softmax(scores): ...@@ -885,13 +883,12 @@ def _compute_softmax(scores):
class FeatureWriter(object): class FeatureWriter(object):
"""Writes InputFeature to TF example file.""" """Writes InputFeature to TF example file."""
def __init__(self, filename, is_training, xlnet_format=False): def __init__(self, filename, is_training):
self.filename = filename self.filename = filename
self.is_training = is_training self.is_training = is_training
self.num_features = 0 self.num_features = 0
tf.io.gfile.makedirs(os.path.dirname(filename)) tf.io.gfile.makedirs(os.path.dirname(filename))
self._writer = tf.io.TFRecordWriter(filename) self._writer = tf.io.TFRecordWriter(filename)
self._xlnet_format = xlnet_format
def process_feature(self, feature): def process_feature(self, feature):
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
...@@ -907,8 +904,9 @@ class FeatureWriter(object): ...@@ -907,8 +904,9 @@ class FeatureWriter(object):
features["input_ids"] = create_int_feature(feature.input_ids) features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask) features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids) features["segment_ids"] = create_int_feature(feature.segment_ids)
if self._xlnet_format: if feature.paragraph_mask:
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask) features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
if feature.class_index:
features["class_index"] = create_int_feature([feature.class_index]) features["class_index"] = create_int_feature([feature.class_index])
if self.is_training: if self.is_training:
...@@ -943,7 +941,7 @@ def generate_tf_record_from_json_file(input_file_path, ...@@ -943,7 +941,7 @@ def generate_tf_record_from_json_file(input_file_path,
tokenizer = tokenization.FullSentencePieceTokenizer( tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=sp_model_file) sp_model_file=sp_model_file)
train_writer = FeatureWriter( train_writer = FeatureWriter(
filename=output_path, is_training=True, xlnet_format=xlnet_format) filename=output_path, is_training=True)
number_of_examples = convert_examples_to_features( number_of_examples = convert_examples_to_features(
examples=train_examples, examples=train_examples,
tokenizer=tokenizer, tokenizer=tokenizer,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""XLNet cls-token classifier.""" """XLNet models."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Any, Mapping, Union from typing import Any, Mapping, Union
...@@ -127,7 +127,7 @@ class XLNetSpanLabeler(tf.keras.Model): ...@@ -127,7 +127,7 @@ class XLNetSpanLabeler(tf.keras.Model):
start_n_top: Beam size for span start. start_n_top: Beam size for span start.
end_n_top: Beam size for span end. end_n_top: Beam size for span end.
dropout_rate: The dropout rate for the span labeling layer. dropout_rate: The dropout rate for the span labeling layer.
span_labeling_activation span_labeling_activation: The activation for the span labeling head.
initializer: The initializer (if any) to use in the span labeling network. initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
""" """
...@@ -135,9 +135,9 @@ class XLNetSpanLabeler(tf.keras.Model): ...@@ -135,9 +135,9 @@ class XLNetSpanLabeler(tf.keras.Model):
def __init__( def __init__(
self, self,
network: Union[tf.keras.layers.Layer, tf.keras.Model], network: Union[tf.keras.layers.Layer, tf.keras.Model],
start_n_top: int, start_n_top: int = 5,
end_n_top: int, end_n_top: int = 5,
dropout_rate: float, dropout_rate: float = 0.1,
span_labeling_activation: tf.keras.initializers.Initializer = 'tanh', span_labeling_activation: tf.keras.initializers.Initializer = 'tanh',
initializer: tf.keras.initializers.Initializer = 'glorot_uniform', initializer: tf.keras.initializers.Initializer = 'glorot_uniform',
**kwargs): **kwargs):
...@@ -165,24 +165,27 @@ class XLNetSpanLabeler(tf.keras.Model): ...@@ -165,24 +165,27 @@ class XLNetSpanLabeler(tf.keras.Model):
initializer=self._initializer) initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]): def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_ids'] input_ids = inputs['input_word_ids']
segment_ids = inputs['segment_ids'] segment_ids = inputs['input_type_ids']
input_mask = inputs['input_mask'] input_mask = inputs['input_mask']
class_index = inputs['class_index']
paragraph_mask = inputs['paragraph_mask']
start_positions = inputs.get('start_positions', None)
class_index = tf.reshape(inputs['class_index'], [-1]) attention_output, _ = self._network(
position_mask = inputs['position_mask']
start_positions = inputs['start_positions']
attention_output, new_states = self._network(
input_ids=input_ids, input_ids=input_ids,
segment_ids=segment_ids, segment_ids=segment_ids,
input_mask=input_mask) input_mask=input_mask)
outputs = self.span_labeling( outputs = self.span_labeling(
sequence_data=attention_output, sequence_data=attention_output,
class_index=class_index, class_index=class_index,
position_mask=position_mask, paragraph_mask=paragraph_mask,
start_positions=start_positions) start_positions=start_positions)
return outputs, new_states return outputs
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self): def get_config(self):
return self._config return self._config
......
...@@ -137,9 +137,9 @@ class XLNetClassifierTest(keras_parameterized.TestCase): ...@@ -137,9 +137,9 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class XLNetSpanLabelerTest(keras_parameterized.TestCase): class XLNetSpanLabelerTest(keras_parameterized.TestCase):
@parameterized.parameters(1, 2) def test_xlnet_trainer(self):
def test_xlnet_trainer(self, top_n):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
top_n = 2
seq_length = 4 seq_length = 4
# Build a simple XLNet based network to use with the XLNet trainer. # Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base() xlnet_base = _get_xlnet_base()
...@@ -153,46 +153,50 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase): ...@@ -153,46 +153,50 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase):
span_labeling_activation='tanh', span_labeling_activation='tanh',
dropout_rate=0.1) dropout_rate=0.1)
inputs = dict( inputs = dict(
input_ids=tf.keras.layers.Input( input_word_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'), shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
segment_ids=tf.keras.layers.Input( input_type_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='segment_ids'), shape=(seq_length,), dtype=tf.int32, name='input_type_ids'),
input_mask=tf.keras.layers.Input( input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='input_mask'), shape=(seq_length,), dtype=tf.float32, name='input_mask'),
position_mask=tf.keras.layers.Input( paragraph_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='position_mask'), shape=(seq_length,), dtype=tf.float32, name='paragraph_mask'),
class_index=tf.keras.layers.Input( class_index=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='class_index'), shape=(), dtype=tf.int32, name='class_index'),
start_positions=tf.keras.layers.Input( start_positions=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='start_positions')) shape=(), dtype=tf.int32, name='start_positions'))
outputs, _ = xlnet_trainer_model(inputs) outputs = xlnet_trainer_model(inputs)
self.assertIsInstance(outputs, dict) self.assertIsInstance(outputs, dict)
# Test tensor value calls for the created model. # Test tensor value calls for the created model.
batch_size = 2 batch_size = 2
sequence_shape = (batch_size, seq_length) sequence_shape = (batch_size, seq_length)
inputs = dict( inputs = dict(
input_ids=np.random.randint(10, size=sequence_shape, dtype='int32'), input_word_ids=np.random.randint(
segment_ids=np.random.randint(2, size=sequence_shape, dtype='int32'), 10, size=sequence_shape, dtype='int32'),
input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('float32'), input_mask=np.random.randint(2, size=sequence_shape).astype('float32'),
position_mask=np.random.randint( paragraph_mask=np.random.randint(
1, size=(sequence_shape)).astype('float32'), 1, size=(sequence_shape)).astype('float32'),
class_index=np.random.randint(1, size=(batch_size)).astype('uint8'), class_index=np.random.randint(1, size=(batch_size)).astype('uint8'),
start_positions=tf.random.uniform( start_positions=tf.random.uniform(
shape=(batch_size,), maxval=5, dtype=tf.int32)) shape=(batch_size,), maxval=5, dtype=tf.int32))
outputs, _ = xlnet_trainer_model(inputs)
expected_inference_keys = { common_keys = {
'start_top_log_probs', 'end_top_log_probs', 'class_logits', 'start_logits', 'end_logits', 'start_predictions', 'end_predictions',
'start_top_index', 'end_top_index', 'class_logits',
} }
self.assertSetEqual(expected_inference_keys, set(outputs.keys())) inference_keys = {
'start_top_predictions', 'end_top_predictions', 'start_top_index',
'end_top_index',
}
outputs = xlnet_trainer_model(inputs)
self.assertSetEqual(common_keys | inference_keys, set(outputs.keys()))
outputs, _ = xlnet_trainer_model(inputs, training=True) outputs = xlnet_trainer_model(inputs, training=True)
self.assertIsInstance(outputs, dict) self.assertIsInstance(outputs, dict)
expected_train_keys = { self.assertSetEqual(common_keys, set(outputs.keys()))
'start_log_probs', 'end_log_probs', 'class_logits'
}
self.assertSetEqual(expected_train_keys, set(outputs.keys()))
self.assertIsInstance(outputs, dict) self.assertIsInstance(outputs, dict)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
......
...@@ -18,11 +18,9 @@ import collections ...@@ -18,11 +18,9 @@ import collections
import tensorflow as tf import tensorflow as tf
def _apply_position_mask(logits, position_mask): def _apply_paragraph_mask(logits, paragraph_mask):
"""Applies a position mask to calculated logits.""" """Applies a position mask to calculated logits."""
if tf.rank(logits) != tf.rank(position_mask): masked_logits = logits * (paragraph_mask) - 1e30 * (1 - paragraph_mask)
position_mask = position_mask[:, None, :]
masked_logits = logits * (1 - position_mask) - 1e30 * position_mask
return tf.nn.log_softmax(masked_logits, -1), masked_logits return tf.nn.log_softmax(masked_logits, -1), masked_logits
...@@ -137,8 +135,8 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -137,8 +135,8 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
input_width, input_width,
start_n_top, start_n_top=5,
end_n_top, end_n_top=5,
activation='tanh', activation='tanh',
dropout_rate=0., dropout_rate=0.,
initializer='glorot_uniform', initializer='glorot_uniform',
...@@ -152,6 +150,8 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -152,6 +150,8 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
'end_n_top': end_n_top, 'end_n_top': end_n_top,
'dropout_rate': dropout_rate, 'dropout_rate': dropout_rate,
} }
if start_n_top <= 1:
raise ValueError('`start_n_top` must be greater than 1.')
self._start_n_top = start_n_top self._start_n_top = start_n_top
self._end_n_top = end_n_top self._end_n_top = end_n_top
self.start_logits_dense = tf.keras.layers.Dense( self.start_logits_dense = tf.keras.layers.Dense(
...@@ -210,16 +210,12 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -210,16 +210,12 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
end_logits = self.end_logits_layer_norm(end_logits) end_logits = self.end_logits_layer_norm(end_logits)
end_logits = self.end_logits_output_dense(end_logits) end_logits = self.end_logits_output_dense(end_logits)
end_logits = tf.squeeze(end_logits) end_logits = tf.squeeze(end_logits)
if tf.rank(end_logits) > 2:
# shape = [B, S, K] -> [B, K, S]
end_logits = tf.transpose(end_logits, [0, 2, 1])
return end_logits return end_logits
def call(self, def call(self,
sequence_data, sequence_data,
class_index, class_index,
position_mask=None, paragraph_mask=None,
start_positions=None, start_positions=None,
training=False): training=False):
"""Implements call(). """Implements call().
...@@ -234,31 +230,35 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -234,31 +230,35 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
sequence_data: The input sequence data of shape sequence_data: The input sequence data of shape
(batch_size, seq_length, input_width). (batch_size, seq_length, input_width).
class_index: The class indices of the inputs of shape (batch_size,). class_index: The class indices of the inputs of shape (batch_size,).
position_mask: Invalid position mask such as query and special symbols paragraph_mask: Invalid position mask such as query and special symbols
(e.g. PAD, SEP, CLS) of shape (batch_size,). (e.g. PAD, SEP, CLS) of shape (batch_size,).
start_positions: The start positions of each example of shape start_positions: The start positions of each example of shape
(batch_size,). (batch_size,).
training: Whether or not this is the training phase. training: Whether or not this is the training phase.
Returns: Returns:
A dictionary with the keys 'cls_logits' and A dictionary with the keys 'start_predictions', 'end_predictions',
- (if training) 'start_log_probs', 'end_log_probs'. 'start_logits', 'end_logits'.
- (if inference/beam search) 'start_top_log_probs', 'start_top_index',
'end_top_log_probs', 'end_top_index'. If inference, then 'start_top_predictions', 'start_top_index',
'end_top_predictions', 'end_top_index' are also included.
""" """
paragraph_mask = tf.cast(paragraph_mask, dtype=sequence_data.dtype)
class_index = tf.reshape(class_index, [-1])
seq_length = tf.shape(sequence_data)[1] seq_length = tf.shape(sequence_data)[1]
start_logits = self.start_logits_dense(sequence_data) start_logits = self.start_logits_dense(sequence_data)
start_logits = tf.squeeze(start_logits, -1) start_logits = tf.squeeze(start_logits, -1)
start_log_probs, masked_start_logits = _apply_position_mask( start_predictions, masked_start_logits = _apply_paragraph_mask(
start_logits, position_mask) start_logits, paragraph_mask)
compute_with_beam_search = not training or start_positions is None compute_with_beam_search = not training or start_positions is None
if compute_with_beam_search: if compute_with_beam_search:
# Compute end logits using beam search. # Compute end logits using beam search.
start_top_log_probs, start_top_index = tf.nn.top_k( start_top_predictions, start_top_index = tf.nn.top_k(
start_log_probs, k=self._start_n_top) start_predictions, k=self._start_n_top)
start_index = tf.one_hot( start_index = tf.one_hot(
start_top_index, depth=seq_length, axis=-1, dtype=tf.float32) start_top_index, depth=seq_length, axis=-1, dtype=tf.float32)
# start_index: [batch_size, end_n_top, seq_length] # start_index: [batch_size, end_n_top, seq_length]
...@@ -272,8 +272,13 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -272,8 +272,13 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
[1, 1, self._start_n_top, 1]) [1, 1, self._start_n_top, 1])
end_input = tf.concat([end_input, start_features], axis=-1) end_input = tf.concat([end_input, start_features], axis=-1)
# end_input: [batch_size, seq_length, end_n_top, 2*input_width] # end_input: [batch_size, seq_length, end_n_top, 2*input_width]
paragraph_mask = paragraph_mask[:, None, :]
end_logits = self.end_logits(end_input)
# Note: this will fail if start_n_top is not >= 1.
end_logits = tf.transpose(end_logits, [0, 2, 1])
else: else:
start_positions = tf.reshape(start_positions, -1) start_positions = tf.reshape(start_positions, [-1])
start_index = tf.one_hot( start_index = tf.one_hot(
start_positions, depth=seq_length, axis=-1, dtype=tf.float32) start_positions, depth=seq_length, axis=-1, dtype=tf.float32)
# start_index: [batch_size, seq_length] # start_index: [batch_size, seq_length]
...@@ -285,24 +290,28 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -285,24 +290,28 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
end_input = tf.concat([sequence_data, start_features], end_input = tf.concat([sequence_data, start_features],
axis=-1) axis=-1)
# end_input: [batch_size, seq_length, 2*input_width] # end_input: [batch_size, seq_length, 2*input_width]
end_logits = self.end_logits(end_input) end_logits = self.end_logits(end_input)
end_log_probs, _ = _apply_position_mask(end_logits, position_mask) end_predictions, masked_end_logits = _apply_paragraph_mask(
end_logits, paragraph_mask)
output_dict = {}
if training: output_dict = dict(
output_dict['start_log_probs'] = start_log_probs start_predictions=start_predictions,
output_dict['end_log_probs'] = end_log_probs end_predictions=end_predictions,
else: start_logits=masked_start_logits,
end_top_log_probs, end_top_index = tf.nn.top_k( end_logits=masked_end_logits)
end_log_probs, k=self._end_n_top)
end_top_log_probs = tf.reshape(end_top_log_probs, if not training:
end_top_predictions, end_top_index = tf.nn.top_k(
end_predictions, k=self._end_n_top)
end_top_predictions = tf.reshape(
end_top_predictions,
[-1, self._start_n_top * self._end_n_top]) [-1, self._start_n_top * self._end_n_top])
end_top_index = tf.reshape(end_top_index, end_top_index = tf.reshape(
end_top_index,
[-1, self._start_n_top * self._end_n_top]) [-1, self._start_n_top * self._end_n_top])
output_dict['start_top_log_probs'] = start_top_log_probs output_dict['start_top_predictions'] = start_top_predictions
output_dict['start_top_index'] = start_top_index output_dict['start_top_index'] = start_top_index
output_dict['end_top_log_probs'] = end_top_log_probs output_dict['end_top_predictions'] = end_top_predictions
output_dict['end_top_index'] = end_top_index output_dict['end_top_index'] = end_top_index
# get the representation of CLS # get the representation of CLS
......
...@@ -13,13 +13,6 @@ ...@@ -13,13 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for span_labeling network.""" """Tests for span_labeling network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -181,39 +174,38 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase): ...@@ -181,39 +174,38 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
hidden_size = 4 hidden_size = 4
sequence_data = np.random.uniform( sequence_data = np.random.uniform(
size=(batch_size, seq_length, hidden_size)).astype('float32') size=(batch_size, seq_length, hidden_size)).astype('float32')
position_mask = np.random.uniform( paragraph_mask = np.random.uniform(
size=(batch_size, seq_length)).astype('float32') size=(batch_size, seq_length)).astype('float32')
class_index = np.random.uniform(size=(batch_size)).astype('uint8') class_index = np.random.uniform(size=(batch_size)).astype('uint8')
start_positions = np.zeros(shape=(batch_size)).astype('uint8') start_positions = np.zeros(shape=(batch_size)).astype('uint8')
layer = span_labeling.XLNetSpanLabeling( layer = span_labeling.XLNetSpanLabeling(
input_width=hidden_size, input_width=hidden_size,
start_n_top=1, start_n_top=2,
end_n_top=1, end_n_top=2,
activation='tanh', activation='tanh',
dropout_rate=0., dropout_rate=0.,
initializer='glorot_uniform') initializer='glorot_uniform')
output = layer(sequence_data=sequence_data, output = layer(sequence_data=sequence_data,
class_index=class_index, class_index=class_index,
position_mask=position_mask, paragraph_mask=paragraph_mask,
start_positions=start_positions, start_positions=start_positions,
training=True) training=True)
expected_keys = { expected_keys = {
'start_log_probs', 'end_log_probs', 'class_logits', 'start_logits', 'end_logits', 'class_logits', 'start_predictions',
'end_predictions',
} }
self.assertSetEqual(expected_keys, set(output.keys())) self.assertSetEqual(expected_keys, set(output.keys()))
@parameterized.named_parameters( def test_basic_invocation_beam_search(self):
('top_1', 1),
('top_n', 5))
def test_basic_invocation_beam_search(self, top_n):
batch_size = 2 batch_size = 2
seq_length = 8 seq_length = 8
hidden_size = 4 hidden_size = 4
top_n = 5
sequence_data = np.random.uniform( sequence_data = np.random.uniform(
size=(batch_size, seq_length, hidden_size)).astype('float32') size=(batch_size, seq_length, hidden_size)).astype('float32')
position_mask = np.random.uniform( paragraph_mask = np.random.uniform(
size=(batch_size, seq_length)).astype('float32') size=(batch_size, seq_length)).astype('float32')
class_index = np.random.uniform(size=(batch_size)).astype('uint8') class_index = np.random.uniform(size=(batch_size)).astype('uint8')
...@@ -226,11 +218,12 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase): ...@@ -226,11 +218,12 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
initializer='glorot_uniform') initializer='glorot_uniform')
output = layer(sequence_data=sequence_data, output = layer(sequence_data=sequence_data,
class_index=class_index, class_index=class_index,
position_mask=position_mask, paragraph_mask=paragraph_mask,
training=False) training=False)
expected_keys = { expected_keys = {
'start_top_log_probs', 'end_top_log_probs', 'class_logits', 'start_top_predictions', 'end_top_predictions', 'class_logits',
'start_top_index', 'end_top_index', 'start_top_index', 'end_top_index', 'start_logits',
'end_logits', 'start_predictions', 'end_predictions'
} }
self.assertSetEqual(expected_keys, set(output.keys())) self.assertSetEqual(expected_keys, set(output.keys()))
...@@ -243,7 +236,7 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase): ...@@ -243,7 +236,7 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
sequence_data = tf.keras.Input(shape=(seq_length, hidden_size), sequence_data = tf.keras.Input(shape=(seq_length, hidden_size),
dtype=tf.float32) dtype=tf.float32)
class_index = tf.keras.Input(shape=(), dtype=tf.uint8) class_index = tf.keras.Input(shape=(), dtype=tf.uint8)
position_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32) paragraph_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32)
start_positions = tf.keras.Input(shape=(), dtype=tf.int32) start_positions = tf.keras.Input(shape=(), dtype=tf.int32)
layer = span_labeling.XLNetSpanLabeling( layer = span_labeling.XLNetSpanLabeling(
...@@ -256,27 +249,27 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase): ...@@ -256,27 +249,27 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
output = layer(sequence_data=sequence_data, output = layer(sequence_data=sequence_data,
class_index=class_index, class_index=class_index,
position_mask=position_mask, paragraph_mask=paragraph_mask,
start_positions=start_positions) start_positions=start_positions)
model = tf.keras.Model( model = tf.keras.Model(
inputs={ inputs={
'sequence_data': sequence_data, 'sequence_data': sequence_data,
'class_index': class_index, 'class_index': class_index,
'position_mask': position_mask, 'paragraph_mask': paragraph_mask,
'start_positions': start_positions, 'start_positions': start_positions,
}, },
outputs=output) outputs=output)
sequence_data = tf.random.uniform( sequence_data = tf.random.uniform(
shape=(batch_size, seq_length, hidden_size), dtype=tf.float32) shape=(batch_size, seq_length, hidden_size), dtype=tf.float32)
position_mask = tf.random.uniform( paragraph_mask = tf.random.uniform(
shape=(batch_size, seq_length), dtype=tf.float32) shape=(batch_size, seq_length), dtype=tf.float32)
class_index = tf.ones(shape=(batch_size,), dtype=tf.uint8) class_index = tf.ones(shape=(batch_size,), dtype=tf.uint8)
start_positions = tf.random.uniform( start_positions = tf.random.uniform(
shape=(batch_size,), maxval=5, dtype=tf.int32) shape=(batch_size,), maxval=5, dtype=tf.int32)
inputs = dict(sequence_data=sequence_data, inputs = dict(sequence_data=sequence_data,
position_mask=position_mask, paragraph_mask=paragraph_mask,
class_index=class_index, class_index=class_index,
start_positions=start_positions) start_positions=start_positions)
......
...@@ -629,6 +629,7 @@ class XLNetBase(tf.keras.layers.Layer): ...@@ -629,6 +629,7 @@ class XLNetBase(tf.keras.layers.Layer):
"enabled. Please enable `two_stream` to enable two " "enabled. Please enable `two_stream` to enable two "
"stream attention.") "stream attention.")
dtype = input_mask.dtype if input_mask is not None else tf.float32
query_attention_mask, content_attention_mask = _compute_attention_mask( query_attention_mask, content_attention_mask = _compute_attention_mask(
input_mask=input_mask, input_mask=input_mask,
permutation_mask=permutation_mask, permutation_mask=permutation_mask,
...@@ -636,7 +637,7 @@ class XLNetBase(tf.keras.layers.Layer): ...@@ -636,7 +637,7 @@ class XLNetBase(tf.keras.layers.Layer):
seq_length=seq_length, seq_length=seq_length,
memory_length=memory_length, memory_length=memory_length,
batch_size=batch_size, batch_size=batch_size,
dtype=tf.float32) dtype=dtype)
relative_position_encoding = _compute_positional_encoding( relative_position_encoding = _compute_positional_encoding(
attention_type=self._attention_type, attention_type=self._attention_type,
position_encoding_layer=self.position_encoding, position_encoding_layer=self.position_encoding,
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Question answering task.""" """Question answering task."""
import collections
import json import json
import os import os
from typing import List, Optional
from absl import logging from absl import logging
import dataclasses import dataclasses
...@@ -58,6 +58,17 @@ class QuestionAnsweringConfig(cfg.TaskConfig): ...@@ -58,6 +58,17 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
@dataclasses.dataclass
class RawAggregatedResult:
"""Raw representation for SQuAD predictions."""
unique_id: int
start_logits: List[float]
end_logits: List[float]
start_indexes: Optional[List[int]] = None
end_indexes: Optional[List[int]] = None
class_logits: Optional[float] = None
@task_factory.register_task_cls(QuestionAnsweringConfig) @task_factory.register_task_cls(QuestionAnsweringConfig)
class QuestionAnsweringTask(base_task.Task): class QuestionAnsweringTask(base_task.Task):
"""Task object for question answering.""" """Task object for question answering."""
...@@ -91,7 +102,6 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -91,7 +102,6 @@ class QuestionAnsweringTask(base_task.Task):
else: else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder) encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get() encoder_cfg = self.task_config.model.encoder.get()
# Currently, we only supports bert-style question answering finetuning.
return models.BertSpanLabeler( return models.BertSpanLabeler(
network=encoder_network, network=encoder_network,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
...@@ -147,6 +157,7 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -147,6 +157,7 @@ class QuestionAnsweringTask(base_task.Task):
kwargs['do_lower_case'] = params.do_lower_case kwargs['do_lower_case'] = params.do_lower_case
kwargs['tokenizer'] = tokenization.FullSentencePieceTokenizer( kwargs['tokenizer'] = tokenization.FullSentencePieceTokenizer(
sp_model_file=params.vocab_file) sp_model_file=params.vocab_file)
kwargs['xlnet_format'] = self.task_config.model.encoder.type == 'xlnet'
elif params.tokenization == 'WordPiece': elif params.tokenization == 'WordPiece':
kwargs['tokenizer'] = tokenization.FullTokenizer( kwargs['tokenizer'] = tokenization.FullTokenizer(
vocab_file=params.vocab_file, do_lower_case=params.do_lower_case) vocab_file=params.vocab_file, do_lower_case=params.do_lower_case)
...@@ -176,7 +187,8 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -176,7 +187,8 @@ class QuestionAnsweringTask(base_task.Task):
input_type_ids=dummy_ids) input_type_ids=dummy_ids)
y = dict( y = dict(
start_positions=tf.constant(0, dtype=tf.int32), start_positions=tf.constant(0, dtype=tf.int32),
end_positions=tf.constant(1, dtype=tf.int32)) end_positions=tf.constant(1, dtype=tf.int32),
is_impossible=tf.constant(0, dtype=tf.int32))
return (x, y) return (x, y)
dataset = tf.data.Dataset.range(1) dataset = tf.data.Dataset.range(1)
...@@ -235,25 +247,22 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -235,25 +247,22 @@ class QuestionAnsweringTask(base_task.Task):
} }
return logs return logs
raw_aggregated_result = collections.namedtuple(
'RawResult', ['unique_id', 'start_logits', 'end_logits'])
def aggregate_logs(self, state=None, step_outputs=None): def aggregate_logs(self, state=None, step_outputs=None):
assert step_outputs is not None, 'Got no logs from self.validation_step.' assert step_outputs is not None, 'Got no logs from self.validation_step.'
if state is None: if state is None:
state = [] state = []
for unique_ids, start_logits, end_logits in zip( for outputs in zip(step_outputs['unique_ids'],
step_outputs['unique_ids'], step_outputs['start_logits'], step_outputs['start_logits'],
step_outputs['end_logits']): step_outputs['end_logits']):
u_ids, s_logits, e_logits = (unique_ids.numpy(), start_logits.numpy(), numpy_values = [
end_logits.numpy()) output.numpy() for output in outputs if output is not None]
for values in zip(u_ids, s_logits, e_logits):
state.append( for values in zip(*numpy_values):
self.raw_aggregated_result( state.append(RawAggregatedResult(
unique_id=values[0], unique_id=values[0],
start_logits=values[1].tolist(), start_logits=values[1],
end_logits=values[2].tolist())) end_logits=values[2]))
return state return state
def reduce_aggregated_logs(self, aggregated_logs): def reduce_aggregated_logs(self, aggregated_logs):
...@@ -299,6 +308,127 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -299,6 +308,127 @@ class QuestionAnsweringTask(base_task.Task):
return eval_metrics return eval_metrics
@dataclasses.dataclass
class XLNetQuestionAnsweringConfig(QuestionAnsweringConfig):
"""The config for the XLNet variation of QuestionAnswering."""
pass
@task_factory.register_task_cls(XLNetQuestionAnsweringConfig)
class XLNetQuestionAnsweringTask(QuestionAnsweringTask):
"""XLNet variant of the Question Answering Task.
The main differences include:
- The encoder is an `XLNetBase` class.
- The `SpanLabeling` head is an instance of `XLNetSpanLabeling` which
predicts start/end positions and impossibility score. During inference,
it predicts the top N scores and indexes.
"""
def build_model(self):
if self.task_config.hub_module_url and self.task_config.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
encoder_network = utils.get_encoder_from_hub(
self.task_config.hub_module_url)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
return models.XLNetSpanLabeler(
network=encoder_network,
start_n_top=self.task_config.n_best_size,
end_n_top=self.task_config.n_best_size,
initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
start_positions = labels['start_positions']
end_positions = labels['end_positions']
is_impossible = labels['is_impossible']
is_impossible = tf.cast(tf.reshape(is_impossible, [-1]), tf.float32)
start_logits = model_outputs['start_logits']
end_logits = model_outputs['end_logits']
class_logits = model_outputs['class_logits']
start_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
start_positions, start_logits)
end_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
end_positions, end_logits)
is_impossible_loss = tf.keras.losses.binary_crossentropy(
is_impossible, class_logits, from_logits=True)
loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
loss += tf.reduce_mean(is_impossible_loss) / 2
return loss
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
start_logits = model_outputs['start_logits']
end_logits = model_outputs['end_logits']
metrics['start_position_accuracy'].update_state(labels['start_positions'],
start_logits)
metrics['end_position_accuracy'].update_state(labels['end_positions'],
end_logits)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
start_logits = model_outputs['start_logits']
end_logits = model_outputs['end_logits']
compiled_metrics.update_state(
y_true=labels, # labels has keys 'start_positions' and 'end_positions'.
y_pred={
'start_positions': start_logits,
'end_positions': end_logits,
})
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
features, _ = inputs
unique_ids = features.pop('unique_ids')
model_outputs = self.inference_step(features, model)
start_top_predictions = model_outputs['start_top_predictions']
end_top_predictions = model_outputs['end_top_predictions']
start_indexes = model_outputs['start_top_index']
end_indexes = model_outputs['end_top_index']
class_logits = model_outputs['class_logits']
logs = {
self.loss: 0.0, # TODO(lehou): compute the real validation loss.
'unique_ids': unique_ids,
'start_top_predictions': start_top_predictions,
'end_top_predictions': end_top_predictions,
'start_indexes': start_indexes,
'end_indexes': end_indexes,
'class_logits': class_logits,
}
return logs
def aggregate_logs(self, state=None, step_outputs=None):
assert step_outputs is not None, 'Got no logs from self.validation_step.'
if state is None:
state = []
for outputs in zip(step_outputs['unique_ids'],
step_outputs['start_top_predictions'],
step_outputs['end_top_predictions'],
step_outputs['start_indexes'],
step_outputs['end_indexes'],
step_outputs['class_logits']):
numpy_values = [
output.numpy() for output in outputs]
for (unique_id, start_top_predictions, end_top_predictions, start_indexes,
end_indexes, class_logits) in zip(*numpy_values):
state.append(RawAggregatedResult(
unique_id=unique_id,
start_logits=start_top_predictions.tolist(),
end_logits=end_top_predictions.tolist(),
start_indexes=start_indexes.tolist(),
end_indexes=end_indexes.tolist(),
class_logits=class_logits))
return state
def predict(task: QuestionAnsweringTask, params: cfg.DataConfig, def predict(task: QuestionAnsweringTask, params: cfg.DataConfig,
model: tf.keras.Model): model: tf.keras.Model):
"""Predicts on the input data. """Predicts on the input data.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment