Unverified Commit 8619e715 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix multi-node offline data parallel (#19937)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent c635c5f7
...@@ -615,13 +615,16 @@ steps: ...@@ -615,13 +615,16 @@ steps:
- vllm/executor/ - vllm/executor/
- vllm/model_executor/models/ - vllm/model_executor/models/
- tests/distributed/ - tests/distributed/
- tests/examples/offline_inference/data_parallel.py
commands: commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
- label: Distributed Tests (2 GPUs) # 40min - label: Distributed Tests (2 GPUs) # 40min
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
......
...@@ -1568,6 +1568,8 @@ class LLM: ...@@ -1568,6 +1568,8 @@ class LLM:
pbar.update(n) pbar.update(n)
else: else:
pbar.update(1) pbar.update(1)
if pbar.n == num_requests:
pbar.refresh()
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
......
...@@ -877,12 +877,16 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -877,12 +877,16 @@ class DPEngineCoreProc(EngineCoreProc):
local_unfinished_reqs) local_unfinished_reqs)
if not self.engines_running: if not self.engines_running:
if self.dp_rank == 0: if self.dp_rank == 0 or not self.has_coordinator:
# Notify client that we are pausing the loop. # Notify client that we are pausing the loop.
logger.debug("Wave %d finished, pausing engine loop.", logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave) self.current_wave)
# In the coordinator case, dp rank 0 sends updates to the
# coordinator. Otherwise (offline spmd case), each rank
# sends the update to its colocated front-end process.
client_index = -1 if self.has_coordinator else 0
self.output_queue.put_nowait( self.output_queue.put_nowait(
(-1, (client_index,
EngineCoreOutputs(wave_complete=self.current_wave))) EngineCoreOutputs(wave_complete=self.current_wave)))
self.current_wave += 1 self.current_wave += 1
......
...@@ -155,6 +155,11 @@ class EngineCoreClient(ABC): ...@@ -155,6 +155,11 @@ class EngineCoreClient(ABC):
kwargs: Optional[dict[str, Any]] = None) -> list[_R]: kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError raise NotImplementedError
def dp_engines_running(self) -> bool:
"""Returns True id data parallel engines are collectively in a
running state."""
raise NotImplementedError
async def get_output_async(self) -> EngineCoreOutputs: async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError raise NotImplementedError
...@@ -282,6 +287,9 @@ class InprocClient(EngineCoreClient): ...@@ -282,6 +287,9 @@ class InprocClient(EngineCoreClient):
kwargs: Optional[dict[str, Any]] = None) -> list[_R]: kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs) return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def dp_engines_running(self) -> bool:
return False
@dataclass @dataclass
class BackgroundResources: class BackgroundResources:
...@@ -384,6 +392,9 @@ class MPClient(EngineCoreClient): ...@@ -384,6 +392,9 @@ class MPClient(EngineCoreClient):
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank dp_rank = parallel_config.data_parallel_rank
# State used for data parallel.
self.engines_running = False
# SPMD mode is where there is an LLM instance per DP rank and # SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see # one core engine per LLM, see
# examples/offline_inference/data_parallel.py. # examples/offline_inference/data_parallel.py.
...@@ -539,6 +550,9 @@ class MPClient(EngineCoreClient): ...@@ -539,6 +550,9 @@ class MPClient(EngineCoreClient):
while self.pending_messages and self.pending_messages[-1][0].done: while self.pending_messages and self.pending_messages[-1][0].done:
self.pending_messages.pop() self.pending_messages.pop()
def dp_engines_running(self) -> bool:
return self.engines_running
def _process_utility_output(output: UtilityOutput, def _process_utility_output(output: UtilityOutput,
utility_results: dict[int, AnyFuture]): utility_results: dict[int, AnyFuture]):
...@@ -562,6 +576,7 @@ class SyncMPClient(MPClient): ...@@ -562,6 +576,7 @@ class SyncMPClient(MPClient):
log_stats=log_stats, log_stats=log_stats,
) )
self.is_dp = self.vllm_config.parallel_config.data_parallel_size > 1
self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]() self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]()
# Ensure that the outputs socket processing thread does not have # Ensure that the outputs socket processing thread does not have
...@@ -623,6 +638,8 @@ class SyncMPClient(MPClient): ...@@ -623,6 +638,8 @@ class SyncMPClient(MPClient):
outputs = self.outputs_queue.get() outputs = self.outputs_queue.get()
if isinstance(outputs, Exception): if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None raise self._format_exception(outputs) from None
if outputs.wave_complete is not None:
self.engines_running = False
return outputs return outputs
def _send_input(self, request_type: EngineCoreRequestType, request: Any): def _send_input(self, request_type: EngineCoreRequestType, request: Any):
...@@ -650,6 +667,8 @@ class SyncMPClient(MPClient): ...@@ -650,6 +667,8 @@ class SyncMPClient(MPClient):
return future.result() return future.result()
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
if self.is_dp:
self.engines_running = True
self._send_input(EngineCoreRequestType.ADD, request) self._send_input(EngineCoreRequestType.ADD, request)
def abort_requests(self, request_ids: list[str]) -> None: def abort_requests(self, request_ids: list[str]) -> None:
...@@ -911,7 +930,6 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -911,7 +930,6 @@ class DPAsyncMPClient(AsyncMPClient):
client_addresses: Optional[dict[str, str]] = None, client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0): client_index: int = 0):
self.current_wave = 0 self.current_wave = 0
self.engines_running = False
# To route aborts to the correct engine. # To route aborts to the correct engine.
self.reqs_in_flight: dict[str, CoreEngine] = {} self.reqs_in_flight: dict[str, CoreEngine] = {}
......
...@@ -160,7 +160,7 @@ class LLMEngine: ...@@ -160,7 +160,7 @@ class LLMEngine:
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests() has_unfinished = self.output_processor.has_unfinished_requests()
if self.dp_group is None: if self.dp_group is None:
return has_unfinished return has_unfinished or self.engine_core.dp_engines_running()
return self.has_unfinished_requests_dp(has_unfinished) return self.has_unfinished_requests_dp(has_unfinished)
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment