Unverified Commit ab4b5606 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Support page size > 1 (#5561)

parent 20f1c8e3
...@@ -35,6 +35,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -35,6 +35,7 @@ from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend, TransferBackend,
get_kv_class, get_kv_class,
kv_to_page_indices,
poll_and_all_reduce, poll_and_all_reduce,
) )
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
...@@ -205,7 +206,10 @@ class DecodePreallocQueue: ...@@ -205,7 +206,10 @@ class DecodePreallocQueue:
self.req_to_metadata_buffer_idx_allocator.alloc() self.req_to_metadata_buffer_idx_allocator.alloc()
) )
assert decode_req.metadata_buffer_index is not None assert decode_req.metadata_buffer_index is not None
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index) page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
)
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
preallocated_reqs.append(decode_req) preallocated_reqs.append(decode_req)
indices_to_remove.add(i) indices_to_remove.add(i)
...@@ -245,10 +249,30 @@ class DecodePreallocQueue: ...@@ -245,10 +249,30 @@ class DecodePreallocQueue:
assert req_pool_indices is not None assert req_pool_indices is not None
req.req_pool_idx = req_pool_indices[0] req.req_pool_idx = req_pool_indices[0]
kv_loc = self.token_to_kv_pool_allocator.alloc( if self.token_to_kv_pool_allocator.page_size == 1:
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) kv_loc = self.token_to_kv_pool_allocator.alloc(
) len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
)
else:
num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens=torch.tensor(
[0],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
seq_lens=torch.tensor(
[num_tokens],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
last_loc=torch.tensor(
[-1],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
extend_num_tokens=num_tokens,
)
assert kv_loc is not None assert kv_loc is not None
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc) self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
......
...@@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import ( ...@@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend, TransferBackend,
get_kv_class, get_kv_class,
kv_to_page_indices,
kv_to_page_num,
poll_and_all_reduce, poll_and_all_reduce,
) )
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
...@@ -154,7 +156,8 @@ class PrefillBootstrapQueue: ...@@ -154,7 +156,8 @@ class PrefillBootstrapQueue:
self.req_to_metadata_buffer_idx_allocator.alloc() self.req_to_metadata_buffer_idx_allocator.alloc()
) )
assert req.metadata_buffer_index is not None assert req.metadata_buffer_index is not None
req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index) num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
bootstrapped_reqs.append(req) bootstrapped_reqs.append(req)
indices_to_remove.add(i) indices_to_remove.add(i)
...@@ -300,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -300,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin:
req.metadata_buffer_index, token_id req.metadata_buffer_index, token_id
) )
is_last = token_id is not None is_last = token_id is not None
req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last) page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
)
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
...@@ -4,6 +4,7 @@ from collections import deque ...@@ -4,6 +4,7 @@ from collections import deque
from enum import Enum from enum import Enum
from typing import List from typing import List
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): ...@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
} }
return class_mapping.get(class_type) return class_mapping.get(class_type)
raise ValueError(f"Unsupported transfer backend: {transfer_backend}") raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
# 1. The page is guaruanteed to be full except the last page.
# 2. page index = kv_index // page_size
# The return vector is kv_indices[::page_size] // page_size
if page_size == 1: # shortcut
return kv_indices
return kv_indices[::page_size] // page_size
def kv_to_page_num(num_kv_indices: int, page_size: int):
# ceil(num_kv_indices / page_size)
return (num_kv_indices + page_size - 1) // page_size
...@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache): ...@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
self.get_key_buffer(i).nbytes for i in range(self.layer_num) self.get_key_buffer(i).nbytes for i in range(self.layer_num)
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)] ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
kv_item_lens = [ kv_item_lens = [
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num) self.get_key_buffer(i)[0].nbytes * self.page_size
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)] for i in range(self.layer_num)
] + [
self.get_value_buffer(i)[0].nbytes * self.page_size
for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
# Todo: different memory layout # Todo: different memory layout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment