Unverified Commit d7854120 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the case when max_new_tokens is too large (#1025)

parent 7b6a5332
...@@ -18,12 +18,16 @@ limitations under the License. ...@@ -18,12 +18,16 @@ limitations under the License.
import random import random
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List from typing import Dict, List, Optional
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode from sglang.srt.mem_cache.radix_cache import TreeNode
# Clip the max new tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
CLIP_MAX_NEW_TOKENS = 4096
class PolicyScheduler: class PolicyScheduler:
def __init__(self, policy: str, tree_cache: BasePrefixCache): def __init__(self, policy: str, tree_cache: BasePrefixCache):
...@@ -98,7 +102,7 @@ class PrefillAdder: ...@@ -98,7 +102,7 @@ class PrefillAdder:
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
rem_total_tokens: int, rem_total_tokens: int,
rem_input_tokens: int, rem_input_tokens: int,
rem_chunk_tokens: int, rem_chunk_tokens: Optional[int],
): ):
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens self.rem_total_tokens = rem_total_tokens
...@@ -126,7 +130,11 @@ class PrefillAdder: ...@@ -126,7 +130,11 @@ class PrefillAdder:
): ):
self.rem_total_tokens -= sum( self.rem_total_tokens -= sum(
[ [
(r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* new_token_ratio
for r in running_batch.reqs for r in running_batch.reqs
] ]
) )
...@@ -151,7 +159,11 @@ class PrefillAdder: ...@@ -151,7 +159,11 @@ class PrefillAdder:
self._prefill_one_req( self._prefill_one_req(
len(req.prefix_indices), len(req.prefix_indices),
req.extend_input_len, req.extend_input_len,
req.sampling_params.max_new_tokens if not truncated else 0, (
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
),
) )
# Return if chunked prefill not finished # Return if chunked prefill not finished
...@@ -168,7 +180,9 @@ class PrefillAdder: ...@@ -168,7 +180,9 @@ class PrefillAdder:
self.rem_total_tokens += delta self.rem_total_tokens += delta
def add_one_req(self, req: Req): def add_one_req(self, req: Req):
total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
)
input_tokens = req.extend_input_len input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
...@@ -191,7 +205,9 @@ class PrefillAdder: ...@@ -191,7 +205,9 @@ class PrefillAdder:
self.can_run_list.append(req) self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req( self._prefill_one_req(
prefix_len, input_tokens, req.sampling_params.max_new_tokens prefix_len,
input_tokens,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
) )
else: else:
# Chunked prefill # Chunked prefill
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment