Unverified Commit eb3b4dc9 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

avoid split chinese characters during decoding (#566)

parent 9c3634ec
...@@ -156,6 +156,11 @@ class AsyncEngine: ...@@ -156,6 +156,11 @@ class AsyncEngine:
# decode res # decode res
response = self.tokenizer.decode(res.tolist(), response = self.tokenizer.decode(res.tolist(),
offset=response_size) offset=response_size)
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
# sequence and decode them together
if response.endswith('�'):
continue
# response, history token len, # response, history token len,
# input token len, gen token len # input token len, gen token len
yield GenOut(response, self.steps[str(session_id)], yield GenOut(response, self.steps[str(session_id)],
...@@ -249,6 +254,11 @@ class AsyncEngine: ...@@ -249,6 +254,11 @@ class AsyncEngine:
# decode res # decode res
response = self.tokenizer.decode(res.tolist(), response = self.tokenizer.decode(res.tolist(),
offset=response_size) offset=response_size)
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
# sequence and decode them together
if response.endswith('�'):
continue
# response, history len, input len, generation len # response, history len, input len, generation len
yield GenOut(response, self.steps[str(session_id)], yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
......
...@@ -657,8 +657,13 @@ class Chatbot: ...@@ -657,8 +657,13 @@ class Chatbot:
continue continue
output_str = postprocess( output_str = postprocess(
output_ids, np.array([[n_token]], dtype=np.uint32)) output_ids, np.array([[n_token]], dtype=np.uint32))
n_token = output_ids.shape[-1]
text = output_str[0].decode() text = output_str[0].decode()
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
# sequence and decode them together
if text.endswith('�'):
continue
n_token = output_ids.shape[-1]
if display: if display:
print(text, end='', flush=True) print(text, end='', flush=True)
session.response += text session.response += text
......
...@@ -145,6 +145,11 @@ def main(model_path, ...@@ -145,6 +145,11 @@ def main(model_path,
res, tokens = outputs[0] res, tokens = outputs[0]
# decode res # decode res
response = tokenizer.decode(res.tolist(), offset=response_size) response = tokenizer.decode(res.tolist(), offset=response_size)
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
# sequence and decode them together
if response.endswith('�'):
continue
response = valid_str(response) response = valid_str(response)
print(f'{response}', end='', flush=True) print(f'{response}', end='', flush=True)
response_size = tokens response_size = tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment