Unverified Commit 07640a3a authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Fix Tokenizer encode (#645)

* same encode with HF

* sequence_start -> add_bos

* complement
parent c02e281f
...@@ -210,7 +210,7 @@ class InternLMChat7B(BaseModel): ...@@ -210,7 +210,7 @@ class InternLMChat7B(BaseModel):
assert self.capability == 'chat', \ assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}' f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start: if sequence_start:
return f'<BOS>{self.system}{self.meta_instruction}{self.eosys}' \ return f'{self.system}{self.meta_instruction}{self.eosys}' \
f'{self.user}{prompt}{self.eoh}' \ f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}' f'{self.assistant}'
else: else:
...@@ -230,7 +230,7 @@ class InternLMChat7B(BaseModel): ...@@ -230,7 +230,7 @@ class InternLMChat7B(BaseModel):
if isinstance(messages, str): if isinstance(messages, str):
return self.get_prompt(messages, sequence_start) return self.get_prompt(messages, sequence_start)
eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys) eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys)
ret = '<BOS>' ret = ''
if self.meta_instruction: if self.meta_instruction:
ret += f'{self.system}:{self.meta_instruction}{self.eosys}' ret += f'{self.system}:{self.meta_instruction}{self.eosys}'
...@@ -355,7 +355,7 @@ class Puyu(BaseModel): ...@@ -355,7 +355,7 @@ class Puyu(BaseModel):
assert self.capability == 'chat', \ assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}' f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start: if sequence_start:
return f'<BOS>{self.system}{self.meta_instruction}{self.eosys}' \ return f'{self.system}{self.meta_instruction}{self.eosys}' \
f'{self.user}{prompt}{self.eoh}' \ f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}' f'{self.assistant}'
else: else:
...@@ -374,7 +374,7 @@ class Puyu(BaseModel): ...@@ -374,7 +374,7 @@ class Puyu(BaseModel):
if isinstance(messages, str): if isinstance(messages, str):
return self.get_prompt(messages, sequence_start) return self.get_prompt(messages, sequence_start)
eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys) eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys)
ret = '<BOS>' ret = ''
if self.meta_instruction: if self.meta_instruction:
ret += f'{self.system}{self.meta_instruction}{self.eosys}' ret += f'{self.system}{self.meta_instruction}{self.eosys}'
...@@ -424,7 +424,7 @@ If a question does not make any sense, or is not factually coherent, explain why ...@@ -424,7 +424,7 @@ If a question does not make any sense, or is not factually coherent, explain why
assert self.capability == 'chat', \ assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}' f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start: if sequence_start:
return f'<BOS>{self.b_inst} ' \ return f'{self.b_inst} ' \
f'{self.b_sys} {self.default_sys_prompt} {self.e_sys}' \ f'{self.b_sys} {self.default_sys_prompt} {self.e_sys}' \
f'{prompt} {self.e_inst} ' f'{prompt} {self.e_inst} '
...@@ -443,7 +443,7 @@ If a question does not make any sense, or is not factually coherent, explain why ...@@ -443,7 +443,7 @@ If a question does not make any sense, or is not factually coherent, explain why
return self.get_prompt(messages, sequence_start) return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages) system, users, assistants = self._translate_messages(messages)
system = self.default_sys_prompt if not system else system system = self.default_sys_prompt if not system else system
ret = f'<BOS>{self.b_inst} {self.b_sys} {system} {self.e_sys}' ret = f'{self.b_inst} {self.b_sys} {system} {self.e_sys}'
for i, (user, assistant) in enumerate(zip(users, assistants)): for i, (user, assistant) in enumerate(zip(users, assistants)):
if i != 0: if i != 0:
ret += f'{self.b_inst} ' ret += f'{self.b_inst} '
...@@ -559,16 +559,16 @@ class CodeLlama(Llama2): ...@@ -559,16 +559,16 @@ class CodeLlama(Llama2):
prefix, suffix = prompt.split('<FILL>') prefix, suffix = prompt.split('<FILL>')
if self.suffix_first: if self.suffix_first:
# format as "<PRE> <SUF>{suf} <MID> {pre}" # format as "<PRE> <SUF>{suf} <MID> {pre}"
prompt = f'<BOS><PRE> <SUF>{suffix} <MID> {prefix}' prompt = f'<PRE> <SUF>{suffix} <MID> {prefix}'
else: else:
# format as "<PRE> {pre} <SUF>{suf} <MID>" # format as "<PRE> {pre} <SUF>{suf} <MID>"
prompt = f'<BOS><PRE> {prefix} <SUF>{suffix} <MID>' prompt = f'<PRE> {prefix} <SUF>{suffix} <MID>'
return prompt return prompt
def _get_prompt(self, prompt, sequence_start): def _get_prompt(self, prompt, sequence_start):
prompt = prompt.strip() prompt = prompt.strip()
if sequence_start: if sequence_start:
return f'<BOS>{self.b_inst} ' \ return f'{self.b_inst} ' \
f'{self.b_sys}{self.default_sys_prompt}{self.e_sys}' \ f'{self.b_sys}{self.default_sys_prompt}{self.e_sys}' \
f'{prompt} {self.e_inst}' f'{prompt} {self.e_inst}'
......
...@@ -204,7 +204,7 @@ class AsyncEngine: ...@@ -204,7 +204,7 @@ class AsyncEngine:
prompt = messages prompt = messages
if do_preprocess: if do_preprocess:
prompt = self.model.messages2prompt(prompt, sequence_start) prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt) input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = 'stop' if stop else None finish_reason = 'stop' if stop else None
if self.steps[str(session_id)] + len( if self.steps[str(session_id)] + len(
input_ids) + request_output_len >= self.tm_model.session_len: input_ids) + request_output_len >= self.tm_model.session_len:
......
...@@ -459,6 +459,10 @@ class Chatbot: ...@@ -459,6 +459,10 @@ class Chatbot:
session.sequence_length = 0 session.sequence_length = 0
input_ids, input_lengths = self.preprocess(prompt) input_ids, input_lengths = self.preprocess(prompt)
# got input_ids with default add_bos == True
if not sequence_start and input_ids[0][0] == self.bos_id:
input_ids = input_ids[:, 1:]
input_lengths = input_lengths - 1
# will crash if last_token_id == eos_id and send empty input_ids # will crash if last_token_id == eos_id and send empty input_ids
if sequence_end and request_output_len == 0: if sequence_end and request_output_len == 0:
input_ids = np.array([[self.bos_id]], dtype=np.uint32) input_ids = np.array([[self.bos_id]], dtype=np.uint32)
......
...@@ -53,7 +53,7 @@ class SentencePieceTokenizer: ...@@ -53,7 +53,7 @@ class SentencePieceTokenizer:
else: else:
return decoded return decoded
def encode(self, s: str): def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt. """Tokenize a prompt.
Args: Args:
...@@ -61,15 +61,7 @@ class SentencePieceTokenizer: ...@@ -61,15 +61,7 @@ class SentencePieceTokenizer:
Returns: Returns:
list[int]: token ids list[int]: token ids
""" """
add_bos = False return self.model.Encode(s, add_bos=add_bos, **kwargs)
add_eos = False
if s.find('<BOS>') != -1:
s = s.replace('<BOS>', '')
add_bos = True
if s == '<EOS>':
s = ''
add_eos = True
return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
def decode(self, t: Sequence[int], offset: Optional[int] = None): def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize. """De-tokenize.
...@@ -175,7 +167,7 @@ class HuggingFaceTokenizer: ...@@ -175,7 +167,7 @@ class HuggingFaceTokenizer:
else: else:
return decoded return decoded
def encode(self, s: str): def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt. """Tokenize a prompt.
Args: Args:
...@@ -183,14 +175,12 @@ class HuggingFaceTokenizer: ...@@ -183,14 +175,12 @@ class HuggingFaceTokenizer:
Returns: Returns:
list[int]: token ids list[int]: token ids
""" """
add_special_tokens = False encoded = self.model.encode(s, **kwargs)
if s.find('<BOS>') != -1: if not add_bos:
s = s.replace('<BOS>', '<s>') # in the middle of a session
if s == '<EOS>': if len(encoded) and encoded[0] == self.bos_token_id:
s = '</s>' encoded = encoded[1:]
if len(s) == 0: return encoded
add_special_tokens = True
return self.model.encode(s, add_special_tokens=add_special_tokens)
def decode(self, t: Sequence[int], offset: Optional[int] = None): def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize. """De-tokenize.
...@@ -261,7 +251,7 @@ class Tokenizer: ...@@ -261,7 +251,7 @@ class Tokenizer:
"""end of the sentence token id.""" """end of the sentence token id."""
return self.model.eos_token_id return self.model.eos_token_id
def encode(self, s: str): def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt. """Tokenize a prompt.
Args: Args:
...@@ -269,7 +259,7 @@ class Tokenizer: ...@@ -269,7 +259,7 @@ class Tokenizer:
Returns: Returns:
list[int]: token ids list[int]: token ids
""" """
return self.model.encode(s) return self.model.encode(s, add_bos, **kwargs)
def decode(self, t: Sequence[int], offset: Optional[int] = None): def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize. """De-tokenize.
......
...@@ -122,7 +122,7 @@ def main(model_path, ...@@ -122,7 +122,7 @@ def main(model_path,
seed = random.getrandbits(64) seed = random.getrandbits(64)
else: else:
prompt = model.get_prompt(prompt, nth_round == 1) prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt) input_ids = tokenizer.encode(prompt, nth_round == 1)
if step + len( if step + len(
input_ids) + request_output_len >= tm_model.session_len: input_ids) + request_output_len >= tm_model.session_len:
print('WARNING: exceed session max length.' print('WARNING: exceed session max length.'
......
...@@ -30,7 +30,9 @@ def _stop_words(stop_words: List[str], tokenizer: Tokenizer): ...@@ -30,7 +30,9 @@ def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
assert isinstance(stop_words, List) and \ assert isinstance(stop_words, List) and \
all(isinstance(elem, str) for elem in stop_words), \ all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}' f'stop_words must be a list but got {type(stop_words)}'
stop_words = [tokenizer.encode(stop_word)[-1] for stop_word in stop_words] stop_words = [
tokenizer.encode(stop_word, False)[-1] for stop_word in stop_words
]
assert isinstance(stop_words, List) and all( assert isinstance(stop_words, List) and all(
isinstance(elem, int) for elem in stop_words), 'invalid stop_words' isinstance(elem, int) for elem in stop_words), 'invalid stop_words'
# each id in stop_words represents a stop word # each id in stop_words represents a stop word
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment