Unverified Commit 5e323017 authored by Anthony MOI's avatar Anthony MOI Committed by GitHub
Browse files

Fix BatchEncoding.word_to_tokens for removed tokens (#7939)

parent 4acfd1a8
...@@ -364,7 +364,7 @@ class BatchEncoding(UserDict): ...@@ -364,7 +364,7 @@ class BatchEncoding(UserDict):
token_index = self._seq_len + token_index token_index = self._seq_len + token_index
return self._encodings[batch_index].token_to_word(token_index) return self._encodings[batch_index].token_to_word(token_index)
def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> TokenSpan: def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> Optional[TokenSpan]:
""" """
Get the encoded token span corresponding to a word in the sequence of the batch. Get the encoded token span corresponding to a word in the sequence of the batch.
...@@ -391,8 +391,9 @@ class BatchEncoding(UserDict): ...@@ -391,8 +391,9 @@ class BatchEncoding(UserDict):
of the word in the sequence. of the word in the sequence.
Returns: Returns:
:class:`~transformers.tokenization_utils_base.TokenSpan` Optional :class:`~transformers.tokenization_utils_base.TokenSpan`
Span of tokens in the encoded sequence. Span of tokens in the encoded sequence. Returns :obj:`None` if no tokens correspond
to the word.
""" """
if not self._encodings: if not self._encodings:
...@@ -406,7 +407,8 @@ class BatchEncoding(UserDict): ...@@ -406,7 +407,8 @@ class BatchEncoding(UserDict):
batch_index = self._batch_size + batch_index batch_index = self._batch_size + batch_index
if word_index < 0: if word_index < 0:
word_index = self._seq_len + word_index word_index = self._seq_len + word_index
return TokenSpan(*(self._encodings[batch_index].word_to_tokens(word_index))) span = self._encodings[batch_index].word_to_tokens(word_index)
return TokenSpan(*span) if span is not None else None
def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan: def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
""" """
......
...@@ -18,7 +18,7 @@ from typing import Callable, Optional ...@@ -18,7 +18,7 @@ from typing import Callable, Optional
import numpy as np import numpy as np
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType, TokenSpan
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow
from transformers.tokenization_gpt2 import GPT2Tokenizer from transformers.tokenization_gpt2 import GPT2Tokenizer
...@@ -142,6 +142,15 @@ class TokenizerUtilsTest(unittest.TestCase): ...@@ -142,6 +142,15 @@ class TokenizerUtilsTest(unittest.TestCase):
with self.subTest("Rust Tokenizer"): with self.subTest("Rust Tokenizer"):
self.assertTrue(tokenizer_r("Small example to_encode").is_fast) self.assertTrue(tokenizer_r("Small example to_encode").is_fast)
@require_tokenizers
def test_batch_encoding_word_to_tokens(self):
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
encoded = tokenizer_r(["Test", "\xad", "test"], is_split_into_words=True)
self.assertEqual(encoded.word_to_tokens(0), TokenSpan(start=1, end=2))
self.assertEqual(encoded.word_to_tokens(1), None)
self.assertEqual(encoded.word_to_tokens(2), TokenSpan(start=2, end=3))
def test_batch_encoding_with_labels(self): def test_batch_encoding_with_labels(self):
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]}) batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
tensor_batch = batch.convert_to_tensors(tensor_type="np") tensor_batch = batch.convert_to_tensors(tensor_type="np")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment