Unverified Commit 327deaee authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

expose stop words and filter eoa (#352)

* expose stop words

* support string

* fix

* remove eoa from chatbot

* remove eoa of turbomind

* fix ut

* suffix wheel and fix InternLM no system bug
parent 0cc667e1
...@@ -29,12 +29,14 @@ class BaseModel: ...@@ -29,12 +29,14 @@ class BaseModel:
temperature=0.8, temperature=0.8,
repetition_penalty=1.0, repetition_penalty=1.0,
capability='chat', capability='chat',
stop_words=None,
**kwargs): **kwargs):
self.session_len = session_len self.session_len = session_len
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.temperature = temperature self.temperature = temperature
self.repetition_penalty = repetition_penalty self.repetition_penalty = repetition_penalty
self.stop_words = stop_words
self.capability = capability self.capability = capability
def get_prompt(self, prompt, sequence_start=True): def get_prompt(self, prompt, sequence_start=True):
...@@ -101,11 +103,6 @@ class BaseModel: ...@@ -101,11 +103,6 @@ class BaseModel:
return self.get_prompt(messages) return self.get_prompt(messages)
# chat history processing in derived classes # chat history processing in derived classes
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return None
@property @property
def sampling_param(self): def sampling_param(self):
return SamplingParam(top_p=self.top_p, return SamplingParam(top_p=self.top_p,
...@@ -185,6 +182,7 @@ class InternLMChat7B(BaseModel): ...@@ -185,6 +182,7 @@ class InternLMChat7B(BaseModel):
eoh='', eoh='',
eoa='<eoa>', eoa='<eoa>',
assistant='<|Bot|>', assistant='<|Bot|>',
stop_words=['<eoa>'],
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.system = system self.system = system
...@@ -193,6 +191,7 @@ class InternLMChat7B(BaseModel): ...@@ -193,6 +191,7 @@ class InternLMChat7B(BaseModel):
self.eoh = eoh self.eoh = eoh
self.eoa = eoa self.eoa = eoa
self.assistant = assistant self.assistant = assistant
self.stop_words = stop_words
def decorate_prompt(self, prompt, sequence_start=True): def decorate_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the """Return the prompt that is concatenated with other elements in the
...@@ -227,7 +226,8 @@ class InternLMChat7B(BaseModel): ...@@ -227,7 +226,8 @@ class InternLMChat7B(BaseModel):
if isinstance(messages, str): if isinstance(messages, str):
return self.get_prompt(messages, sequence_start) return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages) system, users, assistants = self._translate_messages(messages)
ret = '<BOS>' system = self.meta_instruction if not system else system
ret = f'<BOS>{self.system}:{system}\n'
for user, assistant in zip(users, assistants): for user, assistant in zip(users, assistants):
if assistant: if assistant:
ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:' \ ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:' \
...@@ -236,11 +236,6 @@ class InternLMChat7B(BaseModel): ...@@ -236,11 +236,6 @@ class InternLMChat7B(BaseModel):
ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:' ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:'
return ret return ret
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [103028]
@MODELS.register_module(name='internlm-chat-20b') @MODELS.register_module(name='internlm-chat-20b')
@MODELS.register_module(name='internlm-chat-7b-8k') @MODELS.register_module(name='internlm-chat-7b-8k')
...@@ -339,12 +334,14 @@ class Puyu(BaseModel): ...@@ -339,12 +334,14 @@ class Puyu(BaseModel):
eoh='', eoh='',
assistant='', assistant='',
eoa='', eoa='',
stop_words=None,
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.meta_instruction = meta_instruction self.meta_instruction = meta_instruction
self.system = system self.system = system
self.user = user self.user = user
self.assistant = assistant self.assistant = assistant
self.stop_words = stop_words
self.eosys = eosys self.eosys = eosys
self.eoh = eoh self.eoh = eoh
self.eoa = eoa self.eoa = eoa
...@@ -382,11 +379,6 @@ class Puyu(BaseModel): ...@@ -382,11 +379,6 @@ class Puyu(BaseModel):
ret += f'{self.user}{user}{self.eoh}{self.assistant}' ret += f'{self.user}{user}{self.eoh}{self.assistant}'
return ret return ret
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [45623]
@MODELS.register_module(name='llama2') @MODELS.register_module(name='llama2')
class Llama2(BaseModel): class Llama2(BaseModel):
...@@ -468,6 +460,7 @@ class Qwen7BChat(BaseModel): ...@@ -468,6 +460,7 @@ class Qwen7BChat(BaseModel):
im_start='<|im_start|>', im_start='<|im_start|>',
im_end='<|im_end|>', im_end='<|im_end|>',
system='You are a helpful assistant.', system='You are a helpful assistant.',
stop_words=['<|im_end|>'],
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.session_len = session_len self.session_len = session_len
...@@ -478,6 +471,7 @@ class Qwen7BChat(BaseModel): ...@@ -478,6 +471,7 @@ class Qwen7BChat(BaseModel):
self.im_start = im_start self.im_start = im_start
self.im_end = im_end self.im_end = im_end
self.system = system self.system = system
self.stop_words = stop_words
def decorate_prompt(self, prompt, sequence_start=True): def decorate_prompt(self, prompt, sequence_start=True):
assert self.capability == 'chat', \ assert self.capability == 'chat', \
...@@ -513,11 +507,6 @@ class Qwen7BChat(BaseModel): ...@@ -513,11 +507,6 @@ class Qwen7BChat(BaseModel):
f'\n{self.im_start}assistant\n' f'\n{self.im_start}assistant\n'
return ret return ret
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [151645] # <|im_end|>
@MODELS.register_module(name='codellama') @MODELS.register_module(name='codellama')
class CodeLlama(Llama2): class CodeLlama(Llama2):
...@@ -526,6 +515,7 @@ class CodeLlama(Llama2): ...@@ -526,6 +515,7 @@ class CodeLlama(Llama2):
system='', system='',
session_len=4096, session_len=4096,
suffix_first=False, suffix_first=False,
stop_words=None,
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
caps = ['completion', 'infilling', 'chat', 'python'] caps = ['completion', 'infilling', 'chat', 'python']
...@@ -535,6 +525,7 @@ class CodeLlama(Llama2): ...@@ -535,6 +525,7 @@ class CodeLlama(Llama2):
self.default_sys_prompt = system self.default_sys_prompt = system
self.session_len = session_len self.session_len = session_len
self.suffix_first = suffix_first self.suffix_first = suffix_first
self.stop_words = stop_words
# The following sampling parameters refers to https://github.com/facebookresearch/codellama # noqa: E501 # The following sampling parameters refers to https://github.com/facebookresearch/codellama # noqa: E501
if self.capability == 'completion' or self.capability == 'python': if self.capability == 'completion' or self.capability == 'python':
...@@ -546,6 +537,8 @@ class CodeLlama(Llama2): ...@@ -546,6 +537,8 @@ class CodeLlama(Llama2):
elif self.capability == 'infilling': elif self.capability == 'infilling':
self.top_p = kwargs.get('top_p', 0.9) self.top_p = kwargs.get('top_p', 0.9)
self.temperature = kwargs.get('temperature', 0.0) self.temperature = kwargs.get('temperature', 0.0)
if self.stop_words is None:
self.stop_words = ['<EOT>']
def decorate_prompt(self, prompt, sequence_start=True): def decorate_prompt(self, prompt, sequence_start=True):
if self.capability == 'infilling': if self.capability == 'infilling':
...@@ -574,14 +567,6 @@ class CodeLlama(Llama2): ...@@ -574,14 +567,6 @@ class CodeLlama(Llama2):
return f'{self.b_inst} {prompt} {self.e_inst}' return f'{self.b_inst} {prompt} {self.e_inst}'
@property
def stop_words(self):
if self.capability == 'infilling':
# EOT ID
return [32010]
else:
return None
def messages2prompt(self, messages, sequence_start=True): def messages2prompt(self, messages, sequence_start=True):
assert self.capability == 'chat', \ assert self.capability == 'chat', \
f'codellama message2prompt only supports chat mode ' \ f'codellama message2prompt only supports chat mode ' \
......
...@@ -18,6 +18,7 @@ from tritonclient.grpc.service_pb2 import ModelInferResponse ...@@ -18,6 +18,7 @@ from tritonclient.grpc.service_pb2 import ModelInferResponse
from lmdeploy.model import MODELS from lmdeploy.model import MODELS
from lmdeploy.serve.turbomind.utils import (Postprocessor, Preprocessor, from lmdeploy.serve.turbomind.utils import (Postprocessor, Preprocessor,
prepare_tensor) prepare_tensor)
from lmdeploy.utils import filter_suffix
@dataclass @dataclass
...@@ -157,6 +158,8 @@ class Chatbot: ...@@ -157,6 +158,8 @@ class Chatbot:
request_output_len, request_output_len,
sequence_start, sequence_start,
sequence_end): sequence_end):
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
res = filter_suffix(res, self.model.stop_words)
if status.value < 0: if status.value < 0:
break break
else: else:
...@@ -346,6 +349,8 @@ class Chatbot: ...@@ -346,6 +349,8 @@ class Chatbot:
sequence_end): sequence_end):
if status.value < 0: if status.value < 0:
break break
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
res = filter_suffix(res, self.model.stop_words)
if status.value == 0: if status.value == 0:
self._session.histories = \ self._session.histories = \
self._session.histories + self._session.prompt + \ self._session.histories + self._session.prompt + \
...@@ -386,16 +391,23 @@ class Chatbot: ...@@ -386,16 +391,23 @@ class Chatbot:
token_ids, _ = self.preprocess('<EOS>') token_ids, _ = self.preprocess('<EOS>')
return token_ids[0][0] return token_ids[0][0]
def _stop_words(self, stop_words: List[int]): def _stop_words(self, stop_words: List[str]):
"""return stop-words' token ids.""" """return stop-words' token ids."""
if stop_words is None: if stop_words is None:
return None return None
assert isinstance(stop_words, List) and \ assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \ all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}' f'stop_words must be a list but got {type(stop_words)}'
# each id in stop_words represents a stop word # each id in stop_words represents a stop word
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
# detailed explanation about turbomind's stop_words # detailed explanation about turbomind's stop_words
stop_words = [
int(self.preprocess(stop_word)[0][0][-1])
for stop_word in stop_words
]
assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \
'invalid stop_words'
stop_word_offsets = range(1, len(stop_words) + 1) stop_word_offsets = range(1, len(stop_words) + 1)
stop_words = np.array([[stop_words, stop_words = np.array([[stop_words,
stop_word_offsets]]).astype(np.int32) stop_word_offsets]]).astype(np.int32)
......
...@@ -14,6 +14,7 @@ from torch.nn.utils.rnn import pad_sequence ...@@ -14,6 +14,7 @@ from torch.nn.utils.rnn import pad_sequence
import lmdeploy import lmdeploy
from lmdeploy.model import MODELS from lmdeploy.model import MODELS
from lmdeploy.turbomind import Tokenizer
from lmdeploy.utils import get_logger from lmdeploy.utils import get_logger
# TODO: find another way import _turbomind # TODO: find another way import _turbomind
...@@ -22,14 +23,16 @@ sys.path.append(osp.join(lmdeploy_dir, 'lib')) ...@@ -22,14 +23,16 @@ sys.path.append(osp.join(lmdeploy_dir, 'lib'))
import _turbomind as _tm # noqa: E402 import _turbomind as _tm # noqa: E402
def _stop_words(stop_words: List[int]): def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
"""return list of stop-words to numpy.ndarray.""" """return list of stop-words to numpy.ndarray."""
if stop_words is None: if stop_words is None:
return None return None
assert isinstance(stop_words, List) and \ assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \ all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}' f'stop_words must be a list but got {type(stop_words)}'
stop_words = [tokenizer.encode(stop_word)[-1] for stop_word in stop_words]
assert isinstance(stop_words, List) and all(
isinstance(elem, int) for elem in stop_words), 'invalid stop_words'
# each id in stop_words represents a stop word # each id in stop_words represents a stop word
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
# detailed explanation about fastertransformer's stop_words # detailed explanation about fastertransformer's stop_words
...@@ -106,7 +109,10 @@ class TurboMind: ...@@ -106,7 +109,10 @@ class TurboMind:
self.model_name = parser.get(section_name, 'model_name') self.model_name = parser.get(section_name, 'model_name')
data_type = parser.get(section_name, 'weight_type') data_type = parser.get(section_name, 'weight_type')
model = MODELS.get(self.model_name)() model = MODELS.get(self.model_name)()
self.stop_words = _stop_words(model.stop_words) tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
self.stop_words = _stop_words(model.stop_words, tokenizer)
# params # params
self.node_id = node_id self.node_id = node_id
...@@ -162,6 +168,8 @@ class TurboMindInstance: ...@@ -162,6 +168,8 @@ class TurboMindInstance:
self.gpu_count = tm_model.gpu_count self.gpu_count = tm_model.gpu_count
self.stop_words = tm_model.stop_words self.stop_words = tm_model.stop_words
self.stop_tokens = [] if self.stop_words is None else \
self.stop_words.flatten().tolist()
self.eos_id = tm_model.eos_id self.eos_id = tm_model.eos_id
self.session_len = tm_model.session_len self.session_len = tm_model.session_len
...@@ -346,6 +354,8 @@ class TurboMindInstance: ...@@ -346,6 +354,8 @@ class TurboMindInstance:
output, len_ = output, len_.item() output, len_ = output, len_.item()
if len(output) > 0 and output[-1].item() == self.eos_id: if len(output) > 0 and output[-1].item() == self.eos_id:
outputs.append((output[:-1], len_ - 1)) outputs.append((output[:-1], len_ - 1))
elif len(output) > 0 and output[-1].item() in self.stop_tokens:
outputs.append((output[:-1], len_))
else: else:
outputs.append((output, len_)) outputs.append((output, len_))
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging import logging
from typing import Optional from typing import List, Optional
logger_initialized = {} logger_initialized = {}
...@@ -77,3 +77,21 @@ def get_logger(name: str, ...@@ -77,3 +77,21 @@ def get_logger(name: str,
logger_initialized[name] = True logger_initialized[name] = True
return logger return logger
def filter_suffix(response: str, suffixes: Optional[List[str]] = None) -> str:
"""Filter response with suffixes.
Args:
response (str): generated response by LLMs.
suffixes (str): a list of suffixes to be deleted.
Return:
str: a clean response.
"""
if suffixes is None:
return response
for item in suffixes:
if response.endswith(item):
response = response[:len(response) - len(item)]
return response
...@@ -133,7 +133,7 @@ def test_codellama_infilling(): ...@@ -133,7 +133,7 @@ def test_codellama_infilling():
''' '''
_prompt = model.get_prompt(prompt) _prompt = model.get_prompt(prompt)
assert _prompt.find('<FILL>') == -1 assert _prompt.find('<FILL>') == -1
assert model.stop_words == [32010] assert model.stop_words == ['<EOT>']
model = MODELS.get('codellama')(capability='infilling', suffix_first=True) model = MODELS.get('codellama')(capability='infilling', suffix_first=True)
_prompt = model.get_prompt(prompt) _prompt = model.get_prompt(prompt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment