Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e323017
Unverified
Commit
5e323017
authored
Oct 23, 2020
by
Anthony MOI
Committed by
GitHub
Oct 23, 2020
Browse files
Fix BatchEncoding.word_to_tokens for removed tokens (#7939)
parent
4acfd1a8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
5 deletions
+16
-5
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+6
-4
tests/test_tokenization_utils.py
tests/test_tokenization_utils.py
+10
-1
No files found.
src/transformers/tokenization_utils_base.py
View file @
5e323017
...
...
@@ -364,7 +364,7 @@ class BatchEncoding(UserDict):
token_index
=
self
.
_seq_len
+
token_index
return
self
.
_encodings
[
batch_index
].
token_to_word
(
token_index
)
def
word_to_tokens
(
self
,
batch_or_word_index
:
int
,
word_index
:
Optional
[
int
]
=
None
)
->
TokenSpan
:
def
word_to_tokens
(
self
,
batch_or_word_index
:
int
,
word_index
:
Optional
[
int
]
=
None
)
->
Optional
[
TokenSpan
]
:
"""
Get the encoded token span corresponding to a word in the sequence of the batch.
...
...
@@ -391,8 +391,9 @@ class BatchEncoding(UserDict):
of the word in the sequence.
Returns:
:class:`~transformers.tokenization_utils_base.TokenSpan`
Span of tokens in the encoded sequence.
Optional :class:`~transformers.tokenization_utils_base.TokenSpan`
Span of tokens in the encoded sequence. Returns :obj:`None` if no tokens correspond
to the word.
"""
if
not
self
.
_encodings
:
...
...
@@ -406,7 +407,8 @@ class BatchEncoding(UserDict):
batch_index
=
self
.
_batch_size
+
batch_index
if
word_index
<
0
:
word_index
=
self
.
_seq_len
+
word_index
return
TokenSpan
(
*
(
self
.
_encodings
[
batch_index
].
word_to_tokens
(
word_index
)))
span
=
self
.
_encodings
[
batch_index
].
word_to_tokens
(
word_index
)
return
TokenSpan
(
*
span
)
if
span
is
not
None
else
None
def
token_to_chars
(
self
,
batch_or_token_index
:
int
,
token_index
:
Optional
[
int
]
=
None
)
->
CharSpan
:
"""
...
...
tests/test_tokenization_utils.py
View file @
5e323017
...
...
@@ -18,7 +18,7 @@ from typing import Callable, Optional
import
numpy
as
np
from
transformers
import
BatchEncoding
,
BertTokenizer
,
BertTokenizerFast
,
PreTrainedTokenizer
,
TensorType
from
transformers
import
BatchEncoding
,
BertTokenizer
,
BertTokenizerFast
,
PreTrainedTokenizer
,
TensorType
,
TokenSpan
from
transformers.testing_utils
import
require_tf
,
require_tokenizers
,
require_torch
,
slow
from
transformers.tokenization_gpt2
import
GPT2Tokenizer
...
...
@@ -142,6 +142,15 @@ class TokenizerUtilsTest(unittest.TestCase):
with
self
.
subTest
(
"Rust Tokenizer"
):
self
.
assertTrue
(
tokenizer_r
(
"Small example to_encode"
).
is_fast
)
@
require_tokenizers
def
test_batch_encoding_word_to_tokens
(
self
):
tokenizer_r
=
BertTokenizerFast
.
from_pretrained
(
"bert-base-cased"
)
encoded
=
tokenizer_r
([
"Test"
,
"
\xad
"
,
"test"
],
is_split_into_words
=
True
)
self
.
assertEqual
(
encoded
.
word_to_tokens
(
0
),
TokenSpan
(
start
=
1
,
end
=
2
))
self
.
assertEqual
(
encoded
.
word_to_tokens
(
1
),
None
)
self
.
assertEqual
(
encoded
.
word_to_tokens
(
2
),
TokenSpan
(
start
=
2
,
end
=
3
))
def
test_batch_encoding_with_labels
(
self
):
batch
=
BatchEncoding
({
"inputs"
:
[[
1
,
2
,
3
],
[
4
,
5
,
6
]],
"labels"
:
[
0
,
1
]})
tensor_batch
=
batch
.
convert_to_tensors
(
tensor_type
=
"np"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment