Unverified Commit d1578159 authored by Matt's avatar Matt Committed by GitHub
Browse files

Allow add_tokens for ESM (#28535)



* Allow non-special tokens to be added

* Add test, fix token adding code

* Revert changes to id_to_token and token_to_id

* Update the ESM tokenizer to be a bit more standardized

* Update src/transformers/models/esm/tokenization_esm.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 5b7f4bc6
......@@ -14,10 +14,9 @@
# limitations under the License.
"""Tokenization classes for ESM."""
import os
from typing import List, Optional, Union
from typing import List, Optional
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import AddedToken
from ...utils import logging
......@@ -91,11 +90,10 @@ class EsmTokenizer(PreTrainedTokenizer):
def _tokenize(self, text, **kwargs):
return text.split()
def get_vocab_size(self, with_added_tokens=False):
return len(self._id_to_token)
def get_vocab(self):
return {token: i for i, token in enumerate(self.all_tokens)}
base_vocab = self._token_to_id.copy()
base_vocab.update(self.added_tokens_encoder)
return base_vocab
def token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
......@@ -156,7 +154,4 @@ class EsmTokenizer(PreTrainedTokenizer):
@property
def vocab_size(self) -> int:
return self.get_vocab_size(with_added_tokens=False)
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
return super()._add_tokens(new_tokens, special_tokens=True)
return len(self.all_tokens)
......@@ -87,3 +87,25 @@ class ESMTokenizationTest(unittest.TestCase):
self.assertEqual(len(token_2), 1)
self.assertEqual(token_1[0], SPECIAL_TOKEN_1)
self.assertEqual(token_2[0], SPECIAL_TOKEN_2)
def test_add_tokens(self):
tokenizer = self.tokenizer_class(self.vocab_file)
vocab_size = len(tokenizer)
self.assertEqual(tokenizer.add_tokens(""), 0)
self.assertEqual(tokenizer.add_tokens("testoken"), 1)
self.assertEqual(tokenizer.add_tokens(["testoken1", "testtoken2"]), 2)
self.assertEqual(len(tokenizer), vocab_size + 3)
self.assertEqual(tokenizer.add_special_tokens({}), 0)
self.assertEqual(tokenizer.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2)
self.assertRaises(AssertionError, tokenizer.add_special_tokens, {"additional_special_tokens": "<testtoken1>"})
self.assertEqual(tokenizer.add_special_tokens({"additional_special_tokens": ["<testtoken2>"]}), 1)
self.assertEqual(
tokenizer.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
)
self.assertIn("<testtoken3>", tokenizer.special_tokens_map["additional_special_tokens"])
self.assertIsInstance(tokenizer.special_tokens_map["additional_special_tokens"], list)
self.assertGreaterEqual(len(tokenizer.special_tokens_map["additional_special_tokens"]), 2)
self.assertEqual(len(tokenizer), vocab_size + 8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment