Unverified Commit 7353492a authored by jmswen's avatar jmswen Committed by GitHub
Browse files

[Core] Raise when non-multi-instance DP clients target a DP rank (#19227)


Signed-off-by: default avatarJon Swenson <jmswen@gmail.com>
parent 7661e92e
...@@ -384,3 +384,25 @@ async def test_delayed_generator(async_engine, stop): ...@@ -384,3 +384,25 @@ async def test_delayed_generator(async_engine, stop):
assert final_output is not None assert final_output is not None
assert len(final_output.outputs[0].token_ids) == 10 assert len(final_output.outputs[0].token_ids) == 10
assert final_output.finished assert final_output.finished
@pytest.mark.asyncio(scope="module")
async def test_invalid_argument(async_engine):
scheduler_config = await async_engine.get_scheduler_config()
if scheduler_config.num_scheduler_steps != 1:
pytest.skip("no need to test this one with multistep")
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
# Targeting specific DP rank only supported in v1 multi-instance DP
with pytest.raises(ValueError):
async for _ in async_engine.generate("test",
sampling_params,
request_id=uid(),
data_parallel_rank=0):
pass
...@@ -250,3 +250,32 @@ async def test_customize_loggers(monkeypatch): ...@@ -250,3 +250,32 @@ async def test_customize_loggers(monkeypatch):
assert len(engine.stat_loggers) == 1 assert len(engine.stat_loggers) == 1
assert len(engine.stat_loggers[0]) == 1 assert len(engine.stat_loggers[0]) == 1
engine.stat_loggers[0][0].log.assert_called_once() engine.stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio(scope="module")
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=100,
output_kind=RequestOutputKind.DELTA,
temperature=1.0,
seed=33)
# Test with valid DP rank.
async for _ in engine.generate(request_id="request-34",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
data_parallel_rank=0):
pass
# Test with out-of-range DP rank.
with pytest.raises(ValueError):
async for _ in engine.generate(request_id="request-35",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
data_parallel_rank=1):
pass
...@@ -29,12 +29,14 @@ if not current_platform.supports_v1(engine_args.create_model_config()): ...@@ -29,12 +29,14 @@ if not current_platform.supports_v1(engine_args.create_model_config()):
allow_module_level=True) allow_module_level=True)
async def generate(engine: AsyncLLM, async def generate(
request_id: str, engine: AsyncLLM,
prompt: PromptType, request_id: str,
output_kind: RequestOutputKind, prompt: PromptType,
max_tokens: int, output_kind: RequestOutputKind,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]: max_tokens: int,
prompt_logprobs: Optional[int] = None,
data_parallel_rank: Optional[int] = None) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test. # Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
...@@ -46,7 +48,8 @@ async def generate(engine: AsyncLLM, ...@@ -46,7 +48,8 @@ async def generate(engine: AsyncLLM,
prompt_logprobs=prompt_logprobs) prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id, async for out in engine.generate(request_id=request_id,
prompt=prompt, prompt=prompt,
sampling_params=sampling_params): sampling_params=sampling_params,
data_parallel_rank=data_parallel_rank):
num_tokens = len(out.outputs[0].token_ids) num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA: if output_kind == RequestOutputKind.DELTA:
...@@ -89,8 +92,12 @@ async def test_load(output_kind: RequestOutputKind, ...@@ -89,8 +92,12 @@ async def test_load(output_kind: RequestOutputKind,
for request_id in request_ids: for request_id in request_ids:
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
generate(engine, request_id, prompt, output_kind, generate(engine,
NUM_EXPECTED_TOKENS))) request_id,
prompt,
output_kind,
NUM_EXPECTED_TOKENS,
data_parallel_rank=0)))
# Confirm that we got all the EXPECTED tokens from the requests. # Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks, done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION) return_when=asyncio.FIRST_EXCEPTION)
......
...@@ -494,6 +494,10 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -494,6 +494,10 @@ class _AsyncLLMEngine(LLMEngine):
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
if data_parallel_rank is not None:
raise ValueError("Targeting data_parallel_rank only supported "
"in v1 client.")
if (isinstance(prompt, dict) if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)): and not prompt.get("prompt_token_ids", None)):
......
...@@ -1000,9 +1000,6 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -1000,9 +1000,6 @@ class DPAsyncMPClient(AsyncMPClient):
) -> CoreEngine: ) -> CoreEngine:
if dp_rank is not None: if dp_rank is not None:
# engines are already in rank order # engines are already in rank order
if dp_rank < 0 or dp_rank >= len(self.core_engines):
raise ValueError(f"Requested DP rank {dp_rank} is out of "
f"range [0, {len(self.core_engines)})")
return self.core_engines[dp_rank] return self.core_engines[dp_rank]
if not self.lb_engines: if not self.lb_engines:
......
...@@ -226,6 +226,12 @@ class Processor: ...@@ -226,6 +226,12 @@ class Processor:
if prompt_adapter_request is not None: if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.") raise ValueError("V1 does not support prompt_adapter_request.")
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
data_parallel_size):
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {data_parallel_size}).")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment