"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "10b01b45908c2640d97b88b8a3024a56262a353d"
Unverified Commit f57fd72f authored by Aryan Bagade's avatar Aryan Bagade Committed by GitHub
Browse files

feat: Add logprobs support to vLLM backend (#4683) (#4697)


Signed-off-by: default avatarAryan Bagade <aryan@aryanbagade.com>
Signed-off-by: default avatarAryan Bagade <73382554+AryanBagade@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 6d091bf3
...@@ -73,7 +73,8 @@ def build_sampling_params( ...@@ -73,7 +73,8 @@ def build_sampling_params(
Build SamplingParams from a PreprocessedRequest. Build SamplingParams from a PreprocessedRequest.
Args: Args:
request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions' request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions',
and 'output_options'
default_sampling_params: Default sampling parameters to initialize with default_sampling_params: Default sampling parameters to initialize with
Returns: Returns:
...@@ -116,6 +117,41 @@ def build_sampling_params( ...@@ -116,6 +117,41 @@ def build_sampling_params(
existing = sampling_params.stop_token_ids or [] existing = sampling_params.stop_token_ids or []
sampling_params.stop_token_ids = list(set(existing).union(value)) sampling_params.stop_token_ids = list(set(existing).union(value))
# Apply output_options (logprobs, prompt_logprobs, etc.)
output_options = request.get("output_options", {})
if output_options:
# Handle logprobs - vLLM expects this as an integer or None
logprobs_value = output_options.get("logprobs")
if logprobs_value is not None and logprobs_value != "":
try:
parsed_logprobs = int(logprobs_value)
if parsed_logprobs < 0:
logger.warning(
f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring"
)
else:
sampling_params.logprobs = parsed_logprobs
except (ValueError, TypeError):
logger.warning(
f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring"
)
# Handle prompt_logprobs - vLLM expects this as an integer or None
prompt_logprobs_value = output_options.get("prompt_logprobs")
if prompt_logprobs_value is not None and prompt_logprobs_value != "":
try:
parsed_prompt_logprobs = int(prompt_logprobs_value)
if parsed_prompt_logprobs < 0:
logger.warning(
f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be non-negative), ignoring"
)
else:
sampling_params.prompt_logprobs = parsed_prompt_logprobs
except (ValueError, TypeError):
logger.warning(
f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be integer), ignoring"
)
# If max_tokens wasn't provided (None or missing), compute a dynamic default # If max_tokens wasn't provided (None or missing), compute a dynamic default
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None) provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
token_ids = request.get("token_ids", []) token_ids = request.get("token_ids", [])
...@@ -577,6 +613,66 @@ class BaseWorkerHandler(ABC): ...@@ -577,6 +613,66 @@ class BaseWorkerHandler(ABC):
), ),
} }
@staticmethod
def _extract_logprobs(
output, num_output_tokens_so_far: int
) -> tuple[list[float] | None, list[list[dict]] | None]:
"""
Extract logprobs from vLLM CompletionOutput for new tokens.
Args:
output: vLLM CompletionOutput object
num_output_tokens_so_far: Number of tokens already processed
Returns:
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
- log_probs: List of log probabilities for each new token
- top_logprobs: List of top logprobs dicts for each new token
"""
if output.logprobs is None:
return None, None
# Get logprobs for new tokens only
new_logprobs = output.logprobs[num_output_tokens_so_far:]
if not new_logprobs:
return None, None
log_probs = []
top_logprobs = []
for token_idx, token_logprobs_dict in enumerate(new_logprobs):
if token_logprobs_dict is None:
continue
# Get the actual token_id that was generated at this position
actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx]
# Extract log probability for the selected token
# vLLM guarantees the selected token is always in the logprobs dict
selected_logprob = token_logprobs_dict[actual_token_id]
log_probs.append(float(selected_logprob.logprob))
# Build top_logprobs list for this token position
token_top_logprobs = []
for tok_id, logprob_info in token_logprobs_dict.items():
token_top_logprobs.append(
{
"rank": (
logprob_info.rank if hasattr(logprob_info, "rank") else 0
),
"token_id": tok_id,
"token": (
logprob_info.decoded_token
if hasattr(logprob_info, "decoded_token")
else None
),
"logprob": float(logprob_info.logprob),
}
)
top_logprobs.append(token_top_logprobs)
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
async def generate_tokens( async def generate_tokens(
self, self,
prompt, prompt,
...@@ -622,6 +718,16 @@ class BaseWorkerHandler(ABC): ...@@ -622,6 +718,16 @@ class BaseWorkerHandler(ABC):
output = res.outputs[0] output = res.outputs[0]
next_total_toks = len(output.token_ids) next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
# Extract logprobs for new tokens if available
log_probs, top_logprobs = self._extract_logprobs(
output, num_output_tokens_so_far
)
if log_probs is not None:
out["log_probs"] = log_probs
if top_logprobs is not None:
out["top_logprobs"] = top_logprobs
if output.finish_reason: if output.finish_reason:
out["finish_reason"] = output.finish_reason out["finish_reason"] = output.finish_reason
out[ out[
......
...@@ -19,7 +19,9 @@ from tests.utils.engine_process import EngineConfig ...@@ -19,7 +19,9 @@ from tests.utils.engine_process import EngineConfig
from tests.utils.payload_builder import ( from tests.utils.payload_builder import (
chat_payload, chat_payload,
chat_payload_default, chat_payload_default,
chat_payload_with_logprobs,
completion_payload_default, completion_payload_default,
completion_payload_with_logprobs,
metric_payload_default, metric_payload_default,
) )
from tests.utils.payloads import ToolCallingChatPayload from tests.utils.payloads import ToolCallingChatPayload
...@@ -59,6 +61,29 @@ vllm_configs = { ...@@ -59,6 +61,29 @@ vllm_configs = {
metric_payload_default(min_num_requests=6, backend="vllm"), metric_payload_default(min_num_requests=6, backend="vllm"),
], ],
), ),
"aggregated_logprobs": VLLMConfig(
name="aggregated_logprobs",
directory=vllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload_with_logprobs(
repeat_count=2,
expected_response=["AI", "knock", "joke"],
max_tokens=30,
temperature=0.0,
top_logprobs=3,
),
completion_payload_with_logprobs(
repeat_count=2,
expected_response=["AI", "knock", "joke"],
max_tokens=30,
temperature=0.0,
logprobs=5,
),
],
),
"aggregated_lmcache": VLLMConfig( "aggregated_lmcache": VLLMConfig(
name="aggregated_lmcache", name="aggregated_lmcache",
directory=vllm_dir, directory=vllm_dir,
......
...@@ -153,6 +153,10 @@ def chat_payload( ...@@ -153,6 +153,10 @@ def chat_payload(
} }
if temperature is not None: if temperature is not None:
body["temperature"] = temperature body["temperature"] = temperature
if logprobs is not None:
body["logprobs"] = logprobs
if top_logprobs is not None:
body["top_logprobs"] = top_logprobs
if top_logprobs is not None: if top_logprobs is not None:
body["top_logprobs"] = top_logprobs body["top_logprobs"] = top_logprobs
...@@ -307,3 +311,83 @@ def make_completions_health_check(port: int, model: str): ...@@ -307,3 +311,83 @@ def make_completions_health_check(port: int, model: str):
return False return False
return _check_completions_endpoint return _check_completions_endpoint
def chat_payload_with_logprobs(
content: Union[str, List[Dict[str, Any]]] = TEXT_PROMPT,
repeat_count: int = 1,
expected_response: Optional[List[str]] = None,
max_tokens: int = 50,
temperature: float = 0.0,
top_logprobs: int = 3,
) -> ChatPayloadWithLogprobs:
"""
Create a chat payload that requests and validates logprobs in the response.
Args:
content: Message content (text or structured content list)
repeat_count: Number of times to repeat the request
expected_response: List of strings expected in the response text
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_logprobs: Number of top logprobs to return per token
Returns:
ChatPayloadWithLogprobs that validates logprobs in response
"""
body: Dict[str, Any] = {
"messages": [
{
"role": "user",
"content": content,
}
],
"max_tokens": max_tokens,
"temperature": temperature,
"logprobs": True,
"top_logprobs": top_logprobs,
}
return ChatPayloadWithLogprobs(
body=body,
repeat_count=repeat_count,
expected_log=[],
expected_response=expected_response or ["AI", "knock", "joke"],
)
def completion_payload_with_logprobs(
prompt: str = TEXT_PROMPT,
repeat_count: int = 1,
expected_response: Optional[List[str]] = None,
max_tokens: int = 50,
temperature: float = 0.0,
logprobs: int = 5,
) -> CompletionPayloadWithLogprobs:
"""
Create a completion payload that requests and validates logprobs in the response.
Args:
prompt: Text prompt
repeat_count: Number of times to repeat the request
expected_response: List of strings expected in the response text
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
logprobs: Number of logprobs to return per token
Returns:
CompletionPayloadWithLogprobs that validates logprobs in response
"""
body: Dict[str, Any] = {
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"logprobs": logprobs,
}
return CompletionPayloadWithLogprobs(
body=body,
repeat_count=repeat_count,
expected_log=[],
expected_response=expected_response or ["AI", "knock", "joke"],
)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import math
import re import re
import time import time
from copy import deepcopy from copy import deepcopy
...@@ -183,6 +184,14 @@ class ChatPayloadWithLogprobs(ChatPayload): ...@@ -183,6 +184,14 @@ class ChatPayloadWithLogprobs(ChatPayload):
"top_logprobs" in item "top_logprobs" in item
), "Missing 'top_logprobs' in logprobs content" ), "Missing 'top_logprobs' in logprobs content"
# Sanity check: logprob should be valid (not nan/inf/positive)
logprob_val = item["logprob"]
assert not math.isnan(logprob_val), "logprob is NaN"
assert not math.isinf(logprob_val), "logprob is infinite"
assert (
logprob_val <= 0
), f"logprob should be <= 0, got {logprob_val}"
logger.info( logger.info(
f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs" f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs"
) )
...@@ -281,6 +290,20 @@ class CompletionPayloadWithLogprobs(CompletionPayload): ...@@ -281,6 +290,20 @@ class CompletionPayloadWithLogprobs(CompletionPayload):
assert len(token_logprobs) == len( assert len(token_logprobs) == len(
tokens tokens
), "Mismatch between token_logprobs and tokens length" ), "Mismatch between token_logprobs and tokens length"
# Sanity check: each logprob should be valid (not nan/inf/positive)
for i, logprob_val in enumerate(token_logprobs):
if logprob_val is not None: # First token can be None
assert not math.isnan(
logprob_val
), f"logprob at index {i} is NaN"
assert not math.isinf(
logprob_val
), f"logprob at index {i} is infinite"
assert (
logprob_val <= 0
), f"logprob at index {i} should be <= 0, got {logprob_val}"
logger.info( logger.info(
f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs" f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment