Unverified Commit e53be6f0 authored by Chales Xu's avatar Chales Xu Committed by GitHub
Browse files

[Misc] Add type assertion of request_id for LLMEngine.add_request (#19700)


Signed-off-by: default avatarn2ptr <xuzhanchaomail@163.com>
parent c329ceca
......@@ -66,7 +66,7 @@ async def test_evil_forward(tmp_socket):
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
request_id=str(uuid.uuid4())):
pass
assert client.errored
......@@ -115,7 +115,7 @@ async def test_failed_health_check(tmp_socket):
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
request_id=str(uuid.uuid4())):
pass
client.close()
......@@ -157,7 +157,7 @@ async def test_failed_abort(tmp_socket):
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()):
request_id=str(uuid.uuid4())):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored
......@@ -189,7 +189,7 @@ async def test_batch_error(tmp_socket):
params = SamplingParams(min_tokens=2048, max_tokens=2048)
async for _ in client.generate(prompt="Hello my name is",
sampling_params=params,
request_id=uuid.uuid4()):
request_id=str(uuid.uuid4())):
pass
tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)]
......@@ -289,7 +289,7 @@ async def test_engine_process_death(tmp_socket):
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
request_id=str(uuid.uuid4())):
pass
# And the health check should show the engine is dead
......
......@@ -687,6 +687,10 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if not isinstance(request_id, str):
raise TypeError(
f"request_id must be a string, got {type(request_id)}")
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
......
......@@ -192,6 +192,11 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
# Validate the request_id type.
if not isinstance(request_id, str):
raise TypeError(
f"request_id must be a string, got {type(request_id)}")
# Process raw inputs into the request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment