Unverified Commit a721a5ee authored by Muktan's avatar Muktan Committed by GitHub
Browse files

[Wav2vec2] Fixed tokenization mistakes while adding single-char tokens to tokenizer (#11538)



* Fixed tokenization mistakes while adding single-char tokens to tokenizer

* Added tests and Removed unnecessary comments.

* finalize wav2vec2 tok

* add more aggressive tests

* Apply suggestions from code review

* fix useless import
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent f3cf8ae7
......@@ -24,8 +24,8 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...utils import logging
......@@ -277,6 +277,61 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
return (vocab_file,)
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
it with indices starting from length of the current vocabulary.
Args:
new_tokens (:obj:`List[str]`or :obj:`List[tokenizers.AddedToken]`):
Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by
checking if the tokenizer assign the index of the ``unk_token`` to them).
special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the tokens should be added as special tokens.
Returns:
:obj:`int`: The number of tokens actually added to the vocabulary.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
"""
new_tokens = [str(tok) for tok in new_tokens]
tokens_to_add = []
for token in new_tokens:
assert isinstance(token, str)
if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case:
token = token.lower()
if (
token != self.unk_token
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in tokens_to_add
):
tokens_to_add.append(token)
if self.verbose:
logger.info(f"Adding {token} to the vocabulary")
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)
# Make sure we don't split on any special tokens (even they were already in the vocab before)
for token in tokens_to_add:
if len(token) > 1:
self._additional_special_tokens.append(AddedToken(token))
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, token)
return len(tokens_to_add)
class Wav2Vec2Tokenizer(PreTrainedTokenizer):
"""
......
......@@ -375,6 +375,38 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
kwargs.update(self.special_tokens_map)
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def test_tokenizer_add_token_chars(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
# check adding a single token
tokenizer.add_tokens("x")
token_ids = tokenizer("C x A").input_ids
self.assertEqual(token_ids, [19, 4, 32, 4, 7])
tokenizer.add_tokens(["a", "b", "c"])
token_ids = tokenizer("C a A c").input_ids
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35])
tokenizer.add_tokens(["a", "b", "c"])
token_ids = tokenizer("CaA c").input_ids
self.assertEqual(token_ids, [19, 33, 7, 4, 35])
def test_tokenizer_add_token_words(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
# check adding a single token
tokenizer.add_tokens("xxx")
token_ids = tokenizer("C xxx A B").input_ids
self.assertEqual(token_ids, [19, 4, 32, 4, 7, 4, 24])
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
token_ids = tokenizer("C aaa A ccc B B").input_ids
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35, 4, 24, 4, 24])
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
token_ids = tokenizer("CaaaA ccc B B").input_ids
self.assertEqual(token_ids, [19, 33, 7, 4, 35, 4, 24, 4, 24])
def test_tokenizer_decode(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
......@@ -470,3 +502,55 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
def test_pretrained_model_lists(self):
# Wav2Vec2Model has no max model length => no testing
pass
# overwrite from test_tokenization_common
def test_add_tokens_tokenizer(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)
self.assertNotEqual(vocab_size, 0)
# We usually have added tokens from the start in tests because our vocab fixtures are
# smaller than the original vocabs - let's not assert this
# self.assertEqual(vocab_size, all_size)
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)
self.assertNotEqual(vocab_size_2, 0)
self.assertEqual(vocab_size, vocab_size_2)
self.assertEqual(added_toks, len(new_toks))
self.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
self.assertGreaterEqual(len(tokens), 4)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
self.assertNotEqual(vocab_size_3, 0)
self.assertEqual(vocab_size, vocab_size_3)
self.assertEqual(added_toks_2, len(new_toks_2))
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
)
self.assertGreaterEqual(len(tokens), 6)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[0], tokens[1])
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-3], tokens[-4])
self.assertEqual(tokens[0], tokenizer.eos_token_id)
self.assertEqual(tokens[-3], tokenizer.pad_token_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment