Unverified Commit a502831d authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Remove redundant input parsing methods (#33542)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ba871fb7
......@@ -10,7 +10,6 @@ import pybase64
import pytest
import torch
from vllm.inputs.data import is_embeds_prompt
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
......@@ -320,7 +319,6 @@ class TestRenderEmbedPrompt:
)
assert len(results) == 1
assert is_embeds_prompt(results[0])
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
@pytest.mark.asyncio
......@@ -342,7 +340,6 @@ class TestRenderEmbedPrompt:
assert len(results) == 2
for i, result in enumerate(results):
assert is_embeds_prompt(result)
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
@pytest.mark.asyncio
......@@ -420,7 +417,7 @@ class TestRenderEmbedPrompt:
assert len(results) == 2
# First should be embed prompt
assert is_embeds_prompt(results[0])
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
# Second should be tokens prompt
assert "prompt_token_ids" in results[1]
assert results[1]["prompt_token_ids"] == [101, 102, 103]
......@@ -68,7 +68,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
......@@ -359,7 +358,7 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = get_prompt_components(engine_prompt)
prompt_text = engine_prompt.get("prompt")
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
......
......@@ -34,8 +34,7 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.inputs.parse import get_prompt_components
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
......@@ -161,7 +160,7 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = get_prompt_components(engine_prompt)
prompt_text = engine_prompt.get("prompt")
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
......@@ -278,11 +277,7 @@ class OpenAIServingCompletion(OpenAIServing):
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = (
None
if is_embeds_prompt(engine_prompt)
else engine_prompt.get("prompt")
)
final_res.prompt = engine_prompt.get("prompt")
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
......@@ -352,11 +347,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = (
None
if is_embeds_prompt(engine_prompt)
else engine_prompt.get("prompt")
)
prompt_text = engine_prompt.get("prompt")
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:
......
......@@ -1116,7 +1116,7 @@ class OpenAIServing:
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
prompt_text, _, _ = get_prompt_components(engine_prompt)
prompt_text = engine_prompt.get("prompt")
orig_priority = priority
sub_request = 0
......@@ -1186,7 +1186,7 @@ class OpenAIServing:
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
prompt_text, _, _ = get_prompt_components(engine_prompt)
prompt_text = engine_prompt.get("prompt")
sampling_params.max_tokens = get_max_tokens(
self.max_model_len,
......
......@@ -5,7 +5,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast
import torch
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
from typing_extensions import NotRequired, TypedDict, TypeVar
from vllm.sampling_params import SamplingParams
......@@ -77,6 +77,9 @@ class EmbedsPrompt(_CommonKeys):
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token embeddings, if available."""
class DataPrompt(_CommonKeys):
"""Represents generic inputs handled by IO processor plugins."""
......@@ -113,22 +116,6 @@ more than one prompt, i.e.
"""
def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]:
return (
isinstance(prompt, dict)
and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt
)
def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]:
return (
isinstance(prompt, dict)
and "prompt_token_ids" not in prompt
and "prompt_embeds" in prompt
)
_T1_co = TypeVar(
"_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
......
......@@ -4,7 +4,7 @@
import time
from collections.abc import Callable, Mapping
from copy import copy
from typing import Any, cast
from typing import Any
import torch.nn as nn
from typing_extensions import TypeVar
......@@ -32,6 +32,7 @@ from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.utils import get_prompt_text
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
......@@ -245,10 +246,7 @@ class LLMEngine:
trace_headers,
priority,
)
if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
prompt_text = get_prompt_text(prompt)
self.input_processor.assign_request_id(request)
......
......@@ -4,12 +4,12 @@
import contextlib
import os
import weakref
from collections.abc import Callable, Iterator, Mapping
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING
from unittest.mock import patch
import msgspec
......@@ -17,6 +17,8 @@ import zmq
from vllm import envs
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.inputs import PromptType
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
......@@ -224,12 +226,8 @@ def get_device_indices(
return value
def get_prompt_text(prompt: Any) -> str | None:
if isinstance(prompt, str):
return prompt
if isinstance(prompt, Mapping):
return cast(str | None, prompt.get("prompt"))
return None
def get_prompt_text(prompt: PromptType) -> str | None:
return get_prompt_components(prompt)[0]
class CoreEngineActorManager:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment