Unverified Commit a502831d authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Remove redundant input parsing methods (#33542)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ba871fb7
...@@ -10,7 +10,6 @@ import pybase64 ...@@ -10,7 +10,6 @@ import pybase64
import pytest import pytest
import torch import torch
from vllm.inputs.data import is_embeds_prompt
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config from vllm.tokenizers.registry import tokenizer_args_from_config
...@@ -320,7 +319,6 @@ class TestRenderEmbedPrompt: ...@@ -320,7 +319,6 @@ class TestRenderEmbedPrompt:
) )
assert len(results) == 1 assert len(results) == 1
assert is_embeds_prompt(results[0])
assert torch.allclose(results[0]["prompt_embeds"], test_tensor) assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -342,7 +340,6 @@ class TestRenderEmbedPrompt: ...@@ -342,7 +340,6 @@ class TestRenderEmbedPrompt:
assert len(results) == 2 assert len(results) == 2
for i, result in enumerate(results): for i, result in enumerate(results):
assert is_embeds_prompt(result)
assert torch.allclose(result["prompt_embeds"], test_tensors[i]) assert torch.allclose(result["prompt_embeds"], test_tensors[i])
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -420,7 +417,7 @@ class TestRenderEmbedPrompt: ...@@ -420,7 +417,7 @@ class TestRenderEmbedPrompt:
assert len(results) == 2 assert len(results) == 2
# First should be embed prompt # First should be embed prompt
assert is_embeds_prompt(results[0]) assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
# Second should be tokens prompt # Second should be tokens prompt
assert "prompt_token_ids" in results[1] assert "prompt_token_ids" in results[1]
assert results[1]["prompt_token_ids"] == [101, 102, 103] assert results[1]["prompt_token_ids"] == [101, 102, 103]
...@@ -68,7 +68,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -68,7 +68,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
...@@ -359,7 +358,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -359,7 +358,7 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text = engine_prompt.get("prompt")
# If we are creating sub requests for multiple prompts, ensure that they # If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids. # have unique request ids.
......
...@@ -34,8 +34,7 @@ from vllm.entrypoints.openai.engine.serving import ( ...@@ -34,8 +34,7 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -161,7 +160,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -161,7 +160,7 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text = engine_prompt.get("prompt")
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
...@@ -278,11 +277,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -278,11 +277,7 @@ class OpenAIServingCompletion(OpenAIServing):
# with the inputs token IDs # with the inputs token IDs
if final_res.prompt is None: if final_res.prompt is None:
engine_prompt = engine_prompts[i] engine_prompt = engine_prompts[i]
final_res.prompt = ( final_res.prompt = engine_prompt.get("prompt")
None
if is_embeds_prompt(engine_prompt)
else engine_prompt.get("prompt")
)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch) final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
...@@ -352,11 +347,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -352,11 +347,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = res.prompt prompt_text = res.prompt
if prompt_text is None: if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx] engine_prompt = engine_prompts[prompt_idx]
prompt_text = ( prompt_text = engine_prompt.get("prompt")
None
if is_embeds_prompt(engine_prompt)
else engine_prompt.get("prompt")
)
# Prompt details are excluded from later streamed outputs # Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None: if prompt_token_ids is not None:
......
...@@ -1116,7 +1116,7 @@ class OpenAIServing: ...@@ -1116,7 +1116,7 @@ class OpenAIServing:
priority: int = 0, priority: int = 0,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
): ):
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text = engine_prompt.get("prompt")
orig_priority = priority orig_priority = priority
sub_request = 0 sub_request = 0
...@@ -1186,7 +1186,7 @@ class OpenAIServing: ...@@ -1186,7 +1186,7 @@ class OpenAIServing:
context.chat_template_content_format, context.chat_template_content_format,
) )
engine_prompt = engine_prompts[0] engine_prompt = engine_prompts[0]
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text = engine_prompt.get("prompt")
sampling_params.max_tokens = get_max_tokens( sampling_params.max_tokens = get_max_tokens(
self.max_model_len, self.max_model_len,
......
...@@ -5,7 +5,7 @@ from dataclasses import dataclass ...@@ -5,7 +5,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast
import torch import torch
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar from typing_extensions import NotRequired, TypedDict, TypeVar
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -77,6 +77,9 @@ class EmbedsPrompt(_CommonKeys): ...@@ -77,6 +77,9 @@ class EmbedsPrompt(_CommonKeys):
prompt_embeds: torch.Tensor prompt_embeds: torch.Tensor
"""The embeddings of the prompt.""" """The embeddings of the prompt."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token embeddings, if available."""
class DataPrompt(_CommonKeys): class DataPrompt(_CommonKeys):
"""Represents generic inputs handled by IO processor plugins.""" """Represents generic inputs handled by IO processor plugins."""
...@@ -113,22 +116,6 @@ more than one prompt, i.e. ...@@ -113,22 +116,6 @@ more than one prompt, i.e.
""" """
def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]:
return (
isinstance(prompt, dict)
and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt
)
def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]:
return (
isinstance(prompt, dict)
and "prompt_token_ids" not in prompt
and "prompt_embeds" in prompt
)
_T1_co = TypeVar( _T1_co = TypeVar(
"_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True "_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
) )
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import time import time
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from copy import copy from copy import copy
from typing import Any, cast from typing import Any
import torch.nn as nn import torch.nn as nn
from typing_extensions import TypeVar from typing_extensions import TypeVar
...@@ -32,6 +32,7 @@ from vllm.v1.engine.core_client import EngineCoreClient ...@@ -32,6 +32,7 @@ from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.utils import get_prompt_text
from vllm.v1.executor import Executor from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
...@@ -245,10 +246,7 @@ class LLMEngine: ...@@ -245,10 +246,7 @@ class LLMEngine:
trace_headers, trace_headers,
priority, priority,
) )
if isinstance(prompt, str): prompt_text = get_prompt_text(prompt)
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
self.input_processor.assign_request_id(request) self.input_processor.assign_request_id(request)
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
import contextlib import contextlib
import os import os
import weakref import weakref
from collections.abc import Callable, Iterator, Mapping from collections.abc import Callable, Iterator
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from multiprocessing import Process, connection from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING
from unittest.mock import patch from unittest.mock import patch
import msgspec import msgspec
...@@ -17,6 +17,8 @@ import zmq ...@@ -17,6 +17,8 @@ import zmq
from vllm import envs from vllm import envs
from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.inputs import PromptType
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy from vllm.ray.ray_env import get_env_vars_to_copy
...@@ -224,12 +226,8 @@ def get_device_indices( ...@@ -224,12 +226,8 @@ def get_device_indices(
return value return value
def get_prompt_text(prompt: Any) -> str | None: def get_prompt_text(prompt: PromptType) -> str | None:
if isinstance(prompt, str): return get_prompt_components(prompt)[0]
return prompt
if isinstance(prompt, Mapping):
return cast(str | None, prompt.get("prompt"))
return None
class CoreEngineActorManager: class CoreEngineActorManager:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment