Unverified Commit 85362f02 authored by Jiaxin Shan's avatar Jiaxin Shan Committed by GitHub
Browse files

[Misc][LoRA] Ensure Lora Adapter requests return adapter name (#11094)


Signed-off-by: default avatarJiaxin Shan <seedjeffwan@gmail.com>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 62de37a3
...@@ -9,6 +9,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ...@@ -9,6 +9,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest, LoadLoraAdapterRequest,
UnloadLoraAdapterRequest) UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.lora.request import LoRARequest
MODEL_NAME = "meta-llama/Llama-2-7b" MODEL_NAME = "meta-llama/Llama-2-7b"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
...@@ -33,6 +34,16 @@ async def _async_serving_engine_init(): ...@@ -33,6 +34,16 @@ async def _async_serving_engine_init():
return serving_engine return serving_engine
@pytest.mark.asyncio
async def test_serving_model_name():
serving_engine = await _async_serving_engine_init()
assert serving_engine._get_model_name(None) == MODEL_NAME
request = LoRARequest(lora_name="adapter",
lora_path="/path/to/adapter2",
lora_int_id=1)
assert serving_engine._get_model_name(request) == request.lora_name
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_lora_adapter_success(): async def test_load_lora_adapter_success():
serving_engine = await _async_serving_engine_init() serving_engine = await _async_serving_engine_init()
......
...@@ -123,6 +123,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -123,6 +123,8 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
model_name = self._get_model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
tool_parser = self.tool_parser tool_parser = self.tool_parser
...@@ -238,13 +240,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -238,13 +240,13 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer, request, result_generator, request_id, model_name,
request_metadata) conversation, tokenizer, request_metadata)
try: try:
return await self.chat_completion_full_generator( return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer, request, result_generator, request_id, model_name,
request_metadata) conversation, tokenizer, request_metadata)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -259,11 +261,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -259,11 +261,11 @@ class OpenAIServingChat(OpenAIServing):
request: ChatCompletionRequest, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput],
request_id: str, request_id: str,
model_name: str,
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
model_name = self.base_model_paths[0].name
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True first_iteration = True
...@@ -604,12 +606,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -604,12 +606,12 @@ class OpenAIServingChat(OpenAIServing):
request: ChatCompletionRequest, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput],
request_id: str, request_id: str,
model_name: str,
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> Union[ErrorResponse, ChatCompletionResponse]: ) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.base_model_paths[0].name
created_time = int(time.time()) created_time = int(time.time())
final_res: Optional[RequestOutput] = None final_res: Optional[RequestOutput] = None
......
...@@ -85,7 +85,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -85,7 +85,6 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"suffix is not currently supported") "suffix is not currently supported")
model_name = self.base_model_paths[0].name
request_id = f"cmpl-{self._base_request_id(raw_request)}" request_id = f"cmpl-{self._base_request_id(raw_request)}"
created_time = int(time.time()) created_time = int(time.time())
...@@ -162,6 +161,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -162,6 +161,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators( result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected) *generators, is_cancelled=raw_request.is_disconnected)
model_name = self._get_model_name(lora_request)
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
......
...@@ -661,3 +661,16 @@ class OpenAIServing: ...@@ -661,3 +661,16 @@ class OpenAIServing:
def _is_model_supported(self, model_name): def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths) return any(model.name == model_name for model in self.base_model_paths)
def _get_model_name(self, lora: Optional[LoRARequest]):
"""
Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora is not None:
return lora.lora_name
return self.base_model_paths[0].name
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment