Unverified Commit 30d14a96 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#6847)

249377254  by hongkuny<hongkuny@google.com>:

    Internal change

249373328  by hongkuny<hongkuny@google.com>:

    Clean up tf import

--
249333938  by hongkuny<hongkuny@google.com>:

    Fix tf1 import

--
249325089  by hongkuny<hongkuny@google.com>:

    BERT 2.0

--
249173564  by hongkuny<hongkuny@google.com>:

    Internal change

PiperOrigin-RevId: 249377254
parent 1529b82c
# BERT in TensorFlow
Note> Please do not create pull request. This model is still under development
and testing.
The academic paper which describes BERT in detail and provides full results on a
number of tasks can be found here: https://arxiv.org/abs/1810.04805.
This repository contains TensorFlow 2.0 implementation for BERT.
Since the repository is under active development at this moment, the source of
truth is TensorFlow team's internal repository. The repo is not officially
released as it is not stable and requires extensive testing.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf
from official.bert import modeling
def gather_indexes(sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of `BertModel` layer of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining of
with dimension (batch_size, max_predictions_per_seq) where
`max_predictions_per_seq` is maximum number of tokens to mask out and
predict per each sequence.
Returns:
Masked out sequence tensor of shape (batch_size * max_predictions_per_seq,
num_hidden).
"""
sequence_shape = modeling.get_shape_list(
sequence_tensor, name='sequence_output_tensor')
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.keras.backend.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.keras.backend.reshape(
sequence_tensor, [batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
class BertPretrainLayer(tf.keras.layers.Layer):
"""Wrapper layer for pre-training a BERT model.
This layer wraps an existing `bert_layer` which is a Keras Layer.
It outputs `sequence_output` from TransformerBlock sub-layer and
`sentence_output` which are suitable for feeding into a BertPretrainLoss
layer. This layer can be used along with an unsupervised input to
pre-train the embeddings for `bert_layer`.
"""
def __init__(self,
config,
bert_layer,
initializer=None,
float_type=tf.float32,
**kwargs):
super(BertPretrainLayer, self).__init__(**kwargs)
self.config = copy.deepcopy(config)
self.float_type = float_type
self.embedding_table = bert_layer.embedding_lookup.embeddings
self.num_next_sentence_label = 2
if initializer:
self.initializer = initializer
else:
self.initializer = tf.keras.initializers.TruncatedNormal(
stddev=self.config.initializer_range)
def build(self, unused_input_shapes):
self.lm_dense = tf.keras.layers.Dense(
self.config.hidden_size,
activation=modeling.get_activation(self.config.hidden_act),
kernel_initializer=self.initializer)
self.lm_bias = self.add_weight(
shape=[self.config.vocab_size],
name='lm_bias',
initializer=tf.keras.initializers.Zeros())
self.lm_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12)
self.next_sentence_dense = tf.keras.layers.Dense(
self.num_next_sentence_label, kernel_initializer=self.initializer)
super(BertPretrainLayer, self).build(unused_input_shapes)
def __call__(self,
pooled_output,
sequence_output=None,
masked_lm_positions=None):
inputs = modeling.pack_inputs(
[pooled_output, sequence_output, masked_lm_positions])
return super(BertPretrainLayer, self).__call__(inputs)
def call(self, inputs):
unpacked_inputs = modeling.unpack_inputs(inputs)
pooled_output = unpacked_inputs[0]
sequence_output = unpacked_inputs[1]
masked_lm_positions = unpacked_inputs[2]
mask_lm_input_tensor = gather_indexes(
sequence_output, masked_lm_positions)
lm_output = self.lm_dense(mask_lm_input_tensor)
lm_output = self.lm_layer_norm(lm_output)
lm_output = tf.keras.backend.dot(
lm_output, tf.keras.backend.transpose(self.embedding_table))
lm_output = tf.keras.backend.bias_add(lm_output, self.lm_bias)
lm_output = tf.keras.backend.softmax(lm_output)
lm_output = tf.keras.backend.log(lm_output)
sentence_output = self.next_sentence_dense(pooled_output)
sentence_output = tf.keras.backend.softmax(sentence_output)
sentence_output = tf.keras.backend.log(sentence_output)
return (lm_output, sentence_output)
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
"""Returns layer that computes custom loss and metrics for pretraining."""
def __init__(self, bert_config, **kwargs):
super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
self.config = copy.deepcopy(bert_config)
def __call__(self,
lm_output,
sentence_output=None,
lm_label_ids=None,
lm_label_weights=None,
sentence_labels=None):
inputs = modeling.pack_inputs([
lm_output, sentence_output, lm_label_ids, lm_label_weights,
sentence_labels
])
return super(BertPretrainLossAndMetricLayer, self).__call__(inputs)
def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
lm_per_example_loss, sentence_output, sentence_labels,
sentence_per_example_loss):
masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
lm_labels, lm_output)
masked_lm_accuracy = tf.reduce_mean(masked_lm_accuracy * lm_label_weights)
self.add_metric(
masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')
lm_example_loss = tf.reshape(lm_per_example_loss, [-1])
lm_example_loss = tf.reduce_mean(lm_example_loss * lm_label_weights)
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
sentence_labels, sentence_output)
self.add_metric(
next_sentence_accuracy,
name='next_sentence_accuracy',
aggregation='mean')
next_sentence_mean_loss = tf.reduce_mean(sentence_per_example_loss)
self.add_metric(
next_sentence_mean_loss, name='next_sentence_loss', aggregation='mean')
def call(self, inputs):
unpacked_inputs = modeling.unpack_inputs(inputs)
lm_output = unpacked_inputs[0]
sentence_output = unpacked_inputs[1]
lm_label_ids = tf.keras.backend.cast(unpacked_inputs[2], tf.int32)
lm_label_ids = tf.keras.backend.reshape(lm_label_ids, [-1])
lm_label_ids_one_hot = tf.keras.backend.one_hot(lm_label_ids,
self.config.vocab_size)
lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3], tf.float32)
lm_label_weights = tf.keras.backend.reshape(lm_label_weights, [-1])
lm_per_example_loss = -tf.keras.backend.sum(
lm_output * lm_label_ids_one_hot, axis=[-1])
numerator = tf.keras.backend.sum(lm_label_weights * lm_per_example_loss)
denominator = tf.keras.backend.sum(lm_label_weights) + 1e-5
mask_label_loss = numerator / denominator
sentence_labels = tf.keras.backend.cast(unpacked_inputs[4], dtype=tf.int32)
sentence_labels = tf.keras.backend.reshape(sentence_labels, [-1])
sentence_label_one_hot = tf.keras.backend.one_hot(sentence_labels, 2)
per_example_loss_sentence = -tf.keras.backend.sum(
sentence_label_one_hot * sentence_output, axis=-1)
sentence_loss = tf.keras.backend.mean(per_example_loss_sentence)
loss = mask_label_loss + sentence_loss
final_loss = tf.fill(
tf.keras.backend.shape(per_example_loss_sentence), loss)
self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
lm_per_example_loss, sentence_output, sentence_labels,
per_example_loss_sentence)
return final_loss
def pretrain_model(bert_config,
seq_length,
max_predictions_per_seq,
initializer=None):
"""Returns model to be used for pre-training.
Args:
bert_config: Configuration that defines the core BERT model.
seq_length: Maximum sequence length of the training data.
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining.
initializer: Initializer for weights in BertPretrainLayer.
Returns:
Pretraining model as well as core BERT submodel from which to save
weights after pretraining.
"""
input_word_ids = tf.keras.layers.Input(
shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
input_mask = tf.keras.layers.Input(
shape=(seq_length,), name='input_mask', dtype=tf.int32)
input_type_ids = tf.keras.layers.Input(
shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
masked_lm_positions = tf.keras.layers.Input(
shape=(max_predictions_per_seq,),
name='masked_lm_positions',
dtype=tf.int32)
masked_lm_weights = tf.keras.layers.Input(
shape=(max_predictions_per_seq,),
name='masked_lm_weights',
dtype=tf.int32)
next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
masked_lm_ids = tf.keras.layers.Input(
shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
bert_submodel_name = 'bert_core_layer'
bert_submodel = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
name=bert_submodel_name,
config=bert_config)
pooled_output = bert_submodel.outputs[0]
sequence_output = bert_submodel.outputs[1]
pretrain_layer = BertPretrainLayer(
bert_config,
bert_submodel.get_layer(bert_submodel_name),
initializer=initializer)
lm_output, sentence_output = pretrain_layer(pooled_output, sequence_output,
masked_lm_positions)
pretrain_loss_layer = BertPretrainLossAndMetricLayer(bert_config)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
masked_lm_weights, next_sentence_labels)
return tf.keras.Model(
inputs={
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids,
'masked_lm_positions': masked_lm_positions,
'masked_lm_ids': masked_lm_ids,
'masked_lm_weights': masked_lm_weights,
'next_sentence_labels': next_sentence_labels,
},
outputs=output_loss), bert_submodel
class BertSquadLogitsLayer(tf.keras.layers.Layer):
"""Returns a layer that computes custom logits for BERT squad model."""
def __init__(self, initializer=None, float_type=tf.float32, **kwargs):
super(BertSquadLogitsLayer, self).__init__(**kwargs)
self.initializer = initializer
self.float_type = float_type
def build(self, unused_input_shapes):
self.final_dense = tf.keras.layers.Dense(
units=2, kernel_initializer=self.initializer, name='final_dense')
super(BertSquadLogitsLayer, self).build(unused_input_shapes)
def call(self, inputs):
sequence_output = inputs
input_shape = sequence_output.shape.as_list()
sequence_length = input_shape[1]
num_hidden_units = input_shape[2]
final_hidden_input = tf.keras.backend.reshape(sequence_output,
[-1, num_hidden_units])
logits = self.final_dense(final_hidden_input)
logits = tf.keras.backend.reshape(logits, [-1, sequence_length, 2])
logits = tf.transpose(logits, [2, 0, 1])
unstacked_logits = tf.unstack(logits, axis=0)
return unstacked_logits[0], unstacked_logits[1]
def squad_model(bert_config, max_seq_length, float_type, initializer=None):
"""Returns BERT Squad model along with core BERT model to import weights.
Args:
bert_config: BertConfig, the config defines the core Bert model.
max_seq_length: integer, the maximum input sequence length.
float_type: tf.dtype, tf.float32 or tf.bfloat16.
initializer: Initializer for weights in BertSquadLogitsLayer.
Returns:
Two tensors, start logits and end logits, [batch x sequence length].
"""
unique_ids = tf.keras.layers.Input(
shape=(1,), dtype=tf.int32, name='unique_ids')
input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_ids')
input_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='segment_ids')
core_model = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
name='bert_model',
float_type=float_type)
# `BertSquadModel` only uses the sequnce_output which
# has dimensionality (batch_size, sequence_length, num_hidden).
sequence_output = core_model.outputs[1]
if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
squad_logits_layer = BertSquadLogitsLayer(
initializer=initializer, float_type=float_type, name='squad_logits')
start_logits, end_logits = squad_logits_layer(sequence_output)
squad = tf.keras.Model(
inputs={
'unique_ids': unique_ids,
'input_ids': input_word_ids,
'input_mask': input_mask,
'segment_ids': input_type_ids,
},
outputs=[unique_ids, start_logits, end_logits],
name='squad_model')
return squad, core_model
def classifier_model(bert_config,
float_type,
num_labels,
max_seq_length,
final_layer_initializer=None):
"""BERT classifier model in functional API style.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
bert_config: BertConfig, the config defines the core BERT model.
float_type: dtype, tf.float32 or tf.bfloat16.
num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length.
final_layer_initializer: Initializer for final dense layer. Defaulted
TruncatedNormal initializer.
Returns:
Combined prediction model (words, mask, type) -> (one-hot labels)
BERT sub-model (words, mask, type) -> (bert_outputs)
"""
input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
float_type=float_type)
pooled_output = bert_model.outputs[0]
if final_layer_initializer is not None:
initializer = final_layer_initializer
else:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output)
output = tf.keras.layers.Dense(
num_labels,
kernel_initializer=initializer,
name='output',
dtype=float_type)(
output)
return tf.keras.Model(
inputs={
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids
},
outputs=output), bert_model
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT library to process data for classification task."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import csv
import os
from absl import logging
import tensorflow as tf
from official.bert import tokenization
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
input_ids,
input_mask,
segment_ids,
label_id,
is_real_example=True):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@staticmethod
def get_processor_name():
"""Gets the string identifier of the processor."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.io.gfile.GFile(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def __init__(self):
self.language = "zh"
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % self.language))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[0])
text_b = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[2])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % (i)
language = tokenization.convert_to_unicode(line[0])
if language != tokenization.convert_to_unicode(self.language):
continue
text_a = tokenization.convert_to_unicode(line[6])
text_b = tokenization.convert_to_unicode(line[7])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XNLI"
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
if set_type == "test":
label = "contradiction"
else:
label = tokenization.convert_to_unicode(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MRPC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "COLA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
label = "0"
else:
text_a = tokenization.convert_to_unicode(line[3])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
label_id = label_map[example.label]
if ex_index < 5:
logging.info("*** Example ***")
logging.info("guid: %s" % (example.guid))
logging.info("tokens: %s" %
" ".join([tokenization.printable_text(x) for x in tokens]))
logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %d)" % (example.label, label_id))
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id,
is_real_example=True)
return feature
def file_based_convert_examples_to_features(examples, label_list,
max_seq_length, tokenizer,
output_file):
"""Convert a set of `InputExample`s to a TFRecord file."""
writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logging.info("Writing example %d of %d" % (ex_index, len(examples)))
feature = convert_single_example(ex_index, example, label_list,
max_seq_length, tokenizer)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["label_ids"] = create_int_feature([feature.label_id])
features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def generate_tf_record_from_data_file(processor,
data_dir,
vocab_file,
train_data_output_path=None,
eval_data_output_path=None,
max_seq_length=128,
do_lower_case=True):
"""Generates and saves training data into a tf record file.
Arguments:
processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files
should be in from "dev.tsv", "test.tsv", or "train.tsv".
vocab_file: Text file with words to be used for training/evaluation.
train_data_output_path: Output to which processed tf record for training
will be saved.
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
do_lower_case: Whether to lower case input text.
Returns:
A dictionary containing input meta data.
"""
assert train_data_output_path or eval_data_output_path
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file, do_lower_case=do_lower_case)
assert train_data_output_path
train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer,
train_data_output_path)
num_training_data = len(train_input_data_examples)
if eval_data_output_path:
eval_input_data_examples = processor.get_dev_examples(data_dir)
file_based_convert_examples_to_features(eval_input_data_examples,
label_list, max_seq_length,
tokenizer, eval_data_output_path)
meta_data = {
"task_type": "bert_classification",
"processor_type": processor.get_processor_name(),
"num_labels": len(processor.get_labels()),
"train_data_size": num_training_data,
"max_seq_length": max_seq_length,
}
if eval_data_output_path:
meta_data["eval_data_size"] = len(eval_input_data_examples)
return meta_data
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT finetuning task dataset generator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from absl import app
from absl import flags
import tensorflow as tf
from official.bert import classifier_data_lib
from official.bert import squad_lib
FLAGS = flags.FLAGS
# BERT classification specific flags.
flags.DEFINE_enum(
"fine_tuning_task_type", "classification", ["classification", "squad"],
"The name of the BERT fine tuning task for which data "
"will be generated..")
flags.DEFINE_string(
"input_data_dir", None,
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_string("classification_task_name", None,
"The name of the task to train BERT classifier.")
# BERT Squad task specific flags.
flags.DEFINE_string(
"squad_data_file", None,
"The input data file in for generating training data for BERT squad task.")
flags.DEFINE_integer(
"doc_stride", 128,
"When splitting up a long document into chunks, how much stride to "
"take between chunks.")
flags.DEFINE_integer(
"max_query_length", 64,
"The maximum number of tokens for the question. Questions longer than "
"this will be truncated to this length.")
# Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_string(
"train_data_output_path", None,
"The path in which generated training input data will be written as tf records."
)
flags.DEFINE_string(
"eval_data_output_path", None,
"The path in which generated training input data will be written as tf records."
)
flags.DEFINE_string("meta_data_file_path", None,
"The path in which input meta data will be written.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
flags.DEFINE_bool(
"version_2_with_negative", False,
"If true, the SQuAD examples contain some that do not have an answer.")
def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
assert FLAGS.input_data_dir and FLAGS.classification_task_name
processors = {
"cola": classifier_data_lib.ColaProcessor,
"mnli": classifier_data_lib.MnliProcessor,
"mrpc": classifier_data_lib.MrpcProcessor,
"xnli": classifier_data_lib.XnliProcessor,
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
FLAGS.input_data_dir,
FLAGS.vocab_file,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case)
def generate_squad_dataset():
"""Generates squad training dataset and returns input meta data."""
assert FLAGS.squad_data_file
return squad_lib.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
FLAGS.doc_stride, FLAGS.version_2_with_negative)
def main(_):
if FLAGS.fine_tuning_task_type == "classification":
input_meta_data = generate_classifier_dataset()
else:
input_meta_data = generate_squad_dataset()
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
writer.write(json.dumps(input_meta_data, indent=4) + "\n")
if __name__ == "__main__":
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("train_data_output_path")
flags.mark_flag_as_required("meta_data_file_path")
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT model input pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def file_based_input_fn_builder(input_file, name_to_features):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def input_fn():
"""Returns dataset for training/evaluation."""
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
d = d.map(lambda record: _decode_record(record, name_to_features))
# When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that
# same input file is sent to all workers.
if isinstance(input_file, str) or len(input_file) == 1:
options = tf.data.Options()
options.experimental_distribute.auto_shard = False
d = d.with_options(options)
return d
return input_fn
def create_pretrain_dataset(file_path,
seq_length,
max_predictions_per_seq,
batch_size,
is_training=True):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features = {
'input_ids':
tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask':
tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids':
tf.io.FixedLenFeature([seq_length], tf.int64),
'masked_lm_positions':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_ids':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_weights':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
'next_sentence_labels':
tf.io.FixedLenFeature([1], tf.int64),
}
input_fn = file_based_input_fn_builder(file_path, name_to_features)
dataset = input_fn()
def _select_data_from_record(record):
"""Filter out features to use for pretraining."""
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids'],
'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'],
'next_sentence_labels': record['next_sentence_labels'],
}
y = record['masked_lm_weights']
return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(1024)
return dataset
def create_classifier_dataset(file_path,
seq_length,
batch_size,
is_training=True,
drop_remainder=True):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64),
'is_real_example': tf.io.FixedLenFeature([], tf.int64),
}
input_fn = file_based_input_fn_builder(file_path, name_to_features)
dataset = input_fn()
def _select_data_from_record(record):
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
y = record['label_ids']
return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(1024)
return dataset
def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features = {
'unique_ids': tf.io.FixedLenFeature([], tf.int64),
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
}
if is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
input_fn = file_based_input_fn_builder(file_path, name_to_features)
dataset = input_fn()
def _select_data_from_record(record):
x, y = {}, {}
for name, tensor in record.items():
if name in ('start_positions', 'end_positions'):
y[name] = tensor
else:
x[name] = tensor
return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(1024)
return dataset
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to save models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
import tensorflow as tf
try:
import h5py as _ # pylint: disable=g-import-not-at-top
HAS_H5PY = True
except ImportError:
logging.warning('`h5py` is not installed. Please consider installing it '
'to save weights for long-running training.')
HAS_H5PY = False
def save_model(model, model_dir, weights_file):
"""Saves the model weights."""
weights_file_path = os.path.join(model_dir, weights_file)
del model_dir, weights_file # avoid accident usages.
if not HAS_H5PY:
logging.warning('`h5py` is not installed. Skip saving model weights.')
return
logging.info('Saving weights and optimizer states into %s', weights_file_path)
logging.info('This might take a while...')
model.save(weights_file_path, overwrite=True, include_optimizer=True)
def export_bert_model(model_export_path,
model=None,
model_fn=None,
checkpoint_dir=None):
"""Export BERT model for serving.
Arguments:
model_export_path: Path to which exported model will be saved.
model: Keras model object to export. If none, new model is created via
`model_fn`.
model_fn: Function that returns a BERT model. Used when `model` is not
provided.
checkpoint_dir: Path from which model weights will be loaded.
"""
if model:
tf.keras.experimental.export_saved_model(model, model_export_path)
return
assert model_fn and checkpoint_dir
model_to_export = model_fn()
checkpoint = tf.train.Checkpoint(model=model_to_export)
latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file).assert_existing_objects_matched()
tf.keras.experimental.export_saved_model(model_to_export, model_export_path)
class BertModelCheckpoint(tf.keras.callbacks.Callback):
"""Keras callback that saves model at the end of every epoch."""
def __init__(self, checkpoint_dir, checkpoint):
"""Initializes BertModelCheckpoint.
Arguments:
checkpoint_dir: Directory of the to be saved checkpoint file.
checkpoint: tf.train.Checkpoint object.
"""
super(BertModelCheckpoint, self).__init__()
self.checkpoint_file_name = os.path.join(
checkpoint_dir, 'bert_training_checkpoint_step_{global_step}.ckpt')
assert isinstance(checkpoint, tf.train.Checkpoint)
self.checkpoint = checkpoint
def on_epoch_end(self, epoch, logs=None):
global_step = tf.keras.backend.get_value(self.model.optimizer.iterations)
formatted_file_name = self.checkpoint_file_name.format(
global_step=global_step)
saved_path = self.checkpoint.save(formatted_file_name)
logging.info('Saving model TF checkpoint to : %s', saved_path)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to train BERT models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
import tensorflow as tf
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return '/job:worker' if use_remote_tpu else ''
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to with provided checkpoint prefix."""
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info('Saving model as TF checkpoint: %s', saved_path)
return
def run_customized_training_loop(
# pylint: disable=invalid-name
_sentinel=None,
# pylint: enable=invalid-name
strategy=None,
model_fn=None,
loss_fn=None,
model_dir=None,
train_input_fn=None,
steps_per_epoch=None,
epochs=1,
eval_input_fn=None,
eval_steps=None,
metric_fn=None,
init_checkpoint=None,
use_remote_tpu=False):
"""Run BERT pretrain model training using low-level API.
Arguments:
_sentinel: Used to prevent positional parameters. Internal, do not use.
strategy: Distribution strategy on which to run low level training loop.
model_fn: Function that returns a tuple (model, sub_model). Caller of this
function should add optimizer to the `model` via calling
`model.compile()` API or manually setting `model.optimizer` attribute.
Second element of the returned tuple(sub_model) is an optional sub model
to be used for initial checkpoint -- if provided.
loss_fn: Function with signature func(labels, logits) and returns a loss
tensor.
model_dir: Model directory used during training for restoring/saving model
weights.
train_input_fn: Function that returns a tf.data.Dataset used for training.
steps_per_epoch: Number of steps to run per epoch.
epochs: Number of epochs to train.
eval_input_fn: Function that returns evaluation dataset. If none,
evaluation is skipped.
eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
is not none.
metric_fn: A metrics function that returns a Keras Metric object to record
evaluation result using evaluation dataset or with training dataset
after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
use_remote_tpu: If true, input pipeline ops are placed in TPU worker host
as an optimization.
Returns:
Trained model.
Raises:
ValueError: (1) When model returned by `model_fn` does not have optimizer
attribute or when required parameters are set to none. (2) eval args are
not specified correctly. (3) metric_fn must be a callable if specified.
"""
if _sentinel is not None:
raise ValueError('only call `run_customized_training_loop()` '
'with named arguments.')
required_arguments = [
strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
]
if [arg for arg in required_arguments if arg is None]:
raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
'and `steps_per_epoch` are required parameters')
assert tf.executing_eagerly()
if eval_input_fn and (eval_steps is None or metric_fn is None):
raise ValueError(
'`eval_step` and `metric_fn` are required when `eval_input_fn ` '
'is not none.')
if metric_fn and not callable(metric_fn):
raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.')
# To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task.
with tf.device(get_primary_cpu_task(use_remote_tpu)):
train_iterator = strategy.make_dataset_iterator(train_input_fn())
with strategy.scope():
total_training_steps = steps_per_epoch * epochs
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model, sub_model = model_fn()
if not hasattr(model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
optimizer = model.optimizer
if init_checkpoint:
sub_model.load_weights(init_checkpoint)
metric = metric_fn() if metric_fn else None
# If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation.
train_metric = (
metric.__class__.from_config(metric.get_config())
if metric else None)
@tf.function
def train_step(iterator):
"""Performs a distributed training step."""
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
model_outputs = model(inputs)
loss = loss_fn(labels, model_outputs)
if train_metric:
train_metric.update_state(labels, model_outputs)
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(zip(grads, tvars))
return loss
per_replica_losses = strategy.experimental_run(_replicated_step,
iterator)
# For reporting, we returns the mean of losses.
loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return loss
@tf.function
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
metric.update_state(labels, model_outputs)
strategy.experimental_run(_test_step_fn, iterator)
def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
test_step(test_iterator)
logging.info('Step: [%d] Validation metric = %f', current_training_step,
metric.result())
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
logging.info('Loading from checkpoint file completed')
current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps:
loss = train_step(train_iterator)
current_step += 1
if train_metric:
logging.info(
'Train Step: %d/%d / loss = %s / training metric = %s',
current_step, total_training_steps, loss.numpy(),
train_metric.result())
else:
logging.info('Train Step: %d/%d / loss = %s', current_step,
total_training_steps, loss.numpy())
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_training_steps:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step,
strategy.make_dataset_iterator(eval_input_fn()))
# Re-initialize evaluation metric, except the last step.
if metric and current_step < total_training_steps:
metric.reset_states()
train_metric.reset_states()
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step,
strategy.make_dataset_iterator(eval_input_fn()))
return model
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The main BERT model and related functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
import math
import six
import tensorflow as tf
class BertConfig(object):
"""Configuration for `BertModel`."""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
backward_compatible=True):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
backward_compatible: Boolean, whether the variables shape are compatible
with checkpoints converted from TF 1.x BERT.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.backward_compatible = backward_compatible
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.io.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def get_bert_model(input_word_ids,
input_mask,
input_type_ids,
config=None,
name=None,
float_type=tf.float32):
"""Wraps the core BERT model as a keras.Model."""
bert_model_layer = BertModel(config=config, float_type=float_type, name=name)
pooled_output, sequence_output = bert_model_layer(input_word_ids, input_mask,
input_type_ids)
bert_model = tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[pooled_output, sequence_output])
return bert_model
class BertModel(tf.keras.layers.Layer):
"""BERT model ("Bidirectional Encoder Representations from Transformers").
Example usage:
```python
# Already been converted into WordPiece token ids
input_word_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
input_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
pooled_output, sequence_output = modeling.BertModel(config=config)(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
...
```
"""
def __init__(self, config, float_type=tf.float32, **kwargs):
super(BertModel, self).__init__(**kwargs)
self.config = (
BertConfig.from_dict(config)
if isinstance(config, dict) else copy.deepcopy(config))
self.float_type = float_type
def build(self, unused_input_shapes):
self.embedding_lookup = EmbeddingLookup(
vocab_size=self.config.vocab_size,
embedding_size=self.config.hidden_size,
initializer_range=self.config.initializer_range,
dtype=self.float_type,
name="word_embeddings")
self.embedding_postprocessor = EmbeddingPostprocessor(
use_type_embeddings=True,
token_type_vocab_size=self.config.type_vocab_size,
use_position_embeddings=True,
max_position_embeddings=self.config.max_position_embeddings,
dropout_prob=self.config.hidden_dropout_prob,
initializer_range=self.config.initializer_range,
name="embedding_postprocessor")
self.encoder = Transformer(
num_hidden_layers=self.config.num_hidden_layers,
hidden_size=self.config.hidden_size,
num_attention_heads=self.config.num_attention_heads,
intermediate_size=self.config.intermediate_size,
intermediate_activation=self.config.hidden_act,
hidden_dropout_prob=self.config.hidden_dropout_prob,
attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
initializer_range=self.config.initializer_range,
backward_compatible=self.config.backward_compatible,
name="encoder")
self.pooler_transform = tf.keras.layers.Dense(
units=self.config.hidden_size,
activation="tanh",
kernel_initializer=get_initializer(self.config.initializer_range),
name="pooler_transform")
super(BertModel, self).build(unused_input_shapes)
def __call__(self,
input_word_ids,
input_mask=None,
input_type_ids=None,
**kwargs):
inputs = pack_inputs([input_word_ids, input_mask, input_type_ids])
return super(BertModel, self).__call__(inputs, **kwargs)
def call(self, inputs):
unpacked_inputs = unpack_inputs(inputs)
input_word_ids = unpacked_inputs[0]
input_mask = unpacked_inputs[1]
input_type_ids = unpacked_inputs[2]
word_embeddings = self.embedding_lookup(input_word_ids)
embedding_tensor = self.embedding_postprocessor(
word_embeddings=word_embeddings, token_type_ids=input_type_ids)
attention_mask = None
if input_mask is not None:
attention_mask = create_attention_mask_from_input_mask(
input_word_ids, input_mask)
sequence_output = self.encoder(embedding_tensor, attention_mask)
first_token_tensor = tf.squeeze(sequence_output[:, 0:1, :], axis=1)
pooled_output = self.pooler_transform(first_token_tensor)
return (pooled_output, sequence_output)
def get_config(self):
config = {"config": self.config.to_dict()}
base_config = super(BertModel, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class EmbeddingLookup(tf.keras.layers.Layer):
"""Looks up words embeddings for id tensor."""
def __init__(self,
vocab_size,
embedding_size=768,
initializer_range=0.02,
**kwargs):
super(EmbeddingLookup, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.initializer_range = initializer_range
def build(self, unused_input_shapes):
self.embeddings = self.add_weight(
"embeddings",
shape=[self.vocab_size, self.embedding_size],
initializer=get_initializer(self.initializer_range),
dtype=self.dtype)
super(EmbeddingLookup, self).build(unused_input_shapes)
def call(self, inputs):
input_shape = get_shape_list(inputs)
flat_input = tf.reshape(inputs, [-1])
output = tf.gather(self.embeddings, flat_input)
output = tf.reshape(output, input_shape + [self.embedding_size])
return output
class EmbeddingPostprocessor(tf.keras.layers.Layer):
"""Performs various post-processing on a word embedding tensor."""
def __init__(self,
use_type_embeddings=False,
token_type_vocab_size=None,
use_position_embeddings=True,
max_position_embeddings=512,
dropout_prob=0.0,
initializer_range=0.02,
**kwargs):
super(EmbeddingPostprocessor, self).__init__(**kwargs)
self.use_type_embeddings = use_type_embeddings
self.token_type_vocab_size = token_type_vocab_size
self.use_position_embeddings = use_position_embeddings
self.max_position_embeddings = max_position_embeddings
self.dropout_prob = dropout_prob
self.initializer_range = initializer_range
if self.use_type_embeddings and not self.token_type_vocab_size:
raise ValueError("If `use_type_embeddings` is True, then "
"`token_type_vocab_size` must be specified.")
def build(self, input_shapes):
(word_embeddings_shape, _) = input_shapes
width = word_embeddings_shape.as_list()[-1]
self.type_embeddings = None
if self.use_type_embeddings:
self.type_embeddings = self.add_weight(
"type_embeddings",
shape=[self.token_type_vocab_size, width],
initializer=get_initializer(self.initializer_range),
dtype=self.dtype)
self.position_embeddings = None
if self.use_position_embeddings:
self.position_embeddings = self.add_weight(
"position_embeddings",
shape=[self.max_position_embeddings, width],
initializer=get_initializer(self.initializer_range),
dtype=self.dtype)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="layer_norm", axis=-1, epsilon=1e-12)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_prob)
super(EmbeddingPostprocessor, self).build(input_shapes)
def __call__(self, word_embeddings, token_type_ids=None, **kwargs):
inputs = pack_inputs([word_embeddings, token_type_ids])
return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs)
def call(self, inputs):
unpacked_inputs = unpack_inputs(inputs)
word_embeddings = unpacked_inputs[0]
token_type_ids = unpacked_inputs[1]
input_shape = get_shape_list(word_embeddings, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
width = input_shape[2]
output = word_embeddings
if self.use_type_embeddings:
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
one_hot_ids = tf.one_hot(
flat_token_type_ids,
depth=self.token_type_vocab_size,
dtype=self.dtype)
token_type_embeddings = tf.matmul(one_hot_ids, self.type_embeddings)
token_type_embeddings = tf.reshape(token_type_embeddings,
[batch_size, seq_length, width])
output += token_type_embeddings
if self.use_position_embeddings:
position_embeddings = tf.expand_dims(
tf.slice(self.position_embeddings, [0, 0], [seq_length, width]),
axis=0)
output += position_embeddings
output = self.output_layer_norm(output)
output = self.output_dropout(output)
return output
class Attention(tf.keras.layers.Layer):
"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then
this is self-attention. Each timestep in `from_tensor` attends to the
corresponding sequence in `to_tensor`, and returns a fixed-with vector.
This function first projects `from_tensor` into a "query" tensor and
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
of tensors of length `num_attention_heads`, where each tensor is of shape
[batch_size, seq_length, size_per_head].
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
tensor and returned.
In practice, the multi-headed attention are done with tf.einsum as follows:
Input_tensor: [BFD]
Wq, Wk, Wv: [DNH]
Q:[BFNH] = einsum('BFD,DNH->BFNH', Input_tensor, Wq)
K:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wk)
V:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wv)
attention_scores:[BNFT] = einsum('BFNH,BTNH>BNFT', Q, K) / sqrt(H)
attention_probs:[BNFT] = softmax(attention_scores)
context_layer:[BFNH] = einsum('BNFT,BTNH->BFNH', attention_probs, V)
Wout:[DNH]
Output:[BFD] = einsum('BFNH,DNH>BFD', context_layer, Wout)
"""
def __init__(self,
num_attention_heads=12,
size_per_head=64,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
backward_compatible=False,
**kwargs):
super(Attention, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads
self.size_per_head = size_per_head
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.backward_compatible = backward_compatible
def build(self, unused_input_shapes):
self.query_dense = self._projection_dense_layer("query")
self.key_dense = self._projection_dense_layer("key")
self.value_dense = self._projection_dense_layer("value")
self.attention_probs_dropout = tf.keras.layers.Dropout(
rate=self.attention_probs_dropout_prob)
super(Attention, self).build(unused_input_shapes)
def reshape_to_matrix(self, input_tensor):
"""Reshape N > 2 rank tensor to rank 2 tensor for performance."""
ndims = input_tensor.shape.ndims
if ndims < 2:
raise ValueError("Input tensor must have at least rank 2."
"Shape = %s" % (input_tensor.shape))
if ndims == 2:
return input_tensor
width = input_tensor.shape[-1]
output_tensor = tf.reshape(input_tensor, [-1, width])
return output_tensor
def __call__(self, from_tensor, to_tensor, attention_mask=None, **kwargs):
inputs = pack_inputs([from_tensor, to_tensor, attention_mask])
return super(Attention, self).__call__(inputs, **kwargs)
def call(self, inputs):
(from_tensor, to_tensor, attention_mask) = unpack_inputs(inputs)
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self.query_dense(from_tensor)
# `key_tensor` = [B, T, N, H]
key_tensor = self.key_dense(to_tensor)
# `value_tensor` = [B, T, N, H]
value_tensor = self.value_dense(to_tensor)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BFNH,BTNH->BNFT", query_tensor, key_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self.size_per_head)))
if attention_mask is not None:
# `attention_mask` = [B, 1, F, T]
attention_mask = tf.expand_dims(attention_mask, axis=[1])
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - tf.cast(attention_mask, self.dtype)) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_scores += adder
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = tf.nn.softmax(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_probs_dropout(attention_probs)
# `context_layer` = [B, F, N, H]
context_tensor = tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor)
return context_tensor
def _projection_dense_layer(self, name):
return Dense3D(
num_attention_heads=self.num_attention_heads,
size_per_head=self.size_per_head,
kernel_initializer=get_initializer(self.initializer_range),
output_projection=False,
backward_compatible=self.backward_compatible,
name=name)
class Dense3D(tf.keras.layers.Layer):
"""A Dense Layer using 3D kernel with tf.einsum implementation."""
def __init__(self,
num_attention_heads=12,
size_per_head=72,
kernel_initializer=None,
bias_initializer="zeros",
activation=None,
output_projection=False,
backward_compatible=False,
**kwargs):
super(Dense3D, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads
self.size_per_head = size_per_head
self.hidden_size = num_attention_heads * size_per_head
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.activation = activation
self.output_projection = output_projection
self.backward_compatible = backward_compatible
@property
def compatible_kernel_shape(self):
if self.output_projection:
return [self.hidden_size, self.hidden_size]
return [self.last_dim, self.hidden_size]
@property
def compatible_bias_shape(self):
return [self.hidden_size]
@property
def kernel_shape(self):
if self.output_projection:
return [self.num_attention_heads, self.size_per_head, self.hidden_size]
return [self.last_dim, self.num_attention_heads, self.size_per_head]
@property
def bias_shape(self):
if self.output_projection:
return [self.hidden_size]
return [self.num_attention_heads, self.size_per_head]
def build(self, input_shape):
dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx())
if not (dtype.is_floating or dtype.is_complex):
raise TypeError("Unable to build `Dense` layer with non-floating point "
"dtype %s" % (dtype,))
input_shape = tf.TensorShape(input_shape)
if tf.compat.dimension_value(input_shape[-1]) is None:
raise ValueError("The last dimension of the inputs to `Dense` "
"should be defined. Found `None`.")
self.last_dim = tf.compat.dimension_value(input_shape[-1])
self.input_spec = tf.keras.layers.InputSpec(
min_ndim=3, axes={-1: self.last_dim})
# Determines variable shapes.
if self.backward_compatible:
kernel_shape = self.compatible_kernel_shape
bias_shape = self.compatible_bias_shape
else:
kernel_shape = self.kernel_shape
bias_shape = self.bias_shape
self.kernel = self.add_weight(
"kernel",
shape=kernel_shape,
initializer=self.kernel_initializer,
dtype=self.dtype,
trainable=True)
self.bias = self.add_weight(
"bias",
shape=bias_shape,
initializer=self.bias_initializer,
dtype=self.dtype,
trainable=True)
super(Dense3D, self).build(input_shape)
def call(self, inputs):
"""Implements ``call()`` for Dense3D.
Args:
inputs: A float tensor of shape [batch_size, sequence_length, hidden_size]
when output_projection is False, otherwise a float tensor of shape
[batch_size, sequence_length, num_heads, dim_per_head].
Returns:
The projected tensor with shape [batch_size, sequence_length, num_heads,
dim_per_head] when output_projection is False, otherwise [batch_size,
sequence_length, hidden_size].
"""
if self.backward_compatible:
kernel = tf.keras.backend.reshape(self.kernel, self.kernel_shape)
bias = tf.keras.backend.reshape(self.bias, self.bias_shape)
else:
kernel = self.kernel
bias = self.bias
if self.output_projection:
ret = tf.einsum("abcd,cde->abe", inputs, kernel)
else:
ret = tf.einsum("abc,cde->abde", inputs, kernel)
ret += bias
if self.activation is not None:
return self.activation(ret)
return ret
class Dense2DProjection(tf.keras.layers.Layer):
"""A 2D projection layer with tf.einsum implementation."""
def __init__(self,
output_size,
kernel_initializer=None,
bias_initializer="zeros",
activation=None,
**kwargs):
super(Dense2DProjection, self).__init__(**kwargs)
self.output_size = output_size
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.activation = activation
def build(self, input_shape):
dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx())
if not (dtype.is_floating or dtype.is_complex):
raise TypeError("Unable to build `Dense` layer with non-floating point "
"dtype %s" % (dtype,))
input_shape = tf.TensorShape(input_shape)
if tf.compat.dimension_value(input_shape[-1]) is None:
raise ValueError("The last dimension of the inputs to `Dense` "
"should be defined. Found `None`.")
last_dim = tf.compat.dimension_value(input_shape[-1])
self.input_spec = tf.keras.layers.InputSpec(min_ndim=3, axes={-1: last_dim})
self.kernel = self.add_weight(
"kernel",
shape=[last_dim, self.output_size],
initializer=self.kernel_initializer,
dtype=self.dtype,
trainable=True)
self.bias = self.add_weight(
"bias",
shape=[self.output_size],
initializer=self.bias_initializer,
dtype=self.dtype,
trainable=True)
super(Dense2DProjection, self).build(input_shape)
def call(self, inputs):
"""Implements call() for Dense2DProjection.
Args:
inputs: float Tensor of shape [batch, from_seq_length,
num_attention_heads, size_per_head].
Returns:
A 3D Tensor.
"""
ret = tf.einsum("abc,cd->abd", inputs, self.kernel)
ret += self.bias
if self.activation is not None:
return self.activation(ret)
return ret
class TransformerBlock(tf.keras.layers.Layer):
"""Single transformer layer.
It has two sub-layers. The first is a multi-head self-attention mechanism, and
the second is a positionwise fully connected feed-forward network.
"""
def __init__(self,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
intermediate_activation="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
backward_compatible=False,
**kwargs):
super(TransformerBlock, self).__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = get_activation(intermediate_activation)
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.backward_compatible = backward_compatible
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (self.hidden_size, self.num_attention_heads))
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
def build(self, unused_input_shapes):
self.attention_layer = Attention(
num_attention_heads=self.num_attention_heads,
size_per_head=self.attention_head_size,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range,
backward_compatible=self.backward_compatible,
name="self_attention")
self.attention_output_dense = Dense3D(
num_attention_heads=self.num_attention_heads,
size_per_head=int(self.hidden_size / self.num_attention_heads),
kernel_initializer=get_initializer(self.initializer_range),
output_projection=True,
backward_compatible=self.backward_compatible,
name="self_attention_output")
self.attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
self.attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
self.intermediate_dense = Dense2DProjection(
output_size=self.intermediate_size,
kernel_initializer=get_initializer(self.initializer_range),
activation=self.intermediate_activation,
name="intermediate")
self.output_dense = Dense2DProjection(
output_size=self.hidden_size,
kernel_initializer=get_initializer(self.initializer_range),
name="output")
self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12)
super(TransformerBlock, self).build(unused_input_shapes)
def __call__(self, input_tensor, attention_mask=None):
inputs = pack_inputs([input_tensor, attention_mask])
return super(TransformerBlock, self).__call__(inputs)
def call(self, inputs):
(input_tensor, attention_mask) = unpack_inputs(inputs)
attention_output = self.attention_layer(
from_tensor=input_tensor,
to_tensor=input_tensor,
attention_mask=attention_mask)
attention_output = self.attention_output_dense(attention_output)
attention_output = self.attention_dropout(attention_output)
attention_output = self.attention_layer_norm(input_tensor +
attention_output)
intermediate_output = self.intermediate_dense(attention_output)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output
class Transformer(tf.keras.layers.Layer):
"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
This is almost an exact implementation of the original Transformer encoder.
See the original paper:
https://arxiv.org/abs/1706.03762
Also see:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
"""
def __init__(self,
num_hidden_layers=12,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
intermediate_activation="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
backward_compatible=False,
**kwargs):
super(Transformer, self).__init__(**kwargs)
self.num_hidden_layers = num_hidden_layers
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = get_activation(intermediate_activation)
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.backward_compatible = backward_compatible
def build(self, unused_input_shapes):
self.layers = []
for i in range(self.num_hidden_layers):
self.layers.append(
TransformerBlock(
hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
intermediate_activation=self.intermediate_activation,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range,
backward_compatible=self.backward_compatible,
name=("layer_%d" % i)))
super(Transformer, self).build(unused_input_shapes)
# Workaround for Keras bug where layers aren't tracked properly.
for i in range(len(self.layers)):
self.__setattr__("layer%d" % i, self.layers[i])
def __call__(self, input_tensor, attention_mask=None):
inputs = pack_inputs([input_tensor, attention_mask])
return super(Transformer, self).__call__(inputs=inputs)
def call(self, inputs):
unpacked_inputs = unpack_inputs(inputs)
input_tensor = unpacked_inputs[0]
attention_mask = unpacked_inputs[1]
output_tensor = input_tensor
for layer in self.layers:
output_tensor = layer(output_tensor, attention_mask)
return output_tensor
def pack_inputs(inputs):
"""Pack a list of `inputs` tensors to a tuple.
Args:
inputs: a list of tensors.
Returns:
a tuple of tensors. if any input is None, replace it with a special constant
tensor.
"""
inputs = tf.nest.flatten(inputs)
outputs = []
for x in inputs:
if x is None:
outputs.append(tf.constant(0, shape=[], dtype=tf.int32))
else:
outputs.append(x)
return tuple(outputs)
def unpack_inputs(inputs):
"""unpack a tuple of `inputs` tensors to a tuple.
Args:
inputs: a list of tensors.
Returns:
a tuple of tensors. if any input is a special constant tensor, replace it
with None.
"""
inputs = tf.nest.flatten(inputs)
outputs = []
for x in inputs:
if is_special_none_tensor(x):
outputs.append(None)
else:
outputs.append(x)
x = tuple(outputs)
# To trick the very pointless 'unbalanced-tuple-unpacking' pylint check
# from triggering.
if len(x) == 1:
return x[0]
return tuple(outputs)
def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def get_activation(identifier):
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
Args:
identifier: String name of the activation function.
Returns:
A Python function corresponding to the activation function. If
`identifier` is None, empty, or "linear", this will return None.
If `identifier` is not a string, it will return `identifier`.
Raises:
ValueError: The `identifier` does not correspond to a known
activation.
"""
if identifier is None:
return None
elif isinstance(identifier, six.string_types):
name_to_fn = {
"linear": None,
"relu": tf.nn.relu,
"gelu": gelu,
"tanh": tf.nn.tanh,
}
identifier = str(identifier).lower()
if identifier not in name_to_fn:
raise ValueError("Unsupported activation function: %s" % (identifier))
return name_to_fn[identifier]
elif callable(identifier):
return identifier
else:
raise ValueError("Could not interpret activation "
"function identifier: %s" % (identifier))
def get_initializer(initializer_range=0.02):
"""Creates a `tf.initializers.truncated_normal` with the given range.
Args:
initializer_range: float, initializer range for stddev.
Returns:
TruncatedNormal initializer with stddev = `initializer_range`.
"""
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
def get_shape_list(tensor, expected_rank=None, name=None):
"""Returns a list of the shape of tensor, preferring static dimensions.
Args:
tensor: A tf.Tensor object to find the shape of.
expected_rank: (optional) int. The expected rank of `tensor`. If this is
specified and the `tensor` has a different rank, and exception will be
thrown.
name: Optional name of the tensor for the error message.
Returns:
A list of dimensions of the shape of tensor. All static dimensions will
be returned as python integers, and dynamic dimensions will be returned
as tf.Tensor scalars.
"""
if expected_rank is not None:
assert_rank(tensor, expected_rank, name)
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dyn_shape = tf.shape(tensor)
for index in non_static_indexes:
shape[index] = dyn_shape[index]
return shape
def assert_rank(tensor, expected_rank, name=None):
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
expected_rank_dict = {}
if isinstance(expected_rank, six.integer_types):
expected_rank_dict[expected_rank] = True
else:
for x in expected_rank:
expected_rank_dict[x] = True
actual_rank = tensor.shape.ndims
if actual_rank not in expected_rank_dict:
raise ValueError(
"For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not "
"equal to the expected tensor rank `%s`" %
(name, actual_rank, str(tensor.shape), str(expected_rank)))
def create_attention_mask_from_input_mask(from_tensor, to_mask):
"""Create 3D attention mask from a 2D tensor mask.
Args:
from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_shape = get_shape_list(to_mask, expected_rank=2)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=from_tensor.dtype)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to optimization (weight updates)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import tensorflow as tf
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applys a warmup schedule on a given learning rate decay schedule."""
def __init__(
self,
initial_learning_rate,
decay_schedule_fn,
warmup_steps,
power=1.0,
name=None):
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.power = power
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step):
with tf.name_scope(self.name or 'WarmUp') as name:
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = (
self.initial_learning_rate *
tf.math.pow(warmup_percent_done, self.power))
return tf.cond(global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name)
def get_config(self):
return {
'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn,
'warmup_steps': self.warmup_steps,
'power': self.power,
'name': self.name
}
def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr,
decay_steps=num_train_steps,
end_learning_rate=0.0)
if num_warmup_steps:
learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=num_warmup_steps)
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
return optimizer
class AdamWeightDecay(tf.keras.optimizers.Adam):
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
Just adding the square of the weights to the loss function is *not* the
correct way of using L2 regularization/weight decay with Adam, since that will
interact with the m and v parameters in strange ways.
Instead we want ot decay the weights in a manner that doesn't interact with
the m/v parameters. This is equivalent to adding the square of the weights to
the loss with plain (non-momentum) SGD.
"""
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
amsgrad=False,
weight_decay_rate=0.0,
exclude_from_weight_decay=None,
name='AdamWeightDecay',
**kwargs):
super(AdamWeightDecay, self).__init__(
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
self._set_hyper('weight_decay_rate', weight_decay_rate)
self._exclude_from_weight_decay = exclude_from_weight_decay
@classmethod
def from_config(cls, config):
"""Creates an optimizer from its config with WarmUp custom object."""
custom_objects = {'WarmUp': WarmUp}
return super(AdamWeightDecay, cls).from_config(
config, custom_objects=custom_objects)
def _decay_weights_op(self, var, learning_rate):
do_decay = self._do_use_weight_decay(var.name)
if do_decay:
return var.assign_sub(
learning_rate * var *
self._get_hyper('weight_decay_rate'),
use_locking=self._use_locking)
return tf.no_op()
def apply_gradients(self, grads_and_vars, name=None):
grads, tvars = list(zip(*grads_and_vars))
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
def _resource_apply_dense(self, grad, var):
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
with tf.control_dependencies([self._decay_weights_op(var, lr_t)]):
return super(AdamWeightDecay, self)._resource_apply_dense(
grad, var)
def _resource_apply_sparse(self, grad, var, indices):
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
with tf.control_dependencies([self._decay_weights_op(var, lr_t)]):
return super(AdamWeightDecay, self)._resource_apply_sparse(
grad, var, indices)
def get_config(self):
config = super(AdamWeightDecay, self).get_config()
config.update({
'weight_decay_rate':
self._serialize_hyperparameter('weight_decay_rate'),
})
return config
def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if self.weight_decay_rate == 0:
return False
if self._exclude_from_weight_decay:
for r in self._exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT classification finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import math
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.bert import bert_models
from official.bert import input_pipeline
from official.bert import model_saving_utils
from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
flags.DEFINE_enum(
'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
'One of {"train_and_eval", "export_only"}. `train_and_eval`: '
'trains the model and evaluates in the meantime. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.')
flags.DEFINE_string(
'model_dir', None,
('The directory where the model weights and training/evaluation summaries '
'are stored. If not specified, save to /tmp/bert20/.'))
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_string('train_data_path', None,
'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None,
'Path to evaluation data for BERT classifier.')
flags.DEFINE_string(
'init_checkpoint', None,
'Initial checkpoint (usually from a pre-trained BERT model).')
flags.DEFINE_string(
'model_export_path', None,
'Path to the directory, where trainined model will be '
'exported.')
flags.DEFINE_enum(
'strategy_type',
'mirror',
['tpu', 'mirror'],
'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with '
'single host.')
# Model training specific flags.
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 8, 'Batch size for evaluation.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer('steps_per_run', 200,
'Number of steps running on TPU devices.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
FLAGS = flags.FLAGS
def write_eval_results(output_dir, results):
"""Writes and prints evaluation results.
Args:
output_dir: string, the path to the output directory.
results: a dictionary of evaluation metrics.
"""
output_eval_file = os.path.join(output_dir, 'eval_results.txt')
with tf.io.gfile.GFile(output_eval_file, 'w') as writer:
logging.info('***** Eval results *****')
for key, val in results.items():
logging.info(' %s = %s', key, str(val))
writer.write('%s = %s\n' % (key, str(val)))
def get_loss_fn(num_classes, loss_scale=1.0):
"""Gets the classification loss function."""
def classification_loss_fn(labels, logits):
"""Classification loss."""
labels = tf.squeeze(labels)
log_probs = tf.nn.log_softmax(logits, axis=-1)
one_hot_labels = tf.one_hot(
tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
loss *= loss_scale
return loss
return classification_loss_fn
def run_customized_training(strategy,
bert_config,
input_meta_data,
model_dir,
epochs,
steps_per_epoch,
eval_steps,
warmup_steps,
initial_lr,
init_checkpoint,
use_remote_tpu=False):
"""Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels']
train_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
FLAGS.train_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.train_batch_size)
eval_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
FLAGS.eval_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.eval_batch_size,
is_training=False,
drop_remainder=False)
def _get_classifier_model():
classifier_model, core_model = (
bert_models.classifier_model(bert_config, tf.float32, num_classes,
max_seq_length))
classifier_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
return classifier_model, core_model
loss_fn = get_loss_fn(num_classes, loss_scale=1.0)
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
def metric_fn():
return tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
return model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_classifier_model,
loss_fn=loss_fn,
model_dir=model_dir,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
eval_steps=eval_steps,
init_checkpoint=init_checkpoint,
metric_fn=metric_fn,
use_remote_tpu=use_remote_tpu)
def export_classifier(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
def _model_fn():
return bert_models.classifier_model(bert_config, tf.float32,
input_meta_data['num_labels'],
input_meta_data['max_seq_length'])[0]
model_saving_utils.export_bert_model(
model_export_path, model_fn=_model_fn, checkpoint_dir=FLAGS.model_dir)
def run_bert(strategy, input_meta_data):
"""Run BERT training."""
if FLAGS.mode == 'export_only':
export_classifier(FLAGS.model_export_path, input_meta_data)
return
if FLAGS.mode != 'train_and_eval':
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
train_data_size = input_meta_data['train_data_size']
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
if not strategy:
raise ValueError('Distribution strategy has not been specified.')
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
trained_model = run_customized_training(
strategy,
bert_config,
input_meta_data,
FLAGS.model_dir,
epochs,
steps_per_epoch,
eval_steps,
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu)
if FLAGS.model_export_path:
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = None
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else '')
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu)
tf.config.experimental_connect_to_host(cluster_resolver.master()) # pylint: disable=line-too-long
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
run_bert(strategy, input_meta_data)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('input_meta_data_path')
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run masked LM/next sentence masked_lm pre-training for BERT in tf2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.bert import bert_models
from official.bert import input_pipeline
from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
flags.DEFINE_string('input_files', None,
'File path to retrieve training data for pre-training.')
flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.')
flags.DEFINE_string(
'model_dir', None,
('The directory where the model weights and training/evaluation summaries '
'are stored. If not specified, save to /tmp/bert20/.'))
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_enum(
'strategy_type',
'mirror',
['tpu', 'mirror'],
'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with '
'single host.')
# Model training specific flags.
flags.DEFINE_integer(
'max_seq_length', 128,
'The maximum total input sequence length after WordPiece tokenization. '
'Sequences longer than this will be truncated, and sequences shorter '
'than this will be padded.')
flags.DEFINE_integer('max_predictions_per_seq', 20,
'Maximum predictions per sequence_output.')
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer(
'steps_per_run', 1000,
'Number of steps to run in TPU worker before returning to host.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.')
FLAGS = flags.FLAGS
def get_pretrain_input_data(input_file_pattern, seq_length,
max_predictions_per_seq, batch_size):
"""Returns input dataset from input file string."""
input_files = []
for input_pattern in input_file_pattern.split(','):
input_files.extend(tf.io.gfile.glob(input_pattern))
train_dataset = input_pipeline.create_pretrain_dataset(
input_files, seq_length, max_predictions_per_seq, batch_size)
return train_dataset
def get_loss_fn(loss_scale=1.0):
"""Returns loss function for BERT pretraining."""
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
return tf.keras.backend.mean(losses) * loss_scale
return _bert_pretrain_loss_fn
def run_customized_training(strategy,
bert_config,
max_seq_length,
max_predictions_per_seq,
model_dir,
steps_per_epoch,
epochs,
initial_lr,
warmup_steps,
input_files,
train_batch_size,
use_remote_tpu=False):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = functools.partial(get_pretrain_input_data, input_files,
max_seq_length, max_predictions_per_seq,
train_batch_size)
def _get_pretrain_model():
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq)
pretrain_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
return pretrain_model, core_model
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_pretrain_model,
loss_fn=get_loss_fn(),
model_dir=model_dir,
train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
use_remote_tpu=use_remote_tpu)
def run_bert_pretrain(strategy):
"""Runs BERT pre-training."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if not strategy:
raise ValueError('Distribution strategy is not specified.')
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
return run_customized_training(
strategy,
bert_config,
FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq,
FLAGS.model_dir,
FLAGS.num_steps_per_epoch,
FLAGS.num_train_epochs,
FLAGS.learning_rate,
FLAGS.warmup_steps,
FLAGS.input_files,
FLAGS.train_batch_size,
use_remote_tpu=use_remote_tpu)
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = None
if FLAGS.strategy_type == 'tpu':
logging.info('Use TPU at %s',
FLAGS.tpu if FLAGS.tpu is not None else '')
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu)
tf.config.experimental_connect_to_host(cluster_resolver.master()) # pylint: disable=line-too-long
tf.contrib.distribute.initialize_tpu_system(cluster_resolver)
strategy = tf.contrib.distribute.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
elif FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
if strategy:
print('***** Number of cores used : ', strategy.num_replicas_in_sync)
run_bert_pretrain(strategy)
if __name__ == '__main__':
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in tf2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.bert import bert_models
from official.bert import input_pipeline
from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
from official.bert import squad_lib
from official.bert import tokenization
flags.DEFINE_bool('do_train', False, 'Whether to run training.')
flags.DEFINE_bool('do_predict', False, 'Whether to run eval on the dev set.')
flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.')
flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.')
flags.DEFINE_string(
'model_dir', None,
('The directory where the model weights and training/evaluation summaries '
'are stored.'))
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_string(
'init_checkpoint', None,
'Initial checkpoint (usually from a pre-trained BERT model).')
flags.DEFINE_enum(
'strategy_type',
'mirror',
['tpu', 'mirror'],
'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with '
'single host.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer('steps_per_run', 200,
'Number of steps running on TPU devices.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.')
flags.DEFINE_string('vocab_file', None,
'The vocabulary file that the BERT model was trained on.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
'models and False for cased models.')
flags.DEFINE_bool(
'verbose_logging', False,
'If true, all of the warnings related to data processing will be printed. '
'A number of warnings are expected for a normal SQuAD evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
'Total batch size for prediction.')
flags.DEFINE_integer(
'n_best_size', 20,
'The total number of n-best predictions to generate in the '
'nbest_predictions.json output file.')
flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.')
FLAGS = flags.FLAGS
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits,
loss_scale=1.0):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.backend.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.backend.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_scale
return total_loss
def get_loss_fn(loss_scale=1.0):
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
_, start_logits, end_logits = model_outputs
return squad_loss_fn(
start_positions,
end_positions,
start_logits,
end_logits,
loss_scale=loss_scale)
return _loss_fn
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
for values in zip(unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy()):
yield squad_lib.RawResult(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist())
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model."""
primary_cpu_task = '/job:worker' if FLAGS.tpu else ''
with tf.device(primary_cpu_task):
predict_dataset = input_pipeline.create_squad_dataset(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = strategy.make_dataset_iterator(predict_dataset)
with strategy.scope():
squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path)
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids, start_logits, end_logits = squad_model(x, training=False)
return dict(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.experimental_run(replicated_step, iterator)
return tf.nest.map_structure(strategy.unwrap, outputs)
all_results = []
for _ in range(num_steps):
predictions = predict_step(predict_iterator)
for result in get_raw_results(predictions):
all_results.append(result)
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy, input_meta_data):
"""Run bert squad training."""
if not strategy:
raise ValueError('Distribution strategy cannot be None.')
logging.info('Training using customized training loop with distribution'
' strategy.')
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = functools.partial(
input_pipeline.create_squad_dataset,
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
def _get_squad_model():
squad_model, core_model = bert_models.squad_model(
bert_config, max_seq_length, float_type=tf.float32)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(loss_scale=1.0)
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=loss_fn,
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
dataset_size = squad_lib.convert_examples_to_features(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
eval_writer.close()
logging.info('***** Running predictions *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
squad_lib.write_predictions(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
verbose=FLAGS.verbose_logging)
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
strategy = None
if FLAGS.strategy_type == 'tpu':
logging.info('Use TPU at %s',
FLAGS.tpu if FLAGS.tpu is not None else '')
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu)
tf.config.experimental_connect_to_host(cluster_resolver.master()) # pylint: disable=line-too-long
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
elif FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'multi_worker_mirror':
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
if FLAGS.do_train:
train_squad(strategy, input_meta_data)
if FLAGS.do_predict:
predict_squad(strategy, input_meta_data)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('model_dir')
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to process data for SQuAD 1.1 and SQuAD 2.0."""
# pylint: disable=g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import json
import math
import six
from absl import logging
import tensorflow as tf
from official.bert import tokenization
# pylint: enable=g-bad-import-order
class SquadExample(object):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None,
is_impossible=False):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
s += ", question_text: %s" % (
tokenization.printable_text(self.question_text))
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position:
s += ", start_position: %d" % (self.start_position)
if self.start_position:
s += ", end_position: %d" % (self.end_position)
if self.start_position:
s += ", is_impossible: %r" % (self.is_impossible)
return s
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tokens,
token_to_orig_map,
token_is_max_context,
input_ids,
input_mask,
segment_ids,
start_position=None,
end_position=None,
is_impossible=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
class FeatureWriter(object):
"""Writes InputFeature to TF example file."""
def __init__(self, filename, is_training):
self.filename = filename
self.is_training = is_training
self.num_features = 0
self._writer = tf.io.TFRecordWriter(filename)
def process_feature(self, feature):
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
self.num_features += 1
def create_int_feature(values):
feature = tf.train.Feature(
int64_list=tf.train.Int64List(value=list(values)))
return feature
features = collections.OrderedDict()
features["unique_ids"] = create_int_feature([feature.unique_id])
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
if self.is_training:
features["start_positions"] = create_int_feature([feature.start_position])
features["end_positions"] = create_int_feature([feature.end_position])
impossible = 0
if feature.is_impossible:
impossible = 1
features["is_impossible"] = create_int_feature([impossible])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
self._writer.write(tf_example.SerializeToString())
def close(self):
self._writer.close()
def read_squad_examples(input_file, is_training, version_2_with_negative):
"""Read a SQuAD json file into a list of SquadExample."""
with tf.io.gfile.GFile(input_file, "r") as reader:
input_data = json.load(reader)["data"]
def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True
return False
examples = []
for entry in input_data:
for paragraph in entry["paragraphs"]:
paragraph_text = paragraph["context"]
doc_tokens = []
char_to_word_offset = []
prev_is_whitespace = True
for c in paragraph_text:
if is_whitespace(c):
prev_is_whitespace = True
else:
if prev_is_whitespace:
doc_tokens.append(c)
else:
doc_tokens[-1] += c
prev_is_whitespace = False
char_to_word_offset.append(len(doc_tokens) - 1)
for qa in paragraph["qas"]:
qas_id = qa["id"]
question_text = qa["question"]
start_position = None
end_position = None
orig_answer_text = None
is_impossible = False
if is_training:
if version_2_with_negative:
is_impossible = qa["is_impossible"]
if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError(
"For training, each question should have exactly 1 answer.")
if not is_impossible:
answer = qa["answers"][0]
orig_answer_text = answer["text"]
answer_offset = answer["answer_start"]
answer_length = len(orig_answer_text)
start_position = char_to_word_offset[answer_offset]
end_position = char_to_word_offset[answer_offset + answer_length -
1]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text = " ".join(
doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = " ".join(
tokenization.whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logging.warning("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text)
continue
else:
start_position = -1
end_position = -1
orig_answer_text = ""
example = SquadExample(
qas_id=qas_id,
question_text=question_text,
doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text,
start_position=start_position,
end_position=end_position,
is_impossible=is_impossible)
examples.append(example)
return examples
def convert_examples_to_features(examples,
tokenizer,
max_seq_length,
doc_stride,
max_query_length,
is_training,
output_fn,
batch_size=None):
"""Loads a data file into a list of `InputBatch`s."""
base_id = 1000000000
unique_id = base_id
feature = None
for (example_index, example) in enumerate(examples):
query_tokens = tokenizer.tokenize(example.question_text)
if len(query_tokens) > max_query_length:
query_tokens = query_tokens[0:max_query_length]
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
tok_start_position = None
tok_end_position = None
if is_training and example.is_impossible:
tok_start_position = -1
tok_end_position = -1
if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1:
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
else:
tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
example.orig_answer_text)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
"DocSpan", ["start", "length"])
doc_spans = []
start_offset = 0
while start_offset < len(all_doc_tokens):
length = len(all_doc_tokens) - start_offset
if length > max_tokens_for_doc:
length = max_tokens_for_doc
doc_spans.append(_DocSpan(start=start_offset, length=length))
if start_offset + length == len(all_doc_tokens):
break
start_offset += min(length, doc_stride)
for (doc_span_index, doc_span) in enumerate(doc_spans):
tokens = []
token_to_orig_map = {}
token_is_max_context = {}
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in query_tokens:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for i in range(doc_span.length):
split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
start_position = None
end_position = None
if is_training and not example.is_impossible:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start = doc_span.start
doc_end = doc_span.start + doc_span.length - 1
out_of_span = False
if not (tok_start_position >= doc_start and
tok_end_position <= doc_end):
out_of_span = True
if out_of_span:
start_position = 0
end_position = 0
else:
doc_offset = len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
if is_training and example.is_impossible:
start_position = 0
end_position = 0
if example_index < 20:
logging.info("*** Example ***")
logging.info("unique_id: %s", (unique_id))
logging.info("example_index: %s", (example_index))
logging.info("doc_span_index: %s", (doc_span_index))
logging.info("tokens: %s",
" ".join([tokenization.printable_text(x) for x in tokens]))
logging.info(
"token_to_orig_map: %s", " ".join([
"%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)
]))
logging.info(
"token_is_max_context: %s", " ".join([
"%d:%s" % (x, y)
for (x, y) in six.iteritems(token_is_max_context)
]))
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
if is_training and example.is_impossible:
logging.info("impossible example")
if is_training and not example.is_impossible:
answer_text = " ".join(tokens[start_position:(end_position + 1)])
logging.info("start_position: %d", (start_position))
logging.info("end_position: %d", (end_position))
logging.info("answer: %s", tokenization.printable_text(answer_text))
feature = InputFeatures(
unique_id=unique_id,
example_index=example_index,
doc_span_index=doc_span_index,
tokens=tokens,
token_to_orig_map=token_to_orig_map,
token_is_max_context=token_is_max_context,
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
start_position=start_position,
end_position=end_position,
is_impossible=example.is_impossible)
# Run callback
if is_training:
output_fn(feature)
else:
output_fn(feature, is_padding=False)
unique_id += 1
if not is_training and feature:
assert batch_size
num_padding = 0
num_examples = unique_id - base_id
if unique_id % batch_size != 0:
num_padding = batch_size - (num_examples % batch_size)
logging.info("Adding padding examples to make sure no partial batch.")
logging.info("Adds %d padding examples for inference.", num_padding)
dummy_feature = copy.deepcopy(feature)
for _ in range(num_padding):
dummy_feature.unique_id = unique_id
# Run callback
output_fn(feature, is_padding=True)
unique_id += 1
return unique_id - base_id
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
orig_answer_text):
"""Returns tokenized answer spans that better match the annotated answer."""
# The SQuAD annotations are character based. We first project them to
# whitespace-tokenized words. But then after WordPiece tokenization, we can
# often find a "better match". For example:
#
# Question: What year was John Smith born?
# Context: The leader was John Smith (1895-1943).
# Answer: 1895
#
# The original whitespace-tokenized answer will be "(1895-1943).". However
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
# the exact answer, 1895.
#
# However, this is not always possible. Consider the following:
#
# Question: What country is the top exporter of electornics?
# Context: The Japanese electronics industry is the lagest in the world.
# Answer: Japan
#
# In this case, the annotator chose "Japan" as a character sub-span of
# the word "Japanese". Since our WordPiece tokenizer does not split
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
# in SQuAD, but does happen.
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
if text_span == tok_answer_text:
return (new_start, new_end)
return (input_start, input_end)
def _check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the 'max context' doc span for the token."""
# Because of the sliding window approach taken to scoring documents, a single
# token can appear in multiple documents. E.g.
# Doc: the man went to the store and bought a gallon of milk
# Span A: the man went to the
# Span B: to the store and bought
# Span C: and bought a gallon of
# ...
#
# Now the word 'bought' will have two scores from spans B and C. We only
# want to consider the score with "maximum context", which we define as
# the *minimum* of its left and right context (the *sum* of left and
# right context will always be the same, of course).
#
# In the example the maximum context for 'bought' would be span C since
# it has 1 left context and 3 right context, while span B has 4 left context
# and 0 right context.
best_score = None
best_span_index = None
for (span_index, doc_span) in enumerate(doc_spans):
end = doc_span.start + doc_span.length - 1
if position < doc_span.start:
continue
if position > end:
continue
num_left_context = position - doc_span.start
num_right_context = end - position
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
if best_score is None or score > best_score:
best_score = score
best_span_index = span_index
return cur_span_index == best_span_index
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples,
all_features,
all_results,
n_best_size,
max_answer_length,
do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=False,
null_score_diff_threshold=0.0,
verbose=False):
"""Write final predictions to the json file and log-odds of null if needed."""
logging.info("Writing predictions to: %s", (output_prediction_file))
logging.info("Writing nbest to: %s", (output_nbest_file))
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
example_index_to_features[feature.example_index].append(feature)
unique_id_to_result = {}
for result in all_results:
unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction",
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict()
for (example_index, example) in enumerate(all_examples):
features = example_index_to_features[example_index]
prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
min_null_feature_index = 0 # the paragraph slice with min mull score
null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= len(feature.tokens):
continue
if end_index >= len(feature.tokens):
continue
if start_index not in feature.token_to_orig_map:
continue
if end_index not in feature.token_to_orig_map:
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
if version_2_with_negative:
prelim_predictions.append(
_PrelimPrediction(
feature_index=min_null_feature_index,
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit))
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"])
seen_predictions = {}
nbest = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(
tok_text, orig_text, do_lower_case, verbose=verbose)
if final_text in seen_predictions:
continue
seen_predictions[final_text] = True
else:
final_text = ""
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it
if version_2_with_negative:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="", start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
probs = _compute_softmax(total_scores)
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["probability"] = probs[i]
output["start_logit"] = entry.start_logit
output["end_logit"] = entry.end_logit
nbest_json.append(output)
assert len(nbest_json) >= 1
if not version_2_with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
# pytype: disable=attribute-error
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
# pytype: enable=attribute-error
all_nbest_json[example.qas_id] = nbest_json
with tf.io.gfile.GFile(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heruistic between
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def _strip_spaces(text):
ns_chars = []
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == " ":
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = "".join(ns_chars)
return (ns_text, ns_to_s_map)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text))
start_position = tok_text.find(pred_text)
if start_position == -1:
if verbose:
logging.info("Unable to find text: '%s' in '%s'", pred_text, orig_text)
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
if verbose:
logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text)
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
if start_position in tok_s_to_ns_map:
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
if verbose:
logging.info("Couldn't map start position")
return orig_text
orig_end_position = None
if end_position in tok_s_to_ns_map:
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
if verbose:
logging.info("Couldn't map end position")
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
return output_text
def _get_best_indexes(logits, n_best_size):
"""Get the n-best logits from a list."""
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
best_indexes = []
for i in range(len(index_and_score)): # pylint: disable=consider-using-enumerate
if i >= n_best_size:
break
best_indexes.append(index_and_score[i][0])
return best_indexes
def _compute_softmax(scores):
"""Compute softmax probability over raw logits."""
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
def generate_tf_record_from_json_file(input_file_path,
vocab_file_path,
output_path,
max_seq_length=384,
do_lower_case=True,
max_query_length=64,
doc_stride=128,
version_2_with_negative=False):
"""Generates and saves training data into a tf record file."""
train_examples = read_squad_examples(
input_file=input_file_path,
is_training=True,
version_2_with_negative=version_2_with_negative)
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file_path, do_lower_case=do_lower_case)
train_writer = FeatureWriter(filename=output_path, is_training=True)
number_of_examples = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=True,
output_fn=train_writer.process_feature)
train_writer.close()
meta_data = {
"task_type": "bert_squad",
"train_data_size": number_of_examples,
"max_seq_length": max_seq_length,
"max_query_length": max_query_length,
"doc_stride": doc_stride,
"version_2_with_negative": version_2_with_negative,
}
return meta_data
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tokenization classes implementation.
The file is forked from:
https://github.com/google-research/bert/blob/master/tokenization.py.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
import tensorflow as tf
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." %
(actual_flag, init_checkpoint, model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with tf.io.gfile.GFile(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import six
import tensorflow as tf
from official.bert import tokenization
class TokenizationTest(tf.test.TestCase):
"""Tokenization test.
The implementation is forked from
https://github.com/google-research/bert/blob/master/tokenization_test.py."
"""
def test_full_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
]
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
if six.PY2:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
else:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens
]).encode("utf-8"))
vocab_file = vocab_writer.name
tokenizer = tokenization.FullTokenizer(vocab_file)
os.unlink(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertAllEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self):
tokenizer = tokenization.BasicTokenizer()
self.assertAllEqual(
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
[u"ah", u"\u535A", u"\u63A8", u"zz"])
def test_basic_tokenizer_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello", "!", "how", "are", "you", "?"])
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_basic_tokenizer_no_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_wordpiece_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing"
]
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
self.assertAllEqual(tokenizer.tokenize(""), [])
self.assertAllEqual(
tokenizer.tokenize("unwanted running"),
["un", "##want", "##ed", "runn", "##ing"])
self.assertAllEqual(
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
def test_convert_tokens_to_ids(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing"
]
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
self.assertAllEqual(
tokenization.convert_tokens_to_ids(
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
def test_is_whitespace(self):
self.assertTrue(tokenization._is_whitespace(u" "))
self.assertTrue(tokenization._is_whitespace(u"\t"))
self.assertTrue(tokenization._is_whitespace(u"\r"))
self.assertTrue(tokenization._is_whitespace(u"\n"))
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
self.assertFalse(tokenization._is_whitespace(u"A"))
self.assertFalse(tokenization._is_whitespace(u"-"))
def test_is_control(self):
self.assertTrue(tokenization._is_control(u"\u0005"))
self.assertFalse(tokenization._is_control(u"A"))
self.assertFalse(tokenization._is_control(u" "))
self.assertFalse(tokenization._is_control(u"\t"))
self.assertFalse(tokenization._is_control(u"\r"))
self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
def test_is_punctuation(self):
self.assertTrue(tokenization._is_punctuation(u"-"))
self.assertTrue(tokenization._is_punctuation(u"$"))
self.assertTrue(tokenization._is_punctuation(u"`"))
self.assertTrue(tokenization._is_punctuation(u"."))
self.assertFalse(tokenization._is_punctuation(u"A"))
self.assertFalse(tokenization._is_punctuation(u" "))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment