Unverified Commit 135791e8 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Add pad_to_multiple_of on tokenizers (reimport) (#5054)



* Add new parameter `pad_to_multiple_of` on tokenizers.

* unittest for pad_to_multiple_of

* Add .name when logging enum.

* Fix missing .items() on dict in tests.

* Add special check + warning if the tokenizer doesn't have proper pad_token.

* Use the correct logger format specifier.

* Ensure tokenizer with no pad_token do not modify the underlying padding strategy.

* Skip test if tokenizer doesn't have pad_token

* Fix RobertaTokenizer on empty input

* Format.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* fix and updating to simpler API
Co-authored-by: default avatarThomas Wolf <thomwolf@users.noreply.github.com>
parent 7cc15bdd
......@@ -244,7 +244,7 @@ class RobertaTokenizer(GPT2Tokenizer):
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if (is_pretokenized or add_prefix_space) and text:
if (is_pretokenized or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
text = " " + text
return (text, kwargs)
......
......@@ -409,6 +409,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
......@@ -461,6 +462,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
prepend_batch_axis=True,
return_attention_mask=return_attention_mask,
......@@ -487,6 +489,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
......@@ -541,6 +544,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
......@@ -561,6 +565,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
......@@ -587,6 +592,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
pad_to_multiple_of=None, # we pad in batch afterward
return_attention_mask=False, # we pad in batch afterward
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
......@@ -606,6 +612,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
batch_outputs,
padding=padding_strategy.value,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
......@@ -623,6 +630,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[str] = None,
prepend_batch_axis: bool = False,
return_token_type_ids: Optional[bool] = None,
......@@ -654,8 +662,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
encoded_inputs = {}
# Truncation: Handle max sequence length
# Compute the total size of the returned encodings
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
# Truncation: Handle max sequence length
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids,
......@@ -700,6 +710,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
encoded_inputs,
max_length=max_length,
padding=padding_strategy.value,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
......
......@@ -960,6 +960,9 @@ ENCODE_KWARGS_DOCSTRING = r"""
The value of this argument defines the number of overlapping tokens.
is_pretokenized (:obj:`bool`, defaults to :obj:`False`):
Set to True to indicate the input is already tokenized
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
PyTorch :obj:`torch.Tensor` or Numpy :oj: `np.ndarray` instead of a list of python integers.
......@@ -1427,7 +1430,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
raise NotImplementedError
def _get_padding_truncation_strategies(
self, padding=False, truncation=False, max_length=None, verbose=True, **kwargs
self, padding=False, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
):
""" Find the correct padding/truncation strategy with backward compatibility
for old arguments (truncation_strategy and pad_to_max_length) and behaviors.
......@@ -1527,6 +1530,19 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
)
# Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
if (
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
and padding_strategy != PaddingStrategy.DO_NOT_PAD
and pad_to_multiple_of is not None
and max_length is not None
and (max_length % pad_to_multiple_of != 0)
):
raise ValueError(
f"Truncation and padding are both activated but "
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
)
return padding_strategy, truncation_strategy, max_length, kwargs
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
......@@ -1540,6 +1556,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
......@@ -1581,6 +1598,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
max_length=max_length,
stride=stride,
is_pretokenized=is_pretokenized,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
......@@ -1601,6 +1619,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
max_length=max_length,
stride=stride,
is_pretokenized=is_pretokenized,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
......@@ -1623,6 +1642,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
......@@ -1650,7 +1670,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding, truncation, max_length, verbose, **kwargs
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
return self._encode_plus(
......@@ -1662,6 +1687,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
max_length=max_length,
stride=stride,
is_pretokenized=is_pretokenized,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
......@@ -1683,6 +1709,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
......@@ -1712,6 +1739,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
......@@ -1738,7 +1766,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding, truncation, max_length, verbose, **kwargs
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
return self._batch_encode_plus(
......@@ -1749,6 +1782,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
max_length=max_length,
stride=stride,
is_pretokenized=is_pretokenized,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
......@@ -1776,6 +1810,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
......@@ -1799,6 +1834,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
],
padding: Union[bool, str] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
verbose: bool = True,
......@@ -1820,6 +1856,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
- 'do_not_pad' (or `False`): Do not pad
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
......@@ -1842,7 +1881,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
encoded_inputs["attention_mask"] = []
return encoded_inputs
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
# Convert padding_strategy in PaddingStrategy
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
padding=padding, max_length=max_length, verbose=verbose
)
......@@ -1852,6 +1891,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
encoded_inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
......@@ -1872,6 +1912,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
......@@ -1887,6 +1928,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
""" Pad encoded inputs (on left/right and up to predefined legnth or max length in the batch)
......@@ -1902,6 +1944,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
......@@ -1911,6 +1956,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(encoded_inputs["input_ids"])
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = (
padding_strategy != PaddingStrategy.DO_NOT_PAD and len(encoded_inputs["input_ids"]) != max_length
)
......
......@@ -241,25 +241,26 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return self._tokenizer.encode(text, pair, add_special_tokens=add_special_tokens).tokens
def set_truncation_and_padding(
self, padding_strategy: PaddingStrategy, truncation_strategy: TruncationStrategy, max_length: int, stride: int,
self,
padding_strategy: PaddingStrategy,
truncation_strategy: TruncationStrategy,
max_length: int,
stride: int,
pad_to_multiple_of: Optional[int],
):
""" This contextmanager is in charge of defining the truncation and the padding strategies for fast tokenizers
""" Define the truncation and the padding strategies for fast tokenizers
(provided by HuggingFace tokenizers library) and restore the tokenizer settings afterwards.
This contextmanager assumes the provider tokenizer has no padding / truncation strategy
The provided tokenizer has no padding / truncation strategy
before the managed section. If your tokenizer set a padding / truncation strategy before,
then it will be reset to no padding/truncation when exiting the managed section.
Args:
tokenizer (BaseTokenizerFast): The tokenizer which will be used
max_length (int): The maximum size of the sequence
stride (int): The stride to use when handling overflow
strategy (str): Overflowing logic to use
pad_to_max_length (bool): Boolean indicating if the output needs to be padded up to max_length
padding_side (str): "left" or "right" indicating the direction the output sequence will be padded
pad_token_id (int): The integer representation of the padding token to use
pad_token_type_id (int): The integer representation of the padding token type to use
pad_token (str): The string representation of the padding token to use
padding_strategy (:obj:`PaddingStrategy`): The kind of padding that will be applied to the input
truncation_strategy (:obj:`TruncationStrategy`): The kind of truncation that will be applied to the input
max_length (:obj:`int`): The maximum size of the sequence
stride (:obj:`int`): The stride to use when handling overflow
pad_to_multiple_of (:obj:`int`, `optional`, defaults to `None`)
"""
# Set truncation and padding on the backend tokenizer
......@@ -275,6 +276,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
pad_id=self.pad_token_id,
pad_type_id=self.pad_token_type_id,
pad_token=self.pad_token,
pad_to_multiple_of=pad_to_multiple_of,
)
else:
self._tokenizer.no_padding()
......@@ -290,6 +292,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
......@@ -315,6 +318,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
)
# Avoid thread overhead if only one example.
......@@ -383,6 +387,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[bool] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
......@@ -403,6 +408,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
......
......@@ -883,6 +883,40 @@ class TokenizerTesterMixin:
assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right
def test_padding_to_multiple_of(self):
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
if tokenizer.pad_token is None:
self.skipTest("No padding token.")
else:
with self.subTest(f"{tokenizer.__class__.__name__}"):
empty_tokens = tokenizer("", padding=True, pad_to_multiple_of=8)
normal_tokens = tokenizer("This is a sample input", padding=True, pad_to_multiple_of=8)
for key, value in empty_tokens.items():
self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
for key, value in normal_tokens.items():
self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
normal_tokens = tokenizer("This", pad_to_multiple_of=8)
for key, value in normal_tokens.items():
self.assertNotEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
# Should also work with truncation
normal_tokens = tokenizer("This", padding=True, truncation=True, pad_to_multiple_of=8)
for key, value in normal_tokens.items():
self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
# truncation to something which is not a multiple of pad_to_multiple_of raises an error
self.assertRaises(
ValueError,
tokenizer.__call__,
"This",
padding=True,
truncation=True,
max_length=12,
pad_to_multiple_of=8,
)
def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment