Unverified Commit 8233cc10 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Support logprob & Add failure test (#6558)

parent 1b2e8f76
...@@ -36,6 +36,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -36,6 +36,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
FakeBootstrapHost, FakeBootstrapHost,
KVClassType, KVClassType,
MetadataBuffers,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend, TransferBackend,
get_kv_class, get_kv_class,
...@@ -78,8 +79,7 @@ class DecodePreallocQueue: ...@@ -78,8 +79,7 @@ class DecodePreallocQueue:
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
draft_token_to_kv_pool: Optional[KVCache], draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor], metadata_buffers: MetadataBuffers,
aux_dtype: torch.dtype,
scheduler: Scheduler, scheduler: Scheduler,
transfer_queue: DecodeTransferQueue, transfer_queue: DecodeTransferQueue,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
...@@ -94,7 +94,6 @@ class DecodePreallocQueue: ...@@ -94,7 +94,6 @@ class DecodePreallocQueue:
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache() self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
self.draft_token_to_kv_pool = draft_token_to_kv_pool self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool) self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
self.aux_dtype = aux_dtype
self.metadata_buffers = metadata_buffers self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.scheduler = scheduler self.scheduler = scheduler
...@@ -133,15 +132,9 @@ class DecodePreallocQueue: ...@@ -133,15 +132,9 @@ class DecodePreallocQueue:
kv_args.kv_data_lens = kv_data_lens kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens kv_args.kv_item_lens = kv_item_lens
kv_args.aux_data_ptrs = [ kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers self.metadata_buffers.get_buf_infos()
] )
kv_args.aux_data_lens = [
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_item_lens = [
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
...@@ -211,7 +204,18 @@ class DecodePreallocQueue: ...@@ -211,7 +204,18 @@ class DecodePreallocQueue:
indices_to_remove = set() indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens() allocatable_tokens = self._allocatable_tokens()
# First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue): for i, decode_req in enumerate(self.queue):
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
indices_to_remove.add(i)
for i, decode_req in enumerate(self.queue):
if i in indices_to_remove:
continue
if not decode_req.waiting_for_input: if not decode_req.waiting_for_input:
continue continue
...@@ -331,7 +335,7 @@ class DecodeTransferQueue: ...@@ -331,7 +335,7 @@ class DecodeTransferQueue:
self, self,
gloo_group: ProcessGroup, gloo_group: ProcessGroup,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: torch.Tensor, metadata_buffers: MetadataBuffers,
scheduler: Scheduler, scheduler: Scheduler,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
): ):
...@@ -342,11 +346,11 @@ class DecodeTransferQueue: ...@@ -342,11 +346,11 @@ class DecodeTransferQueue:
self.scheduler = scheduler self.scheduler = scheduler
self.tree_cache = tree_cache self.tree_cache = tree_cache
def add(self, req_conn: DecodeRequest) -> None: def add(self, decode_req: DecodeRequest) -> None:
self.queue.append(req_conn) self.queue.append(decode_req)
def extend(self, req_conns) -> None: def extend(self, decode_reqs: List[DecodeRequest]) -> None:
self.queue.extend(req_conns) self.queue.extend(decode_reqs)
def pop_transferred(self) -> List[DecodeRequest]: def pop_transferred(self) -> List[DecodeRequest]:
if not self.queue: if not self.queue:
...@@ -356,14 +360,6 @@ class DecodeTransferQueue: ...@@ -356,14 +360,6 @@ class DecodeTransferQueue:
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
) )
# First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue):
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
indices_to_remove.add(i)
transferred_reqs = [] transferred_reqs = []
indices_to_remove = set() indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
...@@ -387,16 +383,37 @@ class DecodeTransferQueue: ...@@ -387,16 +383,37 @@ class DecodeTransferQueue:
indices_to_remove.add(i) indices_to_remove.add(i)
continue continue
elif poll == KVPoll.Success: elif poll == KVPoll.Success:
# pop and push it to waiting queue
idx = decode_req.metadata_buffer_index idx = decode_req.metadata_buffer_index
assert len(decode_req.req.output_ids) == 0 (
output_id_buffer = self.metadata_buffers[0] output_id,
# the last dimension is padded by the same values. output_token_logprobs_val,
output_id = output_id_buffer[idx][0].item() output_token_logprobs_idx,
assert len(decode_req.req.output_ids) == 0 output_top_logprobs_val,
assert decode_req.req.transferred_output_id is None output_top_logprobs_idx,
decode_req.req.transferred_output_id = output_id ) = self.metadata_buffers.get_buf(idx)
transferred_reqs.append(decode_req)
decode_req.req.output_ids.append(output_id[0].item())
if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append(
output_token_logprobs_val[0].item()
)
decode_req.req.output_token_logprobs_idx.append(
output_token_logprobs_idx[0].item()
)
decode_req.req.output_top_logprobs_val.append(
output_top_logprobs_val[
: decode_req.req.top_logprobs_num
].tolist()
)
decode_req.req.output_top_logprobs_idx.append(
output_top_logprobs_idx[
: decode_req.req.top_logprobs_num
].tolist()
)
transferred_reqs.append(decode_req.req)
indices_to_remove.add(i) indices_to_remove.add(i)
elif poll in [ elif poll in [
KVPoll.Bootstrapping, KVPoll.Bootstrapping,
...@@ -451,7 +468,9 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -451,7 +468,9 @@ class SchedulerDisaggregationDecodeMixin:
# Generate fake extend output. # Generate fake extend output.
if batch.forward_mode.is_extend(): if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine. # Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, False) self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
if prepare_dp_attn_flag: if prepare_dp_attn_flag:
self._prepare_idle_batch_and_run(None) self._prepare_idle_batch_and_run(None)
else: else:
...@@ -497,7 +516,9 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -497,7 +516,9 @@ class SchedulerDisaggregationDecodeMixin:
# Generate fake extend output. # Generate fake extend output.
if batch.forward_mode.is_extend(): if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine. # Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, False) self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
if prepare_dp_attn_flag: if prepare_dp_attn_flag:
batch_, result = self._prepare_idle_batch_and_run( batch_, result = self._prepare_idle_batch_and_run(
None, delay_process=True None, delay_process=True
...@@ -618,15 +639,8 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -618,15 +639,8 @@ class SchedulerDisaggregationDecodeMixin:
def process_decode_queue(self: Scheduler): def process_decode_queue(self: Scheduler):
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
def _num_pre_alloc(req):
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
self.disagg_decode_transfer_queue.extend(req_conns) self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = ( alloc_reqs = (
self.disagg_decode_transfer_queue.pop_transferred() self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived ) # the requests which kv has arrived
self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs) self.waiting_queue.extend(alloc_reqs)
self.waiting_queue.extend([req.req for req in alloc_reqs])
...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING ...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import torch import torch
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -76,6 +76,11 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -76,6 +76,11 @@ class ScheduleBatchDisaggregationDecodeMixin:
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens) self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs] self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs]
...@@ -94,12 +99,41 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -94,12 +99,41 @@ class ScheduleBatchDisaggregationDecodeMixin:
"""Assign the buffered last input id to schedule batch""" """Assign the buffered last input id to schedule batch"""
self.output_ids = [] self.output_ids = []
for req in self.reqs: for req in self.reqs:
if req.output_ids and len(req.output_ids) > 0: self.output_ids.append(req.output_ids[-1])
# resumed retracted req
self.output_ids.append(req.output_ids[-1])
else:
assert req.transferred_output_id is not None
req.output_ids.append(req.transferred_output_id)
self.output_ids.append(req.transferred_output_id)
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
self.output_ids = torch.tensor(self.output_ids, device=self.device) self.output_ids = torch.tensor(self.output_ids, device=self.device)
# Simulate the eagle run. We add mock data to hidden states for the
# ease of implementation now meaning the first token will have acc rate
# of 0.
if not self.spec_algorithm.is_none():
b = len(self.reqs)
topk_p = torch.arange(
b * server_args.speculative_eagle_topk,
0,
-1,
device=self.device,
dtype=torch.float32,
)
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
topk_p /= b * server_args.speculative_eagle_topk
topk_index = torch.arange(
b * server_args.speculative_eagle_topk, device=self.device
)
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
# local import to avoid circular import
from sglang.srt.speculative.eagle_utils import EagleDraftInput
spec_info = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=torch.ones(
(b, model_config.hidden_size), device=self.device
),
verified_id=self.output_ids,
)
spec_info.prepare_for_extend(self)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
self.spec_info = spec_info
...@@ -73,11 +73,27 @@ class MiniLoadBalancer: ...@@ -73,11 +73,27 @@ class MiniLoadBalancer:
session.post(f"{prefill_server}/{endpoint}", json=modified_request), session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request), session.post(f"{decode_server}/{endpoint}", json=modified_request),
] ]
# Wait for both responses to complete. Prefill should end first. # Wait for both responses to complete. Prefill should end first.
_, decode_response = await asyncio.gather(*tasks) prefill_response, decode_response = await asyncio.gather(*tasks)
if "return_logprob" in modified_request:
prefill_json = await prefill_response.json()
ret_json = await decode_response.json()
# merge `meta_info.input_token_logprobs` from prefill to decode
if "meta_info" in ret_json:
if "input_token_logprobs" in ret_json["meta_info"]:
ret_json["meta_info"]["input_token_logprobs"] = (
prefill_json["meta_info"]["input_token_logprobs"]
+ ret_json["meta_info"]["input_token_logprobs"]
)
else:
ret_json = await decode_response.json()
return ORJSONResponse( return ORJSONResponse(
content=await decode_response.json(), content=ret_json,
status_code=decode_response.status, status_code=decode_response.status,
) )
...@@ -92,30 +108,47 @@ class MiniLoadBalancer: ...@@ -92,30 +108,47 @@ class MiniLoadBalancer:
total=3600 total=3600
) # Add timeout for request reliability ) # Add timeout for request reliability
) as session: ) as session:
try: # Create the tasks for both prefill and decode requests
# Create the tasks for both prefill and decode requests tasks = [
tasks = [ session.post(f"{prefill_server}/generate", json=modified_request),
session.post( session.post(f"{decode_server}/generate", json=modified_request),
f"{prefill_server}/{endpoint}", json=modified_request ]
), # Wait for both responses to complete. Since this is streaming, they return immediately.
session.post( prefill_response, decode_response = await asyncio.gather(*tasks)
f"{decode_server}/{endpoint}", json=modified_request
), if modified_request.get("return_logprob", False):
] prefill_chunks = []
# Wait for both responses to complete. Since this is streaming, they return immediately. async for chunk in prefill_response.content:
prefill_response, decode_response = await asyncio.gather(*tasks) prefill_chunks.append(chunk)
first_prefill_chunk = (
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
)
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
async for chunk in decode_response.content:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk = chunk.decode("utf-8")
if (
decoded_chunk
and decoded_chunk.startswith("data:")
and "[DONE]" not in decoded_chunk
):
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
ret_json["meta_info"]["input_token_logprobs"] = (
first_prefill_chunk_json["meta_info"][
"input_token_logprobs"
]
+ ret_json["meta_info"]["input_token_logprobs"]
)
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
else:
yield chunk
else:
async for chunk in decode_response.content: async for chunk in decode_response.content:
yield chunk yield chunk
except Exception as e:
error_msg = {
"error": {"message": f"Stream processing error: {str(e)}"}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
finally:
if prefill_response is not None:
await prefill_response.release()
return StreamingResponse( return StreamingResponse(
stream_results(), stream_results(),
......
...@@ -32,6 +32,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -32,6 +32,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
FakeBootstrapHost, FakeBootstrapHost,
KVClassType, KVClassType,
MetadataBuffers,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend, TransferBackend,
get_kv_class, get_kv_class,
...@@ -63,8 +64,7 @@ class PrefillBootstrapQueue: ...@@ -63,8 +64,7 @@ class PrefillBootstrapQueue:
token_to_kv_pool: KVCache, token_to_kv_pool: KVCache,
draft_token_to_kv_pool: Optional[KVCache], draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor], metadata_buffers: MetadataBuffers,
aux_dtype: torch.dtype,
tp_rank: int, tp_rank: int,
tp_size: int, tp_size: int,
bootstrap_port: int, bootstrap_port: int,
...@@ -76,7 +76,6 @@ class PrefillBootstrapQueue: ...@@ -76,7 +76,6 @@ class PrefillBootstrapQueue:
self.draft_token_to_kv_pool = draft_token_to_kv_pool self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(token_to_kv_pool) self.is_mla_backend = is_mla_backend(token_to_kv_pool)
self.aux_dtype = aux_dtype
self.metadata_buffers = metadata_buffers self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
...@@ -116,15 +115,9 @@ class PrefillBootstrapQueue: ...@@ -116,15 +115,9 @@ class PrefillBootstrapQueue:
kv_args.kv_item_lens = kv_item_lens kv_args.kv_item_lens = kv_item_lens
# Define req -> input ids buffer # Define req -> input ids buffer
kv_args.aux_data_ptrs = [ kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers self.metadata_buffers.get_buf_infos()
] )
kv_args.aux_data_lens = [
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_item_lens = [
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
...@@ -299,10 +292,9 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -299,10 +292,9 @@ class SchedulerDisaggregationPrefillMixin:
launch_done: Optional[threading.Event] = None, launch_done: Optional[threading.Event] = None,
) -> None: ) -> None:
""" """
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
Adapted from process_batch_result_prefill Adapted from process_batch_result_prefill
""" """
( (
logits_output, logits_output,
next_token_ids, next_token_ids,
...@@ -315,27 +307,78 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -315,27 +307,78 @@ class SchedulerDisaggregationPrefillMixin:
result.extend_logprob_start_len_per_req, result.extend_logprob_start_len_per_req,
) )
logprob_pt = 0
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if self.enable_overlap: if self.enable_overlap:
# wait # wait
_, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done) logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
launch_done
)
else: else:
next_token_ids = result.next_token_ids.tolist() next_token_ids = result.next_token_ids.tolist()
if batch.return_logprob:
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True): if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
for i, (req, next_token_id) in enumerate(
zip(batch.reqs, next_token_ids, strict=True)
):
req: Req req: Req
if req.is_chunked <= 0: if req.is_chunked <= 0:
# There is no output_ids for prefill # There is no output_ids for prefill
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock self.tree_cache.cache_unfinished_req(req) # update the tree and lock
self.send_kv_chunk(req, token_id=next_token_id)
self.disagg_prefill_inflight_queue.append(req) self.disagg_prefill_inflight_queue.append(req)
if req.return_logprob:
assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_logprob_return_values(
i,
req,
logprob_pt,
next_token_ids,
num_input_logprobs,
logits_output,
)
logprob_pt += num_input_logprobs
self.send_kv_chunk(req, last_chunk=True)
if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
else: else:
# being chunked reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_chunked -= 1 req.is_chunked -= 1
if req.return_logprob:
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
if extend_logprob_start_len < extend_input_len:
# Update input logprobs.
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_input_logprob_return_values(
i,
req,
logits_output,
logprob_pt,
num_input_logprobs,
last_prefill_chunk=False,
)
logprob_pt += num_input_logprobs
if self.enable_overlap: if self.enable_overlap:
self.send_kv_chunk(req, end_idx=req.tmp_end_idx) self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
# We need to remove the sync in the following function for overlap schedule.
self.set_next_batch_sampling_info_done(batch)
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None: def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
""" """
...@@ -379,7 +422,11 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -379,7 +422,11 @@ class SchedulerDisaggregationPrefillMixin:
) )
# Stream requests which have finished transfer # Stream requests which have finished transfer
self.stream_output(done_reqs, False, None) self.stream_output(
done_reqs,
any(req.return_logprob for req in done_reqs),
None,
)
self.disagg_prefill_inflight_queue = undone_reqs self.disagg_prefill_inflight_queue = undone_reqs
...@@ -405,7 +452,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -405,7 +452,7 @@ class SchedulerDisaggregationPrefillMixin:
def send_kv_chunk( def send_kv_chunk(
self: Scheduler, self: Scheduler,
req: Req, req: Req,
token_id: Optional[int] = None, last_chunk: bool = False,
end_idx: Optional[int] = None, end_idx: Optional[int] = None,
) -> None: ) -> None:
""" """
...@@ -413,37 +460,28 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -413,37 +460,28 @@ class SchedulerDisaggregationPrefillMixin:
""" """
page_size = self.token_to_kv_pool_allocator.page_size page_size = self.token_to_kv_pool_allocator.page_size
start_idx = req.start_send_idx start_idx = req.start_send_idx
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
# the resolved length is not the same as fill_ids's length
end_idx = ( end_idx = (
end_idx end_idx
if end_idx is not None if end_idx is not None
else min(len(req.fill_ids), len(req.origin_input_ids)) else min(len(req.fill_ids), len(req.origin_input_ids))
) )
last_chunk = token_id is not None
if not last_chunk: if not last_chunk:
# if not the last chunk and the last page is partial, delay the last partial page to the next send # if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx = end_idx - end_idx % page_size end_idx = end_idx - end_idx % page_size
# Update next start_send_idx
req.start_send_idx = end_idx
kv_indices = ( kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx] self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu() .cpu()
.numpy() .numpy()
) )
if last_chunk is True: req.start_send_idx = end_idx
self.disagg_prefill_bootstrap_queue.store_prefill_results( if last_chunk:
req.metadata_buffer_index, token_id self.disagg_metadata_buffers.set_buf(req)
)
page_indices = kv_to_page_indices(kv_indices, page_size) page_indices = kv_to_page_indices(kv_indices, page_size)
if len(page_indices) == 0: if len(page_indices) == 0:
logger.info( logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
) )
return return
req.disagg_kv_sender.send(page_indices) req.disagg_kv_sender.send(page_indices)
...@@ -6,7 +6,7 @@ import random ...@@ -6,7 +6,7 @@ import random
import warnings import warnings
from collections import deque from collections import deque
from enum import Enum from enum import Enum
from typing import List, Optional from typing import TYPE_CHECKING, List, Optional
import numpy as np import numpy as np
import requests import requests
...@@ -15,6 +15,9 @@ import torch.distributed as dist ...@@ -15,6 +15,9 @@ import torch.distributed as dist
from sglang.srt.utils import get_ip from sglang.srt.utils import get_ip
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
FakeBootstrapHost = "2.2.2.2" FakeBootstrapHost = "2.2.2.2"
# env var for testing failure, convert to float explicitly # env var for testing failure, convert to float explicitly
...@@ -196,3 +199,83 @@ def prepare_abort(req: Req, error_message: str, status_code=None): ...@@ -196,3 +199,83 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
req.input_top_logprobs_idx = [] req.input_top_logprobs_idx = []
req.input_token_ids_logprobs_val = [] req.input_token_ids_logprobs_val = []
req.input_token_ids_logprobs_idx = [] req.input_token_ids_logprobs_idx = []
class MetadataBuffers:
def __init__(self, size: int, max_top_logprobs_num: int = 128):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device="cpu"
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device="cpu"
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
)
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
]
data_lens = [
self.output_ids.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes,
]
item_lens = [
self.output_ids[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes,
]
return ptrs, data_lens, item_lens
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx],
)
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
req.output_token_logprobs_val[0]
)
if req.output_token_logprobs_idx: # not none or empty list
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
req.output_token_logprobs_idx[0]
)
if req.output_top_logprobs_val: # not none or empty list
self.output_top_logprobs_val[req.metadata_buffer_index][
: len(req.output_top_logprobs_val[0])
] = torch.tensor(
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
)
if req.output_top_logprobs_idx: # not none or empty list
self.output_top_logprobs_idx[req.metadata_buffer_index][
: len(req.output_top_logprobs_idx[0])
] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
)
...@@ -607,9 +607,6 @@ class Req: ...@@ -607,9 +607,6 @@ class Req:
self.tmp_end_idx: int = -1 self.tmp_end_idx: int = -1
self.metadata_buffer_index: int = -1 self.metadata_buffer_index: int = -1
# The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None
@property @property
def seqlen(self): def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids) return len(self.origin_input_ids) + len(self.output_ids)
......
...@@ -48,6 +48,7 @@ from sglang.srt.disaggregation.prefill import ( ...@@ -48,6 +48,7 @@ from sglang.srt.disaggregation.prefill import (
) )
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
MetadataBuffers,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend, TransferBackend,
prepare_abort, prepare_abort,
...@@ -569,20 +570,13 @@ class Scheduler( ...@@ -569,20 +570,13 @@ class Scheduler(
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size buffer_size
) )
aux_dtype = torch.int32 self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch.zeros(
(buffer_size, 16), dtype=aux_dtype, device="cpu"
)
metadata_buffers = [output_id_buffer]
# The decode requests polling kv cache # The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue( self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.attn_tp_cpu_group, gloo_group=self.attn_tp_cpu_group,
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers, metadata_buffers=self.disagg_metadata_buffers,
scheduler=self, scheduler=self,
tree_cache=self.tree_cache, tree_cache=self.tree_cache,
) )
...@@ -597,8 +591,7 @@ class Scheduler( ...@@ -597,8 +591,7 @@ class Scheduler(
else self.draft_worker.model_runner.token_to_kv_pool else self.draft_worker.model_runner.token_to_kv_pool
), ),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers, metadata_buffers=self.disagg_metadata_buffers,
aux_dtype=aux_dtype,
scheduler=self, scheduler=self,
transfer_queue=self.disagg_decode_transfer_queue, transfer_queue=self.disagg_decode_transfer_queue,
tree_cache=self.tree_cache, tree_cache=self.tree_cache,
...@@ -618,14 +611,7 @@ class Scheduler( ...@@ -618,14 +611,7 @@ class Scheduler(
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size buffer_size
) )
aux_dtype = torch.int32 self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch.zeros(
(buffer_size, 16), dtype=aux_dtype, device="cpu"
)
metadata_buffers = [output_id_buffer]
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue( self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
...@@ -635,8 +621,7 @@ class Scheduler( ...@@ -635,8 +621,7 @@ class Scheduler(
else self.draft_worker.model_runner.token_to_kv_pool else self.draft_worker.model_runner.token_to_kv_pool
), ),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers, metadata_buffers=self.disagg_metadata_buffers,
aux_dtype=aux_dtype,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.tp_size, tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port, bootstrap_port=self.server_args.disaggregation_bootstrap_port,
......
...@@ -485,7 +485,6 @@ def popen_launch_pd_server( ...@@ -485,7 +485,6 @@ def popen_launch_pd_server(
api_key: Optional[str] = None, api_key: Optional[str] = None,
other_args: list[str] = (), other_args: list[str] = (),
env: Optional[dict] = None, env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None,
): ):
_, host, port = base_url.split(":") _, host, port = base_url.split(":")
host = host[2:] host = host[2:]
...@@ -515,42 +514,9 @@ def popen_launch_pd_server( ...@@ -515,42 +514,9 @@ def popen_launch_pd_server(
print(f"command={' '.join(command)}") print(f"command={' '.join(command)}")
if return_stdout_stderr: process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
process = subprocess.Popen(
command,
stdout=return_stdout_stderr[0],
stderr=return_stdout_stderr[1],
env=env,
text=True,
)
else:
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
start_time = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
response = session.get(
f"{base_url}/health",
headers=headers,
)
if response.status_code == 200:
return process
except requests.RequestException:
pass
return_code = process.poll()
if return_code is not None:
raise Exception(f"Server unexpectedly exits ({return_code=}).")
time.sleep(10)
kill_process_tree(process.pid) return process
raise TimeoutError("Server failed to start within the timeout period.")
def run_with_timeout( def run_with_timeout(
......
prompt = "The capital of taiwan is "
import json
import requests
response = requests.post(
"http://0.0.0.0:8000/generate",
json={
"text": prompt,
"sampling_params": {"temperature": 0},
"return_logprob": True,
"return_input_logprob": True,
"logprob_start_len": 0,
},
)
j = response.json()
input_logprobs = j["meta_info"]["input_token_logprobs"]
output_logprobs = j["meta_info"]["output_token_logprobs"]
print(len(input_logprobs), len(output_logprobs))
import os
import subprocess import subprocess
import time import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from urllib.parse import urlparse
import requests import requests
...@@ -25,15 +27,22 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -25,15 +27,22 @@ class TestDisaggregationAccuracy(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_host = "127.0.0.1" parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1]) cls.base_host = parsed_url.hostname
cls.lb_url = DEFAULT_URL_FOR_TEST base_port = str(parsed_url.port)
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}" cls.lb_port = base_port
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}" cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health") cls.wait_server_ready(cls.decode_url + "/health")
...@@ -48,7 +57,7 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -48,7 +57,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--host", "--host",
cls.base_host, cls.base_host,
"--port", "--port",
str(cls.base_port), cls.lb_port,
] ]
print("Starting load balancer:", " ".join(lb_command)) print("Starting load balancer:", " ".join(lb_command))
...@@ -63,14 +72,10 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -63,14 +72,10 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--trust-remote-code", "--trust-remote-code",
"--disaggregation-mode", "--disaggregation-mode",
"prefill", "prefill",
"--host",
cls.base_host,
"--port",
str(cls.base_port + 100),
"--tp", "--tp",
"4", "1",
# "--disaggregation-ib-device", "--disaggregation-ib-device",
# "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3", "mlx5_roce0",
] ]
cls.process_prefill = popen_launch_pd_server( cls.process_prefill = popen_launch_pd_server(
cls.model, cls.model,
...@@ -85,16 +90,165 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -85,16 +90,165 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--trust-remote-code", "--trust-remote-code",
"--disaggregation-mode", "--disaggregation-mode",
"decode", "decode",
"--tp",
"1",
"--base-gpu-id",
"1",
"--disaggregation-ib-device",
"mlx5_roce1",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.62)
def test_logprob(self):
prompt = "The capital of taiwan is "
response = requests.post(
self.lb_url + "/generate",
json={
"text": prompt,
"sampling_params": {"temperature": 0},
"return_logprob": True,
"return_input_logprob": True,
"logprob_start_len": 0,
},
)
j = response.json()
completion_tokens = j["meta_info"]["completion_tokens"]
input_logprobs = j["meta_info"]["input_token_logprobs"]
output_logprobs = j["meta_info"]["output_token_logprobs"]
assert (
len(output_logprobs) == completion_tokens
), f"output_logprobs and completion_tokens should have the same length, but got {len(output_logprobs)} and {completion_tokens}"
assert (
len(input_logprobs) > 0
), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"
class TestDisaggregationMooncakeFailure(CustomTestCase):
@classmethod
def setUpClass(cls):
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang.srt.disaggregation.mini_lb",
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host", "--host",
cls.base_host, cls.base_host,
"--port", "--port",
str(cls.base_port + 200), cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp", "--tp",
"4", "1",
"--disaggregation-ib-device",
"mlx5_roce0",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id", "--base-gpu-id",
"4", "1",
# "--disaggregation-ib-device", "--disaggregation-ib-device",
# "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", "mlx5_roce1",
] ]
cls.process_decode = popen_launch_pd_server( cls.process_decode = popen_launch_pd_server(
cls.model, cls.model,
...@@ -121,6 +275,8 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -121,6 +275,8 @@ class TestDisaggregationAccuracy(CustomTestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
# unset DISAGGREGATION_TEST_FAILURE_PROB
os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process: if process:
try: try:
...@@ -128,6 +284,9 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -128,6 +284,9 @@ class TestDisaggregationAccuracy(CustomTestCase):
except Exception as e: except Exception as e:
print(f"Error killing process {process.pid}: {e}") print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self): def test_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
...@@ -135,27 +294,29 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -135,27 +294,29 @@ class TestDisaggregationAccuracy(CustomTestCase):
num_questions=200, num_questions=200,
max_new_tokens=512, max_new_tokens=512,
parallel=128, parallel=128,
host="http://127.0.0.1", host=f"http://{self.base_host}",
port=int(self.lb_url.split(":")[-1]), port=int(self.lb_port),
) )
metrics = run_eval_few_shot_gsm8k(args) metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}") print(f"Evaluation metrics: {metrics}")
# Expect lots of failure but the server cannot crash
self.assertGreater(metrics["accuracy"], 0.62)
class TestDisaggregationSpecAccuracy(CustomTestCase): class TestDisaggregationMooncakeSpec(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
cls.base_host = "127.0.0.1" parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1]) cls.base_host = parsed_url.hostname
cls.lb_url = DEFAULT_URL_FOR_TEST base_port = str(parsed_url.port)
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}" cls.lb_port = base_port
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}" cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
cls.spec_args = [ cls.spec_args = [
"--speculative-algorithm", "--speculative-algorithm",
"EAGLE", "EAGLE",
...@@ -170,10 +331,13 @@ class TestDisaggregationSpecAccuracy(CustomTestCase): ...@@ -170,10 +331,13 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"8", "8",
] ]
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) # Non blocking start servers
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health") cls.wait_server_ready(cls.decode_url + "/health")
...@@ -188,7 +352,7 @@ class TestDisaggregationSpecAccuracy(CustomTestCase): ...@@ -188,7 +352,7 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--host", "--host",
cls.base_host, cls.base_host,
"--port", "--port",
str(cls.base_port), cls.lb_port,
] ]
print("Starting load balancer:", " ".join(lb_command)) print("Starting load balancer:", " ".join(lb_command))
...@@ -215,21 +379,15 @@ class TestDisaggregationSpecAccuracy(CustomTestCase): ...@@ -215,21 +379,15 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
@classmethod @classmethod
def start_prefill(cls): def start_prefill(cls):
prefill_args = [ prefill_args = [
"--trust-remote-code", "--trust-remote-code",
"--disaggregation-mode", "--disaggregation-mode",
"prefill", "prefill",
"--host",
cls.base_host,
"--port",
str(cls.base_port + 100),
"--tp", "--tp",
"4", "2",
# "--disaggregation-ib-device", "--disaggregation-ib-device",
# "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3", "mlx5_roce0,mlx5_roce1",
] + cls.spec_args ] + cls.spec_args
cls.process_prefill = popen_launch_pd_server( cls.process_prefill = popen_launch_pd_server(
cls.model, cls.model,
cls.prefill_url, cls.prefill_url,
...@@ -243,16 +401,12 @@ class TestDisaggregationSpecAccuracy(CustomTestCase): ...@@ -243,16 +401,12 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--trust-remote-code", "--trust-remote-code",
"--disaggregation-mode", "--disaggregation-mode",
"decode", "decode",
"--host",
cls.base_host,
"--port",
str(cls.base_port + 200),
"--tp", "--tp",
"4", "2",
"--base-gpu-id", "--base-gpu-id",
"4", "2",
# "--disaggregation-ib-device", "--disaggregation-ib-device",
# "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", "mlx5_roce2,mlx5_roce3",
] + cls.spec_args ] + cls.spec_args
cls.process_decode = popen_launch_pd_server( cls.process_decode = popen_launch_pd_server(
cls.model, cls.model,
...@@ -261,15 +415,43 @@ class TestDisaggregationSpecAccuracy(CustomTestCase): ...@@ -261,15 +415,43 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
other_args=decode_args, other_args=decode_args,
) )
@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self): def test_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
data_path=None, data_path=None,
num_questions=200, num_questions=200,
max_new_tokens=512, max_new_tokens=512,
parallel=4, # TODO: 128 crashes the decode parallel=2,
host="http://127.0.0.1", host=f"http://{self.base_host}",
port=int(self.lb_url.split(":")[-1]), port=int(self.lb_port),
) )
metrics = run_eval_few_shot_gsm8k(args) metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}") print(f"Evaluation metrics: {metrics}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment