Unverified Commit 327deaee authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

expose stop words and filter eoa (#352)

* expose stop words

* support string

* fix

* remove eoa from chatbot

* remove eoa of turbomind

* fix ut

* suffix wheel and fix InternLM no system bug
parent 0cc667e1
......@@ -29,12 +29,14 @@ class BaseModel:
temperature=0.8,
repetition_penalty=1.0,
capability='chat',
stop_words=None,
**kwargs):
self.session_len = session_len
self.top_p = top_p
self.top_k = top_k
self.temperature = temperature
self.repetition_penalty = repetition_penalty
self.stop_words = stop_words
self.capability = capability
def get_prompt(self, prompt, sequence_start=True):
......@@ -101,11 +103,6 @@ class BaseModel:
return self.get_prompt(messages)
# chat history processing in derived classes
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return None
@property
def sampling_param(self):
return SamplingParam(top_p=self.top_p,
......@@ -185,6 +182,7 @@ class InternLMChat7B(BaseModel):
eoh='',
eoa='<eoa>',
assistant='<|Bot|>',
stop_words=['<eoa>'],
**kwargs):
super().__init__(**kwargs)
self.system = system
......@@ -193,6 +191,7 @@ class InternLMChat7B(BaseModel):
self.eoh = eoh
self.eoa = eoa
self.assistant = assistant
self.stop_words = stop_words
def decorate_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
......@@ -227,7 +226,8 @@ class InternLMChat7B(BaseModel):
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
ret = '<BOS>'
system = self.meta_instruction if not system else system
ret = f'<BOS>{self.system}:{system}\n'
for user, assistant in zip(users, assistants):
if assistant:
ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:' \
......@@ -236,11 +236,6 @@ class InternLMChat7B(BaseModel):
ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:'
return ret
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [103028]
@MODELS.register_module(name='internlm-chat-20b')
@MODELS.register_module(name='internlm-chat-7b-8k')
......@@ -339,12 +334,14 @@ class Puyu(BaseModel):
eoh='',
assistant='',
eoa='',
stop_words=None,
**kwargs):
super().__init__(**kwargs)
self.meta_instruction = meta_instruction
self.system = system
self.user = user
self.assistant = assistant
self.stop_words = stop_words
self.eosys = eosys
self.eoh = eoh
self.eoa = eoa
......@@ -382,11 +379,6 @@ class Puyu(BaseModel):
ret += f'{self.user}{user}{self.eoh}{self.assistant}'
return ret
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [45623]
@MODELS.register_module(name='llama2')
class Llama2(BaseModel):
......@@ -468,6 +460,7 @@ class Qwen7BChat(BaseModel):
im_start='<|im_start|>',
im_end='<|im_end|>',
system='You are a helpful assistant.',
stop_words=['<|im_end|>'],
**kwargs):
super().__init__(**kwargs)
self.session_len = session_len
......@@ -478,6 +471,7 @@ class Qwen7BChat(BaseModel):
self.im_start = im_start
self.im_end = im_end
self.system = system
self.stop_words = stop_words
def decorate_prompt(self, prompt, sequence_start=True):
assert self.capability == 'chat', \
......@@ -513,11 +507,6 @@ class Qwen7BChat(BaseModel):
f'\n{self.im_start}assistant\n'
return ret
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [151645] # <|im_end|>
@MODELS.register_module(name='codellama')
class CodeLlama(Llama2):
......@@ -526,6 +515,7 @@ class CodeLlama(Llama2):
system='',
session_len=4096,
suffix_first=False,
stop_words=None,
**kwargs):
super().__init__(**kwargs)
caps = ['completion', 'infilling', 'chat', 'python']
......@@ -535,6 +525,7 @@ class CodeLlama(Llama2):
self.default_sys_prompt = system
self.session_len = session_len
self.suffix_first = suffix_first
self.stop_words = stop_words
# The following sampling parameters refers to https://github.com/facebookresearch/codellama # noqa: E501
if self.capability == 'completion' or self.capability == 'python':
......@@ -546,6 +537,8 @@ class CodeLlama(Llama2):
elif self.capability == 'infilling':
self.top_p = kwargs.get('top_p', 0.9)
self.temperature = kwargs.get('temperature', 0.0)
if self.stop_words is None:
self.stop_words = ['<EOT>']
def decorate_prompt(self, prompt, sequence_start=True):
if self.capability == 'infilling':
......@@ -574,14 +567,6 @@ class CodeLlama(Llama2):
return f'{self.b_inst} {prompt} {self.e_inst}'
@property
def stop_words(self):
if self.capability == 'infilling':
# EOT ID
return [32010]
else:
return None
def messages2prompt(self, messages, sequence_start=True):
assert self.capability == 'chat', \
f'codellama message2prompt only supports chat mode ' \
......
......@@ -18,6 +18,7 @@ from tritonclient.grpc.service_pb2 import ModelInferResponse
from lmdeploy.model import MODELS
from lmdeploy.serve.turbomind.utils import (Postprocessor, Preprocessor,
prepare_tensor)
from lmdeploy.utils import filter_suffix
@dataclass
......@@ -157,6 +158,8 @@ class Chatbot:
request_output_len,
sequence_start,
sequence_end):
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
res = filter_suffix(res, self.model.stop_words)
if status.value < 0:
break
else:
......@@ -346,6 +349,8 @@ class Chatbot:
sequence_end):
if status.value < 0:
break
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
res = filter_suffix(res, self.model.stop_words)
if status.value == 0:
self._session.histories = \
self._session.histories + self._session.prompt + \
......@@ -386,16 +391,23 @@ class Chatbot:
token_ids, _ = self.preprocess('<EOS>')
return token_ids[0][0]
def _stop_words(self, stop_words: List[int]):
def _stop_words(self, stop_words: List[str]):
"""return stop-words' token ids."""
if stop_words is None:
return None
assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \
all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
# each id in stop_words represents a stop word
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
# detailed explanation about turbomind's stop_words
stop_words = [
int(self.preprocess(stop_word)[0][0][-1])
for stop_word in stop_words
]
assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \
'invalid stop_words'
stop_word_offsets = range(1, len(stop_words) + 1)
stop_words = np.array([[stop_words,
stop_word_offsets]]).astype(np.int32)
......
......@@ -14,6 +14,7 @@ from torch.nn.utils.rnn import pad_sequence
import lmdeploy
from lmdeploy.model import MODELS
from lmdeploy.turbomind import Tokenizer
from lmdeploy.utils import get_logger
# TODO: find another way import _turbomind
......@@ -22,14 +23,16 @@ sys.path.append(osp.join(lmdeploy_dir, 'lib'))
import _turbomind as _tm # noqa: E402
def _stop_words(stop_words: List[int]):
def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
"""return list of stop-words to numpy.ndarray."""
if stop_words is None:
return None
assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \
all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
stop_words = [tokenizer.encode(stop_word)[-1] for stop_word in stop_words]
assert isinstance(stop_words, List) and all(
isinstance(elem, int) for elem in stop_words), 'invalid stop_words'
# each id in stop_words represents a stop word
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
# detailed explanation about fastertransformer's stop_words
......@@ -106,7 +109,10 @@ class TurboMind:
self.model_name = parser.get(section_name, 'model_name')
data_type = parser.get(section_name, 'weight_type')
model = MODELS.get(self.model_name)()
self.stop_words = _stop_words(model.stop_words)
tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
self.stop_words = _stop_words(model.stop_words, tokenizer)
# params
self.node_id = node_id
......@@ -162,6 +168,8 @@ class TurboMindInstance:
self.gpu_count = tm_model.gpu_count
self.stop_words = tm_model.stop_words
self.stop_tokens = [] if self.stop_words is None else \
self.stop_words.flatten().tolist()
self.eos_id = tm_model.eos_id
self.session_len = tm_model.session_len
......@@ -346,6 +354,8 @@ class TurboMindInstance:
output, len_ = output, len_.item()
if len(output) > 0 and output[-1].item() == self.eos_id:
outputs.append((output[:-1], len_ - 1))
elif len(output) > 0 and output[-1].item() in self.stop_tokens:
outputs.append((output[:-1], len_))
else:
outputs.append((output, len_))
......
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional
from typing import List, Optional
logger_initialized = {}
......@@ -77,3 +77,21 @@ def get_logger(name: str,
logger_initialized[name] = True
return logger
def filter_suffix(response: str, suffixes: Optional[List[str]] = None) -> str:
"""Filter response with suffixes.
Args:
response (str): generated response by LLMs.
suffixes (str): a list of suffixes to be deleted.
Return:
str: a clean response.
"""
if suffixes is None:
return response
for item in suffixes:
if response.endswith(item):
response = response[:len(response) - len(item)]
return response
......@@ -133,7 +133,7 @@ def test_codellama_infilling():
'''
_prompt = model.get_prompt(prompt)
assert _prompt.find('<FILL>') == -1
assert model.stop_words == [32010]
assert model.stop_words == ['<EOT>']
model = MODELS.get('codellama')(capability='infilling', suffix_first=True)
_prompt = model.get_prompt(prompt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment