Unverified Commit 9e033649 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Ability to pickle/unpickle BatchEncoding pickle (reimport) (#5039)

* Added is_fast property on BatchEncoding to indicate if the object comes from a Fast Tokenizer.

* Added __get_state__() & __set_state__() to be pickable.

* Correct tokens() return type from List[int] to List[str]

* Added unittest for BatchEncoding pickle/unpickle

* Added unittest for BatchEncoding is_fast

* More careful checking on BatchEncoding unpickle tests.

* Formatting.

* is_fast should assertTrue on Rust tokenizers.

* Ensure tensorflow has correct way of checking array_equal

* More formatting.
parent f9f8a531
...@@ -155,6 +155,14 @@ class BatchEncoding(UserDict): ...@@ -155,6 +155,14 @@ class BatchEncoding(UserDict):
self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis) self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
@property
def is_fast(self):
"""
Indicate if this BatchEncoding was generated from the result of a PreTrainedTokenizerFast
Returns: True if generated from subclasses of PreTrainedTokenizerFast, else otherwise
"""
return self._encodings is not None
def __getitem__(self, item: Union[int, str]) -> EncodingFast: def __getitem__(self, item: Union[int, str]) -> EncodingFast:
""" If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...) """ If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...)
If the key is an integer, get the EncodingFast for batch item with index `key` If the key is an integer, get the EncodingFast for batch item with index `key`
...@@ -175,6 +183,16 @@ class BatchEncoding(UserDict): ...@@ -175,6 +183,16 @@ class BatchEncoding(UserDict):
except KeyError: except KeyError:
raise AttributeError raise AttributeError
def __getstate__(self):
return {"data": self.data, "encodings": self._encodings}
def __setstate__(self, state):
if "data" in state:
self.data = state["data"]
if "encodings" in state:
self._encodings = state["encodings"]
def keys(self): def keys(self):
return self.data.keys() return self.data.keys()
...@@ -197,7 +215,7 @@ class BatchEncoding(UserDict): ...@@ -197,7 +215,7 @@ class BatchEncoding(UserDict):
""" """
return self._encodings return self._encodings
def tokens(self, batch_index: int = 0) -> List[int]: def tokens(self, batch_index: int = 0) -> List[str]:
if not self._encodings: if not self._encodings:
raise ValueError("tokens() is not available when using Python based tokenizers") raise ValueError("tokens() is not available when using Python based tokenizers")
return self._encodings[batch_index].tokens return self._encodings[batch_index].tokens
......
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pickle
import unittest import unittest
from typing import Callable, Optional
from transformers import PreTrainedTokenizer, TensorType from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType
from transformers.tokenization_gpt2 import GPT2Tokenizer from transformers.tokenization_gpt2 import GPT2Tokenizer
from .utils import slow from .utils import require_tf, require_torch, slow
class TokenizerUtilsTest(unittest.TestCase): class TokenizerUtilsTest(unittest.TestCase):
...@@ -36,11 +36,103 @@ class TokenizerUtilsTest(unittest.TestCase): ...@@ -36,11 +36,103 @@ class TokenizerUtilsTest(unittest.TestCase):
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
self.assertIsInstance(special_tok_id, int) self.assertIsInstance(special_tok_id, int)
def assert_dump_and_restore(self, be_original: BatchEncoding, equal_op: Optional[Callable] = None):
batch_encoding_str = pickle.dumps(be_original)
self.assertIsNotNone(batch_encoding_str)
be_restored = pickle.loads(batch_encoding_str)
# Ensure is_fast is correctly restored
self.assertEqual(be_restored.is_fast, be_original.is_fast)
# Ensure encodings are potentially correctly restored
if be_original.is_fast:
self.assertIsNotNone(be_restored.encodings)
else:
self.assertIsNone(be_restored.encodings)
# Ensure the keys are the same
for original_v, restored_v in zip(be_original.values(), be_restored.values()):
if equal_op:
self.assertTrue(equal_op(restored_v, original_v))
else:
self.assertEqual(restored_v, original_v)
@slow @slow
def test_pretrained_tokenizers(self): def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer) self.check_tokenizer_from_pretrained(GPT2Tokenizer)
def check_tensor_type_from_str(self): def test_tensor_type_from_str(self):
self.assertEqual(TensorType("tf"), TensorType.TENSORFLOW) self.assertEqual(TensorType("tf"), TensorType.TENSORFLOW)
self.assertEqual(TensorType("pt"), TensorType.PYTORCH) self.assertEqual(TensorType("pt"), TensorType.PYTORCH)
self.assertEqual(TensorType("np"), TensorType.NUMPY) self.assertEqual(TensorType("np"), TensorType.NUMPY)
def test_batch_encoding_pickle(self):
import numpy as np
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
# Python no tensor
with self.subTest("BatchEncoding (Python, return_tensors=None)"):
self.assert_dump_and_restore(tokenizer_p("Small example to encode"))
with self.subTest("BatchEncoding (Python, return_tensors=NUMPY)"):
self.assert_dump_and_restore(
tokenizer_p("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal
)
with self.subTest("BatchEncoding (Rust, return_tensors=None)"):
self.assert_dump_and_restore(tokenizer_r("Small example to encode"))
with self.subTest("BatchEncoding (Rust, return_tensors=NUMPY)"):
self.assert_dump_and_restore(
tokenizer_r("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal
)
@require_tf
def test_batch_encoding_pickle_tf(self):
import tensorflow as tf
def tf_array_equals(t1, t2):
return tf.reduce_all(tf.equal(t1, t2))
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
with self.subTest("BatchEncoding (Python, return_tensors=TENSORFLOW)"):
self.assert_dump_and_restore(
tokenizer_p("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals
)
with self.subTest("BatchEncoding (Rust, return_tensors=TENSORFLOW)"):
self.assert_dump_and_restore(
tokenizer_r("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals
)
@require_torch
def test_batch_encoding_pickle_pt(self):
import torch
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
with self.subTest("BatchEncoding (Python, return_tensors=PYTORCH)"):
self.assert_dump_and_restore(
tokenizer_p("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal
)
with self.subTest("BatchEncoding (Rust, return_tensors=PYTORCH)"):
self.assert_dump_and_restore(
tokenizer_r("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal
)
def test_batch_encoding_is_fast(self):
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
with self.subTest("Python Tokenizer"):
self.assertFalse(tokenizer_p("Small example to_encode").is_fast)
with self.subTest("Rust Tokenizer"):
self.assertTrue(tokenizer_r("Small example to_encode").is_fast)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment