"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "a210ec74d29ee718bca9b3c192e0a93cf86cbf21"
Unverified Commit 5d0ba403 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Refine the add request reasons to avoid corner cases. (#1574)

parent 04b262cd
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
import random import random
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum, auto
from typing import Dict, List, Optional from typing import Dict, List, Optional
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
...@@ -104,6 +105,12 @@ class SchedulePolicy: ...@@ -104,6 +105,12 @@ class SchedulePolicy:
q.extend(last_node_to_reqs[cur_node]) q.extend(last_node_to_reqs[cur_node])
class AddReqResult(Enum):
CONTINUE = auto() # Continue to add requests
NO_TOKEN = auto() # No token left
OTHER = auto() # Other reasons to stop adding requests
class PrefillAdder: class PrefillAdder:
def __init__( def __init__(
self, self,
...@@ -145,17 +152,16 @@ class PrefillAdder: ...@@ -145,17 +152,16 @@ class PrefillAdder:
] ]
) )
def no_remaining_tokens(self): def budget_state(self):
return ( if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
self.rem_total_tokens <= 0 return AddReqResult.NO_TOKEN
or self.rem_input_tokens <= 0
or ( if self.rem_input_tokens <= 0 or (
self.rem_chunk_tokens <= 0 self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
if self.rem_chunk_tokens is not None ):
else False return AddReqResult.OTHER
)
or self.cur_rem_tokens <= 0 return AddReqResult.CONTINUE
)
def _prefill_one_req( def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int self, prefix_len: int, extend_input_len: int, max_new_tokens: int
...@@ -239,7 +245,7 @@ class PrefillAdder: ...@@ -239,7 +245,7 @@ class PrefillAdder:
) )
bs = len(self.req_states) - i bs = len(self.req_states) - i
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0: if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
return False return AddReqResult.NO_TOKEN
tokens_freed += tokens_occupied tokens_freed += tokens_occupied
if req.extend_input_len <= self.rem_chunk_tokens: if req.extend_input_len <= self.rem_chunk_tokens:
...@@ -258,7 +264,7 @@ class PrefillAdder: ...@@ -258,7 +264,7 @@ class PrefillAdder:
self.new_inflight_req = req self.new_inflight_req = req
self._prefill_one_req(0, trunc_len, 0) self._prefill_one_req(0, trunc_len, 0)
return True return self.budget_state()
def add_one_req(self, req: Req): def add_one_req(self, req: Req):
if req.sampling_params.ignore_eos and self.tree_cache.disable: if req.sampling_params.ignore_eos and self.tree_cache.disable:
...@@ -271,14 +277,14 @@ class PrefillAdder: ...@@ -271,14 +277,14 @@ class PrefillAdder:
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
if total_tokens >= self.rem_total_tokens: if total_tokens >= self.rem_total_tokens:
return False return AddReqResult.NO_TOKEN
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0: if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
return False return AddReqResult.OTHER
with self._lock_node(req.last_node): with self._lock_node(req.last_node):
if total_tokens > self.rem_total_tokens: if total_tokens > self.rem_total_tokens:
return False return AddReqResult.NO_TOKEN
if ( if (
self.rem_chunk_tokens is None self.rem_chunk_tokens is None
...@@ -297,7 +303,7 @@ class PrefillAdder: ...@@ -297,7 +303,7 @@ class PrefillAdder:
# Chunked prefill # Chunked prefill
trunc_len = self.rem_chunk_tokens trunc_len = self.rem_chunk_tokens
if trunc_len == 0: if trunc_len == 0:
return False return AddReqResult.OTHER
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
...@@ -306,4 +312,4 @@ class PrefillAdder: ...@@ -306,4 +312,4 @@ class PrefillAdder:
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0) self._prefill_one_req(prefix_len, trunc_len, 0)
return True and not self.no_remaining_tokens() return self.budget_state()
...@@ -50,7 +50,11 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -50,7 +50,11 @@ from sglang.srt.managers.schedule_batch import (
Req, Req,
ScheduleBatch, ScheduleBatch,
) )
from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy from sglang.srt.managers.schedule_policy import (
AddReqResult,
PrefillAdder,
SchedulePolicy,
)
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
...@@ -493,16 +497,15 @@ class Scheduler: ...@@ -493,16 +497,15 @@ class Scheduler:
self.batch_is_full = True self.batch_is_full = True
break break
if adder.no_remaining_tokens(): if running_bs + len(adder.can_run_list) >= self.max_running_requests:
self.batch_is_full = True self.batch_is_full = True
break break
req.init_next_round_input(None if prefix_computed else self.tree_cache) req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req) res = adder.add_one_req(req)
if ( if res != AddReqResult.CONTINUE:
not res if res == AddReqResult.NO_TOKEN:
or running_bs + len(adder.can_run_list) >= self.max_running_requests self.batch_is_full = True
):
self.batch_is_full = True
break break
can_run_list = adder.can_run_list can_run_list = adder.can_run_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment