"megatron/vscode:/vscode.git/clone" did not exist on "a44360edb23f8853ee70b2204960a90fed4490d0"
Commit 31ca3b97 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

resovle merge conflicts

parents 3e9d886d 7fcd7cba
...@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor): ...@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
def get_activation(identifier): def get_activation(identifier):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`. """Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
......
...@@ -14,23 +14,61 @@ ...@@ -14,23 +14,61 @@
# ============================================================================== # ==============================================================================
"""ALBERT classification finetuning runner in tf2.x.""" """ALBERT classification finetuning runner in tf2.x."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import run_classifier as run_classifier_bert from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def predict(strategy, albert_config, input_meta_data, predict_input_fn):
"""Function outputs both the ground truth predictions as .tsv files."""
with strategy.scope():
classifier_model = bert_models.classifier_model(
albert_config, input_meta_data['num_labels'])[0]
checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or
tf.train.latest_checkpoint(FLAGS.model_dir))
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
preds, ground_truth = run_classifier_bert.get_predictions_and_labels(
strategy, classifier_model, predict_input_fn, return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****')
for probabilities in preds:
output_line = '\t'.join(
str(class_probability)
for class_probability in probabilities) + '\n'
writer.write(output_line)
ground_truth_labels_file = os.path.join(FLAGS.model_dir,
'output_labels.tsv')
with tf.io.gfile.GFile(ground_truth_labels_file, 'w') as writer:
logging.info('***** Ground truth results *****')
for label in ground_truth:
output_line = '\t'.join(str(label)) + '\n'
writer.write(output_line)
return
def main(_): def main(_):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
...@@ -56,9 +94,14 @@ def main(_): ...@@ -56,9 +94,14 @@ def main(_):
albert_config = albert_configs.AlbertConfig.from_json_file( albert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file) FLAGS.bert_config_file)
if FLAGS.mode == 'train_and_eval':
run_classifier_bert.run_bert(strategy, input_meta_data, albert_config, run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
train_input_fn, eval_input_fn) train_input_fn, eval_input_fn)
elif FLAGS.mode == 'predict':
predict(strategy, albert_config, input_meta_data, eval_input_fn)
else:
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
return
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file') flags.mark_flag_as_required('bert_config_file')
......
...@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig, ...@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
do_lower_case, vocab_file) do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config) core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder) checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed() checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file) core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False) core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf") core_model.save(hub_destination, include_optimizer=False, save_format="tf")
......
...@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text, ...@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
raise ValueError('model must be a tf.keras.Model object.') raise ValueError('model must be a tf.keras.Model object.')
if checkpoint_dir: if checkpoint_dir:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if restore_model_using_load_weights: if restore_model_using_load_weights:
model_weight_path = os.path.join(checkpoint_dir, 'checkpoint') model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
assert tf.io.gfile.exists(model_weight_path) assert tf.io.gfile.exists(model_weight_path)
model.load_weights(model_weight_path) model.load_weights(model_weight_path)
# tf.train.Checkpoint API was used via custom training loop logic.
else: else:
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model)
......
...@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir): ...@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
@deprecation.deprecated( @deprecation.deprecated(
None, 'This function is deprecated. Please use Keras compile/fit instead.') None, 'This function is deprecated and we do not expect adding new '
'functionalities. Please do not have your code depending '
'on this library.')
def run_customized_training_loop( def run_customized_training_loop(
# pylint: disable=invalid-name # pylint: disable=invalid-name
_sentinel=None, _sentinel=None,
...@@ -557,7 +559,6 @@ def run_customized_training_loop( ...@@ -557,7 +559,6 @@ def run_customized_training_loop(
for metric in model.metrics: for metric in model.metrics:
training_summary[metric.name] = _float_metric_value(metric) training_summary[metric.name] = _float_metric_value(metric)
if eval_metrics: if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value( training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0]) train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0]) training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
......
...@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config, ...@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
# Export uses float32 for now, even if training uses mixed precision. # Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
classifier_model = bert_models.classifier_model( classifier_model = bert_models.classifier_model(
bert_config, input_meta_data.get('num_labels', 1))[0] bert_config,
input_meta_data.get('num_labels', 1),
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=False)[0]
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
model_export_path, model=classifier_model, checkpoint_dir=model_dir) model_export_path, model=classifier_model, checkpoint_dir=model_dir)
......
...@@ -24,7 +24,6 @@ import tensorflow as tf ...@@ -24,7 +24,6 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling.models import bert_pretrainer from official.nlp.modeling.models import bert_pretrainer
...@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config): ...@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class BertPretrainerConfig(base_config.Config): class BertPretrainerConfig(base_config.Config):
"""BERT encoder configuration.""" """BERT encoder configuration."""
num_masked_tokens: int = 76
encoder: encoders.TransformerEncoderConfig = ( encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list) cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
...@@ -56,96 +54,18 @@ def instantiate_classification_heads_from_cfgs( ...@@ -56,96 +54,18 @@ def instantiate_classification_heads_from_cfgs(
] if cls_head_configs else [] ] if cls_head_configs else []
def instantiate_bertpretrainer_from_cfg( def instantiate_pretrainer_from_cfg(
config: BertPretrainerConfig, config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2: ) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config.""" """Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder encoder_cfg = config.encoder
if encoder_network is None: if encoder_network is None:
encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg) encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
return bert_pretrainer.BertPretrainerV2( return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation), mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network, encoder_network=encoder_network,
classification_heads=instantiate_classification_heads_from_cfgs( classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads)) config.cls_heads))
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig):
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = False
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
@dataclasses.dataclass
class SentencePredictionDevDataConfig(cfg.DataConfig):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = False
seq_length: int = 128
drop_remainder: bool = False
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class QADevDataConfig(cfg.DataConfig):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
drop_remainder: bool = False
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class TaggingDevDataConfig(cfg.DataConfig):
"""Dev Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
drop_remainder: bool = False
...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self): def test_network_invocation(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1)) encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads. # Invokes with classification heads.
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self): def test_checkpoint_items(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase): ...@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
encoder = bert.instantiate_bertpretrainer_from_cfg(config) encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(encoder.checkpoint_items.keys(), self.assertSameElements(
["encoder", "next_sentence.pooler_dense"]) encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__": if __name__ == "__main__":
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from typing import List, Optional
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import electra_pretrainer
@dataclasses.dataclass
class ELECTRAPretrainerConfig(base_config.Config):
"""ELECTRA pretrainer configuration."""
num_masked_tokens: int = 76
sequence_length: int = 512
num_classes: int = 2
discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[bert.ClsHeadConfig]
) -> List[layers.ClassificationHead]:
if cls_head_configs:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
]
else:
return []
def instantiate_pretrainer_from_cfg(
config: ELECTRAPretrainerConfig,
generator_network: Optional[tf.keras.Model] = None,
discriminator_network: Optional[tf.keras.Model] = None,
) -> electra_pretrainer.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads),
disallow_correct=config.disallow_correct)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
class ELECTRAModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
)
_ = electra.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = electra.instantiate_pretrainer_from_cfg(config)
if __name__ == "__main__":
tf.test.main()
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
Includes configurations and instantiation methods. Includes configurations and instantiation methods.
""" """
from typing import Optional
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -40,12 +41,47 @@ class TransformerEncoderConfig(base_config.Config): ...@@ -40,12 +41,47 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings: int = 512 max_position_embeddings: int = 512
type_vocab_size: int = 2 type_vocab_size: int = 2
initializer_range: float = 0.02 initializer_range: float = 0.02
embedding_size: Optional[int] = None
def instantiate_encoder_from_cfg( def instantiate_encoder_from_cfg(
config: TransformerEncoderConfig) -> networks.TransformerEncoder: config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig.""" """Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network = networks.TransformerEncoder( if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
seq_length=None,
max_seq_length=config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
dropout_rate=config.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_activation=tf_utils.get_activation(
config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers,
pooled_output_dim=config.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder":
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls))
encoder_network = encoder_cls(
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
num_layers=config.num_layers, num_layers=config.num_layers,
...@@ -58,5 +94,7 @@ def instantiate_encoder_from_cfg( ...@@ -58,5 +94,7 @@ def instantiate_encoder_from_cfg(
max_sequence_length=config.max_position_embeddings, max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size, type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range)) stddev=config.initializer_range),
embedding_width=config.embedding_size,
embedding_layer=embedding_layer)
return encoder_network return encoder_network
This diff is collapsed.
...@@ -32,14 +32,16 @@ from official.nlp.data import sentence_retrieval_lib ...@@ -32,14 +32,16 @@ from official.nlp.data import sentence_retrieval_lib
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib # sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.nlp.data import tagging_data_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
# TODO(chendouble): consider moving each task to its own binary.
flags.DEFINE_enum( flags.DEFINE_enum(
"fine_tuning_task_type", "classification", "fine_tuning_task_type", "classification",
["classification", "regression", "squad", "retrieval"], ["classification", "regression", "squad", "retrieval", "tagging"],
"The name of the BERT fine tuning task for which data " "The name of the BERT fine tuning task for which data "
"will be generated..") "will be generated.")
# BERT classification specific flags. # BERT classification specific flags.
flags.DEFINE_string( flags.DEFINE_string(
...@@ -48,30 +50,41 @@ flags.DEFINE_string( ...@@ -48,30 +50,41 @@ flags.DEFINE_string(
"for the task.") "for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI", ["COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"PAWS-X", "XTREME-XNLI", "XTREME-PAWS-X"], "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI",
"XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The " "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format " "difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english " "of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for " "only and for XNLI is all languages combined. Same for "
"PAWS-X.") "PAWS-X.")
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"], # MNLI task-specific flag.
"The name of sentence retrieval task for scoring") flags.DEFINE_enum(
"mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task specific flag. # XNLI task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"xnli_language", "en", "xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data " "Language of training data for XNLI task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# PAWS-X task specific flag. # PAWS-X task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"pawsx_language", "en", "pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data " "Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# BERT Squad task specific flags. # Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
# Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.")
# BERT Squad task-specific flags.
flags.DEFINE_string( flags.DEFINE_string(
"squad_data_file", None, "squad_data_file", None,
"The input data file in for generating training data for BERT squad task.") "The input data file in for generating training data for BERT squad task.")
...@@ -171,7 +184,8 @@ def generate_classifier_dataset(): ...@@ -171,7 +184,8 @@ def generate_classifier_dataset():
"cola": "cola":
classifier_data_lib.ColaProcessor, classifier_data_lib.ColaProcessor,
"mnli": "mnli":
classifier_data_lib.MnliProcessor, functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type),
"mrpc": "mrpc":
classifier_data_lib.MrpcProcessor, classifier_data_lib.MrpcProcessor,
"qnli": "qnli":
...@@ -180,6 +194,8 @@ def generate_classifier_dataset(): ...@@ -180,6 +194,8 @@ def generate_classifier_dataset():
"rte": classifier_data_lib.RteProcessor, "rte": classifier_data_lib.RteProcessor,
"sst-2": "sst-2":
classifier_data_lib.SstProcessor, classifier_data_lib.SstProcessor,
"sts-b":
classifier_data_lib.StsBProcessor,
"xnli": "xnli":
functools.partial(classifier_data_lib.XnliProcessor, functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language), language=FLAGS.xnli_language),
...@@ -284,6 +300,34 @@ def generate_retrieval_dataset(): ...@@ -284,6 +300,34 @@ def generate_retrieval_dataset():
FLAGS.max_seq_length) FLAGS.max_seq_length)
def generate_tagging_dataset():
"""Generates tagging dataset."""
processors = {
"panx": tagging_data_lib.PanxProcessor,
"udpos": tagging_data_lib.UdposProcessor,
}
task_name = FLAGS.tagging_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
elif FLAGS.tokenizer_impl == "sentence_piece":
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
else:
raise ValueError("Unsupported tokenizer_impl: %s" % FLAGS.tokenizer_impl)
processor = processors[task_name]()
return tagging_data_lib.generate_tf_record_from_data_file(
processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
FLAGS.test_data_output_path, processor_text_fn)
def main(_): def main(_):
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenizer_impl == "word_piece":
if not FLAGS.vocab_file: if not FLAGS.vocab_file:
...@@ -304,8 +348,11 @@ def main(_): ...@@ -304,8 +348,11 @@ def main(_):
input_meta_data = generate_regression_dataset() input_meta_data = generate_regression_dataset()
elif FLAGS.fine_tuning_task_type == "retrieval": elif FLAGS.fine_tuning_task_type == "retrieval":
input_meta_data = generate_retrieval_dataset() input_meta_data = generate_retrieval_dataset()
else: elif FLAGS.fine_tuning_task_type == "squad":
input_meta_data = generate_squad_dataset() input_meta_data = generate_squad_dataset()
else:
assert FLAGS.fine_tuning_task_type == "tagging"
input_meta_data = generate_tagging_dataset()
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path)) tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer: with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
......
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import itertools
import random import random
from absl import app from absl import app
...@@ -48,6 +49,12 @@ flags.DEFINE_bool( ...@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask", False, "do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.") "Whether to use whole word masking rather than per-WordPiece masking.")
flags.DEFINE_integer(
"max_ngram_size", None,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
flags.DEFINE_bool( flags.DEFINE_bool(
"gzip_compress", False, "gzip_compress", False,
"Whether to use `GZIP` compress option to get compressed TFRecord files.") "Whether to use `GZIP` compress option to get compressed TFRecord files.")
...@@ -192,7 +199,8 @@ def create_training_instances(input_files, ...@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob, masked_lm_prob,
max_predictions_per_seq, max_predictions_per_seq,
rng, rng,
do_whole_word_mask=False): do_whole_word_mask=False,
max_ngram_size=None):
"""Create `TrainingInstance`s from raw text.""" """Create `TrainingInstance`s from raw text."""
all_documents = [[]] all_documents = [[]]
...@@ -229,7 +237,7 @@ def create_training_instances(input_files, ...@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document( create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)) do_whole_word_mask, max_ngram_size))
rng.shuffle(instances) rng.shuffle(instances)
return instances return instances
...@@ -238,7 +246,8 @@ def create_training_instances(input_files, ...@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def create_instances_from_document( def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask=False): do_whole_word_mask=False,
max_ngram_size=None):
"""Creates `TrainingInstance`s for a single document.""" """Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index] document = all_documents[document_index]
...@@ -337,7 +346,7 @@ def create_instances_from_document( ...@@ -337,7 +346,7 @@ def create_instances_from_document(
(tokens, masked_lm_positions, (tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions( masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng, tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask) do_whole_word_mask, max_ngram_size)
instance = TrainingInstance( instance = TrainingInstance(
tokens=tokens, tokens=tokens,
segment_ids=segment_ids, segment_ids=segment_ids,
...@@ -355,72 +364,238 @@ def create_instances_from_document( ...@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"]) ["index", "label"])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram = collections.namedtuple("_Gram", ["begin", "end"])
def _window(iterable, size):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i = iter(iterable)
window = []
try:
for e in range(0, size):
window.append(next(i))
yield window
except StopIteration:
# handle the case where iterable's length is less than the window size.
return
for e in i:
window = window[1:] + [e]
yield window
def _contiguous(sorted_grams):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for a, b in _window(sorted_grams, 2):
if a.end != b.begin:
return False
return True
def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if not grams:
return None
grams = sorted(grams)
num_tokens = grams[-1].end
# Ensure our grams are valid (i.e., they don't overlap).
for a, b in _window(grams, 2):
if a.end > b.begin:
raise ValueError("overlapping grams: {}".format(grams))
# Build map from n-gram length to list of n-grams.
ngrams = {i: [] for i in range(1, max_ngram_size+1)}
for gram_size in range(1, max_ngram_size+1):
for g in _window(grams, gram_size):
if _contiguous(g):
# Add an n-gram which spans these one-grams.
ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
# Shuffle each list of n-grams.
for v in ngrams.values():
rng.shuffle(v)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights = list(
itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
output_ngrams = []
# Keep a bitmask of which tokens have been masked.
masked_tokens = [False] * num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while (sum(masked_tokens) < max_masked_tokens and
sum(len(s) for s in ngrams.values())):
# Pick an n-gram size based on our weights.
sz = random.choices(range(1, max_ngram_size+1),
cum_weights=cummulative_weights)[0]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if sum(masked_tokens) + sz > max_masked_tokens:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams[sz].clear()
continue
def create_masked_lm_predictions(tokens, masked_lm_prob, # All of the n-grams of this size have been used.
max_predictions_per_seq, vocab_words, rng, if not ngrams[sz]:
do_whole_word_mask): continue
"""Creates the predictions for the masked LM objective."""
# Choose a random n-gram of the given size.
gram = ngrams[sz].pop()
num_gram_tokens = gram.end-gram.begin
# Check if this would add too many tokens.
if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
continue
# Check if any of the tokens in this gram have already been masked.
if sum(masked_tokens[gram.begin:gram.end]):
continue
cand_indexes = [] # Found a usable n-gram! Mark its tokens as masked and add it to return.
for (i, token) in enumerate(tokens): masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
if token == "[CLS]" or token == "[SEP]": output_ngrams.append(gram)
return output_ngrams
def _wordpieces_to_grams(tokens):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams = []
gram_start_pos = None
for i, token in enumerate(tokens):
if gram_start_pos is not None and token.startswith("##"):
continue continue
# Whole Word Masking means that if we mask all of the wordpieces if gram_start_pos is not None:
# corresponding to an original word. When a word has been split into grams.append(_Gram(gram_start_pos, i))
# WordPieces, the first token does not have any marker and any subsequence if token not in ["[CLS]", "[SEP]"]:
# tokens are prefixed with ##. So whenever we see the ## token, we gram_start_pos = i
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else: else:
cand_indexes.append([i]) gram_start_pos = None
if gram_start_pos is not None:
grams.append(_Gram(gram_start_pos, len(tokens)))
return grams
rng.shuffle(cand_indexes)
output_tokens = list(tokens) def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask,
max_ngram_size=None):
"""Creates the predictions for the masked LM objective."""
if do_whole_word_mask:
grams = _wordpieces_to_grams(tokens)
else:
# Here we consider each token to be a word to allow for sub-word masking.
if max_ngram_size:
raise ValueError("cannot use ngram masking without whole word masking")
grams = [_Gram(i, i+1) for i in range(0, len(tokens))
if tokens[i] not in ["[CLS]", "[SEP]"]]
num_to_predict = min(max_predictions_per_seq, num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob)))) max(1, int(round(len(tokens) * masked_lm_prob))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
num_to_predict, rng)
masked_lms = [] masked_lms = []
covered_indexes = set() output_tokens = list(tokens)
for index_set in cand_indexes: for gram in masked_grams:
if len(masked_lms) >= num_to_predict: # 80% of the time, replace all n-gram tokens with [MASK]
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8: if rng.random() < 0.8:
masked_token = "[MASK]" replacement_action = lambda idx: "[MASK]"
else: else:
# 10% of the time, keep original # 10% of the time, keep all the original n-gram tokens.
if rng.random() < 0.5: if rng.random() < 0.5:
masked_token = tokens[index] replacement_action = lambda idx: tokens[idx]
# 10% of the time, replace with random word # 10% of the time, replace each n-gram token with a random word.
else: else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] replacement_action = lambda idx: rng.choice(vocab_words)
output_tokens[index] = masked_token for idx in range(gram.begin, gram.end):
output_tokens[idx] = replacement_action(idx)
masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lms = sorted(masked_lms, key=lambda x: x.index)
...@@ -467,7 +642,7 @@ def main(_): ...@@ -467,7 +642,7 @@ def main(_):
instances = create_training_instances( instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng, FLAGS.do_whole_word_mask) rng, FLAGS.do_whole_word_mask, FLAGS.max_ngram_size)
output_files = FLAGS.output_file.split(",") output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***") logging.info("*** Writing to output files ***")
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to access NLP registered data loaders."""
from official.utils import registry
_REGISTERED_DATA_LOADER_CLS = {}
def register_data_loader_cls(data_config_cls):
"""Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
This decorator supports registration of data loaders as follows:
```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass
@register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
def get_data_loader(data_config):
"""Creates a data_loader from data_config."""
return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(
data_config)
...@@ -16,11 +16,27 @@ ...@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task.""" """Loads dataset for the BERT pretraining task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ''
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader: class BertPretrainDataLoader:
"""A class to load dataset for bert pretraining task.""" """A class to load dataset for bert pretraining task."""
...@@ -91,7 +107,5 @@ class BertPretrainDataLoader: ...@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader( reader = input_reader.InputReader(
params=self._params, params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
decoder_fn=self._decode,
parser_fn=self._parse)
return reader.read(input_context) return reader.read(input_context)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ''
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
# Settings below are question answering specific.
version_2_with_negative: bool = False
# Settings below are only used for eval mode.
input_preprocessed_data_path: str = ''
doc_stride: int = 128
query_length: int = 64
vocab_file: str = ''
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
do_lower_case: bool = True
@data_loader_factory.register_data_loader_cls(QADataConfig)
class QuestionAnsweringDataLoader:
"""A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params):
self._params = params
self._seq_length = params.seq_length
self._is_training = params.is_training
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
else:
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x, y = {}, {}
for name, tensor in record.items():
if name in ('start_positions', 'end_positions'):
y[name] = tensor
elif name == 'input_ids':
x['input_word_ids'] = tensor
elif name == 'segment_ids':
x['input_type_ids'] = tensor
else:
x[name] = tensor
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
...@@ -15,11 +15,28 @@ ...@@ -15,11 +15,28 @@
# ============================================================================== # ==============================================================================
"""Loads dataset for the sentence prediction (classification) task.""" """Loads dataset for the sentence prediction (classification) task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ''
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
label_type: str = 'int'
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
class SentencePredictionDataLoader: class SentencePredictionDataLoader:
"""A class to load dataset for sentence prediction (classification) task.""" """A class to load dataset for sentence prediction (classification) task."""
...@@ -29,11 +46,12 @@ class SentencePredictionDataLoader: ...@@ -29,11 +46,12 @@ class SentencePredictionDataLoader:
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
label_type = LABEL_TYPES_MAP[self._params.label_type]
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64), 'label_ids': tf.io.FixedLenFeature([], label_type),
} }
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
......
This diff is collapsed.
...@@ -15,17 +15,30 @@ ...@@ -15,17 +15,30 @@
# ============================================================================== # ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task.""" """Loads dataset for the tagging (e.g., NER/POS) task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
is_training: bool = True
seq_length: int = 128
include_sentence_id: bool = False
@data_loader_factory.register_data_loader_cls(TaggingDataConfig)
class TaggingDataLoader: class TaggingDataLoader:
"""A class to load dataset for tagging (e.g., NER and POS) task.""" """A class to load dataset for tagging (e.g., NER and POS) task."""
def __init__(self, params): def __init__(self, params: TaggingDataConfig):
self._params = params self._params = params
self._seq_length = params.seq_length self._seq_length = params.seq_length
self._include_sentence_id = params.include_sentence_id
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
...@@ -35,6 +48,9 @@ class TaggingDataLoader: ...@@ -35,6 +48,9 @@ class TaggingDataLoader:
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
} }
if self._include_sentence_id:
name_to_features['sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
...@@ -54,6 +70,8 @@ class TaggingDataLoader: ...@@ -54,6 +70,8 @@ class TaggingDataLoader:
'input_mask': record['input_mask'], 'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids'] 'input_type_ids': record['segment_ids']
} }
if self._include_sentence_id:
x['sentence_id'] = record['sentence_id']
y = record['label_ids'] y = record['label_ids']
return (x, y) return (x, y)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment