Unverified Commit 9208618b authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core] in batch prefix caching by delay scheduling (#2442)

parent 864bf2ba
......@@ -55,6 +55,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/flush_cache",
api_key=self.api_key,
verify=self.verify,
method="POST",
)
self._assert_success(res)
......
......@@ -256,6 +256,7 @@ class Req:
# Prefix info
self.prefix_indices = []
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
self.extend_input_len = 0
self.last_node = None
......@@ -316,6 +317,7 @@ class Req:
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
# tree cache is None if the prefix is not computed with tree cache.
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
......
......@@ -20,9 +20,11 @@ from contextlib import contextmanager
from enum import Enum, auto
from typing import Dict, List, Optional
import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
......@@ -32,6 +34,13 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)
# The threshold to apply in-batch prefix caching.
# If we use too small value, in-batch prefix caching cannot be used. E.g.,
# imagine "the" prefix.
IN_BATCH_PREFIX_CACHING_THRESHOLD = int(
os.environ.get("SGLANG_IN_BATCH_PREFIX_CACHING_THRESHOLD", "32")
)
class SchedulePolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
......@@ -51,18 +60,50 @@ class SchedulePolicy:
# Compute matched prefix length
prefix_computed = False
# rid to deprioritize in the current run.
temporary_deprioritized = {}
if policy == "lpm" or policy == "dfs-weight":
# It is used to find the matching prefix for in-batch prefix caching.
temp_radix = RadixCache(None, None, False)
for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
rid=r.rid, key=prefix_ids
)
# NOTE(sang): This logic is for In-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_THRESHOLD:
in_batch_matching_prefixes, _ = temp_radix.match_prefix(
rid=r.rid, key=prefix_ids
)
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_THRESHOLD
):
temporary_deprioritized[r.rid] = r
else:
temp_radix.insert(prefix_ids, torch.tensor(prefix_ids))
prefix_computed = True
if policy == "lpm":
# Longest Prefix Match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
def get_priority(r: Req):
score = 0
if r.rid in temporary_deprioritized:
score = float("inf")
else:
score = -len(r.prefix_indices)
return score
waiting_queue.sort(key=get_priority)
elif policy == "fcfs":
# first come first serve
pass
......@@ -76,6 +117,7 @@ class SchedulePolicy:
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)
# node -> # of requests for that node.
node_to_weight = defaultdict(int)
for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node])
......@@ -87,7 +129,9 @@ class SchedulePolicy:
node_to_weight,
last_node_to_reqs,
waiting_queue,
temporary_deprioritized,
)
waiting_queue.extend(temporary_deprioritized.values())
else:
raise ValueError(f"Unknown schedule_policy: {policy=}")
......@@ -101,15 +145,22 @@ class SchedulePolicy:
def get_dfs_priority(
self,
cur_node: TreeNode,
node_to_priority: Dict,
last_node_to_reqs: Dict,
node_to_priority: Dict[TreeNode, int],
last_node_to_reqs: Dict[TreeNode, List[Req]],
q: List,
temporary_deprioritized: Dict[str, Req],
):
childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x])
for child in childs:
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
q.extend(last_node_to_reqs[cur_node])
self.get_dfs_priority(
child, node_to_priority, last_node_to_reqs, q, temporary_deprioritized
)
for req in last_node_to_reqs[cur_node]:
if req.rid in temporary_deprioritized:
continue
q.append(req)
class AddReqResult(Enum):
......
......@@ -713,7 +713,7 @@ class Scheduler:
if crash_on_warnings():
raise ValueError(msg)
def get_next_batch_to_run(self):
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.being_chunked_req:
......
from abc import ABC, abstractmethod
from typing import Callable
from typing import Callable, List, Tuple
class BasePrefixCache(ABC):
......@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
pass
@abstractmethod
def match_prefix(self, **kwargs):
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
pass
@abstractmethod
......
......@@ -2,7 +2,7 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
......@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
def reset(self):
self.entries = {}
def match_prefix(self, rid: int, key: List[int]):
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
if rid not in self.entries:
return [], None
......
......@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
import torch
......@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
self.root_node.lock_ref = 1
self.evictable_size_ = 0
def match_prefix(self, key: List, **kwargs):
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree.
Args:
key: A list of token IDs to find a matching prefix.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if self.disable:
return [], self.root_node
......
......@@ -79,7 +79,14 @@ class HttpResponse:
return self.resp.status
def http_request(url, json=None, stream=False, api_key=None, verify=None):
def http_request(
url,
json=None,
stream=False,
api_key=None,
verify=None,
method: Optional[str] = None,
):
"""A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"}
......@@ -90,7 +97,7 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None):
if stream:
return requests.post(url, json=json, stream=True, headers=headers)
else:
req = urllib.request.Request(url, headers=headers)
req = urllib.request.Request(url, headers=headers, method=method)
if json is None:
data = None
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment