Unverified Commit c6c62640 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[PD] support pd fake transfer for warmup (#5726)

parent 92ab0a20
...@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup ...@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
FakeBootstrapHost,
KVClassType, KVClassType,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend, TransferBackend,
...@@ -133,8 +134,13 @@ class DecodePreallocQueue: ...@@ -133,8 +134,13 @@ class DecodePreallocQueue:
def add(self, req: Req) -> None: def add(self, req: Req) -> None:
"""Add a request to the pending queue.""" """Add a request to the pending queue."""
if req.bootstrap_host == FakeBootstrapHost:
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER) # Fake transfer for warmup reqs
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
else:
kv_receiver_class = get_kv_class(
self.transfer_backend, KVClassType.RECEIVER
)
kv_receiver = kv_receiver_class( kv_receiver = kv_receiver_class(
mgr=self.kv_manager, mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
......
from .conn import FakeKVReceiver, FakeKVSender
import logging
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
from sglang.srt.disaggregation.base.conn import (
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
logger = logging.getLogger(__name__)
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
class FakeKVSender(BaseKVSender):
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
self.has_sent = False
def poll(self) -> KVPoll:
if self.has_sent is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.info("FakeKVSender poll success")
return KVPoll.Success
def init(
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
dest_ranks: Optional[list[int]] = None,
):
logger.info(
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
)
pass
def send(
self,
kv_indices: npt.NDArray[np.int64],
index_slice: slice,
is_last: bool,
):
logger.info(
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
)
if is_last:
self.has_sent = True
logger.info(f"FakeKVSender send success")
else:
self.has_sent = False
logger.info(f"FakeKVSender send fake transfering")
def failure_exception(self):
raise Exception("Fake KVSender Exception")
class FakeKVReceiver(BaseKVReceiver):
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
):
self.has_init = False
def poll(self) -> KVPoll:
if self.has_init is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.info("FakeKVReceiver poll success")
return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
self.has_init = True
logger.info(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
)
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
...@@ -29,6 +29,7 @@ import torch ...@@ -29,6 +29,7 @@ import torch
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
FakeBootstrapHost,
KVClassType, KVClassType,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend, TransferBackend,
...@@ -116,7 +117,11 @@ class PrefillBootstrapQueue: ...@@ -116,7 +117,11 @@ class PrefillBootstrapQueue:
return kv_manager return kv_manager
def add(self, req: Req) -> None: def add(self, req: Req) -> None:
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER) if req.bootstrap_host == FakeBootstrapHost:
# Fake transfer for warmup reqs
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
else:
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
req.disagg_kv_sender = kv_sender_class( req.disagg_kv_sender = kv_sender_class(
mgr=self.kv_manager, mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
......
...@@ -15,6 +15,9 @@ class DisaggregationMode(Enum): ...@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
DECODE = "decode" DECODE = "decode"
FakeBootstrapHost = "2.2.2.2"
def poll_and_all_reduce(pollers, gloo_group): def poll_and_all_reduce(pollers, gloo_group):
polls = [int(poller.poll()) for poller in pollers] polls = [int(poller.poll()) for poller in pollers]
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
...@@ -59,6 +62,8 @@ class KVClassType(Enum): ...@@ -59,6 +62,8 @@ class KVClassType(Enum):
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
if transfer_backend == TransferBackend.MOONCAKE: if transfer_backend == TransferBackend.MOONCAKE:
from sglang.srt.disaggregation.mooncake import ( from sglang.srt.disaggregation.mooncake import (
MooncakeKVBootstrapServer, MooncakeKVBootstrapServer,
...@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): ...@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
class_mapping = { class_mapping = {
KVClassType.MANAGER: MooncakeKVManager, KVClassType.MANAGER: MooncakeKVManager,
KVClassType.SENDER: MooncakeKVSender, KVClassType.SENDER: MooncakeKVSender,
KVClassType.RECEIVER: MooncakeKVReceiver, KVClassType.RECEIVER: (MooncakeKVReceiver),
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
} }
return class_mapping.get(class_type) return class_mapping.get(class_type)
...@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): ...@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
class_mapping = { class_mapping = {
KVClassType.MANAGER: NixlKVManager, KVClassType.MANAGER: NixlKVManager,
KVClassType.SENDER: NixlKVSender, KVClassType.SENDER: NixlKVSender,
KVClassType.RECEIVER: NixlKVReceiver, KVClassType.RECEIVER: (NixlKVReceiver),
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
} }
return class_mapping.get(class_type) return class_mapping.get(class_type)
if transfer_backend == TransferBackend.FAKE:
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
class_mapping = {
KVClassType.SENDER: FakeKVSender,
KVClassType.RECEIVER: (FakeKVReceiver),
}
return class_mapping.get(class_type)
raise ValueError(f"Unsupported transfer backend: {transfer_backend}") raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
......
...@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile ...@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import FakeBootstrapHost
from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -821,8 +822,32 @@ def _wait_and_warmup( ...@@ -821,8 +822,32 @@ def _wait_and_warmup(
) )
assert res.status_code == 200, f"{res}" assert res.status_code == 200, f"{res}"
else: else:
# Warmup request currently hangs in disaggregation mode, so we skip it. logger.info(f"Start of prefill warmup ...")
logger.info("Skipping warmup request in disaggregation mode") json_data = {
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": True,
},
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
# This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room": [
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
for i in range(server_args.dp_size)
],
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
}
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=1800, # because of deep gemm precache is very long if not precache.
)
logger.info(
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
)
except Exception: except Exception:
last_traceback = get_exception_traceback() last_traceback = get_exception_traceback()
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment