Unverified Commit 0e391e75 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix RequestOutput miss lora_request (#30636)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 0d0c929f
......@@ -76,6 +76,8 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
enable_lora=True,
max_loras=4,
max_lora_rank=8,
max_num_seqs=2,
max_num_batched_tokens=2048,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
......@@ -94,8 +96,10 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
enable_lora=True,
max_loras=2,
max_lora_rank=8,
max_num_seqs=16,
max_num_seqs=2,
max_num_batched_tokens=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.8,
fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
......
......@@ -76,11 +76,18 @@ def do_sample(
if lora_id
else None,
)
# Print the outputs.
lora_request = LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
# The output should include correct lora_request info
if lora_request is not None:
assert output.lora_request.lora_name == lora_request.lora_name
assert output.lora_request.lora_int_id == lora_request.lora_int_id
assert output.lora_request.lora_path == lora_request.lora_path
else:
assert output.lora_request is None
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
......
......@@ -8,6 +8,7 @@ from typing import Any, cast
import torch
from vllm.lora.request import LoRARequest
from vllm.outputs import (
CompletionOutput,
PoolingOutput,
......@@ -93,7 +94,7 @@ class RequestState:
request_id: str,
parent_req: ParentRequest | None,
request_index: int,
lora_name: str | None,
lora_request: LoRARequest | None,
output_kind: RequestOutputKind,
prompt: str | None,
prompt_token_ids: list[int] | None,
......@@ -112,7 +113,8 @@ class RequestState:
self.request_id = request_id
self.parent_req = parent_req
self.request_index = request_index
self.lora_name = lora_name
self.lora_request = lora_request
self.lora_name = lora_request.lora_name if lora_request is not None else None
self.output_kind = output_kind
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
......@@ -178,9 +180,7 @@ class RequestState:
request_id=request.request_id,
parent_req=parent_req,
request_index=request_index,
lora_name=(
request.lora_request.name if request.lora_request is not None else None
),
lora_request=request.lora_request,
output_kind=output_kind,
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
......@@ -289,6 +289,7 @@ class RequestState:
return RequestOutput(
request_id=request_id,
lora_request=self.lora_request,
prompt=self.prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment