"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "839efa4443961bc9ce195d756a750d951b3640c2"
Unverified Commit 5b322a36 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1811 from huggingface/special-tokens

Fix special tokens addition in decoder #1807
parents 1a237d7f 74d0bcb6
......@@ -160,6 +160,26 @@ class CommonTestCases:
self.assertEqual(tokens[0], tokenizer.eos_token_id)
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
def test_add_special_tokens(self):
tokenizer = self.get_tokenizer()
input_text, output_text = self.get_input_output_texts()
special_token = "[SPECIAL TOKEN]"
tokenizer.add_special_tokens({"cls_token": special_token})
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
assert len(encoded_special_token) == 1
text = " ".join([input_text, special_token, output_text])
encoded = tokenizer.encode(text, add_special_tokens=False)
input_encoded = tokenizer.encode(input_text, add_special_tokens=False)
output_encoded = tokenizer.encode(output_text, add_special_tokens=False)
special_token_id = tokenizer.encode(special_token, add_special_tokens=False)
assert encoded == input_encoded + special_token_id + output_encoded
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
assert special_token not in decoded
def test_required_methods_tokenizer(self):
tokenizer = self.get_tokenizer()
......
......@@ -1057,7 +1057,7 @@ class PreTrainedTokenizer(object):
class attributes (cls_token, unk_token...).
"""
all_toks = self.all_special_tokens
all_ids = list(self._convert_token_to_id(t) for t in all_toks)
all_ids = self.convert_tokens_to_ids(all_toks)
return all_ids
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment