Unverified Commit b56de8f9 authored by kyle-pena-kuzco's avatar kyle-pena-kuzco Committed by GitHub
Browse files

Open AI API hidden states (#6716)

parent ce5ee3bd
......@@ -135,6 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `download_dir` | Overrides the default Hugging Face cache directory for model weights. | None |
| `base_gpu_id` | Sets the first GPU to use when distributing the model across multiple GPUs. | `0` |
| `allow_auto_truncate`| Automatically truncate requests that exceed the maximum input length. | `False` |
| `enable_return_hidden_states` | Enables returning hidden states to the user. | `False` |
## Logging
......
......@@ -22,6 +22,7 @@ def main():
# Create an LLM.
llm = sgl.Engine(
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
enable_return_hidden_states=True,
)
sampling_params = {
......
......@@ -23,7 +23,7 @@ else:
def main():
# Launch the server
server_process, port = launch_server_cmd(
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --host 0.0.0.0"
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --enable-return-hidden-states --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port}")
......
......@@ -99,7 +99,7 @@ class GenerateReqInput:
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
# Whether to return hidden states
return_hidden_states: bool = False
return_hidden_states: Union[List[bool], bool] = False
# For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None
......@@ -409,7 +409,11 @@ class GenerateReqInput:
if self.custom_logit_processor is not None
else None
),
return_hidden_states=self.return_hidden_states,
return_hidden_states=(
self.return_hidden_states[i]
if isinstance(self.return_hidden_states, list)
else self.return_hidden_states
),
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
bootstrap_host=(
self.bootstrap_host[i] if self.bootstrap_host is not None else None
......
......@@ -418,6 +418,20 @@ class TokenizerManager:
obj.normalize_batch_and_arguments()
if isinstance(obj, GenerateReqInput):
return_hidden_states = obj.return_hidden_states
has_return_hidden_states = return_hidden_states == True or (
isinstance(return_hidden_states, list) and any(return_hidden_states)
)
if (
not self.server_args.enable_return_hidden_states
and has_return_hidden_states
):
raise ValueError(
"return_hidden_states=True requires the server to be started "
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
)
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(
......
......@@ -235,6 +235,10 @@ class CudaGraphRunner:
self.model_runner.server_args.speculative_num_draft_tokens
)
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
if model_runner.server_args.enable_return_hidden_states:
self.capture_hidden_mode = CaptureHiddenMode.FULL
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
......@@ -342,11 +346,29 @@ class CudaGraphRunner:
else True
)
requested_capture_hidden_mode = max(
forward_batch.capture_hidden_mode,
(
forward_batch.spec_info.capture_hidden_mode
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
is not None
else CaptureHiddenMode.NULL
),
)
capture_hidden_mode_matches = (
requested_capture_hidden_mode == CaptureHiddenMode.NULL
or requested_capture_hidden_mode == self.capture_hidden_mode
)
is_tbo_supported = (
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
)
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
return (
is_bs_supported
and is_encoder_lens_supported
and is_tbo_supported
and capture_hidden_mode_matches
)
def capture(self) -> None:
profile_context = empty_context()
......@@ -541,21 +563,34 @@ class CudaGraphRunner:
return graph, out
def recapture_if_needed(self, forward_batch: ForwardBatch):
# If the capture_hidden_mode changes, we need to recapture the graph
hidden_mode_from_spec_info = getattr(
# If the required capture_hidden_mode changes, we need to recapture the graph
# These are the different factors that can influence the capture_hidden_mode
capture_hidden_mode_required_by_forward_batch = (
forward_batch.capture_hidden_mode
)
capture_hidden_mode_required_by_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
if (
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL
):
self.capture_hidden_mode = CaptureHiddenMode.FULL
self.capture()
elif (
forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
and self.capture_hidden_mode != hidden_mode_from_spec_info
):
self.capture_hidden_mode = hidden_mode_from_spec_info
capture_hidden_mode_required_for_returning_hidden_states = (
CaptureHiddenMode.FULL
if self.model_runner.server_args.enable_return_hidden_states
else CaptureHiddenMode.NULL
)
# Determine the highest capture_hidden_mode required
# (If we have FULL, we can emulate LAST or NULL)
# (If we have LAST, we can emulate NULL)
required_capture_hidden_mode = max(
capture_hidden_mode_required_by_forward_batch,
capture_hidden_mode_required_by_spec_info,
capture_hidden_mode_required_for_returning_hidden_states,
)
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
if self.capture_hidden_mode != required_capture_hidden_mode:
self.capture_hidden_mode = required_capture_hidden_mode
self.capture()
def replay_prepare(
......
......@@ -31,6 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum, auto
from functools import total_ordering
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
......@@ -117,13 +118,14 @@ class ForwardMode(IntEnum):
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
@total_ordering
class CaptureHiddenMode(IntEnum):
# Do not capture anything.
NULL = auto()
# Capture hidden states of all tokens.
FULL = auto()
NULL = 0
# Capture a hidden state of the last token.
LAST = auto()
LAST = 1
# Capture hidden states of all tokens.
FULL = 2
def need_capture(self):
return self != CaptureHiddenMode.NULL
......@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum):
def is_last(self):
return self == CaptureHiddenMode.LAST
def __lt__(self, other):
return self.value < other.value
@dataclass
class ForwardBatch:
......
......@@ -542,6 +542,7 @@ def v1_generate_request(
logprob_start_lens = []
top_logprobs_nums = []
lora_paths = []
return_hidden_states = []
for request in all_requests:
# NOTE: with openai API, the prompt's logprobs are always not computed
......@@ -588,6 +589,7 @@ def v1_generate_request(
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
)
return_hidden_states.append(request.return_hidden_states)
if len(all_requests) == 1:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
......@@ -599,6 +601,7 @@ def v1_generate_request(
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
lora_paths = lora_paths[0]
return_hidden_states = return_hidden_states[0]
else:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts}
......@@ -615,6 +618,7 @@ def v1_generate_request(
stream=all_requests[0].stream,
rid=request_ids,
lora_path=lora_paths,
return_hidden_states=return_hidden_states,
bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room,
......@@ -683,6 +687,16 @@ def v1_generate_response(
else:
logprobs = None
hidden_states = None
if isinstance(request, list) and request[idx].return_hidden_states:
hidden_states = ret_item["meta_info"].get("hidden_states", None)
elif (not isinstance(request, list)) and request.return_hidden_states:
hidden_states = ret_item["meta_info"].get("hidden_states", None)
if hidden_states is not None:
hidden_states = (
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
)
finish_reason = ret_item["meta_info"]["finish_reason"]
if to_file:
......@@ -698,6 +712,8 @@ def v1_generate_response(
else None
),
}
if hidden_states is not None:
choice_data["hidden_states"] = hidden_states
else:
choice_data = CompletionResponseChoice(
index=idx,
......@@ -709,6 +725,7 @@ def v1_generate_response(
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)
......@@ -777,6 +794,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
hidden_states = {}
try:
async for content in tokenizer_manager.generate_request(
......@@ -791,6 +809,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get(
"hidden_states", None
) or hidden_states.get(index)
if not stream_buffer: # The first chunk
if request.echo:
......@@ -873,6 +894,27 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
n_prev_tokens[index] = n_prev_token
yield f"data: {chunk.model_dump_json()}\n\n"
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
last_token_hidden_states = (
choice_hidden_states[-1]
if choice_hidden_states and len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[
CompletionResponseStreamChoice(
text="",
index=index,
hidden_states=last_token_hidden_states,
finish_reason=None,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
if request.stream_options and request.stream_options.include_usage:
total_prompt_tokens = sum(
tokens
......@@ -973,6 +1015,7 @@ def v1_chat_generate_request(
top_logprobs_nums = []
modalities_list = []
lora_paths = []
return_hidden_states = []
# NOTE: with openai API, the prompt's logprobs are always not computed
......@@ -1215,6 +1258,7 @@ def v1_chat_generate_request(
image_data_list.append(image_data)
audio_data_list.append(audio_data)
modalities_list.append(modalities)
return_hidden_states.append(request.return_hidden_states)
if len(all_requests) == 1:
if is_multimodal:
# processor will need text input
......@@ -1233,6 +1277,7 @@ def v1_chat_generate_request(
modalities_list = modalities_list[0]
lora_paths = lora_paths[0]
request_ids = request_ids[0]
return_hidden_states = return_hidden_states[0]
else:
if tokenizer_manager.model_config.is_multimodal:
# processor will need text input
......@@ -1259,6 +1304,7 @@ def v1_chat_generate_request(
bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room,
return_hidden_states=return_hidden_states,
)
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
......@@ -1319,6 +1365,20 @@ def v1_chat_generate_response(
else:
choice_logprobs = None
if isinstance(request, list) and request[idx].return_hidden_states:
include_hidden_states = True
elif not isinstance(request, list) and request.return_hidden_states:
include_hidden_states = True
else:
include_hidden_states = False
if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
hidden_states = ret_item["meta_info"]["hidden_states"]
hidden_states = (
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
)
else:
hidden_states = None
finish_reason = ret_item["meta_info"]["finish_reason"]
tool_calls = None
......@@ -1391,6 +1451,8 @@ def v1_chat_generate_response(
else None
),
}
if hidden_states is not None:
choice_data["hidden_states"] = hidden_states
else:
choice_data = ChatCompletionResponseChoice(
index=idx,
......@@ -1407,6 +1469,7 @@ def v1_chat_generate_response(
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)
......@@ -1486,12 +1549,16 @@ async def v1_chat_completions(
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
hidden_states = {}
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
text = content["text"]
hidden_states[index] = content["meta_info"].get(
"hidden_states", None
) or hidden_states.get(index)
is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "")
......@@ -1613,6 +1680,7 @@ async def v1_chat_completions(
if (delta and len(delta) == 0) or not delta:
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
continue
if request.tool_choice != "none" and request.tools:
......@@ -1702,6 +1770,7 @@ async def v1_chat_completions(
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
else:
# No tool calls => just treat this as normal text
......@@ -1734,6 +1803,7 @@ async def v1_chat_completions(
yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
if finish_reason_type == "stop" and request.tool_choice != "none":
parser = FunctionCallParser(
tools=request.tools,
......@@ -1769,6 +1839,28 @@ async def v1_chat_completions(
else:
usage = None
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
last_token_hidden_states = (
choice_hidden_states[-1]
if choice_hidden_states and len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
hidden_states=last_token_hidden_states
),
finish_reason=finish_reason_type,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
final_usage_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
......
......@@ -16,7 +16,7 @@
import time
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field, model_serializer, root_validator
from typing_extensions import Literal
......@@ -182,6 +182,7 @@ class CompletionRequest(BaseModel):
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
return_hidden_states: Optional[bool] = False
# For PD disaggregation
bootstrap_host: Optional[str] = None
......@@ -195,6 +196,11 @@ class CompletionResponseChoice(BaseModel):
logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionResponse(BaseModel):
......@@ -212,6 +218,11 @@ class CompletionResponseStreamChoice(BaseModel):
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionStreamResponse(BaseModel):
......@@ -405,6 +416,9 @@ class ChatCompletionRequest(BaseModel):
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
# Hidden States
return_hidden_states: Optional[bool] = False
class ChatMessage(BaseModel):
role: Optional[str] = None
......@@ -421,6 +435,11 @@ class ChatCompletionResponseChoice(BaseModel):
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponse(BaseModel):
......@@ -437,6 +456,11 @@ class DeltaMessage(BaseModel):
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponseStreamChoice(BaseModel):
......@@ -513,3 +537,8 @@ class ScoringResponse(BaseModel):
model: str
usage: Optional[UsageInfo] = None
object: str = "scoring"
def exclude_if_none(obj, field_names: List[str]):
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
......@@ -215,6 +215,7 @@ class ServerArgs:
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
warmups: Optional[str] = None
enable_return_hidden_states: bool = False
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
......@@ -1456,6 +1457,12 @@ class ServerArgs:
default=ServerArgs.debug_tensor_dump_inject,
help="Inject the outputs from jax as the input of every layer.",
)
parser.add_argument(
"--enable-return-hidden-states",
action="store_true",
help="Enable returning hidden states with responses.",
)
parser.add_argument(
"--debug-tensor-dump-prefill-only",
action="store_true",
......
......@@ -117,9 +117,7 @@ class EAGLEDraftCudaGraphRunner:
hidden_states = self.hidden_states[:num_seqs]
spec_info = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=hidden_states,
topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states
)
# Forward batch
......
......@@ -290,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
......@@ -431,10 +432,10 @@ class EAGLEWorker(TpModelWorker):
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
# Get forward batch
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -547,11 +548,13 @@ class EAGLEWorker(TpModelWorker):
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size)
batch.return_hidden_states = False
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu
)
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
if batch.has_grammar:
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
......@@ -687,15 +690,18 @@ class EAGLEWorker(TpModelWorker):
hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
batch.spec_info = EagleDraftInput(
hidden_states=hidden_states,
verified_id=next_token_ids,
)
batch.return_hidden_states = False
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -718,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
batch,
self.speculative_num_steps,
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......
......@@ -59,6 +59,7 @@ suites = {
TestFile("test_openai_adapter.py", 1),
TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149),
TestFile("test_openai_server_hidden_states.py", 240),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),
TestFile("test_pytorch_sampling_backend.py", 66),
......
......@@ -23,6 +23,7 @@ class TestHiddenState(CustomTestCase):
model_path=model_path,
random_seed=42,
skip_tokenizer_init=True,
enable_return_hidden_states=True,
)
outputs = engine.generate(
input_ids=input_ids,
......@@ -96,6 +97,7 @@ class TestHiddenState(CustomTestCase):
model_path=model_path,
random_seed=42,
skip_tokenizer_init=True,
enable_return_hidden_states=True,
)
outputs_completion_first_round = engine.generate(
input_ids=input_ids,
......
......@@ -381,12 +381,14 @@ class TestGenerateReqInputNormalization(CustomTestCase):
logprob_start_len=[10, 5],
top_logprobs_num=[5, 3],
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
return_hidden_states=[False, False, True],
)
req.normalize_batch_and_arguments()
self.assertEqual(req.return_logprob, [True, False])
self.assertEqual(req.logprob_start_len, [10, 5])
self.assertEqual(req.top_logprobs_num, [5, 3])
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]])
self.assertEqual(req.return_hidden_states, [False, False, True])
def test_custom_logit_processor_normalization(self):
"""Test normalization of custom_logit_processor."""
......
"""
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream
"""
import json
......@@ -9,6 +11,7 @@ import re
import time
import unittest
import numpy as np
import openai
import requests
......@@ -137,27 +140,29 @@ class TestOpenAIServer(CustomTestCase):
for response in generator:
usage = response.usage
if usage is not None:
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens > 0
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
continue
index = response.choices[0].index
is_first = is_firsts.get(index, True)
if logprobs:
assert response.choices[0].logprobs
assert isinstance(response.choices[0].logprobs.tokens[0], str)
assert response.choices[0].logprobs, f"no logprobs in response"
assert isinstance(
response.choices[0].logprobs.tokens[0], str
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
if not (is_first and echo):
assert isinstance(
response.choices[0].logprobs.top_logprobs[0], dict
)
), f"top_logprobs was not a dictionary"
ret_num_top_logprobs = len(
response.choices[0].logprobs.top_logprobs[0]
)
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
if is_first:
if echo:
......@@ -165,8 +170,8 @@ class TestOpenAIServer(CustomTestCase):
prompt
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
is_firsts[index] = False
assert response.id
assert response.created
assert response.id, f"no id in response"
assert response.created, f"no created in response"
for index in [i for i in range(parallel_sample_num * num_choices)]:
assert not is_firsts.get(
......@@ -231,27 +236,29 @@ class TestOpenAIServer(CustomTestCase):
for response in generator:
usage = response.usage
if usage is not None:
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens > 0
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
continue
index = response.choices[0].index
data = response.choices[0].delta
if is_firsts.get(index, True):
assert data.role == "assistant"
assert (
data.role == "assistant"
), f"data.role was not 'assistant' for first chunk"
is_firsts[index] = False
continue
if logprobs:
assert response.choices[0].logprobs
assert response.choices[0].logprobs, f"logprobs was not returned"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
)
), f"top_logprobs token was not a string"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs, list
)
), f"top_logprobs was not a list"
ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs
)
......
import json
import re
import time
import unittest
from abc import ABC
import numpy as np
import openai
import torch
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class BaseTestOpenAIServerWithHiddenStates(ABC):
@classmethod
def setUpClass(cls):
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1, 2]
def test_completion(self):
for return_hidden_states in self.return_hidden_states:
for use_list_input in self.use_list_input:
for parallel_sample_num in self.parallel_sample_nums:
self.run_completion(
use_list_input,
parallel_sample_num,
return_hidden_states,
)
def test_completion_stream(self):
# parallel sampling and list input are not supported in streaming mode
for return_hidden_states in self.return_hidden_states:
for use_list_input in self.use_list_input:
for parallel_sample_num in self.parallel_sample_nums:
self.run_completion_stream(
use_list_input,
parallel_sample_num,
return_hidden_states,
)
def test_chat_completion(self):
for return_hidden_states in self.return_hidden_states:
for (
parallel_sample_num
) in (
self.parallel_sample_nums
): # parallel sample num 2 breaks in the adapter with a 400 for EAGLE
self.run_chat_completion(parallel_sample_num, return_hidden_states)
def test_chat_completion_stream(self):
for return_hidden_states in self.return_hidden_states:
for (
parallel_sample_num
) in (
self.parallel_sample_nums
): # parallel sample num > 1 breaks in the adapter with a 400 for EAGLE
self.run_chat_completion_stream(
parallel_sample_num, return_hidden_states
)
def run_completion(
self,
use_list_input,
parallel_sample_num,
return_hidden_states,
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
prompt_input = prompt
if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
else:
prompt_arg = prompt_input
num_choices = 1
response = client.completions.create(
model=self.model,
prompt=prompt_arg,
temperature=0,
max_tokens=32,
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
for choice in response.choices:
assert hasattr(choice, "hidden_states") == return_hidden_states
if return_hidden_states:
assert choice.hidden_states is not None, "hidden_states was None"
def run_completion_stream(
self,
use_list_input,
parallel_sample_num,
return_hidden_states,
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))
if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
num_prompt_tokens *= 2
else:
prompt_arg = prompt_input
num_choices = 1
generator = client.completions.create(
model=self.model,
prompt=prompt_arg,
temperature=0,
max_tokens=32,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
hidden_states_list = []
for response in generator:
usage = response.usage
for choice in response.choices:
if hasattr(choice, "hidden_states"):
assert return_hidden_states
assert choice.hidden_states is not None
hidden_states_list.append(choice.hidden_states)
if return_hidden_states:
assert (
len(hidden_states_list) == parallel_sample_num * num_choices
), f"Expected {parallel_sample_num * num_choices} hidden states, got {len(hidden_states_list)}"
else:
assert (
hidden_states_list == []
), "hidden_states were returned and should not have been"
def run_chat_completion(self, parallel_sample_num, return_hidden_states):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "What is the capital of France? Answer in a few words.",
},
],
temperature=0,
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
for choice in response.choices:
assert hasattr(choice, "hidden_states") == return_hidden_states
if return_hidden_states:
assert choice.hidden_states is not None, "hidden_states was None"
def run_chat_completion_stream(
self, parallel_sample_num=1, return_hidden_states=False
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
is_firsts = {}
hidden_states_list = []
for response in generator:
for choice in response.choices:
if hasattr(choice.delta, "hidden_states"):
assert return_hidden_states
assert choice.delta.hidden_states is not None
hidden_states_list.append(choice.delta.hidden_states)
if return_hidden_states:
assert (
len(hidden_states_list) == parallel_sample_num
), f"Expected {parallel_sample_num} hidden states, got {len(hidden_states_list)}"
else:
assert (
hidden_states_list == []
), "hidden_states were returned and should not have been"
class TestOpenAIServerWithHiddenStatesEnabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=["--enable-return-hidden-states"],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1, 2]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestOpenAIServerWithHiddenStatesEnabledAndCUDAGraphDisabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=["--enable-return-hidden-states", "--disable-cuda-graph"],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestOpenAIServerWithEAGLEAndHiddenStatesEnabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.speculative_draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
cls.speculative_algorithm = "EAGLE"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
8,
"--speculative-num-draft-tokens",
64,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--enable-return-hidden-states",
],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestOpenAIServerWithEAGLE3AndHiddenStatesEnabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = "meta-llama/Llama-3.1-8B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.speculative_algorithm = "EAGLE3"
cls.speculative_draft_model = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
cls.speculative_algorithm,
"--speculative-draft-model-path",
cls.speculative_draft_model,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
16,
"--speculative-num-draft-tokens",
64,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--dtype",
"float16",
"--enable-return-hidden-states",
],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment