Unverified Commit 82ec66f5 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[V0 Deprecation] Remove Prompt Adapters (#20588)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 78c13e30
......@@ -14,7 +14,6 @@ API documentation for vLLM's configuration classes.
- [vllm.config.DeviceConfig][]
- [vllm.config.SpeculativeConfig][]
- [vllm.config.LoRAConfig][]
- [vllm.config.PromptAdapterConfig][]
- [vllm.config.MultiModalConfig][]
- [vllm.config.PoolerConfig][]
- [vllm.config.DecodingConfig][]
......
......@@ -34,23 +34,22 @@ th:not(:first-child) {
}
</style>
| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | <abbr title="Prompt Adapter">prmpt adptr</abbr> | [SD](spec_decode.md) | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | [SD](spec_decode.md) | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | |
| [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | |
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | |
| [SD](spec_decode.md) | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | | | |
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | |
| <abbr title="Pooling Models">pooling</abbr> | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | |
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ❌ | [](gh-issue:7366) | ❌ | ❌ | [](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | |
| <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | |
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | |
| <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | |
| multi-step | ❌ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | |
| <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | |
| best-of | ✅ | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ✅ | ✅ | |
| beam-search | ✅ | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ❔ | ✅ | ✅ |
| [SD](spec_decode.md) | ✅ | ✅ | ❌ | ✅ | | | | | | | | | | |
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | |
| <abbr title="Pooling Models">pooling</abbr> | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | |
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ❌ | [](gh-issue:7366) | ❌ | [](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | |
| <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | |
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | |
| <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | |
| multi-step | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | |
| <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | |
| best-of | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ✅ | ✅ | |
| beam-search | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ❔ | ✅ | ✅ |
[](){ #feature-x-hardware }
......@@ -59,10 +58,9 @@ th:not(:first-child) {
| Feature | Volta | Turing | Ampere | Ada | Hopper | CPU | AMD | TPU |
|-----------------------------------------------------------|---------------------|-----------|-----------|--------|------------|--------------------|--------|-----|
| [CP][chunked-prefill] | [](gh-issue:2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [APC](automatic_prefix_caching.md) | [](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | [](gh-issue:8475) | ✅ | ❌ |
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| [APC](automatic_prefix_caching.md) | [](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ |
| <abbr title="Pooling Models">pooling</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ❌ |
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
......
......@@ -72,7 +72,6 @@ line-length = 80
"vllm/core/**/*.py" = ["UP006", "UP035"]
"vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/worker/**/*.py" = ["UP006", "UP035"]
# Python 3.8 typing - skip utils for ROCm
"vllm/utils/__init__.py" = ["UP006", "UP035"]
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# imports for guided decoding tests
import json
import os
import shutil
from tempfile import TemporaryDirectory
from typing import Optional
......@@ -26,10 +27,6 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically these adapters use a different base model,
# but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
......@@ -56,13 +53,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files):
@pytest.fixture(scope="module")
def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
......@@ -81,15 +72,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
"64",
"--max-cpu-loras",
"2",
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128",
]
......@@ -98,8 +80,19 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
def server(default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
original_value = os.environ.get('VLLM_USE_V1')
os.environ['VLLM_USE_V1'] = '0'
try:
with RemoteOpenAIServer(MODEL_NAME,
default_server_args) as remote_server:
yield remote_server
finally:
# Restore original env value
if original_value is None:
os.environ.pop('VLLM_USE_V1', None)
else:
os.environ['VLLM_USE_V1'] = original_value
@pytest_asyncio.fixture
......@@ -110,14 +103,11 @@ async def client(server):
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
"model_name,num_virtual_tokens",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
num_virtual_tokens: int):
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
......@@ -130,9 +120,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5,
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
......@@ -175,9 +163,9 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
......@@ -194,9 +182,9 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora and 1 pa hereafter
# just test 1 lora
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
......@@ -217,7 +205,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
......@@ -238,7 +226,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str):
......@@ -314,7 +302,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str):
......@@ -348,7 +336,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
......@@ -382,7 +370,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str):
......@@ -519,7 +507,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs
......
......@@ -13,7 +13,6 @@ from ...utils import RemoteOpenAIServer
from .test_completion import default_server_args # noqa: F401
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
from .test_completion import zephyr_pa_files # noqa: F401
from .test_completion import MODEL_NAME
......
......@@ -32,8 +32,7 @@ async def _async_serving_models_init() -> OpenAIServingModels:
serving_models = OpenAIServingModels(engine_client=mock_engine_client,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config,
lora_modules=None,
prompt_adapters=None)
lora_modules=None)
await serving_models.init_static_loras()
return serving_models
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
def do_sample(llm, pa_name: str, pa_id: int):
prompts = [
"Tweet text : @nationalgridus I have no water and the bill is \
current and paid. Can you do something about this? Label : ",
"Tweet text : @nationalgridus Looks good thanks! Label : "
]
sampling_params = vllm.SamplingParams(temperature=0.0,
max_tokens=3,
stop_token_ids=[3])
outputs = llm.generate(prompts,
sampling_params,
prompt_adapter_request=PromptAdapterRequest(
pa_name, pa_id, PA_PATH, 8) if pa_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_twitter_prompt_adapter(enforce_eager: bool):
llm = vllm.LLM(MODEL_PATH,
enforce_eager=enforce_eager,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
expected_output = ['complaint', 'no complaint']
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
pa_path2 = 'swapnilbp/angry_tweet_ptune'
def do_sample(engine):
prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def do_sample(engine):
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_lora_prompt_adapter():
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)
expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output
......@@ -31,6 +31,5 @@ run_mypy vllm/inputs
run_mypy vllm/lora
run_mypy vllm/model_executor
run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/worker
run_mypy vllm/v1
......@@ -3143,59 +3143,6 @@ class LoRAConfig:
self.lora_dtype = getattr(torch, self.lora_dtype)
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class PromptAdapterConfig:
"""Configuration for PromptAdapters."""
max_prompt_adapters: int = 1
"""Max number of PromptAdapters in a batch."""
max_prompt_adapter_token: int = 0
"""Max number of PromptAdapters tokens."""
max_cpu_prompt_adapters: Optional[int] = None
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
`max_prompt_adapters`."""
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
"""Data type for PromptAdapter. If auto, will default to base model dtype.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
f"({self.max_prompt_adapters}) must be >= 1.")
if self.max_prompt_adapter_token == 0:
raise ValueError("max_prompt_adapter_token must be set.")
if self.max_cpu_prompt_adapters is None:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype == "auto":
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
self.prompt_adapter_dtype)
@config
@dataclass
class MultiModalConfig:
......@@ -4402,8 +4349,6 @@ class VllmConfig:
"""Decoding configuration."""
observability_config: Optional[ObservabilityConfig] = None
"""Observability configuration."""
prompt_adapter_config: Optional[PromptAdapterConfig] = None
"""Prompt adapter configuration."""
quant_config: Optional[QuantizationConfig] = None
"""Quantization configuration."""
compilation_config: CompilationConfig = field(
......@@ -4500,10 +4445,6 @@ class VllmConfig:
vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:
......@@ -4611,9 +4552,6 @@ class VllmConfig:
if self.lora_config is not None:
self.lora_config.verify_with_cache_config(self.cache_config)
self.lora_config.verify_with_model_config(self.model_config)
if self.prompt_adapter_config is not None:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
if self.quant_config is None and self.model_config is not None:
self.quant_config = VllmConfig._get_quantization_config(
......
......@@ -15,7 +15,6 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupMetadataDelta, SequenceStage,
......@@ -165,8 +164,6 @@ class SchedulerOutputs:
if self.num_loras > 0:
self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
......@@ -194,14 +191,6 @@ class SchedulerOutputs:
if g.seq_group.lora_request is not None
}
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass
class SchedulerRunningOutputs:
......@@ -1648,7 +1637,6 @@ class Scheduler:
multi_modal_placeholders=(
seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None),
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
# When SPMD mode is enabled, we only send delta data except for
......
......@@ -30,9 +30,9 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs, get_field)
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
......@@ -358,11 +358,6 @@ class EngineArgs:
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
# PromptAdapter fields
enable_prompt_adapter: bool = False
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = \
PromptAdapterConfig.max_prompt_adapter_token
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
......@@ -437,6 +432,8 @@ class EngineArgs:
ParallelConfig.enable_multimodal_encoder_data_parallel
async_scheduling: bool = SchedulerConfig.async_scheduling
# DEPRECATED
enable_prompt_adapter: bool = False
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
......@@ -729,23 +726,6 @@ class EngineArgs:
lora_group.add_argument("--default-mm-loras",
**lora_kwargs["default_mm_loras"])
# PromptAdapter related configs
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
prompt_adapter_group = parser.add_argument_group(
title="PromptAdapterConfig",
description=PromptAdapterConfig.__doc__,
)
prompt_adapter_group.add_argument(
"--enable-prompt-adapter",
action=argparse.BooleanOptionalAction,
help="If True, enable handling of PromptAdapters.")
prompt_adapter_group.add_argument(
"--max-prompt-adapters",
**prompt_adapter_kwargs["max_prompt_adapters"])
prompt_adapter_group.add_argument(
"--max-prompt-adapter-token",
**prompt_adapter_kwargs["max_prompt_adapter_token"])
# Speculative arguments
speculative_group = parser.add_argument_group(
title="SpeculativeConfig",
......@@ -850,6 +830,12 @@ class EngineArgs:
parser.add_argument('--disable-log-stats',
action='store_true',
help='Disable logging statistics.')
parser.add_argument('--enable-prompt-adapter',
action='store_true',
deprecated=True,
help='[DEPRECATED] Prompt adapter has been '
'removed. Setting this flag to True or False'
' has no effect on vLLM behavior.')
return parser
......@@ -1234,11 +1220,6 @@ class EngineArgs:
load_config = self.create_load_config()
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig(
backend=self.guided_decoding_backend,
disable_fallback=self.guided_decoding_disable_fallback,
......@@ -1266,7 +1247,6 @@ class EngineArgs:
load_config=load_config,
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
......@@ -1342,12 +1322,6 @@ class EngineArgs:
recommend_to_remove=False)
return False
# No Prompt Adapter so far.
if self.enable_prompt_adapter:
_raise_or_fallback(feature_name="--enable-prompt-adapter",
recommend_to_remove=False)
return False
# No text embedding inputs so far.
if self.enable_prompt_embeds:
_raise_or_fallback(feature_name="--enable-prompt-embeds",
......@@ -1469,7 +1443,6 @@ class EngineArgs:
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
and model_config.runner_type != "pooling"):
self.enable_chunked_prefill = True
logger.warning(
......
......@@ -29,7 +29,6 @@ from vllm.model_executor.guided_decoding import (
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
......@@ -435,7 +434,6 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
......@@ -468,7 +466,6 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
tokenization_kwargs=tokenization_kwargs,
)
......@@ -491,7 +488,6 @@ class _AsyncLLMEngine(LLMEngine):
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
)
......@@ -861,7 +857,6 @@ class AsyncLLMEngine(EngineClient):
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
......@@ -889,7 +884,6 @@ class AsyncLLMEngine(EngineClient):
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
tokenization_kwargs=tokenization_kwargs,
......@@ -904,7 +898,6 @@ class AsyncLLMEngine(EngineClient):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
......@@ -922,8 +915,6 @@ class AsyncLLMEngine(EngineClient):
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
data_parallel_rank: The (global) data parallel rank that must
......@@ -983,7 +974,6 @@ class AsyncLLMEngine(EngineClient):
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
):
......
......@@ -44,7 +44,6 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
......@@ -223,7 +222,6 @@ class LLMEngine:
self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
......@@ -294,8 +292,6 @@ class LLMEngine:
# Feature flags
"enable_lora":
bool(self.lora_config),
"enable_prompt_adapter":
bool(self.prompt_adapter_config),
"enable_prefix_caching":
self.cache_config.enable_prefix_caching,
"enforce_eager":
......@@ -542,9 +538,6 @@ class LLMEngine:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _add_processed_request(
self,
......@@ -553,7 +546,6 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
......@@ -569,7 +561,6 @@ class LLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return None
......@@ -583,11 +574,10 @@ class LLMEngine:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
lora_request)
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
prompt_adapter_request))
seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
......@@ -598,7 +588,6 @@ class LLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
elif isinstance(params, PoolingParams):
......@@ -608,7 +597,6 @@ class LLMEngine:
params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
else:
......@@ -637,7 +625,6 @@ class LLMEngine:
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add a request to the engine's request pool.
......@@ -658,7 +645,6 @@ class LLMEngine:
the current monotonic time.
lora_request: The LoRA request to add.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: The prompt adapter request to add.
priority: The priority of the request.
Only applicable with priority scheduling.
......@@ -719,7 +705,6 @@ class LLMEngine:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request(
......@@ -728,7 +713,6 @@ class LLMEngine:
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
)
......@@ -741,7 +725,6 @@ class LLMEngine:
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup:
......@@ -769,17 +752,15 @@ class LLMEngine:
if self.vllm_config.speculative_config is not None:
draft_size = \
self.vllm_config.speculative_config.num_speculative_tokens + 1
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority,
draft_size=draft_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
encoder_seq=encoder_seq,
priority=priority,
draft_size=draft_size)
return seq_group
......@@ -790,7 +771,6 @@ class LLMEngine:
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup:
......@@ -798,15 +778,13 @@ class LLMEngine:
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
encoder_seq=encoder_seq,
priority=priority)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
......@@ -1834,16 +1812,6 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def start_profile(self) -> None:
self.model_executor.start_profile()
......
......@@ -10,7 +10,6 @@ from vllm import PoolingParams
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import Device
......@@ -33,7 +32,6 @@ class RPCProcessRequest:
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
def __init__(
......@@ -43,7 +41,6 @@ class RPCProcessRequest:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
super().__init__()
......@@ -53,7 +50,6 @@ class RPCProcessRequest:
self.request_id = request_id
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority
......
......@@ -45,7 +45,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import Device
......@@ -448,7 +447,6 @@ class MQLLMEngineClient(EngineClient):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
......@@ -465,8 +463,6 @@ class MQLLMEngineClient(EngineClient):
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
......@@ -474,8 +470,7 @@ class MQLLMEngineClient(EngineClient):
return cast(
AsyncGenerator[RequestOutput, None],
self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request, priority))
lora_request, trace_headers, priority))
def encode(
self,
......@@ -521,7 +516,6 @@ class MQLLMEngineClient(EngineClient):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
......@@ -575,7 +569,6 @@ class MQLLMEngineClient(EngineClient):
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
))
......
......@@ -304,14 +304,12 @@ class MQLLMEngine:
self._send_outputs(rpc_err)
try:
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)
self.engine.add_request(request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
priority=request.priority)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
......
......@@ -16,7 +16,6 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Device, collect_from_async_generator, random_uuid
......@@ -55,7 +54,6 @@ class EngineClient(ABC):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
......
......@@ -45,7 +45,6 @@ from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
......@@ -314,7 +313,6 @@ class LLM:
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
......@@ -330,7 +328,6 @@ class LLM:
prompt_token_ids: Optional[list[int]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
......@@ -346,7 +343,6 @@ class LLM:
prompt_token_ids: Optional[list[list[int]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
......@@ -363,7 +359,6 @@ class LLM:
prompt_token_ids: list[int],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
......@@ -380,7 +375,6 @@ class LLM:
prompt_token_ids: list[list[int]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
......@@ -395,7 +389,6 @@ class LLM:
prompt_token_ids: Union[list[int], list[list[int]]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
......@@ -415,7 +408,6 @@ class LLM:
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
priority: Optional[list[int]] = None,
......@@ -440,8 +432,6 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled.
......@@ -507,7 +497,6 @@ class LLM:
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
......@@ -963,7 +952,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
......@@ -980,7 +968,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
......@@ -997,7 +984,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
......@@ -1015,7 +1001,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
......@@ -1033,7 +1018,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
......@@ -1049,7 +1033,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
......@@ -1070,7 +1053,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
......@@ -1092,8 +1074,6 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
pooling_task: Override the pooling task to use.
Returns:
......@@ -1150,7 +1130,6 @@ class LLM:
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
......@@ -1167,7 +1146,6 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[EmbeddingRequestOutput]:
"""
Generate an embedding vector for each prompt.
......@@ -1187,8 +1165,6 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
......@@ -1205,7 +1181,6 @@ class LLM:
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="embed",
)
......@@ -1218,7 +1193,6 @@ class LLM:
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ClassificationRequestOutput]:
"""
Generate class logits for each prompt.
......@@ -1236,8 +1210,6 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `ClassificationRequestOutput` objects containing the
......@@ -1253,7 +1225,6 @@ class LLM:
prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="classify",
)
......@@ -1267,7 +1238,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:
encoded_output: list[PoolingRequestOutput] = self.encode(
......@@ -1275,7 +1245,6 @@ class LLM:
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="embed",
)
......@@ -1303,7 +1272,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:
if isinstance(tokenizer, MistralTokenizer):
......@@ -1361,7 +1329,6 @@ class LLM:
params=pooling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
......@@ -1381,7 +1348,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs `<text,text_pair>` or
`<multi-modal data, multi-modal data pair>`.
......@@ -1412,8 +1378,6 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `ScoringRequestOutput` objects containing the
......@@ -1504,8 +1468,7 @@ class LLM:
data_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
lora_request,
prompt_adapter_request)
lora_request)
else:
return self._embedding_score(
tokenizer,
......@@ -1513,8 +1476,7 @@ class LLM:
data_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
lora_request,
prompt_adapter_request)
lora_request)
def start_profile(self) -> None:
self.llm_engine.start_profile()
......@@ -1625,7 +1587,6 @@ class LLM:
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[list[int]] = None,
......@@ -1671,7 +1632,6 @@ class LLM:
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority[i] if priority else 0,
)
......@@ -1681,7 +1641,6 @@ class LLM:
params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
request_id = str(next(self.request_counter))
......@@ -1691,7 +1650,6 @@ class LLM:
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment