Commit 9c9aec17 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to run ALBERT on SQuAD task.

PiperOrigin-RevId: 286637307
parent 553a4f41
......@@ -25,7 +25,10 @@ from absl import flags
import tensorflow as tf
from official.nlp.bert import classifier_data_lib
from official.nlp.bert import squad_lib
# word-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib_sp
FLAGS = flags.FLAGS
......@@ -70,14 +73,12 @@ flags.DEFINE_string("vocab_file", None,
flags.DEFINE_string(
"train_data_output_path", None,
"The path in which generated training input data will be written as tf"
" records."
)
" records.")
flags.DEFINE_string(
"eval_data_output_path", None,
"The path in which generated training input data will be written as tf"
" records."
)
" records.")
flags.DEFINE_string("meta_data_file_path", None,
"The path in which input meta data will be written.")
......@@ -93,6 +94,15 @@ flags.DEFINE_integer(
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
flags.DEFINE_string("sp_model_file", "",
"The path to the model used by sentence piece tokenizer.")
flags.DEFINE_enum(
"tokenizer_impl", "word_piece", ["word_piece", "sentence_piece"],
"Specifies the tokenizer implementation, i.e., whehter to use word_piece "
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
"while ALBERT uses sentence_piece tokenizer.")
def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
......@@ -124,13 +134,30 @@ def generate_classifier_dataset():
def generate_squad_dataset():
"""Generates squad training dataset and returns input meta data."""
assert FLAGS.squad_data_file
return squad_lib.generate_tf_record_from_json_file(
if FLAGS.tokenizer_impl == "word_piece":
return squad_lib_wp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
FLAGS.doc_stride, FLAGS.version_2_with_negative)
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
return squad_lib_sp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.sp_model_file,
FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case,
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
def main(_):
if FLAGS.tokenizer_impl == "word_piece":
if not FLAGS.vocab_file:
raise ValueError(
"FLAG vocab_file for word-piece tokenizer is not specified.")
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
if not FLAGS.sp_model_file:
raise ValueError(
"FLAG sp_model_file for sentence-piece tokenizer is not specified.")
if FLAGS.fine_tuning_task_type == "classification":
input_meta_data = generate_classifier_dataset()
else:
......@@ -141,7 +168,6 @@ def main(_):
if __name__ == "__main__":
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("train_data_output_path")
flags.mark_flag_as_required("meta_data_file_path")
app.run(main)
......@@ -34,7 +34,10 @@ from official.nlp import optimization
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.bert import squad_lib
# word-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib_sp
from official.nlp.bert import tokenization
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
......@@ -80,11 +83,22 @@ flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.')
flags.DEFINE_string(
'sp_model_file', None,
'The path to the sentence piece model. Used by sentence piece tokenizer '
'employed by ALBERT.')
common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
MODEL_CLASSES = {
'bert': (modeling.BertConfig, squad_lib_wp, tokenization.FullTokenizer),
'albert': (modeling.AlbertConfig, squad_lib_sp,
tokenization.FullSentencePieceTokenizer),
}
def squad_loss_fn(start_positions,
end_positions,
......@@ -121,6 +135,7 @@ def get_loss_fn(loss_factor=1.0):
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
squad_lib = MODEL_CLASSES[FLAGS.model_type][1]
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
......@@ -167,9 +182,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
float_type=tf.float32)
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
......@@ -219,7 +232,8 @@ def train_squad(strategy,
if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
......@@ -281,7 +295,14 @@ def train_squad(strategy,
def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
config_cls, squad_lib, tokenizer_cls = MODEL_CLASSES[FLAGS.model_type]
bert_config = config_cls.from_json_file(FLAGS.bert_config_file)
if tokenizer_cls == tokenization.FullTokenizer:
tokenizer = tokenizer_cls(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
else:
assert tokenizer_cls == tokenization.FullSentencePieceTokenizer
tokenizer = tokenizer_cls(sp_model_file=FLAGS.sp_model_file)
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
......@@ -292,9 +313,6 @@ def predict_squad(strategy, input_meta_data):
is_training=False,
version_2_with_negative=version_2_with_negative)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
......@@ -309,7 +327,7 @@ def predict_squad(strategy, input_meta_data):
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
dataset_size = squad_lib.convert_examples_to_features(
kwargs = dict(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
......@@ -318,6 +336,11 @@ def predict_squad(strategy, input_meta_data):
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
# squad_lib_sp requires one more argument 'do_lower_case'.
if squad_lib == squad_lib_sp:
kwargs['do_lower_case'] = FLAGS.do_lower_case
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
eval_writer.close()
logging.info('***** Running predictions *****')
......@@ -358,12 +381,10 @@ def export_squad(model_export_path, input_meta_data):
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
FLAGS.bert_config_file)
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
float_type=tf.float32)
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
......
This diff is collapsed.
......@@ -32,7 +32,7 @@ import tensorflow as tf
import sentencepiece as spm
SPIECE_UNDERLINE = u"▁".encode("utf-8")
SPIECE_UNDERLINE = "▁"
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
......@@ -458,6 +458,9 @@ def encode_pieces(sp_model, text, sample=False):
Returns:
A list of token pieces.
"""
if six.PY2 and isinstance(text, six.text_type):
text = six.ensure_binary(text, "utf-8")
if not sample:
pieces = sp_model.EncodeAsPieces(text)
else:
......@@ -466,8 +469,8 @@ def encode_pieces(sp_model, text, sample=False):
for piece in pieces:
piece = printable_text(piece)
if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
cur_pieces = sp_model.EncodeAsPieces(
six.ensure_binary(piece[:-1]).replace(SPIECE_UNDERLINE, b""))
cur_pieces = sp_model.EncodeAsPieces(piece[:-1].replace(
SPIECE_UNDERLINE, ""))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
......@@ -514,21 +517,21 @@ class FullSentencePieceTokenizer(object):
Args:
sp_model_file: The path to the sentence piece model file.
"""
self._sp_model = spm.SentencePieceProcessor()
self._sp_model.Load(sp_model_file)
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(sp_model_file)
self.vocab = {
self._sp_model.IdToPiece(i): i
for i in six.moves.range(self._sp_model.GetPieceSize())
self.sp_model.IdToPiece(i): i
for i in six.moves.range(self.sp_model.GetPieceSize())
}
def tokenize(self, text):
"""Tokenizes text into pieces."""
return encode_pieces(self._sp_model, text)
return encode_pieces(self.sp_model, text)
def convert_tokens_to_ids(self, tokens):
"""Converts a list of tokens to a list of ids."""
return [self._sp_model.PieceToId(printable_text(token)) for token in tokens]
return [self.sp_model.PieceToId(printable_text(token)) for token in tokens]
def convert_ids_to_tokens(self, ids):
"""Converts a list of ids ot a list of tokens."""
return [self._sp_model.IdToPiece(id_) for id_ in ids]
return [self.sp_model.IdToPiece(id_) for id_ in ids]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment