Unverified Commit b1465557 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Revert "Implement `return_hidden_states` for the OpenAI API (#6137)" (#6440)

parent b06215da
...@@ -531,7 +531,6 @@ def v1_generate_request( ...@@ -531,7 +531,6 @@ def v1_generate_request(
logprob_start_lens = [] logprob_start_lens = []
top_logprobs_nums = [] top_logprobs_nums = []
lora_paths = [] lora_paths = []
return_hidden_states = []
for request in all_requests: for request in all_requests:
# NOTE: with openai API, the prompt's logprobs are always not computed # NOTE: with openai API, the prompt's logprobs are always not computed
...@@ -578,7 +577,6 @@ def v1_generate_request( ...@@ -578,7 +577,6 @@ def v1_generate_request(
top_logprobs_nums.append( top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0 request.logprobs if request.logprobs is not None else 0
) )
return_hidden_states.append(request.return_hidden_states)
if len(all_requests) == 1: if len(all_requests) == 1:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
...@@ -590,7 +588,6 @@ def v1_generate_request( ...@@ -590,7 +588,6 @@ def v1_generate_request(
logprob_start_lens = logprob_start_lens[0] logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0] top_logprobs_nums = top_logprobs_nums[0]
lora_paths = lora_paths[0] lora_paths = lora_paths[0]
return_hidden_states = return_hidden_states[0]
else: else:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts} prompt_kwargs = {"text": prompts}
...@@ -607,7 +604,6 @@ def v1_generate_request( ...@@ -607,7 +604,6 @@ def v1_generate_request(
stream=all_requests[0].stream, stream=all_requests[0].stream,
rid=request_ids, rid=request_ids,
lora_path=lora_paths, lora_path=lora_paths,
return_hidden_states=return_hidden_states,
) )
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
...@@ -673,17 +669,6 @@ def v1_generate_response( ...@@ -673,17 +669,6 @@ def v1_generate_response(
else: else:
logprobs = None logprobs = None
hidden_states = None
if isinstance(request, list) and request[idx].return_hidden_states:
hidden_states = ret_item["meta_info"].get("hidden_states", None)
elif (not isinstance(request, list)) and request.return_hidden_states:
hidden_states = ret_item["meta_info"].get("hidden_states", None)
if hidden_states is not None:
hidden_states = hidden_states[1:] # trim off the prefill
hidden_states = (
hidden_states[-1] if len(hidden_states) > 0 else []
) # slice out the last token
finish_reason = ret_item["meta_info"]["finish_reason"] finish_reason = ret_item["meta_info"]["finish_reason"]
if to_file: if to_file:
...@@ -710,7 +695,6 @@ def v1_generate_response( ...@@ -710,7 +695,6 @@ def v1_generate_response(
if finish_reason and "matched" in finish_reason if finish_reason and "matched" in finish_reason
else None else None
), ),
hidden_states=hidden_states,
) )
choices.append(choice_data) choices.append(choice_data)
...@@ -735,7 +719,6 @@ def v1_generate_response( ...@@ -735,7 +719,6 @@ def v1_generate_response(
+ ret[i]["meta_info"]["completion_tokens"], + ret[i]["meta_info"]["completion_tokens"],
}, },
"system_fingerprint": None, "system_fingerprint": None,
"hidden_states": hidden_states,
}, },
} }
responses.append(response) responses.append(response)
...@@ -780,7 +763,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -780,7 +763,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
prompt_tokens = {} prompt_tokens = {}
completion_tokens = {} completion_tokens = {}
cached_tokens = {} cached_tokens = {}
hidden_states = None
try: try:
async for content in tokenizer_manager.generate_request( async for content in tokenizer_manager.generate_request(
...@@ -795,9 +777,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -795,9 +777,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
prompt_tokens[index] = content["meta_info"]["prompt_tokens"] prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states = (
content["meta_info"].get("hidden_states", None) or hidden_states
)
if not stream_buffer: # The first chunk if not stream_buffer: # The first chunk
if request.echo: if request.echo:
...@@ -903,25 +882,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -903,25 +882,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
total_tokens=total_prompt_tokens + total_completion_tokens, total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens_details=prompt_tokens_details, prompt_tokens_details=prompt_tokens_details,
) )
if request.return_hidden_states and hidden_states:
hidden_states = hidden_states[1:] # trim off the prefill
hidden_states = (
hidden_states[-1] if len(hidden_states) > 0 else []
) # slice out the last token
hidden_states_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[
CompletionResponseStreamChoice(
text="",
index=index,
hidden_states=hidden_states,
finish_reason=None,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
final_usage_chunk = CompletionStreamResponse( final_usage_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"], id=content["meta_info"]["id"],
created=created, created=created,
...@@ -998,7 +959,6 @@ def v1_chat_generate_request( ...@@ -998,7 +959,6 @@ def v1_chat_generate_request(
top_logprobs_nums = [] top_logprobs_nums = []
modalities_list = [] modalities_list = []
lora_paths = [] lora_paths = []
return_hidden_states = []
# NOTE: with openai API, the prompt's logprobs are always not computed # NOTE: with openai API, the prompt's logprobs are always not computed
...@@ -1216,7 +1176,6 @@ def v1_chat_generate_request( ...@@ -1216,7 +1176,6 @@ def v1_chat_generate_request(
image_data_list.append(image_data) image_data_list.append(image_data)
audio_data_list.append(audio_data) audio_data_list.append(audio_data)
modalities_list.append(modalities) modalities_list.append(modalities)
return_hidden_states.append(request.return_hidden_states)
if len(all_requests) == 1: if len(all_requests) == 1:
if is_multimodal: if is_multimodal:
# processor will need text input # processor will need text input
...@@ -1235,7 +1194,6 @@ def v1_chat_generate_request( ...@@ -1235,7 +1194,6 @@ def v1_chat_generate_request(
modalities_list = modalities_list[0] modalities_list = modalities_list[0]
lora_paths = lora_paths[0] lora_paths = lora_paths[0]
request_ids = request_ids[0] request_ids = request_ids[0]
return_hidden_states = return_hidden_states[0]
else: else:
if tokenizer_manager.model_config.is_multimodal: if tokenizer_manager.model_config.is_multimodal:
# processor will need text input # processor will need text input
...@@ -1262,7 +1220,6 @@ def v1_chat_generate_request( ...@@ -1262,7 +1220,6 @@ def v1_chat_generate_request(
bootstrap_host=all_requests[0].bootstrap_host, bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port, bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room, bootstrap_room=all_requests[0].bootstrap_room,
return_hidden_states=return_hidden_states,
) )
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
...@@ -1323,21 +1280,6 @@ def v1_chat_generate_response( ...@@ -1323,21 +1280,6 @@ def v1_chat_generate_response(
else: else:
choice_logprobs = None choice_logprobs = None
if isinstance(request, list) and request[idx].return_hidden_states:
include_hidden_states = True
elif not isinstance(request, list) and request.return_hidden_states:
include_hidden_states = True
else:
include_hidden_states = False
if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
hidden_states = ret_item["meta_info"]["hidden_states"]
hidden_states = hidden_states[1:] # trim off the prefill
hidden_states = (
hidden_states[-1] if len(hidden_states) > 0 else []
) # slice out the last token
else:
hidden_states = None
finish_reason = ret_item["meta_info"]["finish_reason"] finish_reason = ret_item["meta_info"]["finish_reason"]
tool_calls = None tool_calls = None
...@@ -1402,7 +1344,6 @@ def v1_chat_generate_response( ...@@ -1402,7 +1344,6 @@ def v1_chat_generate_response(
"content": text if text else None, "content": text if text else None,
"tool_calls": tool_calls, "tool_calls": tool_calls,
"reasoning_content": reasoning_text if reasoning_text else None, "reasoning_content": reasoning_text if reasoning_text else None,
"hidden_states": hidden_states,
}, },
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None, "logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
"finish_reason": finish_reason["type"] if finish_reason else None, "finish_reason": finish_reason["type"] if finish_reason else None,
...@@ -1428,7 +1369,6 @@ def v1_chat_generate_response( ...@@ -1428,7 +1369,6 @@ def v1_chat_generate_response(
if finish_reason and "matched" in finish_reason if finish_reason and "matched" in finish_reason
else None else None
), ),
hidden_states=hidden_states,
) )
choices.append(choice_data) choices.append(choice_data)
...@@ -1497,7 +1437,6 @@ async def v1_chat_completions( ...@@ -1497,7 +1437,6 @@ async def v1_chat_completions(
if adapted_request.stream: if adapted_request.stream:
parser_dict = {} parser_dict = {}
reasoning_parser_dict = {} reasoning_parser_dict = {}
hidden_states = None
async def generate_stream_resp(): async def generate_stream_resp():
tool_call_first = True tool_call_first = True
...@@ -1507,16 +1446,12 @@ async def v1_chat_completions( ...@@ -1507,16 +1446,12 @@ async def v1_chat_completions(
prompt_tokens = {} prompt_tokens = {}
completion_tokens = {} completion_tokens = {}
cached_tokens = {} cached_tokens = {}
hidden_states = None
try: try:
async for content in tokenizer_manager.generate_request( async for content in tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
): ):
index = content.get("index", 0) index = content.get("index", 0)
text = content["text"] text = content["text"]
hidden_states = (
content["meta_info"].get("hidden_states", None) or hidden_states
)
is_first = is_firsts.get(index, True) is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "") stream_buffer = stream_buffers.get(index, "")
...@@ -1638,7 +1573,6 @@ async def v1_chat_completions( ...@@ -1638,7 +1573,6 @@ async def v1_chat_completions(
if (delta and len(delta) == 0) or not delta: if (delta and len(delta) == 0) or not delta:
stream_buffers[index] = new_stream_buffer stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
continue continue
if request.tool_choice != "none" and request.tools: if request.tool_choice != "none" and request.tools:
...@@ -1727,7 +1661,6 @@ async def v1_chat_completions( ...@@ -1727,7 +1661,6 @@ async def v1_chat_completions(
stream_buffers[index] = new_stream_buffer stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
else: else:
# No tool calls => just treat this as normal text # No tool calls => just treat this as normal text
...@@ -1760,7 +1693,6 @@ async def v1_chat_completions( ...@@ -1760,7 +1693,6 @@ async def v1_chat_completions(
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
if finish_reason_type == "stop" and request.tool_choice != "none": if finish_reason_type == "stop" and request.tool_choice != "none":
parser = FunctionCallParser( parser = FunctionCallParser(
tools=request.tools, tools=request.tools,
...@@ -1796,24 +1728,6 @@ async def v1_chat_completions( ...@@ -1796,24 +1728,6 @@ async def v1_chat_completions(
else: else:
usage = None usage = None
if request.return_hidden_states and hidden_states:
hidden_states = hidden_states[1:] # trim off the prefill
hidden_states = (
hidden_states[-1] if len(hidden_states) > 0 else []
) # slice out the last token
hidden_states_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(hidden_states=hidden_states),
finish_reason=finish_reason_type,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
final_usage_chunk = ChatCompletionStreamResponse( final_usage_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"], id=content["meta_info"]["id"],
created=created, created=created,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, model_serializer, root_validator from pydantic import BaseModel, Field, root_validator
from typing_extensions import Literal from typing_extensions import Literal
...@@ -182,7 +182,6 @@ class CompletionRequest(BaseModel): ...@@ -182,7 +182,6 @@ class CompletionRequest(BaseModel):
skip_special_tokens: bool = True skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None session_params: Optional[Dict] = None
return_hidden_states: Optional[bool] = False
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
...@@ -191,11 +190,6 @@ class CompletionResponseChoice(BaseModel): ...@@ -191,11 +190,6 @@ class CompletionResponseChoice(BaseModel):
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter"] finish_reason: Literal["stop", "length", "content_filter"]
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
...@@ -213,11 +207,6 @@ class CompletionResponseStreamChoice(BaseModel): ...@@ -213,11 +207,6 @@ class CompletionResponseStreamChoice(BaseModel):
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionStreamResponse(BaseModel): class CompletionStreamResponse(BaseModel):
...@@ -411,9 +400,6 @@ class ChatCompletionRequest(BaseModel): ...@@ -411,9 +400,6 @@ class ChatCompletionRequest(BaseModel):
bootstrap_port: Optional[int] = None bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None bootstrap_room: Optional[int] = None
# Hidden States
return_hidden_states: Optional[bool] = False
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Optional[str] = None role: Optional[str] = None
...@@ -430,11 +416,6 @@ class ChatCompletionResponseChoice(BaseModel): ...@@ -430,11 +416,6 @@ class ChatCompletionResponseChoice(BaseModel):
"stop", "length", "tool_calls", "content_filter", "function_call" "stop", "length", "tool_calls", "content_filter", "function_call"
] ]
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
...@@ -451,11 +432,6 @@ class DeltaMessage(BaseModel): ...@@ -451,11 +432,6 @@ class DeltaMessage(BaseModel):
content: Optional[str] = None content: Optional[str] = None
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
...@@ -508,8 +484,3 @@ class EmbeddingResponse(BaseModel): ...@@ -508,8 +484,3 @@ class EmbeddingResponse(BaseModel):
model: str model: str
object: str = "list" object: str = "list"
usage: Optional[UsageInfo] = None usage: Optional[UsageInfo] = None
def exclude_if_none(obj, field_names: List[str]):
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
""" """
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream
""" """
import json import json
...@@ -11,7 +9,6 @@ import re ...@@ -11,7 +9,6 @@ import re
import time import time
import unittest import unittest
import numpy as np
import openai import openai
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
...@@ -46,13 +43,7 @@ class TestOpenAIServer(CustomTestCase): ...@@ -46,13 +43,7 @@ class TestOpenAIServer(CustomTestCase):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def run_completion( def run_completion(
self, self, echo, logprobs, use_list_input, parallel_sample_num, token_input
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
return_hidden_states,
): ):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is" prompt = "The capital of France is"
...@@ -79,7 +70,6 @@ class TestOpenAIServer(CustomTestCase): ...@@ -79,7 +70,6 @@ class TestOpenAIServer(CustomTestCase):
echo=echo, echo=echo,
logprobs=logprobs, logprobs=logprobs,
n=parallel_sample_num, n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
) )
assert len(response.choices) == num_choices * parallel_sample_num assert len(response.choices) == num_choices * parallel_sample_num
...@@ -110,26 +100,8 @@ class TestOpenAIServer(CustomTestCase): ...@@ -110,26 +100,8 @@ class TestOpenAIServer(CustomTestCase):
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
if return_hidden_states:
hidden_states = response.choices[0].hidden_states
assert hidden_states is not None, "hidden_states was none"
hidden_states = np.asarray(hidden_states)
assert (
len(hidden_states.shape) == 1
), f"hidden_states shape is not correct, was {hidden_states.shape}"
else:
assert not hasattr(
response.choices[0], "hidden_states"
), "hidden_states was returned and should not have been"
def run_completion_stream( def run_completion_stream(
self, self, echo, logprobs, use_list_input, parallel_sample_num, token_input
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
return_hidden_states,
): ):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is" prompt = "The capital of France is"
...@@ -158,44 +130,33 @@ class TestOpenAIServer(CustomTestCase): ...@@ -158,44 +130,33 @@ class TestOpenAIServer(CustomTestCase):
stream=True, stream=True,
stream_options={"include_usage": True}, stream_options={"include_usage": True},
n=parallel_sample_num, n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
) )
is_firsts = {} is_firsts = {}
hidden_states = None
for response in generator: for response in generator:
usage = response.usage usage = response.usage
if usage is not None: if usage is not None:
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero" assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero" assert usage.completion_tokens > 0
assert usage.total_tokens > 0, f"usage.total_tokens was zero" assert usage.total_tokens > 0
continue
if (
hasattr(response.choices[0], "hidden_states")
and response.choices[0].hidden_states is not None
):
hidden_states = response.choices[0].hidden_states
continue continue
index = response.choices[0].index index = response.choices[0].index
is_first = is_firsts.get(index, True) is_first = is_firsts.get(index, True)
if logprobs: if logprobs:
assert response.choices[0].logprobs, f"no logprobs in response" assert response.choices[0].logprobs
assert isinstance( assert isinstance(response.choices[0].logprobs.tokens[0], str)
response.choices[0].logprobs.tokens[0], str
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
if not (is_first and echo): if not (is_first and echo):
assert isinstance( assert isinstance(
response.choices[0].logprobs.top_logprobs[0], dict response.choices[0].logprobs.top_logprobs[0], dict
), f"top_logprobs was not a dictionary" )
ret_num_top_logprobs = len( ret_num_top_logprobs = len(
response.choices[0].logprobs.top_logprobs[0] response.choices[0].logprobs.top_logprobs[0]
) )
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0" assert ret_num_top_logprobs > 0
if is_first: if is_first:
if echo: if echo:
...@@ -203,29 +164,15 @@ class TestOpenAIServer(CustomTestCase): ...@@ -203,29 +164,15 @@ class TestOpenAIServer(CustomTestCase):
prompt prompt
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}" ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
is_firsts[index] = False is_firsts[index] = False
assert response.id, f"no id in response" assert response.id
assert response.created, f"no created in response" assert response.created
for index in [i for i in range(parallel_sample_num * num_choices)]: for index in [i for i in range(parallel_sample_num * num_choices)]:
assert not is_firsts.get( assert not is_firsts.get(
index, True index, True
), f"index {index} is not found in the response" ), f"index {index} is not found in the response"
if return_hidden_states: def run_chat_completion(self, logprobs, parallel_sample_num):
assert hidden_states is not None, "hidden_states is not returned"
try:
hidden_states = np.asarray(hidden_states)
except Exception as e:
raise Exception(f"Failed to convert hidden states to numpy array: {e}")
assert (
len(hidden_states.shape) == 1
), f"hidden_states shape is not correct, was {hidden_states.shape}"
else:
assert (
hidden_states is None
), "hidden_states was returned and should not have been"
def run_chat_completion(self, logprobs, parallel_sample_num, return_hidden_states):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.model, model=self.model,
...@@ -240,7 +187,6 @@ class TestOpenAIServer(CustomTestCase): ...@@ -240,7 +187,6 @@ class TestOpenAIServer(CustomTestCase):
logprobs=logprobs is not None and logprobs > 0, logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs, top_logprobs=logprobs,
n=parallel_sample_num, n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
) )
if logprobs: if logprobs:
...@@ -264,21 +210,7 @@ class TestOpenAIServer(CustomTestCase): ...@@ -264,21 +210,7 @@ class TestOpenAIServer(CustomTestCase):
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
if return_hidden_states: def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
hidden_states = response.choices[0].hidden_states
assert hidden_states is not None, "hidden_states is not returned"
hidden_states = np.asarray(hidden_states)
assert (
len(hidden_states.shape) == 1
), f"hidden_states shape is not correct, was {hidden_states.shape}"
else:
assert not hasattr(
response.choices[0], "hidden_states"
), "hidden_states was returned and should not have been"
def run_chat_completion_stream(
self, logprobs, parallel_sample_num=1, return_hidden_states=False
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create( generator = client.chat.completions.create(
model=self.model, model=self.model,
...@@ -292,55 +224,40 @@ class TestOpenAIServer(CustomTestCase): ...@@ -292,55 +224,40 @@ class TestOpenAIServer(CustomTestCase):
stream=True, stream=True,
stream_options={"include_usage": True}, stream_options={"include_usage": True},
n=parallel_sample_num, n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
) )
is_firsts = {} is_firsts = {}
hidden_states = None
top_logprob_tokens = []
for response in generator: for response in generator:
usage = response.usage usage = response.usage
if usage is not None: if usage is not None:
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero" assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero" assert usage.completion_tokens > 0
assert usage.total_tokens > 0, f"usage.total_tokens was zero" assert usage.total_tokens > 0
continue
if hasattr(response.choices[0].delta, "hidden_states"):
hidden_states = response.choices[0].delta.hidden_states
continue continue
index = response.choices[0].index index = response.choices[0].index
data = response.choices[0].delta data = response.choices[0].delta
if is_firsts.get(index, True): if is_firsts.get(index, True):
assert ( assert data.role == "assistant"
data.role == "assistant"
), f"data.role was not 'assistant' for first chunk"
is_firsts[index] = False is_firsts[index] = False
continue continue
if logprobs: if logprobs:
assert response.choices[0].logprobs, f"logprobs was not returned" assert response.choices[0].logprobs
assert isinstance( assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str response.choices[0].logprobs.content[0].top_logprobs[0].token, str
), f"top_logprobs token was not a string" )
assert isinstance( assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs, list response.choices[0].logprobs.content[0].top_logprobs, list
), f"top_logprobs was not a list" )
ret_num_top_logprobs = len( ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs response.choices[0].logprobs.content[0].top_logprobs
) )
assert ( assert (
ret_num_top_logprobs == logprobs ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}" ), f"{ret_num_top_logprobs} vs {logprobs}"
top_logprob_tokens.append(
response.choices[0].logprobs.content[0].top_logprobs[0].token
)
assert (
len(top_logprob_tokens) <= 2 or len(set(top_logprob_tokens)) > 1
), "Top Logprob tokens should not consistent of the same token repeated"
assert ( assert (
isinstance(data.content, str) isinstance(data.content, str)
or isinstance(data.reasoning_content, str) or isinstance(data.reasoning_content, str)
...@@ -355,20 +272,6 @@ class TestOpenAIServer(CustomTestCase): ...@@ -355,20 +272,6 @@ class TestOpenAIServer(CustomTestCase):
index, True index, True
), f"index {index} is not found in the response" ), f"index {index} is not found in the response"
if return_hidden_states:
assert hidden_states is not None, "hidden_states is not returned"
try:
hidden_states = np.asarray(hidden_states)
except Exception as e:
raise Exception(f"Failed to convert hidden states to numpy array: {e}")
assert (
len(hidden_states.shape) == 1
), f"hidden_states shape is not correct, was {hidden_states.shape}"
else:
assert (
hidden_states is None
), "hidden_states was returned and should not have been"
def _create_batch(self, mode, client): def _create_batch(self, mode, client):
if mode == "completion": if mode == "completion":
input_file_path = "complete_input.jsonl" input_file_path = "complete_input.jsonl"
...@@ -516,7 +419,6 @@ class TestOpenAIServer(CustomTestCase): ...@@ -516,7 +419,6 @@ class TestOpenAIServer(CustomTestCase):
assert del_response.deleted assert del_response.deleted
def test_completion(self): def test_completion(self):
for return_hidden_states in [False, True]:
for echo in [False, True]: for echo in [False, True]:
for logprobs in [None, 5]: for logprobs in [None, 5]:
for use_list_input in [True, False]: for use_list_input in [True, False]:
...@@ -528,12 +430,10 @@ class TestOpenAIServer(CustomTestCase): ...@@ -528,12 +430,10 @@ class TestOpenAIServer(CustomTestCase):
use_list_input, use_list_input,
parallel_sample_num, parallel_sample_num,
token_input, token_input,
return_hidden_states,
) )
def test_completion_stream(self): def test_completion_stream(self):
# parallel sampling and list input are not supported in streaming mode # parallel sampling and list input are not supported in streaming mode
for return_hidden_states in [False, True]:
for echo in [False, True]: for echo in [False, True]:
for logprobs in [None, 5]: for logprobs in [None, 5]:
for use_list_input in [True, False]: for use_list_input in [True, False]:
...@@ -545,24 +445,17 @@ class TestOpenAIServer(CustomTestCase): ...@@ -545,24 +445,17 @@ class TestOpenAIServer(CustomTestCase):
use_list_input, use_list_input,
parallel_sample_num, parallel_sample_num,
token_input, token_input,
return_hidden_states,
) )
def test_chat_completion(self): def test_chat_completion(self):
for return_hidden_states in [False, True]:
for logprobs in [None, 5]: for logprobs in [None, 5]:
for parallel_sample_num in [1, 2]: for parallel_sample_num in [1, 2]:
self.run_chat_completion( self.run_chat_completion(logprobs, parallel_sample_num)
logprobs, parallel_sample_num, return_hidden_states
)
def test_chat_completion_stream(self): def test_chat_completion_stream(self):
for return_hidden_states in [False, True]:
for logprobs in [None, 5]: for logprobs in [None, 5]:
for parallel_sample_num in [1, 2]: for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream( self.run_chat_completion_stream(logprobs, parallel_sample_num)
logprobs, parallel_sample_num, return_hidden_states
)
def test_batch(self): def test_batch(self):
for mode in ["completion", "chat"]: for mode in ["completion", "chat"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment