Unverified Commit c6c62640 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[PD] support pd fake transfer for warmup (#5726)

parent 92ab0a20
......@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
FakeBootstrapHost,
KVClassType,
ReqToMetadataIdxAllocator,
TransferBackend,
......@@ -133,8 +134,13 @@ class DecodePreallocQueue:
def add(self, req: Req) -> None:
"""Add a request to the pending queue."""
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
if req.bootstrap_host == FakeBootstrapHost:
# Fake transfer for warmup reqs
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
else:
kv_receiver_class = get_kv_class(
self.transfer_backend, KVClassType.RECEIVER
)
kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
......
from .conn import FakeKVReceiver, FakeKVSender
import logging
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
from sglang.srt.disaggregation.base.conn import (
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
logger = logging.getLogger(__name__)
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
class FakeKVSender(BaseKVSender):
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
self.has_sent = False
def poll(self) -> KVPoll:
if self.has_sent is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.info("FakeKVSender poll success")
return KVPoll.Success
def init(
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
dest_ranks: Optional[list[int]] = None,
):
logger.info(
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
)
pass
def send(
self,
kv_indices: npt.NDArray[np.int64],
index_slice: slice,
is_last: bool,
):
logger.info(
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
)
if is_last:
self.has_sent = True
logger.info(f"FakeKVSender send success")
else:
self.has_sent = False
logger.info(f"FakeKVSender send fake transfering")
def failure_exception(self):
raise Exception("Fake KVSender Exception")
class FakeKVReceiver(BaseKVReceiver):
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
):
self.has_init = False
def poll(self) -> KVPoll:
if self.has_init is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.info("FakeKVReceiver poll success")
return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
self.has_init = True
logger.info(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
)
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
......@@ -29,6 +29,7 @@ import torch
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
FakeBootstrapHost,
KVClassType,
ReqToMetadataIdxAllocator,
TransferBackend,
......@@ -116,7 +117,11 @@ class PrefillBootstrapQueue:
return kv_manager
def add(self, req: Req) -> None:
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
if req.bootstrap_host == FakeBootstrapHost:
# Fake transfer for warmup reqs
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
else:
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
req.disagg_kv_sender = kv_sender_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
......
......@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
DECODE = "decode"
FakeBootstrapHost = "2.2.2.2"
def poll_and_all_reduce(pollers, gloo_group):
polls = [int(poller.poll()) for poller in pollers]
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
......@@ -59,6 +62,8 @@ class KVClassType(Enum):
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
if transfer_backend == TransferBackend.MOONCAKE:
from sglang.srt.disaggregation.mooncake import (
MooncakeKVBootstrapServer,
......@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
class_mapping = {
KVClassType.MANAGER: MooncakeKVManager,
KVClassType.SENDER: MooncakeKVSender,
KVClassType.RECEIVER: MooncakeKVReceiver,
KVClassType.RECEIVER: (MooncakeKVReceiver),
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
}
return class_mapping.get(class_type)
......@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
class_mapping = {
KVClassType.MANAGER: NixlKVManager,
KVClassType.SENDER: NixlKVSender,
KVClassType.RECEIVER: NixlKVReceiver,
KVClassType.RECEIVER: (NixlKVReceiver),
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
}
return class_mapping.get(class_type)
if transfer_backend == TransferBackend.FAKE:
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
class_mapping = {
KVClassType.SENDER: FakeKVSender,
KVClassType.RECEIVER: (FakeKVReceiver),
}
return class_mapping.get(class_type)
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
......
......@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import FakeBootstrapHost
from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
......@@ -821,8 +822,32 @@ def _wait_and_warmup(
)
assert res.status_code == 200, f"{res}"
else:
# Warmup request currently hangs in disaggregation mode, so we skip it.
logger.info("Skipping warmup request in disaggregation mode")
logger.info(f"Start of prefill warmup ...")
json_data = {
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": True,
},
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
# This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room": [
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
for i in range(server_args.dp_size)
],
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
}
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=1800, # because of deep gemm precache is very long if not precache.
)
logger.info(
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
)
except Exception:
last_traceback = get_exception_traceback()
if pipe_finish_writer is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment