Unverified Commit 5e65d6b2 authored by Ayush Satyam's avatar Ayush Satyam Committed by GitHub
Browse files

fix[DP][v1]: Prevent hangs from mismatched worker configurations (#26218)


Signed-off-by: default avatarAyush Satyam <ayushsatyam146@gmail.com>
parent 0d4f48fa
...@@ -336,6 +336,9 @@ class ParallelConfig: ...@@ -336,6 +336,9 @@ class ParallelConfig:
graph from input ids/embeddings to the final hidden states, graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after excluding anything before input ids/embeddings and after
the final hidden states. the final hidden states.
This hash is also used for DP worker configuration validation
to prevent hangs from mismatched collective communication patterns.
""" """
factors: list[Any] = [] factors: list[Any] = []
factors.append(self.pipeline_parallel_size) factors.append(self.pipeline_parallel_size)
...@@ -343,6 +346,12 @@ class ParallelConfig: ...@@ -343,6 +346,12 @@ class ParallelConfig:
factors.append(self.enable_expert_parallel) factors.append(self.enable_expert_parallel)
factors.append(self.data_parallel_size) factors.append(self.data_parallel_size)
factors.append(envs.VLLM_ALL2ALL_BACKEND) factors.append(envs.VLLM_ALL2ALL_BACKEND)
factors.append(self.enable_eplb)
if self.enable_eplb:
factors.append(self.eplb_config.log_balancedness)
factors.append(self.eplb_config.window_size)
factors.append(self.eplb_config.step_interval)
factors.append(self.eplb_config.num_redundant_experts)
return hashlib.sha256(str(factors).encode()).hexdigest() return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None: def __post_init__(self) -> None:
......
...@@ -681,17 +681,21 @@ class EngineCoreProc(EngineCore): ...@@ -681,17 +681,21 @@ class EngineCoreProc(EngineCore):
# external LB case for our colocated front-end to use (coordinator # external LB case for our colocated front-end to use (coordinator
# only runs with rank 0). # only runs with rank 0).
dp_stats_address = self.frontend_stats_publish_address dp_stats_address = self.frontend_stats_publish_address
handshake_socket.send(
msgspec.msgpack.encode( # Include config hash for DP configuration validation
{ ready_msg = {
"status": "READY", "status": "READY",
"local": local_client, "local": local_client,
"headless": headless, "headless": headless,
"num_gpu_blocks": num_gpu_blocks, "num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address, "dp_stats_address": dp_stats_address,
} }
if vllm_config.parallel_config.data_parallel_size > 1:
ready_msg["parallel_config_hash"] = (
vllm_config.parallel_config.compute_hash()
) )
)
handshake_socket.send(msgspec.msgpack.encode(ready_msg))
@staticmethod @staticmethod
def startup_handshake( def startup_handshake(
......
...@@ -73,6 +73,7 @@ class EngineHandshakeMetadata: ...@@ -73,6 +73,7 @@ class EngineHandshakeMetadata:
addresses: EngineZmqAddresses addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str, list[int]]] parallel_config: dict[str, Union[int, str, list[int]]]
parallel_config_hash: Optional[str] = None
class CoreEngineProcManager: class CoreEngineProcManager:
...@@ -867,7 +868,8 @@ def wait_for_engine_startup( ...@@ -867,7 +868,8 @@ def wait_for_engine_startup(
) )
if status == "HELLO" and engine.state == CoreEngineState.NEW: if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info. # Send init message with DP config info and config hash.
# The config hash ensures all DP workers have compatible configs.
init_message = msgspec.msgpack.encode( init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata( EngineHandshakeMetadata(
addresses=addresses, addresses=addresses,
...@@ -880,6 +882,9 @@ def wait_for_engine_startup( ...@@ -880,6 +882,9 @@ def wait_for_engine_startup(
"data_parallel_size", "data_parallel_size",
) )
}, },
parallel_config_hash=parallel_config.compute_hash()
if parallel_config.data_parallel_size > 1
else None,
) )
) )
handshake_socket.send_multipart((eng_identity, init_message), copy=False) handshake_socket.send_multipart((eng_identity, init_message), copy=False)
...@@ -900,6 +905,23 @@ def wait_for_engine_startup( ...@@ -900,6 +905,23 @@ def wait_for_engine_startup(
if addresses.frontend_stats_publish_address is None: if addresses.frontend_stats_publish_address is None:
addresses.frontend_stats_publish_address = msg.get("dp_stats_address") addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
# Validate config hash consistency across DP workers
if parallel_config.data_parallel_size > 1:
worker_config_hash = msg.get("parallel_config_hash")
expected_hash = parallel_config.compute_hash()
if worker_config_hash != expected_hash:
raise RuntimeError(
f"Configuration mismatch detected for engine "
f"{eng_index}. All DP workers must have identical "
f"configurations for parameters that affect collective "
f"communication (e.g., enable_eplb, "
f"eplb_config.log_balancedness). "
f"Worker hash: {worker_config_hash}, "
f"Expected hash: {expected_hash}. "
f"Please ensure all workers are started with the same "
f"command-line arguments."
)
start_pending[0 if local else 1] -= 1 start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY engine.state = CoreEngineState.READY
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment