Unverified Commit 9bfe03c6 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Decode generated token_ids incrementally (#309)

* add incremental decoding for turbomind

* update TIS

* fix triton post processing

* update doc

* fix typo

* SentencePieceTokenizer incremental decode, add qwen message prompt

* docstring

* update bot
parent 22e8b2ca
......@@ -90,8 +90,8 @@ Generate:
curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
"prompt": "Hello! Ho are you?",
"prompt": "Hello! How are you?",
"instance_id": 1,
"sequence_start": true,
"sequence_end": true
}'
......
......@@ -92,8 +92,8 @@ curl http://{server_ip}:{server_port}/v1/models
curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
"prompt": "Hello! Ho are you?",
"prompt": "Hello! How are you?",
"instance_id": 1,
"sequence_start": true,
"sequence_end": true
}'
......
......@@ -256,6 +256,29 @@ class Puyu(BaseModel):
else:
return f'\n{self.user}{prompt}{self.eoh}\n{self.assistant}'
def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
messages (str | List): user's input prompt
sequence_start (bool): flag to start the sequence
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = self.system if not system else system
ret = f'<BOS>{system}{self.meta_instruction}{self.eosys}'
for user, assistant in zip(users, assistants):
if assistant:
ret += f'\n{self.user}{user}{self.eoh}\n{self.assistant}' \
f'{assistant}'
else:
ret += f'\n{self.user}{user}{self.eoh}\n{self.assistant}'
return ret
@property
def stop_words(self):
"""Return the stop-words' token ids."""
......@@ -360,6 +383,29 @@ class Qwen7BChat(BaseModel):
return f'\n{self.im_start}user\n{prompt}{self.im_end}' \
f'\n{self.im_start}assistant\n'
def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
messages (str | List): user's input prompt
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = self.system if not system else system
ret = f'{self.im_start}system\n{system}{self.im_end}'
for user, assistant in zip(users, assistants):
if assistant:
ret += f'\n{self.im_start}user\n{user}{self.im_end}' \
f'\n{self.im_start}assistant\n{assistant}'
else:
ret += f'\n{self.im_start}user\n{user}{self.im_end}' \
f'\n{self.im_start}assistant\n'
return ret
@property
def stop_words(self):
"""Return the stop-words' token ids."""
......
......@@ -138,12 +138,13 @@ class AsyncEngine:
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
# decode res
response = self.tokenizer.decode(res)[response_size:]
response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# response, history token len,
# input token len, gen token len
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size += len(response)
response_size = tokens
# update step
self.steps[str(session_id)] += len(input_ids) + tokens
......@@ -229,7 +230,8 @@ class AsyncEngine:
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
# decode res
response = self.tokenizer.decode(res[response_size:])
response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# response, history token len, input token len, gen token len
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
......
......@@ -599,14 +599,12 @@ class Chatbot:
Yields:
tuple: status, text, generated token number
"""
offset = n_input_token + preseq_length
status, res, n_token = None, '', 0
while True:
result = res_queue.get()
if result is None:
status = StatusCode.TRITON_STREAM_END
res = session.response
n_token = session.sequence_length - offset
session.status = StatusCode.TRITON_STREAM_END
break
if 'errcode' in result:
......@@ -629,30 +627,29 @@ class Chatbot:
output_ids = result.as_numpy('output_ids')
session.sequence_length = sequence_length.squeeze()
sequence_length = sequence_length - offset
last_token_id = output_ids[-1][-1][session.sequence_length - 1]
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
output_ids = output_ids[:, :, n_input_token +
preseq_length:sequence_length.squeeze(
)]
last_token_id = output_ids[-1, -1, -1]
if last_token_id == eos_id:
session.sequence_length = session.sequence_length - 1
sequence_length = sequence_length - 1
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
sequence_length = sequence_length.reshape(
(1, sequence_length.shape[-1]))
output_ids = output_ids[:, :, :-1]
if profile_generation:
yield (StatusCode.TRITON_STREAM_ING,
'postprocessing is ignored during profiling '
'token generation', sequence_length.squeeze())
'token generation', output_ids.shape[-1])
continue
output_str = postprocess(output_ids[:, :, offset:],
sequence_length)
output_str = postprocess(
output_ids, np.array([[n_token]], dtype=np.uint32))
n_token = output_ids.shape[-1]
text = output_str[0].decode()
if display:
new_text = text[len(session.response):]
print(new_text, end='', flush=True)
session.response = text
print(text, end='', flush=True)
session.response += text
yield (StatusCode.TRITON_STREAM_ING, session.response,
sequence_length.squeeze())
output_ids.shape[-1])
except Exception as e:
logger.error(f'catch exception: {e}')
......
......@@ -123,7 +123,7 @@ class TritonPythonModel:
outputs = []
for beam_tokens, beam_len in zip(tokens_batch, sequence_length):
for tokens, _len in zip(beam_tokens, beam_len):
output = self.tokenizer.decode(tokens[:_len])
output = self.tokenizer.decode(tokens, _len)
output = output.encode('utf8')
outputs.append(output)
return outputs
......@@ -99,10 +99,10 @@ def main(model_path,
random_seed=seed if nth_round == 1 else None):
res, tokens = outputs[0]
# decode res
response = tokenizer.decode(res)[response_size:]
response = tokenizer.decode(res.tolist(), offset=response_size)
response = valid_str(response)
print(f'{response}', end='', flush=True)
response_size += len(response)
response_size = tokens
# update step
step += len(input_ids) + tokens
......
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from typing import Sequence, Union
from typing import Optional, Sequence, Union
import torch
......@@ -16,6 +16,7 @@ class SentencePieceTokenizer:
def __init__(self, model_file: str):
from sentencepiece import SentencePieceProcessor
self.model = SentencePieceProcessor(model_file=model_file)
self._no_prefix_space_tokens = None
@property
def vocab_size(self):
......@@ -32,6 +33,24 @@ class SentencePieceTokenizer:
"""end of the sentence token id."""
return self.model.eos_id()
@property
def no_prefix_space_tokens(self):
"""tokens without prefix space."""
if self._no_prefix_space_tokens is None:
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
self._no_prefix_space_tokens = {
i
for i, tok in enumerate(vocab) if not tok.startswith('▁')
}
return self._no_prefix_space_tokens
def _maybe_add_prefix_space(self, tokens, decoded):
"""maybe add prefix space for incremental decoding."""
if len(tokens) and tokens[0] not in self.no_prefix_space_tokens:
return ' ' + decoded
else:
return decoded
def encode(self, s: str):
"""Tokenize a prompt.
......@@ -50,17 +69,23 @@ class SentencePieceTokenizer:
add_eos = True
return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
def decode(self, t: Sequence[int]):
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns:
str: text of decoding tokens
"""
if isinstance(t, torch.Tensor):
t = t.tolist()
return self.model.Decode(t)
t = t[offset:]
out_string = self.model.Decode(t)
if offset:
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string
def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
......@@ -86,7 +111,7 @@ class HuggingFaceTokenizer:
"""
def __init__(self, model_dir: str):
from transformers import AutoTokenizer
from transformers import AutoTokenizer, LlamaTokenizerFast
model_file = osp.join(model_dir, 'tokenizer.model')
backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
model_file_exists = osp.exists(model_file)
......@@ -95,6 +120,8 @@ class HuggingFaceTokenizer:
'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_dir,
trust_remote_code=True)
self.need_padding = isinstance(self.model, LlamaTokenizerFast)
self._no_prefix_space_tokens = None
# save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file) and model_file_exists:
if hasattr(self.model, 'backend_tokenizer'):
......@@ -122,6 +149,26 @@ class HuggingFaceTokenizer:
"""end of the sentence token id."""
return self.model.eos_token_id
@property
def no_prefix_space_tokens(self):
"""tokens without prefix space."""
if self._no_prefix_space_tokens is None:
vocab = self.model.convert_ids_to_tokens(
list(range(self.vocab_size)))
self._no_prefix_space_tokens = {
i
for i, tok in enumerate(vocab) if not tok.startswith('▁')
}
return self._no_prefix_space_tokens
def _maybe_add_prefix_space(self, tokens, decoded):
"""maybe add prefix space for incremental decoding."""
if self.need_padding and len(
tokens) and tokens[0] not in self.no_prefix_space_tokens:
return ' ' + decoded
else:
return decoded
def encode(self, s: str):
"""Tokenize a prompt.
......@@ -139,16 +186,23 @@ class HuggingFaceTokenizer:
add_special_tokens = True
return self.model.encode(s, add_special_tokens=add_special_tokens)
def decode(self, t: Sequence[int]):
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns:
str: text of decoding tokens
"""
skip_special_tokens = True
return self.model.decode(t, skip_special_tokens=skip_special_tokens)
t = t[offset:]
out_string = self.model.decode(t,
skip_special_tokens=skip_special_tokens)
if offset:
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string
def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
......@@ -211,15 +265,17 @@ class Tokenizer:
"""
return self.model.encode(s)
def decode(self, t: Sequence[int]):
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns:
str: text of decoding tokens
"""
return self.model.decode(t)
return self.model.decode(t, offset)
def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment