Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82ec66f5
Unverified
Commit
82ec66f5
authored
Jul 23, 2025
by
Michael Goin
Committed by
GitHub
Jul 23, 2025
Browse files
[V0 Deprecation] Remove Prompt Adapters (#20588)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
78c13e30
Changes
60
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
32 additions
and
695 deletions
+32
-695
vllm/entrypoints/logger.py
vllm/entrypoints/logger.py
+2
-5
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+0
-1
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+1
-35
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+0
-1
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-8
vllm/entrypoints/openai/serving_classification.py
vllm/entrypoints/openai/serving_classification.py
+1
-9
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-6
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+1
-8
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+8
-23
vllm/entrypoints/openai/serving_models.py
vllm/entrypoints/openai/serving_models.py
+0
-31
vllm/entrypoints/openai/serving_pooling.py
vllm/entrypoints/openai/serving_pooling.py
+2
-10
vllm/entrypoints/openai/serving_responses.py
vllm/entrypoints/openai/serving_responses.py
+2
-7
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+3
-19
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+4
-17
vllm/entrypoints/openai/speech_to_text.py
vllm/entrypoints/openai/speech_to_text.py
+2
-10
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+0
-31
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+2
-33
vllm/prompt_adapter/__init__.py
vllm/prompt_adapter/__init__.py
+0
-0
vllm/prompt_adapter/layers.py
vllm/prompt_adapter/layers.py
+0
-83
vllm/prompt_adapter/models.py
vllm/prompt_adapter/models.py
+0
-358
No files found.
vllm/entrypoints/logger.py
View file @
82ec66f5
...
...
@@ -8,7 +8,6 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
logger
=
init_logger
(
__name__
)
...
...
@@ -30,7 +29,6 @@ class RequestLogger:
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
,
BeamSearchParams
]],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
None
:
max_log_len
=
self
.
max_log_len
if
max_log_len
is
not
None
:
...
...
@@ -44,7 +42,6 @@ class RequestLogger:
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"prompt_embeds shape: %s, "
"lora_request: %s, prompt_adapter_request: %s."
,
request_id
,
prompt
,
params
,
prompt_token_ids
,
"lora_request: %s."
,
request_id
,
prompt
,
params
,
prompt_token_ids
,
prompt_embeds
.
shape
if
prompt_embeds
is
not
None
else
None
,
lora_request
,
prompt_adapter_request
)
lora_request
)
vllm/entrypoints/openai/api_server.py
View file @
82ec66f5
...
...
@@ -1620,7 +1620,6 @@ async def init_app_state(
model_config
=
model_config
,
base_model_paths
=
base_model_paths
,
lora_modules
=
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
)
await
state
.
openai_serving_models
.
init_static_loras
()
state
.
openai_serving_responses
=
OpenAIServingResponses
(
...
...
vllm/entrypoints/openai/cli_args.py
View file @
82ec66f5
...
...
@@ -20,8 +20,7 @@ from vllm.config import config
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
optional_type
from
vllm.entrypoints.chat_utils
import
(
ChatTemplateContentFormatOption
,
validate_chat_template
)
from
vllm.entrypoints.openai.serving_models
import
(
LoRAModulePath
,
PromptAdapterPath
)
from
vllm.entrypoints.openai.serving_models
import
LoRAModulePath
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.logger
import
init_logger
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -65,27 +64,6 @@ class LoRAParserAction(argparse.Action):
setattr
(
namespace
,
self
.
dest
,
lora_list
)
class
PromptAdapterParserAction
(
argparse
.
Action
):
def
__call__
(
self
,
parser
:
argparse
.
ArgumentParser
,
namespace
:
argparse
.
Namespace
,
values
:
Optional
[
Union
[
str
,
Sequence
[
str
]]],
option_string
:
Optional
[
str
]
=
None
,
):
if
values
is
None
:
values
=
[]
if
isinstance
(
values
,
str
):
raise
TypeError
(
"Expected values to be a list"
)
adapter_list
:
list
[
PromptAdapterPath
]
=
[]
for
item
in
values
:
name
,
path
=
item
.
split
(
'='
)
adapter_list
.
append
(
PromptAdapterPath
(
name
,
path
))
setattr
(
namespace
,
self
.
dest
,
adapter_list
)
@
config
@
dataclass
class
FrontendArgs
:
...
...
@@ -115,9 +93,6 @@ class FrontendArgs:
or JSON list format. Example (old format): `'name=path'` Example (new
format): `{
\"
name
\"
:
\"
name
\"
,
\"
path
\"
:
\"
lora_path
\"
,
\"
base_model_name
\"
:
\"
id
\"
}`"""
prompt_adapters
:
Optional
[
list
[
PromptAdapterPath
]]
=
None
"""Prompt adapter configurations in the format name=path. Multiple adapters
can be specified."""
chat_template
:
Optional
[
str
]
=
None
"""The file path to the chat template, or the template in single-line form
for the specified model."""
...
...
@@ -207,12 +182,6 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
frontend_kwargs
[
"lora_modules"
][
"type"
]
=
optional_type
(
str
)
frontend_kwargs
[
"lora_modules"
][
"action"
]
=
LoRAParserAction
# Special case: Prompt adapters need custom parser action and
# optional_type(str)
frontend_kwargs
[
"prompt_adapters"
][
"type"
]
=
optional_type
(
str
)
frontend_kwargs
[
"prompt_adapters"
][
"action"
]
=
PromptAdapterParserAction
# Special case: Middleware needs append action
frontend_kwargs
[
"middleware"
][
"action"
]
=
"append"
frontend_kwargs
[
"middleware"
][
"type"
]
=
str
...
...
@@ -288,9 +257,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
if
args
.
enable_auto_tool_choice
and
not
args
.
tool_call_parser
:
raise
TypeError
(
"Error: --enable-auto-tool-choice requires "
"--tool-call-parser"
)
if
args
.
enable_prompt_embeds
and
args
.
enable_prompt_adapter
:
raise
ValueError
(
"Cannot use prompt embeds and prompt adapter at the same time."
)
def
log_non_default_args
(
args
:
argparse
.
Namespace
):
...
...
vllm/entrypoints/openai/run_batch.py
View file @
82ec66f5
...
...
@@ -337,7 +337,6 @@ async def main(args):
model_config
=
model_config
,
base_model_paths
=
base_model_paths
,
lora_modules
=
None
,
prompt_adapters
=
None
,
)
openai_serving_chat
=
OpenAIServingChat
(
engine
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
82ec66f5
...
...
@@ -147,11 +147,8 @@ class OpenAIServingChat(OpenAIServing):
raise
self
.
engine_client
.
dead_error
try
:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
,
supports_default_mm_loras
=
True
)
lora_request
=
self
.
_maybe_get_adapters
(
request
,
supports_default_mm_loras
=
True
)
model_name
=
self
.
_get_model_name
(
request
.
model
,
lora_request
)
...
...
@@ -239,8 +236,7 @@ class OpenAIServingChat(OpenAIServing):
self
.
_log_inputs
(
request_id
,
request_prompts
[
i
],
params
=
sampling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
lora_request
=
lora_request
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
))
...
...
@@ -259,7 +255,6 @@ class OpenAIServingChat(OpenAIServing):
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
request
.
priority
,
)
...
...
vllm/entrypoints/openai/serving_classification.py
View file @
82ec66f5
...
...
@@ -49,19 +49,11 @@ class ClassificationMixin(OpenAIServing):
return
None
try
:
(
ctx
.
lora_request
,
ctx
.
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
ctx
.
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
ctx
.
lora_request
)
if
ctx
.
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported for classification models"
)
(
ctx
.
request_prompts
,
ctx
.
engine_prompts
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
82ec66f5
...
...
@@ -121,10 +121,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request
.
state
.
request_metadata
=
request_metadata
try
:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
...
...
@@ -197,7 +194,6 @@ class OpenAIServingCompletion(OpenAIServing):
request_prompts
[
i
],
params
=
sampling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
...
...
@@ -221,7 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params
,
request_id_item
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
82ec66f5
...
...
@@ -53,18 +53,11 @@ class EmbeddingMixin(OpenAIServing):
)
->
Optional
[
ErrorResponse
]:
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
try
:
(
ctx
.
lora_request
,
ctx
.
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
ctx
.
lora_request
)
if
ctx
.
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for embedding models"
)
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
(
_
,
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
82ec66f5
...
...
@@ -68,7 +68,6 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
MultiModalDataDict
)
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
Logprob
,
PromptLogprobs
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
...
...
@@ -161,7 +160,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
request_id
:
str
created_time
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
lora_request
:
Optional
[
LoRARequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
# Shared across most requests
tokenizer
:
Optional
[
AnyTokenizer
]
=
None
...
...
@@ -343,12 +341,10 @@ class OpenAIServing:
return
self
.
create_error_response
(
"Request prompts not available"
)
self
.
_log_inputs
(
request_id_item
,
ctx
.
request_prompts
[
i
],
params
=
pooling_params
,
lora_request
=
ctx
.
lora_request
,
prompt_adapter_request
=
ctx
.
prompt_adapter_request
)
self
.
_log_inputs
(
request_id_item
,
ctx
.
request_prompts
[
i
],
params
=
pooling_params
,
lora_request
=
ctx
.
lora_request
)
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
...
...
@@ -450,11 +446,6 @@ class OpenAIServing:
if
isinstance
(
load_result
,
ErrorResponse
)
and
\
load_result
.
code
==
HTTPStatus
.
BAD_REQUEST
.
value
:
error_response
=
load_result
if
request
.
model
in
[
prompt_adapter
.
prompt_adapter_name
for
prompt_adapter
in
self
.
models
.
prompt_adapter_requests
]:
return
None
return
error_response
or
self
.
create_error_response
(
message
=
f
"The model `
{
request
.
model
}
` does not exist."
,
...
...
@@ -489,25 +480,21 @@ class OpenAIServing:
self
,
request
:
AnyRequest
,
supports_default_mm_loras
:
bool
=
False
,
)
->
Union
[
tuple
[
None
,
None
],
tuple
[
LoRARequest
,
None
],
tuple
[
None
,
PromptAdapterRequest
]]:
)
->
Optional
[
LoRARequest
]:
if
request
.
model
in
self
.
models
.
lora_requests
:
return
self
.
models
.
lora_requests
[
request
.
model
]
,
None
return
self
.
models
.
lora_requests
[
request
.
model
]
# Currently only support default modality specific loras
# if we have exactly one lora matched on the request.
if
supports_default_mm_loras
:
default_mm_lora
=
self
.
_get_active_default_mm_loras
(
request
)
if
default_mm_lora
is
not
None
:
return
default_mm_lora
,
None
return
default_mm_lora
if
self
.
_is_model_supported
(
request
.
model
):
return
None
,
None
return
None
for
prompt_adapter
in
self
.
models
.
prompt_adapter_requests
:
if
request
.
model
==
prompt_adapter
.
prompt_adapter_name
:
return
None
,
prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise
ValueError
(
f
"The model `
{
request
.
model
}
` does not exist."
)
...
...
@@ -987,7 +974,6 @@ class OpenAIServing:
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
,
BeamSearchParams
]],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
None
:
if
self
.
request_logger
is
None
:
return
...
...
@@ -1009,7 +995,6 @@ class OpenAIServing:
prompt_embeds
,
params
=
params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
async
def
_get_trace_headers
(
...
...
vllm/entrypoints/openai/serving_models.py
View file @
82ec66f5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
pathlib
from
asyncio
import
Lock
from
collections
import
defaultdict
from
dataclasses
import
dataclass
...
...
@@ -19,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.resolver
import
LoRAResolver
,
LoRAResolverRegistry
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.utils
import
AtomicCounter
logger
=
init_logger
(
__name__
)
...
...
@@ -31,12 +28,6 @@ class BaseModelPath:
model_path
:
str
@
dataclass
class
PromptAdapterPath
:
name
:
str
local_path
:
str
@
dataclass
class
LoRAModulePath
:
name
:
str
...
...
@@ -60,7 +51,6 @@ class OpenAIServingModels:
base_model_paths
:
list
[
BaseModelPath
],
*
,
lora_modules
:
Optional
[
list
[
LoRAModulePath
]]
=
None
,
prompt_adapters
:
Optional
[
list
[
PromptAdapterPath
]]
=
None
,
):
super
().
__init__
()
...
...
@@ -81,20 +71,6 @@ class OpenAIServingModels:
LoRAResolverRegistry
.
get_resolver
(
lora_resolver_name
))
self
.
lora_resolver_lock
:
dict
[
str
,
Lock
]
=
defaultdict
(
Lock
)
self
.
prompt_adapter_requests
=
[]
if
prompt_adapters
is
not
None
:
for
i
,
prompt_adapter
in
enumerate
(
prompt_adapters
,
start
=
1
):
with
pathlib
.
Path
(
prompt_adapter
.
local_path
,
"adapter_config.json"
).
open
()
as
f
:
adapter_config
=
json
.
load
(
f
)
num_virtual_tokens
=
adapter_config
[
"num_virtual_tokens"
]
self
.
prompt_adapter_requests
.
append
(
PromptAdapterRequest
(
prompt_adapter_name
=
prompt_adapter
.
name
,
prompt_adapter_id
=
i
,
prompt_adapter_local_path
=
prompt_adapter
.
local_path
,
prompt_adapter_num_virtual_tokens
=
num_virtual_tokens
))
async
def
init_static_loras
(
self
):
"""Loads all static LoRA modules.
Raises if any fail to load"""
...
...
@@ -141,14 +117,7 @@ class OpenAIServingModels:
permission
=
[
ModelPermission
()])
for
lora
in
self
.
lora_requests
.
values
()
]
prompt_adapter_cards
=
[
ModelCard
(
id
=
prompt_adapter
.
prompt_adapter_name
,
root
=
self
.
base_model_paths
[
0
].
name
,
permission
=
[
ModelPermission
()])
for
prompt_adapter
in
self
.
prompt_adapter_requests
]
model_cards
.
extend
(
lora_cards
)
model_cards
.
extend
(
prompt_adapter_cards
)
return
ModelList
(
data
=
model_cards
)
async
def
load_lora_adapter
(
...
...
vllm/entrypoints/openai/serving_pooling.py
View file @
82ec66f5
...
...
@@ -94,17 +94,10 @@ class OpenAIServingPooling(OpenAIServing):
try
:
truncate_prompt_tokens
=
_validate_truncation_size
(
self
.
max_model_len
,
truncate_prompt_tokens
)
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
if
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for pooling models"
)
if
isinstance
(
request
,
PoolingChatRequest
):
(
_
,
...
...
@@ -153,8 +146,7 @@ class OpenAIServingPooling(OpenAIServing):
self
.
_log_inputs
(
request_id_item
,
request_prompts
[
i
],
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
lora_request
=
lora_request
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
))
...
...
vllm/entrypoints/openai/serving_responses.py
View file @
82ec66f5
...
...
@@ -133,10 +133,7 @@ class OpenAIServingResponses(OpenAIServing):
messages
=
self
.
_construct_input_messages
(
request
,
prev_response
)
try
:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
model_name
=
self
.
_get_model_name
(
request
.
model
,
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
...
...
@@ -169,8 +166,7 @@ class OpenAIServingResponses(OpenAIServing):
self
.
_log_inputs
(
request
.
request_id
,
request_prompts
[
i
],
params
=
sampling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
lora_request
=
lora_request
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
))
...
...
@@ -181,7 +177,6 @@ class OpenAIServingResponses(OpenAIServing):
request
.
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
request
.
priority
,
)
generators
.
append
(
generator
)
...
...
vllm/entrypoints/openai/serving_score.py
View file @
82ec66f5
...
...
@@ -27,7 +27,6 @@ from vllm.inputs.data import TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
make_async
,
merge_async_iterators
...
...
@@ -58,8 +57,6 @@ class ServingScores(OpenAIServing):
request_id
:
str
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
Union
[
LoRARequest
,
None
]]
=
None
,
prompt_adapter_request
:
Optional
[
Union
[
PromptAdapterRequest
,
None
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
input_texts
=
texts_1
+
texts_2
...
...
@@ -100,8 +97,7 @@ class ServingScores(OpenAIServing):
self
.
_log_inputs
(
request_id_item
,
input_texts
[
i
],
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
lora_request
=
lora_request
)
generators
.
append
(
self
.
engine_client
.
encode
(
...
...
@@ -176,8 +172,6 @@ class ServingScores(OpenAIServing):
request_id
:
str
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
Union
[
LoRARequest
,
None
]]
=
None
,
prompt_adapter_request
:
Optional
[
Union
[
PromptAdapterRequest
,
None
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
request_prompts
:
list
[
str
]
=
[]
...
...
@@ -261,8 +255,7 @@ class ServingScores(OpenAIServing):
self
.
_log_inputs
(
request_id_item
,
request_prompts
[
i
],
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
lora_request
=
lora_request
)
generator
=
self
.
engine_client
.
encode
(
engine_prompt
,
...
...
@@ -295,14 +288,7 @@ class ServingScores(OpenAIServing):
raw_request
:
Optional
[
Request
]
=
None
,
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
if
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for scoring models"
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
...
...
@@ -340,7 +326,6 @@ class ServingScores(OpenAIServing):
request_id
=
request_id
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
)
else
:
...
...
@@ -352,7 +337,6 @@ class ServingScores(OpenAIServing):
request_id
=
request_id
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
)
async
def
create_score
(
...
...
vllm/entrypoints/openai/serving_tokenization.py
View file @
82ec66f5
...
...
@@ -60,10 +60,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_id
=
f
"tokn-
{
self
.
_base_request_id
(
raw_request
)
}
"
try
:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
...
...
@@ -104,11 +101,8 @@ class OpenAIServingTokenization(OpenAIServing):
self
.
_log_inputs
(
request_id
,
request_prompts
[
i
],
params
=
None
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
lora_request
=
lora_request
)
# Silently ignore prompt adapter since it does not affect
# tokenization (Unlike in Embeddings API where an error is raised)
if
isinstance
(
engine_prompt
,
dict
)
and
"prompt_token_ids"
in
engine_prompt
:
input_ids
.
extend
(
engine_prompt
[
"prompt_token_ids"
])
...
...
@@ -133,21 +127,14 @@ class OpenAIServingTokenization(OpenAIServing):
request_id
=
f
"tokn-
{
self
.
_base_request_id
(
raw_request
)
}
"
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
self
.
_log_inputs
(
request_id
,
request
.
tokens
,
params
=
None
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
# Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised)
lora_request
=
lora_request
)
prompt_input
=
await
self
.
_tokenize_prompt_input_async
(
request
,
...
...
vllm/entrypoints/openai/speech_to_text.py
View file @
82ec66f5
...
...
@@ -150,19 +150,12 @@ class OpenAISpeechToText(OpenAIServing):
raw_request
.
state
.
request_metadata
=
request_metadata
try
:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
if
lora_request
:
return
self
.
create_error_response
(
"Currently do not support LoRA for "
f
"
{
self
.
task_type
.
title
()
}
."
)
if
prompt_adapter_request
:
return
self
.
create_error_response
(
f
"Currently do not support PromptAdapter for "
f
"
{
self
.
task_type
.
title
()
}
."
)
prompts
,
duration_s
=
await
self
.
_preprocess_speech_to_text
(
request
=
request
,
...
...
@@ -188,8 +181,7 @@ class OpenAISpeechToText(OpenAIServing):
# It will not display special tokens like <|startoftranscript|>
request
.
prompt
,
params
=
sampling_params
,
lora_request
=
None
,
prompt_adapter_request
=
None
)
lora_request
=
None
)
list_result_generator
=
[
self
.
engine_client
.
generate
(
...
...
vllm/executor/executor_base.py
View file @
82ec66f5
...
...
@@ -17,7 +17,6 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.pooling_params
import
PoolingTask
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.utils
import
make_async
from
vllm.worker.worker_base
import
WorkerBase
...
...
@@ -50,7 +49,6 @@ class ExecutorBase(ABC):
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device_config
=
vllm_config
.
device_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
_init_executor
()
self
.
is_sleeping
=
False
...
...
@@ -171,35 +169,6 @@ class ExecutorBase(ABC):
assert
s
==
sets
[
0
],
"All workers should have the same LORAs."
return
sets
[
0
]
def
add_prompt_adapter
(
self
,
prompt_adapter_request
:
PromptAdapterRequest
)
->
bool
:
assert
prompt_adapter_request
.
prompt_adapter_id
>
0
,
\
"prompt_adapter_id must be greater than 0."
return
all
(
self
.
collective_rpc
(
"add_prompt_adapter"
,
args
=
(
prompt_adapter_request
,
)))
def
remove_prompt_adapter
(
self
,
prompt_adapter_id
:
int
)
->
bool
:
assert
prompt_adapter_id
>
0
,
\
"prompt_adapter_id must be greater than 0."
return
all
(
self
.
collective_rpc
(
"remove_prompt_adapter"
,
args
=
(
prompt_adapter_id
,
)))
def
pin_prompt_adapter
(
self
,
prompt_adapter_id
:
int
)
->
bool
:
assert
prompt_adapter_id
>
0
,
\
"prompt_adapter_id must be greater than 0."
return
all
(
self
.
collective_rpc
(
"pin_prompt_adapter"
,
args
=
(
prompt_adapter_id
,
)))
def
list_prompt_adapters
(
self
)
->
Set
[
int
]:
sets
=
self
.
collective_rpc
(
"list_prompt_adapters"
)
for
s
in
sets
:
assert
(
s
==
sets
[
0
]
),
"All workers should have the same prompt adapters."
return
sets
[
0
]
def
start_profile
(
self
)
->
None
:
self
.
collective_rpc
(
"start_profile"
)
...
...
vllm/inputs/preprocess.py
View file @
82ec66f5
...
...
@@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
)
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
...
...
@@ -168,18 +167,6 @@ class InputPreprocessor:
return
decoder_input_ids
def
_apply_prompt_adapter
(
self
,
prompt_token_ids
:
list
[
int
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
list
[
int
]:
if
prompt_adapter_request
:
prompt_token_ids
=
(
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
+
prompt_token_ids
)
return
prompt_token_ids
def
_get_tokenization_kw
(
self
,
overrides
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
...
...
@@ -786,15 +773,10 @@ class InputPreprocessor:
def
_build_decoder_only_llm_inputs
(
self
,
prompt_inputs
:
DecoderOnlyInputs
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
DecoderOnlyInputs
:
if
"prompt_token_ids"
in
prompt_inputs
:
prompt_inputs
=
cast
(
Union
[
TokenInputs
,
MultiModalInputs
],
prompt_inputs
)
# Needed for mypy
prompt_inputs
[
"prompt_token_ids"
]
=
self
.
_apply_prompt_adapter
(
prompt_inputs
[
"prompt_token_ids"
],
prompt_adapter_request
=
prompt_adapter_request
,
)
return
prompt_inputs
...
...
@@ -803,7 +785,6 @@ class InputPreprocessor:
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
DecoderOnlyInputs
:
"""
...
...
@@ -815,7 +796,6 @@ class InputPreprocessor:
* prompt: input prompt
* lora_request
* prompt_adapter_request
* return_mm_hashes
Returns:
...
...
@@ -830,17 +810,13 @@ class InputPreprocessor:
return_mm_hashes
=
return_mm_hashes
,
)
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
,
prompt_adapter_request
=
prompt_adapter_request
,
)
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
)
async
def
_process_decoder_only_prompt_async
(
self
,
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
DecoderOnlyInputs
:
"""
...
...
@@ -854,17 +830,13 @@ class InputPreprocessor:
return_mm_hashes
=
return_mm_hashes
,
)
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
,
prompt_adapter_request
=
prompt_adapter_request
,
)
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
)
def
preprocess
(
self
,
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
ProcessorInputs
:
"""Preprocess the input prompt."""
...
...
@@ -886,7 +858,6 @@ class InputPreprocessor:
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
return_mm_hashes
=
return_mm_hashes
,
)
...
...
@@ -895,7 +866,6 @@ class InputPreprocessor:
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
ProcessorInputs
:
"""
...
...
@@ -919,6 +889,5 @@ class InputPreprocessor:
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
return_mm_hashes
=
return_mm_hashes
,
)
vllm/prompt_adapter/__init__.py
deleted
100644 → 0
View file @
78c13e30
vllm/prompt_adapter/layers.py
deleted
100644 → 0
View file @
78c13e30
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
from
torch
import
nn
from
vllm.adapter_commons.layers
import
AdapterMapping
from
vllm.config
import
PromptAdapterConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
@
dataclass
class
PromptAdapterMapping
(
AdapterMapping
):
pass
class
VocabParallelEmbeddingWithPromptAdapter
(
nn
.
Module
):
def
__init__
(
self
,
base_layer
:
VocabParallelEmbedding
)
->
None
:
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
emb_layer
=
self
.
base_layer
if
'LoRA'
in
base_layer
.
__class__
.
__name__
:
self
.
emb_layer
=
self
.
base_layer
.
base_layer
def
create_prompt_adapter_weights
(
self
,
prompt_adapter_config
:
PromptAdapterConfig
):
self
.
embeddings_tensors
=
torch
.
zeros
(
(
prompt_adapter_config
.
max_prompt_adapters
,
prompt_adapter_config
.
max_prompt_adapter_token
,
self
.
emb_layer
.
embedding_dim
,
),
dtype
=
self
.
emb_layer
.
weight
.
dtype
,
device
=
self
.
emb_layer
.
weight
.
device
,
)
self
.
adapter_lengths
=
torch
.
zeros
(
prompt_adapter_config
.
max_prompt_adapters
,
dtype
=
torch
.
long
,
device
=
self
.
emb_layer
.
weight
.
device
)
self
.
indices_gpu
:
torch
.
Tensor
self
.
embedding_indices_gpu
:
torch
.
Tensor
def
reset_prompt_adapter
(
self
,
index
:
int
):
self
.
embeddings_tensors
[
index
]
=
0
def
set_prompt_adapter
(
self
,
index
:
int
,
adapter_model
:
Optional
[
torch
.
Tensor
],
):
self
.
reset_prompt_adapter
(
index
)
if
adapter_model
is
not
None
:
length
=
adapter_model
.
shape
[
0
]
self
.
embeddings_tensors
[
index
,
:
length
]
=
adapter_model
self
.
adapter_lengths
[
index
]
=
length
def
set_mapping
(
self
,
prompt_indices
:
torch
.
Tensor
,
prompt_embedding_indices
:
torch
.
Tensor
,
):
self
.
indices_gpu
=
prompt_indices
.
to
(
device
=
self
.
emb_layer
.
weight
.
device
)
self
.
embedding_indices_gpu
=
prompt_embedding_indices
.
to
(
device
=
self
.
emb_layer
.
weight
.
device
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
base_layer
(
x
)
if
self
.
embedding_indices_gpu
.
ndim
>
1
:
valid_mask
=
self
.
indices_gpu
!=
-
1
gathered_embeddings
=
self
.
embeddings_tensors
[
self
.
embedding_indices_gpu
[:,
0
],
self
.
embedding_indices_gpu
[:,
1
]]
# Update hidden states
hidden_states
[
valid_mask
]
=
gathered_embeddings
return
hidden_states
\ No newline at end of file
vllm/prompt_adapter/models.py
deleted
100644 → 0
View file @
78c13e30
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
logging
import
math
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
import
torch
from
torch
import
nn
from
vllm.adapter_commons.models
import
(
AdapterLRUCache
,
AdapterModel
,
AdapterModelManager
)
from
vllm.adapter_commons.utils
import
(
add_adapter
,
deactivate_adapter
,
get_adapter
,
list_adapters
,
remove_adapter
,
set_adapter_mapping
)
from
vllm.config
import
PromptAdapterConfig
from
vllm.prompt_adapter.layers
import
(
VocabParallelEmbeddingWithPromptAdapter
)
# yapf: disable
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.utils
import
load_peft_weights
logger
=
logging
.
getLogger
(
__name__
)
_GLOBAL_PROMPT_ADAPTER_ID
=
0
def
get_prompt_adapter_id
():
global
_GLOBAL_PROMPT_ADAPTER_ID
_GLOBAL_PROMPT_ADAPTER_ID
+=
1
return
_GLOBAL_PROMPT_ADAPTER_ID
def
convert_to_embedding_indices
(
indices
):
embedding_indices
=
[]
count
=
0
for
value
in
indices
:
if
value
==
-
1
:
count
=
0
else
:
embedding_indices
.
append
([
value
,
count
])
count
+=
1
return
torch
.
tensor
(
embedding_indices
)
def
convert_mapping
(
mapping
:
PromptAdapterMapping
,
prompt_adapter_index_to_id
:
List
[
Optional
[
int
]],
)
->
torch
.
Tensor
:
"""Converts PromptAdapterMapping to index tensors.
Args:
mapping: PromptAdapterMapping mapping rows in a
batch to PromptAdapter ids.
prompt_adapter_index_to_id: List mapping PromptAdapter
ids to PromptAdapter indices.
Returns:
pa_indices: Tensor of shape [batch_size] mapping batch rows to
PromptAdapter indices.
"""
id_to_index
=
{
id_
:
idx
for
idx
,
id_
in
enumerate
(
prompt_adapter_index_to_id
)
if
id_
is
not
None
}
pa_indices
=
([
id_to_index
.
get
(
id_
,
-
1
)
if
id_
>
0
else
-
1
for
id_
in
mapping
.
index_mapping
])
pa_embedding_mapping
=
convert_to_embedding_indices
(
pa_indices
)
pa_indices
=
torch
.
tensor
(
pa_indices
)
return
pa_indices
,
pa_embedding_mapping
class
PromptAdapterModel
(
AdapterModel
):
def
__init__
(
self
,
prompt_adapter_id
=
None
,
num_virtual_tokens
=
None
,
prompt_embedding
=
None
)
->
None
:
self
.
id
=
prompt_adapter_id
self
.
prompt_embedding
=
prompt_embedding
self
.
num_virtual_tokens
=
num_virtual_tokens
@
classmethod
def
from_local_checkpoint
(
cls
,
adapter_model_path
:
str
,
prompt_adapter_id
:
int
,
num_virtual_tokens
:
int
,
config
:
PromptAdapterConfig
,
device
:
str
=
"cuda"
,
)
->
"PromptAdapterModel"
:
if
num_virtual_tokens
>
config
.
max_prompt_adapter_token
:
raise
ValueError
(
f
'num_virtual_tokens (
{
num_virtual_tokens
}
) should be <= '
f
'max_prompt_adapter_token(
{
config
.
max_prompt_adapter_token
}
)'
)
adapters_weights
=
load_peft_weights
(
adapter_model_path
,
device
)
prompt_embedding
=
adapters_weights
[
"prompt_embeddings"
].
to
(
config
.
prompt_adapter_dtype
)
return
cls
(
prompt_adapter_id
,
num_virtual_tokens
,
prompt_embedding
)
class
PromptAdapterModelManager
(
AdapterModelManager
):
"""A manager that manages multiple Prompt Adapter models."""
def
__init__
(
self
,
model
:
nn
.
Module
,
max_num_seqs
:
int
,
max_num_batched_tokens
:
int
,
prompt_adapter_config
:
PromptAdapterConfig
,
):
"""Create a PromptAdapterModel and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
prompt_adapter_config: the PromptAdapter config,
"""
self
.
model
:
nn
.
Module
=
model
# Dict instead of a Set for compatibility with LRUCache.
self
.
prompt_adapter_index_to_id
:
List
[
Optional
[
int
]]
=
[
None
]
*
self
.
prompt_adapter_slots
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_batched_tokens
=
math
.
ceil
(
max_num_batched_tokens
/
8
)
*
8
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
model
.
prompt_adapter_manager
=
self
self
.
adapter_type
=
'PromptAdapter'
self
.
base_indices
=
torch
.
tensor
([
-
1
])
self
.
base_embedding_indices
=
torch
.
tensor
([])
self
.
modules
:
Dict
[
str
,
nn
.
Module
]
=
{}
self
.
_create_prompt_adapter_modules
()
self
.
_last_mapping
:
Optional
[
PromptAdapterMapping
]
=
None
@
property
def
prompt_adapter_slots
(
self
)
->
int
:
return
self
.
prompt_adapter_config
.
max_prompt_adapters
@
property
def
adapter_slots
(
self
)
->
int
:
return
self
.
prompt_adapter_slots
@
property
def
capacity
(
self
)
->
int
:
return
self
.
prompt_adapter_config
.
max_cpu_prompt_adapters
def
activate_adapter
(
self
,
prompt_adapter_id
:
int
,
)
->
bool
:
"""Move PromptAdapter into a GPU buffer
to be used in the forward pass."""
if
prompt_adapter_id
in
self
.
_active_adapters
:
return
False
first_free_slot
=
next
(
((
i
,
prompt_adapter_id
)
for
i
,
prompt_adapter_id
in
enumerate
(
self
.
prompt_adapter_index_to_id
)
if
prompt_adapter_id
is
None
),
None
)
if
first_free_slot
is
None
:
raise
ValueError
(
"No free prompt_adapter slots"
)
index
,
_
=
first_free_slot
self
.
_active_adapters
[
prompt_adapter_id
]
=
None
prompt_adapter_model
=
(
self
.
_registered_adapters
[
prompt_adapter_id
])
logger
.
debug
(
"Activating prompt_adapter. int id: %d, slot index: %d"
,
prompt_adapter_model
.
id
,
index
)
self
.
prompt_adapter_index_to_id
[
index
]
=
prompt_adapter_model
.
id
for
_
,
v
in
self
.
modules
.
items
():
v
.
set_prompt_adapter
(
index
,
prompt_adapter_model
.
prompt_embedding
)
return
True
def
_deactivate_adapter
(
self
,
prompt_adapter_id
:
int
):
try
:
index
=
self
.
prompt_adapter_index_to_id
.
index
(
prompt_adapter_id
)
self
.
prompt_adapter_index_to_id
[
index
]
=
None
for
_
,
v
in
self
.
modules
.
items
():
v
.
reset_prompt_adapter
(
index
)
except
ValueError
:
pass
def
_add_adapter
(
self
,
prompt_adapter
:
PromptAdapterModel
):
self
.
_registered_adapters
[
prompt_adapter
.
id
]
=
prompt_adapter
def
_set_adapter_mapping
(
self
,
mapping
:
PromptAdapterMapping
)
->
None
:
base_indices
,
base_embedding_indices
=
convert_mapping
(
mapping
,
self
.
prompt_adapter_index_to_id
)
for
k
,
v
in
self
.
modules
.
items
():
v
.
set_mapping
(
base_indices
,
base_embedding_indices
)
def
_create_prompt_adapter_modules
(
self
):
for
module_name
,
module
in
self
.
model
.
named_modules
(
remove_duplicate
=
False
):
if
"VocabParallel"
in
module
.
__class__
.
__name__
:
new_module
=
VocabParallelEmbeddingWithPromptAdapter
(
module
)
new_module
.
create_prompt_adapter_weights
(
self
.
prompt_adapter_config
)
replaced_module
=
self
.
replace_submodule
(
self
.
model
,
module_name
,
new_module
)
self
.
register_module
(
module
.
__class__
.
__name__
,
replaced_module
)
replaced_module
.
set_mapping
(
self
.
base_indices
,
self
.
base_embedding_indices
)
break
def
replace_submodule
(
self
,
model
:
nn
.
Module
,
module_name
:
str
,
new_module
:
nn
.
Module
)
->
nn
.
Module
:
"""Replace a submodule in a model with a new module."""
parent
=
model
.
get_submodule
(
"."
.
join
(
module_name
.
split
(
"."
)[:
-
1
]))
target_name
=
module_name
.
split
(
"."
)[
-
1
]
setattr
(
parent
,
target_name
,
new_module
)
return
new_module
def
register_module
(
self
,
module_name
:
str
,
module
:
nn
.
Module
):
self
.
modules
[
module_name
]
=
module
def
pin_adapter
(
self
,
prompt_adapter_id
:
int
)
->
bool
:
"""Pin a PromptAdapterModel in the manager cache."""
raise
NotImplementedError
(
"Pinning is not supported in PromptAdapterModelManager. "
"Use LRUCachePromptAdapterModelManager for pinning"
)
# type: ignore
def
remove_all_adapters
(
self
):
"""Remove all PromptAdapterModel from the manager."""
self
.
_registered_adapters
.
clear
()
self
.
prompt_adapter_index_to_id
=
[
None
]
*
self
.
prompt_adapter_slots
self
.
_active_adapters
.
clear
()
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
return
deactivate_adapter
(
adapter_id
,
self
.
_active_adapters
,
self
.
_deactivate_adapter
)
def
add_adapter
(
self
,
adapter
:
PromptAdapterModel
)
->
bool
:
return
add_adapter
(
adapter
,
self
.
_registered_adapters
,
self
.
capacity
,
self
.
_add_adapter
)
def
set_adapter_mapping
(
self
,
mapping
:
PromptAdapterMapping
)
->
None
:
self
.
_last_mapping
=
set_adapter_mapping
(
mapping
,
self
.
_last_mapping
,
self
.
_set_adapter_mapping
)
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
return
remove_adapter
(
adapter_id
,
self
.
_registered_adapters
,
self
.
deactivate_adapter
)
def
list_adapters
(
self
)
->
Dict
[
int
,
Any
]:
return
list_adapters
(
self
.
_registered_adapters
)
def
get_adapter
(
self
,
adapter_id
:
int
)
->
Optional
[
Any
]:
return
get_adapter
(
adapter_id
,
self
.
_registered_adapters
)
class
PromptAdapterLRUCache
(
AdapterLRUCache
[
PromptAdapterModel
]):
def
__init__
(
self
,
capacity
:
int
,
deactivate_prompt_adapter_fn
:
Callable
[[
int
],
bool
]):
super
().
__init__
(
capacity
,
deactivate_prompt_adapter_fn
)
class
LRUCachePromptAdapterModelManager
(
PromptAdapterModelManager
):
"""A model manager that manages multiple prompt_adapters with LRU cache."""
def
__init__
(
self
,
model
:
nn
.
Module
,
max_num_seqs
:
int
,
max_num_batched_tokens
:
int
,
prompt_adapter_config
:
PromptAdapterConfig
,
):
self
.
prompt_adapter_config
=
prompt_adapter_config
super
().
__init__
(
model
,
max_num_seqs
,
max_num_batched_tokens
,
prompt_adapter_config
)
self
.
_registered_adapters
=
PromptAdapterLRUCache
(
self
.
capacity
,
self
.
deactivate_adapter
)
self
.
_active_adapters
=
PromptAdapterLRUCache
(
self
.
prompt_adapter_slots
,
self
.
_deactivate_adapter
)
def
list_adapters
(
self
)
->
Dict
[
int
,
PromptAdapterModel
]:
"""List all registered PromptAdapterModel."""
return
dict
(
self
.
_registered_adapters
.
cache
)
def
add_adapter
(
self
,
prompt_adapter
:
PromptAdapterModel
)
->
bool
:
"""Add a PromptAdapterModel to the manager."""
if
prompt_adapter
.
id
not
in
self
.
_registered_adapters
:
self
.
_add_adapter
(
prompt_adapter
)
was_added
=
True
else
:
# We always touch to update the LRU cache order
self
.
_registered_adapters
.
touch
(
prompt_adapter
.
id
)
was_added
=
False
return
was_added
def
activate_adapter
(
self
,
prompt_adapter_id
:
int
,
)
->
bool
:
if
prompt_adapter_id
not
in
self
.
_active_adapters
and
len
(
self
.
_active_adapters
)
>=
self
.
prompt_adapter_slots
:
self
.
_active_adapters
.
remove_oldest
()
result
=
super
().
activate_adapter
(
prompt_adapter_id
)
# We always touch to update the LRU cache order
self
.
_active_adapters
.
touch
(
prompt_adapter_id
)
return
result
def
remove_oldest_adapter
(
self
)
->
bool
:
if
len
(
self
.
_registered_adapters
)
>
0
:
self
.
_registered_adapters
.
remove_oldest
()
return
True
return
False
def
pin_adapter
(
self
,
prompt_adapter_id
:
int
)
->
bool
:
"""Pin a PromptAdapterModel in the manager cache."""
self
.
_pin_prompt_adapter_in_cpu_cache
(
prompt_adapter_id
)
self
.
_pin_prompt_adapter_in_gpu_cache
(
prompt_adapter_id
)
return
True
def
_pin_prompt_adapter_in_cpu_cache
(
self
,
prompt_adapter_id
:
int
):
try
:
self
.
_registered_adapters
.
pin
(
prompt_adapter_id
)
except
ValueError
as
err
:
raise
ValueError
(
"Pinning failed. "
f
"Prompt Adapter
{
prompt_adapter_id
}
is not registered."
)
from
err
def
_pin_prompt_adapter_in_gpu_cache
(
self
,
prompt_adapter_id
:
int
):
if
prompt_adapter_id
not
in
self
.
_active_adapters
:
# move adapter to gpu if not already active
self
.
activate_adapter
(
prompt_adapter_id
)
self
.
_active_adapters
.
pin
(
prompt_adapter_id
)
def
create_prompt_adapter_manager
(
model
:
nn
.
Module
,
max_num_seqs
:
int
,
max_num_batched_tokens
:
int
,
prompt_adapter_config
:
PromptAdapterConfig
,
prompt_adapter_manager_cls
:
Type
[
PromptAdapterModelManager
]
=
PromptAdapterModelManager
,
**
kwargs
)
->
PromptAdapterModelManager
:
"""Create a PromptAdapterModel for a given model."""
prompt_adapter_manager
=
prompt_adapter_manager_cls
(
model
=
model
,
max_num_seqs
=
max_num_seqs
,
max_num_batched_tokens
=
max_num_batched_tokens
,
prompt_adapter_config
=
prompt_adapter_config
,
**
kwargs
)
return
prompt_adapter_manager
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment