Unverified Commit 0cceabfc authored by Yiming Shi's avatar Yiming Shi Committed by GitHub
Browse files

Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor

parents 17821c0d 39ee0ac9
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from typing import List, Optional
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import electra_pretrainer
@dataclasses.dataclass
class ELECTRAPretrainerConfig(base_config.Config):
"""ELECTRA pretrainer configuration."""
num_masked_tokens: int = 76
sequence_length: int = 512
num_classes: int = 2
discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[bert.ClsHeadConfig]
) -> List[layers.ClassificationHead]:
if cls_head_configs:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
]
else:
return []
def instantiate_pretrainer_from_cfg(
config: ELECTRAPretrainerConfig,
generator_network: Optional[tf.keras.Model] = None,
discriminator_network: Optional[tf.keras.Model] = None,
) -> electra_pretrainer.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads),
disallow_correct=config.disallow_correct)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,53 +13,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
"""Tests of the block operators."""
import numpy as np
import tensorflow as tf
import block_base
import blocks_operator
class AddOneBlock(block_base.BlockBase):
def __init__(self, name=None):
super(AddOneBlock, self).__init__(name)
def _Apply(self, x):
return x + 1.0
class SquareBlock(block_base.BlockBase):
def __init__(self, name=None):
super(SquareBlock, self).__init__(name)
def _Apply(self, x):
return x * x
class BlocksOperatorTest(tf.test.TestCase):
def testComposition(self):
x_value = np.array([[1.0, 2.0, 3.0],
[-1.0, -2.0, -3.0]])
y_expected_value = np.array([[4.0, 9.0, 16.0],
[0.0, 1.0, 4.0]])
x = tf.placeholder(dtype=tf.float32, shape=[2, 3])
complex_block = blocks_operator.CompositionOperator(
[AddOneBlock(),
SquareBlock()])
y = complex_block(x)
with self.test_session():
y_value = y.eval(feed_dict={x: x_value})
self.assertAllClose(y_expected_value, y_value)
if __name__ == '__main__':
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
class ELECTRAModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
)
_ = electra.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = electra.instantiate_pretrainer_from_cfg(config)
if __name__ == "__main__":
tf.test.main()
......@@ -13,11 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configurations for Encoders."""
"""Transformer Encoders.
Includes configurations and instantiation methods.
"""
from typing import Optional
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers
from official.nlp.modeling import networks
@dataclasses.dataclass
......@@ -28,9 +35,64 @@ class TransformerEncoderConfig(base_config.Config):
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3076
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
def instantiate_encoder_from_cfg(
config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
max_seq_length=config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
dropout_rate=config.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_activation=tf_utils.get_activation(
config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers,
pooled_output_dim=config.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder":
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls))
encoder_network = encoder_cls(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
activation=tf_utils.get_activation(config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
embedding_width=config.embedding_size,
embedding_layer=embedding_layer)
return encoder_network
......@@ -31,9 +31,15 @@ from official.nlp.bert import tokenization
class InputExample(object):
"""A single training/test example for simple sequence classification."""
"""A single training/test example for simple seq regression/classification."""
def __init__(self, guid, text_a, text_b=None, label=None, weight=None):
def __init__(self,
guid,
text_a,
text_b=None,
label=None,
weight=None,
int_iden=None):
"""Constructs a InputExample.
Args:
......@@ -42,16 +48,20 @@ class InputExample(object):
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
label: (Optional) string for classification, float for regression. The
label of the example. This should be specified for train and dev
examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
training.
int_iden: (Optional) int. The int identification number of example in the
corpus.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
self.weight = weight
self.int_iden = int_iden
class InputFeatures(object):
......@@ -63,20 +73,24 @@ class InputFeatures(object):
segment_ids,
label_id,
is_real_example=True,
weight=None):
weight=None,
int_iden=None):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
self.weight = weight
self.int_iden = int_iden
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
"""Base class for converters for seq regression/classification datasets."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
self.process_text_fn = process_text_fn
self.is_regression = False
self.label_type = None
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
......@@ -110,92 +124,163 @@ class DataProcessor(object):
return lines
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
if language == "all":
self.languages = XnliProcessor.supported_languages
elif language not in XnliProcessor.supported_languages:
raise ValueError("language %s is not supported for XNLI task." % language)
else:
self.languages = [language]
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
lines = []
for language in self.languages:
# Skips the header.
lines.extend(
self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % language))[1:])
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "COLA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
for i, line in enumerate(lines):
# Only the test set has a header.
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = self.process_text_fn(line[1])
label = "0"
else:
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def __init__(self,
mnli_type="matched",
process_text_fn=tokenization.convert_to_unicode):
super(MnliProcessor, self).__init__(process_text_fn)
if mnli_type not in ("matched", "mismatched"):
raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
self.mnli_type = mnli_type
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_mismatched")
def get_test_examples(self, data_dir):
"""See base class."""
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % i
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[9])
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
language = self.process_text_fn(line[0])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples_by_lang[language].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XNLI"
return "MRPC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
if set_type == "test":
label = "0"
else:
label = self.process_text_fn(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class PawsxProcessor(DataProcessor):
"""Processor for the PAWS-X data set."""
supported_languages = [
"de", "en", "es", "fr", "ja", "ko", "zh"
]
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def __init__(self,
language="en",
......@@ -219,11 +304,10 @@ class PawsxProcessor(DataProcessor):
train_tsv = "translated_train.tsv"
# Skips the header.
lines.extend(
self._read_tsv(
os.path.join(data_dir, language, train_tsv))[1:])
self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
......@@ -235,13 +319,12 @@ class PawsxProcessor(DataProcessor):
def get_dev_examples(self, data_dir):
"""See base class."""
lines = []
for language in PawsxProcessor.supported_languages:
# Skips the header.
for lang in PawsxProcessor.supported_languages:
lines.extend(
self._read_tsv(os.path.join(data_dir, language, "dev_2k.tsv"))[1:])
self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:])
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
......@@ -252,17 +335,15 @@ class PawsxProcessor(DataProcessor):
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in PawsxProcessor.supported_languages}
for language in PawsxProcessor.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, language, "test_2k.tsv"))
for (i, line) in enumerate(lines):
if i == 0:
continue
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:]
for i, line in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
label = self.process_text_fn(line[3])
examples_by_lang[language].append(
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
......@@ -273,57 +354,11 @@ class PawsxProcessor(DataProcessor):
@staticmethod
def get_processor_name():
"""See base class."""
return "PAWS-X"
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return "XTREME-PAWS-X"
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[9])
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
......@@ -333,7 +368,7 @@ class MrpcProcessor(DataProcessor):
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
......@@ -342,26 +377,28 @@ class MrpcProcessor(DataProcessor):
def get_labels(self):
"""See base class."""
return ["0", "1"]
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MRPC"
return "QNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
guid = "%s-%s" % (set_type, 1)
if set_type == "test":
label = "0"
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = "entailment"
else:
label = self.process_text_fn(line[0])
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -395,9 +432,9 @@ class QqpProcessor(DataProcessor):
return "QQP"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
......@@ -407,13 +444,13 @@ class QqpProcessor(DataProcessor):
label = line[5]
except IndexError:
continue
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b,
label=label))
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
......@@ -432,29 +469,30 @@ class ColaProcessor(DataProcessor):
def get_labels(self):
"""See base class."""
return ["0", "1"]
# All datasets are converted to 2-class split, where for 3-class datasets we
# collapse neutral and contradiction into not_entailment.
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "COLA"
return "RTE"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
if set_type == "test" and i == 0:
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
if set_type == "test":
text_a = self.process_text_fn(line[1])
label = "0"
label = "entailment"
else:
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
label = tokenization.convert_to_unicode(line[3])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -486,9 +524,9 @@ class SstProcessor(DataProcessor):
return "SST-2"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
......@@ -503,8 +541,14 @@ class SstProcessor(DataProcessor):
return examples
class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""
class StsBProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(StsBProcessor, self).__init__(process_text_fn=process_text_fn)
self.is_regression = True
self.label_type = float
self._labels = None
def get_train_examples(self, data_dir):
"""See base class."""
......@@ -514,7 +558,7 @@ class QnliProcessor(DataProcessor):
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
......@@ -523,28 +567,26 @@ class QnliProcessor(DataProcessor):
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
return self._labels
@staticmethod
def get_processor_name():
"""See base class."""
return "QNLI"
return "STS-B"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, 1)
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[7])
text_b = tokenization.convert_to_unicode(line[8])
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = "entailment"
label = 0.0
else:
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
label = self.label_type(tokenization.convert_to_unicode(line[9]))
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -564,6 +606,8 @@ class TfdsProcessor(DataProcessor):
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2,"
"is_regression=true,label_type=float"
tfds_params="dataset=snli,text_key=premise,text_b_key=hypothesis,"
"skip_label=-1"
Possible parameters (please refer to the documentation of Tensorflow Datasets
(TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number).
......@@ -581,17 +625,19 @@ class TfdsProcessor(DataProcessor):
label_type: Type of the label key (defaults to `int`).
weight_key: Key of the float sample weight (is not used if not provided).
is_regression: Whether the task is a regression problem (defaults to False).
skip_label: Skip examples with given label (defaults to None).
"""
def __init__(self, tfds_params,
def __init__(self,
tfds_params,
process_text_fn=tokenization.convert_to_unicode):
super(TfdsProcessor, self).__init__(process_text_fn)
self._process_tfds_params_str(tfds_params)
if self.module_import:
importlib.import_module(self.module_import)
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir,
with_info=True)
self.dataset, info = tfds.load(
self.dataset_name, data_dir=self.data_dir, with_info=True)
if self.is_regression:
self._labels = None
else:
......@@ -619,6 +665,9 @@ class TfdsProcessor(DataProcessor):
self.label_type = dtype_map[d.get("label_type", "int")]
self.is_regression = cast_str_to_bool(d.get("is_regression", "False"))
self.weight_key = d.get("weight_key", None)
self.skip_label = d.get("skip_label", None)
if self.skip_label is not None:
self.skip_label = self.label_type(self.skip_label)
def get_train_examples(self, data_dir):
assert data_dir is None
......@@ -639,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return "TFDS_" + self.dataset_name
def _create_examples(self, split_name, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
if split_name not in self.dataset:
raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator()
......@@ -657,13 +706,258 @@ class TfdsProcessor(DataProcessor):
if self.text_b_key:
text_b = self.process_text_fn(example[self.text_b_key])
label = self.label_type(example[self.label_key])
if self.skip_label is not None and label == self.skip_label:
continue
if self.weight_key:
weight = float(example[self.weight_key])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=weight))
InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
label=label,
weight=weight))
return examples
class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "WNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[3])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
if language == "all":
self.languages = XnliProcessor.supported_languages
elif language not in XnliProcessor.supported_languages:
raise ValueError("language %s is not supported for XNLI task." % language)
else:
self.languages = [language]
def get_train_examples(self, data_dir):
"""See base class."""
lines = []
for language in self.languages:
# Skips the header.
lines.extend(
self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % language))[1:])
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % i
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for i, line in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
language = self.process_text_fn(line[0])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples_by_lang[language].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XNLI"
class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for i, line in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "0"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-PAWS-X"
class XtremeXnliProcessor(DataProcessor):
"""Processor for the XTREME XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for i, line in enumerate(lines):
guid = f"test-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "contradiction"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-XNLI"
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer):
......@@ -748,8 +1042,9 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %d)", example.label, label_id)
logging.info("label: %s (id = %s)", example.label, str(label_id))
logging.info("weight: %s", example.weight)
logging.info("int_iden: %s", str(example.int_iden))
feature = InputFeatures(
input_ids=input_ids,
......@@ -757,19 +1052,24 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
segment_ids=segment_ids,
label_id=label_id,
is_real_example=True,
weight=example.weight)
weight=example.weight,
int_iden=example.int_iden)
return feature
def file_based_convert_examples_to_features(examples, label_list,
max_seq_length, tokenizer,
output_file, label_type=None):
def file_based_convert_examples_to_features(examples,
label_list,
max_seq_length,
tokenizer,
output_file,
label_type=None):
"""Convert a set of `InputExample`s to a TFRecord file."""
tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples):
for ex_index, example in enumerate(examples):
if ex_index % 10000 == 0:
logging.info("Writing example %d of %d", ex_index, len(examples))
......@@ -779,6 +1079,7 @@ def file_based_convert_examples_to_features(examples, label_list,
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
......@@ -789,12 +1090,14 @@ def file_based_convert_examples_to_features(examples, label_list,
features["segment_ids"] = create_int_feature(feature.segment_ids)
if label_type is not None and label_type == float:
features["label_ids"] = create_float_feature([feature.label_id])
else:
elif feature.label_id is not None:
features["label_ids"] = create_int_feature([feature.label_id])
features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)])
if feature.weight is not None:
features["weight"] = create_float_feature([feature.weight])
if feature.int_iden is not None:
features["int_iden"] = create_int_feature([feature.int_iden])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
......@@ -830,8 +1133,7 @@ def generate_tf_record_from_data_file(processor,
Arguments:
processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files
should be in from "dev.tsv", "test.tsv", or "train.tsv".
data_dir: Directory that contains train/eval/test data to process.
tokenizer: The tokenizer to be applied on the data.
train_data_output_path: Output to which processed tf record for training
will be saved.
......@@ -857,8 +1159,7 @@ def generate_tf_record_from_data_file(processor,
train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer,
train_data_output_path,
label_type)
train_data_output_path, label_type)
num_training_data = len(train_input_data_examples)
if eval_data_output_path:
......@@ -868,26 +1169,27 @@ def generate_tf_record_from_data_file(processor,
tokenizer, eval_data_output_path,
label_type)
meta_data = {
"processor_type": processor.get_processor_name(),
"train_data_size": num_training_data,
"max_seq_length": max_seq_length,
}
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(data_dir)
if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items():
file_based_convert_examples_to_features(
examples,
label_list, max_seq_length,
tokenizer, test_data_output_path.format(language),
label_type)
examples, label_list, max_seq_length, tokenizer,
test_data_output_path.format(language), label_type)
meta_data["test_{}_data_size".format(language)] = len(examples)
else:
file_based_convert_examples_to_features(test_input_data_examples,
label_list, max_seq_length,
tokenizer, test_data_output_path,
label_type)
meta_data["test_data_size"] = len(test_input_data_examples)
meta_data = {
"processor_type": processor.get_processor_name(),
"train_data_size": num_training_data,
"max_seq_length": max_seq_length,
}
if is_regression:
meta_data["task_type"] = "bert_regression"
meta_data["label_type"] = {int: "int", float: "float"}[label_type]
......@@ -900,12 +1202,4 @@ def generate_tf_record_from_data_file(processor,
if eval_data_output_path:
meta_data["eval_data_size"] = len(eval_input_data_examples)
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(data_dir)
if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items():
meta_data["test_{}_data_size".format(language)] = len(examples)
else:
meta_data["test_data_size"] = len(test_input_data_examples)
return meta_data
......@@ -27,18 +27,21 @@ from absl import flags
import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
from official.nlp.data import sentence_retrieval_lib
# word-piece tokenizer based squad_lib
from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp
from official.nlp.data import tagging_data_lib
FLAGS = flags.FLAGS
# TODO(chendouble): consider moving each task to its own binary.
flags.DEFINE_enum(
"fine_tuning_task_type", "classification",
["classification", "regression", "squad"],
["classification", "regression", "squad", "retrieval", "tagging"],
"The name of the BERT fine tuning task for which data "
"will be generated..")
"will be generated.")
# BERT classification specific flags.
flags.DEFINE_string(
......@@ -47,23 +50,41 @@ flags.DEFINE_string(
"for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI",
"PAWS-X"],
"The name of the task to train BERT classifier.")
["COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI",
"XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
# MNLI task-specific flag.
flags.DEFINE_enum(
"mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task specific flag.
# XNLI task-specific flag.
flags.DEFINE_string(
"xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data "
"Language of training data for XNLI task. If the value is 'all', the data "
"of all languages will be used for training.")
# PAWS-X task specific flag.
# PAWS-X task-specific flag.
flags.DEFINE_string(
"pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.")
# BERT Squad task specific flags.
# Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
# Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.")
# BERT Squad task-specific flags.
flags.DEFINE_string(
"squad_data_file", None,
"The input data file in for generating training data for BERT squad task.")
......@@ -163,20 +184,29 @@ def generate_classifier_dataset():
"cola":
classifier_data_lib.ColaProcessor,
"mnli":
classifier_data_lib.MnliProcessor,
functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type),
"mrpc":
classifier_data_lib.MrpcProcessor,
"qnli":
classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor,
"rte": classifier_data_lib.RteProcessor,
"sst-2":
classifier_data_lib.SstProcessor,
"sts-b":
classifier_data_lib.StsBProcessor,
"xnli":
functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language),
"paws-x":
functools.partial(classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language)
language=FLAGS.pawsx_language),
"wnli": classifier_data_lib.WnliProcessor,
"xtreme-xnli":
functools.partial(classifier_data_lib.XtremeXnliProcessor),
"xtreme-paws-x":
functools.partial(classifier_data_lib.XtremePawsxProcessor)
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
......@@ -237,6 +267,67 @@ def generate_squad_dataset():
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
def generate_retrieval_dataset():
"""Generate retrieval test and dev dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
processors = {
"bucc": sentence_retrieval_lib.BuccProcessor,
"tatoeba": sentence_retrieval_lib.TatoebaProcessor,
}
task_name = FLAGS.retrieval_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
processor = processors[task_name](process_text_fn=processor_text_fn)
return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
processor,
FLAGS.input_data_dir,
tokenizer,
FLAGS.eval_data_output_path,
FLAGS.test_data_output_path,
FLAGS.max_seq_length)
def generate_tagging_dataset():
"""Generates tagging dataset."""
processors = {
"panx": tagging_data_lib.PanxProcessor,
"udpos": tagging_data_lib.UdposProcessor,
}
task_name = FLAGS.tagging_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
elif FLAGS.tokenizer_impl == "sentence_piece":
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
else:
raise ValueError("Unsupported tokenizer_impl: %s" % FLAGS.tokenizer_impl)
processor = processors[task_name]()
return tagging_data_lib.generate_tf_record_from_data_file(
processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
FLAGS.test_data_output_path, processor_text_fn)
def main(_):
if FLAGS.tokenizer_impl == "word_piece":
if not FLAGS.vocab_file:
......@@ -248,12 +339,20 @@ def main(_):
raise ValueError(
"FLAG sp_model_file for sentence-piece tokenizer is not specified.")
if FLAGS.fine_tuning_task_type != "retrieval":
flags.mark_flag_as_required("train_data_output_path")
if FLAGS.fine_tuning_task_type == "classification":
input_meta_data = generate_classifier_dataset()
elif FLAGS.fine_tuning_task_type == "regression":
input_meta_data = generate_regression_dataset()
else:
elif FLAGS.fine_tuning_task_type == "retrieval":
input_meta_data = generate_retrieval_dataset()
elif FLAGS.fine_tuning_task_type == "squad":
input_meta_data = generate_squad_dataset()
else:
assert FLAGS.fine_tuning_task_type == "tagging"
input_meta_data = generate_tagging_dataset()
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
......@@ -261,6 +360,5 @@ def main(_):
if __name__ == "__main__":
flags.mark_flag_as_required("train_data_output_path")
flags.mark_flag_as_required("meta_data_file_path")
app.run(main)
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
import itertools
import random
from absl import app
......@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.")
flags.DEFINE_integer(
"max_ngram_size", None,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
flags.DEFINE_bool(
"gzip_compress", False,
"Whether to use `GZIP` compress option to get compressed TFRecord files.")
......@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob,
max_predictions_per_seq,
rng,
do_whole_word_mask=False):
do_whole_word_mask=False,
max_ngram_size=None):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
......@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask))
do_whole_word_mask, max_ngram_size))
rng.shuffle(instances)
return instances
......@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask=False):
do_whole_word_mask=False,
max_ngram_size=None):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
......@@ -337,7 +346,7 @@ def create_instances_from_document(
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)
do_whole_word_mask, max_ngram_size)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
......@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram = collections.namedtuple("_Gram", ["begin", "end"])
def _window(iterable, size):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i = iter(iterable)
window = []
try:
for e in range(0, size):
window.append(next(i))
yield window
except StopIteration:
# handle the case where iterable's length is less than the window size.
return
for e in i:
window = window[1:] + [e]
yield window
def _contiguous(sorted_grams):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for a, b in _window(sorted_grams, 2):
if a.end != b.begin:
return False
return True
def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if not grams:
return None
grams = sorted(grams)
num_tokens = grams[-1].end
# Ensure our grams are valid (i.e., they don't overlap).
for a, b in _window(grams, 2):
if a.end > b.begin:
raise ValueError("overlapping grams: {}".format(grams))
# Build map from n-gram length to list of n-grams.
ngrams = {i: [] for i in range(1, max_ngram_size+1)}
for gram_size in range(1, max_ngram_size+1):
for g in _window(grams, gram_size):
if _contiguous(g):
# Add an n-gram which spans these one-grams.
ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
# Shuffle each list of n-grams.
for v in ngrams.values():
rng.shuffle(v)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights = list(
itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
output_ngrams = []
# Keep a bitmask of which tokens have been masked.
masked_tokens = [False] * num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while (sum(masked_tokens) < max_masked_tokens and
sum(len(s) for s in ngrams.values())):
# Pick an n-gram size based on our weights.
sz = random.choices(range(1, max_ngram_size+1),
cum_weights=cummulative_weights)[0]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if sum(masked_tokens) + sz > max_masked_tokens:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams[sz].clear()
continue
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask):
"""Creates the predictions for the masked LM objective."""
# All of the n-grams of this size have been used.
if not ngrams[sz]:
continue
# Choose a random n-gram of the given size.
gram = ngrams[sz].pop()
num_gram_tokens = gram.end-gram.begin
# Check if this would add too many tokens.
if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
continue
# Check if any of the tokens in this gram have already been masked.
if sum(masked_tokens[gram.begin:gram.end]):
continue
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
# Found a usable n-gram! Mark its tokens as masked and add it to return.
masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
output_ngrams.append(gram)
return output_ngrams
def _wordpieces_to_grams(tokens):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams = []
gram_start_pos = None
for i, token in enumerate(tokens):
if gram_start_pos is not None and token.startswith("##"):
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
if gram_start_pos is not None:
grams.append(_Gram(gram_start_pos, i))
if token not in ["[CLS]", "[SEP]"]:
gram_start_pos = i
else:
cand_indexes.append([i])
gram_start_pos = None
if gram_start_pos is not None:
grams.append(_Gram(gram_start_pos, len(tokens)))
return grams
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask,
max_ngram_size=None):
"""Creates the predictions for the masked LM objective."""
if do_whole_word_mask:
grams = _wordpieces_to_grams(tokens)
else:
# Here we consider each token to be a word to allow for sub-word masking.
if max_ngram_size:
raise ValueError("cannot use ngram masking without whole word masking")
grams = [_Gram(i, i+1) for i in range(0, len(tokens))
if tokens[i] not in ["[CLS]", "[SEP]"]]
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
num_to_predict, rng)
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
output_tokens = list(tokens)
for gram in masked_grams:
# 80% of the time, replace all n-gram tokens with [MASK]
if rng.random() < 0.8:
replacement_action = lambda idx: "[MASK]"
else:
# 10% of the time, keep all the original n-gram tokens.
if rng.random() < 0.5:
replacement_action = lambda idx: tokens[idx]
# 10% of the time, replace each n-gram token with a random word.
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
replacement_action = lambda idx: rng.choice(vocab_words)
output_tokens[index] = masked_token
for idx in range(gram.begin, gram.end):
output_tokens[idx] = replacement_action(idx)
masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
......@@ -467,7 +642,7 @@ def main(_):
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng, FLAGS.do_whole_word_mask)
rng, FLAGS.do_whole_word_mask, FLAGS.max_ngram_size)
output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***")
......
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,23 +13,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to access NLP registered data loaders."""
"""Define the abstract class for contextual bandit algorithms."""
from official.utils import registry
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
_REGISTERED_DATA_LOADER_CLS = {}
class BanditAlgorithm(object):
"""A bandit algorithm must be able to do two basic operations.
def register_data_loader_cls(data_config_cls):
"""Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
1. Choose an action given a context.
2. Update its internal model given a triple (context, played action, reward).
"""
This decorator supports registration of data loaders as follows:
def action(self, context):
```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass
def update(self, context, action, reward):
@register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
def get_data_loader(data_config):
"""Creates a data_loader from data_config."""
return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(
data_config)
......@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ''
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader:
"""A class to load dataset for bert pretraining task."""
......@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params,
decoder_fn=self._decode,
parser_fn=self._parse)
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ''
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
# Settings below are question answering specific.
version_2_with_negative: bool = False
# Settings below are only used for eval mode.
input_preprocessed_data_path: str = ''
doc_stride: int = 128
query_length: int = 64
vocab_file: str = ''
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
do_lower_case: bool = True
@data_loader_factory.register_data_loader_cls(QADataConfig)
class QuestionAnsweringDataLoader:
"""A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params):
self._params = params
self._seq_length = params.seq_length
self._is_training = params.is_training
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
else:
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x, y = {}, {}
for name, tensor in record.items():
if name in ('start_positions', 'end_positions'):
y[name] = tensor
elif name == 'input_ids':
x['input_word_ids'] = tensor
elif name == 'segment_ids':
x['input_type_ids'] = tensor
else:
x[name] = tensor
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
......@@ -15,11 +15,28 @@
# ==============================================================================
"""Loads dataset for the sentence prediction (classification) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ''
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
label_type: str = 'int'
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
class SentencePredictionDataLoader:
"""A class to load dataset for sentence prediction (classification) task."""
......@@ -29,11 +46,12 @@ class SentencePredictionDataLoader:
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
label_type = LABEL_TYPES_MAP[self._params.label_type]
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64),
'label_ids': tf.io.FixedLenFeature([], label_type),
}
example = tf.io.parse_single_example(record, name_to_features)
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT library to process data for cross lingual sentence retrieval task."""
import os
from absl import logging
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
class BuccProcessor(classifier_data_lib.DataProcessor):
"""Procssor for Xtreme BUCC data set."""
supported_languages = ["de", "fr", "ru", "zh"]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode):
super(BuccProcessor, self).__init__(process_text_fn)
self.languages = BuccProcessor.supported_languages
def get_dev_examples(self, data_dir, file_pattern):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))),
"sample")
def get_test_examples(self, data_dir, file_pattern):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))),
"test")
@staticmethod
def get_processor_name():
"""See base class."""
return "BUCC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
int_iden = int(line[0].split("-")[1])
text_a = self.process_text_fn(line[1])
examples.append(
classifier_data_lib.InputExample(
guid=guid, text_a=text_a, int_iden=int_iden))
return examples
class TatoebaProcessor(classifier_data_lib.DataProcessor):
"""Procssor for Xtreme Tatoeba data set."""
supported_languages = [
"af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr",
"he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr",
"nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode):
super(TatoebaProcessor, self).__init__(process_text_fn)
self.languages = TatoebaProcessor.supported_languages
def get_test_examples(self, data_dir, file_path):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, file_path)), "test")
@staticmethod
def get_processor_name():
"""See base class."""
return "TATOEBA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[0])
examples.append(
classifier_data_lib.InputExample(
guid=guid, text_a=text_a, int_iden=i))
return examples
def generate_sentence_retrevial_tf_record(processor,
data_dir,
tokenizer,
eval_data_output_path=None,
test_data_output_path=None,
max_seq_length=128):
"""Generates the tf records for retrieval tasks.
Args:
processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files
should be in from.
tokenizer: The tokenizer to be applied on the data.
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
test_data_output_path: Output to which processed tf record for testing
will be saved. Must be a pattern template with {} if processor has
language specific test data.
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
Returns:
A dictionary containing input meta data.
"""
assert eval_data_output_path or test_data_output_path
if processor.get_processor_name() == "BUCC":
path_pattern = "{}-en.{{}}.{}"
if processor.get_processor_name() == "TATOEBA":
path_pattern = "{}-en.{}"
meta_data = {
"processor_type": processor.get_processor_name(),
"max_seq_length": max_seq_length,
"number_eval_data": {},
"number_test_data": {},
}
logging.info("Start to process %s task data", processor.get_processor_name())
for lang_a in processor.languages:
for lang_b in [lang_a, "en"]:
if eval_data_output_path:
eval_input_data_examples = processor.get_dev_examples(
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
num_eval_data = len(eval_input_data_examples)
logging.info("Processing %d dev examples of %s-en.%s", num_eval_data,
lang_a, lang_b)
output_file = os.path.join(
eval_data_output_path,
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev"))
classifier_data_lib.file_based_convert_examples_to_features(
eval_input_data_examples, None, max_seq_length, tokenizer,
output_file, None)
meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
num_test_data = len(test_input_data_examples)
logging.info("Processing %d test examples of %s-en.%s", num_test_data,
lang_a, lang_b)
output_file = os.path.join(
test_data_output_path,
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test"))
classifier_data_lib.file_based_convert_examples_to_features(
test_input_data_examples, None, max_seq_length, tokenizer,
output_file, None)
meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data
return meta_data
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to process data for tagging task such as NER/POS."""
import collections
import os
from absl import logging
import tensorflow as tf
from official.nlp.data import classifier_data_lib
# A negative label id for the padding label, which will not contribute
# to loss/metrics in training.
_PADDING_LABEL_ID = -1
# The special unknown token, used to substitute a word which has too many
# subwords after tokenization.
_UNK_TOKEN = "[UNK]"
class InputExample(object):
"""A single training/test example for token classification."""
def __init__(self, sentence_id, words=None, label_ids=None):
"""Constructs an InputExample."""
self.sentence_id = sentence_id
self.words = words if words else []
self.label_ids = label_ids if label_ids else []
def add_word_and_label_id(self, word, label_id):
"""Adds word and label_id pair in the example."""
self.words.append(word)
self.label_ids.append(label_id)
def _read_one_file(file_name, label_list):
"""Reads one file and returns a list of `InputExample` instances."""
lines = tf.io.gfile.GFile(file_name, "r").readlines()
examples = []
label_id_map = {label: i for i, label in enumerate(label_list)}
sentence_id = 0
example = InputExample(sentence_id=0)
for line in lines:
line = line.strip("\n")
if line:
# The format is: <token>\t<label> for train/dev set and <token> for test.
items = line.split("\t")
assert len(items) == 2 or len(items) == 1
token = items[0].strip()
# Assign a dummy label_id for test set
label_id = label_id_map[items[1].strip()] if len(items) == 2 else 0
example.add_word_and_label_id(token, label_id)
else:
# Empty line indicates a new sentence.
if example.words:
examples.append(example)
sentence_id += 1
example = InputExample(sentence_id=sentence_id)
if example.words:
examples.append(example)
return examples
class PanxProcessor(classifier_data_lib.DataProcessor):
"""Processor for the Panx data set."""
supported_languages = [
"ar", "he", "vi", "id", "jv", "ms", "tl", "eu", "ml", "ta", "te", "af",
"nl", "en", "de", "el", "bn", "hi", "mr", "ur", "fa", "fr", "it", "pt",
"es", "bg", "ru", "ja", "ka", "ko", "th", "sw", "yo", "my", "zh", "kk",
"tr", "et", "fi", "hu"
]
def get_train_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
def get_dev_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
def get_test_examples(self, data_dir):
examples_dict = {}
for language in self.supported_languages:
examples_dict[language] = _read_one_file(
os.path.join(data_dir, "test-%s.tsv" % language), self.get_labels())
return examples_dict
def get_labels(self):
return ["O", "B-PER", "I-PER", "B-LOC", "I-LOC", "B-ORG", "I-ORG"]
@staticmethod
def get_processor_name():
return "panx"
class UdposProcessor(classifier_data_lib.DataProcessor):
"""Processor for the Udpos data set."""
supported_languages = [
"af", "ar", "bg", "de", "el", "en", "es", "et", "eu", "fa", "fi", "fr",
"he", "hi", "hu", "id", "it", "ja", "kk", "ko", "mr", "nl", "pt", "ru",
"ta", "te", "th", "tl", "tr", "ur", "vi", "yo", "zh"
]
def get_train_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
def get_dev_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
def get_test_examples(self, data_dir):
examples_dict = {}
for language in self.supported_languages:
examples_dict[language] = _read_one_file(
os.path.join(data_dir, "test-%s.tsv" % language), self.get_labels())
return examples_dict
def get_labels(self):
return [
"ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", "NOUN", "NUM",
"PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X"
]
@staticmethod
def get_processor_name():
return "udpos"
def _tokenize_example(example, max_length, tokenizer, text_preprocessing=None):
"""Tokenizes words and breaks long example into short ones."""
# Needs additional [CLS] and [SEP] tokens.
max_length = max_length - 2
new_examples = []
new_example = InputExample(sentence_id=example.sentence_id)
for i, word in enumerate(example.words):
if any([x < 0 for x in example.label_ids]):
raise ValueError("Unexpected negative label_id: %s" % example.label_ids)
if text_preprocessing:
word = text_preprocessing(word)
subwords = tokenizer.tokenize(word)
if (not subwords or len(subwords) > max_length) and word:
subwords = [_UNK_TOKEN]
if len(subwords) + len(new_example.words) > max_length:
# Start a new example.
new_examples.append(new_example)
new_example = InputExample(sentence_id=example.sentence_id)
for j, subword in enumerate(subwords):
# Use the real label for the first subword, and pad label for
# the remainings.
subword_label = example.label_ids[i] if j == 0 else _PADDING_LABEL_ID
new_example.add_word_and_label_id(subword, subword_label)
if new_example.words:
new_examples.append(new_example)
return new_examples
def _convert_single_example(example, max_seq_length, tokenizer):
"""Converts an `InputExample` instance to a `tf.train.Example` instance."""
tokens = ["[CLS]"]
tokens.extend(example.words)
tokens.append("[SEP]")
input_ids = tokenizer.convert_tokens_to_ids(tokens)
label_ids = [_PADDING_LABEL_ID]
label_ids.extend(example.label_ids)
label_ids.append(_PADDING_LABEL_ID)
segment_ids = [0] * len(input_ids)
input_mask = [1] * len(input_ids)
# Pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(_PADDING_LABEL_ID)
def create_int_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["segment_ids"] = create_int_feature(segment_ids)
features["label_ids"] = create_int_feature(label_ids)
features["sentence_id"] = create_int_feature([example.sentence_id])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
return tf_example
def write_example_to_file(examples,
tokenizer,
max_seq_length,
output_file,
text_preprocessing=None):
"""Writes `InputExample`s into a tfrecord file with `tf.train.Example` protos.
Note that the words inside each example will be tokenized and be applied by
`text_preprocessing` if available. Also, if the length of sentence (plus
special [CLS] and [SEP] tokens) exceeds `max_seq_length`, the long sentence
will be broken into multiple short examples. For example:
Example (text_preprocessing=lowercase, max_seq_length=5)
words: ["What", "a", "great", "weekend"]
labels: [ 7, 5, 9, 10]
sentence_id: 0
preprocessed: ["what", "a", "great", "weekend"]
tokenized: ["what", "a", "great", "week", "##end"]
will result in two tf.example protos:
tokens: ["[CLS]", "what", "a", "great", "[SEP]"]
label_ids: [-1, 7, 5, 9, -1]
input_mask: [ 1, 1, 1, 1, 1]
segment_ids: [ 0, 0, 0, 0, 0]
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
sentence_id: 0
tokens: ["[CLS]", "week", "##end", "[SEP]", "[PAD]"]
label_ids: [-1, 10, -1, -1, -1]
input_mask: [ 1, 1, 1, 0, 0]
segment_ids: [ 0, 0, 0, 0, 0]
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
sentence_id: 0
Note the use of -1 in `label_ids` to indicate that a token should not be
considered for classification (e.g., trailing ## wordpieces or special
token). Token classification models should accordingly ignore these when
calculating loss, metrics, etc...
Args:
examples: A list of `InputExample` instances.
tokenizer: The tokenizer to be applied on the data.
max_seq_length: Maximum length of generated sequences.
output_file: The name of the output tfrecord file.
text_preprocessing: optional preprocessing run on each word prior to
tokenization.
Returns:
The total number of tf.train.Example proto written to file.
"""
tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file)
num_tokenized_examples = 0
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logging.info("Writing example %d of %d to %s", ex_index, len(examples),
output_file)
tokenized_examples = _tokenize_example(example, max_seq_length,
tokenizer, text_preprocessing)
num_tokenized_examples += len(tokenized_examples)
for per_tokenized_example in tokenized_examples:
tf_example = _convert_single_example(
per_tokenized_example, max_seq_length, tokenizer)
writer.write(tf_example.SerializeToString())
writer.close()
return num_tokenized_examples
def token_classification_meta_data(train_data_size,
max_seq_length,
num_labels,
eval_data_size=None,
test_data_size=None,
label_list=None,
processor_type=None):
"""Creates metadata for tagging (token classification) datasets."""
meta_data = {
"train_data_size": train_data_size,
"max_seq_length": max_seq_length,
"num_labels": num_labels,
"task_type": "tagging",
"label_type": "int",
"label_shape": [max_seq_length],
}
if eval_data_size:
meta_data["eval_data_size"] = eval_data_size
if test_data_size:
meta_data["test_data_size"] = test_data_size
if label_list:
meta_data["label_list"] = label_list
if processor_type:
meta_data["processor_type"] = processor_type
return meta_data
def generate_tf_record_from_data_file(processor,
data_dir,
tokenizer,
max_seq_length,
train_data_output_path,
eval_data_output_path,
test_data_output_path,
text_preprocessing):
"""Generates tfrecord files from the raw data."""
common_kwargs = dict(tokenizer=tokenizer, max_seq_length=max_seq_length,
text_preprocessing=text_preprocessing)
train_examples = processor.get_train_examples(data_dir)
train_data_size = write_example_to_file(
train_examples, output_file=train_data_output_path, **common_kwargs)
eval_examples = processor.get_dev_examples(data_dir)
eval_data_size = write_example_to_file(
eval_examples, output_file=eval_data_output_path, **common_kwargs)
test_input_data_examples = processor.get_test_examples(data_dir)
test_data_size = {}
for language, examples in test_input_data_examples.items():
test_data_size[language] = write_example_to_file(
examples,
output_file=test_data_output_path.format(language),
**common_kwargs)
labels = processor.get_labels()
meta_data = token_classification_meta_data(
train_data_size,
max_seq_length,
len(labels),
eval_data_size,
test_data_size,
label_list=labels,
processor_type=processor.get_processor_name())
return meta_data
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
is_training: bool = True
seq_length: int = 128
include_sentence_id: bool = False
@data_loader_factory.register_data_loader_cls(TaggingDataConfig)
class TaggingDataLoader:
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def __init__(self, params: TaggingDataConfig):
self._params = params
self._seq_length = params.seq_length
self._include_sentence_id = params.include_sentence_id
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._include_sentence_id:
name_to_features['sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
if self._include_sentence_id:
x['sentence_id'] = record['sentence_id']
y = record['label_ids']
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
# NLP Modeling Library
This libary provides a set of Keras primitives (Layers, Networks, and Models)
This library provides a set of Keras primitives (Layers, Networks, and Models)
that can be assembled into transformer-based models. They are
flexible, validated, interoperable, and both TF1 and TF2 compatible.
......@@ -16,6 +16,11 @@ standardized configuration.
* [`losses`](losses) contains common loss computation used in NLP tasks.
Please see the colab
[nlp_modeling_library_intro.ipynb]
(https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb)
for how to build transformer-based NLP models using above primitives.
Besides the pre-defined primitives, it also provides scaffold classes to allow
easy experimentation with noval achitectures, e.g., you don’t need to fork a whole Transformer object to try a different kind of attention primitive, for instance.
......@@ -33,11 +38,9 @@ embedding subnetwork (which will replace the standard embedding logic) and/or a
custom hidden layer (which will replace the Transformer instantiation in the
encoder).
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder.
Please see the colab
[customize_encoder.ipynb]
(https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb)
for how to use scaffold classes to build noval achitectures.
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder.
......@@ -3,19 +3,18 @@
Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models.
* [DenseEinsum](dense_einsum.py) implements a feedforward network using
tf.einsum. This layer contains the einsum op, the associated weight, and the
logic required to generate the einsum expression for the given
initialization parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in
between query, key, value tensors as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
`from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache
used for auto-agressive decoding.
* [MultiChannelAttention](multi_channel_attention.py) implements an variant of
multi-head attention which can be used to merge multiple streams for
cross-attentions.
* [TalkingHeadsAttention](talking_heads_attention.py) implements the talking
heads attention, as decribed in
["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436).
......@@ -24,6 +23,10 @@ assemble new layers, networks, or models.
described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [TransformerDecoderLayer](transformer.py) TransformerDecoderLayer is made up
of self multi-head attention, cross multi-head attention and
feedforward network.
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with
ReZero described in
["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
......@@ -45,8 +48,8 @@ assemble new layers, networks, or models.
should be masked), the output will have masked positions set to
approximately zero.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes the
embedding table variable is passed to it.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes
the embedding table variable is passed to it.
* [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks.
......
......@@ -20,10 +20,12 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.transformer import Transformer
from official.nlp.modeling.layers.transformer import *
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
......@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(qkv_rank, attn_axes):
def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
......@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
<query attention dims>, num_heads, channels)
Args:
qkv_rank: the rank of query, key, value tensors.
rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns:
Einsum equations.
"""
target_notation = _CHR_IDX[:qkv_rank]
target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,)))
letter_offset = qkv_rank
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank
source_notation = ""
for i in range(qkv_rank):
if i in batch_dims or i == qkv_rank - 1:
for i in range(rank):
if i in batch_dims or i == rank - 1:
source_notation += target_notation[i]
else:
source_notation += _CHR_IDX[letter_offset]
......@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head
attention scores as an additional output argument.
return_attention_scores: bool, if `True`, returns the multi-head attention
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
......@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def __init__(self,
......@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._attention_axes = (attention_axes,)
else:
self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self):
config = {
......@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
inputs_len = len(input_shape)
if inputs_len > 3 or inputs_len < 2:
raise ValueError(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d" % inputs_len)
tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape)
query_shape = tensor_shapes[0]
value_shape = tensor_shapes[1]
key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape
def _build_from_signature(self, query, value, key=None):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: query tensor or TensorShape.
value: value tensor or TensorShape.
key: key tensor or TensorShape.
"""
self._built_from_signature = True
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
......@@ -271,84 +293,79 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_size]),
bias_axes=bias_axes if self._use_bias else None,
name="value",
**common_kwargs)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once it
# support mult-head einsum computations.
self._build_attention(output_rank)
if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape]
with tf.init_scope():
free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_size]),
bias_axes=bias_axes if self._use_bias else None,
name="value",
**common_kwargs)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
# it support mult-head einsum computations.
self.build_attention(output_rank)
if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape]
else:
output_shape = self._output_shape
else:
output_shape = self._output_shape
else:
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
self._output_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
super(MultiHeadAttention, self).build(input_shape)
def _build_attention(self, qkv_rank):
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
self._output_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
def build_attention(self, rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product
attention.
Args:
qkv_rank: the rank of query, key, value tensors.
rank: the rank of query, key, value tensors.
"""
if self._attention_axes is None:
self._attention_axes = tuple(range(1, qkv_rank - 2))
self._attention_axes = tuple(range(1, rank - 2))
else:
self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(qkv_rank, attn_axes=self._attention_axes))
_build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
def compute_attention(self, query, key, value, attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
......@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention implementation.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
......@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
attention_scores = tf.einsum(self._dot_product_equation, key, query)
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
......@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor)
attention_scores_dropout, value)
return attention_output, attention_scores
def call(self, inputs, attention_mask=None):
def call(self, query, value, key=None, attention_mask=None):
"""Implements the forward pass.
Size glossary:
......@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value (source) attention axes shape (S), the rank must match the target.
Args:
inputs: List of the following tensors:
* query: Query `Tensor` of shape `[B, T, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`.
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
use `value` for both `key` and `value`, which is the most common case.
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
......@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention
axes.
"""
inputs_len = len(inputs)
if inputs_len > 3 or inputs_len < 2:
raise ValueError(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d" % inputs_len)
query = inputs[0]
value = inputs[1]
key = inputs[2] if inputs_len == 3 else value
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, T, N ,H]
query_tensor = self._query_dense(query)
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key_tensor` = [B, S, N, H]
key_tensor = self._key_dense(key)
# `key` = [B, S, N, H]
key = self._key_dense(key)
# `value_tensor` = [B, S, N, H]
value_tensor = self._value_dense(value)
# `value` = [B, S, N, H]
value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention(
query_tensor, key_tensor, value_tensor, attention_mask)
attention_output, attention_scores = self.compute_attention(
query, key, value, attention_mask)
attention_output = self._output_dense(attention_output)
if self._return_attention_scores:
......@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
Arguments are the same as `MultiHeadAttention` layer.
"""
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
def _update_cache(self, key, value, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
if decode_loop_step is not None:
# TPU special case.
key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype),
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
[1, key_seq_dim, 1, 1])
key_tensor = cache["key"] + key_tensor * indices
key = cache["key"] + key * indices
value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype),
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
[1, value_seq_dim, 1, 1])
value_tensor = cache["value"] + value_tensor * indices
value = cache["value"] + value * indices
else:
key_tensor = tf.concat(
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1)
value_tensor = tf.concat(
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
# Update cache
cache["key"] = key_tensor
cache["value"] = value_tensor
cache["key"] = key
cache["value"] = value
return key_tensor, value_tensor
return key, value
def call(self,
inputs,
query,
value,
key=None,
attention_mask=None,
cache=None,
decode_loop_step=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
......@@ -494,25 +509,23 @@ class CachedAttention(MultiHeadAttention):
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `query` = [B, F, N ,H]
query = self._query_dense(query)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `key` = [B, T, N, H]
key = self._key_dense(key)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
# `value` = [B, T, N, H]
value = self._value_dense(value)
if cache:
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor,
cache, decode_loop_step)
key, value = self._update_cache(key, value, cache, decode_loop_step)
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
attention_scores = tf.einsum(self._dot_product_equation, key, query)
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, F, T]
......@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
attention_scores = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H]
attention_output = tf.einsum(self._combine_equation, attention_scores,
value_tensor)
value)
attention_output = self._output_dense(attention_output)
if self._return_attention_scores:
return attention_output, attention_scores, cache
......
......@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value])
output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
......@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self):
......@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query])
output, coef = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
......@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor)
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
......@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
......@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
......@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data)
output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data)
unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
......@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
......@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element.
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data], mask_data, cache)
masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data, cache=cache)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data])
masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertIsNone(cache)
......@@ -243,10 +246,12 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data],
mask_data,
cache,
decode_loop_step=decode_loop_step)
masked_output_data, cache = layer(
query=from_data,
value=from_data,
attention_mask=mask_data,
cache=cache,
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
......@@ -21,6 +21,8 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.util import deprecation
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
......@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer):
`(batch_size, units)`.
"""
@deprecation.deprecated(
None, "DenseEinsum is deprecated. Please use "
"tf.keras.experimental.EinsumDense layer instead.")
def __init__(self,
output_shape,
num_summed_dimensions=1,
......
......@@ -34,7 +34,7 @@ class MaskedLM(tf.keras.layers.Layer):
Arguments:
embedding_table: The embedding table of the targets.
activation: The activation, if any, for the dense layer.
initializer: The intializer for the dense layer. Defaults to a Glorot
initializer: The initializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment