Commit 5a2cf36f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into newavarecords

parents 258ddfc3 a829e648
...@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor): ...@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
def get_activation(identifier): def get_activation(identifier):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`. """Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
......
...@@ -14,23 +14,61 @@ ...@@ -14,23 +14,61 @@
# ============================================================================== # ==============================================================================
"""ALBERT classification finetuning runner in tf2.x.""" """ALBERT classification finetuning runner in tf2.x."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import run_classifier as run_classifier_bert from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def predict(strategy, albert_config, input_meta_data, predict_input_fn):
"""Function outputs both the ground truth predictions as .tsv files."""
with strategy.scope():
classifier_model = bert_models.classifier_model(
albert_config, input_meta_data['num_labels'])[0]
checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or
tf.train.latest_checkpoint(FLAGS.model_dir))
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
preds, ground_truth = run_classifier_bert.get_predictions_and_labels(
strategy, classifier_model, predict_input_fn, return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****')
for probabilities in preds:
output_line = '\t'.join(
str(class_probability)
for class_probability in probabilities) + '\n'
writer.write(output_line)
ground_truth_labels_file = os.path.join(FLAGS.model_dir,
'output_labels.tsv')
with tf.io.gfile.GFile(ground_truth_labels_file, 'w') as writer:
logging.info('***** Ground truth results *****')
for label in ground_truth:
output_line = '\t'.join(str(label)) + '\n'
writer.write(output_line)
return
def main(_): def main(_):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
...@@ -56,9 +94,14 @@ def main(_): ...@@ -56,9 +94,14 @@ def main(_):
albert_config = albert_configs.AlbertConfig.from_json_file( albert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file) FLAGS.bert_config_file)
run_classifier_bert.run_bert(strategy, input_meta_data, albert_config, if FLAGS.mode == 'train_and_eval':
train_input_fn, eval_input_fn) run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
train_input_fn, eval_input_fn)
elif FLAGS.mode == 'predict':
predict(strategy, albert_config, input_meta_data, eval_input_fn)
else:
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
return
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file') flags.mark_flag_as_required('bert_config_file')
......
...@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig, ...@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
do_lower_case, vocab_file) do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config) core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder) checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed() checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file) core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False) core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf") core_model.save(hub_destination, include_optimizer=False, save_format="tf")
......
...@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text, ...@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
raise ValueError('model must be a tf.keras.Model object.') raise ValueError('model must be a tf.keras.Model object.')
if checkpoint_dir: if checkpoint_dir:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if restore_model_using_load_weights: if restore_model_using_load_weights:
model_weight_path = os.path.join(checkpoint_dir, 'checkpoint') model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
assert tf.io.gfile.exists(model_weight_path) assert tf.io.gfile.exists(model_weight_path)
model.load_weights(model_weight_path) model.load_weights(model_weight_path)
# tf.train.Checkpoint API was used via custom training loop logic.
else: else:
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model)
......
...@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir): ...@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
@deprecation.deprecated( @deprecation.deprecated(
None, 'This function is deprecated. Please use Keras compile/fit instead.') None, 'This function is deprecated and we do not expect adding new '
'functionalities. Please do not have your code depending '
'on this library.')
def run_customized_training_loop( def run_customized_training_loop(
# pylint: disable=invalid-name # pylint: disable=invalid-name
_sentinel=None, _sentinel=None,
...@@ -557,7 +559,6 @@ def run_customized_training_loop( ...@@ -557,7 +559,6 @@ def run_customized_training_loop(
for metric in model.metrics: for metric in model.metrics:
training_summary[metric.name] = _float_metric_value(metric) training_summary[metric.name] = _float_metric_value(metric)
if eval_metrics: if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value( training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0]) train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0]) training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
......
...@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config, ...@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
# Export uses float32 for now, even if training uses mixed precision. # Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
classifier_model = bert_models.classifier_model( classifier_model = bert_models.classifier_model(
bert_config, input_meta_data.get('num_labels', 1))[0] bert_config,
input_meta_data.get('num_labels', 1),
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=False)[0]
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
model_export_path, model=classifier_model, checkpoint_dir=model_dir) model_export_path, model=classifier_model, checkpoint_dir=model_dir)
......
...@@ -24,7 +24,6 @@ import tensorflow as tf ...@@ -24,7 +24,6 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling.models import bert_pretrainer from official.nlp.modeling.models import bert_pretrainer
...@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config): ...@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class BertPretrainerConfig(base_config.Config): class BertPretrainerConfig(base_config.Config):
"""BERT encoder configuration.""" """BERT encoder configuration."""
num_masked_tokens: int = 76
encoder: encoders.TransformerEncoderConfig = ( encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list) cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
...@@ -56,45 +54,18 @@ def instantiate_classification_heads_from_cfgs( ...@@ -56,45 +54,18 @@ def instantiate_classification_heads_from_cfgs(
] if cls_head_configs else [] ] if cls_head_configs else []
def instantiate_bertpretrainer_from_cfg( def instantiate_pretrainer_from_cfg(
config: BertPretrainerConfig, config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2: ) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config.""" """Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder encoder_cfg = config.encoder
if encoder_network is None: if encoder_network is None:
encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg) encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
return bert_pretrainer.BertPretrainerV2( return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation), mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network, encoder_network=encoder_network,
classification_heads=instantiate_classification_heads_from_cfgs( classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads)) config.cls_heads))
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class QADevDataConfig(cfg.DataConfig):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path: str = ""
input_preprocessed_data_path: str = ""
version_2_with_negative: bool = False
doc_stride: int = 128
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
query_length: int = 64
drop_remainder: bool = False
vocab_file: str = ""
tokenization: str = "WordPiece" # WordPiece or SentencePiece
do_lower_case: bool = True
...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self): def test_network_invocation(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1)) encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads. # Invokes with classification heads.
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self): def test_checkpoint_items(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase): ...@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
encoder = bert.instantiate_bertpretrainer_from_cfg(config) encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(encoder.checkpoint_items.keys(), self.assertSameElements(
["encoder", "next_sentence.pooler_dense"]) encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config): ...@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config):
sequence_length: int = 512 sequence_length: int = 512
num_classes: int = 2 num_classes: int = 2
discriminator_loss_weight: float = 50.0 discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = ( generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = ( discriminator_encoder: encoders.TransformerEncoderConfig = (
...@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg( ...@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg(
"""Instantiates ElectraPretrainer from the config.""" """Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder discriminator_encoder_cfg = config.discriminator_encoder
if generator_network is None: # Copy discriminator's embeddings to generator for easier model serialization.
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
if discriminator_network is None: if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg( discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg) discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer( return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network, generator_network=generator_network,
discriminator_network=discriminator_network, discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size, vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes, num_classes=config.num_classes,
sequence_length=config.sequence_length, sequence_length=config.sequence_length,
last_hidden_dim=config.generator_encoder.hidden_size,
num_token_predictions=config.num_masked_tokens, num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation( mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation), generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range), stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs( classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads)) config.cls_heads),
disallow_correct=config.disallow_correct)
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
Includes configurations and instantiation methods. Includes configurations and instantiation methods.
""" """
from typing import Optional
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -40,12 +41,47 @@ class TransformerEncoderConfig(base_config.Config): ...@@ -40,12 +41,47 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings: int = 512 max_position_embeddings: int = 512
type_vocab_size: int = 2 type_vocab_size: int = 2
initializer_range: float = 0.02 initializer_range: float = 0.02
embedding_size: Optional[int] = None
def instantiate_encoder_from_cfg( def instantiate_encoder_from_cfg(
config: TransformerEncoderConfig) -> networks.TransformerEncoder: config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig.""" """Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network = networks.TransformerEncoder( if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
seq_length=None,
max_seq_length=config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
dropout_rate=config.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_activation=tf_utils.get_activation(
config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers,
pooled_output_dim=config.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder":
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls))
encoder_network = encoder_cls(
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
num_layers=config.num_layers, num_layers=config.num_layers,
...@@ -58,5 +94,7 @@ def instantiate_encoder_from_cfg( ...@@ -58,5 +94,7 @@ def instantiate_encoder_from_cfg(
max_sequence_length=config.max_position_embeddings, max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size, type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range)) stddev=config.initializer_range),
embedding_width=config.embedding_size,
embedding_layer=embedding_layer)
return encoder_network return encoder_network
...@@ -31,7 +31,7 @@ from official.nlp.bert import tokenization ...@@ -31,7 +31,7 @@ from official.nlp.bert import tokenization
class InputExample(object): class InputExample(object):
"""A single training/test example for simple sequence classification.""" """A single training/test example for simple seq regression/classification."""
def __init__(self, def __init__(self,
guid, guid,
...@@ -48,8 +48,9 @@ class InputExample(object): ...@@ -48,8 +48,9 @@ class InputExample(object):
sequence tasks, only this sequence must be specified. sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence. text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks. Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be label: (Optional) string for classification, float for regression. The
specified for train and dev examples, but not for test examples. label of the example. This should be specified for train and dev
examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during weight: (Optional) float. The weight of the example to be used during
training. training.
int_iden: (Optional) int. The int identification number of example in the int_iden: (Optional) int. The int identification number of example in the
...@@ -84,10 +85,12 @@ class InputFeatures(object): ...@@ -84,10 +85,12 @@ class InputFeatures(object):
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for converters for seq regression/classification datasets."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode): def __init__(self, process_text_fn=tokenization.convert_to_unicode):
self.process_text_fn = process_text_fn self.process_text_fn = process_text_fn
self.is_regression = False
self.label_type = None
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set.""" """Gets a collection of `InputExample`s for the train set."""
...@@ -121,143 +124,158 @@ class DataProcessor(object): ...@@ -121,143 +124,158 @@ class DataProcessor(object):
return lines return lines
class XnliProcessor(DataProcessor): class ColaProcessor(DataProcessor):
"""Processor for the XNLI data set.""" """Processor for the CoLA data set (GLUE version)."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
if language == "all":
self.languages = XnliProcessor.supported_languages
elif language not in XnliProcessor.supported_languages:
raise ValueError("language %s is not supported for XNLI task." % language)
else:
self.languages = [language]
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = [] return self._create_examples(
for language in self.languages: self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
# Skips the header.
lines.extend(
self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % language))[1:])
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) return self._create_examples(
examples = [] self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % i
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv")) return self._create_examples(
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages} self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
language = self.process_text_fn(line[0])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples_by_lang[language].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["contradiction", "entailment", "neutral"] return ["0", "1"]
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "XNLI" return "COLA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
# Only the test set has a header.
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = self.process_text_fn(line[1])
label = "0"
else:
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class XtremeXnliProcessor(DataProcessor):
"""Processor for the XTREME XNLI data set.""" class MnliProcessor(DataProcessor):
supported_languages = [ """Processor for the MultiNLI data set (GLUE version)."""
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh" def __init__(self,
] mnli_type="matched",
process_text_fn=tokenization.convert_to_unicode):
super(MnliProcessor, self).__init__(process_text_fn)
if mnli_type not in ("matched", "mismatched"):
raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
self.mnli_type = mnli_type
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv")) return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_mismatched")
def get_test_examples(self, data_dir):
"""See base class."""
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i if i == 0:
text_a = self.process_text_fn(line[0]) continue
text_b = self.process_text_fn(line[1]) guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
label = self.process_text_fn(line[2]) text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[9])
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv")) return self._create_examples(
examples = [] self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
examples_by_lang = {k: [] for k in self.supported_languages} return self._create_examples(
for lang in self.supported_languages: self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = f"test-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "contradiction"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["contradiction", "entailment", "neutral"] return ["0", "1"]
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "XTREME-XNLI" return "MRPC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
if set_type == "test":
label = "0"
else:
label = self.process_text_fn(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class PawsxProcessor(DataProcessor): class PawsxProcessor(DataProcessor):
...@@ -289,7 +307,7 @@ class PawsxProcessor(DataProcessor): ...@@ -289,7 +307,7 @@ class PawsxProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:]) self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -306,7 +324,7 @@ class PawsxProcessor(DataProcessor): ...@@ -306,7 +324,7 @@ class PawsxProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:]) self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "dev-%d" % i guid = "dev-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -320,7 +338,7 @@ class PawsxProcessor(DataProcessor): ...@@ -320,7 +338,7 @@ class PawsxProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages} examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages: for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:] lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:]
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "test-%d" % i guid = "test-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -339,109 +357,8 @@ class PawsxProcessor(DataProcessor): ...@@ -339,109 +357,8 @@ class PawsxProcessor(DataProcessor):
return "XTREME-PAWS-X" return "XTREME-PAWS-X"
class XtremePawsxProcessor(DataProcessor): class QnliProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set.""" """Processor for the QNLI data set (GLUE version)."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "0"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-PAWS-X"
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[9])
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -451,7 +368,7 @@ class MrpcProcessor(DataProcessor): ...@@ -451,7 +368,7 @@ class MrpcProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -460,26 +377,28 @@ class MrpcProcessor(DataProcessor): ...@@ -460,26 +377,28 @@ class MrpcProcessor(DataProcessor):
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["entailment", "not_entailment"]
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "MRPC" return "QNLI"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, 1)
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
if set_type == "test": if set_type == "test":
label = "0" text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = "entailment"
else: else:
label = self.process_text_fn(line[0]) text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -513,9 +432,9 @@ class QqpProcessor(DataProcessor): ...@@ -513,9 +432,9 @@ class QqpProcessor(DataProcessor):
return "QQP" return "QQP"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
...@@ -530,52 +449,6 @@ class QqpProcessor(DataProcessor): ...@@ -530,52 +449,6 @@ class QqpProcessor(DataProcessor):
return examples return examples
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "COLA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = self.process_text_fn(line[1])
label = "0"
else:
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class RteProcessor(DataProcessor): class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version).""" """Processor for the RTE data set (GLUE version)."""
...@@ -606,7 +479,7 @@ class RteProcessor(DataProcessor): ...@@ -606,7 +479,7 @@ class RteProcessor(DataProcessor):
return "RTE" return "RTE"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
...@@ -651,9 +524,9 @@ class SstProcessor(DataProcessor): ...@@ -651,9 +524,9 @@ class SstProcessor(DataProcessor):
return "SST-2" return "SST-2"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -668,8 +541,14 @@ class SstProcessor(DataProcessor): ...@@ -668,8 +541,14 @@ class SstProcessor(DataProcessor):
return examples return examples
class QnliProcessor(DataProcessor): class StsBProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version).""" """Processor for the STS-B data set (GLUE version)."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(StsBProcessor, self).__init__(process_text_fn=process_text_fn)
self.is_regression = True
self.label_type = float
self._labels = None
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -679,7 +558,7 @@ class QnliProcessor(DataProcessor): ...@@ -679,7 +558,7 @@ class QnliProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched") self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -688,28 +567,26 @@ class QnliProcessor(DataProcessor): ...@@ -688,28 +567,26 @@ class QnliProcessor(DataProcessor):
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["entailment", "not_entailment"] return self._labels
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "QNLI" return "STS-B"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, 1) guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[7])
text_b = tokenization.convert_to_unicode(line[8])
if set_type == "test": if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1]) label = 0.0
text_b = tokenization.convert_to_unicode(line[2])
label = "entailment"
else: else:
text_a = tokenization.convert_to_unicode(line[1]) label = self.label_type(tokenization.convert_to_unicode(line[9]))
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -729,6 +606,8 @@ class TfdsProcessor(DataProcessor): ...@@ -729,6 +606,8 @@ class TfdsProcessor(DataProcessor):
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2" tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2," tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2,"
"is_regression=true,label_type=float" "is_regression=true,label_type=float"
tfds_params="dataset=snli,text_key=premise,text_b_key=hypothesis,"
"skip_label=-1"
Possible parameters (please refer to the documentation of Tensorflow Datasets Possible parameters (please refer to the documentation of Tensorflow Datasets
(TFDS) for the meaning of individual parameters): (TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number). dataset: Required dataset name (potentially with subset and version number).
...@@ -746,6 +625,7 @@ class TfdsProcessor(DataProcessor): ...@@ -746,6 +625,7 @@ class TfdsProcessor(DataProcessor):
label_type: Type of the label key (defaults to `int`). label_type: Type of the label key (defaults to `int`).
weight_key: Key of the float sample weight (is not used if not provided). weight_key: Key of the float sample weight (is not used if not provided).
is_regression: Whether the task is a regression problem (defaults to False). is_regression: Whether the task is a regression problem (defaults to False).
skip_label: Skip examples with given label (defaults to None).
""" """
def __init__(self, def __init__(self,
...@@ -785,6 +665,9 @@ class TfdsProcessor(DataProcessor): ...@@ -785,6 +665,9 @@ class TfdsProcessor(DataProcessor):
self.label_type = dtype_map[d.get("label_type", "int")] self.label_type = dtype_map[d.get("label_type", "int")]
self.is_regression = cast_str_to_bool(d.get("is_regression", "False")) self.is_regression = cast_str_to_bool(d.get("is_regression", "False"))
self.weight_key = d.get("weight_key", None) self.weight_key = d.get("weight_key", None)
self.skip_label = d.get("skip_label", None)
if self.skip_label is not None:
self.skip_label = self.label_type(self.skip_label)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
assert data_dir is None assert data_dir is None
...@@ -805,7 +688,7 @@ class TfdsProcessor(DataProcessor): ...@@ -805,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return "TFDS_" + self.dataset_name return "TFDS_" + self.dataset_name
def _create_examples(self, split_name, set_type): def _create_examples(self, split_name, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
if split_name not in self.dataset: if split_name not in self.dataset:
raise ValueError("Split {} not available.".format(split_name)) raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator() dataset = self.dataset[split_name].as_numpy_iterator()
...@@ -823,6 +706,8 @@ class TfdsProcessor(DataProcessor): ...@@ -823,6 +706,8 @@ class TfdsProcessor(DataProcessor):
if self.text_b_key: if self.text_b_key:
text_b = self.process_text_fn(example[self.text_b_key]) text_b = self.process_text_fn(example[self.text_b_key])
label = self.label_type(example[self.label_key]) label = self.label_type(example[self.label_key])
if self.skip_label is not None and label == self.skip_label:
continue
if self.weight_key: if self.weight_key:
weight = float(example[self.weight_key]) weight = float(example[self.weight_key])
examples.append( examples.append(
...@@ -863,7 +748,7 @@ class WnliProcessor(DataProcessor): ...@@ -863,7 +748,7 @@ class WnliProcessor(DataProcessor):
return "WNLI" return "WNLI"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
...@@ -880,6 +765,200 @@ class WnliProcessor(DataProcessor): ...@@ -880,6 +765,200 @@ class WnliProcessor(DataProcessor):
return examples return examples
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
if language == "all":
self.languages = XnliProcessor.supported_languages
elif language not in XnliProcessor.supported_languages:
raise ValueError("language %s is not supported for XNLI task." % language)
else:
self.languages = [language]
def get_train_examples(self, data_dir):
"""See base class."""
lines = []
for language in self.languages:
# Skips the header.
lines.extend(
self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % language))[1:])
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % i
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for i, line in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
language = self.process_text_fn(line[0])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples_by_lang[language].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XNLI"
class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for i, line in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "0"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-PAWS-X"
class XtremeXnliProcessor(DataProcessor):
"""Processor for the XTREME XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for i, line in enumerate(lines):
guid = f"test-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "contradiction"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-XNLI"
def convert_single_example(ex_index, example, label_list, max_seq_length, def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer): tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`.""" """Converts a single `InputExample` into a single `InputFeatures`."""
...@@ -990,7 +1069,7 @@ def file_based_convert_examples_to_features(examples, ...@@ -990,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
tf.io.gfile.makedirs(os.path.dirname(output_file)) tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file) writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples): for ex_index, example in enumerate(examples):
if ex_index % 10000 == 0: if ex_index % 10000 == 0:
logging.info("Writing example %d of %d", ex_index, len(examples)) logging.info("Writing example %d of %d", ex_index, len(examples))
......
...@@ -50,35 +50,41 @@ flags.DEFINE_string( ...@@ -50,35 +50,41 @@ flags.DEFINE_string(
"for the task.") "for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI", ["COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"PAWS-X", "XTREME-XNLI", "XTREME-PAWS-X"], "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI",
"XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The " "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format " "difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english " "of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for " "only and for XNLI is all languages combined. Same for "
"PAWS-X.") "PAWS-X.")
# XNLI task specific flag. # MNLI task-specific flag.
flags.DEFINE_enum(
"mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"xnli_language", "en", "xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data " "Language of training data for XNLI task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# PAWS-X task specific flag. # PAWS-X task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"pawsx_language", "en", "pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data " "Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# Retrieva task specific flags # Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"], flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring") "The name of sentence retrieval task for scoring")
# Tagging task specific flags # Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"], flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.") "The name of BERT tagging (token classification) task.")
# BERT Squad task specific flags. # BERT Squad task-specific flags.
flags.DEFINE_string( flags.DEFINE_string(
"squad_data_file", None, "squad_data_file", None,
"The input data file in for generating training data for BERT squad task.") "The input data file in for generating training data for BERT squad task.")
...@@ -178,7 +184,8 @@ def generate_classifier_dataset(): ...@@ -178,7 +184,8 @@ def generate_classifier_dataset():
"cola": "cola":
classifier_data_lib.ColaProcessor, classifier_data_lib.ColaProcessor,
"mnli": "mnli":
classifier_data_lib.MnliProcessor, functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type),
"mrpc": "mrpc":
classifier_data_lib.MrpcProcessor, classifier_data_lib.MrpcProcessor,
"qnli": "qnli":
...@@ -187,6 +194,8 @@ def generate_classifier_dataset(): ...@@ -187,6 +194,8 @@ def generate_classifier_dataset():
"rte": classifier_data_lib.RteProcessor, "rte": classifier_data_lib.RteProcessor,
"sst-2": "sst-2":
classifier_data_lib.SstProcessor, classifier_data_lib.SstProcessor,
"sts-b":
classifier_data_lib.StsBProcessor,
"xnli": "xnli":
functools.partial(classifier_data_lib.XnliProcessor, functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language), language=FLAGS.xnli_language),
......
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import itertools
import random import random
from absl import app from absl import app
...@@ -48,6 +49,12 @@ flags.DEFINE_bool( ...@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask", False, "do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.") "Whether to use whole word masking rather than per-WordPiece masking.")
flags.DEFINE_integer(
"max_ngram_size", None,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
flags.DEFINE_bool( flags.DEFINE_bool(
"gzip_compress", False, "gzip_compress", False,
"Whether to use `GZIP` compress option to get compressed TFRecord files.") "Whether to use `GZIP` compress option to get compressed TFRecord files.")
...@@ -192,7 +199,8 @@ def create_training_instances(input_files, ...@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob, masked_lm_prob,
max_predictions_per_seq, max_predictions_per_seq,
rng, rng,
do_whole_word_mask=False): do_whole_word_mask=False,
max_ngram_size=None):
"""Create `TrainingInstance`s from raw text.""" """Create `TrainingInstance`s from raw text."""
all_documents = [[]] all_documents = [[]]
...@@ -229,7 +237,7 @@ def create_training_instances(input_files, ...@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document( create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)) do_whole_word_mask, max_ngram_size))
rng.shuffle(instances) rng.shuffle(instances)
return instances return instances
...@@ -238,7 +246,8 @@ def create_training_instances(input_files, ...@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def create_instances_from_document( def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask=False): do_whole_word_mask=False,
max_ngram_size=None):
"""Creates `TrainingInstance`s for a single document.""" """Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index] document = all_documents[document_index]
...@@ -337,7 +346,7 @@ def create_instances_from_document( ...@@ -337,7 +346,7 @@ def create_instances_from_document(
(tokens, masked_lm_positions, (tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions( masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng, tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask) do_whole_word_mask, max_ngram_size)
instance = TrainingInstance( instance = TrainingInstance(
tokens=tokens, tokens=tokens,
segment_ids=segment_ids, segment_ids=segment_ids,
...@@ -355,72 +364,238 @@ def create_instances_from_document( ...@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"]) ["index", "label"])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram = collections.namedtuple("_Gram", ["begin", "end"])
def _window(iterable, size):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i = iter(iterable)
window = []
try:
for e in range(0, size):
window.append(next(i))
yield window
except StopIteration:
# handle the case where iterable's length is less than the window size.
return
for e in i:
window = window[1:] + [e]
yield window
def _contiguous(sorted_grams):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for a, b in _window(sorted_grams, 2):
if a.end != b.begin:
return False
return True
def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if not grams:
return None
grams = sorted(grams)
num_tokens = grams[-1].end
# Ensure our grams are valid (i.e., they don't overlap).
for a, b in _window(grams, 2):
if a.end > b.begin:
raise ValueError("overlapping grams: {}".format(grams))
# Build map from n-gram length to list of n-grams.
ngrams = {i: [] for i in range(1, max_ngram_size+1)}
for gram_size in range(1, max_ngram_size+1):
for g in _window(grams, gram_size):
if _contiguous(g):
# Add an n-gram which spans these one-grams.
ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
# Shuffle each list of n-grams.
for v in ngrams.values():
rng.shuffle(v)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights = list(
itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
output_ngrams = []
# Keep a bitmask of which tokens have been masked.
masked_tokens = [False] * num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while (sum(masked_tokens) < max_masked_tokens and
sum(len(s) for s in ngrams.values())):
# Pick an n-gram size based on our weights.
sz = random.choices(range(1, max_ngram_size+1),
cum_weights=cummulative_weights)[0]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if sum(masked_tokens) + sz > max_masked_tokens:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams[sz].clear()
continue
def create_masked_lm_predictions(tokens, masked_lm_prob, # All of the n-grams of this size have been used.
max_predictions_per_seq, vocab_words, rng, if not ngrams[sz]:
do_whole_word_mask): continue
"""Creates the predictions for the masked LM objective."""
# Choose a random n-gram of the given size.
gram = ngrams[sz].pop()
num_gram_tokens = gram.end-gram.begin
# Check if this would add too many tokens.
if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
continue
# Check if any of the tokens in this gram have already been masked.
if sum(masked_tokens[gram.begin:gram.end]):
continue
cand_indexes = [] # Found a usable n-gram! Mark its tokens as masked and add it to return.
for (i, token) in enumerate(tokens): masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
if token == "[CLS]" or token == "[SEP]": output_ngrams.append(gram)
return output_ngrams
def _wordpieces_to_grams(tokens):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams = []
gram_start_pos = None
for i, token in enumerate(tokens):
if gram_start_pos is not None and token.startswith("##"):
continue continue
# Whole Word Masking means that if we mask all of the wordpieces if gram_start_pos is not None:
# corresponding to an original word. When a word has been split into grams.append(_Gram(gram_start_pos, i))
# WordPieces, the first token does not have any marker and any subsequence if token not in ["[CLS]", "[SEP]"]:
# tokens are prefixed with ##. So whenever we see the ## token, we gram_start_pos = i
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else: else:
cand_indexes.append([i]) gram_start_pos = None
if gram_start_pos is not None:
grams.append(_Gram(gram_start_pos, len(tokens)))
return grams
rng.shuffle(cand_indexes)
output_tokens = list(tokens) def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask,
max_ngram_size=None):
"""Creates the predictions for the masked LM objective."""
if do_whole_word_mask:
grams = _wordpieces_to_grams(tokens)
else:
# Here we consider each token to be a word to allow for sub-word masking.
if max_ngram_size:
raise ValueError("cannot use ngram masking without whole word masking")
grams = [_Gram(i, i+1) for i in range(0, len(tokens))
if tokens[i] not in ["[CLS]", "[SEP]"]]
num_to_predict = min(max_predictions_per_seq, num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob)))) max(1, int(round(len(tokens) * masked_lm_prob))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
num_to_predict, rng)
masked_lms = [] masked_lms = []
covered_indexes = set() output_tokens = list(tokens)
for index_set in cand_indexes: for gram in masked_grams:
if len(masked_lms) >= num_to_predict: # 80% of the time, replace all n-gram tokens with [MASK]
break if rng.random() < 0.8:
# If adding a whole-word mask would exceed the maximum number of replacement_action = lambda idx: "[MASK]"
# predictions, then just skip this candidate. else:
if len(masked_lms) + len(index_set) > num_to_predict: # 10% of the time, keep all the original n-gram tokens.
continue if rng.random() < 0.5:
is_any_index_covered = False replacement_action = lambda idx: tokens[idx]
for index in index_set: # 10% of the time, replace each n-gram token with a random word.
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else: else:
# 10% of the time, keep original replacement_action = lambda idx: rng.choice(vocab_words)
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token for idx in range(gram.begin, gram.end):
output_tokens[idx] = replacement_action(idx)
masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lms = sorted(masked_lms, key=lambda x: x.index)
...@@ -467,7 +642,7 @@ def main(_): ...@@ -467,7 +642,7 @@ def main(_):
instances = create_training_instances( instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng, FLAGS.do_whole_word_mask) rng, FLAGS.do_whole_word_mask, FLAGS.max_ngram_size)
output_files = FLAGS.output_file.split(",") output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***") logging.info("*** Writing to output files ***")
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ''
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
# Settings below are question answering specific.
version_2_with_negative: bool = False
# Settings below are only used for eval mode.
input_preprocessed_data_path: str = ''
doc_stride: int = 128
query_length: int = 64
vocab_file: str = ''
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
do_lower_case: bool = True
@data_loader_factory.register_data_loader_cls(QADataConfig)
class QuestionAnsweringDataLoader:
"""A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params):
self._params = params
self._seq_length = params.seq_length
self._is_training = params.is_training
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
else:
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x, y = {}, {}
for name, tensor in record.items():
if name in ('start_positions', 'end_positions'):
y[name] = tensor
elif name == 'input_ids':
x['input_word_ids'] = tensor
elif name == 'segment_ids':
x['input_type_ids'] = tensor
else:
x[name] = tensor
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
...@@ -23,6 +23,9 @@ from official.modeling.hyperparams import config_definitions as cfg ...@@ -23,6 +23,9 @@ from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
@dataclasses.dataclass @dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig): class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction).""" """Data config for sentence prediction task (tasks/sentence_prediction)."""
...@@ -30,6 +33,7 @@ class SentencePredictionDataConfig(cfg.DataConfig): ...@@ -30,6 +33,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
global_batch_size: int = 32 global_batch_size: int = 32
is_training: bool = True is_training: bool = True
seq_length: int = 128 seq_length: int = 128
label_type: str = 'int'
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig) @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
...@@ -42,11 +46,12 @@ class SentencePredictionDataLoader: ...@@ -42,11 +46,12 @@ class SentencePredictionDataLoader:
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
label_type = LABEL_TYPES_MAP[self._params.label_type]
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64), 'label_ids': tf.io.FixedLenFeature([], label_type),
} }
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
......
...@@ -28,6 +28,7 @@ class TaggingDataConfig(cfg.DataConfig): ...@@ -28,6 +28,7 @@ class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging).""" """Data config for tagging (tasks/tagging)."""
is_training: bool = True is_training: bool = True
seq_length: int = 128 seq_length: int = 128
include_sentence_id: bool = False
@data_loader_factory.register_data_loader_cls(TaggingDataConfig) @data_loader_factory.register_data_loader_cls(TaggingDataConfig)
...@@ -37,6 +38,7 @@ class TaggingDataLoader: ...@@ -37,6 +38,7 @@ class TaggingDataLoader:
def __init__(self, params: TaggingDataConfig): def __init__(self, params: TaggingDataConfig):
self._params = params self._params = params
self._seq_length = params.seq_length self._seq_length = params.seq_length
self._include_sentence_id = params.include_sentence_id
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
...@@ -46,6 +48,9 @@ class TaggingDataLoader: ...@@ -46,6 +48,9 @@ class TaggingDataLoader:
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
} }
if self._include_sentence_id:
name_to_features['sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
...@@ -65,6 +70,8 @@ class TaggingDataLoader: ...@@ -65,6 +70,8 @@ class TaggingDataLoader:
'input_mask': record['input_mask'], 'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids'] 'input_type_ids': record['segment_ids']
} }
if self._include_sentence_id:
x['sentence_id'] = record['sentence_id']
y = record['label_ids'] y = record['label_ids']
return (x, y) return (x, y)
......
# NLP Modeling Library # NLP Modeling Library
This libary provides a set of Keras primitives (Layers, Networks, and Models) This library provides a set of Keras primitives (Layers, Networks, and Models)
that can be assembled into transformer-based models. They are that can be assembled into transformer-based models. They are
flexible, validated, interoperable, and both TF1 and TF2 compatible. flexible, validated, interoperable, and both TF1 and TF2 compatible.
...@@ -16,6 +16,11 @@ standardized configuration. ...@@ -16,6 +16,11 @@ standardized configuration.
* [`losses`](losses) contains common loss computation used in NLP tasks. * [`losses`](losses) contains common loss computation used in NLP tasks.
Please see the colab
[nlp_modeling_library_intro.ipynb]
(https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb)
for how to build transformer-based NLP models using above primitives.
Besides the pre-defined primitives, it also provides scaffold classes to allow Besides the pre-defined primitives, it also provides scaffold classes to allow
easy experimentation with noval achitectures, e.g., you don’t need to fork a whole Transformer object to try a different kind of attention primitive, for instance. easy experimentation with noval achitectures, e.g., you don’t need to fork a whole Transformer object to try a different kind of attention primitive, for instance.
...@@ -33,11 +38,9 @@ embedding subnetwork (which will replace the standard embedding logic) and/or a ...@@ -33,11 +38,9 @@ embedding subnetwork (which will replace the standard embedding logic) and/or a
custom hidden layer (which will replace the Transformer instantiation in the custom hidden layer (which will replace the Transformer instantiation in the
encoder). encoder).
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder. Please see the colab
[customize_encoder.ipynb]
(https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb)
for how to use scaffold classes to build noval achitectures.
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder.
...@@ -3,11 +3,6 @@ ...@@ -3,11 +3,6 @@
Layers are the fundamental building blocks for NLP models. They can be used to Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models. assemble new layers, networks, or models.
* [DenseEinsum](dense_einsum.py) implements a feedforward network using
tf.einsum. This layer contains the einsum op, the associated weight, and the
logic required to generate the einsum expression for the given
initialization parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention * [MultiHeadAttention](attention.py) implements an optionally masked attention
between query, key, value tensors as described in between query, key, value tensors as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
......
...@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense ...@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
def _build_attention_equation(qkv_rank, attn_axes): def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation. """Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as: Query, key, value inputs after projection are expected to have the shape as:
...@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes): ...@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
<query attention dims>, num_heads, channels) <query attention dims>, num_heads, channels)
Args: Args:
qkv_rank: the rank of query, key, value tensors. rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention. attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns: Returns:
Einsum equations. Einsum equations.
""" """
target_notation = _CHR_IDX[:qkv_rank] target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim. # `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,))) batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = qkv_rank letter_offset = rank
source_notation = "" source_notation = ""
for i in range(qkv_rank): for i in range(rank):
if i in batch_dims or i == qkv_rank - 1: if i in batch_dims or i == rank - 1:
source_notation += target_notation[i] source_notation += target_notation[i]
else: else:
source_notation += _CHR_IDX[letter_offset] source_notation += _CHR_IDX[letter_offset]
...@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
sequence dims. If not specified, projects back to the key feature dim. sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features. attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head return_attention_scores: bool, if `True`, returns the multi-head attention
attention scores as an additional output argument. scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
""" """
def __init__(self, def __init__(self,
...@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._attention_axes = (attention_axes,) self._attention_axes = (attention_axes,)
else: else:
self._attention_axes = attention_axes self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self): def get_config(self):
config = { config = {
...@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config = super(MultiHeadAttention, self).get_config() base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape): def _build_from_signature(self, query, value, key=None):
inputs_len = len(input_shape) """Builds layers and variables.
if inputs_len > 3 or inputs_len < 2:
raise ValueError( Once the method is called, self._built_from_signature will be set to True.
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. " Args:
"Given length: %d" % inputs_len) query: query tensor or TensorShape.
tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape) value: value tensor or TensorShape.
query_shape = tensor_shapes[0] key: key tensor or TensorShape.
value_shape = tensor_shapes[1] """
key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape self._built_from_signature = True
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -271,84 +293,79 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -271,84 +293,79 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1 free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2) free_dims, bound_dims=1, output_dims=2)
self._query_dense = EinsumDense( self._query_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]), [self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="query", name="query",
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2) key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = EinsumDense( self._key_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]), [self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="key", name="key",
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_shape.rank - 1, bound_dims=1, output_dims=2) value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = EinsumDense( self._value_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_size]), [self._num_heads, self._value_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="value", name="value",
**common_kwargs) **common_kwargs)
# Builds the attention computations for multi-head dot product attention. # Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once it # These computations could be wrapped into the keras attention layer once
# support mult-head einsum computations. # it support mult-head einsum computations.
self._build_attention(output_rank) self.build_attention(output_rank)
if self._output_shape: if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized): if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape] output_shape = [self._output_shape]
else:
output_shape = self._output_shape
else: else:
output_shape = self._output_shape output_shape = [query_shape[-1]]
else: einsum_equation, bias_axes, output_rank = _build_proj_equation(
output_shape = [query_shape[-1]] free_dims, bound_dims=2, output_dims=len(output_shape))
einsum_equation, bias_axes, output_rank = _build_proj_equation( self._output_dense = EinsumDense(
free_dims, bound_dims=2, output_dims=len(output_shape)) einsum_equation,
self._output_dense = EinsumDense( output_shape=_get_output_shape(output_rank - 1, output_shape),
einsum_equation, bias_axes=bias_axes if self._use_bias else None,
output_shape=_get_output_shape(output_rank - 1, output_shape), name="attention_output",
bias_axes=bias_axes if self._use_bias else None, **common_kwargs)
name="attention_output",
**common_kwargs) def build_attention(self, rank):
super(MultiHeadAttention, self).build(input_shape)
def _build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product costomize attention computation to replace the default dot-product
attention. attention.
Args: Args:
qkv_rank: the rank of query, key, value tensors. rank: the rank of query, key, value tensors.
""" """
if self._attention_axes is None: if self._attention_axes is None:
self._attention_axes = tuple(range(1, qkv_rank - 2)) self._attention_axes = tuple(range(1, rank - 2))
else: else:
self._attention_axes = tuple(self._attention_axes) self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = ( self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(qkv_rank, attn_axes=self._attention_axes)) _build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple( norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax( self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes) mask_expansion_axes=[1], normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def _compute_attention(self, def compute_attention(self, query, key, value, attention_mask=None):
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors. """Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected This function defines the computation inside `call` with projected
...@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention implementation. attention implementation.
Args: Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`. query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`. key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`. value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions. attention to certain positions.
...@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_output: Multi-headed outputs of attention computation. attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights. attention_scores: Multi-headed attention weights.
""" """
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key, query)
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S] # `attention_scores` = [B, N, T, S]
...@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `context_layer` = [B, T, N, H] # `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation, attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor) attention_scores_dropout, value)
return attention_output, attention_scores return attention_output, attention_scores
def call(self, inputs, attention_mask=None): def call(self, query, value, key=None, attention_mask=None):
"""Implements the forward pass. """Implements the forward pass.
Size glossary: Size glossary:
...@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value (source) attention axes shape (S), the rank must match the target. * Value (source) attention axes shape (S), the rank must match the target.
Args: Args:
inputs: List of the following tensors: query: Query `Tensor` of shape `[B, T, dim]`.
* query: Query `Tensor` of shape `[B, T, dim]`. value: Value `Tensor` of shape `[B, S, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`. key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will `value` for both `key` and `value`, which is the most common case.
use `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions. attention to certain positions.
...@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention attention
axes. axes.
""" """
inputs_len = len(inputs) if not self._built_from_signature:
if inputs_len > 3 or inputs_len < 2: self._build_from_signature(query=query, value=value, key=key)
raise ValueError( if key is None:
"Expects inputs list of length 2 or 3, namely [query, value] or " key = value
"[query, value, key]. "
"Given length: %d" % inputs_len)
query = inputs[0]
value = inputs[1]
key = inputs[2] if inputs_len == 3 else value
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, T, N ,H] # `query` = [B, T, N ,H]
query_tensor = self._query_dense(query) query = self._query_dense(query)
# `key_tensor` = [B, S, N, H] # `key` = [B, S, N, H]
key_tensor = self._key_dense(key) key = self._key_dense(key)
# `value_tensor` = [B, S, N, H] # `value` = [B, S, N, H]
value_tensor = self._value_dense(value) value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention( attention_output, attention_scores = self.compute_attention(
query_tensor, key_tensor, value_tensor, attention_mask) query, key, value, attention_mask)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores: if self._return_attention_scores:
...@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention): ...@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
Arguments are the same as `MultiHeadAttention` layer. Arguments are the same as `MultiHeadAttention` layer.
""" """
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step): def _update_cache(self, key, value, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors.""" """Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values. # Combines cached keys and values with new keys and values.
if decode_loop_step is not None: if decode_loop_step is not None:
# TPU special case. # TPU special case.
key_seq_dim = cache["key"].shape.as_list()[1] key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape( indices = tf.reshape(
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype), tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
[1, key_seq_dim, 1, 1]) [1, key_seq_dim, 1, 1])
key_tensor = cache["key"] + key_tensor * indices key = cache["key"] + key * indices
value_seq_dim = cache["value"].shape.as_list()[1] value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape( indices = tf.reshape(
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype), tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
[1, value_seq_dim, 1, 1]) [1, value_seq_dim, 1, 1])
value_tensor = cache["value"] + value_tensor * indices value = cache["value"] + value * indices
else: else:
key_tensor = tf.concat( key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1) value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
value_tensor = tf.concat(
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
# Update cache # Update cache
cache["key"] = key_tensor cache["key"] = key
cache["value"] = value_tensor cache["value"] = value
return key_tensor, value_tensor return key, value
def call(self, def call(self,
inputs, query,
value,
key=None,
attention_mask=None, attention_mask=None,
cache=None, cache=None,
decode_loop_step=None): decode_loop_step=None):
from_tensor = inputs[0] if not self._built_from_signature:
to_tensor = inputs[1] self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of sequences) # B = batch size (number of sequences)
...@@ -494,23 +509,21 @@ class CachedAttention(MultiHeadAttention): ...@@ -494,23 +509,21 @@ class CachedAttention(MultiHeadAttention):
# T = `to_tensor` sequence length # T = `to_tensor` sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor) query = self._query_dense(query)
# `key_tensor` = [B, T, N, H] # `key` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor) key = self._key_dense(key)
# `value_tensor` = [B, T, N, H] # `value` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor) value = self._value_dense(value)
if cache: if cache:
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor, key, value = self._update_cache(key, value, cache, decode_loop_step)
cache, decode_loop_step)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key, query)
query_tensor)
attention_scores = tf.multiply(attention_scores, attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size))) 1.0 / math.sqrt(float(self._key_size)))
...@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention): ...@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
attention_scores = self._dropout_layer(attention_scores) attention_scores = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
attention_output = tf.einsum(self._combine_equation, attention_scores, attention_output = tf.einsum(self._combine_equation, attention_scores,
value_tensor) value)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores: if self._return_attention_scores:
return attention_output, attention_scores, cache return attention_output, attention_scores, cache
......
...@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value]) output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
...@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64) test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): def test_attention_scores(self):
...@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True) num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query]) output, coef = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
...@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor) output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
...@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V. # Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8)) key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor) output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output) model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
...@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool") mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data) output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape) null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data) unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(output, unmasked_output) self.assertNotAllClose(output, unmasked_output)
...@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase): ...@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
key_size=64) key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
...@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element. # one element.
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length)) 2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data], mask_data, cache) masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data, cache=cache)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache. # Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data]) masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertIsNone(cache) self.assertIsNone(cache)
...@@ -243,10 +246,12 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -243,10 +246,12 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32) 2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly. # Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data], masked_output_data, cache = layer(
mask_data, query=from_data,
cache, value=from_data,
decode_loop_step=decode_loop_step) attention_mask=mask_data,
cache=cache,
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment