Unverified Commit 186bfd2e authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

robust incremental decode for leading space (#581)

* robust incremental decode for leading space

* speed up lookup as prefix_space_tokens is shorter than no_prefix_space_tokens

* add UT and fix qwen stuff
parent 70a5c63a
...@@ -16,7 +16,7 @@ class SentencePieceTokenizer: ...@@ -16,7 +16,7 @@ class SentencePieceTokenizer:
def __init__(self, model_file: str): def __init__(self, model_file: str):
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
self.model = SentencePieceProcessor(model_file=model_file) self.model = SentencePieceProcessor(model_file=model_file)
self._no_prefix_space_tokens = None self._prefix_space_tokens = None
@property @property
def vocab_size(self): def vocab_size(self):
...@@ -34,19 +34,20 @@ class SentencePieceTokenizer: ...@@ -34,19 +34,20 @@ class SentencePieceTokenizer:
return self.model.eos_id() return self.model.eos_id()
@property @property
def no_prefix_space_tokens(self): def prefix_space_tokens(self):
"""tokens without prefix space.""" """tokens without prefix space."""
if self._no_prefix_space_tokens is None: if self._prefix_space_tokens is None:
vocab = self.model.IdToPiece(list(range(self.vocab_size))) vocab = self.model.IdToPiece(list(range(self.vocab_size)))
self._no_prefix_space_tokens = { self._prefix_space_tokens = {
i i
for i, tok in enumerate(vocab) if not tok.startswith('▁') for i, tok in enumerate(vocab) if tok.startswith('▁')
} }
return self._no_prefix_space_tokens return self._prefix_space_tokens
def _maybe_add_prefix_space(self, tokens, decoded): def _maybe_add_prefix_space(self, tokens, decoded):
"""maybe add prefix space for incremental decoding.""" """maybe add prefix space for incremental decoding."""
if len(tokens) and tokens[0] not in self.no_prefix_space_tokens: if len(tokens) and not decoded.startswith(' ') and\
tokens[0] in self.prefix_space_tokens:
return ' ' + decoded return ' ' + decoded
else: else:
return decoded return decoded
...@@ -111,8 +112,7 @@ class HuggingFaceTokenizer: ...@@ -111,8 +112,7 @@ class HuggingFaceTokenizer:
""" """
def __init__(self, model_dir: str): def __init__(self, model_dir: str):
from transformers import (AutoTokenizer, CodeLlamaTokenizerFast, from transformers import AutoTokenizer
LlamaTokenizerFast)
model_file = osp.join(model_dir, 'tokenizer.model') model_file = osp.join(model_dir, 'tokenizer.model')
backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json') backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
model_file_exists = osp.exists(model_file) model_file_exists = osp.exists(model_file)
...@@ -121,9 +121,7 @@ class HuggingFaceTokenizer: ...@@ -121,9 +121,7 @@ class HuggingFaceTokenizer:
'It may take long time to initialize the tokenizer.') 'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_dir, self.model = AutoTokenizer.from_pretrained(model_dir,
trust_remote_code=True) trust_remote_code=True)
self.need_padding = isinstance(self.model, LlamaTokenizerFast) \ self._prefix_space_tokens = None
or isinstance(self.model, CodeLlamaTokenizerFast)
self._no_prefix_space_tokens = None
# save tokenizer.json to reuse # save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file) and model_file_exists: if not osp.exists(backend_tokenizer_file) and model_file_exists:
if hasattr(self.model, 'backend_tokenizer'): if hasattr(self.model, 'backend_tokenizer'):
...@@ -132,9 +130,12 @@ class HuggingFaceTokenizer: ...@@ -132,9 +130,12 @@ class HuggingFaceTokenizer:
if self.model.eos_token_id is None: if self.model.eos_token_id is None:
generation_config_file = osp.join(model_dir, generation_config_file = osp.join(model_dir,
'generation_config.json') 'generation_config.json')
with open(generation_config_file, 'r') as f: if osp.exists(generation_config_file):
cfg = json.load(f) with open(generation_config_file, 'r') as f:
self.model.eos_token_id = cfg['eos_token_id'] cfg = json.load(f)
self.model.eos_token_id = cfg['eos_token_id']
elif hasattr(self.model, 'eod_id'): # Qwen remote
self.model.eos_token_id = self.model.eod_id
@property @property
def vocab_size(self): def vocab_size(self):
...@@ -152,21 +153,22 @@ class HuggingFaceTokenizer: ...@@ -152,21 +153,22 @@ class HuggingFaceTokenizer:
return self.model.eos_token_id return self.model.eos_token_id
@property @property
def no_prefix_space_tokens(self): def prefix_space_tokens(self):
"""tokens without prefix space.""" """tokens without prefix space."""
if self._no_prefix_space_tokens is None: if self._prefix_space_tokens is None:
vocab = self.model.convert_ids_to_tokens( vocab = self.model.convert_ids_to_tokens(
list(range(self.vocab_size))) list(range(self.vocab_size)))
self._no_prefix_space_tokens = { self._prefix_space_tokens = {
i i
for i, tok in enumerate(vocab) if not tok.startswith('▁') for i, tok in enumerate(vocab)
if tok.startswith('▁' if isinstance(tok, str) else b' ')
} }
return self._no_prefix_space_tokens return self._prefix_space_tokens
def _maybe_add_prefix_space(self, tokens, decoded): def _maybe_add_prefix_space(self, tokens, decoded):
"""maybe add prefix space for incremental decoding.""" """maybe add prefix space for incremental decoding."""
if self.need_padding and len( if len(tokens) and not decoded.startswith(' ') and\
tokens) and tokens[0] not in self.no_prefix_space_tokens: tokens[0] in self.prefix_space_tokens:
return ' ' + decoded return ' ' + decoded
else: else:
return decoded return decoded
......
import pytest
from lmdeploy.tokenizer import HuggingFaceTokenizer
@pytest.mark.parametrize('model_path', [
'internlm/internlm-chat-7b', 'Qwen/Qwen-7B-Chat',
'baichuan-inc/Baichuan-7B', 'codellama/CodeLlama-7b-hf',
'upstage/SOLAR-0-70b-16bit'
])
@pytest.mark.parametrize(
'input', ['hi, this is a test 😆😆! ' * 5, '為什麼我還在用繁體字 😆😆 gg! ' * 5])
def test_tokenizer(model_path, input):
tokenizer = HuggingFaceTokenizer(model_path)
encoded = tokenizer.encode(input)
output = ''
offset = 0
for i in range(1, len(encoded) + 1):
decoded = tokenizer.decode(encoded[:i], offset)
if decoded.endswith('�'):
continue
output += decoded
offset = i
assert input == output, 'input string should equal to output after enc-dec'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment