"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "90e0a0dd08159e1c95f4f9d99bb8687f327d36c3"
Commit 31ca3b97 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

resovle merge conflicts

parents 3e9d886d 7fcd7cba
......@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
def get_activation(identifier):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
......
......@@ -14,23 +14,61 @@
# ==============================================================================
"""ALBERT classification finetuning runner in tf2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
def predict(strategy, albert_config, input_meta_data, predict_input_fn):
"""Function outputs both the ground truth predictions as .tsv files."""
with strategy.scope():
classifier_model = bert_models.classifier_model(
albert_config, input_meta_data['num_labels'])[0]
checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or
tf.train.latest_checkpoint(FLAGS.model_dir))
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
preds, ground_truth = run_classifier_bert.get_predictions_and_labels(
strategy, classifier_model, predict_input_fn, return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****')
for probabilities in preds:
output_line = '\t'.join(
str(class_probability)
for class_probability in probabilities) + '\n'
writer.write(output_line)
ground_truth_labels_file = os.path.join(FLAGS.model_dir,
'output_labels.tsv')
with tf.io.gfile.GFile(ground_truth_labels_file, 'w') as writer:
logging.info('***** Ground truth results *****')
for label in ground_truth:
output_line = '\t'.join(str(label)) + '\n'
writer.write(output_line)
return
def main(_):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
......@@ -56,9 +94,14 @@ def main(_):
albert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
train_input_fn, eval_input_fn)
if FLAGS.mode == 'train_and_eval':
run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
train_input_fn, eval_input_fn)
elif FLAGS.mode == 'predict':
predict(strategy, albert_config, input_meta_data, eval_input_fn)
else:
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
return
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
......
......@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed()
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
......
......@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
raise ValueError('model must be a tf.keras.Model object.')
if checkpoint_dir:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if restore_model_using_load_weights:
model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
assert tf.io.gfile.exists(model_weight_path)
model.load_weights(model_weight_path)
# tf.train.Checkpoint API was used via custom training loop logic.
else:
checkpoint = tf.train.Checkpoint(model=model)
......
......@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
@deprecation.deprecated(
None, 'This function is deprecated. Please use Keras compile/fit instead.')
None, 'This function is deprecated and we do not expect adding new '
'functionalities. Please do not have your code depending '
'on this library.')
def run_customized_training_loop(
# pylint: disable=invalid-name
_sentinel=None,
......@@ -557,7 +559,6 @@ def run_customized_training_loop(
for metric in model.metrics:
training_summary[metric.name] = _float_metric_value(metric)
if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
......
......@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
classifier_model = bert_models.classifier_model(
bert_config, input_meta_data.get('num_labels', 1))[0]
bert_config,
input_meta_data.get('num_labels', 1),
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=False)[0]
model_saving_utils.export_bert_model(
model_export_path, model=classifier_model, checkpoint_dir=model_dir)
......
......@@ -24,7 +24,6 @@ import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import bert_pretrainer
......@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass
class BertPretrainerConfig(base_config.Config):
"""BERT encoder configuration."""
num_masked_tokens: int = 76
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
......@@ -56,96 +54,18 @@ def instantiate_classification_heads_from_cfgs(
] if cls_head_configs else []
def instantiate_bertpretrainer_from_cfg(
def instantiate_pretrainer_from_cfg(
config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2:
) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder
if encoder_network is None:
encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network,
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads))
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig):
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = False
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
@dataclasses.dataclass
class SentencePredictionDevDataConfig(cfg.DataConfig):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = False
seq_length: int = 128
drop_remainder: bool = False
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class QADevDataConfig(cfg.DataConfig):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
drop_remainder: bool = False
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class TaggingDevDataConfig(cfg.DataConfig):
"""Dev Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
drop_remainder: bool = False
......@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_bertpretrainer_from_cfg(config)
_ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = bert.BertPretrainerConfig(
......@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_bertpretrainer_from_cfg(config)
_ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig(
......@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_bertpretrainer_from_cfg(config)
_ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self):
config = bert.BertPretrainerConfig(
......@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
encoder = bert.instantiate_bertpretrainer_from_cfg(config)
self.assertSameElements(encoder.checkpoint_items.keys(),
["encoder", "next_sentence.pooler_dense"])
encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(
encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__":
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from typing import List, Optional
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import electra_pretrainer
@dataclasses.dataclass
class ELECTRAPretrainerConfig(base_config.Config):
"""ELECTRA pretrainer configuration."""
num_masked_tokens: int = 76
sequence_length: int = 512
num_classes: int = 2
discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[bert.ClsHeadConfig]
) -> List[layers.ClassificationHead]:
if cls_head_configs:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
]
else:
return []
def instantiate_pretrainer_from_cfg(
config: ELECTRAPretrainerConfig,
generator_network: Optional[tf.keras.Model] = None,
discriminator_network: Optional[tf.keras.Model] = None,
) -> electra_pretrainer.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads),
disallow_correct=config.disallow_correct)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
class ELECTRAModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
)
_ = electra.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = electra.instantiate_pretrainer_from_cfg(config)
if __name__ == "__main__":
tf.test.main()
......@@ -17,12 +17,13 @@
Includes configurations and instantiation methods.
"""
from typing import Optional
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers
from official.nlp.modeling import networks
......@@ -40,12 +41,47 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
def instantiate_encoder_from_cfg(
config: TransformerEncoderConfig) -> networks.TransformerEncoder:
config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network = networks.TransformerEncoder(
if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
seq_length=None,
max_seq_length=config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
dropout_rate=config.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_activation=tf_utils.get_activation(
config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers,
pooled_output_dim=config.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder":
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls))
encoder_network = encoder_cls(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
......@@ -58,5 +94,7 @@ def instantiate_encoder_from_cfg(
max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
stddev=config.initializer_range),
embedding_width=config.embedding_size,
embedding_layer=embedding_layer)
return encoder_network
This diff is collapsed.
......@@ -32,14 +32,16 @@ from official.nlp.data import sentence_retrieval_lib
from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp
from official.nlp.data import tagging_data_lib
FLAGS = flags.FLAGS
# TODO(chendouble): consider moving each task to its own binary.
flags.DEFINE_enum(
"fine_tuning_task_type", "classification",
["classification", "regression", "squad", "retrieval"],
["classification", "regression", "squad", "retrieval", "tagging"],
"The name of the BERT fine tuning task for which data "
"will be generated..")
"will be generated.")
# BERT classification specific flags.
flags.DEFINE_string(
......@@ -48,30 +50,41 @@ flags.DEFINE_string(
"for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI",
"PAWS-X", "XTREME-XNLI", "XTREME-PAWS-X"],
["COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI",
"XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
# MNLI task-specific flag.
flags.DEFINE_enum(
"mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task specific flag.
# XNLI task-specific flag.
flags.DEFINE_string(
"xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data "
"Language of training data for XNLI task. If the value is 'all', the data "
"of all languages will be used for training.")
# PAWS-X task specific flag.
# PAWS-X task-specific flag.
flags.DEFINE_string(
"pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.")
# BERT Squad task specific flags.
# Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
# Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.")
# BERT Squad task-specific flags.
flags.DEFINE_string(
"squad_data_file", None,
"The input data file in for generating training data for BERT squad task.")
......@@ -171,7 +184,8 @@ def generate_classifier_dataset():
"cola":
classifier_data_lib.ColaProcessor,
"mnli":
classifier_data_lib.MnliProcessor,
functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type),
"mrpc":
classifier_data_lib.MrpcProcessor,
"qnli":
......@@ -180,6 +194,8 @@ def generate_classifier_dataset():
"rte": classifier_data_lib.RteProcessor,
"sst-2":
classifier_data_lib.SstProcessor,
"sts-b":
classifier_data_lib.StsBProcessor,
"xnli":
functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language),
......@@ -284,6 +300,34 @@ def generate_retrieval_dataset():
FLAGS.max_seq_length)
def generate_tagging_dataset():
"""Generates tagging dataset."""
processors = {
"panx": tagging_data_lib.PanxProcessor,
"udpos": tagging_data_lib.UdposProcessor,
}
task_name = FLAGS.tagging_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
elif FLAGS.tokenizer_impl == "sentence_piece":
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
else:
raise ValueError("Unsupported tokenizer_impl: %s" % FLAGS.tokenizer_impl)
processor = processors[task_name]()
return tagging_data_lib.generate_tf_record_from_data_file(
processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
FLAGS.test_data_output_path, processor_text_fn)
def main(_):
if FLAGS.tokenizer_impl == "word_piece":
if not FLAGS.vocab_file:
......@@ -304,8 +348,11 @@ def main(_):
input_meta_data = generate_regression_dataset()
elif FLAGS.fine_tuning_task_type == "retrieval":
input_meta_data = generate_retrieval_dataset()
else:
elif FLAGS.fine_tuning_task_type == "squad":
input_meta_data = generate_squad_dataset()
else:
assert FLAGS.fine_tuning_task_type == "tagging"
input_meta_data = generate_tagging_dataset()
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
......
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
import itertools
import random
from absl import app
......@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.")
flags.DEFINE_integer(
"max_ngram_size", None,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
flags.DEFINE_bool(
"gzip_compress", False,
"Whether to use `GZIP` compress option to get compressed TFRecord files.")
......@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob,
max_predictions_per_seq,
rng,
do_whole_word_mask=False):
do_whole_word_mask=False,
max_ngram_size=None):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
......@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask))
do_whole_word_mask, max_ngram_size))
rng.shuffle(instances)
return instances
......@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask=False):
do_whole_word_mask=False,
max_ngram_size=None):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
......@@ -337,7 +346,7 @@ def create_instances_from_document(
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)
do_whole_word_mask, max_ngram_size)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
......@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram = collections.namedtuple("_Gram", ["begin", "end"])
def _window(iterable, size):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i = iter(iterable)
window = []
try:
for e in range(0, size):
window.append(next(i))
yield window
except StopIteration:
# handle the case where iterable's length is less than the window size.
return
for e in i:
window = window[1:] + [e]
yield window
def _contiguous(sorted_grams):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for a, b in _window(sorted_grams, 2):
if a.end != b.begin:
return False
return True
def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if not grams:
return None
grams = sorted(grams)
num_tokens = grams[-1].end
# Ensure our grams are valid (i.e., they don't overlap).
for a, b in _window(grams, 2):
if a.end > b.begin:
raise ValueError("overlapping grams: {}".format(grams))
# Build map from n-gram length to list of n-grams.
ngrams = {i: [] for i in range(1, max_ngram_size+1)}
for gram_size in range(1, max_ngram_size+1):
for g in _window(grams, gram_size):
if _contiguous(g):
# Add an n-gram which spans these one-grams.
ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
# Shuffle each list of n-grams.
for v in ngrams.values():
rng.shuffle(v)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights = list(
itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
output_ngrams = []
# Keep a bitmask of which tokens have been masked.
masked_tokens = [False] * num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while (sum(masked_tokens) < max_masked_tokens and
sum(len(s) for s in ngrams.values())):
# Pick an n-gram size based on our weights.
sz = random.choices(range(1, max_ngram_size+1),
cum_weights=cummulative_weights)[0]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if sum(masked_tokens) + sz > max_masked_tokens:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams[sz].clear()
continue
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask):
"""Creates the predictions for the masked LM objective."""
# All of the n-grams of this size have been used.
if not ngrams[sz]:
continue
# Choose a random n-gram of the given size.
gram = ngrams[sz].pop()
num_gram_tokens = gram.end-gram.begin
# Check if this would add too many tokens.
if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
continue
# Check if any of the tokens in this gram have already been masked.
if sum(masked_tokens[gram.begin:gram.end]):
continue
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
# Found a usable n-gram! Mark its tokens as masked and add it to return.
masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
output_ngrams.append(gram)
return output_ngrams
def _wordpieces_to_grams(tokens):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams = []
gram_start_pos = None
for i, token in enumerate(tokens):
if gram_start_pos is not None and token.startswith("##"):
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
if gram_start_pos is not None:
grams.append(_Gram(gram_start_pos, i))
if token not in ["[CLS]", "[SEP]"]:
gram_start_pos = i
else:
cand_indexes.append([i])
gram_start_pos = None
if gram_start_pos is not None:
grams.append(_Gram(gram_start_pos, len(tokens)))
return grams
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask,
max_ngram_size=None):
"""Creates the predictions for the masked LM objective."""
if do_whole_word_mask:
grams = _wordpieces_to_grams(tokens)
else:
# Here we consider each token to be a word to allow for sub-word masking.
if max_ngram_size:
raise ValueError("cannot use ngram masking without whole word masking")
grams = [_Gram(i, i+1) for i in range(0, len(tokens))
if tokens[i] not in ["[CLS]", "[SEP]"]]
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
num_to_predict, rng)
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
output_tokens = list(tokens)
for gram in masked_grams:
# 80% of the time, replace all n-gram tokens with [MASK]
if rng.random() < 0.8:
replacement_action = lambda idx: "[MASK]"
else:
# 10% of the time, keep all the original n-gram tokens.
if rng.random() < 0.5:
replacement_action = lambda idx: tokens[idx]
# 10% of the time, replace each n-gram token with a random word.
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
replacement_action = lambda idx: rng.choice(vocab_words)
output_tokens[index] = masked_token
for idx in range(gram.begin, gram.end):
output_tokens[idx] = replacement_action(idx)
masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
......@@ -467,7 +642,7 @@ def main(_):
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng, FLAGS.do_whole_word_mask)
rng, FLAGS.do_whole_word_mask, FLAGS.max_ngram_size)
output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***")
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to access NLP registered data loaders."""
from official.utils import registry
_REGISTERED_DATA_LOADER_CLS = {}
def register_data_loader_cls(data_config_cls):
"""Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
This decorator supports registration of data loaders as follows:
```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass
@register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
def get_data_loader(data_config):
"""Creates a data_loader from data_config."""
return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(
data_config)
......@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ''
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader:
"""A class to load dataset for bert pretraining task."""
......@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params,
decoder_fn=self._decode,
parser_fn=self._parse)
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ''
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
# Settings below are question answering specific.
version_2_with_negative: bool = False
# Settings below are only used for eval mode.
input_preprocessed_data_path: str = ''
doc_stride: int = 128
query_length: int = 64
vocab_file: str = ''
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
do_lower_case: bool = True
@data_loader_factory.register_data_loader_cls(QADataConfig)
class QuestionAnsweringDataLoader:
"""A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params):
self._params = params
self._seq_length = params.seq_length
self._is_training = params.is_training
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
else:
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x, y = {}, {}
for name, tensor in record.items():
if name in ('start_positions', 'end_positions'):
y[name] = tensor
elif name == 'input_ids':
x['input_word_ids'] = tensor
elif name == 'segment_ids':
x['input_type_ids'] = tensor
else:
x[name] = tensor
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
......@@ -15,11 +15,28 @@
# ==============================================================================
"""Loads dataset for the sentence prediction (classification) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ''
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
label_type: str = 'int'
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
class SentencePredictionDataLoader:
"""A class to load dataset for sentence prediction (classification) task."""
......@@ -29,11 +46,12 @@ class SentencePredictionDataLoader:
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
label_type = LABEL_TYPES_MAP[self._params.label_type]
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64),
'label_ids': tf.io.FixedLenFeature([], label_type),
}
example = tf.io.parse_single_example(record, name_to_features)
......
This diff is collapsed.
......@@ -15,17 +15,30 @@
# ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
is_training: bool = True
seq_length: int = 128
include_sentence_id: bool = False
@data_loader_factory.register_data_loader_cls(TaggingDataConfig)
class TaggingDataLoader:
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def __init__(self, params):
def __init__(self, params: TaggingDataConfig):
self._params = params
self._seq_length = params.seq_length
self._include_sentence_id = params.include_sentence_id
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
......@@ -35,6 +48,9 @@ class TaggingDataLoader:
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._include_sentence_id:
name_to_features['sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
......@@ -54,6 +70,8 @@ class TaggingDataLoader:
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
if self._include_sentence_id:
x['sentence_id'] = record['sentence_id']
y = record['label_ids']
return (x, y)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment