Unverified Commit c1eda615 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix model name included in responses (#24663)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 4aa23892
...@@ -12,7 +12,7 @@ import pytest_asyncio ...@@ -12,7 +12,7 @@ import pytest_asyncio
import regex as re import regex as re
import requests import requests
import torch import torch
from openai import BadRequestError, OpenAI from openai import BadRequestError
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
...@@ -968,59 +968,6 @@ async def test_long_seed(client: openai.AsyncOpenAI): ...@@ -968,59 +968,6 @@ async def test_long_seed(client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message) or "less_than_equal" in exc_info.value.message)
@pytest.mark.asyncio
async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
url = f"http://localhost:{server.port}/v1/chat/completions"
headers = {
"Content-Type": "application/json",
}
data = {
# model_name is avoided here.
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "what is 1+1?"
}],
"max_tokens":
5
}
response = requests.post(url, headers=headers, json=data)
response_data = response.json()
print(response_data)
assert response_data.get("model") == MODEL_NAME
choice = response_data.get("choices")[0]
message = choice.get("message")
assert message is not None
content = message.get("content")
assert content is not None
assert len(content) > 0
@pytest.mark.asyncio
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:{server.port}/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
messages = [
{
"role": "user",
"content": "Hello, vLLM!"
},
]
response = client.chat.completions.create(
model="", # empty string
messages=messages,
)
assert response.model == MODEL_NAME
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer, async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI): client: openai.AsyncOpenAI):
......
...@@ -213,8 +213,12 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, ...@@ -213,8 +213,12 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}" CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT)
]
@dataclass @dataclass
...@@ -270,6 +274,42 @@ def test_async_serving_chat_init(): ...@@ -270,6 +274,42 @@ def test_async_serving_chat_init():
assert serving_completion.chat_template == CHAT_TEMPLATE assert serving_completion.chat_template == CHAT_TEMPLATE
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
messages = [{"role": "user", "content": "what is 1+1?"}]
async def return_model_name(*args):
return args[3]
serving_chat.chat_completion_full_generator = return_model_name
# Test that full name is returned when short name is requested
req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
# Test that full name is returned when empty string is specified
req = ChatCompletionRequest(model="", messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
# Test that full name is returned when no model is specified
req = ChatCompletionRequest(messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens(): async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine = MagicMock(spec=MQLLMEngineClient)
......
...@@ -186,7 +186,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -186,7 +186,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request = self._maybe_get_adapters( lora_request = self._maybe_get_adapters(
request, supports_default_mm_loras=True) request, supports_default_mm_loras=True)
model_name = self._get_model_name(request.model, lora_request) model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
......
...@@ -146,7 +146,7 @@ class ServingClassification(ClassificationMixin): ...@@ -146,7 +146,7 @@ class ServingClassification(ClassificationMixin):
request: ClassificationRequest, request: ClassificationRequest,
raw_request: Request, raw_request: Request,
) -> Union[ClassificationResponse, ErrorResponse]: ) -> Union[ClassificationResponse, ErrorResponse]:
model_name = self._get_model_name(request.model) model_name = self.models.model_name()
request_id = (f"{self.request_id_prefix}-" request_id = (f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request)}") f"{self._base_request_id(raw_request)}")
......
...@@ -232,7 +232,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -232,7 +232,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
model_name = self._get_model_name(request.model, lora_request) model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
......
...@@ -599,7 +599,7 @@ class OpenAIServingEmbedding(EmbeddingMixin): ...@@ -599,7 +599,7 @@ class OpenAIServingEmbedding(EmbeddingMixin):
See https://platform.openai.com/docs/api-reference/embeddings/create See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API. for the API specification. This API mimics the OpenAI Embedding API.
""" """
model_name = self._get_model_name(request.model) model_name = self.models.model_name()
request_id = ( request_id = (
f"{self.request_id_prefix}-" f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request, request.request_id)}") f"{self._base_request_id(raw_request, request.request_id)}")
......
...@@ -980,17 +980,6 @@ class OpenAIServing: ...@@ -980,17 +980,6 @@ class OpenAIServing:
return True return True
return self.models.is_base_model(model_name) return self.models.is_base_model(model_name)
def _get_model_name(
self,
model_name: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
) -> str:
if lora_request:
return lora_request.lora_name
if not model_name:
return self.models.base_model_paths[0].name
return model_name
def clamp_prompt_logprobs( def clamp_prompt_logprobs(
prompt_logprobs: Union[PromptLogprobs, prompt_logprobs: Union[PromptLogprobs,
......
...@@ -91,7 +91,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -91,7 +91,7 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
model_name = self._get_model_name(request.model) model_name = self.models.model_name()
request_id = f"pool-{self._base_request_id(raw_request)}" request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time()) created_time = int(time.time())
......
...@@ -237,7 +237,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -237,7 +237,7 @@ class OpenAIServingResponses(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
model_name = self._get_model_name(request.model, lora_request) model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
if self.use_harmony: if self.use_harmony:
......
...@@ -353,7 +353,7 @@ class ServingScores(OpenAIServing): ...@@ -353,7 +353,7 @@ class ServingScores(OpenAIServing):
final_res_batch, final_res_batch,
request_id, request_id,
created_time, created_time,
self._get_model_name(request.model), self.models.model_name(),
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
...@@ -399,7 +399,7 @@ class ServingScores(OpenAIServing): ...@@ -399,7 +399,7 @@ class ServingScores(OpenAIServing):
return self.request_output_to_rerank_response( return self.request_output_to_rerank_response(
final_res_batch, final_res_batch,
request_id, request_id,
self._get_model_name(request.model), self.models.model_name(),
documents, documents,
top_n, top_n,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment