Unverified Commit 9208618b authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core] in batch prefix caching by delay scheduling (#2442)

parent 864bf2ba
...@@ -55,6 +55,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -55,6 +55,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/flush_cache", self.base_url + "/flush_cache",
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
method="POST",
) )
self._assert_success(res) self._assert_success(res)
......
...@@ -256,6 +256,7 @@ class Req: ...@@ -256,6 +256,7 @@ class Req:
# Prefix info # Prefix info
self.prefix_indices = [] self.prefix_indices = []
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
self.extend_input_len = 0 self.extend_input_len = 0
self.last_node = None self.last_node = None
...@@ -316,6 +317,7 @@ class Req: ...@@ -316,6 +317,7 @@ class Req:
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self.fill_ids = self.origin_input_ids + self.output_ids self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None: if tree_cache is not None:
# tree cache is None if the prefix is not computed with tree cache.
self.prefix_indices, self.last_node = tree_cache.match_prefix( self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids() rid=self.rid, key=self.adjust_max_prefix_ids()
) )
......
...@@ -20,9 +20,11 @@ from contextlib import contextmanager ...@@ -20,9 +20,11 @@ from contextlib import contextmanager
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative. # This can prevent the server from being too conservative.
...@@ -32,6 +34,13 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int( ...@@ -32,6 +34,13 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096") os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
) )
# The threshold to apply in-batch prefix caching.
# If we use too small value, in-batch prefix caching cannot be used. E.g.,
# imagine "the" prefix.
IN_BATCH_PREFIX_CACHING_THRESHOLD = int(
os.environ.get("SGLANG_IN_BATCH_PREFIX_CACHING_THRESHOLD", "32")
)
class SchedulePolicy: class SchedulePolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache): def __init__(self, policy: str, tree_cache: BasePrefixCache):
...@@ -51,18 +60,50 @@ class SchedulePolicy: ...@@ -51,18 +60,50 @@ class SchedulePolicy:
# Compute matched prefix length # Compute matched prefix length
prefix_computed = False prefix_computed = False
# rid to deprioritize in the current run.
temporary_deprioritized = {}
if policy == "lpm" or policy == "dfs-weight": if policy == "lpm" or policy == "dfs-weight":
# It is used to find the matching prefix for in-batch prefix caching.
temp_radix = RadixCache(None, None, False)
for r in waiting_queue: for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node # NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix( r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids() rid=r.rid, key=prefix_ids
) )
# NOTE(sang): This logic is for In-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_THRESHOLD:
in_batch_matching_prefixes, _ = temp_radix.match_prefix(
rid=r.rid, key=prefix_ids
)
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_THRESHOLD
):
temporary_deprioritized[r.rid] = r
else:
temp_radix.insert(prefix_ids, torch.tensor(prefix_ids))
prefix_computed = True prefix_computed = True
if policy == "lpm": if policy == "lpm":
# Longest Prefix Match # Longest Prefix Match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) def get_priority(r: Req):
score = 0
if r.rid in temporary_deprioritized:
score = float("inf")
else:
score = -len(r.prefix_indices)
return score
waiting_queue.sort(key=get_priority)
elif policy == "fcfs": elif policy == "fcfs":
# first come first serve # first come first serve
pass pass
...@@ -76,6 +117,7 @@ class SchedulePolicy: ...@@ -76,6 +117,7 @@ class SchedulePolicy:
for req in waiting_queue: for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req) last_node_to_reqs[req.last_node].append(req)
# node -> # of requests for that node.
node_to_weight = defaultdict(int) node_to_weight = defaultdict(int)
for node in last_node_to_reqs: for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node]) node_to_weight[node] = len(last_node_to_reqs[node])
...@@ -87,7 +129,9 @@ class SchedulePolicy: ...@@ -87,7 +129,9 @@ class SchedulePolicy:
node_to_weight, node_to_weight,
last_node_to_reqs, last_node_to_reqs,
waiting_queue, waiting_queue,
temporary_deprioritized,
) )
waiting_queue.extend(temporary_deprioritized.values())
else: else:
raise ValueError(f"Unknown schedule_policy: {policy=}") raise ValueError(f"Unknown schedule_policy: {policy=}")
...@@ -101,15 +145,22 @@ class SchedulePolicy: ...@@ -101,15 +145,22 @@ class SchedulePolicy:
def get_dfs_priority( def get_dfs_priority(
self, self,
cur_node: TreeNode, cur_node: TreeNode,
node_to_priority: Dict, node_to_priority: Dict[TreeNode, int],
last_node_to_reqs: Dict, last_node_to_reqs: Dict[TreeNode, List[Req]],
q: List, q: List,
temporary_deprioritized: Dict[str, Req],
): ):
childs = [child for child in cur_node.children.values()] childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x]) childs.sort(key=lambda x: -node_to_priority[x])
for child in childs: for child in childs:
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) self.get_dfs_priority(
q.extend(last_node_to_reqs[cur_node]) child, node_to_priority, last_node_to_reqs, q, temporary_deprioritized
)
for req in last_node_to_reqs[cur_node]:
if req.rid in temporary_deprioritized:
continue
q.append(req)
class AddReqResult(Enum): class AddReqResult(Enum):
......
...@@ -713,7 +713,7 @@ class Scheduler: ...@@ -713,7 +713,7 @@ class Scheduler:
if crash_on_warnings(): if crash_on_warnings():
raise ValueError(msg) raise ValueError(msg)
def get_next_batch_to_run(self): def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch # Merge the prefill batch into the running batch
if self.last_batch and self.last_batch.forward_mode.is_extend(): if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.being_chunked_req: if self.being_chunked_req:
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable from typing import Callable, List, Tuple
class BasePrefixCache(ABC): class BasePrefixCache(ABC):
...@@ -10,7 +10,7 @@ class BasePrefixCache(ABC): ...@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
pass pass
@abstractmethod @abstractmethod
def match_prefix(self, **kwargs): def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
pass pass
@abstractmethod @abstractmethod
......
...@@ -2,7 +2,7 @@ from __future__ import annotations ...@@ -2,7 +2,7 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, List, Optional from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
...@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache): ...@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
def reset(self): def reset(self):
self.entries = {} self.entries = {}
def match_prefix(self, rid: int, key: List[int]): def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
if rid not in self.entries: if rid not in self.entries:
return [], None return [], None
......
...@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache. ...@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import heapq import heapq
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Callable, List, Optional from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
import torch import torch
...@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache): ...@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
self.root_node.lock_ref = 1 self.root_node.lock_ref = 1
self.evictable_size_ = 0 self.evictable_size_ = 0
def match_prefix(self, key: List, **kwargs): def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree.
Args:
key: A list of token IDs to find a matching prefix.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if self.disable: if self.disable:
return [], self.root_node return [], self.root_node
......
...@@ -79,7 +79,14 @@ class HttpResponse: ...@@ -79,7 +79,14 @@ class HttpResponse:
return self.resp.status return self.resp.status
def http_request(url, json=None, stream=False, api_key=None, verify=None): def http_request(
url,
json=None,
stream=False,
api_key=None,
verify=None,
method: Optional[str] = None,
):
"""A faster version of requests.post with low-level urllib API.""" """A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"} headers = {"Content-Type": "application/json; charset=utf-8"}
...@@ -90,7 +97,7 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None): ...@@ -90,7 +97,7 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None):
if stream: if stream:
return requests.post(url, json=json, stream=True, headers=headers) return requests.post(url, json=json, stream=True, headers=headers)
else: else:
req = urllib.request.Request(url, headers=headers) req = urllib.request.Request(url, headers=headers, method=method)
if json is None: if json is None:
data = None data = None
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment