Unverified Commit 7de60345 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix the prefix indices (#1037)

parent d84c5e70
...@@ -43,11 +43,14 @@ class PolicyScheduler: ...@@ -43,11 +43,14 @@ class PolicyScheduler:
def calc_priority(self, waiting_queue: List[Req]): def calc_priority(self, waiting_queue: List[Req]):
# Compute matched prefix length # Compute matched prefix length
for r in waiting_queue: prefix_computed = False
# NOTE: the prefix_indices must always be aligned with last_node if self.policy in ["lpm", "dfs-weight"]:
r.prefix_indices, r.last_node = self.tree_cache.match_prefix( for r in waiting_queue:
rid=r.rid, key=r.adjust_max_prefix_ids() # NOTE: the prefix_indices must always be aligned with last_node
) r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)
prefix_computed = True
if self.policy == "lpm": if self.policy == "lpm":
# Longest Prefix Match # Longest Prefix Match
...@@ -80,6 +83,8 @@ class PolicyScheduler: ...@@ -80,6 +83,8 @@ class PolicyScheduler:
else: else:
raise ValueError(f"Unknown schedule_policy: {self.policy}") raise ValueError(f"Unknown schedule_policy: {self.policy}")
return prefix_computed
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict): def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
for child in cur_node.children.values(): for child in cur_node.children.values():
self.calc_weight(child, node_to_weight) self.calc_weight(child, node_to_weight)
......
...@@ -18,9 +18,8 @@ limitations under the License. ...@@ -18,9 +18,8 @@ limitations under the License.
import logging import logging
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union from typing import List, Optional, Union
import numpy as np
import torch import torch
from flashinfer.sampling import top_k_top_p_sampling_from_probs from flashinfer.sampling import top_k_top_p_sampling_from_probs
...@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib ...@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...@@ -164,8 +163,12 @@ class Req: ...@@ -164,8 +163,12 @@ class Req:
def finished(self) -> bool: def finished(self) -> bool:
return self.finished_reason is not None return self.finished_reason is not None
def init_next_round_input(self): def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self.fill_ids = self.origin_input_ids + self.output_ids self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self): def adjust_max_prefix_ids(self):
...@@ -312,7 +315,7 @@ class ScheduleBatch: ...@@ -312,7 +315,7 @@ class ScheduleBatch:
reqs: List[Req] reqs: List[Req]
req_to_token_pool: ReqToTokenPool req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool token_to_kv_pool: BaseTokenToKVPool
tree_cache: RadixCache tree_cache: BasePrefixCache
# Batched arguments to model runner # Batched arguments to model runner
input_ids: torch.Tensor = None input_ids: torch.Tensor = None
...@@ -534,7 +537,7 @@ class ScheduleBatch: ...@@ -534,7 +537,7 @@ class ScheduleBatch:
residual_size = max(0, residual_size) residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free) self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
req.prefix_indices = None req.prefix_indices = []
req.last_node = None req.last_node = None
req.extend_input_len = 0 req.extend_input_len = 0
......
...@@ -369,7 +369,7 @@ class ModelTpServer: ...@@ -369,7 +369,7 @@ class ModelTpServer:
return None return None
# Get priority queue # Get priority queue
self.scheduler.calc_priority(self.waiting_queue) prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
adder = PrefillAdder( adder = PrefillAdder(
self.tree_cache, self.tree_cache,
...@@ -383,13 +383,15 @@ class ModelTpServer: ...@@ -383,13 +383,15 @@ class ModelTpServer:
has_inflight = self.current_inflight_req is not None has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None: if self.current_inflight_req is not None:
self.current_inflight_req.init_next_round_input() self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
self.current_inflight_req = adder.add_inflight_req( self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req self.current_inflight_req
) )
for req in self.waiting_queue: for req in self.waiting_queue:
req.init_next_round_input() req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req) res = adder.add_one_req(req)
if ( if (
not res not res
......
...@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache. ...@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import heapq import heapq
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Callable, List, Optional
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment