Unverified Commit 0549f21c authored by Mick's avatar Mick Committed by GitHub
Browse files

fix: fix max_new_tokens uninitialized error (#9343)

parent b354e3c9
......@@ -1181,6 +1181,16 @@ class Scheduler(
else:
self.send_to_tokenizer.send_pyobj(output)
def init_req_max_new_tokens(self, req):
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_len - len(req.origin_input_ids) - 1,
)
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
......@@ -1244,6 +1254,7 @@ class Scheduler(
req.set_finish_with_abort(
f"Invalid request: session id {recv_req.session_params.id} does not exist"
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
else:
......@@ -1251,6 +1262,7 @@ class Scheduler(
session = self.sessions[recv_req.session_params.id]
req = session.create_req(recv_req, self.tokenizer)
if isinstance(req.finished_reason, FINISH_ABORT):
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
......@@ -1270,9 +1282,13 @@ class Scheduler(
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
)
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
# initialize before returning
self.init_req_max_new_tokens(req)
# Validate prompt length
error_msg = validate_input_length(
req,
......@@ -1306,15 +1322,6 @@ class Scheduler(
self._add_request_to_queue(req)
return
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_len - len(req.origin_input_ids) - 1,
)
# Init grammar cache for this request
add_to_grammar_queue = False
if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment