Unverified Commit ce21a318 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

fix exceed session len core dump for chat and generate (#366)

parent 71ade772
...@@ -112,7 +112,7 @@ class AsyncEngine: ...@@ -112,7 +112,7 @@ class AsyncEngine:
prompt = self.model.messages2prompt(messages, sequence_start) prompt = self.model.messages2prompt(messages, sequence_start)
input_ids = self.tokenizer.encode(prompt) input_ids = self.tokenizer.encode(prompt)
finish_reason = 'stop' if stop else None finish_reason = 'stop' if stop else None
if not sequence_end and self.steps[str(session_id)] + len( if self.steps[str(session_id)] + len(
input_ids) >= self.tm_model.session_len: input_ids) >= self.tm_model.session_len:
finish_reason = 'length' finish_reason = 'length'
yield GenOut('', self.steps[str(session_id)], len(input_ids), 0, yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
......
...@@ -74,12 +74,12 @@ def main(model_path, ...@@ -74,12 +74,12 @@ def main(model_path,
seed = random.getrandbits(64) seed = random.getrandbits(64)
else: else:
print(f'session {session_id}') print(f'session {session_id}')
if step >= tm_model.session_len: prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt)
if step + len(input_ids) >= tm_model.session_len:
print('WARNING: exceed session max length.' print('WARNING: exceed session max length.'
' Please end the session.') ' Please end the session.')
continue continue
prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt)
print(f'{prompt} ', end='', flush=True) print(f'{prompt} ', end='', flush=True)
response_size = 0 response_size = 0
for outputs in generator.stream_infer( for outputs in generator.stream_infer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment