Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
880154d2
Unverified
Commit
880154d2
authored
Apr 22, 2021
by
Patrick von Platen
Committed by
GitHub
Apr 22, 2021
Browse files
[Wav2Vec2] Fix special tokens for Wav2Vec2 tokenizer (#11349)
* fix wav2vec2 tok * up
parent
6f14eab5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
1 deletion
+28
-1
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
+7
-0
tests/test_tokenization_wav2vec2.py
tests/test_tokenization_wav2vec2.py
+21
-1
No files found.
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
View file @
880154d2
...
...
@@ -142,6 +142,12 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
self
.
encoder
=
json
.
load
(
vocab_handle
)
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
# make sure that tokens made of several
# characters are not split at tokenization
for
token
in
self
.
encoder
.
keys
():
if
len
(
token
)
>
1
:
self
.
unique_no_split_tokens
.
append
(
token
)
@
property
def
word_delimiter_token
(
self
)
->
str
:
"""
...
...
@@ -366,6 +372,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
with
open
(
vocab_file
,
encoding
=
"utf-8"
)
as
vocab_handle
:
self
.
encoder
=
json
.
load
(
vocab_handle
)
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
@
property
...
...
tests/test_tokenization_wav2vec2.py
View file @
880154d2
...
...
@@ -447,6 +447,26 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self
.
assertEqual
(
batch_tokens
,
[
"HELLO<unk>!?!?$$$"
,
"BYE BYE<unk>$$$"
])
def
test_special_characters_in_vocab
(
self
):
sent
=
"ʈʰ æ æ̃ ˧ kʰ"
vocab_dict
=
{
k
:
v
for
v
,
k
in
enumerate
({
phoneme
for
phoneme
in
sent
.
split
()})}
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
"vocab_special.json"
)
with
open
(
vocab_file
,
"w"
)
as
f
:
json
.
dump
(
vocab_dict
,
f
)
tokenizer
=
Wav2Vec2CTCTokenizer
(
vocab_file
)
expected_sent
=
tokenizer
.
decode
(
tokenizer
(
sent
).
input_ids
,
spaces_between_special_tokens
=
True
)
self
.
assertEqual
(
sent
,
expected_sent
)
tokenizer
.
save_pretrained
(
os
.
path
.
join
(
self
.
tmpdirname
,
"special_tokenizer"
))
tokenizer
=
Wav2Vec2CTCTokenizer
.
from_pretrained
(
os
.
path
.
join
(
self
.
tmpdirname
,
"special_tokenizer"
))
expected_sent
=
tokenizer
.
decode
(
tokenizer
(
sent
).
input_ids
,
spaces_between_special_tokens
=
True
)
self
.
assertEqual
(
sent
,
expected_sent
)
def
test_pretrained_model_lists
(
self
):
# Wav2Vec2Model has no max model length => no
# Wav2Vec2Model has no max model length => no
testing
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment