Unverified Commit d7854120 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the case when max_new_tokens is too large (#1025)

parent 7b6a5332
......@@ -18,12 +18,16 @@ limitations under the License.
import random
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List
from typing import Dict, List, Optional
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode
# Clip the max new tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
CLIP_MAX_NEW_TOKENS = 4096
class PolicyScheduler:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
......@@ -98,7 +102,7 @@ class PrefillAdder:
tree_cache: BasePrefixCache,
rem_total_tokens: int,
rem_input_tokens: int,
rem_chunk_tokens: int,
rem_chunk_tokens: Optional[int],
):
self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens
......@@ -126,7 +130,11 @@ class PrefillAdder:
):
self.rem_total_tokens -= sum(
[
(r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* new_token_ratio
for r in running_batch.reqs
]
)
......@@ -151,7 +159,11 @@ class PrefillAdder:
self._prefill_one_req(
len(req.prefix_indices),
req.extend_input_len,
req.sampling_params.max_new_tokens if not truncated else 0,
(
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
),
)
# Return if chunked prefill not finished
......@@ -168,7 +180,9 @@ class PrefillAdder:
self.rem_total_tokens += delta
def add_one_req(self, req: Req):
total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens
total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
)
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)
......@@ -191,7 +205,9 @@ class PrefillAdder:
self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
prefix_len, input_tokens, req.sampling_params.max_new_tokens
prefix_len,
input_tokens,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
# Chunked prefill
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment