Unverified Commit b56de8f9 authored by kyle-pena-kuzco's avatar kyle-pena-kuzco Committed by GitHub
Browse files

Open AI API hidden states (#6716)

parent ce5ee3bd
...@@ -135,6 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -135,6 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `download_dir` | Overrides the default Hugging Face cache directory for model weights. | None | | `download_dir` | Overrides the default Hugging Face cache directory for model weights. | None |
| `base_gpu_id` | Sets the first GPU to use when distributing the model across multiple GPUs. | `0` | | `base_gpu_id` | Sets the first GPU to use when distributing the model across multiple GPUs. | `0` |
| `allow_auto_truncate`| Automatically truncate requests that exceed the maximum input length. | `False` | | `allow_auto_truncate`| Automatically truncate requests that exceed the maximum input length. | `False` |
| `enable_return_hidden_states` | Enables returning hidden states to the user. | `False` |
## Logging ## Logging
......
...@@ -22,6 +22,7 @@ def main(): ...@@ -22,6 +22,7 @@ def main():
# Create an LLM. # Create an LLM.
llm = sgl.Engine( llm = sgl.Engine(
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct", model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
enable_return_hidden_states=True,
) )
sampling_params = { sampling_params = {
......
...@@ -23,7 +23,7 @@ else: ...@@ -23,7 +23,7 @@ else:
def main(): def main():
# Launch the server # Launch the server
server_process, port = launch_server_cmd( server_process, port = launch_server_cmd(
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --host 0.0.0.0" "python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --enable-return-hidden-states --host 0.0.0.0"
) )
wait_for_server(f"http://localhost:{port}") wait_for_server(f"http://localhost:{port}")
......
...@@ -99,7 +99,7 @@ class GenerateReqInput: ...@@ -99,7 +99,7 @@ class GenerateReqInput:
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
# Whether to return hidden states # Whether to return hidden states
return_hidden_states: bool = False return_hidden_states: Union[List[bool], bool] = False
# For disaggregated inference # For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None bootstrap_host: Optional[Union[List[str], str]] = None
...@@ -409,7 +409,11 @@ class GenerateReqInput: ...@@ -409,7 +409,11 @@ class GenerateReqInput:
if self.custom_logit_processor is not None if self.custom_logit_processor is not None
else None else None
), ),
return_hidden_states=self.return_hidden_states, return_hidden_states=(
self.return_hidden_states[i]
if isinstance(self.return_hidden_states, list)
else self.return_hidden_states
),
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
bootstrap_host=( bootstrap_host=(
self.bootstrap_host[i] if self.bootstrap_host is not None else None self.bootstrap_host[i] if self.bootstrap_host is not None else None
......
...@@ -418,6 +418,20 @@ class TokenizerManager: ...@@ -418,6 +418,20 @@ class TokenizerManager:
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
if isinstance(obj, GenerateReqInput):
return_hidden_states = obj.return_hidden_states
has_return_hidden_states = return_hidden_states == True or (
isinstance(return_hidden_states, list) and any(return_hidden_states)
)
if (
not self.server_args.enable_return_hidden_states
and has_return_hidden_states
):
raise ValueError(
"return_hidden_states=True requires the server to be started "
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
)
if self.log_requests: if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata max_length, skip_names, _ = self.log_request_metadata
logger.info( logger.info(
......
...@@ -235,6 +235,10 @@ class CudaGraphRunner: ...@@ -235,6 +235,10 @@ class CudaGraphRunner:
self.model_runner.server_args.speculative_num_draft_tokens self.model_runner.server_args.speculative_num_draft_tokens
) )
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
if model_runner.server_args.enable_return_hidden_states:
self.capture_hidden_mode = CaptureHiddenMode.FULL
# Attention backend # Attention backend
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs self.max_num_token = self.max_bs * self.num_tokens_per_bs
...@@ -342,11 +346,29 @@ class CudaGraphRunner: ...@@ -342,11 +346,29 @@ class CudaGraphRunner:
else True else True
) )
requested_capture_hidden_mode = max(
forward_batch.capture_hidden_mode,
(
forward_batch.spec_info.capture_hidden_mode
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
is not None
else CaptureHiddenMode.NULL
),
)
capture_hidden_mode_matches = (
requested_capture_hidden_mode == CaptureHiddenMode.NULL
or requested_capture_hidden_mode == self.capture_hidden_mode
)
is_tbo_supported = ( is_tbo_supported = (
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
) )
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported return (
is_bs_supported
and is_encoder_lens_supported
and is_tbo_supported
and capture_hidden_mode_matches
)
def capture(self) -> None: def capture(self) -> None:
profile_context = empty_context() profile_context = empty_context()
...@@ -541,21 +563,34 @@ class CudaGraphRunner: ...@@ -541,21 +563,34 @@ class CudaGraphRunner:
return graph, out return graph, out
def recapture_if_needed(self, forward_batch: ForwardBatch): def recapture_if_needed(self, forward_batch: ForwardBatch):
# If the capture_hidden_mode changes, we need to recapture the graph
hidden_mode_from_spec_info = getattr( # If the required capture_hidden_mode changes, we need to recapture the graph
# These are the different factors that can influence the capture_hidden_mode
capture_hidden_mode_required_by_forward_batch = (
forward_batch.capture_hidden_mode
)
capture_hidden_mode_required_by_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
) )
if ( capture_hidden_mode_required_for_returning_hidden_states = (
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL if self.model_runner.server_args.enable_return_hidden_states
): else CaptureHiddenMode.NULL
self.capture_hidden_mode = CaptureHiddenMode.FULL )
self.capture()
elif ( # Determine the highest capture_hidden_mode required
forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL # (If we have FULL, we can emulate LAST or NULL)
and self.capture_hidden_mode != hidden_mode_from_spec_info # (If we have LAST, we can emulate NULL)
): required_capture_hidden_mode = max(
self.capture_hidden_mode = hidden_mode_from_spec_info capture_hidden_mode_required_by_forward_batch,
capture_hidden_mode_required_by_spec_info,
capture_hidden_mode_required_for_returning_hidden_states,
)
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
if self.capture_hidden_mode != required_capture_hidden_mode:
self.capture_hidden_mode = required_capture_hidden_mode
self.capture() self.capture()
def replay_prepare( def replay_prepare(
......
...@@ -31,6 +31,7 @@ from __future__ import annotations ...@@ -31,6 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from functools import total_ordering
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -117,13 +118,14 @@ class ForwardMode(IntEnum): ...@@ -117,13 +118,14 @@ class ForwardMode(IntEnum):
return self == ForwardMode.DECODE or self == ForwardMode.IDLE return self == ForwardMode.DECODE or self == ForwardMode.IDLE
@total_ordering
class CaptureHiddenMode(IntEnum): class CaptureHiddenMode(IntEnum):
# Do not capture anything. # Do not capture anything.
NULL = auto() NULL = 0
# Capture hidden states of all tokens.
FULL = auto()
# Capture a hidden state of the last token. # Capture a hidden state of the last token.
LAST = auto() LAST = 1
# Capture hidden states of all tokens.
FULL = 2
def need_capture(self): def need_capture(self):
return self != CaptureHiddenMode.NULL return self != CaptureHiddenMode.NULL
...@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum): ...@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum):
def is_last(self): def is_last(self):
return self == CaptureHiddenMode.LAST return self == CaptureHiddenMode.LAST
def __lt__(self, other):
return self.value < other.value
@dataclass @dataclass
class ForwardBatch: class ForwardBatch:
......
...@@ -542,6 +542,7 @@ def v1_generate_request( ...@@ -542,6 +542,7 @@ def v1_generate_request(
logprob_start_lens = [] logprob_start_lens = []
top_logprobs_nums = [] top_logprobs_nums = []
lora_paths = [] lora_paths = []
return_hidden_states = []
for request in all_requests: for request in all_requests:
# NOTE: with openai API, the prompt's logprobs are always not computed # NOTE: with openai API, the prompt's logprobs are always not computed
...@@ -588,6 +589,7 @@ def v1_generate_request( ...@@ -588,6 +589,7 @@ def v1_generate_request(
top_logprobs_nums.append( top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0 request.logprobs if request.logprobs is not None else 0
) )
return_hidden_states.append(request.return_hidden_states)
if len(all_requests) == 1: if len(all_requests) == 1:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
...@@ -599,6 +601,7 @@ def v1_generate_request( ...@@ -599,6 +601,7 @@ def v1_generate_request(
logprob_start_lens = logprob_start_lens[0] logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0] top_logprobs_nums = top_logprobs_nums[0]
lora_paths = lora_paths[0] lora_paths = lora_paths[0]
return_hidden_states = return_hidden_states[0]
else: else:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts} prompt_kwargs = {"text": prompts}
...@@ -615,6 +618,7 @@ def v1_generate_request( ...@@ -615,6 +618,7 @@ def v1_generate_request(
stream=all_requests[0].stream, stream=all_requests[0].stream,
rid=request_ids, rid=request_ids,
lora_path=lora_paths, lora_path=lora_paths,
return_hidden_states=return_hidden_states,
bootstrap_host=all_requests[0].bootstrap_host, bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port, bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room, bootstrap_room=all_requests[0].bootstrap_room,
...@@ -683,6 +687,16 @@ def v1_generate_response( ...@@ -683,6 +687,16 @@ def v1_generate_response(
else: else:
logprobs = None logprobs = None
hidden_states = None
if isinstance(request, list) and request[idx].return_hidden_states:
hidden_states = ret_item["meta_info"].get("hidden_states", None)
elif (not isinstance(request, list)) and request.return_hidden_states:
hidden_states = ret_item["meta_info"].get("hidden_states", None)
if hidden_states is not None:
hidden_states = (
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
)
finish_reason = ret_item["meta_info"]["finish_reason"] finish_reason = ret_item["meta_info"]["finish_reason"]
if to_file: if to_file:
...@@ -698,6 +712,8 @@ def v1_generate_response( ...@@ -698,6 +712,8 @@ def v1_generate_response(
else None else None
), ),
} }
if hidden_states is not None:
choice_data["hidden_states"] = hidden_states
else: else:
choice_data = CompletionResponseChoice( choice_data = CompletionResponseChoice(
index=idx, index=idx,
...@@ -709,6 +725,7 @@ def v1_generate_response( ...@@ -709,6 +725,7 @@ def v1_generate_response(
if finish_reason and "matched" in finish_reason if finish_reason and "matched" in finish_reason
else None else None
), ),
hidden_states=hidden_states,
) )
choices.append(choice_data) choices.append(choice_data)
...@@ -777,6 +794,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -777,6 +794,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
prompt_tokens = {} prompt_tokens = {}
completion_tokens = {} completion_tokens = {}
cached_tokens = {} cached_tokens = {}
hidden_states = {}
try: try:
async for content in tokenizer_manager.generate_request( async for content in tokenizer_manager.generate_request(
...@@ -791,6 +809,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -791,6 +809,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
prompt_tokens[index] = content["meta_info"]["prompt_tokens"] prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get(
"hidden_states", None
) or hidden_states.get(index)
if not stream_buffer: # The first chunk if not stream_buffer: # The first chunk
if request.echo: if request.echo:
...@@ -873,6 +894,27 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -873,6 +894,27 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
n_prev_tokens[index] = n_prev_token n_prev_tokens[index] = n_prev_token
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
last_token_hidden_states = (
choice_hidden_states[-1]
if choice_hidden_states and len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[
CompletionResponseStreamChoice(
text="",
index=index,
hidden_states=last_token_hidden_states,
finish_reason=None,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
if request.stream_options and request.stream_options.include_usage: if request.stream_options and request.stream_options.include_usage:
total_prompt_tokens = sum( total_prompt_tokens = sum(
tokens tokens
...@@ -973,6 +1015,7 @@ def v1_chat_generate_request( ...@@ -973,6 +1015,7 @@ def v1_chat_generate_request(
top_logprobs_nums = [] top_logprobs_nums = []
modalities_list = [] modalities_list = []
lora_paths = [] lora_paths = []
return_hidden_states = []
# NOTE: with openai API, the prompt's logprobs are always not computed # NOTE: with openai API, the prompt's logprobs are always not computed
...@@ -1215,6 +1258,7 @@ def v1_chat_generate_request( ...@@ -1215,6 +1258,7 @@ def v1_chat_generate_request(
image_data_list.append(image_data) image_data_list.append(image_data)
audio_data_list.append(audio_data) audio_data_list.append(audio_data)
modalities_list.append(modalities) modalities_list.append(modalities)
return_hidden_states.append(request.return_hidden_states)
if len(all_requests) == 1: if len(all_requests) == 1:
if is_multimodal: if is_multimodal:
# processor will need text input # processor will need text input
...@@ -1233,6 +1277,7 @@ def v1_chat_generate_request( ...@@ -1233,6 +1277,7 @@ def v1_chat_generate_request(
modalities_list = modalities_list[0] modalities_list = modalities_list[0]
lora_paths = lora_paths[0] lora_paths = lora_paths[0]
request_ids = request_ids[0] request_ids = request_ids[0]
return_hidden_states = return_hidden_states[0]
else: else:
if tokenizer_manager.model_config.is_multimodal: if tokenizer_manager.model_config.is_multimodal:
# processor will need text input # processor will need text input
...@@ -1259,6 +1304,7 @@ def v1_chat_generate_request( ...@@ -1259,6 +1304,7 @@ def v1_chat_generate_request(
bootstrap_host=all_requests[0].bootstrap_host, bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port, bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room, bootstrap_room=all_requests[0].bootstrap_room,
return_hidden_states=return_hidden_states,
) )
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
...@@ -1319,6 +1365,20 @@ def v1_chat_generate_response( ...@@ -1319,6 +1365,20 @@ def v1_chat_generate_response(
else: else:
choice_logprobs = None choice_logprobs = None
if isinstance(request, list) and request[idx].return_hidden_states:
include_hidden_states = True
elif not isinstance(request, list) and request.return_hidden_states:
include_hidden_states = True
else:
include_hidden_states = False
if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
hidden_states = ret_item["meta_info"]["hidden_states"]
hidden_states = (
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
)
else:
hidden_states = None
finish_reason = ret_item["meta_info"]["finish_reason"] finish_reason = ret_item["meta_info"]["finish_reason"]
tool_calls = None tool_calls = None
...@@ -1391,6 +1451,8 @@ def v1_chat_generate_response( ...@@ -1391,6 +1451,8 @@ def v1_chat_generate_response(
else None else None
), ),
} }
if hidden_states is not None:
choice_data["hidden_states"] = hidden_states
else: else:
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=idx, index=idx,
...@@ -1407,6 +1469,7 @@ def v1_chat_generate_response( ...@@ -1407,6 +1469,7 @@ def v1_chat_generate_response(
if finish_reason and "matched" in finish_reason if finish_reason and "matched" in finish_reason
else None else None
), ),
hidden_states=hidden_states,
) )
choices.append(choice_data) choices.append(choice_data)
...@@ -1486,12 +1549,16 @@ async def v1_chat_completions( ...@@ -1486,12 +1549,16 @@ async def v1_chat_completions(
prompt_tokens = {} prompt_tokens = {}
completion_tokens = {} completion_tokens = {}
cached_tokens = {} cached_tokens = {}
hidden_states = {}
try: try:
async for content in tokenizer_manager.generate_request( async for content in tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
): ):
index = content.get("index", 0) index = content.get("index", 0)
text = content["text"] text = content["text"]
hidden_states[index] = content["meta_info"].get(
"hidden_states", None
) or hidden_states.get(index)
is_first = is_firsts.get(index, True) is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "") stream_buffer = stream_buffers.get(index, "")
...@@ -1613,6 +1680,7 @@ async def v1_chat_completions( ...@@ -1613,6 +1680,7 @@ async def v1_chat_completions(
if (delta and len(delta) == 0) or not delta: if (delta and len(delta) == 0) or not delta:
stream_buffers[index] = new_stream_buffer stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
continue continue
if request.tool_choice != "none" and request.tools: if request.tool_choice != "none" and request.tools:
...@@ -1702,6 +1770,7 @@ async def v1_chat_completions( ...@@ -1702,6 +1770,7 @@ async def v1_chat_completions(
stream_buffers[index] = new_stream_buffer stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
else: else:
# No tool calls => just treat this as normal text # No tool calls => just treat this as normal text
...@@ -1734,6 +1803,7 @@ async def v1_chat_completions( ...@@ -1734,6 +1803,7 @@ async def v1_chat_completions(
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
if finish_reason_type == "stop" and request.tool_choice != "none": if finish_reason_type == "stop" and request.tool_choice != "none":
parser = FunctionCallParser( parser = FunctionCallParser(
tools=request.tools, tools=request.tools,
...@@ -1769,6 +1839,28 @@ async def v1_chat_completions( ...@@ -1769,6 +1839,28 @@ async def v1_chat_completions(
else: else:
usage = None usage = None
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
last_token_hidden_states = (
choice_hidden_states[-1]
if choice_hidden_states and len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
hidden_states=last_token_hidden_states
),
finish_reason=finish_reason_type,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
final_usage_chunk = ChatCompletionStreamResponse( final_usage_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"], id=content["meta_info"]["id"],
created=created, created=created,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator from pydantic import BaseModel, Field, model_serializer, root_validator
from typing_extensions import Literal from typing_extensions import Literal
...@@ -182,6 +182,7 @@ class CompletionRequest(BaseModel): ...@@ -182,6 +182,7 @@ class CompletionRequest(BaseModel):
skip_special_tokens: bool = True skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None session_params: Optional[Dict] = None
return_hidden_states: Optional[bool] = False
# For PD disaggregation # For PD disaggregation
bootstrap_host: Optional[str] = None bootstrap_host: Optional[str] = None
...@@ -195,6 +196,11 @@ class CompletionResponseChoice(BaseModel): ...@@ -195,6 +196,11 @@ class CompletionResponseChoice(BaseModel):
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"] finish_reason: Literal["stop", "length", "content_filter", "abort"]
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
...@@ -212,6 +218,11 @@ class CompletionResponseStreamChoice(BaseModel): ...@@ -212,6 +218,11 @@ class CompletionResponseStreamChoice(BaseModel):
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionStreamResponse(BaseModel): class CompletionStreamResponse(BaseModel):
...@@ -405,6 +416,9 @@ class ChatCompletionRequest(BaseModel): ...@@ -405,6 +416,9 @@ class ChatCompletionRequest(BaseModel):
bootstrap_port: Optional[int] = None bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None bootstrap_room: Optional[int] = None
# Hidden States
return_hidden_states: Optional[bool] = False
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Optional[str] = None role: Optional[str] = None
...@@ -421,6 +435,11 @@ class ChatCompletionResponseChoice(BaseModel): ...@@ -421,6 +435,11 @@ class ChatCompletionResponseChoice(BaseModel):
"stop", "length", "tool_calls", "content_filter", "function_call", "abort" "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
] ]
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
...@@ -437,6 +456,11 @@ class DeltaMessage(BaseModel): ...@@ -437,6 +456,11 @@ class DeltaMessage(BaseModel):
content: Optional[str] = None content: Optional[str] = None
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
...@@ -513,3 +537,8 @@ class ScoringResponse(BaseModel): ...@@ -513,3 +537,8 @@ class ScoringResponse(BaseModel):
model: str model: str
usage: Optional[UsageInfo] = None usage: Optional[UsageInfo] = None
object: str = "scoring" object: str = "scoring"
def exclude_if_none(obj, field_names: List[str]):
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
...@@ -215,6 +215,7 @@ class ServerArgs: ...@@ -215,6 +215,7 @@ class ServerArgs:
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False disable_fast_image_processor: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
enable_return_hidden_states: bool = False
# Debug tensor dumps # Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
...@@ -1456,6 +1457,12 @@ class ServerArgs: ...@@ -1456,6 +1457,12 @@ class ServerArgs:
default=ServerArgs.debug_tensor_dump_inject, default=ServerArgs.debug_tensor_dump_inject,
help="Inject the outputs from jax as the input of every layer.", help="Inject the outputs from jax as the input of every layer.",
) )
parser.add_argument(
"--enable-return-hidden-states",
action="store_true",
help="Enable returning hidden states with responses.",
)
parser.add_argument( parser.add_argument(
"--debug-tensor-dump-prefill-only", "--debug-tensor-dump-prefill-only",
action="store_true", action="store_true",
......
...@@ -117,9 +117,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -117,9 +117,7 @@ class EAGLEDraftCudaGraphRunner:
hidden_states = self.hidden_states[:num_seqs] hidden_states = self.hidden_states[:num_seqs]
spec_info = EagleDraftInput( spec_info = EagleDraftInput(
topk_p=topk_p, topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states
topk_index=topk_index,
hidden_states=hidden_states,
) )
# Forward batch # Forward batch
......
...@@ -290,6 +290,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -290,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted, A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens. the batch id (used for overlap schedule), and number of accepted tokens.
""" """
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch) spec_info = self.draft(batch)
...@@ -431,10 +432,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -431,10 +432,10 @@ class EAGLEWorker(TpModelWorker):
batch.out_cache_loc = out_cache_loc batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
# Get forward batch
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -547,11 +548,13 @@ class EAGLEWorker(TpModelWorker): ...@@ -547,11 +548,13 @@ class EAGLEWorker(TpModelWorker):
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size) spec_info.prepare_for_verify(batch, self.page_size)
batch.return_hidden_states = False
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch( model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu seq_lens_cpu_cache=spec_info.seq_lens_cpu
) )
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
if batch.has_grammar: if batch.has_grammar:
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu() retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
...@@ -687,15 +690,18 @@ class EAGLEWorker(TpModelWorker): ...@@ -687,15 +690,18 @@ class EAGLEWorker(TpModelWorker):
hidden_states: Hidden states from the target model forward hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward. next_token_ids: Next token ids generated from the target forward.
""" """
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
batch.spec_info = EagleDraftInput( batch.spec_info = EagleDraftInput(
hidden_states=hidden_states, hidden_states=hidden_states,
verified_id=next_token_ids, verified_id=next_token_ids,
) )
batch.return_hidden_states = False
batch.spec_info.prepare_for_extend(batch) batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch( model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu seq_lens_cpu_cache=seq_lens_cpu
) )
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -718,7 +724,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -718,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
batch, batch,
self.speculative_num_steps, self.speculative_num_steps,
) )
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
......
...@@ -59,6 +59,7 @@ suites = { ...@@ -59,6 +59,7 @@ suites = {
TestFile("test_openai_adapter.py", 1), TestFile("test_openai_adapter.py", 1),
TestFile("test_openai_function_calling.py", 60), TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149), TestFile("test_openai_server.py", 149),
TestFile("test_openai_server_hidden_states.py", 240),
TestFile("test_penalty.py", 41), TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60), TestFile("test_page_size.py", 60),
TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_pytorch_sampling_backend.py", 66),
......
...@@ -23,6 +23,7 @@ class TestHiddenState(CustomTestCase): ...@@ -23,6 +23,7 @@ class TestHiddenState(CustomTestCase):
model_path=model_path, model_path=model_path,
random_seed=42, random_seed=42,
skip_tokenizer_init=True, skip_tokenizer_init=True,
enable_return_hidden_states=True,
) )
outputs = engine.generate( outputs = engine.generate(
input_ids=input_ids, input_ids=input_ids,
...@@ -96,6 +97,7 @@ class TestHiddenState(CustomTestCase): ...@@ -96,6 +97,7 @@ class TestHiddenState(CustomTestCase):
model_path=model_path, model_path=model_path,
random_seed=42, random_seed=42,
skip_tokenizer_init=True, skip_tokenizer_init=True,
enable_return_hidden_states=True,
) )
outputs_completion_first_round = engine.generate( outputs_completion_first_round = engine.generate(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -381,12 +381,14 @@ class TestGenerateReqInputNormalization(CustomTestCase): ...@@ -381,12 +381,14 @@ class TestGenerateReqInputNormalization(CustomTestCase):
logprob_start_len=[10, 5], logprob_start_len=[10, 5],
top_logprobs_num=[5, 3], top_logprobs_num=[5, 3],
token_ids_logprob=[[7, 8, 9], [4, 5, 6]], token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
return_hidden_states=[False, False, True],
) )
req.normalize_batch_and_arguments() req.normalize_batch_and_arguments()
self.assertEqual(req.return_logprob, [True, False]) self.assertEqual(req.return_logprob, [True, False])
self.assertEqual(req.logprob_start_len, [10, 5]) self.assertEqual(req.logprob_start_len, [10, 5])
self.assertEqual(req.top_logprobs_num, [5, 3]) self.assertEqual(req.top_logprobs_num, [5, 3])
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]]) self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]])
self.assertEqual(req.return_hidden_states, [False, False, True])
def test_custom_logit_processor_normalization(self): def test_custom_logit_processor_normalization(self):
"""Test normalization of custom_logit_processor.""" """Test normalization of custom_logit_processor."""
......
""" """
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream
""" """
import json import json
...@@ -9,6 +11,7 @@ import re ...@@ -9,6 +11,7 @@ import re
import time import time
import unittest import unittest
import numpy as np
import openai import openai
import requests import requests
...@@ -137,27 +140,29 @@ class TestOpenAIServer(CustomTestCase): ...@@ -137,27 +140,29 @@ class TestOpenAIServer(CustomTestCase):
for response in generator: for response in generator:
usage = response.usage usage = response.usage
if usage is not None: if usage is not None:
assert usage.prompt_tokens > 0 assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
assert usage.completion_tokens > 0 assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
assert usage.total_tokens > 0 assert usage.total_tokens > 0, f"usage.total_tokens was zero"
continue continue
index = response.choices[0].index index = response.choices[0].index
is_first = is_firsts.get(index, True) is_first = is_firsts.get(index, True)
if logprobs: if logprobs:
assert response.choices[0].logprobs assert response.choices[0].logprobs, f"no logprobs in response"
assert isinstance(response.choices[0].logprobs.tokens[0], str) assert isinstance(
response.choices[0].logprobs.tokens[0], str
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
if not (is_first and echo): if not (is_first and echo):
assert isinstance( assert isinstance(
response.choices[0].logprobs.top_logprobs[0], dict response.choices[0].logprobs.top_logprobs[0], dict
) ), f"top_logprobs was not a dictionary"
ret_num_top_logprobs = len( ret_num_top_logprobs = len(
response.choices[0].logprobs.top_logprobs[0] response.choices[0].logprobs.top_logprobs[0]
) )
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0 assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
if is_first: if is_first:
if echo: if echo:
...@@ -165,8 +170,8 @@ class TestOpenAIServer(CustomTestCase): ...@@ -165,8 +170,8 @@ class TestOpenAIServer(CustomTestCase):
prompt prompt
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}" ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
is_firsts[index] = False is_firsts[index] = False
assert response.id assert response.id, f"no id in response"
assert response.created assert response.created, f"no created in response"
for index in [i for i in range(parallel_sample_num * num_choices)]: for index in [i for i in range(parallel_sample_num * num_choices)]:
assert not is_firsts.get( assert not is_firsts.get(
...@@ -231,27 +236,29 @@ class TestOpenAIServer(CustomTestCase): ...@@ -231,27 +236,29 @@ class TestOpenAIServer(CustomTestCase):
for response in generator: for response in generator:
usage = response.usage usage = response.usage
if usage is not None: if usage is not None:
assert usage.prompt_tokens > 0 assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
assert usage.completion_tokens > 0 assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
assert usage.total_tokens > 0 assert usage.total_tokens > 0, f"usage.total_tokens was zero"
continue continue
index = response.choices[0].index index = response.choices[0].index
data = response.choices[0].delta data = response.choices[0].delta
if is_firsts.get(index, True): if is_firsts.get(index, True):
assert data.role == "assistant" assert (
data.role == "assistant"
), f"data.role was not 'assistant' for first chunk"
is_firsts[index] = False is_firsts[index] = False
continue continue
if logprobs: if logprobs:
assert response.choices[0].logprobs assert response.choices[0].logprobs, f"logprobs was not returned"
assert isinstance( assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str response.choices[0].logprobs.content[0].top_logprobs[0].token, str
) ), f"top_logprobs token was not a string"
assert isinstance( assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs, list response.choices[0].logprobs.content[0].top_logprobs, list
) ), f"top_logprobs was not a list"
ret_num_top_logprobs = len( ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs response.choices[0].logprobs.content[0].top_logprobs
) )
......
import json
import re
import time
import unittest
from abc import ABC
import numpy as np
import openai
import torch
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class BaseTestOpenAIServerWithHiddenStates(ABC):
@classmethod
def setUpClass(cls):
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1, 2]
def test_completion(self):
for return_hidden_states in self.return_hidden_states:
for use_list_input in self.use_list_input:
for parallel_sample_num in self.parallel_sample_nums:
self.run_completion(
use_list_input,
parallel_sample_num,
return_hidden_states,
)
def test_completion_stream(self):
# parallel sampling and list input are not supported in streaming mode
for return_hidden_states in self.return_hidden_states:
for use_list_input in self.use_list_input:
for parallel_sample_num in self.parallel_sample_nums:
self.run_completion_stream(
use_list_input,
parallel_sample_num,
return_hidden_states,
)
def test_chat_completion(self):
for return_hidden_states in self.return_hidden_states:
for (
parallel_sample_num
) in (
self.parallel_sample_nums
): # parallel sample num 2 breaks in the adapter with a 400 for EAGLE
self.run_chat_completion(parallel_sample_num, return_hidden_states)
def test_chat_completion_stream(self):
for return_hidden_states in self.return_hidden_states:
for (
parallel_sample_num
) in (
self.parallel_sample_nums
): # parallel sample num > 1 breaks in the adapter with a 400 for EAGLE
self.run_chat_completion_stream(
parallel_sample_num, return_hidden_states
)
def run_completion(
self,
use_list_input,
parallel_sample_num,
return_hidden_states,
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
prompt_input = prompt
if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
else:
prompt_arg = prompt_input
num_choices = 1
response = client.completions.create(
model=self.model,
prompt=prompt_arg,
temperature=0,
max_tokens=32,
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
for choice in response.choices:
assert hasattr(choice, "hidden_states") == return_hidden_states
if return_hidden_states:
assert choice.hidden_states is not None, "hidden_states was None"
def run_completion_stream(
self,
use_list_input,
parallel_sample_num,
return_hidden_states,
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))
if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
num_prompt_tokens *= 2
else:
prompt_arg = prompt_input
num_choices = 1
generator = client.completions.create(
model=self.model,
prompt=prompt_arg,
temperature=0,
max_tokens=32,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
hidden_states_list = []
for response in generator:
usage = response.usage
for choice in response.choices:
if hasattr(choice, "hidden_states"):
assert return_hidden_states
assert choice.hidden_states is not None
hidden_states_list.append(choice.hidden_states)
if return_hidden_states:
assert (
len(hidden_states_list) == parallel_sample_num * num_choices
), f"Expected {parallel_sample_num * num_choices} hidden states, got {len(hidden_states_list)}"
else:
assert (
hidden_states_list == []
), "hidden_states were returned and should not have been"
def run_chat_completion(self, parallel_sample_num, return_hidden_states):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "What is the capital of France? Answer in a few words.",
},
],
temperature=0,
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
for choice in response.choices:
assert hasattr(choice, "hidden_states") == return_hidden_states
if return_hidden_states:
assert choice.hidden_states is not None, "hidden_states was None"
def run_chat_completion_stream(
self, parallel_sample_num=1, return_hidden_states=False
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
is_firsts = {}
hidden_states_list = []
for response in generator:
for choice in response.choices:
if hasattr(choice.delta, "hidden_states"):
assert return_hidden_states
assert choice.delta.hidden_states is not None
hidden_states_list.append(choice.delta.hidden_states)
if return_hidden_states:
assert (
len(hidden_states_list) == parallel_sample_num
), f"Expected {parallel_sample_num} hidden states, got {len(hidden_states_list)}"
else:
assert (
hidden_states_list == []
), "hidden_states were returned and should not have been"
class TestOpenAIServerWithHiddenStatesEnabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=["--enable-return-hidden-states"],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1, 2]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestOpenAIServerWithHiddenStatesEnabledAndCUDAGraphDisabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=["--enable-return-hidden-states", "--disable-cuda-graph"],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestOpenAIServerWithEAGLEAndHiddenStatesEnabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.speculative_draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
cls.speculative_algorithm = "EAGLE"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
8,
"--speculative-num-draft-tokens",
64,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--enable-return-hidden-states",
],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestOpenAIServerWithEAGLE3AndHiddenStatesEnabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = "meta-llama/Llama-3.1-8B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.speculative_algorithm = "EAGLE3"
cls.speculative_draft_model = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
cls.speculative_algorithm,
"--speculative-draft-model-path",
cls.speculative_draft_model,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
16,
"--speculative-num-draft-tokens",
64,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--dtype",
"float16",
"--enable-return-hidden-states",
],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment