Unverified Commit f3764c26 authored by cctry's avatar cctry Committed by GitHub
Browse files

Clean match_prefix and prepare_for_extend for mem cache V2 (#11200)

parent 7ba3de0e
...@@ -204,7 +204,6 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts): ...@@ -204,7 +204,6 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
origin_input_ids=tmp_input_ids, origin_input_ids=tmp_input_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
req.prefix_indices = []
req.fill_ids = req.origin_input_ids req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
...@@ -248,7 +247,6 @@ def prepare_synthetic_inputs_for_latency_test( ...@@ -248,7 +247,6 @@ def prepare_synthetic_inputs_for_latency_test(
origin_input_ids=list(input_ids[i]), origin_input_ids=list(input_ids[i]),
sampling_params=sampling_params, sampling_params=sampling_params,
) )
req.prefix_indices = []
req.fill_ids = req.origin_input_ids req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
......
...@@ -539,7 +539,7 @@ class Req: ...@@ -539,7 +539,7 @@ class Req:
# Prefix info # Prefix info
# The indices to kv cache for the shared prefix. # The indices to kv cache for the shared prefix.
self.prefix_indices: torch.Tensor = [] self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
# Number of tokens to run prefill. # Number of tokens to run prefill.
self.extend_input_len = 0 self.extend_input_len = 0
# The relative logprob_start_len in an extend batch # The relative logprob_start_len in an extend batch
...@@ -691,11 +691,16 @@ class Req: ...@@ -691,11 +691,16 @@ class Req:
# Whether request reached finished condition # Whether request reached finished condition
return self.finished_reason is not None return self.finished_reason is not None
def init_next_round_input( def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self,
tree_cache: Optional[BasePrefixCache] = None,
):
self.fill_ids = self.origin_input_ids + self.output_ids self.fill_ids = self.origin_input_ids + self.output_ids
input_len = len(self.fill_ids)
# NOTE: the matched length is at most 1 less than the input length to enable logprob computation
max_prefix_len = input_len - 1
if self.return_logprob:
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
max_prefix_len = max(max_prefix_len, 0)
token_ids = self.fill_ids[:max_prefix_len]
if tree_cache is not None: if tree_cache is not None:
( (
self.prefix_indices, self.prefix_indices,
...@@ -703,31 +708,11 @@ class Req: ...@@ -703,31 +708,11 @@ class Req:
self.last_host_node, self.last_host_node,
self.host_hit_length, self.host_hit_length,
) = tree_cache.match_prefix( ) = tree_cache.match_prefix(
key=RadixKey( key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
),
) )
self.last_matched_prefix_len = len(self.prefix_indices) self.last_matched_prefix_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):
self.fill_ids = self.origin_input_ids + self.output_ids
input_len = len(self.fill_ids)
# FIXME: To work around some bugs in logprob computation, we need to ensure each
# request has at least one token. Later, we can relax this requirement and use `input_len`.
max_prefix_len = input_len - 1
if self.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
max_prefix_len = min(max_prefix_len, input_len - 1)
if self.return_logprob:
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
max_prefix_len = max(max_prefix_len, 0)
return self.fill_ids[:max_prefix_len]
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def init_incremental_detokenize(self): def init_incremental_detokenize(self):
first_iter = self.surr_offset is None or self.read_offset is None first_iter = self.surr_offset is None or self.read_offset is None
...@@ -808,7 +793,7 @@ class Req: ...@@ -808,7 +793,7 @@ class Req:
return return
def reset_for_retract(self): def reset_for_retract(self):
self.prefix_indices = [] self.prefix_indices = torch.empty((0,), dtype=torch.int64)
self.last_node = None self.last_node = None
self.swa_uuid_for_lock = None self.swa_uuid_for_lock = None
self.extend_input_len = 0 self.extend_input_len = 0
...@@ -1124,6 +1109,47 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1124,6 +1109,47 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else: else:
return out_cache_loc return out_cache_loc
def write_cache_indices(
self,
req_pool_indices: List[int],
prefix_lens: List[int],
seq_lens: List[int],
extend_lens: List[int],
out_cache_loc: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
seq_lens_tensor: torch.Tensor,
extend_lens_tensor: torch.Tensor,
prefix_tensors: list[torch.Tensor],
):
if support_triton(global_server_args_dict.get("attention_backend")):
prefix_pointers = torch.tensor(
[t.data_ptr() for t in prefix_tensors], device=self.device
)
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton[(len(req_pool_indices),)](
self.req_to_token_pool.req_to_token,
req_pool_indices_tensor,
prefix_pointers,
prefix_lens_tensor,
seq_lens_tensor,
extend_lens_tensor,
out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
)
else:
pt = 0
for i in range(len(req_pool_indices)):
self.req_to_token_pool.write(
(req_pool_indices[i], slice(0, prefix_lens[i])),
prefix_tensors[i],
)
self.req_to_token_pool.write(
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
out_cache_loc[pt : pt + extend_lens[i]],
)
pt += extend_lens[i]
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = [] self.encoder_lens_cpu = []
self.encoder_cached = [] self.encoder_cached = []
...@@ -1201,10 +1227,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1201,10 +1227,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def prepare_for_extend(self): def prepare_for_extend(self):
self.forward_mode = ForwardMode.EXTEND self.forward_mode = ForwardMode.EXTEND
# Allocate req slots
bs = len(self.reqs)
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
# Init tensors # Init tensors
reqs = self.reqs reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
...@@ -1218,9 +1240,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1218,9 +1240,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
r.token_type_ids for r in reqs if r.token_type_ids is not None r.token_type_ids for r in reqs if r.token_type_ids is not None
] ]
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
input_ids_tensor = torch.tensor( input_ids_tensor = torch.tensor(
list(chain.from_iterable(input_ids)), dtype=torch.int64 list(chain.from_iterable(input_ids)), dtype=torch.int64
).to(self.device, non_blocking=True) ).to(self.device, non_blocking=True)
...@@ -1244,7 +1263,49 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1244,7 +1263,49 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
# Copy prefix and do some basic check # Allocate req slots
bs = len(self.reqs)
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
# Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1:
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
else:
last_loc = [
(
r.prefix_indices[-1:]
if len(r.prefix_indices) > 0
else torch.tensor([-1], device=self.device)
)
for r in self.reqs
]
out_cache_loc = self.alloc_paged_token_slots_extend(
prefix_lens_tensor,
prefix_lens_cpu_tensor,
seq_lens_tensor,
seq_lens_cpu,
torch.cat(last_loc),
extend_num_tokens,
)
# Write allocated tokens to req_to_token_pool
self.write_cache_indices(
req_pool_indices,
prefix_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_pool_indices_tensor,
prefix_lens_tensor,
seq_lens_tensor,
extend_lens_tensor,
[r.prefix_indices for r in reqs],
)
# Set fields
input_embeds = [] input_embeds = []
extend_input_logprob_token_ids = [] extend_input_logprob_token_ids = []
multimodal_inputs = [] multimodal_inputs = []
...@@ -1254,9 +1315,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1254,9 +1315,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
assert seq_len - pre_len == req.extend_input_len assert seq_len - pre_len == req.extend_input_len
if pre_len > 0: if pre_len > 0:
self.req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
if isinstance(self.tree_cache, SWAChunkCache): if isinstance(self.tree_cache, SWAChunkCache):
self.tree_cache.evict_swa( self.tree_cache.evict_swa(
req, pre_len, self.model_config.attention_chunk_size req, pre_len, self.model_config.attention_chunk_size
...@@ -1351,25 +1409,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1351,25 +1409,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else: else:
extend_input_logprob_token_ids = None extend_input_logprob_token_ids = None
# Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1:
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
else:
last_loc = get_last_loc(
self.req_to_token_pool.req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
)
out_cache_loc = self.alloc_paged_token_slots_extend(
prefix_lens_tensor,
prefix_lens_cpu_tensor,
seq_lens_tensor,
seq_lens_cpu,
last_loc,
extend_num_tokens,
)
# Set fields
self.input_ids = input_ids_tensor self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor self.seq_lens = seq_lens_tensor
...@@ -1402,28 +1441,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1402,28 +1441,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.extend_lens = extend_lens self.extend_lens = extend_lens
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
# Write to req_to_token_pool
if support_triton(global_server_args_dict.get("attention_backend")):
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton[(bs,)](
self.req_to_token_pool.req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
seq_lens_tensor,
extend_lens_tensor,
out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
)
else:
pt = 0
for i in range(bs):
self.req_to_token_pool.write(
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
out_cache_loc[pt : pt + extend_lens[i]],
)
pt += extend_lens[i]
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens) self.prepare_encoder_info_extend(input_ids, seq_lens)
...@@ -2024,6 +2041,7 @@ class ModelWorkerBatch: ...@@ -2024,6 +2041,7 @@ class ModelWorkerBatch:
def write_req_to_token_pool_triton( def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len] req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices, req_pool_indices,
prefix_tensors,
pre_lens, pre_lens,
seq_lens, seq_lens,
extend_lens, extend_lens,
...@@ -2036,6 +2054,19 @@ def write_req_to_token_pool_triton( ...@@ -2036,6 +2054,19 @@ def write_req_to_token_pool_triton(
req_pool_index = tl.load(req_pool_indices + pid) req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid) pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid) seq_len = tl.load(seq_lens + pid)
prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
# write prefix
num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < pre_len
value = tl.load(prefix_tensor + offset, mask=mask)
tl.store(
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
value,
mask=mask,
)
# NOTE: This can be slow for large bs # NOTE: This can be slow for large bs
cumsum_start = tl.cast(0, tl.int64) cumsum_start = tl.cast(0, tl.int64)
......
...@@ -174,7 +174,7 @@ class SchedulePolicy: ...@@ -174,7 +174,7 @@ class SchedulePolicy:
self.waiting_queue_radix_tree.reset() self.waiting_queue_radix_tree.reset()
for r in waiting_queue: for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids() prefix_ids = r.origin_input_ids + r.output_ids
extra_key = r.extra_key extra_key = r.extra_key
# NOTE: the prefix_indices must always be aligned with last_node # NOTE: the prefix_indices must always be aligned with last_node
......
...@@ -60,7 +60,7 @@ class ChunkCache(BasePrefixCache): ...@@ -60,7 +60,7 @@ class ChunkCache(BasePrefixCache):
] ]
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = kv_indices req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
def evict(self, num_tokens: int): def evict(self, num_tokens: int):
pass pass
......
...@@ -90,7 +90,6 @@ class TestForwardSplitPrefill(CustomTestCase): ...@@ -90,7 +90,6 @@ class TestForwardSplitPrefill(CustomTestCase):
origin_input_ids=list(input_ids[i]), origin_input_ids=list(input_ids[i]),
sampling_params=sampling_params, sampling_params=sampling_params,
) )
req.prefix_indices = []
req.fill_ids = req.origin_input_ids req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment