Unverified Commit c951d312 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Fix large page size + chunk prefill (#5588)

parent dcb82325
......@@ -231,7 +231,7 @@ class MooncakeKVManager(BaseKVManager):
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
assert len(chunked_dst_kv_indice) == len(
kv_chunk.prefill_kv_indices
)
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
ret = self.send_kvcache(
req.mooncake_session_id,
......
......@@ -306,4 +306,10 @@ class SchedulerDisaggregationPrefillMixin:
page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
)
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
page_start_idx = start_idx // self.token_to_kv_pool_allocator.page_size
page_end_idx = page_start_idx + len(page_indices)
req.disagg_kv_sender.send(
page_indices, slice(page_start_idx, page_end_idx), is_last
)
......@@ -76,13 +76,22 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int, is_last: bool = True):
# 1. The page is guaruanteed to be full except the last page.
# 2. page index = kv_index // page_size
# The return vector is kv_indices[::page_size] // page_size
if page_size == 1: # shortcut
return kv_indices
return kv_indices[::page_size] // page_size
# if last chunk, send the last partial page
# if not last chunk, delay the last partial page to the next send
if is_last:
return kv_indices[::page_size] // page_size
else:
if len(kv_indices) % page_size == 0: # no partial page
return kv_indices[::page_size] // page_size
else: # partial page
return kv_indices[::page_size][:-1] // page_size
def kv_to_page_num(num_kv_indices: int, page_size: int):
......
......@@ -446,13 +446,16 @@ class MLATokenToKVPool(KVCache):
]
self.layer_transfer_counter = None
self.page_size = page_size
# for disagg
def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
kv_item_lens = [
self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_key_buffer(self, layer_id: int):
......
prompt = "Hello " * 16000
import json
import requests
response = requests.post(
"http://0.0.0.0:8000/generate",
json={"text": prompt, "sampling_params": {"temperature": 0}},
)
print("Response content (raw):", response.content)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment