Unverified Commit 30d14a96 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#6847)

249377254  by hongkuny<hongkuny@google.com>:

    Internal change

249373328  by hongkuny<hongkuny@google.com>:

    Clean up tf import

--
249333938  by hongkuny<hongkuny@google.com>:

    Fix tf1 import

--
249325089  by hongkuny<hongkuny@google.com>:

    BERT 2.0

--
249173564  by hongkuny<hongkuny@google.com>:

    Internal change

PiperOrigin-RevId: 249377254
parent 1529b82c
# BERT in TensorFlow
Note> Please do not create pull request. This model is still under development
and testing.
The academic paper which describes BERT in detail and provides full results on a
number of tasks can be found here: https://arxiv.org/abs/1810.04805.
This repository contains TensorFlow 2.0 implementation for BERT.
Since the repository is under active development at this moment, the source of
truth is TensorFlow team's internal repository. The repo is not officially
released as it is not stable and requires extensive testing.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf
from official.bert import modeling
def gather_indexes(sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of `BertModel` layer of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining of
with dimension (batch_size, max_predictions_per_seq) where
`max_predictions_per_seq` is maximum number of tokens to mask out and
predict per each sequence.
Returns:
Masked out sequence tensor of shape (batch_size * max_predictions_per_seq,
num_hidden).
"""
sequence_shape = modeling.get_shape_list(
sequence_tensor, name='sequence_output_tensor')
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.keras.backend.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.keras.backend.reshape(
sequence_tensor, [batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
class BertPretrainLayer(tf.keras.layers.Layer):
"""Wrapper layer for pre-training a BERT model.
This layer wraps an existing `bert_layer` which is a Keras Layer.
It outputs `sequence_output` from TransformerBlock sub-layer and
`sentence_output` which are suitable for feeding into a BertPretrainLoss
layer. This layer can be used along with an unsupervised input to
pre-train the embeddings for `bert_layer`.
"""
def __init__(self,
config,
bert_layer,
initializer=None,
float_type=tf.float32,
**kwargs):
super(BertPretrainLayer, self).__init__(**kwargs)
self.config = copy.deepcopy(config)
self.float_type = float_type
self.embedding_table = bert_layer.embedding_lookup.embeddings
self.num_next_sentence_label = 2
if initializer:
self.initializer = initializer
else:
self.initializer = tf.keras.initializers.TruncatedNormal(
stddev=self.config.initializer_range)
def build(self, unused_input_shapes):
self.lm_dense = tf.keras.layers.Dense(
self.config.hidden_size,
activation=modeling.get_activation(self.config.hidden_act),
kernel_initializer=self.initializer)
self.lm_bias = self.add_weight(
shape=[self.config.vocab_size],
name='lm_bias',
initializer=tf.keras.initializers.Zeros())
self.lm_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12)
self.next_sentence_dense = tf.keras.layers.Dense(
self.num_next_sentence_label, kernel_initializer=self.initializer)
super(BertPretrainLayer, self).build(unused_input_shapes)
def __call__(self,
pooled_output,
sequence_output=None,
masked_lm_positions=None):
inputs = modeling.pack_inputs(
[pooled_output, sequence_output, masked_lm_positions])
return super(BertPretrainLayer, self).__call__(inputs)
def call(self, inputs):
unpacked_inputs = modeling.unpack_inputs(inputs)
pooled_output = unpacked_inputs[0]
sequence_output = unpacked_inputs[1]
masked_lm_positions = unpacked_inputs[2]
mask_lm_input_tensor = gather_indexes(
sequence_output, masked_lm_positions)
lm_output = self.lm_dense(mask_lm_input_tensor)
lm_output = self.lm_layer_norm(lm_output)
lm_output = tf.keras.backend.dot(
lm_output, tf.keras.backend.transpose(self.embedding_table))
lm_output = tf.keras.backend.bias_add(lm_output, self.lm_bias)
lm_output = tf.keras.backend.softmax(lm_output)
lm_output = tf.keras.backend.log(lm_output)
sentence_output = self.next_sentence_dense(pooled_output)
sentence_output = tf.keras.backend.softmax(sentence_output)
sentence_output = tf.keras.backend.log(sentence_output)
return (lm_output, sentence_output)
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
"""Returns layer that computes custom loss and metrics for pretraining."""
def __init__(self, bert_config, **kwargs):
super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
self.config = copy.deepcopy(bert_config)
def __call__(self,
lm_output,
sentence_output=None,
lm_label_ids=None,
lm_label_weights=None,
sentence_labels=None):
inputs = modeling.pack_inputs([
lm_output, sentence_output, lm_label_ids, lm_label_weights,
sentence_labels
])
return super(BertPretrainLossAndMetricLayer, self).__call__(inputs)
def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
lm_per_example_loss, sentence_output, sentence_labels,
sentence_per_example_loss):
masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
lm_labels, lm_output)
masked_lm_accuracy = tf.reduce_mean(masked_lm_accuracy * lm_label_weights)
self.add_metric(
masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')
lm_example_loss = tf.reshape(lm_per_example_loss, [-1])
lm_example_loss = tf.reduce_mean(lm_example_loss * lm_label_weights)
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
sentence_labels, sentence_output)
self.add_metric(
next_sentence_accuracy,
name='next_sentence_accuracy',
aggregation='mean')
next_sentence_mean_loss = tf.reduce_mean(sentence_per_example_loss)
self.add_metric(
next_sentence_mean_loss, name='next_sentence_loss', aggregation='mean')
def call(self, inputs):
unpacked_inputs = modeling.unpack_inputs(inputs)
lm_output = unpacked_inputs[0]
sentence_output = unpacked_inputs[1]
lm_label_ids = tf.keras.backend.cast(unpacked_inputs[2], tf.int32)
lm_label_ids = tf.keras.backend.reshape(lm_label_ids, [-1])
lm_label_ids_one_hot = tf.keras.backend.one_hot(lm_label_ids,
self.config.vocab_size)
lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3], tf.float32)
lm_label_weights = tf.keras.backend.reshape(lm_label_weights, [-1])
lm_per_example_loss = -tf.keras.backend.sum(
lm_output * lm_label_ids_one_hot, axis=[-1])
numerator = tf.keras.backend.sum(lm_label_weights * lm_per_example_loss)
denominator = tf.keras.backend.sum(lm_label_weights) + 1e-5
mask_label_loss = numerator / denominator
sentence_labels = tf.keras.backend.cast(unpacked_inputs[4], dtype=tf.int32)
sentence_labels = tf.keras.backend.reshape(sentence_labels, [-1])
sentence_label_one_hot = tf.keras.backend.one_hot(sentence_labels, 2)
per_example_loss_sentence = -tf.keras.backend.sum(
sentence_label_one_hot * sentence_output, axis=-1)
sentence_loss = tf.keras.backend.mean(per_example_loss_sentence)
loss = mask_label_loss + sentence_loss
final_loss = tf.fill(
tf.keras.backend.shape(per_example_loss_sentence), loss)
self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
lm_per_example_loss, sentence_output, sentence_labels,
per_example_loss_sentence)
return final_loss
def pretrain_model(bert_config,
seq_length,
max_predictions_per_seq,
initializer=None):
"""Returns model to be used for pre-training.
Args:
bert_config: Configuration that defines the core BERT model.
seq_length: Maximum sequence length of the training data.
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining.
initializer: Initializer for weights in BertPretrainLayer.
Returns:
Pretraining model as well as core BERT submodel from which to save
weights after pretraining.
"""
input_word_ids = tf.keras.layers.Input(
shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
input_mask = tf.keras.layers.Input(
shape=(seq_length,), name='input_mask', dtype=tf.int32)
input_type_ids = tf.keras.layers.Input(
shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
masked_lm_positions = tf.keras.layers.Input(
shape=(max_predictions_per_seq,),
name='masked_lm_positions',
dtype=tf.int32)
masked_lm_weights = tf.keras.layers.Input(
shape=(max_predictions_per_seq,),
name='masked_lm_weights',
dtype=tf.int32)
next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
masked_lm_ids = tf.keras.layers.Input(
shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
bert_submodel_name = 'bert_core_layer'
bert_submodel = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
name=bert_submodel_name,
config=bert_config)
pooled_output = bert_submodel.outputs[0]
sequence_output = bert_submodel.outputs[1]
pretrain_layer = BertPretrainLayer(
bert_config,
bert_submodel.get_layer(bert_submodel_name),
initializer=initializer)
lm_output, sentence_output = pretrain_layer(pooled_output, sequence_output,
masked_lm_positions)
pretrain_loss_layer = BertPretrainLossAndMetricLayer(bert_config)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
masked_lm_weights, next_sentence_labels)
return tf.keras.Model(
inputs={
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids,
'masked_lm_positions': masked_lm_positions,
'masked_lm_ids': masked_lm_ids,
'masked_lm_weights': masked_lm_weights,
'next_sentence_labels': next_sentence_labels,
},
outputs=output_loss), bert_submodel
class BertSquadLogitsLayer(tf.keras.layers.Layer):
"""Returns a layer that computes custom logits for BERT squad model."""
def __init__(self, initializer=None, float_type=tf.float32, **kwargs):
super(BertSquadLogitsLayer, self).__init__(**kwargs)
self.initializer = initializer
self.float_type = float_type
def build(self, unused_input_shapes):
self.final_dense = tf.keras.layers.Dense(
units=2, kernel_initializer=self.initializer, name='final_dense')
super(BertSquadLogitsLayer, self).build(unused_input_shapes)
def call(self, inputs):
sequence_output = inputs
input_shape = sequence_output.shape.as_list()
sequence_length = input_shape[1]
num_hidden_units = input_shape[2]
final_hidden_input = tf.keras.backend.reshape(sequence_output,
[-1, num_hidden_units])
logits = self.final_dense(final_hidden_input)
logits = tf.keras.backend.reshape(logits, [-1, sequence_length, 2])
logits = tf.transpose(logits, [2, 0, 1])
unstacked_logits = tf.unstack(logits, axis=0)
return unstacked_logits[0], unstacked_logits[1]
def squad_model(bert_config, max_seq_length, float_type, initializer=None):
"""Returns BERT Squad model along with core BERT model to import weights.
Args:
bert_config: BertConfig, the config defines the core Bert model.
max_seq_length: integer, the maximum input sequence length.
float_type: tf.dtype, tf.float32 or tf.bfloat16.
initializer: Initializer for weights in BertSquadLogitsLayer.
Returns:
Two tensors, start logits and end logits, [batch x sequence length].
"""
unique_ids = tf.keras.layers.Input(
shape=(1,), dtype=tf.int32, name='unique_ids')
input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_ids')
input_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='segment_ids')
core_model = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
name='bert_model',
float_type=float_type)
# `BertSquadModel` only uses the sequnce_output which
# has dimensionality (batch_size, sequence_length, num_hidden).
sequence_output = core_model.outputs[1]
if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
squad_logits_layer = BertSquadLogitsLayer(
initializer=initializer, float_type=float_type, name='squad_logits')
start_logits, end_logits = squad_logits_layer(sequence_output)
squad = tf.keras.Model(
inputs={
'unique_ids': unique_ids,
'input_ids': input_word_ids,
'input_mask': input_mask,
'segment_ids': input_type_ids,
},
outputs=[unique_ids, start_logits, end_logits],
name='squad_model')
return squad, core_model
def classifier_model(bert_config,
float_type,
num_labels,
max_seq_length,
final_layer_initializer=None):
"""BERT classifier model in functional API style.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
bert_config: BertConfig, the config defines the core BERT model.
float_type: dtype, tf.float32 or tf.bfloat16.
num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length.
final_layer_initializer: Initializer for final dense layer. Defaulted
TruncatedNormal initializer.
Returns:
Combined prediction model (words, mask, type) -> (one-hot labels)
BERT sub-model (words, mask, type) -> (bert_outputs)
"""
input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
float_type=float_type)
pooled_output = bert_model.outputs[0]
if final_layer_initializer is not None:
initializer = final_layer_initializer
else:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output)
output = tf.keras.layers.Dense(
num_labels,
kernel_initializer=initializer,
name='output',
dtype=float_type)(
output)
return tf.keras.Model(
inputs={
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids
},
outputs=output), bert_model
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT library to process data for classification task."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import csv
import os
from absl import logging
import tensorflow as tf
from official.bert import tokenization
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
input_ids,
input_mask,
segment_ids,
label_id,
is_real_example=True):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@staticmethod
def get_processor_name():
"""Gets the string identifier of the processor."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.io.gfile.GFile(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def __init__(self):
self.language = "zh"
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % self.language))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[0])
text_b = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[2])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % (i)
language = tokenization.convert_to_unicode(line[0])
if language != tokenization.convert_to_unicode(self.language):
continue
text_a = tokenization.convert_to_unicode(line[6])
text_b = tokenization.convert_to_unicode(line[7])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XNLI"
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
if set_type == "test":
label = "contradiction"
else:
label = tokenization.convert_to_unicode(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MRPC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "COLA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
label = "0"
else:
text_a = tokenization.convert_to_unicode(line[3])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
label_id = label_map[example.label]
if ex_index < 5:
logging.info("*** Example ***")
logging.info("guid: %s" % (example.guid))
logging.info("tokens: %s" %
" ".join([tokenization.printable_text(x) for x in tokens]))
logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %d)" % (example.label, label_id))
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id,
is_real_example=True)
return feature
def file_based_convert_examples_to_features(examples, label_list,
max_seq_length, tokenizer,
output_file):
"""Convert a set of `InputExample`s to a TFRecord file."""
writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logging.info("Writing example %d of %d" % (ex_index, len(examples)))
feature = convert_single_example(ex_index, example, label_list,
max_seq_length, tokenizer)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["label_ids"] = create_int_feature([feature.label_id])
features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def generate_tf_record_from_data_file(processor,
data_dir,
vocab_file,
train_data_output_path=None,
eval_data_output_path=None,
max_seq_length=128,
do_lower_case=True):
"""Generates and saves training data into a tf record file.
Arguments:
processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files
should be in from "dev.tsv", "test.tsv", or "train.tsv".
vocab_file: Text file with words to be used for training/evaluation.
train_data_output_path: Output to which processed tf record for training
will be saved.
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
do_lower_case: Whether to lower case input text.
Returns:
A dictionary containing input meta data.
"""
assert train_data_output_path or eval_data_output_path
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file, do_lower_case=do_lower_case)
assert train_data_output_path
train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer,
train_data_output_path)
num_training_data = len(train_input_data_examples)
if eval_data_output_path:
eval_input_data_examples = processor.get_dev_examples(data_dir)
file_based_convert_examples_to_features(eval_input_data_examples,
label_list, max_seq_length,
tokenizer, eval_data_output_path)
meta_data = {
"task_type": "bert_classification",
"processor_type": processor.get_processor_name(),
"num_labels": len(processor.get_labels()),
"train_data_size": num_training_data,
"max_seq_length": max_seq_length,
}
if eval_data_output_path:
meta_data["eval_data_size"] = len(eval_input_data_examples)
return meta_data
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT finetuning task dataset generator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from absl import app
from absl import flags
import tensorflow as tf
from official.bert import classifier_data_lib
from official.bert import squad_lib
FLAGS = flags.FLAGS
# BERT classification specific flags.
flags.DEFINE_enum(
"fine_tuning_task_type", "classification", ["classification", "squad"],
"The name of the BERT fine tuning task for which data "
"will be generated..")
flags.DEFINE_string(
"input_data_dir", None,
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_string("classification_task_name", None,
"The name of the task to train BERT classifier.")
# BERT Squad task specific flags.
flags.DEFINE_string(
"squad_data_file", None,
"The input data file in for generating training data for BERT squad task.")
flags.DEFINE_integer(
"doc_stride", 128,
"When splitting up a long document into chunks, how much stride to "
"take between chunks.")
flags.DEFINE_integer(
"max_query_length", 64,
"The maximum number of tokens for the question. Questions longer than "
"this will be truncated to this length.")
# Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_string(
"train_data_output_path", None,
"The path in which generated training input data will be written as tf records."
)
flags.DEFINE_string(
"eval_data_output_path", None,
"The path in which generated training input data will be written as tf records."
)
flags.DEFINE_string("meta_data_file_path", None,
"The path in which input meta data will be written.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
flags.DEFINE_bool(
"version_2_with_negative", False,
"If true, the SQuAD examples contain some that do not have an answer.")
def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
assert FLAGS.input_data_dir and FLAGS.classification_task_name
processors = {
"cola": classifier_data_lib.ColaProcessor,
"mnli": classifier_data_lib.MnliProcessor,
"mrpc": classifier_data_lib.MrpcProcessor,
"xnli": classifier_data_lib.XnliProcessor,
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
FLAGS.input_data_dir,
FLAGS.vocab_file,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case)
def generate_squad_dataset():
"""Generates squad training dataset and returns input meta data."""
assert FLAGS.squad_data_file
return squad_lib.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
FLAGS.doc_stride, FLAGS.version_2_with_negative)
def main(_):
if FLAGS.fine_tuning_task_type == "classification":
input_meta_data = generate_classifier_dataset()
else:
input_meta_data = generate_squad_dataset()
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
writer.write(json.dumps(input_meta_data, indent=4) + "\n")
if __name__ == "__main__":
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("train_data_output_path")
flags.mark_flag_as_required("meta_data_file_path")
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT model input pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def file_based_input_fn_builder(input_file, name_to_features):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def input_fn():
"""Returns dataset for training/evaluation."""
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
d = d.map(lambda record: _decode_record(record, name_to_features))
# When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that
# same input file is sent to all workers.
if isinstance(input_file, str) or len(input_file) == 1:
options = tf.data.Options()
options.experimental_distribute.auto_shard = False
d = d.with_options(options)
return d
return input_fn
def create_pretrain_dataset(file_path,
seq_length,
max_predictions_per_seq,
batch_size,
is_training=True):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features = {
'input_ids':
tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask':
tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids':
tf.io.FixedLenFeature([seq_length], tf.int64),
'masked_lm_positions':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_ids':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_weights':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
'next_sentence_labels':
tf.io.FixedLenFeature([1], tf.int64),
}
input_fn = file_based_input_fn_builder(file_path, name_to_features)
dataset = input_fn()
def _select_data_from_record(record):
"""Filter out features to use for pretraining."""
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids'],
'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'],
'next_sentence_labels': record['next_sentence_labels'],
}
y = record['masked_lm_weights']
return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(1024)
return dataset
def create_classifier_dataset(file_path,
seq_length,
batch_size,
is_training=True,
drop_remainder=True):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64),
'is_real_example': tf.io.FixedLenFeature([], tf.int64),
}
input_fn = file_based_input_fn_builder(file_path, name_to_features)
dataset = input_fn()
def _select_data_from_record(record):
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
y = record['label_ids']
return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(1024)
return dataset
def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features = {
'unique_ids': tf.io.FixedLenFeature([], tf.int64),
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
}
if is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
input_fn = file_based_input_fn_builder(file_path, name_to_features)
dataset = input_fn()
def _select_data_from_record(record):
x, y = {}, {}
for name, tensor in record.items():
if name in ('start_positions', 'end_positions'):
y[name] = tensor
else:
x[name] = tensor
return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(1024)
return dataset
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to save models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
import tensorflow as tf
try:
import h5py as _ # pylint: disable=g-import-not-at-top
HAS_H5PY = True
except ImportError:
logging.warning('`h5py` is not installed. Please consider installing it '
'to save weights for long-running training.')
HAS_H5PY = False
def save_model(model, model_dir, weights_file):
"""Saves the model weights."""
weights_file_path = os.path.join(model_dir, weights_file)
del model_dir, weights_file # avoid accident usages.
if not HAS_H5PY:
logging.warning('`h5py` is not installed. Skip saving model weights.')
return
logging.info('Saving weights and optimizer states into %s', weights_file_path)
logging.info('This might take a while...')
model.save(weights_file_path, overwrite=True, include_optimizer=True)
def export_bert_model(model_export_path,
model=None,
model_fn=None,
checkpoint_dir=None):
"""Export BERT model for serving.
Arguments:
model_export_path: Path to which exported model will be saved.
model: Keras model object to export. If none, new model is created via
`model_fn`.
model_fn: Function that returns a BERT model. Used when `model` is not
provided.
checkpoint_dir: Path from which model weights will be loaded.
"""
if model:
tf.keras.experimental.export_saved_model(model, model_export_path)
return
assert model_fn and checkpoint_dir
model_to_export = model_fn()
checkpoint = tf.train.Checkpoint(model=model_to_export)
latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file).assert_existing_objects_matched()
tf.keras.experimental.export_saved_model(model_to_export, model_export_path)
class BertModelCheckpoint(tf.keras.callbacks.Callback):
"""Keras callback that saves model at the end of every epoch."""
def __init__(self, checkpoint_dir, checkpoint):
"""Initializes BertModelCheckpoint.
Arguments:
checkpoint_dir: Directory of the to be saved checkpoint file.
checkpoint: tf.train.Checkpoint object.
"""
super(BertModelCheckpoint, self).__init__()
self.checkpoint_file_name = os.path.join(
checkpoint_dir, 'bert_training_checkpoint_step_{global_step}.ckpt')
assert isinstance(checkpoint, tf.train.Checkpoint)
self.checkpoint = checkpoint
def on_epoch_end(self, epoch, logs=None):
global_step = tf.keras.backend.get_value(self.model.optimizer.iterations)
formatted_file_name = self.checkpoint_file_name.format(
global_step=global_step)
saved_path = self.checkpoint.save(formatted_file_name)
logging.info('Saving model TF checkpoint to : %s', saved_path)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to train BERT models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
import tensorflow as tf
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return '/job:worker' if use_remote_tpu else ''
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to with provided checkpoint prefix."""
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info('Saving model as TF checkpoint: %s', saved_path)
return
def run_customized_training_loop(
# pylint: disable=invalid-name
_sentinel=None,
# pylint: enable=invalid-name
strategy=None,
model_fn=None,
loss_fn=None,
model_dir=None,
train_input_fn=None,
steps_per_epoch=None,
epochs=1,
eval_input_fn=None,
eval_steps=None,
metric_fn=None,
init_checkpoint=None,
use_remote_tpu=False):
"""Run BERT pretrain model training using low-level API.
Arguments:
_sentinel: Used to prevent positional parameters. Internal, do not use.
strategy: Distribution strategy on which to run low level training loop.
model_fn: Function that returns a tuple (model, sub_model). Caller of this
function should add optimizer to the `model` via calling
`model.compile()` API or manually setting `model.optimizer` attribute.
Second element of the returned tuple(sub_model) is an optional sub model
to be used for initial checkpoint -- if provided.
loss_fn: Function with signature func(labels, logits) and returns a loss
tensor.
model_dir: Model directory used during training for restoring/saving model
weights.
train_input_fn: Function that returns a tf.data.Dataset used for training.
steps_per_epoch: Number of steps to run per epoch.
epochs: Number of epochs to train.
eval_input_fn: Function that returns evaluation dataset. If none,
evaluation is skipped.
eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
is not none.
metric_fn: A metrics function that returns a Keras Metric object to record
evaluation result using evaluation dataset or with training dataset
after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
use_remote_tpu: If true, input pipeline ops are placed in TPU worker host
as an optimization.
Returns:
Trained model.
Raises:
ValueError: (1) When model returned by `model_fn` does not have optimizer
attribute or when required parameters are set to none. (2) eval args are
not specified correctly. (3) metric_fn must be a callable if specified.
"""
if _sentinel is not None:
raise ValueError('only call `run_customized_training_loop()` '
'with named arguments.')
required_arguments = [
strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
]
if [arg for arg in required_arguments if arg is None]:
raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
'and `steps_per_epoch` are required parameters')
assert tf.executing_eagerly()
if eval_input_fn and (eval_steps is None or metric_fn is None):
raise ValueError(
'`eval_step` and `metric_fn` are required when `eval_input_fn ` '
'is not none.')
if metric_fn and not callable(metric_fn):
raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.')
# To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task.
with tf.device(get_primary_cpu_task(use_remote_tpu)):
train_iterator = strategy.make_dataset_iterator(train_input_fn())
with strategy.scope():
total_training_steps = steps_per_epoch * epochs
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model, sub_model = model_fn()
if not hasattr(model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
optimizer = model.optimizer
if init_checkpoint:
sub_model.load_weights(init_checkpoint)
metric = metric_fn() if metric_fn else None
# If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation.
train_metric = (
metric.__class__.from_config(metric.get_config())
if metric else None)
@tf.function
def train_step(iterator):
"""Performs a distributed training step."""
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
model_outputs = model(inputs)
loss = loss_fn(labels, model_outputs)
if train_metric:
train_metric.update_state(labels, model_outputs)
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(zip(grads, tvars))
return loss
per_replica_losses = strategy.experimental_run(_replicated_step,
iterator)
# For reporting, we returns the mean of losses.
loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return loss
@tf.function
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
metric.update_state(labels, model_outputs)
strategy.experimental_run(_test_step_fn, iterator)
def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
test_step(test_iterator)
logging.info('Step: [%d] Validation metric = %f', current_training_step,
metric.result())
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
logging.info('Loading from checkpoint file completed')
current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps:
loss = train_step(train_iterator)
current_step += 1
if train_metric:
logging.info(
'Train Step: %d/%d / loss = %s / training metric = %s',
current_step, total_training_steps, loss.numpy(),
train_metric.result())
else:
logging.info('Train Step: %d/%d / loss = %s', current_step,
total_training_steps, loss.numpy())
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_training_steps:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step,
strategy.make_dataset_iterator(eval_input_fn()))
# Re-initialize evaluation metric, except the last step.
if metric and current_step < total_training_steps:
metric.reset_states()
train_metric.reset_states()
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step,
strategy.make_dataset_iterator(eval_input_fn()))
return model
This diff is collapsed.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to optimization (weight updates)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import tensorflow as tf
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applys a warmup schedule on a given learning rate decay schedule."""
def __init__(
self,
initial_learning_rate,
decay_schedule_fn,
warmup_steps,
power=1.0,
name=None):
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.power = power
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step):
with tf.name_scope(self.name or 'WarmUp') as name:
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = (
self.initial_learning_rate *
tf.math.pow(warmup_percent_done, self.power))
return tf.cond(global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name)
def get_config(self):
return {
'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn,
'warmup_steps': self.warmup_steps,
'power': self.power,
'name': self.name
}
def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr,
decay_steps=num_train_steps,
end_learning_rate=0.0)
if num_warmup_steps:
learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=num_warmup_steps)
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
return optimizer
class AdamWeightDecay(tf.keras.optimizers.Adam):
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
Just adding the square of the weights to the loss function is *not* the
correct way of using L2 regularization/weight decay with Adam, since that will
interact with the m and v parameters in strange ways.
Instead we want ot decay the weights in a manner that doesn't interact with
the m/v parameters. This is equivalent to adding the square of the weights to
the loss with plain (non-momentum) SGD.
"""
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
amsgrad=False,
weight_decay_rate=0.0,
exclude_from_weight_decay=None,
name='AdamWeightDecay',
**kwargs):
super(AdamWeightDecay, self).__init__(
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
self._set_hyper('weight_decay_rate', weight_decay_rate)
self._exclude_from_weight_decay = exclude_from_weight_decay
@classmethod
def from_config(cls, config):
"""Creates an optimizer from its config with WarmUp custom object."""
custom_objects = {'WarmUp': WarmUp}
return super(AdamWeightDecay, cls).from_config(
config, custom_objects=custom_objects)
def _decay_weights_op(self, var, learning_rate):
do_decay = self._do_use_weight_decay(var.name)
if do_decay:
return var.assign_sub(
learning_rate * var *
self._get_hyper('weight_decay_rate'),
use_locking=self._use_locking)
return tf.no_op()
def apply_gradients(self, grads_and_vars, name=None):
grads, tvars = list(zip(*grads_and_vars))
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
def _resource_apply_dense(self, grad, var):
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
with tf.control_dependencies([self._decay_weights_op(var, lr_t)]):
return super(AdamWeightDecay, self)._resource_apply_dense(
grad, var)
def _resource_apply_sparse(self, grad, var, indices):
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
with tf.control_dependencies([self._decay_weights_op(var, lr_t)]):
return super(AdamWeightDecay, self)._resource_apply_sparse(
grad, var, indices)
def get_config(self):
config = super(AdamWeightDecay, self).get_config()
config.update({
'weight_decay_rate':
self._serialize_hyperparameter('weight_decay_rate'),
})
return config
def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if self.weight_decay_rate == 0:
return False
if self._exclude_from_weight_decay:
for r in self._exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT classification finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import math
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.bert import bert_models
from official.bert import input_pipeline
from official.bert import model_saving_utils
from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
flags.DEFINE_enum(
'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
'One of {"train_and_eval", "export_only"}. `train_and_eval`: '
'trains the model and evaluates in the meantime. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.')
flags.DEFINE_string(
'model_dir', None,
('The directory where the model weights and training/evaluation summaries '
'are stored. If not specified, save to /tmp/bert20/.'))
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_string('train_data_path', None,
'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None,
'Path to evaluation data for BERT classifier.')
flags.DEFINE_string(
'init_checkpoint', None,
'Initial checkpoint (usually from a pre-trained BERT model).')
flags.DEFINE_string(
'model_export_path', None,
'Path to the directory, where trainined model will be '
'exported.')
flags.DEFINE_enum(
'strategy_type',
'mirror',
['tpu', 'mirror'],
'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with '
'single host.')
# Model training specific flags.
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 8, 'Batch size for evaluation.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer('steps_per_run', 200,
'Number of steps running on TPU devices.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
FLAGS = flags.FLAGS
def write_eval_results(output_dir, results):
"""Writes and prints evaluation results.
Args:
output_dir: string, the path to the output directory.
results: a dictionary of evaluation metrics.
"""
output_eval_file = os.path.join(output_dir, 'eval_results.txt')
with tf.io.gfile.GFile(output_eval_file, 'w') as writer:
logging.info('***** Eval results *****')
for key, val in results.items():
logging.info(' %s = %s', key, str(val))
writer.write('%s = %s\n' % (key, str(val)))
def get_loss_fn(num_classes, loss_scale=1.0):
"""Gets the classification loss function."""
def classification_loss_fn(labels, logits):
"""Classification loss."""
labels = tf.squeeze(labels)
log_probs = tf.nn.log_softmax(logits, axis=-1)
one_hot_labels = tf.one_hot(
tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
loss *= loss_scale
return loss
return classification_loss_fn
def run_customized_training(strategy,
bert_config,
input_meta_data,
model_dir,
epochs,
steps_per_epoch,
eval_steps,
warmup_steps,
initial_lr,
init_checkpoint,
use_remote_tpu=False):
"""Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels']
train_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
FLAGS.train_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.train_batch_size)
eval_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
FLAGS.eval_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.eval_batch_size,
is_training=False,
drop_remainder=False)
def _get_classifier_model():
classifier_model, core_model = (
bert_models.classifier_model(bert_config, tf.float32, num_classes,
max_seq_length))
classifier_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
return classifier_model, core_model
loss_fn = get_loss_fn(num_classes, loss_scale=1.0)
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
def metric_fn():
return tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
return model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_classifier_model,
loss_fn=loss_fn,
model_dir=model_dir,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
eval_steps=eval_steps,
init_checkpoint=init_checkpoint,
metric_fn=metric_fn,
use_remote_tpu=use_remote_tpu)
def export_classifier(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
def _model_fn():
return bert_models.classifier_model(bert_config, tf.float32,
input_meta_data['num_labels'],
input_meta_data['max_seq_length'])[0]
model_saving_utils.export_bert_model(
model_export_path, model_fn=_model_fn, checkpoint_dir=FLAGS.model_dir)
def run_bert(strategy, input_meta_data):
"""Run BERT training."""
if FLAGS.mode == 'export_only':
export_classifier(FLAGS.model_export_path, input_meta_data)
return
if FLAGS.mode != 'train_and_eval':
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
train_data_size = input_meta_data['train_data_size']
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
if not strategy:
raise ValueError('Distribution strategy has not been specified.')
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
trained_model = run_customized_training(
strategy,
bert_config,
input_meta_data,
FLAGS.model_dir,
epochs,
steps_per_epoch,
eval_steps,
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu)
if FLAGS.model_export_path:
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = None
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else '')
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu)
tf.config.experimental_connect_to_host(cluster_resolver.master()) # pylint: disable=line-too-long
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
run_bert(strategy, input_meta_data)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('input_meta_data_path')
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run masked LM/next sentence masked_lm pre-training for BERT in tf2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.bert import bert_models
from official.bert import input_pipeline
from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
flags.DEFINE_string('input_files', None,
'File path to retrieve training data for pre-training.')
flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.')
flags.DEFINE_string(
'model_dir', None,
('The directory where the model weights and training/evaluation summaries '
'are stored. If not specified, save to /tmp/bert20/.'))
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_enum(
'strategy_type',
'mirror',
['tpu', 'mirror'],
'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with '
'single host.')
# Model training specific flags.
flags.DEFINE_integer(
'max_seq_length', 128,
'The maximum total input sequence length after WordPiece tokenization. '
'Sequences longer than this will be truncated, and sequences shorter '
'than this will be padded.')
flags.DEFINE_integer('max_predictions_per_seq', 20,
'Maximum predictions per sequence_output.')
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer(
'steps_per_run', 1000,
'Number of steps to run in TPU worker before returning to host.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.')
FLAGS = flags.FLAGS
def get_pretrain_input_data(input_file_pattern, seq_length,
max_predictions_per_seq, batch_size):
"""Returns input dataset from input file string."""
input_files = []
for input_pattern in input_file_pattern.split(','):
input_files.extend(tf.io.gfile.glob(input_pattern))
train_dataset = input_pipeline.create_pretrain_dataset(
input_files, seq_length, max_predictions_per_seq, batch_size)
return train_dataset
def get_loss_fn(loss_scale=1.0):
"""Returns loss function for BERT pretraining."""
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
return tf.keras.backend.mean(losses) * loss_scale
return _bert_pretrain_loss_fn
def run_customized_training(strategy,
bert_config,
max_seq_length,
max_predictions_per_seq,
model_dir,
steps_per_epoch,
epochs,
initial_lr,
warmup_steps,
input_files,
train_batch_size,
use_remote_tpu=False):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = functools.partial(get_pretrain_input_data, input_files,
max_seq_length, max_predictions_per_seq,
train_batch_size)
def _get_pretrain_model():
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq)
pretrain_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
return pretrain_model, core_model
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_pretrain_model,
loss_fn=get_loss_fn(),
model_dir=model_dir,
train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
use_remote_tpu=use_remote_tpu)
def run_bert_pretrain(strategy):
"""Runs BERT pre-training."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if not strategy:
raise ValueError('Distribution strategy is not specified.')
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
return run_customized_training(
strategy,
bert_config,
FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq,
FLAGS.model_dir,
FLAGS.num_steps_per_epoch,
FLAGS.num_train_epochs,
FLAGS.learning_rate,
FLAGS.warmup_steps,
FLAGS.input_files,
FLAGS.train_batch_size,
use_remote_tpu=use_remote_tpu)
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = None
if FLAGS.strategy_type == 'tpu':
logging.info('Use TPU at %s',
FLAGS.tpu if FLAGS.tpu is not None else '')
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu)
tf.config.experimental_connect_to_host(cluster_resolver.master()) # pylint: disable=line-too-long
tf.contrib.distribute.initialize_tpu_system(cluster_resolver)
strategy = tf.contrib.distribute.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
elif FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
if strategy:
print('***** Number of cores used : ', strategy.num_replicas_in_sync)
run_bert_pretrain(strategy)
if __name__ == '__main__':
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in tf2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.bert import bert_models
from official.bert import input_pipeline
from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
from official.bert import squad_lib
from official.bert import tokenization
flags.DEFINE_bool('do_train', False, 'Whether to run training.')
flags.DEFINE_bool('do_predict', False, 'Whether to run eval on the dev set.')
flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.')
flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.')
flags.DEFINE_string(
'model_dir', None,
('The directory where the model weights and training/evaluation summaries '
'are stored.'))
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_string(
'init_checkpoint', None,
'Initial checkpoint (usually from a pre-trained BERT model).')
flags.DEFINE_enum(
'strategy_type',
'mirror',
['tpu', 'mirror'],
'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with '
'single host.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer('steps_per_run', 200,
'Number of steps running on TPU devices.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.')
flags.DEFINE_string('vocab_file', None,
'The vocabulary file that the BERT model was trained on.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
'models and False for cased models.')
flags.DEFINE_bool(
'verbose_logging', False,
'If true, all of the warnings related to data processing will be printed. '
'A number of warnings are expected for a normal SQuAD evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
'Total batch size for prediction.')
flags.DEFINE_integer(
'n_best_size', 20,
'The total number of n-best predictions to generate in the '
'nbest_predictions.json output file.')
flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.')
FLAGS = flags.FLAGS
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits,
loss_scale=1.0):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.backend.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.backend.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_scale
return total_loss
def get_loss_fn(loss_scale=1.0):
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
_, start_logits, end_logits = model_outputs
return squad_loss_fn(
start_positions,
end_positions,
start_logits,
end_logits,
loss_scale=loss_scale)
return _loss_fn
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
for values in zip(unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy()):
yield squad_lib.RawResult(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist())
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model."""
primary_cpu_task = '/job:worker' if FLAGS.tpu else ''
with tf.device(primary_cpu_task):
predict_dataset = input_pipeline.create_squad_dataset(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = strategy.make_dataset_iterator(predict_dataset)
with strategy.scope():
squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path)
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids, start_logits, end_logits = squad_model(x, training=False)
return dict(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.experimental_run(replicated_step, iterator)
return tf.nest.map_structure(strategy.unwrap, outputs)
all_results = []
for _ in range(num_steps):
predictions = predict_step(predict_iterator)
for result in get_raw_results(predictions):
all_results.append(result)
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy, input_meta_data):
"""Run bert squad training."""
if not strategy:
raise ValueError('Distribution strategy cannot be None.')
logging.info('Training using customized training loop with distribution'
' strategy.')
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = functools.partial(
input_pipeline.create_squad_dataset,
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
def _get_squad_model():
squad_model, core_model = bert_models.squad_model(
bert_config, max_seq_length, float_type=tf.float32)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(loss_scale=1.0)
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=loss_fn,
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
dataset_size = squad_lib.convert_examples_to_features(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
eval_writer.close()
logging.info('***** Running predictions *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
squad_lib.write_predictions(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
verbose=FLAGS.verbose_logging)
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
strategy = None
if FLAGS.strategy_type == 'tpu':
logging.info('Use TPU at %s',
FLAGS.tpu if FLAGS.tpu is not None else '')
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu)
tf.config.experimental_connect_to_host(cluster_resolver.master()) # pylint: disable=line-too-long
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
elif FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'multi_worker_mirror':
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
if FLAGS.do_train:
train_squad(strategy, input_meta_data)
if FLAGS.do_predict:
predict_squad(strategy, input_meta_data)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('model_dir')
app.run(main)
This diff is collapsed.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tokenization classes implementation.
The file is forked from:
https://github.com/google-research/bert/blob/master/tokenization.py.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
import tensorflow as tf
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." %
(actual_flag, init_checkpoint, model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with tf.io.gfile.GFile(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import six
import tensorflow as tf
from official.bert import tokenization
class TokenizationTest(tf.test.TestCase):
"""Tokenization test.
The implementation is forked from
https://github.com/google-research/bert/blob/master/tokenization_test.py."
"""
def test_full_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
]
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
if six.PY2:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
else:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens
]).encode("utf-8"))
vocab_file = vocab_writer.name
tokenizer = tokenization.FullTokenizer(vocab_file)
os.unlink(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertAllEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self):
tokenizer = tokenization.BasicTokenizer()
self.assertAllEqual(
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
[u"ah", u"\u535A", u"\u63A8", u"zz"])
def test_basic_tokenizer_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello", "!", "how", "are", "you", "?"])
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_basic_tokenizer_no_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_wordpiece_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing"
]
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
self.assertAllEqual(tokenizer.tokenize(""), [])
self.assertAllEqual(
tokenizer.tokenize("unwanted running"),
["un", "##want", "##ed", "runn", "##ing"])
self.assertAllEqual(
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
def test_convert_tokens_to_ids(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing"
]
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
self.assertAllEqual(
tokenization.convert_tokens_to_ids(
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
def test_is_whitespace(self):
self.assertTrue(tokenization._is_whitespace(u" "))
self.assertTrue(tokenization._is_whitespace(u"\t"))
self.assertTrue(tokenization._is_whitespace(u"\r"))
self.assertTrue(tokenization._is_whitespace(u"\n"))
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
self.assertFalse(tokenization._is_whitespace(u"A"))
self.assertFalse(tokenization._is_whitespace(u"-"))
def test_is_control(self):
self.assertTrue(tokenization._is_control(u"\u0005"))
self.assertFalse(tokenization._is_control(u"A"))
self.assertFalse(tokenization._is_control(u" "))
self.assertFalse(tokenization._is_control(u"\t"))
self.assertFalse(tokenization._is_control(u"\r"))
self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
def test_is_punctuation(self):
self.assertTrue(tokenization._is_punctuation(u"-"))
self.assertTrue(tokenization._is_punctuation(u"$"))
self.assertTrue(tokenization._is_punctuation(u"`"))
self.assertTrue(tokenization._is_punctuation(u"."))
self.assertFalse(tokenization._is_punctuation(u"A"))
self.assertFalse(tokenization._is_punctuation(u" "))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment