Commit 9d2a48ae authored by zhouxiang's avatar zhouxiang
Browse files

解决当request_output_len+input_id长度大于session_len时就不输出的问题,将策略改成自适应request_output_len,可以生成到session_len为止

parent 2e528580
...@@ -264,12 +264,18 @@ class AsyncEngine: ...@@ -264,12 +264,18 @@ class AsyncEngine:
prompt = self.model.messages2prompt(prompt, sequence_start) prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start) input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = None finish_reason = None
request_output_len = min(
request_output_len, self.tm_model.session_len - self.id2step[str(session_id)] -
len(input_ids))
request_output_len = max(0, request_output_len)
if stop is True: if stop is True:
self.stop_session(session_id) self.stop_session(session_id)
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0, yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason) finish_reason)
elif self.id2step[str(session_id)] + len( elif self.id2step[str(session_id)] + len(
input_ids) + request_output_len >= self.tm_model.session_len: input_ids) + request_output_len > self.tm_model.session_len:
finish_reason = 'length' finish_reason = 'length'
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0, yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason) finish_reason)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment