Unverified Commit 5e1558f1 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Update `max_req_len` and `max_req_input_len` (#1748)

parent 94cde109
...@@ -165,6 +165,7 @@ class Scheduler: ...@@ -165,6 +165,7 @@ class Scheduler:
self.max_total_num_tokens, self.max_total_num_tokens,
self.max_prefill_tokens, self.max_prefill_tokens,
self.max_running_requests, self.max_running_requests,
self.max_req_len,
self.max_req_input_len, self.max_req_input_len,
self.random_seed, self.random_seed,
self.device, self.device,
...@@ -421,13 +422,14 @@ class Scheduler: ...@@ -421,13 +422,14 @@ class Scheduler:
"the max context length. Truncated!!!" "the max context length. Truncated!!!"
) )
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens = min(
( (
req.sampling_params.max_new_tokens req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None if req.sampling_params.max_new_tokens is not None
else 1 << 30 else 1 << 30
), ),
self.max_req_input_len - len(req.origin_input_ids), self.max_req_len - len(req.origin_input_ids) - 1,
) )
self.waiting_queue.append(req) self.waiting_queue.append(req)
......
...@@ -90,10 +90,14 @@ class TpModelWorker: ...@@ -90,10 +90,14 @@ class TpModelWorker:
), ),
self.model_runner.req_to_token_pool.size, self.model_runner.req_to_token_pool.size,
) )
self.max_req_input_len = min( self.max_req_len = min(
self.model_config.context_len - 1, self.model_config.context_len - 1,
self.max_total_num_tokens - 1, self.max_total_num_tokens - 1,
) )
self.max_req_input_len = self.max_req_len - 5
assert (
self.max_req_len > 0 and self.max_req_input_len > 0
), "Memory pool size is too small"
# Sync random seed across TP workers # Sync random seed across TP workers
self.random_seed = broadcast_pyobj( self.random_seed = broadcast_pyobj(
...@@ -108,6 +112,7 @@ class TpModelWorker: ...@@ -108,6 +112,7 @@ class TpModelWorker:
self.max_total_num_tokens, self.max_total_num_tokens,
self.max_prefill_tokens, self.max_prefill_tokens,
self.max_running_requests, self.max_running_requests,
self.max_req_len,
self.max_req_input_len, self.max_req_input_len,
self.random_seed, self.random_seed,
self.device, self.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment