Commit 15069a36 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Open source text_layers.py

PiperOrigin-RevId: 351309008
parent 9fb74a40
......@@ -105,6 +105,12 @@ pip will install all models and dependencies automatically.
pip install tf-models-official
```
If you are using nlp packages, please also install **tensorflow-text**:
```shell
pip install tensorflow-text
```
Please check out our [example](colab/fine_tuning_bert.ipynb)
to learn how to use a PIP package.
......@@ -143,6 +149,13 @@ os.environ['PYTHONPATH'] += ":/path/to/models"
pip3 install --user -r official/requirements.txt
```
Finally, if you are using nlp packages, please also install
**tensorflow-text-nightly**:
```shell
pip3 install tensorflow-text-nightly
```
## Contributions
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
......
......@@ -89,3 +89,8 @@ assemble new layers, networks, or models.
[MobileBertTransformer](mobile_bert_layers.py) implement the embedding layer
and also transformer layer proposed in the
[MobileBERT paper](https://arxiv.org/pdf/2004.02984.pdf).
* [BertPackInputs](text_layers.py) and
[BertTokenizer](text_layers.py) and [SentencepieceTokenizer](text_layers.py)
implements the layer to tokenize raw text and pack them into the inputs for
BERT models.
......@@ -32,6 +32,9 @@ from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAtt
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.text_layers import BertPackInputs
from official.nlp.modeling.layers.text_layers import BertTokenizer
from official.nlp.modeling.layers.text_layers import SentencepieceTokenizer
from official.nlp.modeling.layers.tn_transformer_expand_condense import TNTransformerExpandCondense
from official.nlp.modeling.layers.transformer import *
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras Layers for BERT-specific preprocessing."""
from typing import Any, Dict, List, Optional, Union
from absl import logging
import tensorflow as tf
try:
import tensorflow_text as text # pylint: disable=g-import-not-at-top
except ImportError:
text = None
def _check_if_tf_text_installed():
if text is None:
raise ImportError("import tensorflow_text failed, please install "
"'tensorflow-text-nightly'.")
def round_robin_truncate_inputs(
inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]],
limit: Union[int, tf.Tensor],
) -> Union[tf.RaggedTensor, List[tf.RaggedTensor]]:
"""Truncates a list of batched segments to fit a per-example length limit.
Available space is assigned one token at a time in a round-robin fashion
to the inputs that still need some, until the limit is reached.
(Or equivalently: the longest input is truncated by one token until the total
length of inputs fits the limit.) Examples that fit the limit as passed in
remain unchanged.
Args:
inputs: A list of rank-2 RaggedTensors. The i-th example is given by
the i-th row in each list element, that is, `inputs[:][i, :]`.
limit: The largest permissible number of tokens in total across one example.
Returns:
A list of rank-2 RaggedTensors at corresponding indices with the inputs,
in which the rows of each RaggedTensor have been truncated such that
the total number of tokens in each example does not exceed the `limit`.
"""
if not isinstance(inputs, (list, tuple)):
return round_robin_truncate_inputs([inputs], limit)[0]
limit = tf.cast(limit, tf.int64)
if not all(rt.shape.rank == 2 for rt in inputs):
raise ValueError("All inputs must have shape [batch_size, (items)]")
if len(inputs) == 1:
return [_truncate_row_lengths(inputs[0], limit)]
elif len(inputs) == 2:
size_a, size_b = [rt.row_lengths() for rt in inputs]
# Here's a brain-twister: This does round-robin assignment of quota
# to both inputs until the limit is reached. Hint: consider separately
# the cases of zero, one, or two inputs exceeding half the limit.
floor_half = limit // 2
ceil_half = limit - floor_half
quota_a = tf.minimum(size_a, ceil_half + tf.nn.relu(floor_half - size_b))
quota_b = tf.minimum(size_b, floor_half + tf.nn.relu(ceil_half - size_a))
return [_truncate_row_lengths(inputs[0], quota_a),
_truncate_row_lengths(inputs[1], quota_b)]
else:
raise ValueError("Must pass 1 or 2 inputs")
def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
new_lengths: tf.Tensor) -> tf.RaggedTensor:
"""Truncates the rows of `ragged_tensor` to the given row lengths."""
new_lengths = tf.broadcast_to(new_lengths,
ragged_tensor.bounding_shape()[0:1])
def fn(x):
row, new_length = x
return row[0:new_length]
fn_dtype = tf.RaggedTensorSpec(dtype=ragged_tensor.dtype,
ragged_rank=ragged_tensor.ragged_rank - 1)
result = tf.map_fn(fn, (ragged_tensor, new_lengths), dtype=fn_dtype)
# Work around broken shape propagation: without this, result has unknown rank.
flat_values_shape = [None] * ragged_tensor.flat_values.shape.rank
result = result.with_flat_values(
tf.ensure_shape(result.flat_values, flat_values_shape))
return result
class BertTokenizer(tf.keras.layers.Layer):
"""Wraps BertTokenizer with pre-defined vocab as a Keras Layer.
Attributes:
tokenize_with_offsets: If true, calls BertTokenizer.tokenize_with_offsets()
instead of plain .tokenize() and outputs a triple of
(tokens, start_offsets, limit_offsets).
raw_table_access: An object with methods .lookup(keys) and .size()
that operate on the raw lookup table of tokens. It can be used to
look up special token synbols like [MASK].
"""
def __init__(self, *,
vocab_file: str,
lower_case: bool,
tokenize_with_offsets: bool = False,
**kwargs):
"""Initialize a BertTokenizer layer.
Args:
vocab_file: A Python string with the path of the vocabulary file.
This is a text file with newline-separated wordpiece tokens.
This layer initializes a lookup table from it that gets used with
text.BertTokenizer.
lower_case: A Python boolean forwarded to text.BertTokenizer.
If true, input text is converted to lower case (where applicable)
before tokenization. This must be set to match the way in which
the vocab_file was created.
tokenize_with_offsets: A Python boolean. If true, this layer calls
BertTokenizer.tokenize_with_offsets() instead of plain .tokenize()
and outputs a triple of (tokens, start_offsets, limit_offsets)
insead of just tokens.
**kwargs: standard arguments to Layer().
Raises:
ImportError: if importing tensorflow_text failed.
"""
_check_if_tf_text_installed()
self.tokenize_with_offsets = tokenize_with_offsets
self._vocab_table = self._create_vocab_table(vocab_file)
self._special_tokens_dict = self._create_special_tokens_dict(
self._vocab_table, vocab_file)
super().__init__(**kwargs)
self._bert_tokenizer = text.BertTokenizer(
self._vocab_table, lower_case=lower_case)
@property
def vocab_size(self):
return self._vocab_table.size()
def _create_vocab_table(self, vocab_file):
vocab_initializer = tf.lookup.TextFileInitializer(
vocab_file,
key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
return tf.lookup.StaticHashTable(vocab_initializer, default_value=-1)
def call(self, inputs: tf.Tensor):
"""Calls text.BertTokenizer on inputs.
Args:
inputs: A string Tensor of shape [batch_size].
Returns:
One or three of RaggedTensors if tokenize_with_offsets is False or True,
respectively. These are
tokens: A RaggedTensor of shape [batch_size, (words), (pieces_per_word)]
and type int32. tokens[i,j,k] contains the k-th wordpiece of the
j-th word in the i-th input.
start_offsets, limit_offsets: If tokenize_with_offsets is True,
RaggedTensors of type int64 with the same indices as tokens.
Element [i,j,k] contains the byte offset at the start, or past the
end, resp., for the k-th wordpiece of the j-th word in the i-th input.
"""
# Prepare to reshape the result to work around broken shape inference.
batch_size = tf.shape(inputs)[0]
def _reshape(rt):
values = rt.values
row_splits = rt.row_splits
row_splits = tf.reshape(row_splits, [batch_size + 1])
return tf.RaggedTensor.from_row_splits(values, row_splits)
# Call the tokenizer.
if self.tokenize_with_offsets:
tokens, start_offsets, limit_offsets = (
self._bert_tokenizer.tokenize_with_offsets(inputs))
tokens = tf.cast(tokens, dtype=tf.int32)
return _reshape(tokens), _reshape(start_offsets), _reshape(limit_offsets)
else:
tokens = self._bert_tokenizer.tokenize(inputs)
tokens = tf.cast(tokens, dtype=tf.int32)
return _reshape(tokens)
def get_config(self):
# Skip in tf.saved_model.save(); fail if called direcly.
# TODO(arnoegw): Implement when switching to MutableHashTable, which gets
# initialized from the checkpoint and not from a vocab file.
# We cannot just put the original, user-supplied vocab file name into
# the config, because the path has to change as the SavedModel is copied
# around.
raise NotImplementedError("Not implemented yet.")
def get_special_tokens_dict(self):
"""Returns dict of token ids, keyed by standard names for their purpose.
Returns:
A dict from Python strings to Python integers. Each key is a standard
name for a special token describing its use. (For example, "padding_id"
is what BERT traditionally calls "[PAD]" but others may call "<pad>".)
The corresponding value is the integer token id. If a special token
is not found, its entry is omitted from the dict.
The supported keys and tokens are:
* start_of_sequence_id: looked up from "[CLS]"
* end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up form "[PAD]"
* mask_id: looked up from "[MASK]"
"""
return self._special_tokens_dict
def _create_special_tokens_dict(self, vocab_table, vocab_file):
special_tokens = dict(start_of_sequence_id="[CLS]",
end_of_segment_id="[SEP]",
padding_id="[PAD]",
mask_id="[MASK]")
with tf.init_scope():
if tf.executing_eagerly():
special_token_ids = vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string))
else:
# A blast from the past: non-eager init context while building Model.
# This can happen with Estimator or tf.compat.v1.disable_v2_behavior().
logging.warning(
"Non-eager init context; computing "
"BertTokenizer's special_tokens_dict in tf.compat.v1.Session")
with tf.Graph().as_default():
local_vocab_table = self._create_vocab_table(vocab_file)
special_token_ids_tensor = local_vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string))
init_ops = [tf.compat.v1.initialize_all_tables()]
with tf.compat.v1.Session() as sess:
sess.run(init_ops)
special_token_ids = sess.run(special_token_ids_tensor)
result = dict()
for k, v in zip(special_tokens, special_token_ids):
v = int(v) # Numpy to Python.
if v >= 0:
result[k] = v
else:
logging.warning("Could not find %s as token \"%s\" in vocab file %s",
k, special_tokens[k], vocab_file)
return result
class SentencepieceTokenizer(tf.keras.layers.Layer):
"""Wraps tf_text.SentencepieceTokenizer as a Keras Layer.
Attributes:
tokenize_with_offsets: If true, calls
SentencepieceTokenizer.tokenize_with_offsets()
instead of plain .tokenize() and outputs a triple of
(tokens, start_offsets, limit_offsets).
"""
def __init__(self,
*,
lower_case: bool,
model_file_path: Optional[str] = None,
model_serialized_proto: Optional[str] = None,
tokenize_with_offsets: bool = False,
nbest_size: int = 0,
alpha: float = 1.0,
strip_diacritics: bool = False,
**kwargs):
"""Initializes a SentencepieceTokenizer layer.
Args:
lower_case: A Python boolean indicating whether to lowercase the string
before tokenization. NOTE: New models are encouraged to build `*_cf`
(case folding) normalization into the Sentencepiece model itself and
avoid this extra step.
model_file_path: A Python string with the path of the sentencepiece model.
Exactly one of `model_file_path` and `model_serialized_proto` can be
specified. In either case, the Keras model config for this layer will
store the actual proto (not a filename passed here).
model_serialized_proto: The sentencepiece model serialized proto string.
tokenize_with_offsets: A Python boolean. If true, this layer calls
SentencepieceTokenizer.tokenize_with_offsets() instead of
plain .tokenize() and outputs a triple of
(tokens, start_offsets, limit_offsets) insead of just tokens.
Note that when following `strip_diacritics` is set to True, returning
offsets is not supported now.
nbest_size: A scalar for sampling:
nbest_size = {0,1}: No sampling is performed. (default)
nbest_size > 1: samples from the nbest_size results.
nbest_size < 0: assuming that nbest_size is infinite and samples
from the all hypothesis (lattice) using
forward-filtering-and-backward-sampling algorithm.
alpha: A scalar for a smoothing parameter. Inverse temperature for
probability rescaling.
strip_diacritics: Whether to strip diacritics or not. Note that stripping
diacritics requires additional text normalization and dropping bytes,
which makes it impossible to keep track of the offsets now. Hence
when `strip_diacritics` is set to True, we don't yet support
`tokenize_with_offsets`. NOTE: New models are encouraged to put this
into custom normalization rules for the Sentencepiece model itself to
avoid this extra step and the limitation regarding offsets.
**kwargs: standard arguments to Layer().
Raises:
ImportError: if importing tensorflow_text failed.
"""
_check_if_tf_text_installed()
super().__init__(**kwargs)
if bool(model_file_path) == bool(model_serialized_proto):
raise ValueError("Exact one of `model_file_path` and "
"`model_serialized_proto` can be specified.")
# TODO(chendouble): After b/149576200 is resolved, support
# tokenize_with_offsets when strip_diacritics is True,
if tokenize_with_offsets and strip_diacritics:
raise ValueError("`tokenize_with_offsets` is not supported when "
"`strip_diacritics` is set to True.")
if model_file_path:
self._model_serialized_proto = tf.io.gfile.GFile(model_file_path,
"rb").read()
else:
self._model_serialized_proto = model_serialized_proto
self._lower_case = lower_case
self.tokenize_with_offsets = tokenize_with_offsets
self._nbest_size = nbest_size
self._alpha = alpha
self._strip_diacritics = strip_diacritics
self._tokenizer = self._create_tokenizer()
self._special_tokens_dict = self._create_special_tokens_dict()
def _create_tokenizer(self):
return text.SentencepieceTokenizer(
model=self._model_serialized_proto,
out_type=tf.int32,
nbest_size=self._nbest_size,
alpha=self._alpha)
@property
def vocab_size(self):
return self._tokenizer.vocab_size()
def call(self, inputs: tf.Tensor):
"""Calls text.SentencepieceTokenizer on inputs.
Args:
inputs: A string Tensor of shape [batch_size].
Returns:
One or three of RaggedTensors if tokenize_with_offsets is False or True,
respectively. These are
tokens: A RaggedTensor of shape [batch_size, (pieces)] and type int32.
tokens[i,j] contains the j-th piece in the i-th input.
start_offsets, limit_offsets: If tokenize_with_offsets is True,
RaggedTensors of type int64 with the same indices as tokens.
Element [i,j] contains the byte offset at the start, or past the
end, resp., for the j-th piece in the i-th input.
"""
if self._strip_diacritics:
if self.tokenize_with_offsets:
raise ValueError("`tokenize_with_offsets` is not supported yet due to "
"b/149576200, when `strip_diacritics` is set to True.")
inputs = text.normalize_utf8(inputs, "NFD")
inputs = tf.strings.regex_replace(inputs, r"\p{Mn}", "")
if self._lower_case:
inputs = text.case_fold_utf8(inputs)
# Prepare to reshape the result to work around broken shape inference.
batch_size = tf.shape(inputs)[0]
def _reshape(rt):
values = rt.values
row_splits = rt.row_splits
row_splits = tf.reshape(row_splits, [batch_size + 1])
return tf.RaggedTensor.from_row_splits(values, row_splits)
# Call the tokenizer.
if self.tokenize_with_offsets:
tokens, start_offsets, limit_offsets = (
self._tokenizer.tokenize_with_offsets(inputs))
return _reshape(tokens), _reshape(start_offsets), _reshape(limit_offsets)
else:
tokens = self._tokenizer.tokenize(inputs)
return _reshape(tokens)
def get_config(self):
raise NotImplementedError("b/170480226")
# TODO(b/170480226): Uncomment and improve to fix the bug.
# config = {
# "model_serialized_proto": self._model_serialized_proto,
# "lower_case": self._lower_case,
# "tokenize_with_offsets": self.tokenize_with_offsets,
# "nbest_size": self._nbest_size,
# "alpha": self._alpha,
# "strip_diacritics": self._strip_diacritics,
# }
# base_config = super(SentencepieceTokenizer, self).get_config()
# base_config.update(config)
# return base_config
def get_special_tokens_dict(self):
"""Returns dict of token ids, keyed by standard names for their purpose.
Returns:
A dict from Python strings to Python integers. Each key is a standard
name for a special token describing its use. (For example, "padding_id"
is what Sentencepiece calls "<pad>" but others may call "[PAD]".)
The corresponding value is the integer token id. If a special token
is not found, its entry is omitted from the dict.
The supported keys and tokens are:
* start_of_sequence_id: looked up from "[CLS]"
* end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up from "<pad>"
* mask_id: looked up from "[MASK]"
"""
return self._special_tokens_dict
def _create_special_tokens_dict(self):
special_tokens = dict(
start_of_sequence_id=b"[CLS]",
end_of_segment_id=b"[SEP]",
padding_id=b"<pad>",
mask_id=b"[MASK]")
with tf.init_scope():
if tf.executing_eagerly():
special_token_ids = self._tokenizer.string_to_id(
tf.constant(list(special_tokens.values()), tf.string))
inverse_tokens = self._tokenizer.id_to_string(special_token_ids)
else:
# A blast from the past: non-eager init context while building Model.
# This can happen with Estimator or tf.compat.v1.disable_v2_behavior().
logging.warning(
"Non-eager init context; computing SentencepieceTokenizer's "
"special_tokens_dict in tf.compat.v1.Session")
with tf.Graph().as_default():
local_tokenizer = self._create_tokenizer()
special_token_ids_tensor = local_tokenizer.string_to_id(
tf.constant(list(special_tokens.values()), tf.string))
inverse_tokens_tensor = local_tokenizer.id_to_string(
special_token_ids_tensor)
with tf.compat.v1.Session() as sess:
special_token_ids, inverse_tokens = sess.run(
[special_token_ids_tensor, inverse_tokens_tensor])
result = dict()
for name, token_id, inverse_token in zip(special_tokens,
special_token_ids,
inverse_tokens):
if special_tokens[name] == inverse_token:
result[name] = int(token_id) # Numpy to Python.
else:
logging.warning(
"Could not find %s as token \"%s\" in sentencepiece model, "
"got \"%s\"", name, special_tokens[name], inverse_token)
return result
class BertPackInputs(tf.keras.layers.Layer):
"""Packs tokens into model inputs for BERT."""
def __init__(self,
seq_length,
*,
start_of_sequence_id=None,
end_of_segment_id=None,
padding_id=None,
special_tokens_dict=None,
truncator="round_robin",
**kwargs):
"""Initializes with a target seq_length, relevant token ids and truncator.
Args:
seq_length: The desired output length. Must not exceed the max_seq_length
that was fixed at training time for the BERT model receiving the inputs.
start_of_sequence_id: The numeric id of the token that is to be placed
at the start of each sequence (called "[CLS]" for BERT).
end_of_segment_id: The numeric id of the token that is to be placed
at the end of each input segment (called "[SEP]" for BERT).
padding_id: The numeric id of the token that is to be placed into the
unused positions after the last segment in the sequence
(called "[PAD]" for BERT).
special_tokens_dict: Optionally, a dict from Python strings to Python
integers that contains values for start_of_sequence_id,
end_of_segment_id and padding_id. (Further values in the dict are
silenty ignored.) If this is passed, separate *_id arguments must be
omitted.
truncator: The algorithm to truncate a list of batched segments to fit a
per-example length limit. The value can be either "round_robin" or
"waterfall":
(1) For "round_robin" algorithm, available space is assigned
one token at a time in a round-robin fashion to the inputs that still
need some, until the limit is reached. It currently only supports
one or two segments.
(2) For "waterfall" algorithm, the allocation of the budget is done
using a "waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run out of
budget. It support arbitrary number of segments.
**kwargs: standard arguments to Layer().
Raises:
ImportError: if importing tensorflow_text failed.
"""
_check_if_tf_text_installed()
super().__init__(**kwargs)
self.seq_length = seq_length
if truncator not in ("round_robin", "waterfall"):
raise ValueError("Only 'round_robin' and 'waterfall' algorithms are "
"supported, but got %s" % truncator)
self.truncator = truncator
self._init_token_ids(
start_of_sequence_id=start_of_sequence_id,
end_of_segment_id=end_of_segment_id,
padding_id=padding_id,
special_tokens_dict=special_tokens_dict)
def _init_token_ids(
self, *,
start_of_sequence_id,
end_of_segment_id,
padding_id,
special_tokens_dict):
usage = ("Must pass either all of start_of_sequence_id, end_of_segment_id, "
"padding_id as arguments, or else a special_tokens_dict "
"with those keys.")
special_tokens_args = [start_of_sequence_id, end_of_segment_id, padding_id]
if special_tokens_dict is None:
if any(x is None for x in special_tokens_args):
return ValueError(usage)
self.start_of_sequence_id = int(start_of_sequence_id)
self.end_of_segment_id = int(end_of_segment_id)
self.padding_id = int(padding_id)
else:
if any(x is not None for x in special_tokens_args):
return ValueError(usage)
self.start_of_sequence_id = int(
special_tokens_dict["start_of_sequence_id"])
self.end_of_segment_id = int(special_tokens_dict["end_of_segment_id"])
self.padding_id = int(special_tokens_dict["padding_id"])
def get_config(self) -> Dict[str, Any]:
config = super().get_config()
config["seq_length"] = self.seq_length
config["start_of_sequence_id"] = self.start_of_sequence_id
config["end_of_segment_id"] = self.end_of_segment_id
config["padding_id"] = self.padding_id
config["truncator"] = self.truncator
return config
def call(self, inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]]):
"""Adds special tokens to pack a list of segments into BERT input Tensors.
Args:
inputs: A Python list of one or two RaggedTensors, each with the batched
values one input segment. The j-th segment of the i-th input example
consists of slice `inputs[j][i, ...]`.
Returns:
A nest of Tensors for use as input to the BERT TransformerEncoder.
"""
# BertPackInputsSavedModelWrapper relies on only calling bert_pack_inputs()
return BertPackInputs.bert_pack_inputs(
inputs, self.seq_length,
start_of_sequence_id=self.start_of_sequence_id,
end_of_segment_id=self.end_of_segment_id,
padding_id=self.padding_id,
truncator=self.truncator)
@staticmethod
def bert_pack_inputs(inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]],
seq_length: Union[int, tf.Tensor],
start_of_sequence_id: Union[int, tf.Tensor],
end_of_segment_id: Union[int, tf.Tensor],
padding_id: Union[int, tf.Tensor],
truncator="round_robin"):
"""Freestanding equivalent of the BertPackInputs layer."""
_check_if_tf_text_installed()
# Sanitize inputs.
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
if not inputs:
raise ValueError("At least one input is required for packing")
input_ranks = [rt.shape.rank for rt in inputs]
if None in input_ranks or len(set(input_ranks)) > 1:
raise ValueError("All inputs for packing must have the same known rank, "
"found ranks " + ",".join(input_ranks))
# Flatten inputs to [batch_size, (tokens)].
if input_ranks[0] > 2:
inputs = [rt.merge_dims(1, -1) for rt in inputs]
# In case inputs weren't truncated (as they should have been),
# fall back to some ad-hoc truncation.
num_special_tokens = len(inputs) + 1
if truncator == "round_robin":
trimmed_segments = round_robin_truncate_inputs(
inputs, seq_length - num_special_tokens)
elif truncator == "waterfall":
trimmed_segments = text.WaterfallTrimmer(
seq_length - num_special_tokens).trim(inputs)
else:
raise ValueError("Unsupported truncator: %s" % truncator)
# Combine segments.
segments_combined, segment_ids = text.combine_segments(
trimmed_segments,
start_of_sequence_id=start_of_sequence_id,
end_of_segment_id=end_of_segment_id)
# Pad to dense Tensors.
input_word_ids, _ = text.pad_model_inputs(segments_combined, seq_length,
pad_value=padding_id)
input_type_ids, input_mask = text.pad_model_inputs(segment_ids, seq_length,
pad_value=0)
# Work around broken shape inference.
output_shape = tf.stack([
inputs[0].nrows(out_type=tf.int32), # batch_size
tf.cast(seq_length, dtype=tf.int32)])
def _reshape(t):
return tf.reshape(t, output_shape)
# Assemble nest of input tensors as expected by BERT TransformerEncoder.
return dict(input_word_ids=_reshape(input_word_ids),
input_mask=_reshape(input_mask),
input_type_ids=_reshape(input_type_ids))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests bert.text_layers."""
import os
import tempfile
import numpy as np
import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from official.nlp.modeling.layers import text_layers
class RoundRobinTruncatorTest(tf.test.TestCase):
def test_correct_outputs(self):
def test_input(start, lengths):
return tf.ragged.constant([[start + 10*j + i for i in range(length)]
for j, length in enumerate(lengths)],
dtype=tf.int32)
# Single segment.
single_input = test_input(11, [4, 5, 6])
expected_single_output = tf.ragged.constant(
[[11, 12, 13, 14],
[21, 22, 23, 24, 25],
[31, 32, 33, 34, 35], # Truncated.
])
self.assertAllEqual(
expected_single_output,
text_layers.round_robin_truncate_inputs(single_input, limit=5))
# Test wrapping in a singleton list.
actual_single_list_output = text_layers.round_robin_truncate_inputs(
[single_input], limit=5)
self.assertIsInstance(actual_single_list_output, list)
self.assertAllEqual(expected_single_output, actual_single_list_output[0])
# Two segments.
input_a = test_input(111, [1, 2, 2, 3, 4, 5])
input_b = test_input(211, [1, 3, 4, 2, 2, 5])
expected_a = tf.ragged.constant(
[[111],
[121, 122],
[131, 132],
[141, 142, 143],
[151, 152, 153], # Truncated.
[161, 162, 163], # Truncated.
])
expected_b = tf.ragged.constant(
[[211],
[221, 222, 223],
[231, 232, 233], # Truncated.
[241, 242],
[251, 252],
[261, 262], # Truncated.
])
actual_a, actual_b = text_layers.round_robin_truncate_inputs(
[input_a, input_b], limit=5)
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
# This test covers the in-process behavior of a BertTokenizer layer.
# For saving, restoring, and the restored behavior (incl. shape inference),
# see export_tfub_test.py.
class BertTokenizerTest(tf.test.TestCase):
def _make_vocab_file(self, vocab, filename="vocab.txt"):
path = os.path.join(
tempfile.mkdtemp(dir=self.get_temp_dir()), # New subdir each time.
filename)
with tf.io.gfile.GFile(path, "w") as f:
f.write("\n".join(vocab + [""]))
return path
def test_uncased(self):
vocab_file = self._make_vocab_file(
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])
bert_tokenize = text_layers.BertTokenizer(
vocab_file=vocab_file, lower_case=True)
inputs = tf.constant(["abc def", "ABC DEF d"])
token_ids = bert_tokenize(inputs)
self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
[[6], [4, 5], [4]]]))
bert_tokenize.tokenize_with_offsets = True
token_ids_2, start_offsets, limit_offsets = bert_tokenize(inputs)
self.assertAllEqual(token_ids, token_ids_2)
self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
[[0], [4, 5], [8]]]))
self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
[[3], [5, 7], [9]]]))
self.assertEqual(bert_tokenize.vocab_size.numpy(), 8)
# Repeat the above and test that case matters with lower_case=False.
def test_cased(self):
vocab_file = self._make_vocab_file(
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "ABC"])
bert_tokenize = text_layers.BertTokenizer(
vocab_file=vocab_file, lower_case=False, tokenize_with_offsets=True)
inputs = tf.constant(["abc def", "ABC DEF"])
token_ids, start_offsets, limit_offsets = bert_tokenize(inputs)
self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
[[7], [1]]]))
self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
[[0], [4]]]))
self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
[[3], [7]]]))
def test_special_tokens_complete(self):
vocab_file = self._make_vocab_file(
["foo", "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "xy"])
bert_tokenize = text_layers.BertTokenizer(
vocab_file=vocab_file, lower_case=True)
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
dict(padding_id=1,
start_of_sequence_id=3,
end_of_segment_id=4,
mask_id=5))
def test_special_tokens_partial(self):
vocab_file = self._make_vocab_file(
["[PAD]", "[CLS]", "[SEP]"])
bert_tokenize = text_layers.BertTokenizer(
vocab_file=vocab_file, lower_case=True)
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
dict(padding_id=0,
start_of_sequence_id=1,
end_of_segment_id=2)) # No mask_id,
def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context."""
vocab_file = self._make_vocab_file(
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])
def input_fn():
with tf.init_scope():
self.assertFalse(tf.executing_eagerly())
# Build a preprocessing Model.
sentences = tf.keras.layers.Input(shape=[], dtype=tf.string)
bert_tokenizer = text_layers.BertTokenizer(
vocab_file=vocab_file, lower_case=True)
special_tokens_dict = bert_tokenizer.get_special_tokens_dict()
for k, v in special_tokens_dict.items():
self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
tokens = bert_tokenizer(sentences)
packed_inputs = text_layers.BertPackInputs(
4, special_tokens_dict=special_tokens_dict)(tokens)
preprocessing = tf.keras.Model(sentences, packed_inputs)
# Map the dataset.
ds = tf.data.Dataset.from_tensors(
(tf.constant(["abc", "DEF"]), tf.constant([0, 1])))
ds = ds.map(lambda features, labels: (preprocessing(features), labels))
return ds
def model_fn(features, labels, mode):
del labels # Unused.
return tf.estimator.EstimatorSpec(mode=mode,
predictions=features["input_word_ids"])
estimator = tf.estimator.Estimator(model_fn=model_fn)
outputs = list(estimator.predict(input_fn))
self.assertAllEqual(outputs, np.array([[2, 6, 3, 0],
[2, 4, 5, 3]]))
# This test covers the in-process behavior of a SentencepieceTokenizer layer.
class SentencepieceTokenizerTest(tf.test.TestCase):
def setUp(self):
super().setUp()
# Make a sentencepiece model.
tmp_dir = self.get_temp_dir()
tempfile.mkdtemp(dir=tmp_dir)
vocab = ["a", "b", "c", "d", "e", "abc", "def", "ABC", "DEF"]
model_prefix = os.path.join(tmp_dir, "spm_model")
input_text_file_path = os.path.join(tmp_dir, "train_input.txt")
with tf.io.gfile.GFile(input_text_file_path, "w") as f:
f.write(" ".join(vocab + ["\n"]))
# Add 7 more tokens: <pad>, <unk>, [CLS], [SEP], [MASK], <s>, </s>.
full_vocab_size = len(vocab) + 7
flags = dict(
model_prefix=model_prefix,
model_type="word",
input=input_text_file_path,
pad_id=0, unk_id=1, control_symbols="[CLS],[SEP],[MASK]",
vocab_size=full_vocab_size,
bos_id=full_vocab_size-2, eos_id=full_vocab_size-1)
SentencePieceTrainer.Train(
" ".join(["--{}={}".format(k, v) for k, v in flags.items()]))
self._spm_path = model_prefix + ".model"
def test_uncased(self):
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path, lower_case=True, nbest_size=0)
inputs = tf.constant(["abc def", "ABC DEF d"])
token_ids = sentencepiece_tokenizer(inputs)
self.assertAllEqual(
token_ids,
tf.ragged.constant([[8, 12], [8, 12, 11]]))
sentencepiece_tokenizer.tokenize_with_offsets = True
token_ids_2, start_offsets, limit_offsets = sentencepiece_tokenizer(inputs)
self.assertAllEqual(token_ids, token_ids_2)
self.assertAllEqual(
start_offsets, tf.ragged.constant([[0, 3], [0, 3, 7]]))
self.assertAllEqual(
limit_offsets, tf.ragged.constant([[3, 7], [3, 7, 9]]))
self.assertEqual(sentencepiece_tokenizer.vocab_size.numpy(), 16)
# Repeat the above and test that case matters with lower_case=False.
def test_cased(self):
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path,
lower_case=False,
nbest_size=0,
tokenize_with_offsets=False)
inputs = tf.constant(["abc def", "ABC DEF d"])
token_ids = sentencepiece_tokenizer(inputs)
self.assertAllEqual(
token_ids,
tf.ragged.constant([[8, 12], [5, 6, 11]]))
sentencepiece_tokenizer.tokenize_with_offsets = True
token_ids_2, start_offsets, limit_offsets = sentencepiece_tokenizer(inputs)
self.assertAllEqual(token_ids, token_ids_2)
self.assertAllEqual(
start_offsets,
tf.ragged.constant([[0, 3], [0, 3, 7]]))
self.assertAllEqual(
limit_offsets,
tf.ragged.constant([[3, 7], [3, 7, 9]]))
def test_special_tokens(self):
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path, lower_case=True, nbest_size=0)
self.assertDictEqual(sentencepiece_tokenizer.get_special_tokens_dict(),
dict(padding_id=0,
start_of_sequence_id=2,
end_of_segment_id=3,
mask_id=4))
def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context."""
def input_fn():
with tf.init_scope():
self.assertFalse(tf.executing_eagerly())
# Build a preprocessing Model.
sentences = tf.keras.layers.Input(shape=[], dtype=tf.string)
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path, lower_case=True, nbest_size=0)
special_tokens_dict = sentencepiece_tokenizer.get_special_tokens_dict()
for k, v in special_tokens_dict.items():
self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
tokens = sentencepiece_tokenizer(sentences)
packed_inputs = text_layers.BertPackInputs(
4, special_tokens_dict=special_tokens_dict)(tokens)
preprocessing = tf.keras.Model(sentences, packed_inputs)
# Map the dataset.
ds = tf.data.Dataset.from_tensors(
(tf.constant(["abc", "DEF"]), tf.constant([0, 1])))
ds = ds.map(lambda features, labels: (preprocessing(features), labels))
return ds
def model_fn(features, labels, mode):
del labels # Unused.
return tf.estimator.EstimatorSpec(mode=mode,
predictions=features["input_word_ids"])
estimator = tf.estimator.Estimator(model_fn=model_fn)
outputs = list(estimator.predict(input_fn))
self.assertAllEqual(outputs, np.array([[2, 8, 3, 0],
[2, 12, 3, 0]]))
def test_strip_diacritics(self):
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path,
lower_case=True,
nbest_size=0,
strip_diacritics=True)
inputs = tf.constant(["a b c d e", "ă ḅ č ḓ é"])
token_ids = sentencepiece_tokenizer(inputs)
self.assertAllEqual(
token_ids,
tf.ragged.constant([[7, 9, 10, 11, 13], [7, 9, 10, 11, 13]]))
def test_fail_on_tokenize_with_offsets_and_strip_diacritics(self):
# Raise an error in init().
with self.assertRaises(ValueError):
text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path,
tokenize_with_offsets=True,
lower_case=True,
nbest_size=0,
strip_diacritics=True)
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path,
lower_case=True,
nbest_size=0,
strip_diacritics=True)
sentencepiece_tokenizer.tokenize_with_offsets = True
# Raise an error in call():
inputs = tf.constant(["abc def", "ABC DEF d", "Äffin"])
with self.assertRaises(ValueError):
sentencepiece_tokenizer(inputs)
def test_serialize_deserialize(self):
self.skipTest("b/170480226")
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path,
lower_case=False,
nbest_size=0,
tokenize_with_offsets=False,
name="sentencepiece_tokenizer_layer")
config = sentencepiece_tokenizer.get_config()
new_tokenizer = text_layers.SentencepieceTokenizer.from_config(config)
self.assertEqual(config, new_tokenizer.get_config())
inputs = tf.constant(["abc def", "ABC DEF d"])
token_ids = sentencepiece_tokenizer(inputs)
token_ids_2 = new_tokenizer(inputs)
self.assertAllEqual(token_ids, token_ids_2)
# TODO(b/170480226): Remove once tf_hub_export_lib_test.py covers saving.
def test_saving(self):
sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
model_file_path=self._spm_path, lower_case=True, nbest_size=0)
inputs = tf.keras.layers.Input([], dtype=tf.string)
outputs = sentencepiece_tokenizer(inputs)
model = tf.keras.Model(inputs, outputs)
export_path = tempfile.mkdtemp(dir=self.get_temp_dir())
model.save(export_path, signatures={})
class BertPackInputsTest(tf.test.TestCase):
def test_round_robin_correct_outputs(self):
bpi = text_layers.BertPackInputs(
10,
start_of_sequence_id=1001,
end_of_segment_id=1002,
padding_id=999,
truncator="round_robin")
# Single input, rank 2.
bert_inputs = bpi(
tf.ragged.constant([[11, 12, 13],
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
[1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
self.assertAllEqual(
bert_inputs["input_mask"],
tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
# Two inputs, rank 3. Truncation does not respect word boundaries.
bert_inputs = bpi([
tf.ragged.constant([[[111], [112, 113]],
[[121, 122, 123], [124, 125, 126], [127, 128]]]),
tf.ragged.constant([[[211, 212], [213]],
[[221, 222], [223, 224, 225], [226, 227, 228]]])
])
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 213, 1002, 999],
[1001, 121, 122, 123, 124, 1002, 221, 222, 223, 1002]]))
self.assertAllEqual(
bert_inputs["input_mask"],
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]))
# Three inputs has not been supported for round_robin so far.
with self.assertRaisesRegex(ValueError, "Must pass 1 or 2 inputs"):
bert_inputs = bpi([
tf.ragged.constant([[[111], [112, 113]],
[[121, 122, 123], [124, 125, 126], [127, 128]]]),
tf.ragged.constant([[[211, 212], [213]],
[[221, 222], [223, 224, 225], [226, 227, 228]]]),
tf.ragged.constant([[[311, 312], [313]],
[[321, 322], [323, 324, 325], [326, 327, 328]]])
])
def test_waterfall_correct_outputs(self):
bpi = text_layers.BertPackInputs(
10,
start_of_sequence_id=1001,
end_of_segment_id=1002,
padding_id=999,
truncator="waterfall")
# Single input, rank 2.
bert_inputs = bpi(
tf.ragged.constant([[11, 12, 13],
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
[1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
self.assertAllEqual(
bert_inputs["input_mask"],
tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
# Two inputs, rank 3. Truncation does not respect word boundaries.
bert_inputs = bpi([
tf.ragged.constant([[[111], [112, 113]],
[[121, 122, 123], [124, 125, 126], [127, 128]]]),
tf.ragged.constant([[[211, 212], [213]],
[[221, 222], [223, 224, 225], [226, 227, 228]]])
])
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 213, 1002, 999],
[1001, 121, 122, 123, 124, 125, 126, 127, 1002, 1002]]))
self.assertAllEqual(
bert_inputs["input_mask"],
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]))
# Three inputs, rank 3. Truncation does not respect word boundaries.
bert_inputs = bpi([
tf.ragged.constant([[[111], [112, 113]],
[[121, 122, 123], [124, 125, 126], [127, 128]]]),
tf.ragged.constant([[[211], [212]],
[[221, 222], [223, 224, 225], [226, 227, 228]]]),
tf.ragged.constant([[[311, 312], [313]],
[[321, 322], [323, 324, 325], [326, 327]]])
])
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 1002, 311, 1002],
[1001, 121, 122, 123, 124, 125, 126, 1002, 1002, 1002]]))
self.assertAllEqual(
bert_inputs["input_mask"],
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
self.assertAllEqual(
bert_inputs["input_type_ids"],
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 2, 2],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 2]]))
def test_special_tokens_dict(self):
special_tokens_dict = dict(start_of_sequence_id=1001,
end_of_segment_id=1002,
padding_id=999,
extraneous_key=666)
bpi = text_layers.BertPackInputs(10,
special_tokens_dict=special_tokens_dict)
bert_inputs = bpi(
tf.ragged.constant([[11, 12, 13],
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
[1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment