Unverified Commit 51794bf2 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`SPM`] Patch `spm` Llama and T5 (#25656)

* hot fix

* only encode with string prefix if starts with prefix

* styling

* add a new test

* fixup
parent 57943630
...@@ -220,13 +220,14 @@ class LlamaTokenizer(PreTrainedTokenizer): ...@@ -220,13 +220,14 @@ class LlamaTokenizer(PreTrainedTokenizer):
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`. `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`. `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
""" """
if self.legacy:
return self.sp_model.encode(text, out_type=str)
unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
text = self.unk_token + text
tokens = self.sp_model.encode(text, out_type=str) tokens = self.sp_model.encode(text, out_type=str)
return tokens[unk_token_length:] if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
return tokens
# 1. Encode string + prefix ex: "<unk> Hey"
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab.""" """Converts a token (str) in an id using the vocab."""
......
...@@ -363,6 +363,10 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -363,6 +363,10 @@ class T5Tokenizer(PreTrainedTokenizer):
tokens = tokens[1:] tokens = tokens[1:]
return tokens return tokens
@property
def unk_token_length(self):
return len(self.sp_model.encode(str(self.unk_token)))
def _tokenize(self, text, **kwargs): def _tokenize(self, text, **kwargs):
""" """
Returns a tokenized string. Returns a tokenized string.
...@@ -373,13 +377,14 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -373,13 +377,14 @@ class T5Tokenizer(PreTrainedTokenizer):
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`. `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`. `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
""" """
if self.legacy:
return self.sp_model.encode(text, out_type=str)
unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
text = self.unk_token + text
tokens = self.sp_model.encode(text, out_type=str) tokens = self.sp_model.encode(text, out_type=str)
return tokens[unk_token_length:] if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
return tokens
# 1. Encode string + prefix ex: "<unk> Hey"
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab.""" """Converts a token (str) in an id using the vocab."""
......
...@@ -546,6 +546,15 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -546,6 +546,15 @@ class LlamaIntegrationTest(unittest.TestCase):
decoded_tokens = tokenizer.decode(input_ids) decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, " <s> Hello<s> how") self.assertEqual(decoded_tokens, " <s> Hello<s> how")
def test_some_edge_cases(self):
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
sp_tokens = tokenizer.sp_model.encode("<s>>", out_type=str)
self.assertEqual(sp_tokens, ["<", "s", ">>"])
tokens = tokenizer.tokenize("<s>>")
self.assertNotEqual(sp_tokens, tokens)
self.assertEqual(tokens, ["<s>", ">"])
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment