Unverified Commit 9bfe03c6 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Decode generated token_ids incrementally (#309)

* add incremental decoding for turbomind

* update TIS

* fix triton post processing

* update doc

* fix typo

* SentencePieceTokenizer incremental decode, add qwen message prompt

* docstring

* update bot
parent 22e8b2ca
...@@ -90,8 +90,8 @@ Generate: ...@@ -90,8 +90,8 @@ Generate:
curl http://{server_ip}:{server_port}/generate \ curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "internlm-chat-7b", "prompt": "Hello! How are you?",
"prompt": "Hello! Ho are you?", "instance_id": 1,
"sequence_start": true, "sequence_start": true,
"sequence_end": true "sequence_end": true
}' }'
......
...@@ -92,8 +92,8 @@ curl http://{server_ip}:{server_port}/v1/models ...@@ -92,8 +92,8 @@ curl http://{server_ip}:{server_port}/v1/models
curl http://{server_ip}:{server_port}/generate \ curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "internlm-chat-7b", "prompt": "Hello! How are you?",
"prompt": "Hello! Ho are you?", "instance_id": 1,
"sequence_start": true, "sequence_start": true,
"sequence_end": true "sequence_end": true
}' }'
......
...@@ -256,6 +256,29 @@ class Puyu(BaseModel): ...@@ -256,6 +256,29 @@ class Puyu(BaseModel):
else: else:
return f'\n{self.user}{prompt}{self.eoh}\n{self.assistant}' return f'\n{self.user}{prompt}{self.eoh}\n{self.assistant}'
def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
messages (str | List): user's input prompt
sequence_start (bool): flag to start the sequence
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = self.system if not system else system
ret = f'<BOS>{system}{self.meta_instruction}{self.eosys}'
for user, assistant in zip(users, assistants):
if assistant:
ret += f'\n{self.user}{user}{self.eoh}\n{self.assistant}' \
f'{assistant}'
else:
ret += f'\n{self.user}{user}{self.eoh}\n{self.assistant}'
return ret
@property @property
def stop_words(self): def stop_words(self):
"""Return the stop-words' token ids.""" """Return the stop-words' token ids."""
...@@ -360,6 +383,29 @@ class Qwen7BChat(BaseModel): ...@@ -360,6 +383,29 @@ class Qwen7BChat(BaseModel):
return f'\n{self.im_start}user\n{prompt}{self.im_end}' \ return f'\n{self.im_start}user\n{prompt}{self.im_end}' \
f'\n{self.im_start}assistant\n' f'\n{self.im_start}assistant\n'
def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
messages (str | List): user's input prompt
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = self.system if not system else system
ret = f'{self.im_start}system\n{system}{self.im_end}'
for user, assistant in zip(users, assistants):
if assistant:
ret += f'\n{self.im_start}user\n{user}{self.im_end}' \
f'\n{self.im_start}assistant\n{assistant}'
else:
ret += f'\n{self.im_start}user\n{user}{self.im_end}' \
f'\n{self.im_start}assistant\n'
return ret
@property @property
def stop_words(self): def stop_words(self):
"""Return the stop-words' token ids.""" """Return the stop-words' token ids."""
......
...@@ -138,12 +138,13 @@ class AsyncEngine: ...@@ -138,12 +138,13 @@ class AsyncEngine:
random_seed=seed if sequence_start else None): random_seed=seed if sequence_start else None):
res, tokens = outputs[0] res, tokens = outputs[0]
# decode res # decode res
response = self.tokenizer.decode(res)[response_size:] response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# response, history token len, # response, history token len,
# input token len, gen token len # input token len, gen token len
yield GenOut(response, self.steps[str(session_id)], yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
response_size += len(response) response_size = tokens
# update step # update step
self.steps[str(session_id)] += len(input_ids) + tokens self.steps[str(session_id)] += len(input_ids) + tokens
...@@ -229,7 +230,8 @@ class AsyncEngine: ...@@ -229,7 +230,8 @@ class AsyncEngine:
random_seed=seed if sequence_start else None): random_seed=seed if sequence_start else None):
res, tokens = outputs[0] res, tokens = outputs[0]
# decode res # decode res
response = self.tokenizer.decode(res[response_size:]) response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# response, history token len, input token len, gen token len # response, history token len, input token len, gen token len
yield GenOut(response, self.steps[str(session_id)], yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
......
...@@ -599,14 +599,12 @@ class Chatbot: ...@@ -599,14 +599,12 @@ class Chatbot:
Yields: Yields:
tuple: status, text, generated token number tuple: status, text, generated token number
""" """
offset = n_input_token + preseq_length
status, res, n_token = None, '', 0 status, res, n_token = None, '', 0
while True: while True:
result = res_queue.get() result = res_queue.get()
if result is None: if result is None:
status = StatusCode.TRITON_STREAM_END status = StatusCode.TRITON_STREAM_END
res = session.response res = session.response
n_token = session.sequence_length - offset
session.status = StatusCode.TRITON_STREAM_END session.status = StatusCode.TRITON_STREAM_END
break break
if 'errcode' in result: if 'errcode' in result:
...@@ -629,30 +627,29 @@ class Chatbot: ...@@ -629,30 +627,29 @@ class Chatbot:
output_ids = result.as_numpy('output_ids') output_ids = result.as_numpy('output_ids')
session.sequence_length = sequence_length.squeeze() session.sequence_length = sequence_length.squeeze()
sequence_length = sequence_length - offset output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
last_token_id = output_ids[-1][-1][session.sequence_length - 1] output_ids = output_ids[:, :, n_input_token +
preseq_length:sequence_length.squeeze(
)]
last_token_id = output_ids[-1, -1, -1]
if last_token_id == eos_id: if last_token_id == eos_id:
session.sequence_length = session.sequence_length - 1 session.sequence_length = session.sequence_length - 1
sequence_length = sequence_length - 1 output_ids = output_ids[:, :, :-1]
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
sequence_length = sequence_length.reshape(
(1, sequence_length.shape[-1]))
if profile_generation: if profile_generation:
yield (StatusCode.TRITON_STREAM_ING, yield (StatusCode.TRITON_STREAM_ING,
'postprocessing is ignored during profiling ' 'postprocessing is ignored during profiling '
'token generation', sequence_length.squeeze()) 'token generation', output_ids.shape[-1])
continue continue
output_str = postprocess(output_ids[:, :, offset:], output_str = postprocess(
sequence_length) output_ids, np.array([[n_token]], dtype=np.uint32))
n_token = output_ids.shape[-1]
text = output_str[0].decode() text = output_str[0].decode()
if display: if display:
new_text = text[len(session.response):] print(text, end='', flush=True)
print(new_text, end='', flush=True) session.response += text
session.response = text
yield (StatusCode.TRITON_STREAM_ING, session.response, yield (StatusCode.TRITON_STREAM_ING, session.response,
sequence_length.squeeze()) output_ids.shape[-1])
except Exception as e: except Exception as e:
logger.error(f'catch exception: {e}') logger.error(f'catch exception: {e}')
......
...@@ -123,7 +123,7 @@ class TritonPythonModel: ...@@ -123,7 +123,7 @@ class TritonPythonModel:
outputs = [] outputs = []
for beam_tokens, beam_len in zip(tokens_batch, sequence_length): for beam_tokens, beam_len in zip(tokens_batch, sequence_length):
for tokens, _len in zip(beam_tokens, beam_len): for tokens, _len in zip(beam_tokens, beam_len):
output = self.tokenizer.decode(tokens[:_len]) output = self.tokenizer.decode(tokens, _len)
output = output.encode('utf8') output = output.encode('utf8')
outputs.append(output) outputs.append(output)
return outputs return outputs
...@@ -99,10 +99,10 @@ def main(model_path, ...@@ -99,10 +99,10 @@ def main(model_path,
random_seed=seed if nth_round == 1 else None): random_seed=seed if nth_round == 1 else None):
res, tokens = outputs[0] res, tokens = outputs[0]
# decode res # decode res
response = tokenizer.decode(res)[response_size:] response = tokenizer.decode(res.tolist(), offset=response_size)
response = valid_str(response) response = valid_str(response)
print(f'{response}', end='', flush=True) print(f'{response}', end='', flush=True)
response_size += len(response) response_size = tokens
# update step # update step
step += len(input_ids) + tokens step += len(input_ids) + tokens
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import json import json
import os.path as osp import os.path as osp
from typing import Sequence, Union from typing import Optional, Sequence, Union
import torch import torch
...@@ -16,6 +16,7 @@ class SentencePieceTokenizer: ...@@ -16,6 +16,7 @@ class SentencePieceTokenizer:
def __init__(self, model_file: str): def __init__(self, model_file: str):
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
self.model = SentencePieceProcessor(model_file=model_file) self.model = SentencePieceProcessor(model_file=model_file)
self._no_prefix_space_tokens = None
@property @property
def vocab_size(self): def vocab_size(self):
...@@ -32,6 +33,24 @@ class SentencePieceTokenizer: ...@@ -32,6 +33,24 @@ class SentencePieceTokenizer:
"""end of the sentence token id.""" """end of the sentence token id."""
return self.model.eos_id() return self.model.eos_id()
@property
def no_prefix_space_tokens(self):
"""tokens without prefix space."""
if self._no_prefix_space_tokens is None:
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
self._no_prefix_space_tokens = {
i
for i, tok in enumerate(vocab) if not tok.startswith('▁')
}
return self._no_prefix_space_tokens
def _maybe_add_prefix_space(self, tokens, decoded):
"""maybe add prefix space for incremental decoding."""
if len(tokens) and tokens[0] not in self.no_prefix_space_tokens:
return ' ' + decoded
else:
return decoded
def encode(self, s: str): def encode(self, s: str):
"""Tokenize a prompt. """Tokenize a prompt.
...@@ -50,17 +69,23 @@ class SentencePieceTokenizer: ...@@ -50,17 +69,23 @@ class SentencePieceTokenizer:
add_eos = True add_eos = True
return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos) return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
def decode(self, t: Sequence[int]): def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize. """De-tokenize.
Args: Args:
t (List[int]): a list of token ids t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns: Returns:
str: text of decoding tokens str: text of decoding tokens
""" """
if isinstance(t, torch.Tensor): if isinstance(t, torch.Tensor):
t = t.tolist() t = t.tolist()
return self.model.Decode(t) t = t[offset:]
out_string = self.model.Decode(t)
if offset:
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string
def __call__(self, s: Union[str, Sequence[str]]): def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts. """Tokenize prompts.
...@@ -86,7 +111,7 @@ class HuggingFaceTokenizer: ...@@ -86,7 +111,7 @@ class HuggingFaceTokenizer:
""" """
def __init__(self, model_dir: str): def __init__(self, model_dir: str):
from transformers import AutoTokenizer from transformers import AutoTokenizer, LlamaTokenizerFast
model_file = osp.join(model_dir, 'tokenizer.model') model_file = osp.join(model_dir, 'tokenizer.model')
backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json') backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
model_file_exists = osp.exists(model_file) model_file_exists = osp.exists(model_file)
...@@ -95,6 +120,8 @@ class HuggingFaceTokenizer: ...@@ -95,6 +120,8 @@ class HuggingFaceTokenizer:
'It may take long time to initialize the tokenizer.') 'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_dir, self.model = AutoTokenizer.from_pretrained(model_dir,
trust_remote_code=True) trust_remote_code=True)
self.need_padding = isinstance(self.model, LlamaTokenizerFast)
self._no_prefix_space_tokens = None
# save tokenizer.json to reuse # save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file) and model_file_exists: if not osp.exists(backend_tokenizer_file) and model_file_exists:
if hasattr(self.model, 'backend_tokenizer'): if hasattr(self.model, 'backend_tokenizer'):
...@@ -122,6 +149,26 @@ class HuggingFaceTokenizer: ...@@ -122,6 +149,26 @@ class HuggingFaceTokenizer:
"""end of the sentence token id.""" """end of the sentence token id."""
return self.model.eos_token_id return self.model.eos_token_id
@property
def no_prefix_space_tokens(self):
"""tokens without prefix space."""
if self._no_prefix_space_tokens is None:
vocab = self.model.convert_ids_to_tokens(
list(range(self.vocab_size)))
self._no_prefix_space_tokens = {
i
for i, tok in enumerate(vocab) if not tok.startswith('▁')
}
return self._no_prefix_space_tokens
def _maybe_add_prefix_space(self, tokens, decoded):
"""maybe add prefix space for incremental decoding."""
if self.need_padding and len(
tokens) and tokens[0] not in self.no_prefix_space_tokens:
return ' ' + decoded
else:
return decoded
def encode(self, s: str): def encode(self, s: str):
"""Tokenize a prompt. """Tokenize a prompt.
...@@ -139,16 +186,23 @@ class HuggingFaceTokenizer: ...@@ -139,16 +186,23 @@ class HuggingFaceTokenizer:
add_special_tokens = True add_special_tokens = True
return self.model.encode(s, add_special_tokens=add_special_tokens) return self.model.encode(s, add_special_tokens=add_special_tokens)
def decode(self, t: Sequence[int]): def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize. """De-tokenize.
Args: Args:
t (List[int]): a list of token ids t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns: Returns:
str: text of decoding tokens str: text of decoding tokens
""" """
skip_special_tokens = True skip_special_tokens = True
return self.model.decode(t, skip_special_tokens=skip_special_tokens) t = t[offset:]
out_string = self.model.decode(t,
skip_special_tokens=skip_special_tokens)
if offset:
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string
def __call__(self, s: Union[str, Sequence[str]]): def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts. """Tokenize prompts.
...@@ -211,15 +265,17 @@ class Tokenizer: ...@@ -211,15 +265,17 @@ class Tokenizer:
""" """
return self.model.encode(s) return self.model.encode(s)
def decode(self, t: Sequence[int]): def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize. """De-tokenize.
Args: Args:
t (List[int]): a list of token ids t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns: Returns:
str: text of decoding tokens str: text of decoding tokens
""" """
return self.model.decode(t) return self.model.decode(t, offset)
def __call__(self, s: Union[str, Sequence[str]]): def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts. """Tokenize prompts.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment