"docs/source/vscode:/vscode.git/clone" did not exist on "146c521235ca057570cac4c1fc3f884ac464e580"
Unverified Commit 5bf9afbf authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Introduce a new tensor type for return_tensors on tokenizer for NumPy (#4585)

* Refactor tensor creation in tokenizers.

* Make sure to convert string to TensorType

* Refactor convert_to_tensors_

* Introduce numpy tensor creation

* Format

* Add unittest for TensorType creation from str

* sorting imports

* Added unittests for numpy tensor conversion.

* Do not use in-place version for squeeze as numpy doesn't provide such feature.

* Added extra parameter prepend_batch_axis: bool on prepare_for_model.

* Ensure test_np_encode_plus_sent_to_model is not executed if encoder/decoder model.

* style.

* numpy tests require_torch for now while flax not merged.

* Hopefully will make flake8 happy.

* One more time 🎶
parent efae1549
......@@ -132,7 +132,7 @@ from .tokenization_reformer import ReformerTokenizer
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils import PreTrainedTokenizer, TensorType
from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
......
......@@ -25,8 +25,10 @@ import re
import warnings
from collections import UserDict, defaultdict
from contextlib import contextmanager
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from enum import Enum
from typing import Any, Dict, List, MutableMapping, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
from tokenizers import AddedToken as AddedTokenFast
from tokenizers import Encoding as EncodingFast
from tokenizers.decoders import Decoder as DecoderFast
......@@ -42,6 +44,18 @@ if is_torch_available():
logger = logging.getLogger(__name__)
NO_PAD_TOKEN_FOR_BATCH_MSG = (
"No padding token is set for this model, therefore no batch can be made with uneven "
"sequences. Set a padding token or adjust the lengths of the sequences building the "
"batch so that every sequence is of the same length."
)
UNEVEN_SEQUENCES_FOR_BATCH_MSG = (
"The sequences building the batch are not of the same size, no tensor "
"can be built. Set `pad_to_max_length=True` to pad the smaller sequences"
"up to the larger sequence's length."
)
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
......@@ -58,6 +72,12 @@ PreTokenizedInputPair = Tuple[List[str], List[str]]
EncodedInputPair = Tuple[List[int], List[int]]
class TensorType(Enum):
PYTORCH = "pt"
TENSORFLOW = "tf"
NUMPY = "np"
class CharSpan(NamedTuple):
""" Character span in the original string
......@@ -161,6 +181,51 @@ def truncate_and_pad(
tokenizer.no_padding()
def convert_to_tensors(
batch_outputs: MutableMapping, return_tensors: Union[str, TensorType], prepend_batch_axis: bool = False
) -> MutableMapping:
# Convert to TensorType
if not isinstance(return_tensors, TensorType):
return_tensors = TensorType(return_tensors)
# Get a function reference for the correct framework
if return_tensors == TensorType.TENSORFLOW and is_tf_available():
as_tensor = tf.constant
elif return_tensors == TensorType.PYTORCH and is_torch_available():
as_tensor = torch.tensor
elif return_tensors == TensorType.NUMPY:
as_tensor = np.asarray
else:
raise ImportError(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors
)
)
# Do the tensor conversion in batch
for key, value in batch_outputs.items():
try:
if prepend_batch_axis:
value = [value]
tensor = as_tensor(value)
# at-least2d
if tensor.ndim > 2:
tensor = tensor.squeeze(0)
elif tensor.ndim < 2:
tensor = tensor[None, :]
batch_outputs[key] = tensor
except ValueError:
if None in [item for sequence in value for item in sequence]:
raise ValueError(NO_PAD_TOKEN_FOR_BATCH_MSG)
else:
raise ValueError(UNEVEN_SEQUENCES_FOR_BATCH_MSG)
return batch_outputs
class BatchEncoding(UserDict):
""" BatchEncoding hold the output of the encode and batch_encode methods (tokens, attention_masks, etc).
This class is derived from a python Dictionary and can be used as a dictionnary.
......@@ -755,18 +820,6 @@ class PreTrainedTokenizer(SpecialTokensMixin):
padding_side: str = "right"
NO_PAD_TOKEN_FOR_BATCH_MSG = (
"No padding token is set for this model, therefore no batch can be made with uneven "
"sequences. Set a padding token or adjust the lengths of the sequences building the "
"batch so that every sequence is of the same length."
)
UNEVEN_SEQUENCES_FOR_BATCH_MSG = (
"The sequences building the batch are not of the same size, no tensor "
"can be built. Set `pad_to_max_length=True` to pad the smaller sequences"
"up to the larger sequence's length."
)
@property
def vocab_size(self) -> int:
""" Size of the base vocabulary (without the added tokens) """
......@@ -1373,7 +1426,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
stride: int = 0,
truncation_strategy: str = "longest_first",
pad_to_max_length: bool = False,
return_tensors: Optional[str] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
):
"""
......@@ -1447,7 +1500,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
truncation_strategy: str = "longest_first",
pad_to_max_length: bool = False,
is_pretokenized: bool = False,
return_tensors: Optional[str] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
......@@ -1590,6 +1643,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
prepend_batch_axis=return_tensors is not None,
)
def batch_encode_plus(
......@@ -1608,7 +1662,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
truncation_strategy: str = "longest_first",
pad_to_max_length: bool = False,
is_pretokenized: bool = False,
return_tensors: Optional[str] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_masks: Optional[bool] = None,
return_overflowing_tokens: bool = False,
......@@ -1783,39 +1837,10 @@ class PreTrainedTokenizer(SpecialTokensMixin):
batch_outputs[key].append(value)
if return_tensors is not None:
convert_to_tensors(batch_outputs, return_tensors)
self.convert_to_tensors_(batch_outputs, return_tensors)
return BatchEncoding(batch_outputs)
def convert_to_tensors_(self, batch_outputs: dict, return_tensors: str) -> None:
# Do the tensor conversion in batch
for key, value in batch_outputs.items():
if return_tensors == "tf" and is_tf_available():
try:
batch_outputs[key] = tf.constant(value)
except ValueError:
if None in [item for sequence in value for item in sequence]:
raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG)
else:
raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG)
elif return_tensors == "pt" and is_torch_available():
try:
batch_outputs[key] = torch.tensor(value)
except ValueError:
raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG)
except RuntimeError:
if None in [item for sequence in value for item in sequence]:
raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG)
else:
raise
elif return_tensors is not None:
logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors
)
)
def prepare_for_model(
self,
ids: List[int],
......@@ -1825,12 +1850,13 @@ class PreTrainedTokenizer(SpecialTokensMixin):
stride: int = 0,
truncation_strategy: str = "longest_first",
pad_to_max_length: bool = False,
return_tensors: Optional[str] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_lengths: bool = False,
prepend_batch_axis: bool = False,
) -> BatchEncoding:
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
......@@ -1866,6 +1892,9 @@ class PreTrainedTokenizer(SpecialTokensMixin):
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
return_lengths (:obj:`bool`, `optional`, defaults to :obj:`False`):
If set the resulting dictionary will include the length of each encoded inputs
prepend_batch_axis (:obj:`bool`, `optional`, defaults to :obj:`False`):
If set the resulting object will feature an extra dim at position 0.
This can be seen as an unsqueezing operator.
Return:
A Dictionary of shape::
......@@ -1990,29 +2019,8 @@ class PreTrainedTokenizer(SpecialTokensMixin):
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
# Prepare model inputs as tensors if asked
if return_tensors == "tf" and is_tf_available():
encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]])
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = tf.constant([encoded_inputs["token_type_ids"]])
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = tf.constant([encoded_inputs["attention_mask"]])
elif return_tensors == "pt" and is_torch_available():
encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]])
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = torch.tensor([encoded_inputs["token_type_ids"]])
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = torch.tensor([encoded_inputs["attention_mask"]])
elif return_tensors is not None:
logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors
)
)
if return_tensors is not None:
convert_to_tensors(encoded_inputs, return_tensors, prepend_batch_axis)
return BatchEncoding(encoded_inputs)
......@@ -2305,7 +2313,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
def _convert_encoding(
self,
encoding: EncodingFast,
return_tensors: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
......@@ -2345,16 +2353,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
encoding_dict["offset_mapping"].append(e.offsets)
if return_tensors is not None:
for key, value in encoding_dict.items():
if return_tensors == "tf" and is_tf_available():
encoding_dict[key] = tf.constant(value)
elif return_tensors == "pt" and is_torch_available():
encoding_dict[key] = torch.tensor(value)
elif return_tensors is not None:
logger.warning(
"Unable to convert output to tensors format {}, "
"PyTorch or TensorFlow is not available.".format(return_tensors)
)
encoding_dict = convert_to_tensors(encoding_dict, return_tensors)
return encoding_dict
......@@ -2438,7 +2437,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
truncation_strategy: str = "longest_first",
pad_to_max_length: bool = False,
is_pretokenized: bool = False,
return_tensors: Optional[str] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
......@@ -2575,7 +2574,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
stride: int = 0,
truncation_strategy: str = "longest_first",
is_pretokenized: bool = False,
return_tensors: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
......
......@@ -832,3 +832,47 @@ class TokenizerTesterMixin:
# This should not fail
model(encoded_sequence_fast)
model(batch_encoded_sequence_fast)
# TODO: Check if require_torch is the best to test for numpy here ... Maybe move to require_flax when available
@require_torch
def test_np_encode_plus_sent_to_model(self):
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
tokenizer = self.get_tokenizer()
if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
return
config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
config = config_class()
if config.is_encoder_decoder or config.pad_token_id is None:
return
# Build sequence
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
sequence = " ".join(first_ten_tokens)
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="np")
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="np")
# TODO: add forward through JAX/Flax when PR is merged
# This is currently here to make flake8 happy !
if encoded_sequence is None:
raise ValueError("Cannot convert list to numpy tensor on encode_plus()")
if batch_encoded_sequence is None:
raise ValueError("Cannot convert list to numpy tensor on batch_encode_plus()")
if self.test_rust_tokenizer:
fast_tokenizer = self.get_rust_tokenizer()
encoded_sequence_fast = fast_tokenizer.encode_plus(sequence, return_tensors="np")
batch_encoded_sequence_fast = fast_tokenizer.batch_encode_plus([sequence, sequence], return_tensors="np")
# TODO: add forward through JAX/Flax when PR is merged
# This is currently here to make flake8 happy !
if encoded_sequence_fast is None:
raise ValueError("Cannot convert list to numpy tensor on encode_plus() (fast)")
if batch_encoded_sequence_fast is None:
raise ValueError("Cannot convert list to numpy tensor on batch_encode_plus() (fast)")
......@@ -16,7 +16,7 @@
import unittest
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, TensorType
from transformers.tokenization_gpt2 import GPT2Tokenizer
from .utils import slow
......@@ -39,3 +39,8 @@ class TokenizerUtilsTest(unittest.TestCase):
@slow
def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
def check_tensor_type_from_str(self):
self.assertEqual(TensorType("tf"), TensorType.TENSORFLOW)
self.assertEqual(TensorType("pt"), TensorType.PYTORCH)
self.assertEqual(TensorType("np"), TensorType.NUMPY)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment