Unverified Commit 898efca7 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix to removing ESM special tokens (#22870)

Fix to make sure the EOS token doesn't come back
parent a8aad0ec
...@@ -54,16 +54,25 @@ class EsmTokenizer(PreTrainedTokenizer): ...@@ -54,16 +54,25 @@ class EsmTokenizer(PreTrainedTokenizer):
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"] model_input_names = ["input_ids", "attention_mask"]
def __init__(self, vocab_file, **kwargs): def __init__(
self,
vocab_file,
unk_token="<unk>",
cls_token="<cls>",
pad_token="<pad>",
mask_token="<mask>",
eos_token="<eos>",
**kwargs,
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.all_tokens = load_vocab_file(vocab_file) self.all_tokens = load_vocab_file(vocab_file)
self._id_to_token = dict(enumerate(self.all_tokens)) self._id_to_token = dict(enumerate(self.all_tokens))
self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)} self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
self.unk_token = "<unk>" self.unk_token = unk_token
self.cls_token = "<cls>" self.cls_token = cls_token
self.pad_token = "<pad>" self.pad_token = pad_token
self.mask_token = "<mask>" self.mask_token = mask_token
self.eos_token = "<eos>" self.eos_token = eos_token
self.unique_no_split_tokens = self.all_tokens self.unique_no_split_tokens = self.all_tokens
self._create_trie(self.unique_no_split_tokens) self._create_trie(self.unique_no_split_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment