Unverified Commit c7c7dbeb authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Release initial code (#4654)


Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatarYing1123 <sqy1415@gmail.com>
Co-authored-by: default avatarmerrymercy <lianminzheng@gmail.com>
Co-authored-by: makro
Co-authored-by: dhou-xai
parent 417fc72f
from __future__ import annotations
import logging
from enum import Enum
from typing import Optional
import numpy as np
import numpy.typing as npt
logger = logging.getLogger(__name__)
class KVArgs:
engine_rank: int
kv_data_ptrs: list[int]
kv_data_lens: list[int]
kv_item_lens: list[int]
aux_data_ptrs: list[int]
aux_data_lens: list[int]
aux_item_lens: list[int]
ib_device: str
class KVManager:
def __init__(self, args: KVArgs): ...
class KVPoll:
Failed = 0
Bootstrapping = 1
WaitingForInput = 2
Transferring = 3
Success = 4
class KVSender:
def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
self.has_sent = False
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ...
def send(self, kv_indices: npt.NDArray[np.int32]):
self.has_sent = True
def poll(self) -> KVPoll:
if self.has_sent is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
return KVPoll.Success
def failure_exception(self):
raise Exception("Fake KVSender Exception")
class KVReceiver:
def __init__(
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
):
self.has_init = False
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
self.has_init = True
def poll(self) -> KVPoll:
if self.has_init is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
return KVPoll.Success
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
class KVBootstrapServer:
def __init__(self, port: int): ...
def poll(self) -> KVPoll: ...
"""
Life cycle of a request in the decode server
1. PreallocQueue:
a. Initialize a receiver for each request
b. The request handshakes first, and pre-allocate kv once there is available kv.
c. Move the request to TransferQueue.
2. TransferQueue:
a. Poll the receiver to check the transfer state
b. If the transfer has finished, move the request to waiting queue
3. WaitingQueue:
a. Use the requests in the queue to construct a PrebuiltExtendBatch
b. Skip the prefill forward but only populate metadata
4. RunningBatch:
a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator,
poll_and_all_reduce,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.server_args import ServerArgs
@dataclass
class DecodeRequest:
req: Req
kv_receiver: KVReceiver
waiting_for_input: bool = False
metadata_buffer_index: int = -1
class DecodePreallocQueue:
"""
Store the requests that are preallocating.
"""
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor],
aux_dtype: torch.dtype,
scheduler: Scheduler,
transfer_queue: DecodeTransferQueue,
tree_cache: BasePrefixCache,
gloo_group: ProcessGroup,
tp_rank: int,
tp_size: int,
bootstrap_port: int,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
self.aux_dtype = aux_dtype
self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.scheduler = scheduler
self.transfer_queue = transfer_queue
self.tree_cache = tree_cache # this is always a chunk cache
self.gloo_group = gloo_group
self.tp_rank = tp_rank
self.tp_size = tp_size
self.bootstrap_port = bootstrap_port
self.num_reserved_decode_tokens = 512
# Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = []
self.kv_manager = self._init_kv_manager()
def _init_kv_manager(self) -> KVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
kv_args.aux_data_ptrs = [
output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
]
kv_args.aux_data_lens = [
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_item_lens = [
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = "mock-ib-device"
kv_manager = KVManager(kv_args)
return kv_manager
def add(self, req: Req) -> None:
"""Add a request to the pending queue."""
kv_receiver = KVReceiver(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
)
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
def extend(self, reqs: List[Req]) -> None:
"""Add a request to the pending queue."""
for req in reqs:
self.add(req)
def _update_handshake_waiters(self) -> None:
if not self.queue:
return
if all(decode_req.waiting_for_input for decode_req in self.queue):
return
polls = poll_and_all_reduce(
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Bootstrapping:
pass
elif poll == KVPoll.WaitingForInput:
decode_req.waiting_for_input = True
elif poll == KVPoll.Failed:
raise Exception("Handshake failed")
def pop_preallocated(self) -> List[DecodeRequest]:
"""Pop the preallocated requests from the pending queue (FIFO)."""
self._update_handshake_waiters()
preallocated_reqs = []
indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens(count_retracted=True)
for i, decode_req in enumerate(self.queue):
if not decode_req.waiting_for_input:
continue
if self.req_to_token_pool.available_size() <= 0:
break
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
break
required_tokens_for_request = (
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
)
if required_tokens_for_request > allocatable_tokens:
break
allocatable_tokens -= required_tokens_for_request
self._pre_alloc(decode_req.req)
kv_indices = (
self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
: len(decode_req.req.origin_input_ids)
]
.cpu()
.numpy()
)
decode_req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert decode_req.metadata_buffer_index is not None
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
return preallocated_reqs
def _allocatable_tokens(self) -> int:
allocatable_tokens = (
self.token_to_kv_pool_allocator.available_size()
- self.num_reserved_decode_tokens
* (
len(self.scheduler.running_batch.reqs)
+ len(self.transfer_queue.queue)
+ len(self.scheduler.waiting_queue)
)
)
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
# the extend batch is not in any queue, so we need to explicitly add the tokens slots here
if (
self.scheduler.last_batch
and self.scheduler.last_batch.forward_mode.is_extend()
):
allocatable_tokens -= self.num_reserved_decode_tokens * len(
self.scheduler.last_batch.reqs
)
return allocatable_tokens
def _pre_alloc(self, req: Req) -> torch.Tensor:
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
req_pool_indices = self.req_to_token_pool.alloc(1)
assert req_pool_indices is not None
req.req_pool_idx = req_pool_indices[0]
kv_loc = self.token_to_kv_pool_allocator.alloc(
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
)
assert kv_loc is not None
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
# populate metadata
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = len(req.origin_input_ids)
return kv_loc
class DecodeTransferQueue:
"""
Store the requests that is polling kv
"""
def __init__(
self,
gloo_group: ProcessGroup,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: torch.Tensor,
):
self.queue: List[DecodeRequest] = []
self.gloo_group = gloo_group
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.metadata_buffers = metadata_buffers
def add(self, req_conn: DecodeRequest) -> None:
self.queue.append(req_conn)
def extend(self, req_conns) -> None:
self.queue.extend(req_conns)
def pop_transferred(self) -> List[Req]:
if not self.queue:
return []
polls = poll_and_all_reduce(
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
transferred_reqs = []
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed:
raise Exception("Transfer failed")
elif poll == KVPoll.Success:
# pop and push it to waiting queue
idx = decode_req.metadata_buffer_index
assert len(decode_req.req.output_ids) == 0
output_id_buffer = self.metadata_buffers[0]
# the last dimension is padded by the same values.
output_id = output_id_buffer[idx][0].item()
assert len(decode_req.req.output_ids) == 0
assert decode_req.req.transferred_output_id is None
decode_req.req.transferred_output_id = output_id
transferred_reqs.append(decode_req.req)
indices_to_remove.add(i)
elif poll in [
KVPoll.Bootstrapping,
KVPoll.WaitingForInput,
KVPoll.Transferring,
]:
pass
else:
raise ValueError(f"Unexpected poll case: {poll}")
for i in indices_to_remove:
idx = self.queue[i].metadata_buffer_index
assert idx != -1
self.req_to_metadata_buffer_idx_allocator.free(idx)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
return transferred_reqs
class ScheduleBatchDisaggregationDecodeMixin:
def prepare_for_prebuilt_extend(self: ScheduleBatch):
"""
Prepare a prebuilt extend by populate metadata
Adapted from .prepare_for_extend().
"""
self.forward_mode = ForwardMode.EXTEND
reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = []
pre_lens = []
req_pool_indices = []
# Pre-calculate total size
total_size = sum(req.extend_input_len for req in reqs)
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
# Fill the tensor in one pass
offset = 0
for i, req in enumerate(reqs):
req_pool_indices.append(req.req_pool_idx)
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
: req.extend_input_len
]
assert (
offset + req.extend_input_len <= total_size
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
out_cache_loc[offset : offset + req.extend_input_len] = chunk
offset += req.extend_input_len
pre_len = len(req.prefix_indices)
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
seq_lens.append(seq_len)
if len(req.output_ids) == 0:
assert (
seq_len - pre_len == req.extend_input_len
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)
req.extend_logprob_start_len = 0
extend_input_logprob_token_ids = None
# Set fields
self.input_ids = torch.tensor(
sum(input_ids, []), dtype=torch.int32, device=self.device
)
self.req_pool_indices = torch.tensor(
req_pool_indices, dtype=torch.int64, device=self.device
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
# Build sampling info
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
)
def process_prebuilt_extend(
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
):
"""Assign the buffered last input id to schedule batch"""
self.output_ids = []
for req in self.reqs:
if req.output_ids and len(req.output_ids) > 0:
# resumed retracted req
self.output_ids.append(req.output_ids[-1])
else:
assert req.transferred_output_id is not None
req.output_ids.append(req.transferred_output_id)
self.output_ids.append(req.transferred_output_id)
self.tree_cache.cache_unfinished_req(req)
self.output_ids = torch.tensor(self.output_ids, device=self.device)
class SchedulerDisaggregationDecodeMixin:
def get_next_disagg_decode_batch_to_run(
self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]:
"""Create fake completed prefill if possible and merge with running batch"""
# Merge the prefill batch into the running batch
last_batch = self.last_batch
if last_batch and last_batch.forward_mode.is_extend():
# chunked prefill doesn't happen in decode instance.
assert self.chunked_req is None
# Filter finished batches.
last_batch.filter_batch()
if not last_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = last_batch
else:
# merge running_batch with prefill batch
self.running_batch.merge_batch(last_batch)
new_prebuilt_batch = self.get_new_prebuilt_batch()
ret: Optional[ScheduleBatch] = None
if new_prebuilt_batch:
ret = new_prebuilt_batch
else:
if self.running_batch.is_empty():
ret = None
else:
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch if not self.running_batch.is_empty() else None
return ret
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
"""Create a schedulebatch for fake completed prefill"""
if len(self.waiting_queue) == 0:
return None
curr_batch_size = self.running_batch.batch_size()
batch_size = min(self.req_to_token_pool.size, self.max_running_requests)
num_not_used_batch = batch_size - curr_batch_size
# pop req from waiting queue
can_run_list: List[Req] = []
waiting_queue: List[Req] = []
for i in range(len(self.waiting_queue)):
req = self.waiting_queue[i]
# we can only add at least `num_not_used_batch` new batch to the running queue
if i < num_not_used_batch:
can_run_list.append(req)
req.init_next_round_input(self.tree_cache)
else:
waiting_queue.append(req)
self.waiting_queue = waiting_queue
if len(can_run_list) == 0:
return None
# local import to avoid circular import
from sglang.srt.managers.schedule_batch import ScheduleBatch
# construct a schedule batch with those requests and mark as decode
new_batch = ScheduleBatch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
# construct fake completed prefill
new_batch.prepare_for_prebuilt_extend()
new_batch.process_prebuilt_extend(self.server_args, self.model_config)
return new_batch
def process_decode_queue(self: Scheduler):
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = (
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs)
"""
Minimal HTTP load balancer for prefill and decode servers for testing purpose.
"""
import asyncio
import random
import urllib
from itertools import chain
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
class MiniLoadBalancer:
def __init__(self, prefill_servers, decode_servers):
self.prefill_servers = prefill_servers
self.decode_servers = decode_servers
def select_pair(self):
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
async def generate_request(self, request_data):
prefill_server, decode_server = self.select_pair()
# Parse and transform prefill_server
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
bootstrap_host = f"{hostname}"
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": bootstrap_host,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
)
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = [
session.post(f"{prefill_server}/generate", json=modified_request),
session.post(f"{decode_server}/generate", json=modified_request),
]
prefill_response = None
decode_response = None
# Process responses as they arrive
for i, response in enumerate(asyncio.as_completed(tasks)):
response = await response
# Check if this is the prefill or decode response based on order created
if i == 0: # First completed task
if str(response.url).startswith(prefill_server):
prefill_response = response
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}",
)
else:
decode_response = response
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Decode server error: Status {response.status} Details: {await response.text()}",
)
else: # Second completed task
if str(response.url).startswith(prefill_server):
prefill_response = response
else:
decode_response = response
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}",
)
return await decode_response.json()
app = FastAPI()
load_balancer = None
@app.get("/health")
async def health_check():
return Response(status_code=200)
@app.get("/health_generate")
async def health_check():
prefill_servers, decode_servers = (
load_balancer.prefill_servers,
load_balancer.decode_servers,
)
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(prefill_servers, decode_servers):
tasks.append(session.post(f"{server}/health_generate"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.post("/flush_cache")
async def flush_cache():
prefill_servers, decode_servers = (
load_balancer.prefill_servers,
load_balancer.decode_servers,
)
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(prefill_servers, decode_servers):
tasks.append(session.post(f"{server}/flush_cache"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.get("/get_server_info")
async def get_server_info():
prefill_servers, decode_servers = (
load_balancer.prefill_servers,
load_balancer.decode_servers,
)
prefill_infos = []
decode_infos = []
async with aiohttp.ClientSession() as session:
for server in chain(prefill_servers):
server_info = await session.get(f"{server}/get_server_info")
prefill_infos.append(await server_info.json())
for server in chain(decode_servers):
server_info = await session.get(f"{server}/get_server_info")
decode_infos.append(await server_info.json())
return {"prefill": prefill_infos, "decode": decode_infos}
@app.get("/get_model_info")
async def get_model_info():
# Dummy model information
model_info = {
"model_path": "/path/to/dummy/model",
"tokenizer_path": "/path/to/dummy/tokenizer",
"is_generation": True,
"preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
}
return ORJSONResponse(content=model_info)
@app.post("/generate")
async def handle_generate_request(request_data: dict):
prefill_server, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
)
# Check if streaming is requested
if request_data.get("stream", False):
async def stream_results():
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=3600)
) as session:
try:
# Create the tasks
tasks = [
session.post(
f"{prefill_server}/generate", json=modified_request
),
session.post(
f"{decode_server}/generate", json=modified_request
),
]
prefill_response = None
decode_response = None
# Process responses as they arrive
for i, response_task in enumerate(asyncio.as_completed(tasks)):
response = await response_task
# Check the response immediately
if str(response.url).startswith(prefill_server):
prefill_response = response
if response.status != 200:
error_msg = {
"error": {
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return
else:
decode_response = response
if response.status != 200:
error_msg = {
"error": {
"message": f"Decode server error: Status {response.status}"
}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return
# Stream successful decode server response
async for line in decode_response.content:
yield line
yield b"data: [DONE]\n\n"
except Exception as e:
error_msg = {
"error": {"message": f"Stream processing error: {str(e)}"}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
)
# Non-streaming case
result = await load_balancer.generate_request(request_data)
return ORJSONResponse(content=result)
@app.get("/v1/models")
async def get_models():
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
async with aiohttp.ClientSession() as session:
try:
response = await session.get(f"{prefill_server}/v1/models")
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Prefill server error: Status {response.status}",
)
return ORJSONResponse(content=await response.json())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def run(prefill_addrs, decode_addrs, host, port):
global load_balancer
load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
parser.add_argument(
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
)
parser.add_argument(
"--decode", required=True, help="Comma-separated URLs for decode servers"
)
parser.add_argument(
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
)
args = parser.parse_args()
run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
"""
Life cycle of a request in the prefill server
1. Bootstrap Queue
a. Initialize a sender for each request
b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished
c. Poll senders to check bootstrap state
d. Once bootstrap is complete, move request to Waiting Queue
2. Waiting Queue
a. Use PrefillAdder to pop requests
b. Run forward
c. Add the request to Infight Queue
3. Infight Queue
a. Poll (non-blocking) the sender of the request
b. Once the transfer has finished, return the request
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional
import torch
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator,
poll_and_all_reduce,
)
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
from sglang.srt.mem_cache.memory_pool import KVCache
logger = logging.getLogger(__name__)
class PrefillBootstrapQueue:
"""
Store the requests in bootstrapping
"""
def __init__(
self,
token_to_kv_pool: KVCache,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor],
aux_dtype: torch.dtype,
tp_rank: int,
tp_size: int,
bootstrap_port: int,
gloo_group: ProcessGroup,
):
self.token_to_kv_pool = token_to_kv_pool
self.aux_dtype = aux_dtype
self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank
self.tp_size = tp_size
self.kv_manager = self._init_kv_manager()
self.queue: List[Req] = []
self.gloo_group = gloo_group
self.bootstrap_port = bootstrap_port
def allocate_token_id(self, idx: int, token_id: int):
assert token_id >= 0, f"token_id: {token_id} is negative"
output_id_buffer = self.metadata_buffers[0]
output_id_buffer[idx] = token_id
def _init_kv_manager(self) -> KVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
# Define req -> input ids buffer
kv_args.aux_data_ptrs = [
metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
]
kv_args.aux_data_lens = [
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_item_lens = [
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = "mock-ib-device"
kv_manager = KVManager(kv_args)
return kv_manager
def add(self, req: Req) -> None:
req.disagg_kv_sender = KVSender(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
)
self._process_req(req)
self.queue.append(req)
def _process_req(self, req: Req) -> None:
"""
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
"""
req.sampling_params.max_new_tokens = 1
def pop_bootstrapped(self) -> List[Req]:
"""pop the reqs which has finished bootstrapping"""
bootstrapped_reqs = []
indices_to_remove = set()
if len(self.queue) == 0:
return []
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.queue], self.gloo_group
)
for i, (req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Bootstrapping:
continue
elif poll == KVPoll.Failed:
raise Exception("Bootstrap failed")
# KV.WaitingForInput - init here
num_kv_indices = len(req.origin_input_ids)
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
break
req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert req.metadata_buffer_index is not None
req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
bootstrapped_reqs.append(req)
indices_to_remove.add(i)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
return bootstrapped_reqs
class SchedulerDisaggregationPrefillMixin:
"""
Mixin for Scheduler to handle disaggregation prefill
"""
def process_batch_result_disagg_prefill(
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
) -> None:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
Adapted from process_batch_result_prefill
"""
next_token_ids = result.next_token_ids.tolist()
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
req: Req
if req.is_chunked <= 0:
# There is no output_ids for prefill
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
self.send_kv_chunk(req, token_id=next_token_id)
self.disagg_prefill_infight_queue.append(req)
else:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
# TODO: Not sure if this is necessary
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
# We need to remove this for overlap schedule.
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_disagg_prefill_infight_queue(self: Scheduler) -> None:
"""
Poll the requests in the middle of transfer. If done, return the request.
"""
assert len(self.disagg_prefill_infight_queue) > 0
done_reqs = []
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_infight_queue],
self.tp_worker.get_tp_cpu_group(),
)
undone_reqs: List[Req] = []
# Check .poll() for the reqs in disagg_prefill_infight_queue. If Success, respond to the client and remove it from the queue
for req, poll in zip(self.disagg_prefill_infight_queue, polls):
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
undone_reqs.append(req)
elif poll == KVPoll.Success: # transfer done
self.tree_cache.cache_finished_req(req) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0)
done_reqs.append(req)
elif poll == KVPoll.Failed:
raise Exception("Transferring failed")
# Stream requests which have finished transfer
self.stream_output(done_reqs, False, None)
self.disagg_prefill_infight_queue = undone_reqs
def process_prefill_chunk(self: Scheduler) -> None:
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.chunked_req:
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req)
self.send_kv_chunk(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.running_batch.batch_is_full = False
def send_kv_chunk(
self: Scheduler, req: Req, token_id: Optional[int] = None
) -> None:
"""
Send a prefilled chunk to the decode server
"""
start_idx = req.start_send_idx
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
.cpu()
.numpy()
)
req.start_send_idx = end_idx
if token_id is not None:
self.disagg_prefill_pending_queue.allocate_token_id(
req.metadata_buffer_index, token_id
)
req.disagg_kv_sender.send(kv_indices)
from __future__ import annotations
from collections import deque
from enum import Enum
from typing import List
import torch
import torch.distributed as dist
class DisaggregationMode(Enum):
NULL = "null"
PREFILL = "prefill"
DECODE = "decode"
def poll_and_all_reduce(pollers, gloo_group):
polls = [int(poller.poll()) for poller in pollers]
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
return tensor_to_reduce.tolist()
class ReqToMetadataIdxAllocator:
"""A memory pool that maps a request to its first output token location."""
def __init__(
self,
size: int,
):
self.size = size
self.free_slots = deque(list(range(size)))
def available_size(self):
return len(self.free_slots)
def alloc(self) -> List[int]:
if len(self.free_slots) == 0:
return None
return self.free_slots.popleft()
def free(self, free_index: int):
self.free_slots.append(free_index)
......@@ -42,6 +42,8 @@ import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.conn import KVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
......@@ -396,6 +398,24 @@ class Req:
self.spec_verify_ct = 0
self.lora_path = lora_path
# For disaggregation
self.bootstrap_host: str = "0.0.0.0"
self.bootstrap_room: Optional[int] = None
self.disagg_kv_sender: Optional[KVSender] = None
# used for warmup because we don't have a pair yet when init
self.skip_kv_transfer: bool = False
# the start index of the sent kv cache
# We want to send it chunk by chunk for chunked prefill.
# After every chunk forward, we do the following:
# kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
# start_send_idx = len(req.fill_ids)
self.start_send_idx: int = 0
self.metadata_buffer_index: int = -1
# The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None
@property
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
......@@ -531,7 +551,7 @@ bid = 0
@dataclasses.dataclass
class ScheduleBatch:
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
"""Store all information of a batch on the scheduler."""
# Request, memory pool, and cache
......
......@@ -37,6 +37,19 @@ from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.disaggregation.decode import (
DecodePreallocQueue,
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
ReqToMetadataIdxAllocator,
)
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
......@@ -137,7 +150,11 @@ class EmbeddingBatchResult:
bid: int
class Scheduler(SchedulerOutputProcessorMixin):
class Scheduler(
SchedulerOutputProcessorMixin,
SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin,
):
"""A scheduler that manages a tensor parallel GPU worker."""
def __init__(
......@@ -389,6 +406,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
]
)
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
def init_tokenizer(self):
server_args = self.server_args
......@@ -489,6 +511,73 @@ class Scheduler(SchedulerOutputProcessorMixin):
},
)
def init_disaggregation(self):
if (
self.disaggregation_mode == DisaggregationMode.DECODE
): # *2 for the headroom.
buffer_size = (self.req_to_token_pool.size) * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
aux_dtype = torch.int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch.zeros(
(buffer_size, 16), dtype=aux_dtype, device="cpu"
)
metadata_buffers = [output_id_buffer]
# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
)
# The decode requests pending for pre-allocation
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
aux_dtype=aux_dtype,
scheduler=self,
transfer_queue=self.disagg_decode_transfer_queue,
tree_cache=self.tree_cache,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
)
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
# *2 for the headroom.
buffer_size = self.max_running_requests * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
aux_dtype = torch.int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch.zeros(
(buffer_size, 16), dtype=aux_dtype, device="cpu"
)
metadata_buffers = [output_id_buffer]
self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
aux_dtype=aux_dtype,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
)
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_infight_queue: List[Req] = []
@DynamicGradMode()
def event_loop_normal(self):
"""A normal scheduler loop."""
......@@ -549,6 +638,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.last_batch = batch
@torch.no_grad()
def event_loop_normal_disagg_prefill(self):
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_pending_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.process_batch_result_disagg_prefill(batch, result)
if len(self.disagg_prefill_infight_queue) > 0:
self.process_disagg_prefill_infight_queue()
if batch is None and len(self.disagg_prefill_infight_queue) == 0:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
@torch.no_grad()
def event_loop_normal_disagg_decode(self):
"""A normal scheduler loop for decode worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# polling and allocating kv cache
self.process_decode_queue()
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch
if batch:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(
batch.reqs, [False for _ in range(len(batch.reqs))]
)
else:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if self.attn_tp_rank == 0:
......@@ -778,10 +931,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
self._add_request_to_queue(req)
def _add_request_to_queue(self, req: Req):
self.waiting_queue.append(req)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_pending_queue.add(req)
def _extend_requests_to_queue(self, reqs: List[Req]):
self.waiting_queue.extend(reqs)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
self.waiting_queue.append(req)
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.extend(reqs)
else:
self.waiting_queue.extend(reqs)
def handle_embedding_request(
self,
......@@ -1814,10 +1977,18 @@ def run_scheduler_process(
"max_req_input_len": scheduler.max_req_input_len,
}
)
if scheduler.enable_overlap:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
if disaggregation_mode == DisaggregationMode.NULL:
if scheduler.enable_overlap:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()
elif disaggregation_mode == DisaggregationMode.PREFILL:
scheduler.event_loop_normal_disagg_prefill()
elif disaggregation_mode == DisaggregationMode.DECODE:
scheduler.event_loop_normal_disagg_decode()
except Exception:
traceback = get_exception_traceback()
logger.error(f"Scheduler hit an exception: {traceback}")
......
......@@ -49,6 +49,8 @@ from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.conn import KVBootstrapServer
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
......@@ -313,6 +315,16 @@ class TokenizerManager:
]
)
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
# for disaggregtion, start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
self.bootstrap_server = KVBootstrapServer(
self.server_args.disaggregation_bootstrap_port
)
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
......
......@@ -271,6 +271,19 @@ class MHATokenToKVPool(KVCache):
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
return k_size_bytes, v_size_bytes
# for disagg
def get_contiguous_buf_infos(self):
kv_data_ptrs = [
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
kv_data_lens = [
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
kv_item_lens = [
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
return kv_data_ptrs, kv_data_lens, kv_item_lens
# Todo: different memory layout
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer
......
......@@ -185,6 +185,10 @@ class ServerArgs:
debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode: str = "null"
disaggregation_bootstrap_port: int = 8998
def __post_init__(self):
# Set missing default values
if self.tokenizer_path is None:
......@@ -325,6 +329,18 @@ class ServerArgs:
if is_hip():
self.triton_attention_num_kv_splits = 16
# PD disaggregation
if self.disaggregation_mode == "prefill":
self.disable_cuda_graph = True
logger.warning("KV cache is forced as chunk cache for decode server")
self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for prefill server")
elif self.disaggregation_mode == "decode":
self.disable_radix_cache = True
logger.warning("Cuda graph is disabled for prefill server")
self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for decode server")
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
......@@ -1063,6 +1079,21 @@ class ServerArgs:
help="Inject the outputs from jax as the input of every layer.",
)
# Disaggregation
parser.add_argument(
"--disaggregation-mode",
type=str,
default="null",
choices=["null", "prefill", "decode"],
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
)
parser.add_argument(
"--disaggregation-bootstrap-port",
type=int,
default=ServerArgs.disaggregation_bootstrap_port,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment