"tests/python/vscode:/vscode.git/clone" did not exist on "661f8177ceb885fd534e6c73b759da01f8937431"
Unverified Commit fb7421db authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

minor: some potential bugs (#1044)

parent 14b64930
...@@ -2,7 +2,7 @@ from __future__ import annotations ...@@ -2,7 +2,7 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Callable, List, Optional
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
...@@ -30,12 +30,13 @@ class ChunkCache(BasePrefixCache): ...@@ -30,12 +30,13 @@ class ChunkCache(BasePrefixCache):
def reset(self): def reset(self):
self.entries = {} self.entries = {}
def match_prefix(self, rid, **kwargs): def match_prefix(self, rid: int, key: List[int]):
if rid not in self.entries: if rid not in self.entries:
return [], None return [], None
entry = self.entries[rid] entry = self.entries[rid]
return entry.value, entry max_prefix_len = len(key)
return entry.value[:max_prefix_len], entry
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None: if token_ids is None:
......
...@@ -140,13 +140,13 @@ class InputMetadata: ...@@ -140,13 +140,13 @@ class InputMetadata:
if self.forward_mode == ForwardMode.DECODE: if self.forward_mode == ForwardMode.DECODE:
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
else: else:
prefix_lens_cpu = [ extend_lens_cpu = [
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
] ]
self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda") self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu) self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
@classmethod @classmethod
def from_schedule_batch( def from_schedule_batch(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment