Unverified Commit 82ec66f5 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[V0 Deprecation] Remove Prompt Adapters (#20588)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 78c13e30
...@@ -14,7 +14,6 @@ API documentation for vLLM's configuration classes. ...@@ -14,7 +14,6 @@ API documentation for vLLM's configuration classes.
- [vllm.config.DeviceConfig][] - [vllm.config.DeviceConfig][]
- [vllm.config.SpeculativeConfig][] - [vllm.config.SpeculativeConfig][]
- [vllm.config.LoRAConfig][] - [vllm.config.LoRAConfig][]
- [vllm.config.PromptAdapterConfig][]
- [vllm.config.MultiModalConfig][] - [vllm.config.MultiModalConfig][]
- [vllm.config.PoolerConfig][] - [vllm.config.PoolerConfig][]
- [vllm.config.DecodingConfig][] - [vllm.config.DecodingConfig][]
......
...@@ -34,23 +34,22 @@ th:not(:first-child) { ...@@ -34,23 +34,22 @@ th:not(:first-child) {
} }
</style> </style>
| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | <abbr title="Prompt Adapter">prmpt adptr</abbr> | [SD](spec_decode.md) | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search | | Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | [SD](spec_decode.md) | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | | | [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | |
| [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | | | [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | | | [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | |
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | | | [SD](spec_decode.md) | ✅ | ✅ | ❌ | ✅ | | | | | | | | | | |
| [SD](spec_decode.md) | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | | | | | CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | |
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | <abbr title="Pooling Models">pooling</abbr> | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | |
| <abbr title="Pooling Models">pooling</abbr> | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | | | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ❌ | [](gh-issue:7366) | ❌ | [](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | |
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ❌ | [](gh-issue:7366) | ❌ | ❌ | [](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | | | <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | |
| <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | |
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | | | <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | |
| <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | | | multi-step | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | |
| multi-step | ❌ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | | | <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | |
| <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | | best-of | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ✅ | ✅ | |
| best-of | ✅ | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ✅ | ✅ | | | beam-search | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ❔ | ✅ | ✅ |
| beam-search | ✅ | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ❔ | ✅ | ✅ |
[](){ #feature-x-hardware } [](){ #feature-x-hardware }
...@@ -61,7 +60,6 @@ th:not(:first-child) { ...@@ -61,7 +60,6 @@ th:not(:first-child) {
| [CP][chunked-prefill] | [](gh-issue:2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [CP][chunked-prefill] | [](gh-issue:2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [APC](automatic_prefix_caching.md) | [](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [APC](automatic_prefix_caching.md) | [](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | [](gh-issue:8475) | ✅ | ❌ |
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | | CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ |
| <abbr title="Pooling Models">pooling</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ❌ | | <abbr title="Pooling Models">pooling</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ❌ |
......
...@@ -72,7 +72,6 @@ line-length = 80 ...@@ -72,7 +72,6 @@ line-length = 80
"vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"]
"vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/worker/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"]
# Python 3.8 typing - skip utils for ROCm # Python 3.8 typing - skip utils for ROCm
"vllm/utils/__init__.py" = ["UP006", "UP035"] "vllm/utils/__init__.py" = ["UP006", "UP035"]
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# imports for guided decoding tests # imports for guided decoding tests
import json import json
import os
import shutil import shutil
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional
...@@ -26,10 +27,6 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" ...@@ -26,10 +27,6 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically these adapters use a different base model, # technically these adapters use a different base model,
# but we're not testing generation quality here # but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora" LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
...@@ -56,13 +53,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files): ...@@ -56,13 +53,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def zephyr_pa_files(): def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files):
return snapshot_download(repo_id=PA_NAME)
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
return [ return [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -81,15 +72,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, ...@@ -81,15 +72,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",
"2", "2",
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128",
] ]
...@@ -98,8 +80,19 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, ...@@ -98,8 +80,19 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
def server(default_server_args, request): def server(default_server_args, request):
if request.param: if request.param:
default_server_args.append(request.param) default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
original_value = os.environ.get('VLLM_USE_V1')
os.environ['VLLM_USE_V1'] = '0'
try:
with RemoteOpenAIServer(MODEL_NAME,
default_server_args) as remote_server:
yield remote_server yield remote_server
finally:
# Restore original env value
if original_value is None:
os.environ.pop('VLLM_USE_V1', None)
else:
os.environ['VLLM_USE_V1'] = original_value
@pytest_asyncio.fixture @pytest_asyncio.fixture
...@@ -110,14 +103,11 @@ async def client(server): ...@@ -110,14 +103,11 @@ async def client(server):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters # first test base model, then test loras
"model_name,num_virtual_tokens", "model_name",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
) )
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
num_virtual_tokens: int):
completion = await client.completions.create(model=model_name, completion = await client.completions.create(model=model_name,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
...@@ -130,9 +120,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, ...@@ -130,9 +120,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert len(choice.text) >= 5 assert len(choice.text) >= 5
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage( assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, completion_tokens=5, prompt_tokens=6, total_tokens=11)
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)
# test using token IDs # test using token IDs
completion = await client.completions.create( completion = await client.completions.create(
...@@ -175,9 +163,9 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): ...@@ -175,9 +163,9 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters # first test base model, then test loras
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
) )
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
...@@ -194,9 +182,9 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): ...@@ -194,9 +182,9 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora and 1 pa hereafter # just test 1 lora
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
...@@ -217,7 +205,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): ...@@ -217,7 +205,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
...@@ -238,7 +226,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): ...@@ -238,7 +226,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
...@@ -314,7 +302,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, ...@@ -314,7 +302,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_completion_streaming(client: openai.AsyncOpenAI, async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
...@@ -348,7 +336,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, ...@@ -348,7 +336,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling. """Streaming for parallel sampling.
...@@ -382,7 +370,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): ...@@ -382,7 +370,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_completion_stream_options(client: openai.AsyncOpenAI, async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
...@@ -519,7 +507,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, ...@@ -519,7 +507,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs # test both text and token IDs
......
...@@ -13,7 +13,6 @@ from ...utils import RemoteOpenAIServer ...@@ -13,7 +13,6 @@ from ...utils import RemoteOpenAIServer
from .test_completion import default_server_args # noqa: F401 from .test_completion import default_server_args # noqa: F401
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401 from .test_completion import zephyr_lora_files # noqa: F401
from .test_completion import zephyr_pa_files # noqa: F401
from .test_completion import MODEL_NAME from .test_completion import MODEL_NAME
......
...@@ -32,8 +32,7 @@ async def _async_serving_models_init() -> OpenAIServingModels: ...@@ -32,8 +32,7 @@ async def _async_serving_models_init() -> OpenAIServingModels:
serving_models = OpenAIServingModels(engine_client=mock_engine_client, serving_models = OpenAIServingModels(engine_client=mock_engine_client,
base_model_paths=BASE_MODEL_PATHS, base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config, model_config=mock_model_config,
lora_modules=None, lora_modules=None)
prompt_adapters=None)
await serving_models.init_static_loras() await serving_models.init_static_loras()
return serving_models return serving_models
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
def do_sample(llm, pa_name: str, pa_id: int):
prompts = [
"Tweet text : @nationalgridus I have no water and the bill is \
current and paid. Can you do something about this? Label : ",
"Tweet text : @nationalgridus Looks good thanks! Label : "
]
sampling_params = vllm.SamplingParams(temperature=0.0,
max_tokens=3,
stop_token_ids=[3])
outputs = llm.generate(prompts,
sampling_params,
prompt_adapter_request=PromptAdapterRequest(
pa_name, pa_id, PA_PATH, 8) if pa_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_twitter_prompt_adapter(enforce_eager: bool):
llm = vllm.LLM(MODEL_PATH,
enforce_eager=enforce_eager,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
expected_output = ['complaint', 'no complaint']
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
pa_path2 = 'swapnilbp/angry_tweet_ptune'
def do_sample(engine):
prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def do_sample(engine):
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_lora_prompt_adapter():
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)
expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output
...@@ -31,6 +31,5 @@ run_mypy vllm/inputs ...@@ -31,6 +31,5 @@ run_mypy vllm/inputs
run_mypy vllm/lora run_mypy vllm/lora
run_mypy vllm/model_executor run_mypy vllm/model_executor
run_mypy vllm/plugins run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/worker run_mypy vllm/worker
run_mypy vllm/v1 run_mypy vllm/v1
...@@ -3143,59 +3143,6 @@ class LoRAConfig: ...@@ -3143,59 +3143,6 @@ class LoRAConfig:
self.lora_dtype = getattr(torch, self.lora_dtype) self.lora_dtype = getattr(torch, self.lora_dtype)
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class PromptAdapterConfig:
"""Configuration for PromptAdapters."""
max_prompt_adapters: int = 1
"""Max number of PromptAdapters in a batch."""
max_prompt_adapter_token: int = 0
"""Max number of PromptAdapters tokens."""
max_cpu_prompt_adapters: Optional[int] = None
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
`max_prompt_adapters`."""
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
"""Data type for PromptAdapter. If auto, will default to base model dtype.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
f"({self.max_prompt_adapters}) must be >= 1.")
if self.max_prompt_adapter_token == 0:
raise ValueError("max_prompt_adapter_token must be set.")
if self.max_cpu_prompt_adapters is None:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype == "auto":
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
self.prompt_adapter_dtype)
@config @config
@dataclass @dataclass
class MultiModalConfig: class MultiModalConfig:
...@@ -4402,8 +4349,6 @@ class VllmConfig: ...@@ -4402,8 +4349,6 @@ class VllmConfig:
"""Decoding configuration.""" """Decoding configuration."""
observability_config: Optional[ObservabilityConfig] = None observability_config: Optional[ObservabilityConfig] = None
"""Observability configuration.""" """Observability configuration."""
prompt_adapter_config: Optional[PromptAdapterConfig] = None
"""Prompt adapter configuration."""
quant_config: Optional[QuantizationConfig] = None quant_config: Optional[QuantizationConfig] = None
"""Quantization configuration.""" """Quantization configuration."""
compilation_config: CompilationConfig = field( compilation_config: CompilationConfig = field(
...@@ -4500,10 +4445,6 @@ class VllmConfig: ...@@ -4500,10 +4445,6 @@ class VllmConfig:
vllm_factors.append(self.observability_config.compute_hash()) vllm_factors.append(self.observability_config.compute_hash())
else: else:
vllm_factors.append("None") vllm_factors.append("None")
if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config: if self.quant_config:
pass # should be captured by model_config.quantization pass # should be captured by model_config.quantization
if self.compilation_config: if self.compilation_config:
...@@ -4611,9 +4552,6 @@ class VllmConfig: ...@@ -4611,9 +4552,6 @@ class VllmConfig:
if self.lora_config is not None: if self.lora_config is not None:
self.lora_config.verify_with_cache_config(self.cache_config) self.lora_config.verify_with_cache_config(self.cache_config)
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
if self.prompt_adapter_config is not None:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
if self.quant_config is None and self.model_config is not None: if self.quant_config is None and self.model_config is not None:
self.quant_config = VllmConfig._get_quantization_config( self.quant_config = VllmConfig._get_quantization_config(
......
...@@ -15,7 +15,6 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig ...@@ -15,7 +15,6 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupBase, SequenceGroupMetadata, SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupMetadataDelta, SequenceStage, SequenceGroupMetadataDelta, SequenceStage,
...@@ -165,8 +164,6 @@ class SchedulerOutputs: ...@@ -165,8 +164,6 @@ class SchedulerOutputs:
if self.num_loras > 0: if self.num_loras > 0:
self._sort_by_lora_ids() self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool: def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups. # NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
...@@ -194,14 +191,6 @@ class SchedulerOutputs: ...@@ -194,14 +191,6 @@ class SchedulerOutputs:
if g.seq_group.lora_request is not None if g.seq_group.lora_request is not None
} }
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass @dataclass
class SchedulerRunningOutputs: class SchedulerRunningOutputs:
...@@ -1648,7 +1637,6 @@ class Scheduler: ...@@ -1648,7 +1637,6 @@ class Scheduler:
multi_modal_placeholders=( multi_modal_placeholders=(
seq_group.multi_modal_placeholders seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None), if scheduler_outputs.num_prefill_groups > 0 else None),
prompt_adapter_request=seq_group.prompt_adapter_request,
) )
else: else:
# When SPMD mode is enabled, we only send delta data except for # When SPMD mode is enabled, we only send delta data except for
......
...@@ -30,9 +30,9 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ...@@ -30,9 +30,9 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType, LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig, ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy, SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
SpeculativeConfig, TaskOption, TokenizerMode, TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
VllmConfig, get_attr_docs, get_field) get_field)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
...@@ -358,11 +358,6 @@ class EngineArgs: ...@@ -358,11 +358,6 @@ class EngineArgs:
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
# PromptAdapter fields
enable_prompt_adapter: bool = False
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = \
PromptAdapterConfig.max_prompt_adapter_token
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
...@@ -437,6 +432,8 @@ class EngineArgs: ...@@ -437,6 +432,8 @@ class EngineArgs:
ParallelConfig.enable_multimodal_encoder_data_parallel ParallelConfig.enable_multimodal_encoder_data_parallel
async_scheduling: bool = SchedulerConfig.async_scheduling async_scheduling: bool = SchedulerConfig.async_scheduling
# DEPRECATED
enable_prompt_adapter: bool = False
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
...@@ -729,23 +726,6 @@ class EngineArgs: ...@@ -729,23 +726,6 @@ class EngineArgs:
lora_group.add_argument("--default-mm-loras", lora_group.add_argument("--default-mm-loras",
**lora_kwargs["default_mm_loras"]) **lora_kwargs["default_mm_loras"])
# PromptAdapter related configs
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
prompt_adapter_group = parser.add_argument_group(
title="PromptAdapterConfig",
description=PromptAdapterConfig.__doc__,
)
prompt_adapter_group.add_argument(
"--enable-prompt-adapter",
action=argparse.BooleanOptionalAction,
help="If True, enable handling of PromptAdapters.")
prompt_adapter_group.add_argument(
"--max-prompt-adapters",
**prompt_adapter_kwargs["max_prompt_adapters"])
prompt_adapter_group.add_argument(
"--max-prompt-adapter-token",
**prompt_adapter_kwargs["max_prompt_adapter_token"])
# Speculative arguments # Speculative arguments
speculative_group = parser.add_argument_group( speculative_group = parser.add_argument_group(
title="SpeculativeConfig", title="SpeculativeConfig",
...@@ -850,6 +830,12 @@ class EngineArgs: ...@@ -850,6 +830,12 @@ class EngineArgs:
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='Disable logging statistics.') help='Disable logging statistics.')
parser.add_argument('--enable-prompt-adapter',
action='store_true',
deprecated=True,
help='[DEPRECATED] Prompt adapter has been '
'removed. Setting this flag to True or False'
' has no effect on vLLM behavior.')
return parser return parser
...@@ -1234,11 +1220,6 @@ class EngineArgs: ...@@ -1234,11 +1220,6 @@ class EngineArgs:
load_config = self.create_load_config() load_config = self.create_load_config()
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig( decoding_config = DecodingConfig(
backend=self.guided_decoding_backend, backend=self.guided_decoding_backend,
disable_fallback=self.guided_decoding_disable_fallback, disable_fallback=self.guided_decoding_disable_fallback,
...@@ -1266,7 +1247,6 @@ class EngineArgs: ...@@ -1266,7 +1247,6 @@ class EngineArgs:
load_config=load_config, load_config=load_config,
decoding_config=decoding_config, decoding_config=decoding_config,
observability_config=observability_config, observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config, compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config, kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config, kv_events_config=self.kv_events_config,
...@@ -1342,12 +1322,6 @@ class EngineArgs: ...@@ -1342,12 +1322,6 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# No Prompt Adapter so far.
if self.enable_prompt_adapter:
_raise_or_fallback(feature_name="--enable-prompt-adapter",
recommend_to_remove=False)
return False
# No text embedding inputs so far. # No text embedding inputs so far.
if self.enable_prompt_embeds: if self.enable_prompt_embeds:
_raise_or_fallback(feature_name="--enable-prompt-embeds", _raise_or_fallback(feature_name="--enable-prompt-embeds",
...@@ -1469,7 +1443,6 @@ class EngineArgs: ...@@ -1469,7 +1443,6 @@ class EngineArgs:
if (is_gpu and not use_sliding_window and not use_spec_decode if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora and not self.enable_lora
and not self.enable_prompt_adapter
and model_config.runner_type != "pooling"): and model_config.runner_type != "pooling"):
self.enable_chunked_prefill = True self.enable_chunked_prefill = True
logger.warning( logger.warning(
......
...@@ -29,7 +29,6 @@ from vllm.model_executor.guided_decoding import ( ...@@ -29,7 +29,6 @@ from vllm.model_executor.guided_decoding import (
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
...@@ -435,7 +434,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -435,7 +434,6 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
...@@ -468,7 +466,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -468,7 +466,6 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs = await self.input_preprocessor.preprocess_async( processed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
...@@ -491,7 +488,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -491,7 +488,6 @@ class _AsyncLLMEngine(LLMEngine):
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
) )
...@@ -861,7 +857,6 @@ class AsyncLLMEngine(EngineClient): ...@@ -861,7 +857,6 @@ class AsyncLLMEngine(EngineClient):
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
...@@ -889,7 +884,6 @@ class AsyncLLMEngine(EngineClient): ...@@ -889,7 +884,6 @@ class AsyncLLMEngine(EngineClient):
arrival_time=arrival_time or time.time(), arrival_time=arrival_time or time.time(),
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
...@@ -904,7 +898,6 @@ class AsyncLLMEngine(EngineClient): ...@@ -904,7 +898,6 @@ class AsyncLLMEngine(EngineClient):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
...@@ -922,8 +915,6 @@ class AsyncLLMEngine(EngineClient): ...@@ -922,8 +915,6 @@ class AsyncLLMEngine(EngineClient):
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: The priority of the request. priority: The priority of the request.
Only applicable with priority scheduling. Only applicable with priority scheduling.
data_parallel_rank: The (global) data parallel rank that must data_parallel_rank: The (global) data parallel rank that must
...@@ -983,7 +974,6 @@ class AsyncLLMEngine(EngineClient): ...@@ -983,7 +974,6 @@ class AsyncLLMEngine(EngineClient):
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
): ):
......
...@@ -44,7 +44,6 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor ...@@ -44,7 +44,6 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput, from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup, PoolingSequenceGroupOutput, Sequence, SequenceGroup,
...@@ -223,7 +222,6 @@ class LLMEngine: ...@@ -223,7 +222,6 @@ class LLMEngine:
self.load_config = vllm_config.load_config self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
) )
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
) )
...@@ -294,8 +292,6 @@ class LLMEngine: ...@@ -294,8 +292,6 @@ class LLMEngine:
# Feature flags # Feature flags
"enable_lora": "enable_lora":
bool(self.lora_config), bool(self.lora_config),
"enable_prompt_adapter":
bool(self.prompt_adapter_config),
"enable_prefix_caching": "enable_prefix_caching":
self.cache_config.enable_prefix_caching, self.cache_config.enable_prefix_caching,
"enforce_eager": "enforce_eager":
...@@ -542,9 +538,6 @@ class LLMEngine: ...@@ -542,9 +538,6 @@ class LLMEngine:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
self.scheduler_config) self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _add_processed_request( def _add_processed_request(
self, self,
...@@ -553,7 +546,6 @@ class LLMEngine: ...@@ -553,7 +546,6 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> Optional[SequenceGroup]: ) -> Optional[SequenceGroup]:
...@@ -569,7 +561,6 @@ class LLMEngine: ...@@ -569,7 +561,6 @@ class LLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
) )
return None return None
...@@ -583,11 +574,10 @@ class LLMEngine: ...@@ -583,11 +574,10 @@ class LLMEngine:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request) lora_request)
encoder_seq = (None if encoder_inputs is None else Sequence( encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request, seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
prompt_adapter_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams # Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
...@@ -598,7 +588,6 @@ class LLMEngine: ...@@ -598,7 +588,6 @@ class LLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority) priority=priority)
elif isinstance(params, PoolingParams): elif isinstance(params, PoolingParams):
...@@ -608,7 +597,6 @@ class LLMEngine: ...@@ -608,7 +597,6 @@ class LLMEngine:
params, params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority) priority=priority)
else: else:
...@@ -637,7 +625,6 @@ class LLMEngine: ...@@ -637,7 +625,6 @@ class LLMEngine:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
...@@ -658,7 +645,6 @@ class LLMEngine: ...@@ -658,7 +645,6 @@ class LLMEngine:
the current monotonic time. the current monotonic time.
lora_request: The LoRA request to add. lora_request: The LoRA request to add.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: The prompt adapter request to add.
priority: The priority of the request. priority: The priority of the request.
Only applicable with priority scheduling. Only applicable with priority scheduling.
...@@ -719,7 +705,6 @@ class LLMEngine: ...@@ -719,7 +705,6 @@ class LLMEngine:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
) )
self._add_processed_request( self._add_processed_request(
...@@ -728,7 +713,6 @@ class LLMEngine: ...@@ -728,7 +713,6 @@ class LLMEngine:
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
) )
...@@ -741,7 +725,6 @@ class LLMEngine: ...@@ -741,7 +725,6 @@ class LLMEngine:
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
priority: int = 0, priority: int = 0,
) -> SequenceGroup: ) -> SequenceGroup:
...@@ -769,14 +752,12 @@ class LLMEngine: ...@@ -769,14 +752,12 @@ class LLMEngine:
if self.vllm_config.speculative_config is not None: if self.vllm_config.speculative_config is not None:
draft_size = \ draft_size = \
self.vllm_config.speculative_config.num_speculative_tokens + 1 self.vllm_config.speculative_config.num_speculative_tokens + 1
seq_group = SequenceGroup( seq_group = SequenceGroup(request_id=request_id,
request_id=request_id,
seqs=[seq], seqs=[seq],
arrival_time=arrival_time, arrival_time=arrival_time,
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority, priority=priority,
draft_size=draft_size) draft_size=draft_size)
...@@ -790,7 +771,6 @@ class LLMEngine: ...@@ -790,7 +771,6 @@ class LLMEngine:
pooling_params: PoolingParams, pooling_params: PoolingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
priority: int = 0, priority: int = 0,
) -> SequenceGroup: ) -> SequenceGroup:
...@@ -798,13 +778,11 @@ class LLMEngine: ...@@ -798,13 +778,11 @@ class LLMEngine:
# Defensive copy of PoolingParams, which are used by the pooler # Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone() pooling_params = pooling_params.clone()
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup( seq_group = SequenceGroup(request_id=request_id,
request_id=request_id,
seqs=[seq], seqs=[seq],
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
pooling_params=pooling_params, pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority) priority=priority)
return seq_group return seq_group
...@@ -1834,16 +1812,6 @@ class LLMEngine: ...@@ -1834,16 +1812,6 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id) return self.model_executor.pin_lora(lora_id)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def start_profile(self) -> None: def start_profile(self) -> None:
self.model_executor.start_profile() self.model_executor.start_profile()
......
...@@ -10,7 +10,6 @@ from vllm import PoolingParams ...@@ -10,7 +10,6 @@ from vllm import PoolingParams
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import Device from vllm.utils import Device
...@@ -33,7 +32,6 @@ class RPCProcessRequest: ...@@ -33,7 +32,6 @@ class RPCProcessRequest:
request_id: str request_id: str
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0 priority: int = 0
def __init__( def __init__(
...@@ -43,7 +41,6 @@ class RPCProcessRequest: ...@@ -43,7 +41,6 @@ class RPCProcessRequest:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -53,7 +50,6 @@ class RPCProcessRequest: ...@@ -53,7 +50,6 @@ class RPCProcessRequest:
self.request_id = request_id self.request_id = request_id
self.lora_request = lora_request self.lora_request = lora_request
self.trace_headers = trace_headers self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority self.priority = priority
......
...@@ -45,7 +45,6 @@ from vllm.logger import init_logger ...@@ -45,7 +45,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import Device from vllm.utils import Device
...@@ -448,7 +447,6 @@ class MQLLMEngineClient(EngineClient): ...@@ -448,7 +447,6 @@ class MQLLMEngineClient(EngineClient):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request. """Generate outputs for a request.
...@@ -465,8 +463,6 @@ class MQLLMEngineClient(EngineClient): ...@@ -465,8 +463,6 @@ class MQLLMEngineClient(EngineClient):
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling). priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the Any priority other than 0 will lead to an error if the
scheduling policy is not "priority". scheduling policy is not "priority".
...@@ -474,8 +470,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -474,8 +470,7 @@ class MQLLMEngineClient(EngineClient):
return cast( return cast(
AsyncGenerator[RequestOutput, None], AsyncGenerator[RequestOutput, None],
self._process_request(prompt, sampling_params, request_id, self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers, priority))
prompt_adapter_request, priority))
def encode( def encode(
self, self,
...@@ -521,7 +516,6 @@ class MQLLMEngineClient(EngineClient): ...@@ -521,7 +516,6 @@ class MQLLMEngineClient(EngineClient):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]: PoolingRequestOutput, None]]:
...@@ -575,7 +569,6 @@ class MQLLMEngineClient(EngineClient): ...@@ -575,7 +569,6 @@ class MQLLMEngineClient(EngineClient):
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
)) ))
......
...@@ -304,13 +304,11 @@ class MQLLMEngine: ...@@ -304,13 +304,11 @@ class MQLLMEngine:
self._send_outputs(rpc_err) self._send_outputs(rpc_err)
try: try:
self.engine.add_request( self.engine.add_request(request_id=request_id,
request_id=request_id,
prompt=request.prompt, prompt=request.prompt,
params=request.params, params=request.params,
lora_request=request.lora_request, lora_request=request.lora_request,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority) priority=request.priority)
if self.log_requests: if self.log_requests:
......
...@@ -16,7 +16,6 @@ from vllm.lora.request import LoRARequest ...@@ -16,7 +16,6 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Device, collect_from_async_generator, random_uuid from vllm.utils import Device, collect_from_async_generator, random_uuid
...@@ -55,7 +54,6 @@ class EngineClient(ABC): ...@@ -55,7 +54,6 @@ class EngineClient(ABC):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.""" """Generate outputs for a request."""
......
...@@ -45,7 +45,6 @@ from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, ...@@ -45,7 +45,6 @@ from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput, PoolingRequestOutput, RequestOutput,
ScoringRequestOutput) ScoringRequestOutput)
from vllm.pooling_params import PoolingParams, PoolingTask from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams) RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
...@@ -314,7 +313,6 @@ class LLM: ...@@ -314,7 +313,6 @@ class LLM:
*, *,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
...@@ -330,7 +328,6 @@ class LLM: ...@@ -330,7 +328,6 @@ class LLM:
prompt_token_ids: Optional[list[int]] = None, prompt_token_ids: Optional[list[int]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
...@@ -346,7 +343,6 @@ class LLM: ...@@ -346,7 +343,6 @@ class LLM:
prompt_token_ids: Optional[list[list[int]]] = None, prompt_token_ids: Optional[list[list[int]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
...@@ -363,7 +359,6 @@ class LLM: ...@@ -363,7 +359,6 @@ class LLM:
prompt_token_ids: list[int], prompt_token_ids: list[int],
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
...@@ -380,7 +375,6 @@ class LLM: ...@@ -380,7 +375,6 @@ class LLM:
prompt_token_ids: list[list[int]], prompt_token_ids: list[list[int]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
...@@ -395,7 +389,6 @@ class LLM: ...@@ -395,7 +389,6 @@ class LLM:
prompt_token_ids: Union[list[int], list[list[int]]], prompt_token_ids: Union[list[int], list[list[int]]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
...@@ -415,7 +408,6 @@ class LLM: ...@@ -415,7 +408,6 @@ class LLM:
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
priority: Optional[list[int]] = None, priority: Optional[list[int]] = None,
...@@ -440,8 +432,6 @@ class LLM: ...@@ -440,8 +432,6 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
priority: The priority of the requests, if any. priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled. Only applicable when priority scheduling policy is enabled.
...@@ -507,7 +497,6 @@ class LLM: ...@@ -507,7 +497,6 @@ class LLM:
params=sampling_params, params=sampling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request, guided_options=guided_options_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority, priority=priority,
...@@ -963,7 +952,6 @@ class LLM: ...@@ -963,7 +952,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
...@@ -980,7 +968,6 @@ class LLM: ...@@ -980,7 +968,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
...@@ -997,7 +984,6 @@ class LLM: ...@@ -997,7 +984,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
...@@ -1015,7 +1001,6 @@ class LLM: ...@@ -1015,7 +1001,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
...@@ -1033,7 +1018,6 @@ class LLM: ...@@ -1033,7 +1018,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
...@@ -1049,7 +1033,6 @@ class LLM: ...@@ -1049,7 +1033,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
...@@ -1070,7 +1053,6 @@ class LLM: ...@@ -1070,7 +1053,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
...@@ -1092,8 +1074,6 @@ class LLM: ...@@ -1092,8 +1074,6 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
pooling_task: Override the pooling task to use. pooling_task: Override the pooling task to use.
Returns: Returns:
...@@ -1150,7 +1130,6 @@ class LLM: ...@@ -1150,7 +1130,6 @@ class LLM:
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
) )
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
...@@ -1167,7 +1146,6 @@ class LLM: ...@@ -1167,7 +1146,6 @@ class LLM:
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[EmbeddingRequestOutput]: ) -> list[EmbeddingRequestOutput]:
""" """
Generate an embedding vector for each prompt. Generate an embedding vector for each prompt.
...@@ -1187,8 +1165,6 @@ class LLM: ...@@ -1187,8 +1165,6 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns: Returns:
A list of `EmbeddingRequestOutput` objects containing the A list of `EmbeddingRequestOutput` objects containing the
...@@ -1205,7 +1181,6 @@ class LLM: ...@@ -1205,7 +1181,6 @@ class LLM:
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
pooling_params=pooling_params, pooling_params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="embed", pooling_task="embed",
) )
...@@ -1218,7 +1193,6 @@ class LLM: ...@@ -1218,7 +1193,6 @@ class LLM:
*, *,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ClassificationRequestOutput]: ) -> list[ClassificationRequestOutput]:
""" """
Generate class logits for each prompt. Generate class logits for each prompt.
...@@ -1236,8 +1210,6 @@ class LLM: ...@@ -1236,8 +1210,6 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns: Returns:
A list of `ClassificationRequestOutput` objects containing the A list of `ClassificationRequestOutput` objects containing the
...@@ -1253,7 +1225,6 @@ class LLM: ...@@ -1253,7 +1225,6 @@ class LLM:
prompts, prompts,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="classify", pooling_task="classify",
) )
...@@ -1267,7 +1238,6 @@ class LLM: ...@@ -1267,7 +1238,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
encoded_output: list[PoolingRequestOutput] = self.encode( encoded_output: list[PoolingRequestOutput] = self.encode(
...@@ -1275,7 +1245,6 @@ class LLM: ...@@ -1275,7 +1245,6 @@ class LLM:
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="embed", pooling_task="embed",
) )
...@@ -1303,7 +1272,6 @@ class LLM: ...@@ -1303,7 +1272,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
...@@ -1361,7 +1329,6 @@ class LLM: ...@@ -1361,7 +1329,6 @@ class LLM:
params=pooling_params, params=pooling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
) )
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
...@@ -1381,7 +1348,6 @@ class LLM: ...@@ -1381,7 +1348,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs `<text,text_pair>` or """Generate similarity scores for all pairs `<text,text_pair>` or
`<multi-modal data, multi-modal data pair>`. `<multi-modal data, multi-modal data pair>`.
...@@ -1412,8 +1378,6 @@ class LLM: ...@@ -1412,8 +1378,6 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns: Returns:
A list of `ScoringRequestOutput` objects containing the A list of `ScoringRequestOutput` objects containing the
...@@ -1504,8 +1468,7 @@ class LLM: ...@@ -1504,8 +1468,7 @@ class LLM:
data_2, # type: ignore[arg-type] data_2, # type: ignore[arg-type]
truncate_prompt_tokens, truncate_prompt_tokens,
use_tqdm, use_tqdm,
lora_request, lora_request)
prompt_adapter_request)
else: else:
return self._embedding_score( return self._embedding_score(
tokenizer, tokenizer,
...@@ -1513,8 +1476,7 @@ class LLM: ...@@ -1513,8 +1476,7 @@ class LLM:
data_2, # type: ignore[arg-type] data_2, # type: ignore[arg-type]
truncate_prompt_tokens, truncate_prompt_tokens,
use_tqdm, use_tqdm,
lora_request, lora_request)
prompt_adapter_request)
def start_profile(self) -> None: def start_profile(self) -> None:
self.llm_engine.start_profile() self.llm_engine.start_profile()
...@@ -1625,7 +1587,6 @@ class LLM: ...@@ -1625,7 +1587,6 @@ class LLM:
*, *,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
guided_options: Optional[GuidedDecodingRequest] = None, guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[list[int]] = None, priority: Optional[list[int]] = None,
...@@ -1671,7 +1632,6 @@ class LLM: ...@@ -1671,7 +1632,6 @@ class LLM:
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority[i] if priority else 0, priority=priority[i] if priority else 0,
) )
...@@ -1681,7 +1641,6 @@ class LLM: ...@@ -1681,7 +1641,6 @@ class LLM:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
...@@ -1691,7 +1650,6 @@ class LLM: ...@@ -1691,7 +1650,6 @@ class LLM:
params, params,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment