Unverified Commit ab4b5606 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Support page size > 1 (#5561)

parent 20f1c8e3
......@@ -35,6 +35,7 @@ from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
kv_to_page_indices,
poll_and_all_reduce,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
......@@ -205,7 +206,10 @@ class DecodePreallocQueue:
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert decode_req.metadata_buffer_index is not None
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
)
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)
......@@ -245,10 +249,30 @@ class DecodePreallocQueue:
assert req_pool_indices is not None
req.req_pool_idx = req_pool_indices[0]
kv_loc = self.token_to_kv_pool_allocator.alloc(
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
)
if self.token_to_kv_pool_allocator.page_size == 1:
kv_loc = self.token_to_kv_pool_allocator.alloc(
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
)
else:
num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens=torch.tensor(
[0],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
seq_lens=torch.tensor(
[num_tokens],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
last_loc=torch.tensor(
[-1],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
extend_num_tokens=num_tokens,
)
assert kv_loc is not None
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
......
......@@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
kv_to_page_indices,
kv_to_page_num,
poll_and_all_reduce,
)
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
......@@ -154,7 +156,8 @@ class PrefillBootstrapQueue:
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert req.metadata_buffer_index is not None
req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
bootstrapped_reqs.append(req)
indices_to_remove.add(i)
......@@ -300,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin:
req.metadata_buffer_index, token_id
)
is_last = token_id is not None
req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last)
page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
)
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
......@@ -4,6 +4,7 @@ from collections import deque
from enum import Enum
from typing import List
import numpy as np
import torch
import torch.distributed as dist
......@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
}
return class_mapping.get(class_type)
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
# 1. The page is guaruanteed to be full except the last page.
# 2. page index = kv_index // page_size
# The return vector is kv_indices[::page_size] // page_size
if page_size == 1: # shortcut
return kv_indices
return kv_indices[::page_size] // page_size
def kv_to_page_num(num_kv_indices: int, page_size: int):
# ceil(num_kv_indices / page_size)
return (num_kv_indices + page_size - 1) // page_size
......@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
kv_item_lens = [
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
self.get_key_buffer(i)[0].nbytes * self.page_size
for i in range(self.layer_num)
] + [
self.get_value_buffer(i)[0].nbytes * self.page_size
for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
# Todo: different memory layout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment