Unverified Commit 9e033649 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Ability to pickle/unpickle BatchEncoding pickle (reimport) (#5039)

* Added is_fast property on BatchEncoding to indicate if the object comes from a Fast Tokenizer.

* Added __get_state__() & __set_state__() to be pickable.

* Correct tokens() return type from List[int] to List[str]

* Added unittest for BatchEncoding pickle/unpickle

* Added unittest for BatchEncoding is_fast

* More careful checking on BatchEncoding unpickle tests.

* Formatting.

* is_fast should assertTrue on Rust tokenizers.

* Ensure tensorflow has correct way of checking array_equal

* More formatting.
parent f9f8a531
......@@ -155,6 +155,14 @@ class BatchEncoding(UserDict):
self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
@property
def is_fast(self):
"""
Indicate if this BatchEncoding was generated from the result of a PreTrainedTokenizerFast
Returns: True if generated from subclasses of PreTrainedTokenizerFast, else otherwise
"""
return self._encodings is not None
def __getitem__(self, item: Union[int, str]) -> EncodingFast:
""" If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...)
If the key is an integer, get the EncodingFast for batch item with index `key`
......@@ -175,6 +183,16 @@ class BatchEncoding(UserDict):
except KeyError:
raise AttributeError
def __getstate__(self):
return {"data": self.data, "encodings": self._encodings}
def __setstate__(self, state):
if "data" in state:
self.data = state["data"]
if "encodings" in state:
self._encodings = state["encodings"]
def keys(self):
return self.data.keys()
......@@ -197,7 +215,7 @@ class BatchEncoding(UserDict):
"""
return self._encodings
def tokens(self, batch_index: int = 0) -> List[int]:
def tokens(self, batch_index: int = 0) -> List[str]:
if not self._encodings:
raise ValueError("tokens() is not available when using Python based tokenizers")
return self._encodings[batch_index].tokens
......
......@@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import unittest
from typing import Callable, Optional
from transformers import PreTrainedTokenizer, TensorType
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType
from transformers.tokenization_gpt2 import GPT2Tokenizer
from .utils import slow
from .utils import require_tf, require_torch, slow
class TokenizerUtilsTest(unittest.TestCase):
......@@ -36,11 +36,103 @@ class TokenizerUtilsTest(unittest.TestCase):
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
self.assertIsInstance(special_tok_id, int)
def assert_dump_and_restore(self, be_original: BatchEncoding, equal_op: Optional[Callable] = None):
batch_encoding_str = pickle.dumps(be_original)
self.assertIsNotNone(batch_encoding_str)
be_restored = pickle.loads(batch_encoding_str)
# Ensure is_fast is correctly restored
self.assertEqual(be_restored.is_fast, be_original.is_fast)
# Ensure encodings are potentially correctly restored
if be_original.is_fast:
self.assertIsNotNone(be_restored.encodings)
else:
self.assertIsNone(be_restored.encodings)
# Ensure the keys are the same
for original_v, restored_v in zip(be_original.values(), be_restored.values()):
if equal_op:
self.assertTrue(equal_op(restored_v, original_v))
else:
self.assertEqual(restored_v, original_v)
@slow
def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
def check_tensor_type_from_str(self):
def test_tensor_type_from_str(self):
self.assertEqual(TensorType("tf"), TensorType.TENSORFLOW)
self.assertEqual(TensorType("pt"), TensorType.PYTORCH)
self.assertEqual(TensorType("np"), TensorType.NUMPY)
def test_batch_encoding_pickle(self):
import numpy as np
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
# Python no tensor
with self.subTest("BatchEncoding (Python, return_tensors=None)"):
self.assert_dump_and_restore(tokenizer_p("Small example to encode"))
with self.subTest("BatchEncoding (Python, return_tensors=NUMPY)"):
self.assert_dump_and_restore(
tokenizer_p("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal
)
with self.subTest("BatchEncoding (Rust, return_tensors=None)"):
self.assert_dump_and_restore(tokenizer_r("Small example to encode"))
with self.subTest("BatchEncoding (Rust, return_tensors=NUMPY)"):
self.assert_dump_and_restore(
tokenizer_r("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal
)
@require_tf
def test_batch_encoding_pickle_tf(self):
import tensorflow as tf
def tf_array_equals(t1, t2):
return tf.reduce_all(tf.equal(t1, t2))
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
with self.subTest("BatchEncoding (Python, return_tensors=TENSORFLOW)"):
self.assert_dump_and_restore(
tokenizer_p("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals
)
with self.subTest("BatchEncoding (Rust, return_tensors=TENSORFLOW)"):
self.assert_dump_and_restore(
tokenizer_r("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals
)
@require_torch
def test_batch_encoding_pickle_pt(self):
import torch
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
with self.subTest("BatchEncoding (Python, return_tensors=PYTORCH)"):
self.assert_dump_and_restore(
tokenizer_p("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal
)
with self.subTest("BatchEncoding (Rust, return_tensors=PYTORCH)"):
self.assert_dump_and_restore(
tokenizer_r("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal
)
def test_batch_encoding_is_fast(self):
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
with self.subTest("Python Tokenizer"):
self.assertFalse(tokenizer_p("Small example to_encode").is_fast)
with self.subTest("Rust Tokenizer"):
self.assertTrue(tokenizer_r("Small example to_encode").is_fast)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment