Unverified Commit 384f8ab5 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Support PD disaggregation with Prefill PP (#8846)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
Co-authored-by: default avatarroot <huzhiyuan@xiaohongshu.com>
Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
Co-authored-by: default avatarFrancis <38564764+ssssnow@users.noreply.github.com>
Co-authored-by: default avatarzitto <zhjc1124@gmail.com>
parent 6a9d6ca3
...@@ -30,6 +30,7 @@ class KVArgs: ...@@ -30,6 +30,7 @@ class KVArgs:
# for pp prefill # for pp prefill
prefill_pp_size: int prefill_pp_size: int
pp_rank: int pp_rank: int
prefill_start_layer: int
# for system dp # for system dp
system_dp_rank: int system_dp_rank: int
......
...@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.common.utils import ( ...@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.common.utils import (
) )
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_dp_rank, get_attention_dp_rank,
get_attention_dp_size, get_attention_dp_size,
...@@ -180,6 +181,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -180,6 +181,7 @@ class MooncakeKVManager(BaseKVManager):
self.session_failures = defaultdict(int) self.session_failures = defaultdict(int)
self.failed_sessions = set() self.failed_sessions = set()
self.session_lock = threading.Lock() self.session_lock = threading.Lock()
self.pp_group = get_pp_group()
# Determine the number of threads to use for kv sender # Determine the number of threads to use for kv sender
cpu_count = os.cpu_count() cpu_count = os.cpu_count()
transfer_thread_pool_size = get_int_env_var( transfer_thread_pool_size = get_int_env_var(
...@@ -313,11 +315,11 @@ class MooncakeKVManager(BaseKVManager): ...@@ -313,11 +315,11 @@ class MooncakeKVManager(BaseKVManager):
layers_params = None layers_params = None
# pp is not supported on the decode side yet # pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
if self.is_mla_backend: if self.is_mla_backend:
src_kv_ptrs = self.kv_args.kv_data_ptrs src_kv_ptrs = self.kv_args.kv_data_ptrs
layers_per_pp_stage = len(src_kv_ptrs) layers_per_pp_stage = len(src_kv_ptrs)
start_layer = self.pp_rank * layers_per_pp_stage
end_layer = start_layer + layers_per_pp_stage
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer] dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [ layers_params = [
...@@ -330,17 +332,15 @@ class MooncakeKVManager(BaseKVManager): ...@@ -330,17 +332,15 @@ class MooncakeKVManager(BaseKVManager):
] ]
else: else:
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
dst_num_total_layers = num_kv_layers * self.pp_size
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:] src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs) layers_per_pp_stage = len(src_k_ptrs)
start_layer = self.pp_rank * layers_per_pp_stage
end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[ dst_v_ptrs = dst_kv_ptrs[
num_kv_layers + start_layer : num_kv_layers + end_layer dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
] ]
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [ layers_params = [
( (
src_k_ptrs[layer_id], src_k_ptrs[layer_id],
...@@ -452,6 +452,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -452,6 +452,7 @@ class MooncakeKVManager(BaseKVManager):
# pp is not supported on the decode side yet # pp is not supported on the decode side yet
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
dst_num_total_layers = num_kv_layers * self.pp_size
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:] src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs) layers_per_pp_stage = len(src_k_ptrs)
...@@ -459,7 +460,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -459,7 +460,7 @@ class MooncakeKVManager(BaseKVManager):
end_layer = start_layer + layers_per_pp_stage end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[ dst_v_ptrs = dst_kv_ptrs[
num_kv_layers + start_layer : num_kv_layers + end_layer dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
] ]
# Calculate precise byte offset and length for the sub-slice within the token # Calculate precise byte offset and length for the sub-slice within the token
...@@ -612,7 +613,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -612,7 +613,7 @@ class MooncakeKVManager(BaseKVManager):
) )
polls = [] polls = []
dst_ranks_infos = [] dst_ranks_infos = []
local_rank = self.kv_args.engine_rank local_rank = self.attn_tp_rank * self.pp_size + self.pp_rank
for req in reqs_to_be_processed: for req in reqs_to_be_processed:
if not req.is_dummy: if not req.is_dummy:
# Early exit if the request has failed # Early exit if the request has failed
...@@ -695,13 +696,14 @@ class MooncakeKVManager(BaseKVManager): ...@@ -695,13 +696,14 @@ class MooncakeKVManager(BaseKVManager):
break break
if kv_chunk.is_last: if kv_chunk.is_last:
# Only the last chunk we need to send the aux data if self.pp_group.is_last_rank:
ret = self.send_aux( # Only the last chunk we need to send the aux data
req.mooncake_session_id, ret = self.send_aux(
kv_chunk.prefill_aux_index, req.mooncake_session_id,
target_rank_registration_info.dst_aux_ptrs, kv_chunk.prefill_aux_index,
req.dst_aux_index, target_rank_registration_info.dst_aux_ptrs,
) req.dst_aux_index,
)
polls.append(True if ret == 0 else False) polls.append(True if ret == 0 else False)
dst_ranks_infos.append( dst_ranks_infos.append(
(req.endpoint, req.dst_port, req.room) (req.endpoint, req.dst_port, req.room)
...@@ -798,10 +800,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -798,10 +800,7 @@ class MooncakeKVManager(BaseKVManager):
arrived_response_num = len( arrived_response_num = len(
self.prefill_response_tracker[bootstrap_room] self.prefill_response_tracker[bootstrap_room]
) )
if ( if arrived_response_num == expected_response_num:
self.is_mla_backend
or arrived_response_num == expected_response_num
):
self.update_status(bootstrap_room, KVPoll.Success) self.update_status(bootstrap_room, KVPoll.Success)
elif status == KVPoll.Failed: elif status == KVPoll.Failed:
self.record_failure( self.record_failure(
...@@ -1183,7 +1182,9 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1183,7 +1182,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
) )
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.required_prefill_response_num = 1 self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank] self.target_tp_ranks = [self.target_tp_rank]
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size: elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
if not self.kv_mgr.is_mla_backend: if not self.kv_mgr.is_mla_backend:
...@@ -1196,7 +1197,9 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1196,7 +1197,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.required_dst_info_num = ( self.required_dst_info_num = (
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
) )
self.required_prefill_response_num = 1 self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank] self.target_tp_ranks = [self.target_tp_rank]
else: else:
if not self.kv_mgr.is_mla_backend: if not self.kv_mgr.is_mla_backend:
...@@ -1219,9 +1222,14 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1219,9 +1222,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
# or the KVPoll will never be set correctly # or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0] self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.required_prefill_response_num = ( if self.kv_mgr.is_mla_backend:
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size self.required_prefill_response_num = (
) self.prefill_pp_size // self.kv_mgr.pp_size
)
else:
self.required_prefill_response_num = (
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
if self.data_parallel_rank is not None: if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
...@@ -1530,7 +1538,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -1530,7 +1538,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
"rank_port": rank_port, "rank_port": rank_port,
} }
logger.debug( logger.debug(
f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
) )
return web.Response(text="OK", status=200) return web.Response(text="OK", status=200)
......
...@@ -43,8 +43,13 @@ from sglang.srt.disaggregation.utils import ( ...@@ -43,8 +43,13 @@ from sglang.srt.disaggregation.utils import (
prepare_abort, prepare_abort,
) )
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import require_mlp_sync from sglang.srt.utils import (
DynamicGradMode,
broadcast_pyobj,
point_to_point_pyobj,
require_mlp_sync,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -107,6 +112,7 @@ class PrefillBootstrapQueue: ...@@ -107,6 +112,7 @@ class PrefillBootstrapQueue:
kv_args.system_dp_rank = self.scheduler.dp_rank kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
kv_args.prefill_pp_size = self.pp_size kv_args.prefill_pp_size = self.pp_size
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
kv_data_ptrs, kv_data_lens, kv_item_lens = ( kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos() self.token_to_kv_pool.get_contiguous_buf_infos()
) )
...@@ -208,8 +214,8 @@ class PrefillBootstrapQueue: ...@@ -208,8 +214,8 @@ class PrefillBootstrapQueue:
polls = poll_and_all_reduce( polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.queue], self.gloo_group [req.disagg_kv_sender for req in self.queue], self.gloo_group
) )
for i, (req, poll) in enumerate(zip(self.queue, polls)):
for i, (req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None: if rids_to_check is not None:
# if req not in reqs_info_to_check, skip # if req not in reqs_info_to_check, skip
if req.rid not in rids_to_check: if req.rid not in rids_to_check:
...@@ -395,7 +401,10 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -395,7 +401,10 @@ class SchedulerDisaggregationPrefillMixin:
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock self.tree_cache.cache_unfinished_req(req) # update the tree and lock
self.disagg_prefill_inflight_queue.append(req) self.disagg_prefill_inflight_queue.append(req)
if logits_output.hidden_states is not None: if (
logits_output is not None
and logits_output.hidden_states is not None
):
last_hidden_index = ( last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1 hidden_state_offset + extend_input_len_per_req[i] - 1
) )
...@@ -603,3 +612,250 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -603,3 +612,250 @@ class SchedulerDisaggregationPrefillMixin:
) )
return return
req.disagg_kv_sender.send(page_indices) req.disagg_kv_sender.send(page_indices)
# PP
@DynamicGradMode()
def event_loop_pp_disagg_prefill(self: Scheduler):
"""
An event loop for the prefill server in pipeline parallelism.
Rules:
1. Each stage runs in the same order and is notified by the previous stage.
2. Each send/recv operation is blocking and matched by the neighboring stage.
Regular Schedule:
====================================================================
Stage i | Stage i+1
send ith req | recv ith req
send ith proxy | recv ith proxy
send prev (i+1)th carry | recv prev (i+1)th carry
====================================================================
Prefill Server Schedule:
====================================================================
Stage i | Stage i+1
send ith req | recv ith req
send ith bootstrap req | recv ith bootstrap req
send ith transferred req | recv ith transferred req
send ith proxy | recv ith proxy
send prev (i+1)th carry | recv prev (i+1)th carry
send prev (i+1)th release req | recv prev (i+1)th release req
====================================================================
There are two additional elements compared to the regular schedule:
1. Bootstrap Requests:
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
2. Transferred Requests + Release Requests:
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
"""
from sglang.srt.managers.scheduler import GenerationBatchResult
mbs = [None] * self.pp_size
last_mbs = [None] * self.pp_size
self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None
# Either success or failed
bootstrapped_rids: List[str] = []
transferred_rids: List[str] = []
release_rids: Optional[List[str]] = None
# transferred microbatch
tmbs = [None] * self.pp_size
ENABLE_RELEASE = True # For debug
while True:
server_is_idle = True
for mb_id in range(self.pp_size):
self.running_batch = self.running_mbs[mb_id]
self.last_batch = last_mbs[mb_id]
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
if self.pp_group.is_first_rank:
# First rank, pop the bootstrap reqs from the bootstrap queue
bootstrapped_reqs, failed_reqs = (
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
return_failed_reqs=True
)
)
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
req.rid for req in failed_reqs
]
self.waiting_queue.extend(bootstrapped_reqs)
else:
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
bootstrapped_reqs = (
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
rids_to_check=bootstrapped_rids
)
)
self.waiting_queue.extend(bootstrapped_reqs)
if self.pp_group.is_first_rank:
transferred_rids = self.get_transferred_rids()
# if other ranks,
else:
# 1. recv previous stage's transferred reqs info
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
# 2. get the current stage's transferred reqs info
curr_transferred_rids = self.get_transferred_rids()
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
transferred_rids = list(
set(prev_transferred_rids) & set(curr_transferred_rids)
)
tmbs[mb_id] = transferred_rids
self.process_prefill_chunk()
mbs[mb_id] = self.get_new_batch_prefill()
self.running_mbs[mb_id] = self.running_batch
self.cur_batch = mbs[mb_id]
if self.cur_batch:
server_is_idle = False
result = self.run_batch(self.cur_batch)
# send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
result.next_token_ids,
result.bid,
)
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
# send the output from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
if ENABLE_RELEASE:
if self.pp_group.is_last_rank:
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
release_rids = transferred_rids
# send to the first rank
self.send_pyobj_to_next_stage(release_rids)
# receive outputs and post-process (filter finished reqs) the coming microbatch
next_mb_id = (mb_id + 1) % self.pp_size
next_pp_outputs = None
next_release_rids = None
if mbs[next_mb_id] is not None:
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.attn_tp_group
)
)
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
output_result = GenerationBatchResult(
logits_output=None,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None,
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph,
)
self.process_batch_result_disagg_prefill(
mbs[next_mb_id], output_result
)
last_mbs[next_mb_id] = mbs[next_mb_id]
if ENABLE_RELEASE:
if tmbs[next_mb_id] is not None:
# recv consensus rids from the previous rank
next_release_rids = self.recv_pyobj_from_prev_stage()
self.process_disagg_prefill_inflight_queue(next_release_rids)
# carry the outputs to the next stage
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
if pp_outputs:
# send the outputs from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
if ENABLE_RELEASE:
if release_rids is not None:
self.send_pyobj_to_next_stage(release_rids)
if not self.pp_group.is_last_rank:
# send out reqs to the next stage
self.send_pyobj_to_next_stage(recv_reqs)
self.send_pyobj_to_next_stage(bootstrapped_rids)
self.send_pyobj_to_next_stage(transferred_rids)
# send out proxy tensors to the next stage
if self.cur_batch:
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors,
all_gather_group=self.attn_tp_group,
)
pp_outputs = next_pp_outputs
release_rids = next_release_rids
self.running_batch.batch_is_full = False
if not ENABLE_RELEASE:
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
# When the server is idle, self-check and re-init some states
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
def send_pyobj_to_next_stage(self, data):
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
point_to_point_pyobj(
data,
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset,
((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
)
def recv_pyobj_from_prev_stage(self):
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
data = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset,
)
else:
data = None
if self.tp_size != 1:
data = broadcast_pyobj(
data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
)
return data
...@@ -2579,7 +2579,10 @@ def run_scheduler_process( ...@@ -2579,7 +2579,10 @@ def run_scheduler_process(
if scheduler.enable_overlap: if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_prefill() scheduler.event_loop_overlap_disagg_prefill()
else: else:
scheduler.event_loop_normal_disagg_prefill() if server_args.pp_size > 1:
scheduler.event_loop_pp_disagg_prefill()
else:
scheduler.event_loop_normal_disagg_prefill()
elif disaggregation_mode == DisaggregationMode.DECODE: elif disaggregation_mode == DisaggregationMode.DECODE:
if scheduler.enable_overlap: if scheduler.enable_overlap:
......
from __future__ import annotations
import logging import logging
import multiprocessing as mp import multiprocessing as mp
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import GenerationBatchResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,6 +48,57 @@ def validate_input_length( ...@@ -41,6 +48,57 @@ def validate_input_length(
return None return None
def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict:
logits_output = result.logits_output
assert logits_output is not None
return {
"extend_input_len_per_req": result.extend_input_len_per_req,
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
"next_token_logprobs": result.logits_output.next_token_logprobs,
"next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val,
"next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx,
"next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val,
"next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx,
"input_token_logprobs": result.logits_output.input_token_logprobs,
"input_top_logprobs_val": result.logits_output.input_top_logprobs_val,
"input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx,
"input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val,
"input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx,
}
def get_logprob_from_pp_outputs(
next_pp_outputs: PPProxyTensors,
) -> tuple[LogitsProcessorOutput, list[int], list[int]]:
logits_output = LogitsProcessorOutput(
# Do not send logits and hidden states because they are large
next_token_logits=None,
hidden_states=None,
next_token_logprobs=next_pp_outputs["next_token_logprobs"],
next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"],
next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"],
next_token_token_ids_logprobs_val=next_pp_outputs[
"next_token_token_ids_logprobs_val"
],
next_token_token_ids_logprobs_idx=next_pp_outputs[
"next_token_token_ids_logprobs_idx"
],
input_token_logprobs=next_pp_outputs["input_token_logprobs"],
input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"],
input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"],
input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"],
input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"],
)
extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"]
extend_logprob_start_len_per_req = next_pp_outputs[
"extend_logprob_start_len_per_req"
]
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
class DPBalanceMeta: class DPBalanceMeta:
""" """
This class will be use in scheduler and dp controller This class will be use in scheduler and dp controller
......
...@@ -849,7 +849,7 @@ class MLATokenToKVPool(KVCache): ...@@ -849,7 +849,7 @@ class MLATokenToKVPool(KVCache):
cache_k_rope = cache_k_rope.view(self.store_dtype) cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton( set_mla_kv_buffer_triton(
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
) )
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):
......
...@@ -307,8 +307,13 @@ class ModelRunner: ...@@ -307,8 +307,13 @@ class ModelRunner:
self.start_layer = getattr(self.model, "start_layer", 0) self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr(self.model, "end_layer", model_num_layers) self.end_layer = getattr(self.model, "end_layer", model_num_layers)
self.num_effective_layers = self.end_layer - self.start_layer self.num_effective_layers = self.end_layer - self.start_layer
assert (not model_has_mtp_layers) or ( assert (
self.num_effective_layers == model_num_layers (not model_has_mtp_layers)
or (self.spec_algorithm.is_none())
or (
(not self.spec_algorithm.is_none())
and (self.num_effective_layers == model_num_layers)
)
), "PP is not compatible with MTP models." ), "PP is not compatible with MTP models."
# Apply torchao quantization # Apply torchao quantization
...@@ -1048,8 +1053,6 @@ class ModelRunner: ...@@ -1048,8 +1053,6 @@ class ModelRunner:
else: else:
num_layers = self.num_effective_layers num_layers = self.num_effective_layers
if self.use_mla_backend: if self.use_mla_backend:
# FIXME: pipeline parallelism is not compatible with mla backend
assert self.pp_size == 1
cell_size = ( cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* num_layers * num_layers
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -135,6 +135,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -135,6 +135,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
# if not set, model load will be broken in DeepseekV3ForCausalLM load_weights()
self.pp_group = get_pp_group()
self.determine_num_fused_shared_experts("DeepseekV3ForCausalLMNextN") self.determine_num_fused_shared_experts("DeepseekV3ForCausalLMNextN")
self.model = DeepseekModelNextN( self.model = DeepseekModelNextN(
......
...@@ -20,7 +20,7 @@ import concurrent.futures ...@@ -20,7 +20,7 @@ import concurrent.futures
import logging import logging
import os import os
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -30,6 +30,7 @@ from transformers import PretrainedConfig ...@@ -30,6 +30,7 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state, parallel_state,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
...@@ -87,13 +88,13 @@ from sglang.srt.layers.quantization.int8_utils import ( ...@@ -87,13 +88,13 @@ from sglang.srt.layers.quantization.int8_utils import (
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.two_batch_overlap import ( from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher, MaybeTboDeepEPDispatcher,
...@@ -114,6 +115,7 @@ from sglang.srt.utils import ( ...@@ -114,6 +115,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_non_idle_and_non_empty, is_non_idle_and_non_empty,
log_info_on_rank0, log_info_on_rank0,
make_layers,
use_intel_amx_backend, use_intel_amx_backend,
) )
...@@ -2029,26 +2031,35 @@ class DeepseekV2Model(nn.Module): ...@@ -2029,26 +2031,35 @@ class DeepseekV2Model(nn.Module):
self.padding_id = config.pad_token_id self.padding_id = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.first_k_dense_replace = config.first_k_dense_replace self.first_k_dense_replace = config.first_k_dense_replace
self.pp_group = get_pp_group()
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not is_dp_attention_enabled(),
)
else:
self.embed_tokens = PPMissingLayer()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not is_dp_attention_enabled(),
)
self.alt_stream = torch.cuda.Stream() if _is_cuda else None self.alt_stream = torch.cuda.Stream() if _is_cuda else None
self.layers = nn.ModuleList( self.layers, self.start_layer, self.end_layer = make_layers(
[ config.num_hidden_layers,
DeepseekV2DecoderLayer( lambda idx, prefix: DeepseekV2DecoderLayer(
config, config=config,
layer_id, layer_id=idx,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix), prefix=prefix,
alt_stream=self.alt_stream, alt_stream=self.alt_stream,
) ),
for layer_id in range(config.num_hidden_layers) pp_rank=self.pp_group.rank_in_group,
] pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
def get_input_embeddings(self) -> torch.Tensor: def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens return self.embed_tokens
...@@ -2059,8 +2070,9 @@ class DeepseekV2Model(nn.Module): ...@@ -2059,8 +2070,9 @@ class DeepseekV2Model(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: pp_proxy_tensors: Optional[PPProxyTensors] = None,
total_num_layers = len(self.layers) ) -> Union[torch.Tensor, PPProxyTensors]:
total_num_layers = self.end_layer - self.start_layer
device = input_embeds.device if input_embeds is not None else input_ids.device device = input_embeds.device if input_embeds is not None else input_ids.device
zero_allocator = BumpAllocator( zero_allocator = BumpAllocator(
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1), buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
...@@ -2068,44 +2080,62 @@ class DeepseekV2Model(nn.Module): ...@@ -2068,44 +2080,62 @@ class DeepseekV2Model(nn.Module):
device=device, device=device,
) )
if input_embeds is None: if self.pp_group.is_first_rank:
hidden_states = self.embed_tokens(input_ids) if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else: else:
hidden_states = input_embeds assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
residual = None normal_start_layer = self.start_layer
normal_end_layer = self.end_layer
if forward_batch.can_run_tbo:
if (
self.first_k_dense_replace > normal_start_layer
and self.first_k_dense_replace < normal_end_layer
):
normal_end_layer = self.first_k_dense_replace
elif self.first_k_dense_replace < normal_start_layer:
normal_end_layer = normal_start_layer = 0
normal_num_layers = ( for i in range(normal_start_layer, normal_end_layer):
self.first_k_dense_replace
if forward_batch.can_run_tbo
else total_num_layers
)
for i in range(normal_num_layers):
with get_global_expert_distribution_recorder().with_current_layer(i): with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual, zero_allocator positions, hidden_states, forward_batch, residual, zero_allocator
) )
if normal_num_layers != total_num_layers: if normal_end_layer != self.end_layer:
hidden_states, residual = model_forward_maybe_tbo( hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers[normal_num_layers:], layers=self.layers[normal_end_layer : self.end_layer],
enable_tbo=True, enable_tbo=True,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
input_data_scatter_mode=self.layers[ input_data_scatter_mode=self.layers[
normal_num_layers - 1 normal_end_layer - 1
].layer_scatter_modes.layer_output_mode, ].layer_scatter_modes.layer_output_mode,
zero_allocator=zero_allocator, zero_allocator=zero_allocator,
) )
if not forward_batch.forward_mode.is_idle(): if not self.pp_group.is_last_rank:
if residual is None: return PPProxyTensors(
hidden_states = self.norm(hidden_states) {
else: "hidden_states": hidden_states,
hidden_states, _ = self.norm(hidden_states, residual) "residual": residual,
}
)
else:
if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -2132,6 +2162,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2132,6 +2162,7 @@ class DeepseekV2ForCausalLM(nn.Module):
"kv_a_proj_with_mqa", "kv_a_proj_with_mqa",
] ]
self.pp_group = get_pp_group()
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
...@@ -2201,13 +2232,27 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2201,13 +2232,27 @@ class DeepseekV2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
) )
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return hidden_states
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def post_load_weights(self, is_nextn=False, weight_names=None): def post_load_weights(self, is_nextn=False, weight_names=None):
# Perform post-processing after loading weights # Perform post-processing after loading weights
...@@ -2215,7 +2260,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2215,7 +2260,7 @@ class DeepseekV2ForCausalLM(nn.Module):
layer_ids = [self.config.num_hidden_layers] layer_ids = [self.config.num_hidden_layers]
else: else:
if weight_names is None: if weight_names is None:
layer_ids = range(self.config.num_hidden_layers) layer_ids = range(self.model.start_layer, self.model.end_layer)
else: else:
layer_ids = set() layer_ids = set()
for name in weight_names: for name in weight_names:
...@@ -2497,6 +2542,16 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2497,6 +2542,16 @@ class DeepseekV2ForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
weight_names = [] weight_names = []
for name, loaded_weight in weights: for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
name = name.replace( name = name.replace(
"mlp.shared_experts", "mlp.shared_experts",
...@@ -2581,6 +2636,12 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2581,6 +2636,12 @@ class DeepseekV2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip loading embed_tokens if not first rank in pipeline parallelism
if ".embed_tokens." in name and not self.pp_group.is_first_rank:
continue
# Skip loading norm if not last rank in pipeline parallelism
if ".norm." in name and not self.pp_group.is_last_rank:
continue
if fuse_qkv_a_proj and ( if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name "q_a_proj" in name or "kv_a_proj_with_mqa" in name
): ):
......
import json
import os
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from types import SimpleNamespace
from typing import List, Optional
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.runners import DEFAULT_PROMPTS
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestPDPPAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang.srt.disaggregation.mini_lb",
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp-size",
"2",
"--pp-size",
"2",
"--disaggregation-ib-device",
"mlx5_roce0",
"--disable-overlap-schedule",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"1",
"--disaggregation-ib-device",
"mlx5_roce1",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.24)
# Wait a little bit so that the memory check happens.
time.sleep(5)
if __name__ == "__main__":
unittest.main()
...@@ -9,6 +9,8 @@ import time ...@@ -9,6 +9,8 @@ import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests
from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
...@@ -62,6 +64,29 @@ class TestPPAccuracy(unittest.TestCase): ...@@ -62,6 +64,29 @@ class TestPPAccuracy(unittest.TestCase):
# Wait a little bit so that the memory check happens. # Wait a little bit so that the memory check happens.
time.sleep(4) time.sleep(4)
def test_logprob(self):
response = requests.post(
f"{self.base_url}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
"return_logprob": True,
"top_logprobs_num": 5,
"logprob_start_len": 0,
},
)
response_json = response.json()
input_token_logprobs = response_json["meta_info"]["input_token_logprobs"]
output_token_logprobs = response_json["meta_info"]["output_token_logprobs"]
output_top_logprobs = response_json["meta_info"]["output_top_logprobs"]
assert len(input_token_logprobs) == 6
assert len(output_token_logprobs) == 16
assert len(output_top_logprobs) == 16
class TestQwenPPAccuracy(unittest.TestCase): class TestQwenPPAccuracy(unittest.TestCase):
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment