Commit 8a802023 authored by hepj987's avatar hepj987
Browse files

初始化仓库

parents
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run masked LM/next sentence pre-training for BERT in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import configs
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_training_utils
from official.utils.misc import distribution_utils
flags.DEFINE_string('input_files', None,
'File path to retrieve training data for pre-training.')
# Model training specific flags.
flags.DEFINE_integer(
'max_seq_length', 128,
'The maximum total input sequence length after WordPiece tokenization. '
'Sequences longer than this will be truncated, and sequences shorter '
'than this will be padded.')
flags.DEFINE_integer('max_predictions_per_seq', 20,
'Maximum predictions per sequence_output.')
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.')
flags.DEFINE_bool('train_summary_interval', 0, 'Step interval for training '
'summaries. If the value is a negative number, '
'then training summaries are not enabled.')
common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, global_batch_size,
use_next_sentence_label=True):
"""Returns input dataset from input file string."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
input_patterns = input_file_pattern.split(',')
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
train_dataset = input_pipeline.create_pretrain_dataset(
input_patterns,
seq_length,
max_predictions_per_seq,
batch_size,
is_training=True,
input_pipeline_context=ctx,
use_next_sentence_label=use_next_sentence_label)
return train_dataset
return _dataset_fn
def get_loss_fn():
"""Returns loss function for BERT pretraining."""
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
return tf.reduce_mean(losses)
return _bert_pretrain_loss_fn
def run_customized_training(strategy,
bert_config,
init_checkpoint,
max_seq_length,
max_predictions_per_seq,
model_dir,
steps_per_epoch,
steps_per_loop,
epochs,
initial_lr,
warmup_steps,
end_lr,
optimizer_type,
input_files,
train_batch_size,
use_next_sentence_label=True,
train_summary_interval=0,
custom_callbacks=None):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
max_predictions_per_seq,
train_batch_size,
use_next_sentence_label)
def _get_pretrain_model():
"""Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq,
use_next_sentence_label=use_next_sentence_label)
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps,
end_lr, optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_pretrain_model,
loss_fn=get_loss_fn(),
scale_loss=FLAGS.scale_loss,
model_dir=model_dir,
init_checkpoint=init_checkpoint,
train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
epochs=epochs,
sub_model_export_name='pretrained/bert_model',
train_summary_interval=train_summary_interval,
custom_callbacks=custom_callbacks)
return trained_model
def run_bert_pretrain(strategy, custom_callbacks=None):
"""Runs BERT pre-training."""
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if not strategy:
raise ValueError('Distribution strategy is not specified.')
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distributed'
'strategy.')
performance.set_mixed_precision_policy(common_flags.dtype())
return run_customized_training(
strategy,
bert_config,
FLAGS.init_checkpoint, # Used to initialize only the BERT submodel.
FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq,
FLAGS.model_dir,
FLAGS.num_steps_per_epoch,
FLAGS.steps_per_loop,
FLAGS.num_train_epochs,
FLAGS.learning_rate,
FLAGS.warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type,
FLAGS.input_files,
FLAGS.train_batch_size,
FLAGS.use_next_sentence_label,
FLAGS.train_summary_interval,
custom_callbacks=custom_callbacks)
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
if strategy:
print('***** Number of cores used : ', strategy.num_replicas_in_sync)
run_bert_pretrain(strategy)
if __name__ == '__main__':
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import time
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_string('vocab_file', None,
'The vocabulary file that the BERT model was trained on.')
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False,
init_checkpoint=None,
sub_model_export_name=None):
"""Run bert squad training."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly, init_checkpoint,
sub_model_export_name=sub_model_export_name)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
run_squad_helper.predict_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
def eval_squad(strategy, input_meta_data):
"""Evaluate on the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
eval_metrics = run_squad_helper.eval_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
return eval_metrics
def export_squad(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if FLAGS.mode == 'export_only':
export_squad(FLAGS.model_export_path, input_meta_data)
return
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if 'train' in FLAGS.mode:
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
train_squad(
strategy,
input_meta_data,
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
sub_model_export_name=FLAGS.sub_model_export_name,
)
if 'predict' in FLAGS.mode:
predict_squad(strategy, input_meta_data)
if 'eval' in FLAGS.mode:
eval_metrics = eval_squad(strategy, input_meta_data)
f1_score = eval_metrics['final_f1']
logging.info('SQuAD eval F1-score: %f', f1_score)
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
summary_writer = tf.summary.create_file_writer(summary_dir)
with summary_writer.as_default():
# TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush()
# Also write eval_metrics to json file.
squad_lib_wp.write_to_json_files(
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
time.sleep(60)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('model_dir')
app.run(main)
module rm compiler/rocm/2.9
module load compiler/rocm/3.9.1
module load mathlib/miopen/2.10_rocm3.9/gcc
HOME_PATH=/public/home/xuanbaby/
export PATH=$HOME_PATH/rocm3.9-python3.6.8-tf1.15/bin:$PATH
which python3
USER_HOME=/public/home/xuanbaby
WORK_HOME=`pwd`
BERT_DIR=${WORK_HOME}/pre_tf2x
SQUAD_VERSION=v1.1
OUTPUT_DIR=${WORK_HOME}/SQuAD_v1.1_tf_v1.2
#python3 ../data/create_finetuning_data.py \
# --squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
# --vocab_file=${BERT_DIR}/vocab.txt \
## --train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
# --meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
# --eval_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_eval.tf_record \
# --fine_tuning_task_type=squad --max_seq_length=384
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#HSA_FORCE_FINE_GRAIN_PCIE=1 numactl --cpunodebind=4,5,6,7 --membind=4,5,6,7 python3 run_squad_xuan.py \
numactl --cpunodebind=0,1,2,3 --membind=0,1,2,3 python3 run_squad_xuan.py \
--vocab_file=${BERT_DIR}/vocab.txt \
--bert_config_file=${BERT_DIR}/bert_config.json \
--mode='train_and_eval' \
--input_meta_data_path=${WORK_HOME}/SQuAD_v1.1_tf/squad_v1.1_meta_data \
--train_data_path=${WORK_HOME}/SQuAD_v1.1_tf/squad_v1.1_train.tf_record \
--train_batch_size=128 \
--predict_batch_size=4 \
--learning_rate=2e-5 \
--log_steps=1 \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--num_gpus=4 \
--distribution_strategy=mirrored \
--model_dir=${WORK_HOME}/model_squad_v2 \
--run_eagerly=False \
--predict_file=${WORK_HOME}/SQuAD_v1.1_tf/dev-v1.1.json \
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for running BERT family models on SQuAD 1.1/2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import os
from absl import flags
from absl import logging
import tensorflow as tf
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.bert import model_training_utils
from official.nlp.bert import squad_evaluate_v1_1
from official.nlp.bert import squad_evaluate_v2_0
from official.nlp.data import squad_lib_sp
from official.utils.misc import keras_utils
def define_common_squad_flags():
"""Defines common flags used by SQuAD tasks."""
flags.DEFINE_enum(
'mode', 'train_and_eval',
['train_and_eval', 'train_and_predict',
'train', 'eval', 'predict', 'export_only'],
'One of {"train_and_eval", "train_and_predict", '
'"train", "eval", "predict", "export_only"}. '
'`train_and_eval`: train & predict to json files & compute eval metrics. '
'`train_and_predict`: train & predict to json files. '
'`train`: only trains the model. '
'`eval`: predict answers from squad json file & compute eval metrics. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.')
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'SQuAD prediction json file path. '
'`predict` mode supports multiple files: one can use '
'wildcard to specify multiple files and it can also be '
'multiple file patterns separated by comma. Note that '
'`eval` mode only supports a single predict file.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
'models and False for cased models.')
flags.DEFINE_float(
'null_score_diff_threshold', 0.0,
'If null_score - best_non_null is greater than the threshold, '
'predict null. This is only used for SQuAD v2.')
flags.DEFINE_bool(
'verbose_logging', False,
'If true, all of the warnings related to data processing will be '
'printed. A number of warnings are expected for a normal SQuAD '
'evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
'Total batch size for prediction.')
flags.DEFINE_integer(
'n_best_size', 20,
'The total number of n-best predictions to generate in the '
'nbest_predictions.json output file.')
flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one '
'another.')
common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
return total_loss
def get_loss_fn():
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
return squad_loss_fn(
start_positions,
end_positions,
start_logits,
end_logits)
return _loss_fn
RawResult = collections.namedtuple('RawResult',
['unique_id', 'start_logits', 'end_logits'])
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
for values in zip(unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy()):
yield RawResult(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist())
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset.."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_squad_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
input_meta_data):
"""Gets a squad model to make predictions."""
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
if checkpoint_path is None:
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial()
return squad_model
def predict_squad_customized(strategy,
input_meta_data,
predict_tfrecord_path,
num_steps,
squad_model):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids = x.pop('unique_ids')
start_logits, end_logits = squad_model(x, training=False)
return dict(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.run(_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
all_results = []
for _ in range(num_steps):
predictions = predict_step(predict_iterator)
for result in get_raw_results(predictions):
all_results.append(result)
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy,
input_meta_data,
bert_config,
custom_callbacks=None,
run_eagerly=False,
init_checkpoint=None,
sub_model_export_name=None):
"""Run bert squad training."""
if strategy:
logging.info('Training using customized training loop with distribution'
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
def _get_squad_model():
"""Get Squad model and optimizer."""
squad_model, core_model = bert_models.squad_model(
bert_config,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
optimizer = optimization.create_optimizer(FLAGS.learning_rate,
steps_per_epoch * epochs,
warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type)
squad_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
return squad_model, core_model
# If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be
# applied to allreduced gradients.
def clip_by_global_norm_callback(grads_and_vars):
grads, variables = zip(*grads_and_vars)
(clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return zip(clipped_grads, variables)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=get_loss_fn(),
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=init_checkpoint or FLAGS.init_checkpoint,
sub_model_export_name=sub_model_export_name,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks,
explicit_allreduce=False,
post_allreduce_callbacks=[clip_by_global_norm_callback])
def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
predict_file, squad_model):
"""Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
kwargs = dict(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
# squad_lib_sp requires one more argument 'do_lower_case'.
if squad_lib == squad_lib_sp:
kwargs['do_lower_case'] = FLAGS.do_lower_case
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
eval_writer.close()
logging.info('***** Running predictions *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(
strategy, input_meta_data, eval_writer.filename, num_steps, squad_model)
all_predictions, all_nbest_json, scores_diff_json = (
squad_lib.postprocess_output(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging))
return all_predictions, all_nbest_json, scores_diff_json
def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib, version_2_with_negative, file_prefix=''):
"""Save output to json files."""
output_prediction_file = os.path.join(FLAGS.model_dir,
'%spredictions.json' % file_prefix)
output_nbest_file = os.path.join(FLAGS.model_dir,
'%snbest_predictions.json' % file_prefix)
output_null_log_odds_file = os.path.join(FLAGS.model_dir, file_prefix,
'%snull_odds.json' % file_prefix)
logging.info('Writing predictions to: %s', (output_prediction_file))
logging.info('Writing nbest to: %s', (output_nbest_file))
squad_lib.write_to_json_files(all_predictions, output_prediction_file)
squad_lib.write_to_json_files(all_nbest_json, output_nbest_file)
if version_2_with_negative:
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
def _get_matched_files(input_path):
"""Returns all files that matches the input_path."""
input_patterns = input_path.strip().split(',')
all_matched_files = []
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
matched_files = tf.io.gfile.glob(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
else:
all_matched_files.extend(matched_files)
return sorted(all_matched_files)
def predict_squad(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None):
"""Get prediction results and evaluate them to hard drive."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predict_files = _get_matched_files(FLAGS.predict_file)
squad_model = get_squad_model_to_predict(strategy, bert_config,
init_checkpoint, input_meta_data)
for idx, predict_file in enumerate(all_predict_files):
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, squad_lib, predict_file,
squad_model)
if len(all_predict_files) == 1:
file_prefix = ''
else:
# if predict_file is /path/xquad.ar.json, the `file_prefix` may be
# "xquad.ar-0-"
file_prefix = '%s-' % os.path.splitext(
os.path.basename(all_predict_files[idx]))[0]
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False),
file_prefix)
def eval_squad(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None):
"""Get prediction results and evaluate them against ground truth."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predict_files = _get_matched_files(FLAGS.predict_file)
if len(all_predict_files) != 1:
raise ValueError('`eval_squad` only supports one predict file, '
'but got %s' % all_predict_files)
squad_model = get_squad_model_to_predict(strategy, bert_config,
init_checkpoint, input_meta_data)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, squad_lib, all_predict_files[0],
squad_model)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader:
dataset_json = json.load(reader)
pred_dataset = dataset_json['data']
if input_meta_data.get('version_2_with_negative', False):
eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset,
all_predictions,
scores_diff_json)
else:
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
return eval_metrics
def export_squad(model_export_path, input_meta_data, bert_config):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
bert_config: Bert configuration file to define core bert layers.
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import time
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
import sys
sys.path.append("/public/home/xuanbaby/DL-TensorFlow/models_r2.3.0")
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_string('vocab_file', None,
'The vocabulary file that the BERT model was trained on.')
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False,
init_checkpoint=None,
sub_model_export_name=None):
"""Run bert squad training."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly, init_checkpoint,
sub_model_export_name=sub_model_export_name)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
run_squad_helper.predict_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
def eval_squad(strategy, input_meta_data):
"""Evaluate on the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
eval_metrics = run_squad_helper.eval_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
return eval_metrics
def export_squad(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
def main(_):
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True)
print('{} memory growth: {}'.format(device, tf.config.experimental.get_memory_growth(device)))
else:
print("Not enough GPU hardware devices available")
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if FLAGS.mode == 'export_only':
export_squad(FLAGS.model_export_path, input_meta_data)
return
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if 'train' in FLAGS.mode:
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
train_squad(
strategy,
input_meta_data,
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
sub_model_export_name=FLAGS.sub_model_export_name,
)
if 'predict' in FLAGS.mode:
predict_squad(strategy, input_meta_data)
if 'eval' in FLAGS.mode:
eval_metrics = eval_squad(strategy, input_meta_data)
f1_score = eval_metrics['final_f1']
logging.info('SQuAD eval F1-score: %f', f1_score)
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
summary_writer = tf.summary.create_file_writer(summary_dir)
with summary_writer.as_default():
# TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush()
# Also write eval_metrics to json file.
squad_lib_wp.write_to_json_files(
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
time.sleep(60)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('model_dir')
app.run(main)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Examples of SavedModel export for tf-serving."""
from absl import app
from absl import flags
import tensorflow as tf
from official.nlp.bert import bert_models
from official.nlp.bert import configs
flags.DEFINE_integer("sequence_length", None,
"Sequence length to parse the tf.Example. If "
"sequence_length > 0, add a signature for serialized "
"tf.Example and define the parsing specification by the "
"sequence_length.")
flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None,
"Destination folder to export the serving SavedModel.")
FLAGS = flags.FLAGS
class BertServing(tf.keras.Model):
"""Bert transformer encoder model for serving."""
def __init__(self, bert_config, name_to_features=None, name="serving_model"):
super(BertServing, self).__init__(name=name)
self.encoder = bert_models.get_transformer_encoder(
bert_config, sequence_length=None)
self.name_to_features = name_to_features
def call(self, inputs):
input_word_ids = inputs["input_ids"]
input_mask = inputs["input_mask"]
input_type_ids = inputs["segment_ids"]
encoder_outputs, _ = self.encoder(
[input_word_ids, input_mask, input_type_ids])
return encoder_outputs
def serve_body(self, input_ids, input_mask=None, segment_ids=None):
if segment_ids is None:
# Requires CLS token is the first token of inputs.
segment_ids = tf.zeros_like(input_ids)
if input_mask is None:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask = tf.where(
tf.equal(input_ids, 0), tf.zeros_like(input_ids),
tf.ones_like(input_ids))
inputs = dict(
input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids)
return self.call(inputs)
@tf.function
def serve(self, input_ids, input_mask=None, segment_ids=None):
outputs = self.serve_body(input_ids, input_mask, segment_ids)
# Returns a dictionary to control SignatureDef output signature.
return {"outputs": outputs[-1]}
@tf.function
def serve_examples(self, inputs):
features = tf.io.parse_example(inputs, self.name_to_features)
for key in list(features.keys()):
t = features[key]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
features[key] = t
return self.serve(
features["input_ids"],
input_mask=features["input_mask"] if "input_mask" in features else None,
segment_ids=features["segment_ids"]
if "segment_ids" in features else None)
@classmethod
def export(cls, model, export_dir):
if not isinstance(model, cls):
raise ValueError("Invalid model instance: %s, it should be a %s" %
(model, cls))
signatures = {
"serving_default":
model.serve.get_concrete_function(
input_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="inputs")),
}
if model.name_to_features:
signatures[
"serving_examples"] = model.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
tf.saved_model.save(model, export_dir=export_dir, signatures=signatures)
def main(_):
sequence_length = FLAGS.sequence_length
if sequence_length is not None and sequence_length > 0:
name_to_features = {
"input_ids": tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask": tf.io.FixedLenFeature([sequence_length], tf.int64),
"segment_ids": tf.io.FixedLenFeature([sequence_length], tf.int64),
}
else:
name_to_features = None
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
serving_model = BertServing(
bert_config=bert_config, name_to_features=name_to_features)
checkpoint = tf.train.Checkpoint(model=serving_model.encoder)
checkpoint.restore(FLAGS.model_checkpoint_path
).assert_existing_objects_matched().run_restore_ops()
BertServing.export(serving_model, FLAGS.export_path)
if __name__ == "__main__":
flags.mark_flag_as_required("bert_config_file")
flags.mark_flag_as_required("model_checkpoint_path")
flags.mark_flag_as_required("export_path")
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation of SQuAD predictions (version 1.1).
The functions are copied from
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
The SQuAD dataset is described in this paper:
SQuAD: 100,000+ Questions for Machine Comprehension of Text
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import string
# pylint: disable=g-bad-import-order
from absl import logging
# pylint: enable=g-bad-import-order
def _normalize_answer(s):
"""Lowers text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _f1_score(prediction, ground_truth):
"""Computes F1 score by comparing prediction to ground truth."""
prediction_tokens = _normalize_answer(prediction).split()
ground_truth_tokens = _normalize_answer(ground_truth).split()
prediction_counter = collections.Counter(prediction_tokens)
ground_truth_counter = collections.Counter(ground_truth_tokens)
common = prediction_counter & ground_truth_counter
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def _exact_match_score(prediction, ground_truth):
"""Checks if predicted answer exactly matches ground truth answer."""
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Computes the max over all metric scores."""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(dataset, predictions):
"""Evaluates predictions for a dataset."""
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = "Unanswered question " + qa["id"] + " will receive score 0."
logging.error(message)
continue
ground_truths = [entry["text"] for entry in qa["answers"]]
prediction = predictions[qa["id"]]
exact_match += _metric_max_over_ground_truths(_exact_match_score,
prediction, ground_truths)
f1 += _metric_max_over_ground_truths(_f1_score, prediction,
ground_truths)
exact_match = exact_match / total
f1 = f1 / total
return {"exact_match": exact_match, "final_f1": f1}
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation script for SQuAD version 2.0.
The functions are copied and modified from
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import string
from absl import logging
def _make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
qid_to_has_ans[qa['id']] = bool(qa['answers'])
return qid_to_has_ans
def _normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _get_tokens(s):
if not s: return []
return _normalize_answer(s).split()
def _compute_exact(a_gold, a_pred):
return int(_normalize_answer(a_gold) == _normalize_answer(a_pred))
def _compute_f1(a_gold, a_pred):
"""Compute F1-score."""
gold_toks = _get_tokens(a_gold)
pred_toks = _get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if not gold_toks or not pred_toks:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def _get_raw_scores(dataset, predictions):
"""Compute raw scores."""
exact_scores = {}
f1_scores = {}
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
qid = qa['id']
gold_answers = [a['text'] for a in qa['answers']
if _normalize_answer(a['text'])]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = ['']
if qid not in predictions:
logging.error('Missing prediction for %s', qid)
continue
a_pred = predictions[qid]
# Take max over all gold answers
exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers)
f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers)
return exact_scores, f1_scores
def _apply_no_ans_threshold(
scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def _make_eval_dict(exact_scores, f1_scores, qid_list=None):
"""Make evaluation result dictionary."""
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores.values()) / total),
('f1', 100.0 * sum(f1_scores.values()) / total),
('total', total),
])
else:
total = len(qid_list)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
('total', total),
])
def _merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans):
"""Make evaluation dictionary containing average recision recall."""
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
true_pos = 0.0
cur_p = 1.0
cur_r = 0.0
precisions = [1.0]
recalls = [0.0]
avg_prec = 0.0
for i, qid in enumerate(qid_list):
if qid_to_has_ans[qid]:
true_pos += scores[qid]
cur_p = true_pos / float(i+1)
cur_r = true_pos / float(num_true_pos)
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
# i.e., if we can put a threshold after this point
avg_prec += cur_p * (cur_r - recalls[-1])
precisions.append(cur_p)
recalls.append(cur_r)
return {'ap': 100.0 * avg_prec}
def _run_precision_recall_analysis(
main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans):
"""Run precision recall analysis and return result dictionary."""
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
if num_true_pos == 0:
return
pr_exact = _make_precision_recall_eval(
exact_raw, na_probs, num_true_pos, qid_to_has_ans)
pr_f1 = _make_precision_recall_eval(
f1_raw, na_probs, num_true_pos, qid_to_has_ans)
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
pr_oracle = _make_precision_recall_eval(
oracle_scores, na_probs, num_true_pos, qid_to_has_ans)
_merge_eval(main_eval, pr_exact, 'pr_exact')
_merge_eval(main_eval, pr_f1, 'pr_f1')
_merge_eval(main_eval, pr_oracle, 'pr_oracle')
def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans):
"""Find the best threshold for no answer probability."""
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for qid in qid_list:
if qid not in scores: continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if predictions[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def _find_all_best_thresh(
main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh = _find_best_thresh(
predictions, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = _find_best_thresh(
predictions, f1_raw, na_probs, qid_to_has_ans)
main_eval['final_exact'] = best_exact
main_eval['final_exact_thresh'] = exact_thresh
main_eval['final_f1'] = best_f1
main_eval['final_f1_thresh'] = f1_thresh
def evaluate(dataset, predictions, na_probs=None):
"""Evaluate prediction results."""
new_orig_data = []
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
if qa['id'] in predictions:
new_para = {'qas': [qa]}
new_article = {'paragraphs': [new_para]}
new_orig_data.append(new_article)
dataset = new_orig_data
if na_probs is None:
na_probs = {k: 0.0 for k in predictions}
qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = _get_raw_scores(dataset, predictions)
exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans)
f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans)
out_eval = _make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = _make_eval_dict(
exact_thresh, f1_thresh, qid_list=has_ans_qids)
_merge_eval(out_eval, has_ans_eval, 'HasAns')
if no_ans_qids:
no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
_merge_eval(out_eval, no_ans_eval, 'NoAns')
_find_all_best_thresh(
out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
_run_precision_recall_analysis(
out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans)
return out_eval
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow.compat.v1 as tf # TF 1.x
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
BERT_NAME_REPLACEMENTS = (
("bert", "bert_model"),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/token_type_embeddings",
"embedding_postprocessor/type_embeddings"),
("embeddings/position_embeddings",
"embedding_postprocessor/position_embeddings"),
("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"),
)
BERT_V2_NAME_REPLACEMENTS = (
("bert/", ""),
("encoder", "transformer"),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
("embeddings/position_embeddings", "position_embedding/embeddings"),
("embeddings/LayerNorm", "embeddings/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention/attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
("cls/seq_relationship/output_weights",
"predictions/transform/logits/kernel"),
)
BERT_PERMUTATIONS = ()
BERT_V2_PERMUTATIONS = (("cls/seq_relationship/output_weights", (1, 0)),)
def _bert_name_replacement(var_name, name_replacements):
"""Gets the variable name replacement."""
for src_pattern, tgt_pattern in name_replacements:
if src_pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(src_pattern, tgt_pattern)
tf.logging.info("Converted: %s --> %s", old_var_name, var_name)
return var_name
def _has_exclude_patterns(name, exclude_patterns):
"""Checks if a string contains substrings that match patterns to exclude."""
for p in exclude_patterns:
if p in name:
return True
return False
def _get_permutation(name, permutations):
"""Checks whether a variable requires transposition by pattern matching."""
for src_pattern, permutation in permutations:
if src_pattern in name:
tf.logging.info("Permuted: %s --> %s", name, permutation)
return permutation
return None
def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching."""
if "self_attention/attention_output/kernel" in name:
return tuple([num_heads, shape[0] // num_heads, shape[1]])
if "self_attention/attention_output/bias" in name:
return shape
patterns = [
"self_attention/query", "self_attention/value", "self_attention/key"
]
for pattern in patterns:
if pattern in name:
if "kernel" in name:
return tuple([shape[0], num_heads, shape[1] // num_heads])
if "bias" in name:
return tuple([num_heads, shape[0] // num_heads])
return None
def create_v2_checkpoint(model, src_checkpoint, output_path):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model.load_weights(src_checkpoint).assert_existing_objects_matched()
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.save(output_path)
def convert(checkpoint_from_path,
checkpoint_to_path,
num_heads,
name_replacements,
permutations,
exclude_patterns=None):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
num_heads: The number of heads of the model.
name_replacements: A list of tuples of the form (match_str, replace_str)
describing variable names to adjust.
permutations: A list of tuples of the form (match_str, permutation)
describing permutations to apply to given variables. Note that match_str
should match the original variable name, not the replaced one.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with tf.Graph().as_default():
tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
reader = tf.train.NewCheckpointReader(checkpoint_from_path)
name_shape_map = reader.get_variable_to_shape_map()
new_variable_map = {}
conversion_map = {}
for var_name in name_shape_map:
if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
continue
# Get the original tensor data.
tensor = reader.get_tensor(var_name)
# Look up the new variable name, if any.
new_var_name = _bert_name_replacement(var_name, name_replacements)
# See if we need to reshape the underlying tensor.
new_shape = None
if num_heads > 0:
new_shape = _get_new_shape(new_var_name, tensor.shape, num_heads)
if new_shape:
tf.logging.info("Veriable %s has a shape change from %s to %s",
var_name, tensor.shape, new_shape)
tensor = np.reshape(tensor, new_shape)
# See if we need to permute the underlying tensor.
permutation = _get_permutation(var_name, permutations)
if permutation:
tensor = np.transpose(tensor, permutation)
# Create a new variable with the possibly-reshaped or transposed tensor.
var = tf.Variable(tensor, name=var_name)
# Save the variable into the new variable map.
new_variable_map[new_var_name] = var
# Keep a list of converter variables for sanity checking.
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
saver = tf.train.Saver(new_variable_map)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path)
saver.save(sess, checkpoint_to_path, write_meta_graph=False)
tf.logging.info("Summary:")
tf.logging.info(" Converted %d variable name(s).", len(new_variable_map))
tf.logging.info(" Converted: %s", str(conversion_map))
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used
to restore a TransformerEncoder object.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.modeling import activations
from official.nlp.bert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import networks
FLAGS = flags.FLAGS
flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string(
"checkpoint_to_convert", None,
"Initial checkpoint from a pretrained BERT model core (that is, only the "
"BertModel, with no task heads.)")
flags.DEFINE_string("converted_checkpoint_path", None,
"Name for the created object-based V2 checkpoint.")
def _create_bert_model(cfg):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A TransformerEncoder netowork.
"""
bert_encoder = networks.TransformerEncoder(
vocab_size=cfg.vocab_size,
hidden_size=cfg.hidden_size,
num_layers=cfg.num_hidden_layers,
num_attention_heads=cfg.num_attention_heads,
intermediate_size=cfg.intermediate_size,
activation=activations.gelu,
dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob,
sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range),
embedding_width=cfg.embedding_size)
return bert_encoder
def convert_checkpoint(bert_config, output_path, v1_checkpoint):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path)
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
tf1_checkpoint_converter_lib.convert(
checkpoint_from_path=v1_checkpoint,
checkpoint_to_path=temporary_checkpoint,
num_heads=bert_config.num_attention_heads,
name_replacements=tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS,
permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
exclude_patterns=["adam", "Adam"])
# Create a V2 checkpoint from the temporary checkpoint.
model = _create_bert_model(bert_config)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path)
# Clean up the temporary checkpoint, if it exists.
try:
tf.io.gfile.rmtree(temporary_checkpoint_dir)
except tf.errors.OpError:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def main(_):
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
convert_checkpoint(bert_config, output_path, v1_checkpoint)
if __name__ == "__main__":
app.run(main)
# coding=utf-8
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tokenization classes implementation.
The file is forked from:
https://github.com/google-research/bert/blob/master/tokenization.py.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
import tensorflow as tf
import sentencepiece as spm
SPIECE_UNDERLINE = "▁"
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." %
(actual_flag, init_checkpoint, model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with tf.io.gfile.GFile(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True, split_on_punc=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case, split_on_punc=split_on_punc)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True, split_on_punc=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
split_on_punc: Whether to apply split on punctuations. By default BERT
starts a new token for punctuations. This makes detokenization difficult
for tasks like seq2seq decoding.
"""
self.do_lower_case = do_lower_case
self.split_on_punc = split_on_punc
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
if self.split_on_punc:
split_tokens.extend(self._run_split_on_punc(token))
else:
split_tokens.append(token)
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=400):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
def preprocess_text(inputs, remove_space=True, lower=False):
"""Preprocesses data by removing extra space and normalize data.
This method is used together with sentence piece tokenizer and is forked from:
https://github.com/google-research/google-research/blob/master/albert/tokenization.py
Args:
inputs: The input text.
remove_space: Whether to remove the extra space.
lower: Whether to lowercase the text.
Returns:
The preprocessed text.
"""
outputs = inputs
if remove_space:
outputs = " ".join(inputs.strip().split())
if six.PY2 and isinstance(outputs, str):
try:
outputs = six.ensure_text(outputs, "utf-8")
except UnicodeDecodeError:
outputs = six.ensure_text(outputs, "latin-1")
outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if lower:
outputs = outputs.lower()
return outputs
def encode_pieces(sp_model, text, sample=False):
"""Segements text into pieces.
This method is used together with sentence piece tokenizer and is forked from:
https://github.com/google-research/google-research/blob/master/albert/tokenization.py
Args:
sp_model: A spm.SentencePieceProcessor object.
text: The input text to be segemented.
sample: Whether to randomly sample a segmentation output or return a
deterministic one.
Returns:
A list of token pieces.
"""
if six.PY2 and isinstance(text, six.text_type):
text = six.ensure_binary(text, "utf-8")
if not sample:
pieces = sp_model.EncodeAsPieces(text)
else:
pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = []
for piece in pieces:
piece = printable_text(piece)
if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
cur_pieces = sp_model.EncodeAsPieces(piece[:-1].replace(
SPIECE_UNDERLINE, ""))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
else:
cur_pieces[0] = cur_pieces[0][1:]
cur_pieces.append(piece[-1])
new_pieces.extend(cur_pieces)
else:
new_pieces.append(piece)
return new_pieces
def encode_ids(sp_model, text, sample=False):
"""Segments text and return token ids.
This method is used together with sentence piece tokenizer and is forked from:
https://github.com/google-research/google-research/blob/master/albert/tokenization.py
Args:
sp_model: A spm.SentencePieceProcessor object.
text: The input text to be segemented.
sample: Whether to randomly sample a segmentation output or return a
deterministic one.
Returns:
A list of token ids.
"""
pieces = encode_pieces(sp_model, text, sample=sample)
ids = [sp_model.PieceToId(piece) for piece in pieces]
return ids
class FullSentencePieceTokenizer(object):
"""Runs end-to-end sentence piece tokenization.
The interface of this class is intended to keep the same as above
`FullTokenizer` class for easier usage.
"""
def __init__(self, sp_model_file):
"""Inits FullSentencePieceTokenizer.
Args:
sp_model_file: The path to the sentence piece model file.
"""
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(sp_model_file)
self.vocab = {
self.sp_model.IdToPiece(i): i
for i in six.moves.range(self.sp_model.GetPieceSize())
}
def tokenize(self, text):
"""Tokenizes text into pieces."""
return encode_pieces(self.sp_model, text)
def convert_tokens_to_ids(self, tokens):
"""Converts a list of tokens to a list of ids."""
return [self.sp_model.PieceToId(printable_text(token)) for token in tokens]
def convert_ids_to_tokens(self, ids):
"""Converts a list of ids ot a list of tokens."""
return [self.sp_model.IdToPiece(id_) for id_ in ids]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import six
import tensorflow as tf
from official.nlp.bert import tokenization
class TokenizationTest(tf.test.TestCase):
"""Tokenization test.
The implementation is forked from
https://github.com/google-research/bert/blob/master/tokenization_test.py."
"""
def test_full_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
]
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
if six.PY2:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
else:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens
]).encode("utf-8"))
vocab_file = vocab_writer.name
tokenizer = tokenization.FullTokenizer(vocab_file)
os.unlink(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertAllEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self):
tokenizer = tokenization.BasicTokenizer()
self.assertAllEqual(
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
[u"ah", u"\u535A", u"\u63A8", u"zz"])
def test_basic_tokenizer_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello", "!", "how", "are", "you", "?"])
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_basic_tokenizer_no_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_basic_tokenizer_no_split_on_punc(self):
tokenizer = tokenization.BasicTokenizer(
do_lower_case=True, split_on_punc=False)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello!how", "are", "you?"])
def test_wordpiece_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", "##!", "!"
]
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
self.assertAllEqual(tokenizer.tokenize(""), [])
self.assertAllEqual(
tokenizer.tokenize("unwanted running"),
["un", "##want", "##ed", "runn", "##ing"])
self.assertAllEqual(
tokenizer.tokenize("unwanted running !"),
["un", "##want", "##ed", "runn", "##ing", "!"])
self.assertAllEqual(
tokenizer.tokenize("unwanted running!"),
["un", "##want", "##ed", "runn", "##ing", "##!"])
self.assertAllEqual(
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
def test_convert_tokens_to_ids(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing"
]
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
self.assertAllEqual(
tokenization.convert_tokens_to_ids(
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
def test_is_whitespace(self):
self.assertTrue(tokenization._is_whitespace(u" "))
self.assertTrue(tokenization._is_whitespace(u"\t"))
self.assertTrue(tokenization._is_whitespace(u"\r"))
self.assertTrue(tokenization._is_whitespace(u"\n"))
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
self.assertFalse(tokenization._is_whitespace(u"A"))
self.assertFalse(tokenization._is_whitespace(u"-"))
def test_is_control(self):
self.assertTrue(tokenization._is_control(u"\u0005"))
self.assertFalse(tokenization._is_control(u"A"))
self.assertFalse(tokenization._is_control(u" "))
self.assertFalse(tokenization._is_control(u"\t"))
self.assertFalse(tokenization._is_control(u"\r"))
self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
def test_is_punctuation(self):
self.assertTrue(tokenization._is_punctuation(u"-"))
self.assertTrue(tokenization._is_punctuation(u"$"))
self.assertTrue(tokenization._is_punctuation(u"`"))
self.assertTrue(tokenization._is_punctuation(u"."))
self.assertFalse(tokenization._is_punctuation(u"A"))
self.assertFalse(tokenization._is_punctuation(u" "))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment