Unverified Commit 7623091d authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

RadixCache method adjust (#977)

parent f724f1f1
......@@ -124,7 +124,7 @@ class Req:
# For vision input
self.pixel_values = None
self.image_size = None
self.image_offset = 0
self.image_offset = None
self.pad_value = None
# Prefix info
......@@ -162,6 +162,13 @@ class Req:
def finished(self) -> bool:
return self.finished_reason is not None
def adjust_max_prefix_ids(self):
max_prefix_ids = self.input_ids
if self.return_logprob:
max_prefix_ids = self.input_ids[: self.logprob_start_len]
return max_prefix_ids
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def init_incremental_detokenize(self):
first_iter = self.surr_offset is None or self.read_offset is None
......@@ -444,7 +451,8 @@ class ScheduleBatch:
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
(r.image_offset - p_len) if r.image_offset is not None else 0
for r, p_len in zip(reqs, prefix_lens)
]
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
self.extend_num_tokens = extend_num_tokens
......@@ -596,15 +604,7 @@ class ScheduleBatch:
req.vid += 1
# insert the old request into tree_cache
self.tree_cache.cache_req(
rid=req.rid,
token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req.req_pool_idx,
)
# unlock the last node
self.tree_cache.dec_lock_ref(req.last_node)
self.tree_cache.cache_finished_req(req, cur_all_ids)
# re-applying image padding
if req.pixel_values is not None:
......@@ -621,7 +621,6 @@ class ScheduleBatch:
jump_forward_reqs.append(req)
filter_indices.remove(i)
if len(filter_indices) < len(self.reqs):
self.filter_batch(filter_indices)
return jump_forward_reqs
......@@ -644,6 +643,15 @@ class ScheduleBatch:
] = self.out_cache_loc
def filter_batch(self, unfinished_indices: List[int]):
if unfinished_indices is None or len(unfinished_indices) == 0:
# Filter out all requests
self.reqs = []
return
if len(unfinished_indices) == len(self.reqs):
# No need to filter
return
self.reqs = [self.reqs[i] for i in unfinished_indices]
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
self.seq_lens = self.seq_lens[new_indices]
......@@ -711,6 +719,7 @@ class ScheduleBatch:
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
def sample(self, logits: torch.Tensor):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits = logits.contiguous()
logits.div_(self.temperatures)
......
......@@ -232,8 +232,6 @@ class ModelTpServer:
if new_batch is not None:
# Run a new prefill batch
self.forward_prefill_batch(new_batch)
self.cache_filled_batch(new_batch)
self.filter_out_inflight(new_batch)
if not new_batch.is_empty():
if self.running_batch is None:
......@@ -353,26 +351,20 @@ class ModelTpServer:
self.waiting_queue.append(req)
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
# TODO(lsyin): organize this function
running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
if running_bs >= self.max_running_requests:
return
return None
# Compute matched prefix length
for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids
try_match_ids = req.input_ids
if req.return_logprob:
try_match_ids = req.input_ids[: req.logprob_start_len]
# NOTE: the prefix_indices must always be aligned with last_node
prefix_indices, last_node = self.tree_cache.match_prefix(
rid=req.rid, key=try_match_ids
req.prefix_indices, req.last_node = self.tree_cache.match_prefix(
rid=req.rid, key=req.adjust_max_prefix_ids()
)
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
req.prefix_indices = prefix_indices
req.last_node = last_node
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
# Get priority queue
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
......@@ -394,6 +386,24 @@ class ModelTpServer:
)
for req in self.waiting_queue:
# FIXME: Move this code into adjust_max_prefix_len
if req.return_logprob and req.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2:
delta = 2 - req.extend_input_len
req.extend_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1
res = adder.add_one_req(req)
if (
not res
......@@ -470,10 +480,20 @@ class ModelTpServer:
pt = 0
for i, req in enumerate(batch.reqs):
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req:
# Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
pt += req.extend_input_len
......@@ -529,22 +549,6 @@ class ModelTpServer:
)
req.output_top_logprobs.append(output.output_top_logprobs[i])
def cache_filled_batch(self, batch: ScheduleBatch):
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.input_ids),
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req.req_pool_idx,
del_in_memory_pool=False,
old_last_node=req.last_node,
)
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
if req is self.current_inflight_req:
# inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
def forward_decode_batch(self, batch: ScheduleBatch):
# Check if decode out of memory
if not batch.check_decode_mem():
......@@ -595,6 +599,9 @@ class ModelTpServer:
req.output_ids.append(next_token_id)
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
if req.return_logprob:
req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
......@@ -614,12 +621,9 @@ class ModelTpServer:
output_spaces_between_special_tokens = []
output_meta_info = []
output_finished_reason: List[BaseFinishReason] = []
finished_indices = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if req.finished():
finished_indices.append(i)
else:
if not req.finished() and req is not self.current_inflight_req:
unfinished_indices.append(i)
if req.finished() or (
......@@ -683,34 +687,7 @@ class ModelTpServer:
)
)
# Remove finished reqs
if finished_indices:
# Update radix cache
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req.req_pool_idx,
)
self.tree_cache.dec_lock_ref(req.last_node)
# Update batch tensors
if unfinished_indices:
batch.filter_batch(unfinished_indices)
else:
batch.reqs = []
def filter_out_inflight(self, batch: ScheduleBatch):
# TODO(lsyin): reduce the overhead, make a special version for this
if self.current_inflight_req is None:
return
to_remove = batch.reqs.index(self.current_inflight_req)
unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
# Remove finished reqs: update batch tensors
batch.filter_batch(unfinished_indices)
def flush_cache(self):
......
......@@ -17,7 +17,11 @@ class BasePrefixCache(ABC):
pass
@abstractmethod
def cache_req(self, **kwargs):
def cache_finished_req(self, **kwargs):
pass
@abstractmethod
def cache_unfinished_req(self, **kwargs):
pass
@abstractmethod
......
"""Cache for chunked prefill, used when RadixCache is disabled."""
from sglang.srt.mem_cache.base_cache import BasePrefixCache
from typing import TYPE_CHECKING
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class ChunkCacheEntry:
......@@ -27,22 +32,31 @@ class ChunkCache(BasePrefixCache):
entry = self.entries[rid]
return entry.value, entry
def cache_req(
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
):
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
if del_in_memory_pool:
assert rid in self.entries
self.req_to_token_pool.free(req_pool_idx)
self.token_to_kv_pool.free(indices)
return
def cache_finished_req(self, req: "Req", token_ids=None):
if token_ids is None:
token_ids = (req.input_ids + req.output_ids)[:-1]
if rid not in self.entries:
self.entries[rid] = ChunkCacheEntry(rid, indices)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
assert req.rid in self.entries
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool.free(kv_indices)
entry = self.entries[rid]
entry.value = indices
return indices, entry
def cache_unfinished_req(self, req: "Req", token_ids=None):
if token_ids is None:
token_ids = req.input_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if req.rid not in self.entries:
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
entry = self.entries[req.rid]
entry.value = kv_indices
return kv_indices, entry
def insert(self):
raise NotImplementedError
......
......@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache.
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING
import torch
from sglang.srt.mem_cache.base_cache import BasePrefixCache
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class TreeNode:
......@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache):
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_req(
self,
token_ids,
last_uncached_pos,
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
**kwargs,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
new_prefix_len = self.insert(token_ids, indices.clone())
def cache_finished_req(self, req: "Req", token_ids=None):
"""Cache request when it finishes."""
if token_ids is None:
token_ids = (req.input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.disable:
if del_in_memory_pool:
self.token_to_kv_pool.free(indices)
else:
return torch.tensor([], dtype=torch.int32), self.root_node
self.token_to_kv_pool.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
# Radix Cache takes one ref in memory pool
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
if del_in_memory_pool:
self.req_to_token_pool.free(req_pool_idx)
else:
cached_indices, new_last_node = self.match_prefix(token_ids)
assert len(cached_indices) == len(token_ids)
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: "Req", token_ids=None):
"""Cache request when it is unfinished."""
if self.disable:
return
if token_ids is None:
token_ids = req.input_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)
assert len(new_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req_pool_idx, last_uncached_pos : len(cached_indices)
] = cached_indices[last_uncached_pos:]
self.dec_lock_ref(old_last_node)
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
] = new_indices[len(req.prefix_indices) :]
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
return cached_indices, new_last_node
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self):
self._print_helper(self.root_node, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment