"src/diffusers/quantizers/__init__.py" did not exist on "d849816659539eb4c3807f80a865f754dc76d586"
Unverified Commit 8233cc10 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Support logprob & Add failure test (#6558)

parent 1b2e8f76
......@@ -36,6 +36,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode,
FakeBootstrapHost,
KVClassType,
MetadataBuffers,
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
......@@ -78,8 +79,7 @@ class DecodePreallocQueue:
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor],
aux_dtype: torch.dtype,
metadata_buffers: MetadataBuffers,
scheduler: Scheduler,
transfer_queue: DecodeTransferQueue,
tree_cache: BasePrefixCache,
......@@ -94,7 +94,6 @@ class DecodePreallocQueue:
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
self.aux_dtype = aux_dtype
self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.scheduler = scheduler
......@@ -133,15 +132,9 @@ class DecodePreallocQueue:
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
kv_args.aux_data_ptrs = [
output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
]
kv_args.aux_data_lens = [
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_item_lens = [
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
......@@ -211,7 +204,18 @@ class DecodePreallocQueue:
indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens()
# First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue):
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
indices_to_remove.add(i)
for i, decode_req in enumerate(self.queue):
if i in indices_to_remove:
continue
if not decode_req.waiting_for_input:
continue
......@@ -331,7 +335,7 @@ class DecodeTransferQueue:
self,
gloo_group: ProcessGroup,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: torch.Tensor,
metadata_buffers: MetadataBuffers,
scheduler: Scheduler,
tree_cache: BasePrefixCache,
):
......@@ -342,11 +346,11 @@ class DecodeTransferQueue:
self.scheduler = scheduler
self.tree_cache = tree_cache
def add(self, req_conn: DecodeRequest) -> None:
self.queue.append(req_conn)
def add(self, decode_req: DecodeRequest) -> None:
self.queue.append(decode_req)
def extend(self, req_conns) -> None:
self.queue.extend(req_conns)
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
self.queue.extend(decode_reqs)
def pop_transferred(self) -> List[DecodeRequest]:
if not self.queue:
......@@ -356,14 +360,6 @@ class DecodeTransferQueue:
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
# First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue):
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
indices_to_remove.add(i)
transferred_reqs = []
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
......@@ -387,16 +383,37 @@ class DecodeTransferQueue:
indices_to_remove.add(i)
continue
elif poll == KVPoll.Success:
# pop and push it to waiting queue
idx = decode_req.metadata_buffer_index
assert len(decode_req.req.output_ids) == 0
output_id_buffer = self.metadata_buffers[0]
# the last dimension is padded by the same values.
output_id = output_id_buffer[idx][0].item()
assert len(decode_req.req.output_ids) == 0
assert decode_req.req.transferred_output_id is None
decode_req.req.transferred_output_id = output_id
transferred_reqs.append(decode_req)
(
output_id,
output_token_logprobs_val,
output_token_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item())
if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append(
output_token_logprobs_val[0].item()
)
decode_req.req.output_token_logprobs_idx.append(
output_token_logprobs_idx[0].item()
)
decode_req.req.output_top_logprobs_val.append(
output_top_logprobs_val[
: decode_req.req.top_logprobs_num
].tolist()
)
decode_req.req.output_top_logprobs_idx.append(
output_top_logprobs_idx[
: decode_req.req.top_logprobs_num
].tolist()
)
transferred_reqs.append(decode_req.req)
indices_to_remove.add(i)
elif poll in [
KVPoll.Bootstrapping,
......@@ -451,7 +468,9 @@ class SchedulerDisaggregationDecodeMixin:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, False)
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
if prepare_dp_attn_flag:
self._prepare_idle_batch_and_run(None)
else:
......@@ -497,7 +516,9 @@ class SchedulerDisaggregationDecodeMixin:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, False)
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
if prepare_dp_attn_flag:
batch_, result = self._prepare_idle_batch_and_run(
None, delay_process=True
......@@ -618,15 +639,8 @@ class SchedulerDisaggregationDecodeMixin:
def process_decode_queue(self: Scheduler):
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
def _num_pre_alloc(req):
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = (
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
self.waiting_queue.extend([req.req for req in alloc_reqs])
self.waiting_queue.extend(alloc_reqs)
......@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import torch
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__)
......@@ -76,6 +76,11 @@ class ScheduleBatchDisaggregationDecodeMixin:
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
......@@ -94,12 +99,41 @@ class ScheduleBatchDisaggregationDecodeMixin:
"""Assign the buffered last input id to schedule batch"""
self.output_ids = []
for req in self.reqs:
if req.output_ids and len(req.output_ids) > 0:
# resumed retracted req
self.output_ids.append(req.output_ids[-1])
else:
assert req.transferred_output_id is not None
req.output_ids.append(req.transferred_output_id)
self.output_ids.append(req.transferred_output_id)
self.output_ids.append(req.output_ids[-1])
self.tree_cache.cache_unfinished_req(req)
self.output_ids = torch.tensor(self.output_ids, device=self.device)
# Simulate the eagle run. We add mock data to hidden states for the
# ease of implementation now meaning the first token will have acc rate
# of 0.
if not self.spec_algorithm.is_none():
b = len(self.reqs)
topk_p = torch.arange(
b * server_args.speculative_eagle_topk,
0,
-1,
device=self.device,
dtype=torch.float32,
)
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
topk_p /= b * server_args.speculative_eagle_topk
topk_index = torch.arange(
b * server_args.speculative_eagle_topk, device=self.device
)
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
# local import to avoid circular import
from sglang.srt.speculative.eagle_utils import EagleDraftInput
spec_info = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=torch.ones(
(b, model_config.hidden_size), device=self.device
),
verified_id=self.output_ids,
)
spec_info.prepare_for_extend(self)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
self.spec_info = spec_info
......@@ -73,11 +73,27 @@ class MiniLoadBalancer:
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Prefill should end first.
_, decode_response = await asyncio.gather(*tasks)
prefill_response, decode_response = await asyncio.gather(*tasks)
if "return_logprob" in modified_request:
prefill_json = await prefill_response.json()
ret_json = await decode_response.json()
# merge `meta_info.input_token_logprobs` from prefill to decode
if "meta_info" in ret_json:
if "input_token_logprobs" in ret_json["meta_info"]:
ret_json["meta_info"]["input_token_logprobs"] = (
prefill_json["meta_info"]["input_token_logprobs"]
+ ret_json["meta_info"]["input_token_logprobs"]
)
else:
ret_json = await decode_response.json()
return ORJSONResponse(
content=await decode_response.json(),
content=ret_json,
status_code=decode_response.status,
)
......@@ -92,30 +108,47 @@ class MiniLoadBalancer:
total=3600
) # Add timeout for request reliability
) as session:
try:
# Create the tasks for both prefill and decode requests
tasks = [
session.post(
f"{prefill_server}/{endpoint}", json=modified_request
),
session.post(
f"{decode_server}/{endpoint}", json=modified_request
),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
# Create the tasks for both prefill and decode requests
tasks = [
session.post(f"{prefill_server}/generate", json=modified_request),
session.post(f"{decode_server}/generate", json=modified_request),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
if modified_request.get("return_logprob", False):
prefill_chunks = []
async for chunk in prefill_response.content:
prefill_chunks.append(chunk)
first_prefill_chunk = (
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
)
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
async for chunk in decode_response.content:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk = chunk.decode("utf-8")
if (
decoded_chunk
and decoded_chunk.startswith("data:")
and "[DONE]" not in decoded_chunk
):
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
ret_json["meta_info"]["input_token_logprobs"] = (
first_prefill_chunk_json["meta_info"][
"input_token_logprobs"
]
+ ret_json["meta_info"]["input_token_logprobs"]
)
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
else:
yield chunk
else:
async for chunk in decode_response.content:
yield chunk
except Exception as e:
error_msg = {
"error": {"message": f"Stream processing error: {str(e)}"}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
finally:
if prefill_response is not None:
await prefill_response.release()
return StreamingResponse(
stream_results(),
......
......@@ -32,6 +32,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode,
FakeBootstrapHost,
KVClassType,
MetadataBuffers,
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
......@@ -63,8 +64,7 @@ class PrefillBootstrapQueue:
token_to_kv_pool: KVCache,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor],
aux_dtype: torch.dtype,
metadata_buffers: MetadataBuffers,
tp_rank: int,
tp_size: int,
bootstrap_port: int,
......@@ -76,7 +76,6 @@ class PrefillBootstrapQueue:
self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
self.aux_dtype = aux_dtype
self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
......@@ -116,15 +115,9 @@ class PrefillBootstrapQueue:
kv_args.kv_item_lens = kv_item_lens
# Define req -> input ids buffer
kv_args.aux_data_ptrs = [
metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
]
kv_args.aux_data_lens = [
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_item_lens = [
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
......@@ -299,10 +292,9 @@ class SchedulerDisaggregationPrefillMixin:
launch_done: Optional[threading.Event] = None,
) -> None:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
Adapted from process_batch_result_prefill
"""
(
logits_output,
next_token_ids,
......@@ -315,27 +307,78 @@ class SchedulerDisaggregationPrefillMixin:
result.extend_logprob_start_len_per_req,
)
logprob_pt = 0
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if self.enable_overlap:
# wait
_, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
launch_done
)
else:
next_token_ids = result.next_token_ids.tolist()
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
for i, (req, next_token_id) in enumerate(
zip(batch.reqs, next_token_ids, strict=True)
):
req: Req
if req.is_chunked <= 0:
# There is no output_ids for prefill
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
self.send_kv_chunk(req, token_id=next_token_id)
self.disagg_prefill_inflight_queue.append(req)
if req.return_logprob:
assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_logprob_return_values(
i,
req,
logprob_pt,
next_token_ids,
num_input_logprobs,
logits_output,
)
logprob_pt += num_input_logprobs
self.send_kv_chunk(req, last_chunk=True)
if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
else:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
if req.return_logprob:
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
if extend_logprob_start_len < extend_input_len:
# Update input logprobs.
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_input_logprob_return_values(
i,
req,
logits_output,
logprob_pt,
num_input_logprobs,
last_prefill_chunk=False,
)
logprob_pt += num_input_logprobs
if self.enable_overlap:
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
# We need to remove the sync in the following function for overlap schedule.
self.set_next_batch_sampling_info_done(batch)
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
"""
......@@ -379,7 +422,11 @@ class SchedulerDisaggregationPrefillMixin:
)
# Stream requests which have finished transfer
self.stream_output(done_reqs, False, None)
self.stream_output(
done_reqs,
any(req.return_logprob for req in done_reqs),
None,
)
self.disagg_prefill_inflight_queue = undone_reqs
......@@ -405,7 +452,7 @@ class SchedulerDisaggregationPrefillMixin:
def send_kv_chunk(
self: Scheduler,
req: Req,
token_id: Optional[int] = None,
last_chunk: bool = False,
end_idx: Optional[int] = None,
) -> None:
"""
......@@ -413,37 +460,28 @@ class SchedulerDisaggregationPrefillMixin:
"""
page_size = self.token_to_kv_pool_allocator.page_size
start_idx = req.start_send_idx
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
# the resolved length is not the same as fill_ids's length
end_idx = (
end_idx
if end_idx is not None
else min(len(req.fill_ids), len(req.origin_input_ids))
)
last_chunk = token_id is not None
if not last_chunk:
# if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx = end_idx - end_idx % page_size
# Update next start_send_idx
req.start_send_idx = end_idx
kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu()
.numpy()
)
if last_chunk is True:
self.disagg_prefill_bootstrap_queue.store_prefill_results(
req.metadata_buffer_index, token_id
)
req.start_send_idx = end_idx
if last_chunk:
self.disagg_metadata_buffers.set_buf(req)
page_indices = kv_to_page_indices(kv_indices, page_size)
if len(page_indices) == 0:
logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return
req.disagg_kv_sender.send(page_indices)
......@@ -6,7 +6,7 @@ import random
import warnings
from collections import deque
from enum import Enum
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
import numpy as np
import requests
......@@ -15,6 +15,9 @@ import torch.distributed as dist
from sglang.srt.utils import get_ip
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
FakeBootstrapHost = "2.2.2.2"
# env var for testing failure, convert to float explicitly
......@@ -196,3 +199,83 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
req.input_top_logprobs_idx = []
req.input_token_ids_logprobs_val = []
req.input_token_ids_logprobs_idx = []
class MetadataBuffers:
def __init__(self, size: int, max_top_logprobs_num: int = 128):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device="cpu"
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device="cpu"
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
)
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
]
data_lens = [
self.output_ids.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes,
]
item_lens = [
self.output_ids[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes,
]
return ptrs, data_lens, item_lens
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx],
)
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
req.output_token_logprobs_val[0]
)
if req.output_token_logprobs_idx: # not none or empty list
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
req.output_token_logprobs_idx[0]
)
if req.output_top_logprobs_val: # not none or empty list
self.output_top_logprobs_val[req.metadata_buffer_index][
: len(req.output_top_logprobs_val[0])
] = torch.tensor(
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
)
if req.output_top_logprobs_idx: # not none or empty list
self.output_top_logprobs_idx[req.metadata_buffer_index][
: len(req.output_top_logprobs_idx[0])
] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
)
......@@ -607,9 +607,6 @@ class Req:
self.tmp_end_idx: int = -1
self.metadata_buffer_index: int = -1
# The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None
@property
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
......
......@@ -48,6 +48,7 @@ from sglang.srt.disaggregation.prefill import (
)
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
MetadataBuffers,
ReqToMetadataIdxAllocator,
TransferBackend,
prepare_abort,
......@@ -569,20 +570,13 @@ class Scheduler(
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
aux_dtype = torch.int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch.zeros(
(buffer_size, 16), dtype=aux_dtype, device="cpu"
)
metadata_buffers = [output_id_buffer]
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.attn_tp_cpu_group,
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
metadata_buffers=self.disagg_metadata_buffers,
scheduler=self,
tree_cache=self.tree_cache,
)
......@@ -597,8 +591,7 @@ class Scheduler(
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
aux_dtype=aux_dtype,
metadata_buffers=self.disagg_metadata_buffers,
scheduler=self,
transfer_queue=self.disagg_decode_transfer_queue,
tree_cache=self.tree_cache,
......@@ -618,14 +611,7 @@ class Scheduler(
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
aux_dtype = torch.int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch.zeros(
(buffer_size, 16), dtype=aux_dtype, device="cpu"
)
metadata_buffers = [output_id_buffer]
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
......@@ -635,8 +621,7 @@ class Scheduler(
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
aux_dtype=aux_dtype,
metadata_buffers=self.disagg_metadata_buffers,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
......
......@@ -485,7 +485,6 @@ def popen_launch_pd_server(
api_key: Optional[str] = None,
other_args: list[str] = (),
env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None,
):
_, host, port = base_url.split(":")
host = host[2:]
......@@ -515,42 +514,9 @@ def popen_launch_pd_server(
print(f"command={' '.join(command)}")
if return_stdout_stderr:
process = subprocess.Popen(
command,
stdout=return_stdout_stderr[0],
stderr=return_stdout_stderr[1],
env=env,
text=True,
)
else:
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
start_time = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
response = session.get(
f"{base_url}/health",
headers=headers,
)
if response.status_code == 200:
return process
except requests.RequestException:
pass
return_code = process.poll()
if return_code is not None:
raise Exception(f"Server unexpectedly exits ({return_code=}).")
time.sleep(10)
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
kill_process_tree(process.pid)
raise TimeoutError("Server failed to start within the timeout period.")
return process
def run_with_timeout(
......
prompt = "The capital of taiwan is "
import json
import requests
response = requests.post(
"http://0.0.0.0:8000/generate",
json={
"text": prompt,
"sampling_params": {"temperature": 0},
"return_logprob": True,
"return_input_logprob": True,
"logprob_start_len": 0,
},
)
j = response.json()
input_logprobs = j["meta_info"]["input_token_logprobs"]
output_logprobs = j["meta_info"]["output_token_logprobs"]
print(len(input_logprobs), len(output_logprobs))
import os
import subprocess
import time
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
import requests
......@@ -25,15 +27,22 @@ class TestDisaggregationAccuracy(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_host = "127.0.0.1"
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1])
cls.lb_url = DEFAULT_URL_FOR_TEST
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}"
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}"
run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
......@@ -48,7 +57,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--host",
cls.base_host,
"--port",
str(cls.base_port),
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
......@@ -63,14 +72,10 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--host",
cls.base_host,
"--port",
str(cls.base_port + 100),
"--tp",
"4",
# "--disaggregation-ib-device",
# "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3",
"1",
"--disaggregation-ib-device",
"mlx5_roce0",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
......@@ -85,16 +90,165 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"1",
"--disaggregation-ib-device",
"mlx5_roce1",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.62)
def test_logprob(self):
prompt = "The capital of taiwan is "
response = requests.post(
self.lb_url + "/generate",
json={
"text": prompt,
"sampling_params": {"temperature": 0},
"return_logprob": True,
"return_input_logprob": True,
"logprob_start_len": 0,
},
)
j = response.json()
completion_tokens = j["meta_info"]["completion_tokens"]
input_logprobs = j["meta_info"]["input_token_logprobs"]
output_logprobs = j["meta_info"]["output_token_logprobs"]
assert (
len(output_logprobs) == completion_tokens
), f"output_logprobs and completion_tokens should have the same length, but got {len(output_logprobs)} and {completion_tokens}"
assert (
len(input_logprobs) > 0
), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"
class TestDisaggregationMooncakeFailure(CustomTestCase):
@classmethod
def setUpClass(cls):
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang.srt.disaggregation.mini_lb",
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
str(cls.base_port + 200),
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"4",
"1",
"--disaggregation-ib-device",
"mlx5_roce0",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"4",
# "--disaggregation-ib-device",
# "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"1",
"--disaggregation-ib-device",
"mlx5_roce1",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
......@@ -121,6 +275,8 @@ class TestDisaggregationAccuracy(CustomTestCase):
@classmethod
def tearDownClass(cls):
# unset DISAGGREGATION_TEST_FAILURE_PROB
os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
......@@ -128,6 +284,9 @@ class TestDisaggregationAccuracy(CustomTestCase):
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
......@@ -135,27 +294,29 @@ class TestDisaggregationAccuracy(CustomTestCase):
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.lb_url.split(":")[-1]),
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.62)
# Expect lots of failure but the server cannot crash
class TestDisaggregationSpecAccuracy(CustomTestCase):
class TestDisaggregationMooncakeSpec(CustomTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
cls.base_host = "127.0.0.1"
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1])
cls.lb_url = DEFAULT_URL_FOR_TEST
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}"
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}"
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
cls.spec_args = [
"--speculative-algorithm",
"EAGLE",
......@@ -170,10 +331,13 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--cuda-graph-max-bs",
"8",
]
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
......@@ -188,7 +352,7 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--host",
cls.base_host,
"--port",
str(cls.base_port),
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
......@@ -215,21 +379,15 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--host",
cls.base_host,
"--port",
str(cls.base_port + 100),
"--tp",
"4",
# "--disaggregation-ib-device",
# "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3",
"2",
"--disaggregation-ib-device",
"mlx5_roce0,mlx5_roce1",
] + cls.spec_args
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
......@@ -243,16 +401,12 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--host",
cls.base_host,
"--port",
str(cls.base_port + 200),
"--tp",
"4",
"2",
"--base-gpu-id",
"4",
# "--disaggregation-ib-device",
# "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"2",
"--disaggregation-ib-device",
"mlx5_roce2,mlx5_roce3",
] + cls.spec_args
cls.process_decode = popen_launch_pd_server(
cls.model,
......@@ -261,15 +415,43 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=4, # TODO: 128 crashes the decode
host="http://127.0.0.1",
port=int(self.lb_url.split(":")[-1]),
parallel=2,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment