Unverified Commit a2ba0bc3 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

Tiny clean up for PD module and doc (#11747)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent 6d2d0ce2
...@@ -41,6 +41,7 @@ uv pip install mooncake-transfer-engine ...@@ -41,6 +41,7 @@ uv pip install mooncake-transfer-engine
python -m sglang.launch_server \ python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \ --model-path meta-llama/Llama-3.1-8B-Instruct \
--disaggregation-mode prefill \ --disaggregation-mode prefill \
--port 30000 \
--disaggregation-ib-device mlx5_roce0 --disaggregation-ib-device mlx5_roce0
python -m sglang.launch_server \ python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \ --model-path meta-llama/Llama-3.1-8B-Instruct \
...@@ -179,6 +180,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx" ...@@ -179,6 +180,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
python -m sglang.launch_server \ python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \ --model-path meta-llama/Llama-3.1-8B-Instruct \
--disaggregation-mode prefill \ --disaggregation-mode prefill \
--port 30000 \
--disaggregation-transfer-backend nixl --disaggregation-transfer-backend nixl
python -m sglang.launch_server \ python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \ --model-path meta-llama/Llama-3.1-8B-Instruct \
...@@ -282,6 +284,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true ...@@ -282,6 +284,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true
python -m sglang.launch_server \ python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \ --model-path meta-llama/Llama-3.1-8B-Instruct \
--disaggregation-mode prefill \ --disaggregation-mode prefill \
--port 30000 \
--disaggregation-transfer-backend ascend --disaggregation-transfer-backend ascend
python -m sglang.launch_server \ python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \ --model-path meta-llama/Llama-3.1-8B-Instruct \
......
...@@ -246,6 +246,7 @@ class CommonKVReceiver(BaseKVReceiver): ...@@ -246,6 +246,7 @@ class CommonKVReceiver(BaseKVReceiver):
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
) )
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
self.bootstrap_infos = None
return return
else: else:
logger.debug( logger.debug(
......
...@@ -174,7 +174,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -174,7 +174,7 @@ class MooncakeKVManager(CommonKVManager):
cpu_count = os.cpu_count() cpu_count = os.cpu_count()
transfer_thread_pool_size = get_int_env_var( transfer_thread_pool_size = get_int_env_var(
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE", "SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
min(max(4, int(0.75 * cpu_count) // 8), 12), min(max(4, int(0.5 * cpu_count) // 8), 12),
) )
transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4) transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4)
self.transfer_queues: List[FastQueue] = [ self.transfer_queues: List[FastQueue] = [
...@@ -190,9 +190,6 @@ class MooncakeKVManager(CommonKVManager): ...@@ -190,9 +190,6 @@ class MooncakeKVManager(CommonKVManager):
) )
for _ in range(transfer_queue_size) for _ in range(transfer_queue_size)
] ]
self.state_executors = concurrent.futures.ThreadPoolExecutor(
transfer_thread_pool_size // transfer_queue_size
)
for queue, executor in zip(self.transfer_queues, self.executors): for queue, executor in zip(self.transfer_queues, self.executors):
threading.Thread( threading.Thread(
target=self.transfer_worker, args=(queue, executor), daemon=True target=self.transfer_worker, args=(queue, executor), daemon=True
...@@ -641,6 +638,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -641,6 +638,7 @@ class MooncakeKVManager(CommonKVManager):
req: TransferInfo, req: TransferInfo,
prefill_state_indices: list[int], prefill_state_indices: list[int],
dst_state_data_ptrs: list[int], dst_state_data_ptrs: list[int],
executor: concurrent.futures.ThreadPoolExecutor,
): ):
"""Send state or extra pool data with type-specific handling.""" """Send state or extra pool data with type-specific handling."""
state_type = getattr(self.kv_args, "state_type", "none") state_type = getattr(self.kv_args, "state_type", "none")
...@@ -662,7 +660,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -662,7 +660,7 @@ class MooncakeKVManager(CommonKVManager):
item_lens=self.kv_args.state_item_lens, item_lens=self.kv_args.state_item_lens,
prefill_data_indices=prefill_state_indices, prefill_data_indices=prefill_state_indices,
dst_data_indices=dst_state_indices, dst_data_indices=dst_state_indices,
executor=self.state_executors, executor=executor,
) )
else: else:
return 0 return 0
...@@ -810,6 +808,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -810,6 +808,7 @@ class MooncakeKVManager(CommonKVManager):
req, req,
kv_chunk.state_indices, kv_chunk.state_indices,
target_rank_registration_info.dst_state_data_ptrs, target_rank_registration_info.dst_state_data_ptrs,
executor,
) )
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
...@@ -1257,6 +1256,14 @@ class MooncakeKVReceiver(CommonKVReceiver): ...@@ -1257,6 +1256,14 @@ class MooncakeKVReceiver(CommonKVReceiver):
aux_index: Optional[int] = None, aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None, state_indices: Optional[List[int]] = None,
): ):
if self.bootstrap_infos is None:
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
sock, lock = self._connect_to_bootstrap_server(bootstrap_info) sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"] is_dummy = bootstrap_info["is_dummy"]
......
...@@ -762,6 +762,13 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -762,6 +762,13 @@ class NixlKVReceiver(CommonKVReceiver):
aux_index: Optional[int] = None, aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None, state_indices: Optional[List[int]] = None,
): ):
if self.bootstrap_infos is None:
logger.error(
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
logger.debug( logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment