Unverified Commit 0cceabfc authored by Yiming Shi's avatar Yiming Shi Committed by GitHub
Browse files

Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor

parents 17821c0d 39ee0ac9
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from typing import List, Optional
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import electra_pretrainer
@dataclasses.dataclass
class ELECTRAPretrainerConfig(base_config.Config):
"""ELECTRA pretrainer configuration."""
num_masked_tokens: int = 76
sequence_length: int = 512
num_classes: int = 2
discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[bert.ClsHeadConfig]
) -> List[layers.ClassificationHead]:
if cls_head_configs:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
]
else:
return []
def instantiate_pretrainer_from_cfg(
config: ELECTRAPretrainerConfig,
generator_network: Optional[tf.keras.Model] = None,
discriminator_network: Optional[tf.keras.Model] = None,
) -> electra_pretrainer.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads),
disallow_correct=config.disallow_correct)
# Copyright 2017 The TensorFlow Authors All Rights Reserved. # Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,53 +13,37 @@ ...@@ -12,53 +13,37 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
"""Tests of the block operators."""
import numpy as np
import tensorflow as tf import tensorflow as tf
import block_base from official.nlp.configs import bert
import blocks_operator from official.nlp.configs import electra
from official.nlp.configs import encoders
class AddOneBlock(block_base.BlockBase):
class ELECTRAModelsTest(tf.test.TestCase):
def __init__(self, name=None):
super(AddOneBlock, self).__init__(name) def test_network_invocation(self):
config = electra.ELECTRAPretrainerConfig(
def _Apply(self, x): generator_encoder=encoders.TransformerEncoderConfig(
return x + 1.0 vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
class SquareBlock(block_base.BlockBase): )
_ = electra.instantiate_pretrainer_from_cfg(config)
def __init__(self, name=None):
super(SquareBlock, self).__init__(name) # Invokes with classification heads.
config = electra.ELECTRAPretrainerConfig(
def _Apply(self, x): generator_encoder=encoders.TransformerEncoderConfig(
return x * x vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
class BlocksOperatorTest(tf.test.TestCase): cls_heads=[
bert.ClsHeadConfig(
def testComposition(self): inner_dim=10, num_classes=2, name="next_sentence")
x_value = np.array([[1.0, 2.0, 3.0], ])
[-1.0, -2.0, -3.0]]) _ = electra.instantiate_pretrainer_from_cfg(config)
y_expected_value = np.array([[4.0, 9.0, 16.0],
[0.0, 1.0, 4.0]]) if __name__ == "__main__":
x = tf.placeholder(dtype=tf.float32, shape=[2, 3])
complex_block = blocks_operator.CompositionOperator(
[AddOneBlock(),
SquareBlock()])
y = complex_block(x)
with self.test_session():
y_value = y.eval(feed_dict={x: x_value})
self.assertAllClose(y_expected_value, y_value)
if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -13,11 +13,18 @@ ...@@ -13,11 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Configurations for Encoders.""" """Transformer Encoders.
Includes configurations and instantiation methods.
"""
from typing import Optional
import dataclasses import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers
from official.nlp.modeling import networks
@dataclasses.dataclass @dataclasses.dataclass
...@@ -28,9 +35,64 @@ class TransformerEncoderConfig(base_config.Config): ...@@ -28,9 +35,64 @@ class TransformerEncoderConfig(base_config.Config):
num_layers: int = 12 num_layers: int = 12
num_attention_heads: int = 12 num_attention_heads: int = 12
hidden_activation: str = "gelu" hidden_activation: str = "gelu"
intermediate_size: int = 3076 intermediate_size: int = 3072
dropout_rate: float = 0.1 dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512 max_position_embeddings: int = 512
type_vocab_size: int = 2 type_vocab_size: int = 2
initializer_range: float = 0.02 initializer_range: float = 0.02
embedding_size: Optional[int] = None
def instantiate_encoder_from_cfg(
config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
max_seq_length=config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
dropout_rate=config.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_activation=tf_utils.get_activation(
config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers,
pooled_output_dim=config.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder":
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls))
encoder_network = encoder_cls(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
activation=tf_utils.get_activation(config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
embedding_width=config.embedding_size,
embedding_layer=embedding_layer)
return encoder_network
This diff is collapsed.
...@@ -27,18 +27,21 @@ from absl import flags ...@@ -27,18 +27,21 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.data import sentence_retrieval_lib
# word-piece tokenizer based squad_lib # word-piece tokenizer based squad_lib
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib # sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.nlp.data import tagging_data_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
# TODO(chendouble): consider moving each task to its own binary.
flags.DEFINE_enum( flags.DEFINE_enum(
"fine_tuning_task_type", "classification", "fine_tuning_task_type", "classification",
["classification", "regression", "squad"], ["classification", "regression", "squad", "retrieval", "tagging"],
"The name of the BERT fine tuning task for which data " "The name of the BERT fine tuning task for which data "
"will be generated..") "will be generated.")
# BERT classification specific flags. # BERT classification specific flags.
flags.DEFINE_string( flags.DEFINE_string(
...@@ -47,23 +50,41 @@ flags.DEFINE_string( ...@@ -47,23 +50,41 @@ flags.DEFINE_string(
"for the task.") "for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI", ["COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"PAWS-X"], "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI",
"The name of the task to train BERT classifier.") "XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
# MNLI task-specific flag.
flags.DEFINE_enum(
"mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task specific flag. # XNLI task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"xnli_language", "en", "xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data " "Language of training data for XNLI task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# PAWS-X task specific flag. # PAWS-X task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"pawsx_language", "en", "pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data " "Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# BERT Squad task specific flags. # Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
# Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.")
# BERT Squad task-specific flags.
flags.DEFINE_string( flags.DEFINE_string(
"squad_data_file", None, "squad_data_file", None,
"The input data file in for generating training data for BERT squad task.") "The input data file in for generating training data for BERT squad task.")
...@@ -163,20 +184,29 @@ def generate_classifier_dataset(): ...@@ -163,20 +184,29 @@ def generate_classifier_dataset():
"cola": "cola":
classifier_data_lib.ColaProcessor, classifier_data_lib.ColaProcessor,
"mnli": "mnli":
classifier_data_lib.MnliProcessor, functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type),
"mrpc": "mrpc":
classifier_data_lib.MrpcProcessor, classifier_data_lib.MrpcProcessor,
"qnli": "qnli":
classifier_data_lib.QnliProcessor, classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor, "qqp": classifier_data_lib.QqpProcessor,
"rte": classifier_data_lib.RteProcessor,
"sst-2": "sst-2":
classifier_data_lib.SstProcessor, classifier_data_lib.SstProcessor,
"sts-b":
classifier_data_lib.StsBProcessor,
"xnli": "xnli":
functools.partial(classifier_data_lib.XnliProcessor, functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language), language=FLAGS.xnli_language),
"paws-x": "paws-x":
functools.partial(classifier_data_lib.PawsxProcessor, functools.partial(classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language) language=FLAGS.pawsx_language),
"wnli": classifier_data_lib.WnliProcessor,
"xtreme-xnli":
functools.partial(classifier_data_lib.XtremeXnliProcessor),
"xtreme-paws-x":
functools.partial(classifier_data_lib.XtremePawsxProcessor)
} }
task_name = FLAGS.classification_task_name.lower() task_name = FLAGS.classification_task_name.lower()
if task_name not in processors: if task_name not in processors:
...@@ -237,6 +267,67 @@ def generate_squad_dataset(): ...@@ -237,6 +267,67 @@ def generate_squad_dataset():
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative) FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
def generate_retrieval_dataset():
"""Generate retrieval test and dev dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
processors = {
"bucc": sentence_retrieval_lib.BuccProcessor,
"tatoeba": sentence_retrieval_lib.TatoebaProcessor,
}
task_name = FLAGS.retrieval_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
processor = processors[task_name](process_text_fn=processor_text_fn)
return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
processor,
FLAGS.input_data_dir,
tokenizer,
FLAGS.eval_data_output_path,
FLAGS.test_data_output_path,
FLAGS.max_seq_length)
def generate_tagging_dataset():
"""Generates tagging dataset."""
processors = {
"panx": tagging_data_lib.PanxProcessor,
"udpos": tagging_data_lib.UdposProcessor,
}
task_name = FLAGS.tagging_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
elif FLAGS.tokenizer_impl == "sentence_piece":
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
else:
raise ValueError("Unsupported tokenizer_impl: %s" % FLAGS.tokenizer_impl)
processor = processors[task_name]()
return tagging_data_lib.generate_tf_record_from_data_file(
processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
FLAGS.test_data_output_path, processor_text_fn)
def main(_): def main(_):
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenizer_impl == "word_piece":
if not FLAGS.vocab_file: if not FLAGS.vocab_file:
...@@ -248,12 +339,20 @@ def main(_): ...@@ -248,12 +339,20 @@ def main(_):
raise ValueError( raise ValueError(
"FLAG sp_model_file for sentence-piece tokenizer is not specified.") "FLAG sp_model_file for sentence-piece tokenizer is not specified.")
if FLAGS.fine_tuning_task_type != "retrieval":
flags.mark_flag_as_required("train_data_output_path")
if FLAGS.fine_tuning_task_type == "classification": if FLAGS.fine_tuning_task_type == "classification":
input_meta_data = generate_classifier_dataset() input_meta_data = generate_classifier_dataset()
elif FLAGS.fine_tuning_task_type == "regression": elif FLAGS.fine_tuning_task_type == "regression":
input_meta_data = generate_regression_dataset() input_meta_data = generate_regression_dataset()
else: elif FLAGS.fine_tuning_task_type == "retrieval":
input_meta_data = generate_retrieval_dataset()
elif FLAGS.fine_tuning_task_type == "squad":
input_meta_data = generate_squad_dataset() input_meta_data = generate_squad_dataset()
else:
assert FLAGS.fine_tuning_task_type == "tagging"
input_meta_data = generate_tagging_dataset()
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path)) tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer: with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
...@@ -261,6 +360,5 @@ def main(_): ...@@ -261,6 +360,5 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
flags.mark_flag_as_required("train_data_output_path")
flags.mark_flag_as_required("meta_data_file_path") flags.mark_flag_as_required("meta_data_file_path")
app.run(main) app.run(main)
# Copyright 2018 The TensorFlow Authors All Rights Reserved. # Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,23 +13,47 @@ ...@@ -12,23 +13,47 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A global factory to access NLP registered data loaders."""
"""Define the abstract class for contextual bandit algorithms.""" from official.utils import registry
from __future__ import absolute_import _REGISTERED_DATA_LOADER_CLS = {}
from __future__ import division
from __future__ import print_function
class BanditAlgorithm(object): def register_data_loader_cls(data_config_cls):
"""A bandit algorithm must be able to do two basic operations. """Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
1. Choose an action given a context. This decorator supports registration of data loaders as follows:
2. Update its internal model given a triple (context, played action, reward).
"""
def action(self, context): ```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass pass
def update(self, context, action, reward): @register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
def get_data_loader(data_config):
"""Creates a data_loader from data_config."""
return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(
data_config)
...@@ -16,11 +16,27 @@ ...@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task.""" """Loads dataset for the BERT pretraining task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ''
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader: class BertPretrainDataLoader:
"""A class to load dataset for bert pretraining task.""" """A class to load dataset for bert pretraining task."""
...@@ -91,7 +107,5 @@ class BertPretrainDataLoader: ...@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader( reader = input_reader.InputReader(
params=self._params, params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
decoder_fn=self._decode,
parser_fn=self._parse)
return reader.read(input_context) return reader.read(input_context)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ''
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
# Settings below are question answering specific.
version_2_with_negative: bool = False
# Settings below are only used for eval mode.
input_preprocessed_data_path: str = ''
doc_stride: int = 128
query_length: int = 64
vocab_file: str = ''
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
do_lower_case: bool = True
@data_loader_factory.register_data_loader_cls(QADataConfig)
class QuestionAnsweringDataLoader:
"""A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params):
self._params = params
self._seq_length = params.seq_length
self._is_training = params.is_training
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
else:
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x, y = {}, {}
for name, tensor in record.items():
if name in ('start_positions', 'end_positions'):
y[name] = tensor
elif name == 'input_ids':
x['input_word_ids'] = tensor
elif name == 'segment_ids':
x['input_type_ids'] = tensor
else:
x[name] = tensor
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
...@@ -15,11 +15,28 @@ ...@@ -15,11 +15,28 @@
# ============================================================================== # ==============================================================================
"""Loads dataset for the sentence prediction (classification) task.""" """Loads dataset for the sentence prediction (classification) task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ''
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
label_type: str = 'int'
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
class SentencePredictionDataLoader: class SentencePredictionDataLoader:
"""A class to load dataset for sentence prediction (classification) task.""" """A class to load dataset for sentence prediction (classification) task."""
...@@ -29,11 +46,12 @@ class SentencePredictionDataLoader: ...@@ -29,11 +46,12 @@ class SentencePredictionDataLoader:
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
label_type = LABEL_TYPES_MAP[self._params.label_type]
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64), 'label_ids': tf.io.FixedLenFeature([], label_type),
} }
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -20,10 +20,12 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum ...@@ -20,10 +20,12 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.masked_lm import MaskedLM from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.transformer import Transformer from official.nlp.modeling.layers.transformer import *
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment