Unverified Commit 07640a3a authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Fix Tokenizer encode (#645)

* same encode with HF

* sequence_start -> add_bos

* complement
parent c02e281f
......@@ -210,7 +210,7 @@ class InternLMChat7B(BaseModel):
assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start:
return f'<BOS>{self.system}{self.meta_instruction}{self.eosys}' \
return f'{self.system}{self.meta_instruction}{self.eosys}' \
f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}'
else:
......@@ -230,7 +230,7 @@ class InternLMChat7B(BaseModel):
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys)
ret = '<BOS>'
ret = ''
if self.meta_instruction:
ret += f'{self.system}:{self.meta_instruction}{self.eosys}'
......@@ -355,7 +355,7 @@ class Puyu(BaseModel):
assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start:
return f'<BOS>{self.system}{self.meta_instruction}{self.eosys}' \
return f'{self.system}{self.meta_instruction}{self.eosys}' \
f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}'
else:
......@@ -374,7 +374,7 @@ class Puyu(BaseModel):
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys)
ret = '<BOS>'
ret = ''
if self.meta_instruction:
ret += f'{self.system}{self.meta_instruction}{self.eosys}'
......@@ -424,7 +424,7 @@ If a question does not make any sense, or is not factually coherent, explain why
assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start:
return f'<BOS>{self.b_inst} ' \
return f'{self.b_inst} ' \
f'{self.b_sys} {self.default_sys_prompt} {self.e_sys}' \
f'{prompt} {self.e_inst} '
......@@ -443,7 +443,7 @@ If a question does not make any sense, or is not factually coherent, explain why
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = self.default_sys_prompt if not system else system
ret = f'<BOS>{self.b_inst} {self.b_sys} {system} {self.e_sys}'
ret = f'{self.b_inst} {self.b_sys} {system} {self.e_sys}'
for i, (user, assistant) in enumerate(zip(users, assistants)):
if i != 0:
ret += f'{self.b_inst} '
......@@ -559,16 +559,16 @@ class CodeLlama(Llama2):
prefix, suffix = prompt.split('<FILL>')
if self.suffix_first:
# format as "<PRE> <SUF>{suf} <MID> {pre}"
prompt = f'<BOS><PRE> <SUF>{suffix} <MID> {prefix}'
prompt = f'<PRE> <SUF>{suffix} <MID> {prefix}'
else:
# format as "<PRE> {pre} <SUF>{suf} <MID>"
prompt = f'<BOS><PRE> {prefix} <SUF>{suffix} <MID>'
prompt = f'<PRE> {prefix} <SUF>{suffix} <MID>'
return prompt
def _get_prompt(self, prompt, sequence_start):
prompt = prompt.strip()
if sequence_start:
return f'<BOS>{self.b_inst} ' \
return f'{self.b_inst} ' \
f'{self.b_sys}{self.default_sys_prompt}{self.e_sys}' \
f'{prompt} {self.e_inst}'
......
......@@ -204,7 +204,7 @@ class AsyncEngine:
prompt = messages
if do_preprocess:
prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = 'stop' if stop else None
if self.steps[str(session_id)] + len(
input_ids) + request_output_len >= self.tm_model.session_len:
......
......@@ -459,6 +459,10 @@ class Chatbot:
session.sequence_length = 0
input_ids, input_lengths = self.preprocess(prompt)
# got input_ids with default add_bos == True
if not sequence_start and input_ids[0][0] == self.bos_id:
input_ids = input_ids[:, 1:]
input_lengths = input_lengths - 1
# will crash if last_token_id == eos_id and send empty input_ids
if sequence_end and request_output_len == 0:
input_ids = np.array([[self.bos_id]], dtype=np.uint32)
......
......@@ -53,7 +53,7 @@ class SentencePieceTokenizer:
else:
return decoded
def encode(self, s: str):
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
Args:
......@@ -61,15 +61,7 @@ class SentencePieceTokenizer:
Returns:
list[int]: token ids
"""
add_bos = False
add_eos = False
if s.find('<BOS>') != -1:
s = s.replace('<BOS>', '')
add_bos = True
if s == '<EOS>':
s = ''
add_eos = True
return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
return self.model.Encode(s, add_bos=add_bos, **kwargs)
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
......@@ -175,7 +167,7 @@ class HuggingFaceTokenizer:
else:
return decoded
def encode(self, s: str):
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
Args:
......@@ -183,14 +175,12 @@ class HuggingFaceTokenizer:
Returns:
list[int]: token ids
"""
add_special_tokens = False
if s.find('<BOS>') != -1:
s = s.replace('<BOS>', '<s>')
if s == '<EOS>':
s = '</s>'
if len(s) == 0:
add_special_tokens = True
return self.model.encode(s, add_special_tokens=add_special_tokens)
encoded = self.model.encode(s, **kwargs)
if not add_bos:
# in the middle of a session
if len(encoded) and encoded[0] == self.bos_token_id:
encoded = encoded[1:]
return encoded
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
......@@ -261,7 +251,7 @@ class Tokenizer:
"""end of the sentence token id."""
return self.model.eos_token_id
def encode(self, s: str):
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
Args:
......@@ -269,7 +259,7 @@ class Tokenizer:
Returns:
list[int]: token ids
"""
return self.model.encode(s)
return self.model.encode(s, add_bos, **kwargs)
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
......
......@@ -122,7 +122,7 @@ def main(model_path,
seed = random.getrandbits(64)
else:
prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt)
input_ids = tokenizer.encode(prompt, nth_round == 1)
if step + len(
input_ids) + request_output_len >= tm_model.session_len:
print('WARNING: exceed session max length.'
......
......@@ -30,7 +30,9 @@ def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
assert isinstance(stop_words, List) and \
all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
stop_words = [tokenizer.encode(stop_word)[-1] for stop_word in stop_words]
stop_words = [
tokenizer.encode(stop_word, False)[-1] for stop_word in stop_words
]
assert isinstance(stop_words, List) and all(
isinstance(elem, int) for elem in stop_words), 'invalid stop_words'
# each id in stop_words represents a stop word
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment