Unverified Commit 0cceabfc authored by Yiming Shi's avatar Yiming Shi Committed by GitHub
Browse files

Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor

parents 17821c0d 39ee0ac9
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from typing import List, Optional
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import electra_pretrainer
@dataclasses.dataclass
class ELECTRAPretrainerConfig(base_config.Config):
"""ELECTRA pretrainer configuration."""
num_masked_tokens: int = 76
sequence_length: int = 512
num_classes: int = 2
discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[bert.ClsHeadConfig]
) -> List[layers.ClassificationHead]:
if cls_head_configs:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
]
else:
return []
def instantiate_pretrainer_from_cfg(
config: ELECTRAPretrainerConfig,
generator_network: Optional[tf.keras.Model] = None,
discriminator_network: Optional[tf.keras.Model] = None,
) -> electra_pretrainer.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads),
disallow_correct=config.disallow_correct)
# Copyright 2017 The TensorFlow Authors All Rights Reserved. # Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,53 +13,37 @@ ...@@ -12,53 +13,37 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
"""Tests of the block operators."""
import numpy as np
import tensorflow as tf import tensorflow as tf
import block_base from official.nlp.configs import bert
import blocks_operator from official.nlp.configs import electra
from official.nlp.configs import encoders
class AddOneBlock(block_base.BlockBase):
class ELECTRAModelsTest(tf.test.TestCase):
def __init__(self, name=None):
super(AddOneBlock, self).__init__(name) def test_network_invocation(self):
config = electra.ELECTRAPretrainerConfig(
def _Apply(self, x): generator_encoder=encoders.TransformerEncoderConfig(
return x + 1.0 vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
class SquareBlock(block_base.BlockBase): )
_ = electra.instantiate_pretrainer_from_cfg(config)
def __init__(self, name=None):
super(SquareBlock, self).__init__(name) # Invokes with classification heads.
config = electra.ELECTRAPretrainerConfig(
def _Apply(self, x): generator_encoder=encoders.TransformerEncoderConfig(
return x * x vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
class BlocksOperatorTest(tf.test.TestCase): cls_heads=[
bert.ClsHeadConfig(
def testComposition(self): inner_dim=10, num_classes=2, name="next_sentence")
x_value = np.array([[1.0, 2.0, 3.0], ])
[-1.0, -2.0, -3.0]]) _ = electra.instantiate_pretrainer_from_cfg(config)
y_expected_value = np.array([[4.0, 9.0, 16.0],
[0.0, 1.0, 4.0]]) if __name__ == "__main__":
x = tf.placeholder(dtype=tf.float32, shape=[2, 3])
complex_block = blocks_operator.CompositionOperator(
[AddOneBlock(),
SquareBlock()])
y = complex_block(x)
with self.test_session():
y_value = y.eval(feed_dict={x: x_value})
self.assertAllClose(y_expected_value, y_value)
if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -13,11 +13,18 @@ ...@@ -13,11 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Configurations for Encoders.""" """Transformer Encoders.
Includes configurations and instantiation methods.
"""
from typing import Optional
import dataclasses import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers
from official.nlp.modeling import networks
@dataclasses.dataclass @dataclasses.dataclass
...@@ -28,9 +35,64 @@ class TransformerEncoderConfig(base_config.Config): ...@@ -28,9 +35,64 @@ class TransformerEncoderConfig(base_config.Config):
num_layers: int = 12 num_layers: int = 12
num_attention_heads: int = 12 num_attention_heads: int = 12
hidden_activation: str = "gelu" hidden_activation: str = "gelu"
intermediate_size: int = 3076 intermediate_size: int = 3072
dropout_rate: float = 0.1 dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512 max_position_embeddings: int = 512
type_vocab_size: int = 2 type_vocab_size: int = 2
initializer_range: float = 0.02 initializer_range: float = 0.02
embedding_size: Optional[int] = None
def instantiate_encoder_from_cfg(
config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
max_seq_length=config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
dropout_rate=config.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_activation=tf_utils.get_activation(
config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers,
pooled_output_dim=config.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder":
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls))
encoder_network = encoder_cls(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
activation=tf_utils.get_activation(config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
embedding_width=config.embedding_size,
embedding_layer=embedding_layer)
return encoder_network
...@@ -31,9 +31,15 @@ from official.nlp.bert import tokenization ...@@ -31,9 +31,15 @@ from official.nlp.bert import tokenization
class InputExample(object): class InputExample(object):
"""A single training/test example for simple sequence classification.""" """A single training/test example for simple seq regression/classification."""
def __init__(self, guid, text_a, text_b=None, label=None, weight=None): def __init__(self,
guid,
text_a,
text_b=None,
label=None,
weight=None,
int_iden=None):
"""Constructs a InputExample. """Constructs a InputExample.
Args: Args:
...@@ -42,16 +48,20 @@ class InputExample(object): ...@@ -42,16 +48,20 @@ class InputExample(object):
sequence tasks, only this sequence must be specified. sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence. text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks. Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be label: (Optional) string for classification, float for regression. The
specified for train and dev examples, but not for test examples. label of the example. This should be specified for train and dev
examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during weight: (Optional) float. The weight of the example to be used during
training. training.
int_iden: (Optional) int. The int identification number of example in the
corpus.
""" """
self.guid = guid self.guid = guid
self.text_a = text_a self.text_a = text_a
self.text_b = text_b self.text_b = text_b
self.label = label self.label = label
self.weight = weight self.weight = weight
self.int_iden = int_iden
class InputFeatures(object): class InputFeatures(object):
...@@ -63,20 +73,24 @@ class InputFeatures(object): ...@@ -63,20 +73,24 @@ class InputFeatures(object):
segment_ids, segment_ids,
label_id, label_id,
is_real_example=True, is_real_example=True,
weight=None): weight=None,
int_iden=None):
self.input_ids = input_ids self.input_ids = input_ids
self.input_mask = input_mask self.input_mask = input_mask
self.segment_ids = segment_ids self.segment_ids = segment_ids
self.label_id = label_id self.label_id = label_id
self.is_real_example = is_real_example self.is_real_example = is_real_example
self.weight = weight self.weight = weight
self.int_iden = int_iden
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for converters for seq regression/classification datasets."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode): def __init__(self, process_text_fn=tokenization.convert_to_unicode):
self.process_text_fn = process_text_fn self.process_text_fn = process_text_fn
self.is_regression = False
self.label_type = None
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set.""" """Gets a collection of `InputExample`s for the train set."""
...@@ -110,92 +124,163 @@ class DataProcessor(object): ...@@ -110,92 +124,163 @@ class DataProcessor(object):
return lines return lines
class XnliProcessor(DataProcessor): class ColaProcessor(DataProcessor):
"""Processor for the XNLI data set.""" """Processor for the CoLA data set (GLUE version)."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
if language == "all":
self.languages = XnliProcessor.supported_languages
elif language not in XnliProcessor.supported_languages:
raise ValueError("language %s is not supported for XNLI task." % language)
else:
self.languages = [language]
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = [] return self._create_examples(
for language in self.languages: self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
# Skips the header.
lines.extend( def get_dev_examples(self, data_dir):
self._read_tsv( """See base class."""
os.path.join(data_dir, "multinli", return self._create_examples(
"multinli.train.%s.tsv" % language))[1:]) self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "COLA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i # Only the test set has a header.
text_a = self.process_text_fn(line[0]) if set_type == "test" and i == 0:
text_b = self.process_text_fn(line[1]) continue
label = self.process_text_fn(line[2]) guid = "%s-%s" % (set_type, i)
if label == self.process_text_fn("contradictory"): if set_type == "test":
label = self.process_text_fn("contradiction") text_a = self.process_text_fn(line[1])
label = "0"
else:
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def __init__(self,
mnli_type="matched",
process_text_fn=tokenization.convert_to_unicode):
super(MnliProcessor, self).__init__(process_text_fn)
if mnli_type not in ("matched", "mismatched"):
raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
self.mnli_type = mnli_type
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_mismatched")
def get_test_examples(self, data_dir):
"""See base class."""
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "dev-%d" % i guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[6]) text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[7]) text_b = self.process_text_fn(line[9])
label = self.process_text_fn(line[1]) if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv")) return self._create_examples(
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages} self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
language = self.process_text_fn(line[0])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples_by_lang[language].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["contradiction", "entailment", "neutral"] return ["0", "1"]
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "XNLI" return "MRPC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
if set_type == "test":
label = "0"
else:
label = self.process_text_fn(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class PawsxProcessor(DataProcessor): class PawsxProcessor(DataProcessor):
"""Processor for the PAWS-X data set.""" """Processor for the PAWS-X data set."""
supported_languages = [ supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
"de", "en", "es", "fr", "ja", "ko", "zh"
]
def __init__(self, def __init__(self,
language="en", language="en",
...@@ -219,11 +304,10 @@ class PawsxProcessor(DataProcessor): ...@@ -219,11 +304,10 @@ class PawsxProcessor(DataProcessor):
train_tsv = "translated_train.tsv" train_tsv = "translated_train.tsv"
# Skips the header. # Skips the header.
lines.extend( lines.extend(
self._read_tsv( self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
os.path.join(data_dir, language, train_tsv))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -235,13 +319,12 @@ class PawsxProcessor(DataProcessor): ...@@ -235,13 +319,12 @@ class PawsxProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = [] lines = []
for language in PawsxProcessor.supported_languages: for lang in PawsxProcessor.supported_languages:
# Skips the header.
lines.extend( lines.extend(
self._read_tsv(os.path.join(data_dir, language, "dev_2k.tsv"))[1:]) self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "dev-%d" % i guid = "dev-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -252,17 +335,15 @@ class PawsxProcessor(DataProcessor): ...@@ -252,17 +335,15 @@ class PawsxProcessor(DataProcessor):
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
examples_by_lang = {k: [] for k in PawsxProcessor.supported_languages} examples_by_lang = {k: [] for k in self.supported_languages}
for language in PawsxProcessor.supported_languages: for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, language, "test_2k.tsv")) lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:]
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i guid = "test-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
label = self.process_text_fn(line[3]) label = self.process_text_fn(line[3])
examples_by_lang[language].append( examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang return examples_by_lang
...@@ -273,57 +354,11 @@ class PawsxProcessor(DataProcessor): ...@@ -273,57 +354,11 @@ class PawsxProcessor(DataProcessor):
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "PAWS-X" return "XTREME-PAWS-X"
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_test_examples(self, data_dir): class QnliProcessor(DataProcessor):
"""See base class.""" """Processor for the QNLI data set (GLUE version)."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[9])
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -333,7 +368,7 @@ class MrpcProcessor(DataProcessor): ...@@ -333,7 +368,7 @@ class MrpcProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -342,26 +377,28 @@ class MrpcProcessor(DataProcessor): ...@@ -342,26 +377,28 @@ class MrpcProcessor(DataProcessor):
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["entailment", "not_entailment"]
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "MRPC" return "QNLI"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, 1)
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
if set_type == "test": if set_type == "test":
label = "0" text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = "entailment"
else: else:
label = self.process_text_fn(line[0]) text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -395,9 +432,9 @@ class QqpProcessor(DataProcessor): ...@@ -395,9 +432,9 @@ class QqpProcessor(DataProcessor):
return "QQP" return "QQP"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
...@@ -407,13 +444,13 @@ class QqpProcessor(DataProcessor): ...@@ -407,13 +444,13 @@ class QqpProcessor(DataProcessor):
label = line[5] label = line[5]
except IndexError: except IndexError:
continue continue
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, examples.append(
label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
class ColaProcessor(DataProcessor): class RteProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the RTE data set (GLUE version)."""
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -432,29 +469,30 @@ class ColaProcessor(DataProcessor): ...@@ -432,29 +469,30 @@ class ColaProcessor(DataProcessor):
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] # All datasets are converted to 2-class split, where for 3-class datasets we
# collapse neutral and contradiction into not_entailment.
return ["entailment", "not_entailment"]
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "COLA" return "RTE"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
# Only the test set has a header if i == 0:
if set_type == "test" and i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
if set_type == "test": if set_type == "test":
text_a = self.process_text_fn(line[1]) label = "entailment"
label = "0"
else: else:
text_a = self.process_text_fn(line[3]) label = tokenization.convert_to_unicode(line[3])
label = self.process_text_fn(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -486,9 +524,9 @@ class SstProcessor(DataProcessor): ...@@ -486,9 +524,9 @@ class SstProcessor(DataProcessor):
return "SST-2" return "SST-2"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -503,8 +541,14 @@ class SstProcessor(DataProcessor): ...@@ -503,8 +541,14 @@ class SstProcessor(DataProcessor):
return examples return examples
class QnliProcessor(DataProcessor): class StsBProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version).""" """Processor for the STS-B data set (GLUE version)."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(StsBProcessor, self).__init__(process_text_fn=process_text_fn)
self.is_regression = True
self.label_type = float
self._labels = None
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -514,7 +558,7 @@ class QnliProcessor(DataProcessor): ...@@ -514,7 +558,7 @@ class QnliProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched") self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -523,28 +567,26 @@ class QnliProcessor(DataProcessor): ...@@ -523,28 +567,26 @@ class QnliProcessor(DataProcessor):
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["entailment", "not_entailment"] return self._labels
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "QNLI" return "STS-B"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, 1) guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[7])
text_b = tokenization.convert_to_unicode(line[8])
if set_type == "test": if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1]) label = 0.0
text_b = tokenization.convert_to_unicode(line[2])
label = "entailment"
else: else:
text_a = tokenization.convert_to_unicode(line[1]) label = self.label_type(tokenization.convert_to_unicode(line[9]))
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -564,6 +606,8 @@ class TfdsProcessor(DataProcessor): ...@@ -564,6 +606,8 @@ class TfdsProcessor(DataProcessor):
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2" tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2," tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2,"
"is_regression=true,label_type=float" "is_regression=true,label_type=float"
tfds_params="dataset=snli,text_key=premise,text_b_key=hypothesis,"
"skip_label=-1"
Possible parameters (please refer to the documentation of Tensorflow Datasets Possible parameters (please refer to the documentation of Tensorflow Datasets
(TFDS) for the meaning of individual parameters): (TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number). dataset: Required dataset name (potentially with subset and version number).
...@@ -581,17 +625,19 @@ class TfdsProcessor(DataProcessor): ...@@ -581,17 +625,19 @@ class TfdsProcessor(DataProcessor):
label_type: Type of the label key (defaults to `int`). label_type: Type of the label key (defaults to `int`).
weight_key: Key of the float sample weight (is not used if not provided). weight_key: Key of the float sample weight (is not used if not provided).
is_regression: Whether the task is a regression problem (defaults to False). is_regression: Whether the task is a regression problem (defaults to False).
skip_label: Skip examples with given label (defaults to None).
""" """
def __init__(self, tfds_params, def __init__(self,
tfds_params,
process_text_fn=tokenization.convert_to_unicode): process_text_fn=tokenization.convert_to_unicode):
super(TfdsProcessor, self).__init__(process_text_fn) super(TfdsProcessor, self).__init__(process_text_fn)
self._process_tfds_params_str(tfds_params) self._process_tfds_params_str(tfds_params)
if self.module_import: if self.module_import:
importlib.import_module(self.module_import) importlib.import_module(self.module_import)
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir, self.dataset, info = tfds.load(
with_info=True) self.dataset_name, data_dir=self.data_dir, with_info=True)
if self.is_regression: if self.is_regression:
self._labels = None self._labels = None
else: else:
...@@ -619,6 +665,9 @@ class TfdsProcessor(DataProcessor): ...@@ -619,6 +665,9 @@ class TfdsProcessor(DataProcessor):
self.label_type = dtype_map[d.get("label_type", "int")] self.label_type = dtype_map[d.get("label_type", "int")]
self.is_regression = cast_str_to_bool(d.get("is_regression", "False")) self.is_regression = cast_str_to_bool(d.get("is_regression", "False"))
self.weight_key = d.get("weight_key", None) self.weight_key = d.get("weight_key", None)
self.skip_label = d.get("skip_label", None)
if self.skip_label is not None:
self.skip_label = self.label_type(self.skip_label)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
assert data_dir is None assert data_dir is None
...@@ -639,7 +688,7 @@ class TfdsProcessor(DataProcessor): ...@@ -639,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return "TFDS_" + self.dataset_name return "TFDS_" + self.dataset_name
def _create_examples(self, split_name, set_type): def _create_examples(self, split_name, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
if split_name not in self.dataset: if split_name not in self.dataset:
raise ValueError("Split {} not available.".format(split_name)) raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator() dataset = self.dataset[split_name].as_numpy_iterator()
...@@ -657,13 +706,258 @@ class TfdsProcessor(DataProcessor): ...@@ -657,13 +706,258 @@ class TfdsProcessor(DataProcessor):
if self.text_b_key: if self.text_b_key:
text_b = self.process_text_fn(example[self.text_b_key]) text_b = self.process_text_fn(example[self.text_b_key])
label = self.label_type(example[self.label_key]) label = self.label_type(example[self.label_key])
if self.skip_label is not None and label == self.skip_label:
continue
if self.weight_key: if self.weight_key:
weight = float(example[self.weight_key]) weight = float(example[self.weight_key])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, InputExample(
weight=weight)) guid=guid,
text_a=text_a,
text_b=text_b,
label=label,
weight=weight))
return examples
class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "WNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[3])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
if language == "all":
self.languages = XnliProcessor.supported_languages
elif language not in XnliProcessor.supported_languages:
raise ValueError("language %s is not supported for XNLI task." % language)
else:
self.languages = [language]
def get_train_examples(self, data_dir):
"""See base class."""
lines = []
for language in self.languages:
# Skips the header.
lines.extend(
self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % language))[1:])
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % i
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for i, line in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
language = self.process_text_fn(line[0])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples_by_lang[language].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XNLI"
class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for i, line in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "0"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-PAWS-X"
class XtremeXnliProcessor(DataProcessor):
"""Processor for the XTREME XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for i, line in enumerate(lines):
guid = f"test-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "contradiction"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-XNLI"
def convert_single_example(ex_index, example, label_list, max_seq_length, def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer): tokenizer):
...@@ -748,8 +1042,9 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -748,8 +1042,9 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %d)", example.label, label_id) logging.info("label: %s (id = %s)", example.label, str(label_id))
logging.info("weight: %s", example.weight) logging.info("weight: %s", example.weight)
logging.info("int_iden: %s", str(example.int_iden))
feature = InputFeatures( feature = InputFeatures(
input_ids=input_ids, input_ids=input_ids,
...@@ -757,19 +1052,24 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -757,19 +1052,24 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
segment_ids=segment_ids, segment_ids=segment_ids,
label_id=label_id, label_id=label_id,
is_real_example=True, is_real_example=True,
weight=example.weight) weight=example.weight,
int_iden=example.int_iden)
return feature return feature
def file_based_convert_examples_to_features(examples, label_list, def file_based_convert_examples_to_features(examples,
max_seq_length, tokenizer, label_list,
output_file, label_type=None): max_seq_length,
tokenizer,
output_file,
label_type=None):
"""Convert a set of `InputExample`s to a TFRecord file.""" """Convert a set of `InputExample`s to a TFRecord file."""
tf.io.gfile.makedirs(os.path.dirname(output_file)) tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file) writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples): for ex_index, example in enumerate(examples):
if ex_index % 10000 == 0: if ex_index % 10000 == 0:
logging.info("Writing example %d of %d", ex_index, len(examples)) logging.info("Writing example %d of %d", ex_index, len(examples))
...@@ -779,6 +1079,7 @@ def file_based_convert_examples_to_features(examples, label_list, ...@@ -779,6 +1079,7 @@ def file_based_convert_examples_to_features(examples, label_list,
def create_int_feature(values): def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f return f
def create_float_feature(values): def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f return f
...@@ -789,12 +1090,14 @@ def file_based_convert_examples_to_features(examples, label_list, ...@@ -789,12 +1090,14 @@ def file_based_convert_examples_to_features(examples, label_list,
features["segment_ids"] = create_int_feature(feature.segment_ids) features["segment_ids"] = create_int_feature(feature.segment_ids)
if label_type is not None and label_type == float: if label_type is not None and label_type == float:
features["label_ids"] = create_float_feature([feature.label_id]) features["label_ids"] = create_float_feature([feature.label_id])
else: elif feature.label_id is not None:
features["label_ids"] = create_int_feature([feature.label_id]) features["label_ids"] = create_int_feature([feature.label_id])
features["is_real_example"] = create_int_feature( features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)]) [int(feature.is_real_example)])
if feature.weight is not None: if feature.weight is not None:
features["weight"] = create_float_feature([feature.weight]) features["weight"] = create_float_feature([feature.weight])
if feature.int_iden is not None:
features["int_iden"] = create_int_feature([feature.int_iden])
tf_example = tf.train.Example(features=tf.train.Features(feature=features)) tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString()) writer.write(tf_example.SerializeToString())
...@@ -830,8 +1133,7 @@ def generate_tf_record_from_data_file(processor, ...@@ -830,8 +1133,7 @@ def generate_tf_record_from_data_file(processor,
Arguments: Arguments:
processor: Input processor object to be used for generating data. Subclass processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`. of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files data_dir: Directory that contains train/eval/test data to process.
should be in from "dev.tsv", "test.tsv", or "train.tsv".
tokenizer: The tokenizer to be applied on the data. tokenizer: The tokenizer to be applied on the data.
train_data_output_path: Output to which processed tf record for training train_data_output_path: Output to which processed tf record for training
will be saved. will be saved.
...@@ -857,8 +1159,7 @@ def generate_tf_record_from_data_file(processor, ...@@ -857,8 +1159,7 @@ def generate_tf_record_from_data_file(processor,
train_input_data_examples = processor.get_train_examples(data_dir) train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list, file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer, max_seq_length, tokenizer,
train_data_output_path, train_data_output_path, label_type)
label_type)
num_training_data = len(train_input_data_examples) num_training_data = len(train_input_data_examples)
if eval_data_output_path: if eval_data_output_path:
...@@ -868,26 +1169,27 @@ def generate_tf_record_from_data_file(processor, ...@@ -868,26 +1169,27 @@ def generate_tf_record_from_data_file(processor,
tokenizer, eval_data_output_path, tokenizer, eval_data_output_path,
label_type) label_type)
meta_data = {
"processor_type": processor.get_processor_name(),
"train_data_size": num_training_data,
"max_seq_length": max_seq_length,
}
if test_data_output_path: if test_data_output_path:
test_input_data_examples = processor.get_test_examples(data_dir) test_input_data_examples = processor.get_test_examples(data_dir)
if isinstance(test_input_data_examples, dict): if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items(): for language, examples in test_input_data_examples.items():
file_based_convert_examples_to_features( file_based_convert_examples_to_features(
examples, examples, label_list, max_seq_length, tokenizer,
label_list, max_seq_length, test_data_output_path.format(language), label_type)
tokenizer, test_data_output_path.format(language), meta_data["test_{}_data_size".format(language)] = len(examples)
label_type)
else: else:
file_based_convert_examples_to_features(test_input_data_examples, file_based_convert_examples_to_features(test_input_data_examples,
label_list, max_seq_length, label_list, max_seq_length,
tokenizer, test_data_output_path, tokenizer, test_data_output_path,
label_type) label_type)
meta_data["test_data_size"] = len(test_input_data_examples)
meta_data = {
"processor_type": processor.get_processor_name(),
"train_data_size": num_training_data,
"max_seq_length": max_seq_length,
}
if is_regression: if is_regression:
meta_data["task_type"] = "bert_regression" meta_data["task_type"] = "bert_regression"
meta_data["label_type"] = {int: "int", float: "float"}[label_type] meta_data["label_type"] = {int: "int", float: "float"}[label_type]
...@@ -900,12 +1202,4 @@ def generate_tf_record_from_data_file(processor, ...@@ -900,12 +1202,4 @@ def generate_tf_record_from_data_file(processor,
if eval_data_output_path: if eval_data_output_path:
meta_data["eval_data_size"] = len(eval_input_data_examples) meta_data["eval_data_size"] = len(eval_input_data_examples)
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(data_dir)
if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items():
meta_data["test_{}_data_size".format(language)] = len(examples)
else:
meta_data["test_data_size"] = len(test_input_data_examples)
return meta_data return meta_data
...@@ -27,18 +27,21 @@ from absl import flags ...@@ -27,18 +27,21 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.data import sentence_retrieval_lib
# word-piece tokenizer based squad_lib # word-piece tokenizer based squad_lib
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib # sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.nlp.data import tagging_data_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
# TODO(chendouble): consider moving each task to its own binary.
flags.DEFINE_enum( flags.DEFINE_enum(
"fine_tuning_task_type", "classification", "fine_tuning_task_type", "classification",
["classification", "regression", "squad"], ["classification", "regression", "squad", "retrieval", "tagging"],
"The name of the BERT fine tuning task for which data " "The name of the BERT fine tuning task for which data "
"will be generated..") "will be generated.")
# BERT classification specific flags. # BERT classification specific flags.
flags.DEFINE_string( flags.DEFINE_string(
...@@ -47,23 +50,41 @@ flags.DEFINE_string( ...@@ -47,23 +50,41 @@ flags.DEFINE_string(
"for the task.") "for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI", ["COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"PAWS-X"], "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI",
"The name of the task to train BERT classifier.") "XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
# MNLI task-specific flag.
flags.DEFINE_enum(
"mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task specific flag. # XNLI task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"xnli_language", "en", "xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data " "Language of training data for XNLI task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# PAWS-X task specific flag. # PAWS-X task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"pawsx_language", "en", "pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data " "Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# BERT Squad task specific flags. # Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
# Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.")
# BERT Squad task-specific flags.
flags.DEFINE_string( flags.DEFINE_string(
"squad_data_file", None, "squad_data_file", None,
"The input data file in for generating training data for BERT squad task.") "The input data file in for generating training data for BERT squad task.")
...@@ -163,20 +184,29 @@ def generate_classifier_dataset(): ...@@ -163,20 +184,29 @@ def generate_classifier_dataset():
"cola": "cola":
classifier_data_lib.ColaProcessor, classifier_data_lib.ColaProcessor,
"mnli": "mnli":
classifier_data_lib.MnliProcessor, functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type),
"mrpc": "mrpc":
classifier_data_lib.MrpcProcessor, classifier_data_lib.MrpcProcessor,
"qnli": "qnli":
classifier_data_lib.QnliProcessor, classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor, "qqp": classifier_data_lib.QqpProcessor,
"rte": classifier_data_lib.RteProcessor,
"sst-2": "sst-2":
classifier_data_lib.SstProcessor, classifier_data_lib.SstProcessor,
"sts-b":
classifier_data_lib.StsBProcessor,
"xnli": "xnli":
functools.partial(classifier_data_lib.XnliProcessor, functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language), language=FLAGS.xnli_language),
"paws-x": "paws-x":
functools.partial(classifier_data_lib.PawsxProcessor, functools.partial(classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language) language=FLAGS.pawsx_language),
"wnli": classifier_data_lib.WnliProcessor,
"xtreme-xnli":
functools.partial(classifier_data_lib.XtremeXnliProcessor),
"xtreme-paws-x":
functools.partial(classifier_data_lib.XtremePawsxProcessor)
} }
task_name = FLAGS.classification_task_name.lower() task_name = FLAGS.classification_task_name.lower()
if task_name not in processors: if task_name not in processors:
...@@ -237,6 +267,67 @@ def generate_squad_dataset(): ...@@ -237,6 +267,67 @@ def generate_squad_dataset():
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative) FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
def generate_retrieval_dataset():
"""Generate retrieval test and dev dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
processors = {
"bucc": sentence_retrieval_lib.BuccProcessor,
"tatoeba": sentence_retrieval_lib.TatoebaProcessor,
}
task_name = FLAGS.retrieval_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
processor = processors[task_name](process_text_fn=processor_text_fn)
return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
processor,
FLAGS.input_data_dir,
tokenizer,
FLAGS.eval_data_output_path,
FLAGS.test_data_output_path,
FLAGS.max_seq_length)
def generate_tagging_dataset():
"""Generates tagging dataset."""
processors = {
"panx": tagging_data_lib.PanxProcessor,
"udpos": tagging_data_lib.UdposProcessor,
}
task_name = FLAGS.tagging_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
elif FLAGS.tokenizer_impl == "sentence_piece":
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
else:
raise ValueError("Unsupported tokenizer_impl: %s" % FLAGS.tokenizer_impl)
processor = processors[task_name]()
return tagging_data_lib.generate_tf_record_from_data_file(
processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
FLAGS.test_data_output_path, processor_text_fn)
def main(_): def main(_):
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenizer_impl == "word_piece":
if not FLAGS.vocab_file: if not FLAGS.vocab_file:
...@@ -248,12 +339,20 @@ def main(_): ...@@ -248,12 +339,20 @@ def main(_):
raise ValueError( raise ValueError(
"FLAG sp_model_file for sentence-piece tokenizer is not specified.") "FLAG sp_model_file for sentence-piece tokenizer is not specified.")
if FLAGS.fine_tuning_task_type != "retrieval":
flags.mark_flag_as_required("train_data_output_path")
if FLAGS.fine_tuning_task_type == "classification": if FLAGS.fine_tuning_task_type == "classification":
input_meta_data = generate_classifier_dataset() input_meta_data = generate_classifier_dataset()
elif FLAGS.fine_tuning_task_type == "regression": elif FLAGS.fine_tuning_task_type == "regression":
input_meta_data = generate_regression_dataset() input_meta_data = generate_regression_dataset()
else: elif FLAGS.fine_tuning_task_type == "retrieval":
input_meta_data = generate_retrieval_dataset()
elif FLAGS.fine_tuning_task_type == "squad":
input_meta_data = generate_squad_dataset() input_meta_data = generate_squad_dataset()
else:
assert FLAGS.fine_tuning_task_type == "tagging"
input_meta_data = generate_tagging_dataset()
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path)) tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer: with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
...@@ -261,6 +360,5 @@ def main(_): ...@@ -261,6 +360,5 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
flags.mark_flag_as_required("train_data_output_path")
flags.mark_flag_as_required("meta_data_file_path") flags.mark_flag_as_required("meta_data_file_path")
app.run(main) app.run(main)
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import itertools
import random import random
from absl import app from absl import app
...@@ -48,6 +49,12 @@ flags.DEFINE_bool( ...@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask", False, "do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.") "Whether to use whole word masking rather than per-WordPiece masking.")
flags.DEFINE_integer(
"max_ngram_size", None,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
flags.DEFINE_bool( flags.DEFINE_bool(
"gzip_compress", False, "gzip_compress", False,
"Whether to use `GZIP` compress option to get compressed TFRecord files.") "Whether to use `GZIP` compress option to get compressed TFRecord files.")
...@@ -192,7 +199,8 @@ def create_training_instances(input_files, ...@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob, masked_lm_prob,
max_predictions_per_seq, max_predictions_per_seq,
rng, rng,
do_whole_word_mask=False): do_whole_word_mask=False,
max_ngram_size=None):
"""Create `TrainingInstance`s from raw text.""" """Create `TrainingInstance`s from raw text."""
all_documents = [[]] all_documents = [[]]
...@@ -229,7 +237,7 @@ def create_training_instances(input_files, ...@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document( create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)) do_whole_word_mask, max_ngram_size))
rng.shuffle(instances) rng.shuffle(instances)
return instances return instances
...@@ -238,7 +246,8 @@ def create_training_instances(input_files, ...@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def create_instances_from_document( def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask=False): do_whole_word_mask=False,
max_ngram_size=None):
"""Creates `TrainingInstance`s for a single document.""" """Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index] document = all_documents[document_index]
...@@ -337,7 +346,7 @@ def create_instances_from_document( ...@@ -337,7 +346,7 @@ def create_instances_from_document(
(tokens, masked_lm_positions, (tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions( masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng, tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask) do_whole_word_mask, max_ngram_size)
instance = TrainingInstance( instance = TrainingInstance(
tokens=tokens, tokens=tokens,
segment_ids=segment_ids, segment_ids=segment_ids,
...@@ -355,72 +364,238 @@ def create_instances_from_document( ...@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"]) ["index", "label"])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram = collections.namedtuple("_Gram", ["begin", "end"])
def _window(iterable, size):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i = iter(iterable)
window = []
try:
for e in range(0, size):
window.append(next(i))
yield window
except StopIteration:
# handle the case where iterable's length is less than the window size.
return
for e in i:
window = window[1:] + [e]
yield window
def _contiguous(sorted_grams):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for a, b in _window(sorted_grams, 2):
if a.end != b.begin:
return False
return True
def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if not grams:
return None
grams = sorted(grams)
num_tokens = grams[-1].end
# Ensure our grams are valid (i.e., they don't overlap).
for a, b in _window(grams, 2):
if a.end > b.begin:
raise ValueError("overlapping grams: {}".format(grams))
# Build map from n-gram length to list of n-grams.
ngrams = {i: [] for i in range(1, max_ngram_size+1)}
for gram_size in range(1, max_ngram_size+1):
for g in _window(grams, gram_size):
if _contiguous(g):
# Add an n-gram which spans these one-grams.
ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
# Shuffle each list of n-grams.
for v in ngrams.values():
rng.shuffle(v)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights = list(
itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
output_ngrams = []
# Keep a bitmask of which tokens have been masked.
masked_tokens = [False] * num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while (sum(masked_tokens) < max_masked_tokens and
sum(len(s) for s in ngrams.values())):
# Pick an n-gram size based on our weights.
sz = random.choices(range(1, max_ngram_size+1),
cum_weights=cummulative_weights)[0]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if sum(masked_tokens) + sz > max_masked_tokens:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams[sz].clear()
continue
def create_masked_lm_predictions(tokens, masked_lm_prob, # All of the n-grams of this size have been used.
max_predictions_per_seq, vocab_words, rng, if not ngrams[sz]:
do_whole_word_mask): continue
"""Creates the predictions for the masked LM objective."""
# Choose a random n-gram of the given size.
gram = ngrams[sz].pop()
num_gram_tokens = gram.end-gram.begin
# Check if this would add too many tokens.
if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
continue
# Check if any of the tokens in this gram have already been masked.
if sum(masked_tokens[gram.begin:gram.end]):
continue
cand_indexes = [] # Found a usable n-gram! Mark its tokens as masked and add it to return.
for (i, token) in enumerate(tokens): masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
if token == "[CLS]" or token == "[SEP]": output_ngrams.append(gram)
return output_ngrams
def _wordpieces_to_grams(tokens):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams = []
gram_start_pos = None
for i, token in enumerate(tokens):
if gram_start_pos is not None and token.startswith("##"):
continue continue
# Whole Word Masking means that if we mask all of the wordpieces if gram_start_pos is not None:
# corresponding to an original word. When a word has been split into grams.append(_Gram(gram_start_pos, i))
# WordPieces, the first token does not have any marker and any subsequence if token not in ["[CLS]", "[SEP]"]:
# tokens are prefixed with ##. So whenever we see the ## token, we gram_start_pos = i
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else: else:
cand_indexes.append([i]) gram_start_pos = None
if gram_start_pos is not None:
grams.append(_Gram(gram_start_pos, len(tokens)))
return grams
rng.shuffle(cand_indexes)
output_tokens = list(tokens) def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask,
max_ngram_size=None):
"""Creates the predictions for the masked LM objective."""
if do_whole_word_mask:
grams = _wordpieces_to_grams(tokens)
else:
# Here we consider each token to be a word to allow for sub-word masking.
if max_ngram_size:
raise ValueError("cannot use ngram masking without whole word masking")
grams = [_Gram(i, i+1) for i in range(0, len(tokens))
if tokens[i] not in ["[CLS]", "[SEP]"]]
num_to_predict = min(max_predictions_per_seq, num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob)))) max(1, int(round(len(tokens) * masked_lm_prob))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
num_to_predict, rng)
masked_lms = [] masked_lms = []
covered_indexes = set() output_tokens = list(tokens)
for index_set in cand_indexes: for gram in masked_grams:
if len(masked_lms) >= num_to_predict: # 80% of the time, replace all n-gram tokens with [MASK]
break if rng.random() < 0.8:
# If adding a whole-word mask would exceed the maximum number of replacement_action = lambda idx: "[MASK]"
# predictions, then just skip this candidate. else:
if len(masked_lms) + len(index_set) > num_to_predict: # 10% of the time, keep all the original n-gram tokens.
continue if rng.random() < 0.5:
is_any_index_covered = False replacement_action = lambda idx: tokens[idx]
for index in index_set: # 10% of the time, replace each n-gram token with a random word.
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else: else:
# 10% of the time, keep original replacement_action = lambda idx: rng.choice(vocab_words)
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token for idx in range(gram.begin, gram.end):
output_tokens[idx] = replacement_action(idx)
masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lms = sorted(masked_lms, key=lambda x: x.index)
...@@ -467,7 +642,7 @@ def main(_): ...@@ -467,7 +642,7 @@ def main(_):
instances = create_training_instances( instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng, FLAGS.do_whole_word_mask) rng, FLAGS.do_whole_word_mask, FLAGS.max_ngram_size)
output_files = FLAGS.output_file.split(",") output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***") logging.info("*** Writing to output files ***")
......
# Copyright 2018 The TensorFlow Authors All Rights Reserved. # Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,23 +13,47 @@ ...@@ -12,23 +13,47 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A global factory to access NLP registered data loaders."""
"""Define the abstract class for contextual bandit algorithms.""" from official.utils import registry
from __future__ import absolute_import _REGISTERED_DATA_LOADER_CLS = {}
from __future__ import division
from __future__ import print_function
class BanditAlgorithm(object): def register_data_loader_cls(data_config_cls):
"""A bandit algorithm must be able to do two basic operations. """Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
1. Choose an action given a context. This decorator supports registration of data loaders as follows:
2. Update its internal model given a triple (context, played action, reward).
"""
def action(self, context): ```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass pass
def update(self, context, action, reward): @register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
def get_data_loader(data_config):
"""Creates a data_loader from data_config."""
return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(
data_config)
...@@ -16,11 +16,27 @@ ...@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task.""" """Loads dataset for the BERT pretraining task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ''
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader: class BertPretrainDataLoader:
"""A class to load dataset for bert pretraining task.""" """A class to load dataset for bert pretraining task."""
...@@ -91,7 +107,5 @@ class BertPretrainDataLoader: ...@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader( reader = input_reader.InputReader(
params=self._params, params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
decoder_fn=self._decode,
parser_fn=self._parse)
return reader.read(input_context) return reader.read(input_context)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ''
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
# Settings below are question answering specific.
version_2_with_negative: bool = False
# Settings below are only used for eval mode.
input_preprocessed_data_path: str = ''
doc_stride: int = 128
query_length: int = 64
vocab_file: str = ''
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
do_lower_case: bool = True
@data_loader_factory.register_data_loader_cls(QADataConfig)
class QuestionAnsweringDataLoader:
"""A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params):
self._params = params
self._seq_length = params.seq_length
self._is_training = params.is_training
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
else:
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x, y = {}, {}
for name, tensor in record.items():
if name in ('start_positions', 'end_positions'):
y[name] = tensor
elif name == 'input_ids':
x['input_word_ids'] = tensor
elif name == 'segment_ids':
x['input_type_ids'] = tensor
else:
x[name] = tensor
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
...@@ -15,11 +15,28 @@ ...@@ -15,11 +15,28 @@
# ============================================================================== # ==============================================================================
"""Loads dataset for the sentence prediction (classification) task.""" """Loads dataset for the sentence prediction (classification) task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ''
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
label_type: str = 'int'
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
class SentencePredictionDataLoader: class SentencePredictionDataLoader:
"""A class to load dataset for sentence prediction (classification) task.""" """A class to load dataset for sentence prediction (classification) task."""
...@@ -29,11 +46,12 @@ class SentencePredictionDataLoader: ...@@ -29,11 +46,12 @@ class SentencePredictionDataLoader:
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
label_type = LABEL_TYPES_MAP[self._params.label_type]
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64), 'label_ids': tf.io.FixedLenFeature([], label_type),
} }
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT library to process data for cross lingual sentence retrieval task."""
import os
from absl import logging
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
class BuccProcessor(classifier_data_lib.DataProcessor):
"""Procssor for Xtreme BUCC data set."""
supported_languages = ["de", "fr", "ru", "zh"]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode):
super(BuccProcessor, self).__init__(process_text_fn)
self.languages = BuccProcessor.supported_languages
def get_dev_examples(self, data_dir, file_pattern):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))),
"sample")
def get_test_examples(self, data_dir, file_pattern):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))),
"test")
@staticmethod
def get_processor_name():
"""See base class."""
return "BUCC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
int_iden = int(line[0].split("-")[1])
text_a = self.process_text_fn(line[1])
examples.append(
classifier_data_lib.InputExample(
guid=guid, text_a=text_a, int_iden=int_iden))
return examples
class TatoebaProcessor(classifier_data_lib.DataProcessor):
"""Procssor for Xtreme Tatoeba data set."""
supported_languages = [
"af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr",
"he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr",
"nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode):
super(TatoebaProcessor, self).__init__(process_text_fn)
self.languages = TatoebaProcessor.supported_languages
def get_test_examples(self, data_dir, file_path):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, file_path)), "test")
@staticmethod
def get_processor_name():
"""See base class."""
return "TATOEBA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[0])
examples.append(
classifier_data_lib.InputExample(
guid=guid, text_a=text_a, int_iden=i))
return examples
def generate_sentence_retrevial_tf_record(processor,
data_dir,
tokenizer,
eval_data_output_path=None,
test_data_output_path=None,
max_seq_length=128):
"""Generates the tf records for retrieval tasks.
Args:
processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files
should be in from.
tokenizer: The tokenizer to be applied on the data.
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
test_data_output_path: Output to which processed tf record for testing
will be saved. Must be a pattern template with {} if processor has
language specific test data.
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
Returns:
A dictionary containing input meta data.
"""
assert eval_data_output_path or test_data_output_path
if processor.get_processor_name() == "BUCC":
path_pattern = "{}-en.{{}}.{}"
if processor.get_processor_name() == "TATOEBA":
path_pattern = "{}-en.{}"
meta_data = {
"processor_type": processor.get_processor_name(),
"max_seq_length": max_seq_length,
"number_eval_data": {},
"number_test_data": {},
}
logging.info("Start to process %s task data", processor.get_processor_name())
for lang_a in processor.languages:
for lang_b in [lang_a, "en"]:
if eval_data_output_path:
eval_input_data_examples = processor.get_dev_examples(
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
num_eval_data = len(eval_input_data_examples)
logging.info("Processing %d dev examples of %s-en.%s", num_eval_data,
lang_a, lang_b)
output_file = os.path.join(
eval_data_output_path,
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev"))
classifier_data_lib.file_based_convert_examples_to_features(
eval_input_data_examples, None, max_seq_length, tokenizer,
output_file, None)
meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
num_test_data = len(test_input_data_examples)
logging.info("Processing %d test examples of %s-en.%s", num_test_data,
lang_a, lang_b)
output_file = os.path.join(
test_data_output_path,
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test"))
classifier_data_lib.file_based_convert_examples_to_features(
test_input_data_examples, None, max_seq_length, tokenizer,
output_file, None)
meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data
return meta_data
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to process data for tagging task such as NER/POS."""
import collections
import os
from absl import logging
import tensorflow as tf
from official.nlp.data import classifier_data_lib
# A negative label id for the padding label, which will not contribute
# to loss/metrics in training.
_PADDING_LABEL_ID = -1
# The special unknown token, used to substitute a word which has too many
# subwords after tokenization.
_UNK_TOKEN = "[UNK]"
class InputExample(object):
"""A single training/test example for token classification."""
def __init__(self, sentence_id, words=None, label_ids=None):
"""Constructs an InputExample."""
self.sentence_id = sentence_id
self.words = words if words else []
self.label_ids = label_ids if label_ids else []
def add_word_and_label_id(self, word, label_id):
"""Adds word and label_id pair in the example."""
self.words.append(word)
self.label_ids.append(label_id)
def _read_one_file(file_name, label_list):
"""Reads one file and returns a list of `InputExample` instances."""
lines = tf.io.gfile.GFile(file_name, "r").readlines()
examples = []
label_id_map = {label: i for i, label in enumerate(label_list)}
sentence_id = 0
example = InputExample(sentence_id=0)
for line in lines:
line = line.strip("\n")
if line:
# The format is: <token>\t<label> for train/dev set and <token> for test.
items = line.split("\t")
assert len(items) == 2 or len(items) == 1
token = items[0].strip()
# Assign a dummy label_id for test set
label_id = label_id_map[items[1].strip()] if len(items) == 2 else 0
example.add_word_and_label_id(token, label_id)
else:
# Empty line indicates a new sentence.
if example.words:
examples.append(example)
sentence_id += 1
example = InputExample(sentence_id=sentence_id)
if example.words:
examples.append(example)
return examples
class PanxProcessor(classifier_data_lib.DataProcessor):
"""Processor for the Panx data set."""
supported_languages = [
"ar", "he", "vi", "id", "jv", "ms", "tl", "eu", "ml", "ta", "te", "af",
"nl", "en", "de", "el", "bn", "hi", "mr", "ur", "fa", "fr", "it", "pt",
"es", "bg", "ru", "ja", "ka", "ko", "th", "sw", "yo", "my", "zh", "kk",
"tr", "et", "fi", "hu"
]
def get_train_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
def get_dev_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
def get_test_examples(self, data_dir):
examples_dict = {}
for language in self.supported_languages:
examples_dict[language] = _read_one_file(
os.path.join(data_dir, "test-%s.tsv" % language), self.get_labels())
return examples_dict
def get_labels(self):
return ["O", "B-PER", "I-PER", "B-LOC", "I-LOC", "B-ORG", "I-ORG"]
@staticmethod
def get_processor_name():
return "panx"
class UdposProcessor(classifier_data_lib.DataProcessor):
"""Processor for the Udpos data set."""
supported_languages = [
"af", "ar", "bg", "de", "el", "en", "es", "et", "eu", "fa", "fi", "fr",
"he", "hi", "hu", "id", "it", "ja", "kk", "ko", "mr", "nl", "pt", "ru",
"ta", "te", "th", "tl", "tr", "ur", "vi", "yo", "zh"
]
def get_train_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
def get_dev_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
def get_test_examples(self, data_dir):
examples_dict = {}
for language in self.supported_languages:
examples_dict[language] = _read_one_file(
os.path.join(data_dir, "test-%s.tsv" % language), self.get_labels())
return examples_dict
def get_labels(self):
return [
"ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", "NOUN", "NUM",
"PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X"
]
@staticmethod
def get_processor_name():
return "udpos"
def _tokenize_example(example, max_length, tokenizer, text_preprocessing=None):
"""Tokenizes words and breaks long example into short ones."""
# Needs additional [CLS] and [SEP] tokens.
max_length = max_length - 2
new_examples = []
new_example = InputExample(sentence_id=example.sentence_id)
for i, word in enumerate(example.words):
if any([x < 0 for x in example.label_ids]):
raise ValueError("Unexpected negative label_id: %s" % example.label_ids)
if text_preprocessing:
word = text_preprocessing(word)
subwords = tokenizer.tokenize(word)
if (not subwords or len(subwords) > max_length) and word:
subwords = [_UNK_TOKEN]
if len(subwords) + len(new_example.words) > max_length:
# Start a new example.
new_examples.append(new_example)
new_example = InputExample(sentence_id=example.sentence_id)
for j, subword in enumerate(subwords):
# Use the real label for the first subword, and pad label for
# the remainings.
subword_label = example.label_ids[i] if j == 0 else _PADDING_LABEL_ID
new_example.add_word_and_label_id(subword, subword_label)
if new_example.words:
new_examples.append(new_example)
return new_examples
def _convert_single_example(example, max_seq_length, tokenizer):
"""Converts an `InputExample` instance to a `tf.train.Example` instance."""
tokens = ["[CLS]"]
tokens.extend(example.words)
tokens.append("[SEP]")
input_ids = tokenizer.convert_tokens_to_ids(tokens)
label_ids = [_PADDING_LABEL_ID]
label_ids.extend(example.label_ids)
label_ids.append(_PADDING_LABEL_ID)
segment_ids = [0] * len(input_ids)
input_mask = [1] * len(input_ids)
# Pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(_PADDING_LABEL_ID)
def create_int_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["segment_ids"] = create_int_feature(segment_ids)
features["label_ids"] = create_int_feature(label_ids)
features["sentence_id"] = create_int_feature([example.sentence_id])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
return tf_example
def write_example_to_file(examples,
tokenizer,
max_seq_length,
output_file,
text_preprocessing=None):
"""Writes `InputExample`s into a tfrecord file with `tf.train.Example` protos.
Note that the words inside each example will be tokenized and be applied by
`text_preprocessing` if available. Also, if the length of sentence (plus
special [CLS] and [SEP] tokens) exceeds `max_seq_length`, the long sentence
will be broken into multiple short examples. For example:
Example (text_preprocessing=lowercase, max_seq_length=5)
words: ["What", "a", "great", "weekend"]
labels: [ 7, 5, 9, 10]
sentence_id: 0
preprocessed: ["what", "a", "great", "weekend"]
tokenized: ["what", "a", "great", "week", "##end"]
will result in two tf.example protos:
tokens: ["[CLS]", "what", "a", "great", "[SEP]"]
label_ids: [-1, 7, 5, 9, -1]
input_mask: [ 1, 1, 1, 1, 1]
segment_ids: [ 0, 0, 0, 0, 0]
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
sentence_id: 0
tokens: ["[CLS]", "week", "##end", "[SEP]", "[PAD]"]
label_ids: [-1, 10, -1, -1, -1]
input_mask: [ 1, 1, 1, 0, 0]
segment_ids: [ 0, 0, 0, 0, 0]
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
sentence_id: 0
Note the use of -1 in `label_ids` to indicate that a token should not be
considered for classification (e.g., trailing ## wordpieces or special
token). Token classification models should accordingly ignore these when
calculating loss, metrics, etc...
Args:
examples: A list of `InputExample` instances.
tokenizer: The tokenizer to be applied on the data.
max_seq_length: Maximum length of generated sequences.
output_file: The name of the output tfrecord file.
text_preprocessing: optional preprocessing run on each word prior to
tokenization.
Returns:
The total number of tf.train.Example proto written to file.
"""
tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file)
num_tokenized_examples = 0
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logging.info("Writing example %d of %d to %s", ex_index, len(examples),
output_file)
tokenized_examples = _tokenize_example(example, max_seq_length,
tokenizer, text_preprocessing)
num_tokenized_examples += len(tokenized_examples)
for per_tokenized_example in tokenized_examples:
tf_example = _convert_single_example(
per_tokenized_example, max_seq_length, tokenizer)
writer.write(tf_example.SerializeToString())
writer.close()
return num_tokenized_examples
def token_classification_meta_data(train_data_size,
max_seq_length,
num_labels,
eval_data_size=None,
test_data_size=None,
label_list=None,
processor_type=None):
"""Creates metadata for tagging (token classification) datasets."""
meta_data = {
"train_data_size": train_data_size,
"max_seq_length": max_seq_length,
"num_labels": num_labels,
"task_type": "tagging",
"label_type": "int",
"label_shape": [max_seq_length],
}
if eval_data_size:
meta_data["eval_data_size"] = eval_data_size
if test_data_size:
meta_data["test_data_size"] = test_data_size
if label_list:
meta_data["label_list"] = label_list
if processor_type:
meta_data["processor_type"] = processor_type
return meta_data
def generate_tf_record_from_data_file(processor,
data_dir,
tokenizer,
max_seq_length,
train_data_output_path,
eval_data_output_path,
test_data_output_path,
text_preprocessing):
"""Generates tfrecord files from the raw data."""
common_kwargs = dict(tokenizer=tokenizer, max_seq_length=max_seq_length,
text_preprocessing=text_preprocessing)
train_examples = processor.get_train_examples(data_dir)
train_data_size = write_example_to_file(
train_examples, output_file=train_data_output_path, **common_kwargs)
eval_examples = processor.get_dev_examples(data_dir)
eval_data_size = write_example_to_file(
eval_examples, output_file=eval_data_output_path, **common_kwargs)
test_input_data_examples = processor.get_test_examples(data_dir)
test_data_size = {}
for language, examples in test_input_data_examples.items():
test_data_size[language] = write_example_to_file(
examples,
output_file=test_data_output_path.format(language),
**common_kwargs)
labels = processor.get_labels()
meta_data = token_classification_meta_data(
train_data_size,
max_seq_length,
len(labels),
eval_data_size,
test_data_size,
label_list=labels,
processor_type=processor.get_processor_name())
return meta_data
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
is_training: bool = True
seq_length: int = 128
include_sentence_id: bool = False
@data_loader_factory.register_data_loader_cls(TaggingDataConfig)
class TaggingDataLoader:
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def __init__(self, params: TaggingDataConfig):
self._params = params
self._seq_length = params.seq_length
self._include_sentence_id = params.include_sentence_id
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._include_sentence_id:
name_to_features['sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
if self._include_sentence_id:
x['sentence_id'] = record['sentence_id']
y = record['label_ids']
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
# NLP Modeling Library # NLP Modeling Library
This libary provides a set of Keras primitives (Layers, Networks, and Models) This library provides a set of Keras primitives (Layers, Networks, and Models)
that can be assembled into transformer-based models. They are that can be assembled into transformer-based models. They are
flexible, validated, interoperable, and both TF1 and TF2 compatible. flexible, validated, interoperable, and both TF1 and TF2 compatible.
...@@ -16,6 +16,11 @@ standardized configuration. ...@@ -16,6 +16,11 @@ standardized configuration.
* [`losses`](losses) contains common loss computation used in NLP tasks. * [`losses`](losses) contains common loss computation used in NLP tasks.
Please see the colab
[nlp_modeling_library_intro.ipynb]
(https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb)
for how to build transformer-based NLP models using above primitives.
Besides the pre-defined primitives, it also provides scaffold classes to allow Besides the pre-defined primitives, it also provides scaffold classes to allow
easy experimentation with noval achitectures, e.g., you don’t need to fork a whole Transformer object to try a different kind of attention primitive, for instance. easy experimentation with noval achitectures, e.g., you don’t need to fork a whole Transformer object to try a different kind of attention primitive, for instance.
...@@ -33,11 +38,9 @@ embedding subnetwork (which will replace the standard embedding logic) and/or a ...@@ -33,11 +38,9 @@ embedding subnetwork (which will replace the standard embedding logic) and/or a
custom hidden layer (which will replace the Transformer instantiation in the custom hidden layer (which will replace the Transformer instantiation in the
encoder). encoder).
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder. Please see the colab
[customize_encoder.ipynb]
(https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb)
for how to use scaffold classes to build noval achitectures.
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder.
...@@ -3,19 +3,18 @@ ...@@ -3,19 +3,18 @@
Layers are the fundamental building blocks for NLP models. They can be used to Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models. assemble new layers, networks, or models.
* [DenseEinsum](dense_einsum.py) implements a feedforward network using
tf.einsum. This layer contains the einsum op, the associated weight, and the
logic required to generate the einsum expression for the given
initialization parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention * [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in between query, key, value tensors as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
`from_tensor` and `to_tensor` are the same, then this is self-attention. `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache * [CachedAttention](attention.py) implements an attention layer with cache
used for auto-agressive decoding. used for auto-agressive decoding.
* [MultiChannelAttention](multi_channel_attention.py) implements an variant of
multi-head attention which can be used to merge multiple streams for
cross-attentions.
* [TalkingHeadsAttention](talking_heads_attention.py) implements the talking * [TalkingHeadsAttention](talking_heads_attention.py) implements the talking
heads attention, as decribed in heads attention, as decribed in
["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436). ["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436).
...@@ -24,6 +23,10 @@ assemble new layers, networks, or models. ...@@ -24,6 +23,10 @@ assemble new layers, networks, or models.
described in described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [TransformerDecoderLayer](transformer.py) TransformerDecoderLayer is made up
of self multi-head attention, cross multi-head attention and
feedforward network.
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with * [ReZeroTransformer](rezero_transformer.py) implements Transformer with
ReZero described in ReZero described in
["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887). ["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
...@@ -45,8 +48,8 @@ assemble new layers, networks, or models. ...@@ -45,8 +48,8 @@ assemble new layers, networks, or models.
should be masked), the output will have masked positions set to should be masked), the output will have masked positions set to
approximately zero. approximately zero.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes the * [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes
embedding table variable is passed to it. the embedding table variable is passed to it.
* [ClassificationHead](cls_head.py) A pooling head over a sequence of * [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks. embeddings, commonly used by classification tasks.
......
...@@ -20,10 +20,12 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum ...@@ -20,10 +20,12 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.masked_lm import MaskedLM from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.transformer import Transformer from official.nlp.modeling.layers.transformer import *
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
...@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense ...@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
def _build_attention_equation(qkv_rank, attn_axes): def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation. """Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as: Query, key, value inputs after projection are expected to have the shape as:
...@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes): ...@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
<query attention dims>, num_heads, channels) <query attention dims>, num_heads, channels)
Args: Args:
qkv_rank: the rank of query, key, value tensors. rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention. attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns: Returns:
Einsum equations. Einsum equations.
""" """
target_notation = _CHR_IDX[:qkv_rank] target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim. # `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,))) batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = qkv_rank letter_offset = rank
source_notation = "" source_notation = ""
for i in range(qkv_rank): for i in range(rank):
if i in batch_dims or i == qkv_rank - 1: if i in batch_dims or i == rank - 1:
source_notation += target_notation[i] source_notation += target_notation[i]
else: else:
source_notation += _CHR_IDX[letter_offset] source_notation += _CHR_IDX[letter_offset]
...@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
sequence dims. If not specified, projects back to the key feature dim. sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features. attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head return_attention_scores: bool, if `True`, returns the multi-head attention
attention scores as an additional output argument. scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
""" """
def __init__(self, def __init__(self,
...@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._attention_axes = (attention_axes,) self._attention_axes = (attention_axes,)
else: else:
self._attention_axes = attention_axes self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self): def get_config(self):
config = { config = {
...@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config = super(MultiHeadAttention, self).get_config() base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape): def _build_from_signature(self, query, value, key=None):
inputs_len = len(input_shape) """Builds layers and variables.
if inputs_len > 3 or inputs_len < 2:
raise ValueError( Once the method is called, self._built_from_signature will be set to True.
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. " Args:
"Given length: %d" % inputs_len) query: query tensor or TensorShape.
tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape) value: value tensor or TensorShape.
query_shape = tensor_shapes[0] key: key tensor or TensorShape.
value_shape = tensor_shapes[1] """
key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape self._built_from_signature = True
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -271,84 +293,79 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -271,84 +293,79 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1 free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2) free_dims, bound_dims=1, output_dims=2)
self._query_dense = EinsumDense( self._query_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]), [self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="query", name="query",
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2) key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = EinsumDense( self._key_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]), [self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="key", name="key",
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_shape.rank - 1, bound_dims=1, output_dims=2) value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = EinsumDense( self._value_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_size]), [self._num_heads, self._value_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="value", name="value",
**common_kwargs) **common_kwargs)
# Builds the attention computations for multi-head dot product attention. # Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once it # These computations could be wrapped into the keras attention layer once
# support mult-head einsum computations. # it support mult-head einsum computations.
self._build_attention(output_rank) self.build_attention(output_rank)
if self._output_shape: if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized): if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape] output_shape = [self._output_shape]
else:
output_shape = self._output_shape
else: else:
output_shape = self._output_shape output_shape = [query_shape[-1]]
else: einsum_equation, bias_axes, output_rank = _build_proj_equation(
output_shape = [query_shape[-1]] free_dims, bound_dims=2, output_dims=len(output_shape))
einsum_equation, bias_axes, output_rank = _build_proj_equation( self._output_dense = EinsumDense(
free_dims, bound_dims=2, output_dims=len(output_shape)) einsum_equation,
self._output_dense = EinsumDense( output_shape=_get_output_shape(output_rank - 1, output_shape),
einsum_equation, bias_axes=bias_axes if self._use_bias else None,
output_shape=_get_output_shape(output_rank - 1, output_shape), name="attention_output",
bias_axes=bias_axes if self._use_bias else None, **common_kwargs)
name="attention_output",
**common_kwargs) def build_attention(self, rank):
super(MultiHeadAttention, self).build(input_shape)
def _build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product costomize attention computation to replace the default dot-product
attention. attention.
Args: Args:
qkv_rank: the rank of query, key, value tensors. rank: the rank of query, key, value tensors.
""" """
if self._attention_axes is None: if self._attention_axes is None:
self._attention_axes = tuple(range(1, qkv_rank - 2)) self._attention_axes = tuple(range(1, rank - 2))
else: else:
self._attention_axes = tuple(self._attention_axes) self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = ( self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(qkv_rank, attn_axes=self._attention_axes)) _build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple( norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax( self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes) mask_expansion_axes=[1], normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def _compute_attention(self, def compute_attention(self, query, key, value, attention_mask=None):
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors. """Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected This function defines the computation inside `call` with projected
...@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention implementation. attention implementation.
Args: Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`. query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`. key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`. value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions. attention to certain positions.
...@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_output: Multi-headed outputs of attention computation. attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights. attention_scores: Multi-headed attention weights.
""" """
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key, query)
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S] # `attention_scores` = [B, N, T, S]
...@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `context_layer` = [B, T, N, H] # `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation, attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor) attention_scores_dropout, value)
return attention_output, attention_scores return attention_output, attention_scores
def call(self, inputs, attention_mask=None): def call(self, query, value, key=None, attention_mask=None):
"""Implements the forward pass. """Implements the forward pass.
Size glossary: Size glossary:
...@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value (source) attention axes shape (S), the rank must match the target. * Value (source) attention axes shape (S), the rank must match the target.
Args: Args:
inputs: List of the following tensors: query: Query `Tensor` of shape `[B, T, dim]`.
* query: Query `Tensor` of shape `[B, T, dim]`. value: Value `Tensor` of shape `[B, S, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`. key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will `value` for both `key` and `value`, which is the most common case.
use `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions. attention to certain positions.
...@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention attention
axes. axes.
""" """
inputs_len = len(inputs) if not self._built_from_signature:
if inputs_len > 3 or inputs_len < 2: self._build_from_signature(query=query, value=value, key=key)
raise ValueError( if key is None:
"Expects inputs list of length 2 or 3, namely [query, value] or " key = value
"[query, value, key]. "
"Given length: %d" % inputs_len)
query = inputs[0]
value = inputs[1]
key = inputs[2] if inputs_len == 3 else value
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, T, N ,H] # `query` = [B, T, N ,H]
query_tensor = self._query_dense(query) query = self._query_dense(query)
# `key_tensor` = [B, S, N, H] # `key` = [B, S, N, H]
key_tensor = self._key_dense(key) key = self._key_dense(key)
# `value_tensor` = [B, S, N, H] # `value` = [B, S, N, H]
value_tensor = self._value_dense(value) value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention( attention_output, attention_scores = self.compute_attention(
query_tensor, key_tensor, value_tensor, attention_mask) query, key, value, attention_mask)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores: if self._return_attention_scores:
...@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention): ...@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
Arguments are the same as `MultiHeadAttention` layer. Arguments are the same as `MultiHeadAttention` layer.
""" """
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step): def _update_cache(self, key, value, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors.""" """Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values. # Combines cached keys and values with new keys and values.
if decode_loop_step is not None: if decode_loop_step is not None:
# TPU special case. # TPU special case.
key_seq_dim = cache["key"].shape.as_list()[1] key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape( indices = tf.reshape(
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype), tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
[1, key_seq_dim, 1, 1]) [1, key_seq_dim, 1, 1])
key_tensor = cache["key"] + key_tensor * indices key = cache["key"] + key * indices
value_seq_dim = cache["value"].shape.as_list()[1] value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape( indices = tf.reshape(
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype), tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
[1, value_seq_dim, 1, 1]) [1, value_seq_dim, 1, 1])
value_tensor = cache["value"] + value_tensor * indices value = cache["value"] + value * indices
else: else:
key_tensor = tf.concat( key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1) value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
value_tensor = tf.concat(
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
# Update cache # Update cache
cache["key"] = key_tensor cache["key"] = key
cache["value"] = value_tensor cache["value"] = value
return key_tensor, value_tensor return key, value
def call(self, def call(self,
inputs, query,
value,
key=None,
attention_mask=None, attention_mask=None,
cache=None, cache=None,
decode_loop_step=None): decode_loop_step=None):
from_tensor = inputs[0] if not self._built_from_signature:
to_tensor = inputs[1] self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of sequences) # B = batch size (number of sequences)
...@@ -494,25 +509,23 @@ class CachedAttention(MultiHeadAttention): ...@@ -494,25 +509,23 @@ class CachedAttention(MultiHeadAttention):
# T = `to_tensor` sequence length # T = `to_tensor` sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor) query = self._query_dense(query)
# `key_tensor` = [B, T, N, H] # `key` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor) key = self._key_dense(key)
# `value_tensor` = [B, T, N, H] # `value` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor) value = self._value_dense(value)
if cache: if cache:
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor, key, value = self._update_cache(key, value, cache, decode_loop_step)
cache, decode_loop_step)
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key, query)
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, F, T] # `attention_scores` = [B, N, F, T]
...@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention): ...@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
attention_scores = self._dropout_layer(attention_scores) attention_scores = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
attention_output = tf.einsum(self._combine_equation, attention_scores, attention_output = tf.einsum(self._combine_equation, attention_scores,
value_tensor) value)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores: if self._return_attention_scores:
return attention_output, attention_scores, cache return attention_output, attention_scores, cache
......
...@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value]) output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
...@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64) test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): def test_attention_scores(self):
...@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True) num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query]) output, coef = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
...@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor) output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
...@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V. # Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8)) key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor) output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output) model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
...@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool") mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data) output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape) null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data) unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(output, unmasked_output) self.assertNotAllClose(output, unmasked_output)
...@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase): ...@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
key_size=64) key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
...@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element. # one element.
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length)) 2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data], mask_data, cache) masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data, cache=cache)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache. # Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data]) masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertIsNone(cache) self.assertIsNone(cache)
...@@ -243,10 +246,12 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -243,10 +246,12 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32) 2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly. # Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data], masked_output_data, cache = layer(
mask_data, query=from_data,
cache, value=from_data,
decode_loop_step=decode_loop_step) attention_mask=mask_data,
cache=cache,
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
...@@ -21,6 +21,8 @@ from __future__ import print_function ...@@ -21,6 +21,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import deprecation
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"] _CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
...@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer): ...@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer):
`(batch_size, units)`. `(batch_size, units)`.
""" """
@deprecation.deprecated(
None, "DenseEinsum is deprecated. Please use "
"tf.keras.experimental.EinsumDense layer instead.")
def __init__(self, def __init__(self,
output_shape, output_shape,
num_summed_dimensions=1, num_summed_dimensions=1,
......
...@@ -34,7 +34,7 @@ class MaskedLM(tf.keras.layers.Layer): ...@@ -34,7 +34,7 @@ class MaskedLM(tf.keras.layers.Layer):
Arguments: Arguments:
embedding_table: The embedding table of the targets. embedding_table: The embedding table of the targets.
activation: The activation, if any, for the dense layer. activation: The activation, if any, for the dense layer.
initializer: The intializer for the dense layer. Defaults to a Glorot initializer: The initializer for the dense layer. Defaults to a Glorot
uniform initializer. uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
'predictions'. 'predictions'.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment