Unverified Commit c2b16795 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

Add decode req pool (#6980)

parent f6ebba53
...@@ -25,7 +25,7 @@ import os ...@@ -25,7 +25,7 @@ import os
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -49,6 +49,7 @@ from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch ...@@ -49,6 +49,7 @@ from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -57,6 +58,67 @@ if TYPE_CHECKING: ...@@ -57,6 +58,67 @@ if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Scheduler from sglang.srt.managers.scheduler import Scheduler
class DecodeReqToTokenPool:
"""
The difference of DecodeReqToTokenPool and ReqToTokenPool is that
DecodeReqToTokenPool subscribes memory for pre-allocated requests.
In ReqToTokenPool, if `--max-running-requests` is 8,
#pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests.
In DecodeReqToTokenPool, if `--max-running-requests` is 8,
#running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill.
"""
def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
pre_alloc_size: int,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.size = size
self.max_context_len = max_context_len
self.device = device
self.pre_alloc_size = pre_alloc_size
with memory_saver_adapter.region():
self.req_to_token = torch.zeros(
(size + pre_alloc_size, max_context_len),
dtype=torch.int32,
device=device,
)
self.free_slots = list(range(size + pre_alloc_size))
def write(self, indices, values):
self.req_to_token[indices] = values
def available_size(self):
return len(self.free_slots)
def alloc(self, need_size: int) -> List[int]:
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: Union[int, List[int]]):
if isinstance(free_index, (int,)):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
def clear(self):
self.free_slots = list(range(self.size + self.pre_alloc_size))
@dataclass @dataclass
class DecodeRequest: class DecodeRequest:
req: Req req: Req
......
...@@ -916,12 +916,26 @@ class ModelRunner: ...@@ -916,12 +916,26 @@ class ModelRunner:
) )
if self.req_to_token_pool is None: if self.req_to_token_pool is None:
self.req_to_token_pool = ReqToTokenPool( if self.server_args.disaggregation_mode == "decode":
size=max_num_reqs, from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
max_context_len=self.model_config.context_len + 4,
device=self.device, # subscribe memory for pre-allocated requests
enable_memory_saver=self.server_args.enable_memory_saver, # if max_num_reqs <= 32, we pre-allocate 2x requests
) pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
self.req_to_token_pool = DecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
else:
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
else: else:
# Draft worker shares req_to_token_pool with the target worker. # Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker assert self.is_draft_worker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment