Unverified Commit 5d0ba403 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Refine the add request reasons to avoid corner cases. (#1574)

parent 04b262cd
......@@ -19,6 +19,7 @@ import os
import random
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
from typing import Dict, List, Optional
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
......@@ -104,6 +105,12 @@ class SchedulePolicy:
q.extend(last_node_to_reqs[cur_node])
class AddReqResult(Enum):
CONTINUE = auto() # Continue to add requests
NO_TOKEN = auto() # No token left
OTHER = auto() # Other reasons to stop adding requests
class PrefillAdder:
def __init__(
self,
......@@ -145,17 +152,16 @@ class PrefillAdder:
]
)
def no_remaining_tokens(self):
return (
self.rem_total_tokens <= 0
or self.rem_input_tokens <= 0
or (
self.rem_chunk_tokens <= 0
if self.rem_chunk_tokens is not None
else False
)
or self.cur_rem_tokens <= 0
)
def budget_state(self):
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
return AddReqResult.NO_TOKEN
if self.rem_input_tokens <= 0 or (
self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
):
return AddReqResult.OTHER
return AddReqResult.CONTINUE
def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
......@@ -239,7 +245,7 @@ class PrefillAdder:
)
bs = len(self.req_states) - i
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
return False
return AddReqResult.NO_TOKEN
tokens_freed += tokens_occupied
if req.extend_input_len <= self.rem_chunk_tokens:
......@@ -258,7 +264,7 @@ class PrefillAdder:
self.new_inflight_req = req
self._prefill_one_req(0, trunc_len, 0)
return True
return self.budget_state()
def add_one_req(self, req: Req):
if req.sampling_params.ignore_eos and self.tree_cache.disable:
......@@ -271,14 +277,14 @@ class PrefillAdder:
prefix_len = len(req.prefix_indices)
if total_tokens >= self.rem_total_tokens:
return False
return AddReqResult.NO_TOKEN
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
return False
return AddReqResult.OTHER
with self._lock_node(req.last_node):
if total_tokens > self.rem_total_tokens:
return False
return AddReqResult.NO_TOKEN
if (
self.rem_chunk_tokens is None
......@@ -297,7 +303,7 @@ class PrefillAdder:
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return False
return AddReqResult.OTHER
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
......@@ -306,4 +312,4 @@ class PrefillAdder:
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)
return True and not self.no_remaining_tokens()
return self.budget_state()
......@@ -50,7 +50,11 @@ from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
)
from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy
from sglang.srt.managers.schedule_policy import (
AddReqResult,
PrefillAdder,
SchedulePolicy,
)
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
......@@ -493,16 +497,15 @@ class Scheduler:
self.batch_is_full = True
break
if adder.no_remaining_tokens():
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
self.batch_is_full = True
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req)
if (
not res
or running_bs + len(adder.can_run_list) >= self.max_running_requests
):
self.batch_is_full = True
if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
self.batch_is_full = True
break
can_run_list = adder.can_run_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment