Unverified Commit 8896eb72 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Deprecation] Remove `prompt_token_ids` arg fallback in `LLM.generate` and `LLM.embed` (#18800)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 19fe1a05
...@@ -38,8 +38,7 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, ...@@ -38,8 +38,7 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
with vllm_runner(model_id) as llm: with vllm_runner(model_id) as llm:
# note: this does not test accuracy, just that we can run through # note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy # see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"], outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10)
max_tokens=10)
print(outputs[0][1]) print(outputs[0][1])
...@@ -90,8 +89,7 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, ...@@ -90,8 +89,7 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
# note: this does not test accuracy, just that we can run through # note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy # see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"], outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10)
max_tokens=10)
print(outputs[0][1]) print(outputs[0][1])
......
...@@ -46,5 +46,5 @@ def test_lm_head( ...@@ -46,5 +46,5 @@ def test_lm_head(
vllm_model.apply_model(check_model) vllm_model.apply_model(check_model)
print( print(
vllm_model.generate_greedy(prompts=["Hello my name is"], vllm_model.generate_greedy(["Hello my name is"],
max_tokens=10)[0][1]) max_tokens=10)[0][1])
...@@ -127,13 +127,15 @@ def test_structured_output( ...@@ -127,13 +127,15 @@ def test_structured_output(
temperature=1.0, temperature=1.0,
max_tokens=4096, max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=sample_json_schema)) guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
(f"Give an example JSON for an employee profile that fits this " prompt = ("Give an example JSON for an employee profile that fits this "
f"schema. Make the response as short as possible. Schema: " "schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}") f"{sample_json_schema}")
] * 2, outputs = llm.generate(
sampling_params=sampling_params, [prompt] * 2,
use_tqdm=True) sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None assert outputs is not None
...@@ -191,20 +193,24 @@ def test_structured_output( ...@@ -191,20 +193,24 @@ def test_structured_output(
with pytest.raises(ValueError, with pytest.raises(ValueError,
match="The provided JSON schema contains features " match="The provided JSON schema contains features "
"not supported by xgrammar."): "not supported by xgrammar."):
prompt = (f"Give an example JSON for an employee profile that "
f"fits this schema: {unsupported_json_schema}. "
f"Make the response as short as possible.")
llm.generate( llm.generate(
prompts=[(f"Give an example JSON for an employee profile that " [prompt] * 2,
f"fits this schema: {unsupported_json_schema}. "
f"Make the response as short as possible.")] * 2,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True,
)
else: else:
outputs = llm.generate(prompts=( prompt = (f"Give an example JSON object for a grade that "
"Give an example JSON object for a grade " f"fits this schema: {unsupported_json_schema}. "
"that fits this schema: " f"Make the response as short as possible.")
f"{unsupported_json_schema}. Make the response as short as " outputs = llm.generate(
"possible."), prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True,
)
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
assert output is not None assert output is not None
...@@ -227,10 +233,9 @@ def test_structured_output( ...@@ -227,10 +233,9 @@ def test_structured_output(
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
outputs = llm.generate( outputs = llm.generate(
prompts=( ("Generate a sql statement that selects col_1 from "
"Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1. Make the response as short as "
"table_1 where it is equal to 1. Make the response as short as " "possible."),
"possible."),
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True, use_tqdm=True,
) )
...@@ -261,10 +266,9 @@ def test_structured_output( ...@@ -261,10 +266,9 @@ def test_structured_output(
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
outputs = llm.generate( outputs = llm.generate(
prompts=( ("Generate a sql statement that selects col_1 from "
"Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1. Make the response as short as "
"table_1 where it is equal to 1. Make the response as short as " "possible."),
"possible."),
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True, use_tqdm=True,
) )
...@@ -301,7 +305,6 @@ def test_structured_output( ...@@ -301,7 +305,6 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(grammar="not a grammar")) guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "): with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate( llm.generate(
prompts=
("Generate a sql statement that selects col_1 from " ("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short " "table_1 where it is equal to 1. Make the response as short "
"as possible."), "as possible."),
...@@ -316,11 +319,11 @@ def test_structured_output( ...@@ -316,11 +319,11 @@ def test_structured_output(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex)) guided_decoding=GuidedDecodingParams(regex=sample_regex))
prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. "
f"Make the response as short as possible.")
outputs = llm.generate( outputs = llm.generate(
prompts=[ [prompt] * 2,
(f"Give an example IPv4 address with this regex: {sample_regex}. "
f"Make the response as short as possible.")
] * 2,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True, use_tqdm=True,
) )
...@@ -343,11 +346,13 @@ def test_structured_output( ...@@ -343,11 +346,13 @@ def test_structured_output(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
outputs = llm.generate( outputs = llm.generate(
prompts=("The best language for type-safe systems programming is " ("The best language for type-safe systems programming is "
"(Make the response as short as possible.) "), "(Make the response as short as possible.) "),
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True,
)
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
assert output is not None assert output is not None
...@@ -367,12 +372,14 @@ def test_structured_output( ...@@ -367,12 +372,14 @@ def test_structured_output(
temperature=1.0, temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=json_schema)) guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate(prompts=(
"Generate a JSON with the brand, model and car_type of the most " outputs = llm.generate(
"iconic car from the 90's. Make the response as short as " ("Generate a JSON with the brand, model and car_type of the most "
"possible."), "iconic car from the 90's. Make the response as short as "
sampling_params=sampling_params, "possible."),
use_tqdm=True) sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None assert outputs is not None
...@@ -411,10 +418,11 @@ def test_structured_output( ...@@ -411,10 +418,11 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(json=json_schema)) guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate( outputs = llm.generate(
prompts=("Generate a description of a frog using 50 characters. " ("Generate a description of a frog using 50 characters. "
"Make the response as short as possible."), "Make the response as short as possible."),
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True,
)
assert outputs is not None assert outputs is not None
...@@ -498,7 +506,7 @@ Make the response as short as possible. ...@@ -498,7 +506,7 @@ Make the response as short as possible.
""" """
# Change this once other backends support structural_tag # Change this once other backends support structural_tag
outputs = llm.generate(prompts=prompt, outputs = llm.generate(prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
assert outputs is not None assert outputs is not None
...@@ -639,15 +647,13 @@ def test_structured_output_auto_mode( ...@@ -639,15 +647,13 @@ def test_structured_output_auto_mode(
f"{unsupported_json_schema}. Make the response as short as possible.") f"{unsupported_json_schema}. Make the response as short as possible.")
# This would fail with the default of "xgrammar", but in "auto" # This would fail with the default of "xgrammar", but in "auto"
# we will handle fallback automatically. # we will handle fallback automatically.
outputs = llm.generate(prompts=prompts, outputs = llm.generate(prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
# Make sure `auto` backend handling doesn't mess up sampling_params # Make sure `auto` backend handling doesn't mess up sampling_params
# and that we can reuse it without error. # and that we can reuse it without error.
outputs.extend( outputs.extend(
llm.generate(prompts=prompts, llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True))
sampling_params=sampling_params,
use_tqdm=True))
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
...@@ -705,7 +711,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): ...@@ -705,7 +711,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
max_tokens=256, max_tokens=256,
guided_decoding=guided_params) guided_decoding=guided_params)
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) outputs = llm.generate(prompt, sampling_params=sampling_params)
assert outputs is not None assert outputs is not None
generated_text = outputs[0].outputs[0].text generated_text = outputs[0].outputs[0].text
assert generated_text is not None assert generated_text is not None
......
...@@ -3,15 +3,13 @@ ...@@ -3,15 +3,13 @@
import itertools import itertools
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
cast, overload)
import cloudpickle import cloudpickle
import torch.nn as nn import torch.nn as nn
from pydantic import ValidationError from pydantic import ValidationError
from tqdm.auto import tqdm from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar
import vllm.envs as envs import vllm.envs as envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
...@@ -40,7 +38,6 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam, ...@@ -40,7 +38,6 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam,
from vllm.entrypoints.utils import (_validate_truncation_size, from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args) log_non_default_args)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -54,7 +51,7 @@ from vllm.tasks import PoolingTask ...@@ -54,7 +51,7 @@ from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of from vllm.utils import Counter, Device, is_list_of
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -157,18 +154,6 @@ class LLM: ...@@ -157,18 +154,6 @@ class LLM:
serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead. serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
""" """
DEPRECATE_LEGACY: ClassVar[bool] = True
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
cls.DEPRECATE_LEGACY = True
yield
cls.DEPRECATE_LEGACY = False
def __init__( def __init__(
self, self,
model: str, model: str,
...@@ -325,99 +310,14 @@ class LLM: ...@@ -325,99 +310,14 @@ class LLM:
return SamplingParams.from_optional(**self.default_sampling_params) return SamplingParams.from_optional(**self.default_sampling_params)
return SamplingParams() return SamplingParams()
@overload
def generate( def generate(
self, self,
prompts: Union[PromptType, Sequence[PromptType]], prompts: Union[PromptType, Sequence[PromptType]],
/,
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None, Sequence[SamplingParams]]] = None,
*, *,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[RequestOutput]:
...
@overload # LEGACY: single (prompt + optional token ids)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate(
self,
prompts: str,
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
prompt_token_ids: Optional[list[int]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[RequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate(
self,
prompts: list[str],
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
prompt_token_ids: Optional[list[list[int]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[RequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate(
self,
prompts: Optional[str] = None,
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
*,
prompt_token_ids: list[int],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[RequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate(
self,
prompts: Optional[list[str]] = None,
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
*,
prompt_token_ids: list[list[int]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[RequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate(
self,
prompts: None,
sampling_params: None,
prompt_token_ids: Union[list[int], list[list[int]]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[RequestOutput]:
...
@deprecate_kwargs(
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'prompts' parameter instead.",
)
def generate(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, list[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
priority: Optional[list[int]] = None, priority: Optional[list[int]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -460,15 +360,6 @@ class LLM: ...@@ -460,15 +360,6 @@ class LLM:
"Try passing `--runner generate` to use the model as a " "Try passing `--runner generate` to use the model as a "
"generative model.") "generative model.")
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, list[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = self.get_default_sampling_params() sampling_params = self.get_default_sampling_params()
...@@ -483,10 +374,10 @@ class LLM: ...@@ -483,10 +374,10 @@ class LLM:
# Add any modality specific loras to the corresponding prompts # Add any modality specific loras to the corresponding prompts
lora_request = self._get_modality_specific_lora_reqs( lora_request = self._get_modality_specific_lora_reqs(
parsed_prompts, lora_request) prompts, lora_request)
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, prompts=prompts,
params=sampling_params, params=sampling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
...@@ -498,7 +389,7 @@ class LLM: ...@@ -498,7 +389,7 @@ class LLM:
return self.engine_class.validate_outputs(outputs, RequestOutput) return self.engine_class.validate_outputs(outputs, RequestOutput)
def _get_modality_specific_lora_reqs( def _get_modality_specific_lora_reqs(
self, parsed_prompts: Union[PromptType, Sequence[PromptType]], self, prompts: Union[PromptType, Sequence[PromptType]],
lora_request: Optional[Union[list[LoRARequest], LoRARequest]]): lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
# Grab the lora config off the vllm config on the engine, # Grab the lora config off the vllm config on the engine,
# since this is the same for both v0 & v1. # since this is the same for both v0 & v1.
...@@ -511,35 +402,33 @@ class LLM: ...@@ -511,35 +402,33 @@ class LLM:
or (lora_config and lora_config.default_mm_loras is None)): or (lora_config and lora_config.default_mm_loras is None)):
return lora_request return lora_request
if not isinstance(parsed_prompts, Sequence): if not isinstance(prompts, Sequence):
parsed_prompts = [parsed_prompts] prompts = [prompts]
optional_loras = ([lora_request] * len(parsed_prompts) optional_loras = ([lora_request] * len(prompts)
if not isinstance(lora_request, Sequence) else if not isinstance(lora_request, Sequence) else
lora_request) lora_request)
return [ return [
self._resolve_single_prompt_mm_lora( self._resolve_single_prompt_mm_lora(
parsed_prompt, prompt,
opt_lora_req, opt_lora_req,
lora_config.default_mm_loras, lora_config.default_mm_loras,
) for parsed_prompt, opt_lora_req in zip(parsed_prompts, ) for prompt, opt_lora_req in zip(prompts, optional_loras)
optional_loras)
] ]
def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType, def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
default_mm_loras: Optional[dict[str, default_mm_loras: Optional[dict[str,
str]]): str]]):
if (not default_mm_loras or not isinstance(parsed_prompt, dict) if (not default_mm_loras or not isinstance(prompt, dict)
or "multi_modal_data" not in parsed_prompt): or "multi_modal_data" not in prompt):
return lora_request return lora_request
parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt) prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
intersection = set( intersection = set(prompt["multi_modal_data"].keys()) \
parsed_prompt["multi_modal_data"].keys()).intersection( .intersection(default_mm_loras.keys())
default_mm_loras.keys())
if not intersection: if not intersection:
return lora_request return lora_request
if len(intersection) > 1: if len(intersection) > 1:
...@@ -933,120 +822,17 @@ class LLM: ...@@ -933,120 +822,17 @@ class LLM:
lora_request=lora_request, lora_request=lora_request,
) )
@overload
def encode( def encode(
self, self,
prompts: Union[PromptType, Sequence[PromptType]], prompts: Union[PromptType, Sequence[PromptType]],
/,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@overload # LEGACY: single (prompt + optional token ids)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode(
self,
prompts: str,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[list[int]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode(
self,
prompts: list[str],
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[list[list[int]]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode(
self,
prompts: Optional[str] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: list[int],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode(
self,
prompts: Optional[list[str]] = None,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
*, *,
prompt_token_ids: list[list[int]],
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode(
self,
prompts: None,
pooling_params: None,
prompt_token_ids: Union[list[int], list[list[int]]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@deprecate_kwargs(
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'prompts' parameter instead.",
)
def encode(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, list[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
pooling_task: Optional[PoolingTask] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
"""Apply pooling to the hidden states corresponding to the input """Apply pooling to the hidden states corresponding to the input
prompts. prompts.
...@@ -1108,15 +894,6 @@ class LLM: ...@@ -1108,15 +894,6 @@ class LLM:
raise ValueError( raise ValueError(
f"pooling_task must be one of {self.supported_tasks}.") f"pooling_task must be one of {self.supported_tasks}.")
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, list[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
...@@ -1134,7 +911,7 @@ class LLM: ...@@ -1134,7 +911,7 @@ class LLM:
tokenization_kwargs) tokenization_kwargs)
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, prompts=prompts,
params=pooling_params, params=pooling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
...@@ -1148,7 +925,6 @@ class LLM: ...@@ -1148,7 +925,6 @@ class LLM:
def embed( def embed(
self, self,
prompts: Union[PromptType, Sequence[PromptType]], prompts: Union[PromptType, Sequence[PromptType]],
/,
*, *,
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
...@@ -1198,7 +974,6 @@ class LLM: ...@@ -1198,7 +974,6 @@ class LLM:
def classify( def classify(
self, self,
prompts: Union[PromptType, Sequence[PromptType]], prompts: Union[PromptType, Sequence[PromptType]],
/,
*, *,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
...@@ -1348,7 +1123,7 @@ class LLM: ...@@ -1348,7 +1123,7 @@ class LLM:
_validate_truncation_size(model_config.max_model_len, _validate_truncation_size(model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs) truncate_prompt_tokens, tokenization_kwargs)
parsed_prompts = [] prompts = list[PromptType]()
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
...@@ -1372,10 +1147,10 @@ class LLM: ...@@ -1372,10 +1147,10 @@ class LLM:
else: else:
pooling_params_list.append(pooling_params) pooling_params_list.append(pooling_params)
parsed_prompts.append(engine_prompt) prompts.append(engine_prompt)
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, prompts=prompts,
params=pooling_params_list, params=pooling_params_list,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
...@@ -1585,48 +1360,6 @@ class LLM: ...@@ -1585,48 +1360,6 @@ class LLM:
assert isinstance(self.llm_engine, V1LLMEngine) assert isinstance(self.llm_engine, V1LLMEngine)
return self.llm_engine.get_metrics() return self.llm_engine.get_metrics()
# LEGACY
def _convert_v1_inputs(
self,
prompts: Optional[Union[str, list[str]]],
prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
):
# skip_tokenizer_init is now checked in engine
if prompts is None and prompt_token_ids is None:
raise ValueError(
"Either prompts or prompt_token_ids must be provided.")
if prompts is not None and prompt_token_ids is not None \
and len(prompts) != len(prompt_token_ids):
raise ValueError(
"The lengths of prompts and prompt_token_ids must be the same."
)
if prompts is not None:
prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
if prompt_token_ids is not None:
prompt_token_ids = [
p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
]
if prompts is not None:
num_requests = len(prompts)
elif prompt_token_ids is not None:
num_requests = len(prompt_token_ids)
parsed_prompts: list[PromptType] = []
for i in range(num_requests):
item: PromptType
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
elif prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError
parsed_prompts.append(item)
return parsed_prompts
def _validate_and_add_requests( def _validate_and_add_requests(
self, self,
prompts: Union[PromptType, Sequence[PromptType]], prompts: Union[PromptType, Sequence[PromptType]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment