Unverified Commit 7623091d authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

RadixCache method adjust (#977)

parent f724f1f1
...@@ -124,7 +124,7 @@ class Req: ...@@ -124,7 +124,7 @@ class Req:
# For vision input # For vision input
self.pixel_values = None self.pixel_values = None
self.image_size = None self.image_size = None
self.image_offset = 0 self.image_offset = None
self.pad_value = None self.pad_value = None
# Prefix info # Prefix info
...@@ -162,6 +162,13 @@ class Req: ...@@ -162,6 +162,13 @@ class Req:
def finished(self) -> bool: def finished(self) -> bool:
return self.finished_reason is not None return self.finished_reason is not None
def adjust_max_prefix_ids(self):
max_prefix_ids = self.input_ids
if self.return_logprob:
max_prefix_ids = self.input_ids[: self.logprob_start_len]
return max_prefix_ids
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def init_incremental_detokenize(self): def init_incremental_detokenize(self):
first_iter = self.surr_offset is None or self.read_offset is None first_iter = self.surr_offset is None or self.read_offset is None
...@@ -444,7 +451,8 @@ class ScheduleBatch: ...@@ -444,7 +451,8 @@ class ScheduleBatch:
self.pixel_values = [r.pixel_values for r in reqs] self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs] self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [ self.image_offsets = [
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens) (r.image_offset - p_len) if r.image_offset is not None else 0
for r, p_len in zip(reqs, prefix_lens)
] ]
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device) self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
...@@ -596,15 +604,7 @@ class ScheduleBatch: ...@@ -596,15 +604,7 @@ class ScheduleBatch:
req.vid += 1 req.vid += 1
# insert the old request into tree_cache # insert the old request into tree_cache
self.tree_cache.cache_req( self.tree_cache.cache_finished_req(req, cur_all_ids)
rid=req.rid,
token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req.req_pool_idx,
)
# unlock the last node
self.tree_cache.dec_lock_ref(req.last_node)
# re-applying image padding # re-applying image padding
if req.pixel_values is not None: if req.pixel_values is not None:
...@@ -621,8 +621,7 @@ class ScheduleBatch: ...@@ -621,8 +621,7 @@ class ScheduleBatch:
jump_forward_reqs.append(req) jump_forward_reqs.append(req)
filter_indices.remove(i) filter_indices.remove(i)
if len(filter_indices) < len(self.reqs): self.filter_batch(filter_indices)
self.filter_batch(filter_indices)
return jump_forward_reqs return jump_forward_reqs
...@@ -644,6 +643,15 @@ class ScheduleBatch: ...@@ -644,6 +643,15 @@ class ScheduleBatch:
] = self.out_cache_loc ] = self.out_cache_loc
def filter_batch(self, unfinished_indices: List[int]): def filter_batch(self, unfinished_indices: List[int]):
if unfinished_indices is None or len(unfinished_indices) == 0:
# Filter out all requests
self.reqs = []
return
if len(unfinished_indices) == len(self.reqs):
# No need to filter
return
self.reqs = [self.reqs[i] for i in unfinished_indices] self.reqs = [self.reqs[i] for i in unfinished_indices]
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda") new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
self.seq_lens = self.seq_lens[new_indices] self.seq_lens = self.seq_lens[new_indices]
...@@ -711,6 +719,7 @@ class ScheduleBatch: ...@@ -711,6 +719,7 @@ class ScheduleBatch:
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
def sample(self, logits: torch.Tensor): def sample(self, logits: torch.Tensor):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits # Post process logits
logits = logits.contiguous() logits = logits.contiguous()
logits.div_(self.temperatures) logits.div_(self.temperatures)
......
...@@ -232,8 +232,6 @@ class ModelTpServer: ...@@ -232,8 +232,6 @@ class ModelTpServer:
if new_batch is not None: if new_batch is not None:
# Run a new prefill batch # Run a new prefill batch
self.forward_prefill_batch(new_batch) self.forward_prefill_batch(new_batch)
self.cache_filled_batch(new_batch)
self.filter_out_inflight(new_batch)
if not new_batch.is_empty(): if not new_batch.is_empty():
if self.running_batch is None: if self.running_batch is None:
...@@ -353,26 +351,20 @@ class ModelTpServer: ...@@ -353,26 +351,20 @@ class ModelTpServer:
self.waiting_queue.append(req) self.waiting_queue.append(req)
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
# TODO(lsyin): organize this function
running_bs = ( running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0 len(self.running_batch.reqs) if self.running_batch is not None else 0
) )
if running_bs >= self.max_running_requests: if running_bs >= self.max_running_requests:
return return None
# Compute matched prefix length # Compute matched prefix length
for req in self.waiting_queue: for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids req.input_ids = req.origin_input_ids + req.output_ids
try_match_ids = req.input_ids
if req.return_logprob:
try_match_ids = req.input_ids[: req.logprob_start_len]
# NOTE: the prefix_indices must always be aligned with last_node # NOTE: the prefix_indices must always be aligned with last_node
prefix_indices, last_node = self.tree_cache.match_prefix( req.prefix_indices, req.last_node = self.tree_cache.match_prefix(
rid=req.rid, key=try_match_ids rid=req.rid, key=req.adjust_max_prefix_ids()
) )
req.extend_input_len = len(req.input_ids) - len(prefix_indices) req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
req.prefix_indices = prefix_indices
req.last_node = last_node
# Get priority queue # Get priority queue
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue) self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
...@@ -394,6 +386,24 @@ class ModelTpServer: ...@@ -394,6 +386,24 @@ class ModelTpServer:
) )
for req in self.waiting_queue: for req in self.waiting_queue:
# FIXME: Move this code into adjust_max_prefix_len
if req.return_logprob and req.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2:
delta = 2 - req.extend_input_len
req.extend_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1
res = adder.add_one_req(req) res = adder.add_one_req(req)
if ( if (
not res not res
...@@ -470,10 +480,20 @@ class ModelTpServer: ...@@ -470,10 +480,20 @@ class ModelTpServer:
pt = 0 pt = 0
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if req is not self.current_inflight_req: if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i]) req.output_ids.append(next_token_ids[i])
req.check_finished() req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req:
# Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
if req.return_logprob: if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output) self.add_logprob_return_values(i, req, pt, next_token_ids, output)
pt += req.extend_input_len pt += req.extend_input_len
...@@ -529,22 +549,6 @@ class ModelTpServer: ...@@ -529,22 +549,6 @@ class ModelTpServer:
) )
req.output_top_logprobs.append(output.output_top_logprobs[i]) req.output_top_logprobs.append(output.output_top_logprobs[i])
def cache_filled_batch(self, batch: ScheduleBatch):
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.input_ids),
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req.req_pool_idx,
del_in_memory_pool=False,
old_last_node=req.last_node,
)
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
if req is self.current_inflight_req:
# inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
def forward_decode_batch(self, batch: ScheduleBatch): def forward_decode_batch(self, batch: ScheduleBatch):
# Check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem(): if not batch.check_decode_mem():
...@@ -595,6 +599,9 @@ class ModelTpServer: ...@@ -595,6 +599,9 @@ class ModelTpServer:
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
if req.return_logprob: if req.return_logprob:
req.output_token_logprobs.append( req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id) (next_token_logprobs[i], next_token_id)
...@@ -614,12 +621,9 @@ class ModelTpServer: ...@@ -614,12 +621,9 @@ class ModelTpServer:
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
output_meta_info = [] output_meta_info = []
output_finished_reason: List[BaseFinishReason] = [] output_finished_reason: List[BaseFinishReason] = []
finished_indices = []
unfinished_indices = [] unfinished_indices = []
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if req.finished(): if not req.finished() and req is not self.current_inflight_req:
finished_indices.append(i)
else:
unfinished_indices.append(i) unfinished_indices.append(i)
if req.finished() or ( if req.finished() or (
...@@ -683,34 +687,7 @@ class ModelTpServer: ...@@ -683,34 +687,7 @@ class ModelTpServer:
) )
) )
# Remove finished reqs # Remove finished reqs: update batch tensors
if finished_indices:
# Update radix cache
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req.req_pool_idx,
)
self.tree_cache.dec_lock_ref(req.last_node)
# Update batch tensors
if unfinished_indices:
batch.filter_batch(unfinished_indices)
else:
batch.reqs = []
def filter_out_inflight(self, batch: ScheduleBatch):
# TODO(lsyin): reduce the overhead, make a special version for this
if self.current_inflight_req is None:
return
to_remove = batch.reqs.index(self.current_inflight_req)
unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
batch.filter_batch(unfinished_indices) batch.filter_batch(unfinished_indices)
def flush_cache(self): def flush_cache(self):
......
...@@ -17,7 +17,11 @@ class BasePrefixCache(ABC): ...@@ -17,7 +17,11 @@ class BasePrefixCache(ABC):
pass pass
@abstractmethod @abstractmethod
def cache_req(self, **kwargs): def cache_finished_req(self, **kwargs):
pass
@abstractmethod
def cache_unfinished_req(self, **kwargs):
pass pass
@abstractmethod @abstractmethod
......
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from sglang.srt.mem_cache.base_cache import BasePrefixCache from typing import TYPE_CHECKING
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class ChunkCacheEntry: class ChunkCacheEntry:
...@@ -27,22 +32,31 @@ class ChunkCache(BasePrefixCache): ...@@ -27,22 +32,31 @@ class ChunkCache(BasePrefixCache):
entry = self.entries[rid] entry = self.entries[rid]
return entry.value, entry return entry.value, entry
def cache_req( def cache_finished_req(self, req: "Req", token_ids=None):
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs if token_ids is None:
): token_ids = (req.input_ids + req.output_ids)[:-1]
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
if del_in_memory_pool:
assert rid in self.entries
self.req_to_token_pool.free(req_pool_idx)
self.token_to_kv_pool.free(indices)
return
if rid not in self.entries: kv_indices = self.req_to_token_pool.req_to_token[
self.entries[rid] = ChunkCacheEntry(rid, indices) req.req_pool_idx, : len(token_ids)
]
assert req.rid in self.entries
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool.free(kv_indices)
entry = self.entries[rid] def cache_unfinished_req(self, req: "Req", token_ids=None):
entry.value = indices if token_ids is None:
return indices, entry token_ids = req.input_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if req.rid not in self.entries:
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
entry = self.entries[req.rid]
entry.value = kv_indices
return kv_indices, entry
def insert(self): def insert(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache. ...@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache.
import heapq import heapq
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING
import torch import torch
from sglang.srt.mem_cache.base_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class TreeNode: class TreeNode:
...@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache): ...@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache):
value = [x for x in key] value = [x for x in key]
return self._insert_helper(self.root_node, key, value) return self._insert_helper(self.root_node, key, value)
def cache_req( def cache_finished_req(self, req: "Req", token_ids=None):
self, """Cache request when it finishes."""
token_ids, if token_ids is None:
last_uncached_pos, token_ids = (req.input_ids + req.output_ids)[:-1]
req_pool_idx, kv_indices = self.req_to_token_pool.req_to_token[
del_in_memory_pool=True, req.req_pool_idx, : len(token_ids)
old_last_node=None, ]
**kwargs,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
new_prefix_len = self.insert(token_ids, indices.clone())
if self.disable: if self.disable:
if del_in_memory_pool: self.token_to_kv_pool.free(kv_indices)
self.token_to_kv_pool.free(indices) self.req_to_token_pool.free(req.req_pool_idx)
else: return
return torch.tensor([], dtype=torch.int32), self.root_node
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len]) new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
if del_in_memory_pool: # Remove req slot release the cache lock
self.req_to_token_pool.free(req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
else: self.dec_lock_ref(req.last_node)
cached_indices, new_last_node = self.match_prefix(token_ids)
assert len(cached_indices) == len(token_ids) def cache_unfinished_req(self, req: "Req", token_ids=None):
"""Cache request when it is unfinished."""
self.req_to_token_pool.req_to_token[ if self.disable:
req_pool_idx, last_uncached_pos : len(cached_indices) return
] = cached_indices[last_uncached_pos:]
self.dec_lock_ref(old_last_node) if token_ids is None:
self.inc_lock_ref(new_last_node) token_ids = req.input_ids
return cached_indices, new_last_node
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)
assert len(new_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
] = new_indices[len(req.prefix_indices) :]
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self): def pretty_print(self):
self._print_helper(self.root_node, 0) self._print_helper(self.root_node, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment