Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82ec66f5
Unverified
Commit
82ec66f5
authored
Jul 23, 2025
by
Michael Goin
Committed by
GitHub
Jul 23, 2025
Browse files
[V0 Deprecation] Remove Prompt Adapters (#20588)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
78c13e30
Changes
60
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
32 additions
and
695 deletions
+32
-695
vllm/entrypoints/logger.py
vllm/entrypoints/logger.py
+2
-5
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+0
-1
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+1
-35
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+0
-1
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-8
vllm/entrypoints/openai/serving_classification.py
vllm/entrypoints/openai/serving_classification.py
+1
-9
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-6
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+1
-8
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+8
-23
vllm/entrypoints/openai/serving_models.py
vllm/entrypoints/openai/serving_models.py
+0
-31
vllm/entrypoints/openai/serving_pooling.py
vllm/entrypoints/openai/serving_pooling.py
+2
-10
vllm/entrypoints/openai/serving_responses.py
vllm/entrypoints/openai/serving_responses.py
+2
-7
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+3
-19
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+4
-17
vllm/entrypoints/openai/speech_to_text.py
vllm/entrypoints/openai/speech_to_text.py
+2
-10
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+0
-31
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+2
-33
vllm/prompt_adapter/__init__.py
vllm/prompt_adapter/__init__.py
+0
-0
vllm/prompt_adapter/layers.py
vllm/prompt_adapter/layers.py
+0
-83
vllm/prompt_adapter/models.py
vllm/prompt_adapter/models.py
+0
-358
No files found.
vllm/entrypoints/logger.py
View file @
82ec66f5
...
@@ -8,7 +8,6 @@ import torch
...
@@ -8,7 +8,6 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -30,7 +29,6 @@ class RequestLogger:
...
@@ -30,7 +29,6 @@ class RequestLogger:
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
,
BeamSearchParams
]],
BeamSearchParams
]],
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
None
:
)
->
None
:
max_log_len
=
self
.
max_log_len
max_log_len
=
self
.
max_log_len
if
max_log_len
is
not
None
:
if
max_log_len
is
not
None
:
...
@@ -44,7 +42,6 @@ class RequestLogger:
...
@@ -44,7 +42,6 @@ class RequestLogger:
"Received request %s: prompt: %r, "
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"params: %s, prompt_token_ids: %s, "
"prompt_embeds shape: %s, "
"prompt_embeds shape: %s, "
"lora_request: %s, prompt_adapter_request: %s."
,
request_id
,
"lora_request: %s."
,
request_id
,
prompt
,
params
,
prompt_token_ids
,
prompt
,
params
,
prompt_token_ids
,
prompt_embeds
.
shape
if
prompt_embeds
is
not
None
else
None
,
prompt_embeds
.
shape
if
prompt_embeds
is
not
None
else
None
,
lora_request
,
prompt_adapter_request
)
lora_request
)
vllm/entrypoints/openai/api_server.py
View file @
82ec66f5
...
@@ -1620,7 +1620,6 @@ async def init_app_state(
...
@@ -1620,7 +1620,6 @@ async def init_app_state(
model_config
=
model_config
,
model_config
=
model_config
,
base_model_paths
=
base_model_paths
,
base_model_paths
=
base_model_paths
,
lora_modules
=
lora_modules
,
lora_modules
=
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
)
)
await
state
.
openai_serving_models
.
init_static_loras
()
await
state
.
openai_serving_models
.
init_static_loras
()
state
.
openai_serving_responses
=
OpenAIServingResponses
(
state
.
openai_serving_responses
=
OpenAIServingResponses
(
...
...
vllm/entrypoints/openai/cli_args.py
View file @
82ec66f5
...
@@ -20,8 +20,7 @@ from vllm.config import config
...
@@ -20,8 +20,7 @@ from vllm.config import config
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
optional_type
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
optional_type
from
vllm.entrypoints.chat_utils
import
(
ChatTemplateContentFormatOption
,
from
vllm.entrypoints.chat_utils
import
(
ChatTemplateContentFormatOption
,
validate_chat_template
)
validate_chat_template
)
from
vllm.entrypoints.openai.serving_models
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_models
import
LoRAModulePath
PromptAdapterPath
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
...
@@ -65,27 +64,6 @@ class LoRAParserAction(argparse.Action):
...
@@ -65,27 +64,6 @@ class LoRAParserAction(argparse.Action):
setattr
(
namespace
,
self
.
dest
,
lora_list
)
setattr
(
namespace
,
self
.
dest
,
lora_list
)
class
PromptAdapterParserAction
(
argparse
.
Action
):
def
__call__
(
self
,
parser
:
argparse
.
ArgumentParser
,
namespace
:
argparse
.
Namespace
,
values
:
Optional
[
Union
[
str
,
Sequence
[
str
]]],
option_string
:
Optional
[
str
]
=
None
,
):
if
values
is
None
:
values
=
[]
if
isinstance
(
values
,
str
):
raise
TypeError
(
"Expected values to be a list"
)
adapter_list
:
list
[
PromptAdapterPath
]
=
[]
for
item
in
values
:
name
,
path
=
item
.
split
(
'='
)
adapter_list
.
append
(
PromptAdapterPath
(
name
,
path
))
setattr
(
namespace
,
self
.
dest
,
adapter_list
)
@
config
@
config
@
dataclass
@
dataclass
class
FrontendArgs
:
class
FrontendArgs
:
...
@@ -115,9 +93,6 @@ class FrontendArgs:
...
@@ -115,9 +93,6 @@ class FrontendArgs:
or JSON list format. Example (old format): `'name=path'` Example (new
or JSON list format. Example (old format): `'name=path'` Example (new
format): `{
\"
name
\"
:
\"
name
\"
,
\"
path
\"
:
\"
lora_path
\"
,
format): `{
\"
name
\"
:
\"
name
\"
,
\"
path
\"
:
\"
lora_path
\"
,
\"
base_model_name
\"
:
\"
id
\"
}`"""
\"
base_model_name
\"
:
\"
id
\"
}`"""
prompt_adapters
:
Optional
[
list
[
PromptAdapterPath
]]
=
None
"""Prompt adapter configurations in the format name=path. Multiple adapters
can be specified."""
chat_template
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
"""The file path to the chat template, or the template in single-line form
"""The file path to the chat template, or the template in single-line form
for the specified model."""
for the specified model."""
...
@@ -207,12 +182,6 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
...
@@ -207,12 +182,6 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
frontend_kwargs
[
"lora_modules"
][
"type"
]
=
optional_type
(
str
)
frontend_kwargs
[
"lora_modules"
][
"type"
]
=
optional_type
(
str
)
frontend_kwargs
[
"lora_modules"
][
"action"
]
=
LoRAParserAction
frontend_kwargs
[
"lora_modules"
][
"action"
]
=
LoRAParserAction
# Special case: Prompt adapters need custom parser action and
# optional_type(str)
frontend_kwargs
[
"prompt_adapters"
][
"type"
]
=
optional_type
(
str
)
frontend_kwargs
[
"prompt_adapters"
][
"action"
]
=
PromptAdapterParserAction
# Special case: Middleware needs append action
# Special case: Middleware needs append action
frontend_kwargs
[
"middleware"
][
"action"
]
=
"append"
frontend_kwargs
[
"middleware"
][
"action"
]
=
"append"
frontend_kwargs
[
"middleware"
][
"type"
]
=
str
frontend_kwargs
[
"middleware"
][
"type"
]
=
str
...
@@ -288,9 +257,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
...
@@ -288,9 +257,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
if
args
.
enable_auto_tool_choice
and
not
args
.
tool_call_parser
:
if
args
.
enable_auto_tool_choice
and
not
args
.
tool_call_parser
:
raise
TypeError
(
"Error: --enable-auto-tool-choice requires "
raise
TypeError
(
"Error: --enable-auto-tool-choice requires "
"--tool-call-parser"
)
"--tool-call-parser"
)
if
args
.
enable_prompt_embeds
and
args
.
enable_prompt_adapter
:
raise
ValueError
(
"Cannot use prompt embeds and prompt adapter at the same time."
)
def
log_non_default_args
(
args
:
argparse
.
Namespace
):
def
log_non_default_args
(
args
:
argparse
.
Namespace
):
...
...
vllm/entrypoints/openai/run_batch.py
View file @
82ec66f5
...
@@ -337,7 +337,6 @@ async def main(args):
...
@@ -337,7 +337,6 @@ async def main(args):
model_config
=
model_config
,
model_config
=
model_config
,
base_model_paths
=
base_model_paths
,
base_model_paths
=
base_model_paths
,
lora_modules
=
None
,
lora_modules
=
None
,
prompt_adapters
=
None
,
)
)
openai_serving_chat
=
OpenAIServingChat
(
openai_serving_chat
=
OpenAIServingChat
(
engine
,
engine
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
82ec66f5
...
@@ -147,11 +147,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -147,11 +147,8 @@ class OpenAIServingChat(OpenAIServing):
raise
self
.
engine_client
.
dead_error
raise
self
.
engine_client
.
dead_error
try
:
try
:
(
lora_request
=
self
.
_maybe_get_adapters
(
lora_request
,
request
,
supports_default_mm_loras
=
True
)
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
,
supports_default_mm_loras
=
True
)
model_name
=
self
.
_get_model_name
(
request
.
model
,
lora_request
)
model_name
=
self
.
_get_model_name
(
request
.
model
,
lora_request
)
...
@@ -239,8 +236,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -239,8 +236,7 @@ class OpenAIServingChat(OpenAIServing):
self
.
_log_inputs
(
request_id
,
self
.
_log_inputs
(
request_id
,
request_prompts
[
i
],
request_prompts
[
i
],
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
)
prompt_adapter_request
=
prompt_adapter_request
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
))
self
.
_get_trace_headers
(
raw_request
.
headers
))
...
@@ -259,7 +255,6 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -259,7 +255,6 @@ class OpenAIServingChat(OpenAIServing):
request_id
,
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
)
)
...
...
vllm/entrypoints/openai/serving_classification.py
View file @
82ec66f5
...
@@ -49,19 +49,11 @@ class ClassificationMixin(OpenAIServing):
...
@@ -49,19 +49,11 @@ class ClassificationMixin(OpenAIServing):
return
None
return
None
try
:
try
:
(
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
ctx
.
lora_request
,
ctx
.
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
ctx
.
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
ctx
.
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
ctx
.
lora_request
)
ctx
.
lora_request
)
if
ctx
.
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported for classification models"
)
(
(
ctx
.
request_prompts
,
ctx
.
request_prompts
,
ctx
.
engine_prompts
,
ctx
.
engine_prompts
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
82ec66f5
...
@@ -121,10 +121,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -121,10 +121,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request
.
state
.
request_metadata
=
request_metadata
raw_request
.
state
.
request_metadata
=
request_metadata
try
:
try
:
(
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
...
@@ -197,7 +194,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -197,7 +194,6 @@ class OpenAIServingCompletion(OpenAIServing):
request_prompts
[
i
],
request_prompts
[
i
],
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
trace_headers
=
(
None
if
raw_request
is
None
else
await
...
@@ -221,7 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -221,7 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params
,
sampling_params
,
request_id_item
,
request_id_item
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
)
)
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
82ec66f5
...
@@ -53,18 +53,11 @@ class EmbeddingMixin(OpenAIServing):
...
@@ -53,18 +53,11 @@ class EmbeddingMixin(OpenAIServing):
)
->
Optional
[
ErrorResponse
]:
)
->
Optional
[
ErrorResponse
]:
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
try
:
try
:
(
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
ctx
.
lora_request
,
ctx
.
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
ctx
.
lora_request
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
ctx
.
lora_request
)
)
if
ctx
.
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for embedding models"
)
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
(
(
_
,
_
,
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
82ec66f5
...
@@ -68,7 +68,6 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
...
@@ -68,7 +68,6 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
MultiModalDataDict
)
MultiModalDataDict
)
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
Logprob
,
PromptLogprobs
from
vllm.sequence
import
Logprob
,
PromptLogprobs
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
...
@@ -161,7 +160,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
...
@@ -161,7 +160,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
request_id
:
str
request_id
:
str
created_time
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created_time
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
# Shared across most requests
# Shared across most requests
tokenizer
:
Optional
[
AnyTokenizer
]
=
None
tokenizer
:
Optional
[
AnyTokenizer
]
=
None
...
@@ -343,12 +341,10 @@ class OpenAIServing:
...
@@ -343,12 +341,10 @@ class OpenAIServing:
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"Request prompts not available"
)
"Request prompts not available"
)
self
.
_log_inputs
(
self
.
_log_inputs
(
request_id_item
,
request_id_item
,
ctx
.
request_prompts
[
i
],
ctx
.
request_prompts
[
i
],
params
=
pooling_params
,
params
=
pooling_params
,
lora_request
=
ctx
.
lora_request
,
lora_request
=
ctx
.
lora_request
)
prompt_adapter_request
=
ctx
.
prompt_adapter_request
)
# Mypy has an existing bug related to inferring the variance of
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
# TypedDicts with `builtins.enumerate`:
...
@@ -450,11 +446,6 @@ class OpenAIServing:
...
@@ -450,11 +446,6 @@ class OpenAIServing:
if
isinstance
(
load_result
,
ErrorResponse
)
and
\
if
isinstance
(
load_result
,
ErrorResponse
)
and
\
load_result
.
code
==
HTTPStatus
.
BAD_REQUEST
.
value
:
load_result
.
code
==
HTTPStatus
.
BAD_REQUEST
.
value
:
error_response
=
load_result
error_response
=
load_result
if
request
.
model
in
[
prompt_adapter
.
prompt_adapter_name
for
prompt_adapter
in
self
.
models
.
prompt_adapter_requests
]:
return
None
return
error_response
or
self
.
create_error_response
(
return
error_response
or
self
.
create_error_response
(
message
=
f
"The model `
{
request
.
model
}
` does not exist."
,
message
=
f
"The model `
{
request
.
model
}
` does not exist."
,
...
@@ -489,25 +480,21 @@ class OpenAIServing:
...
@@ -489,25 +480,21 @@ class OpenAIServing:
self
,
self
,
request
:
AnyRequest
,
request
:
AnyRequest
,
supports_default_mm_loras
:
bool
=
False
,
supports_default_mm_loras
:
bool
=
False
,
)
->
Union
[
tuple
[
None
,
None
],
tuple
[
LoRARequest
,
None
],
tuple
[
)
->
Optional
[
LoRARequest
]:
None
,
PromptAdapterRequest
]]:
if
request
.
model
in
self
.
models
.
lora_requests
:
if
request
.
model
in
self
.
models
.
lora_requests
:
return
self
.
models
.
lora_requests
[
request
.
model
]
,
None
return
self
.
models
.
lora_requests
[
request
.
model
]
# Currently only support default modality specific loras
# Currently only support default modality specific loras
# if we have exactly one lora matched on the request.
# if we have exactly one lora matched on the request.
if
supports_default_mm_loras
:
if
supports_default_mm_loras
:
default_mm_lora
=
self
.
_get_active_default_mm_loras
(
request
)
default_mm_lora
=
self
.
_get_active_default_mm_loras
(
request
)
if
default_mm_lora
is
not
None
:
if
default_mm_lora
is
not
None
:
return
default_mm_lora
,
None
return
default_mm_lora
if
self
.
_is_model_supported
(
request
.
model
):
if
self
.
_is_model_supported
(
request
.
model
):
return
None
,
None
return
None
for
prompt_adapter
in
self
.
models
.
prompt_adapter_requests
:
if
request
.
model
==
prompt_adapter
.
prompt_adapter_name
:
return
None
,
prompt_adapter
# if _check_model has been called earlier, this will be unreachable
# if _check_model has been called earlier, this will be unreachable
raise
ValueError
(
f
"The model `
{
request
.
model
}
` does not exist."
)
raise
ValueError
(
f
"The model `
{
request
.
model
}
` does not exist."
)
...
@@ -987,7 +974,6 @@ class OpenAIServing:
...
@@ -987,7 +974,6 @@ class OpenAIServing:
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
,
BeamSearchParams
]],
BeamSearchParams
]],
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
None
:
)
->
None
:
if
self
.
request_logger
is
None
:
if
self
.
request_logger
is
None
:
return
return
...
@@ -1009,7 +995,6 @@ class OpenAIServing:
...
@@ -1009,7 +995,6 @@ class OpenAIServing:
prompt_embeds
,
prompt_embeds
,
params
=
params
,
params
=
params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
)
async
def
_get_trace_headers
(
async
def
_get_trace_headers
(
...
...
vllm/entrypoints/openai/serving_models.py
View file @
82ec66f5
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
pathlib
from
asyncio
import
Lock
from
asyncio
import
Lock
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -19,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
...
@@ -19,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.resolver
import
LoRAResolver
,
LoRAResolverRegistry
from
vllm.lora.resolver
import
LoRAResolver
,
LoRAResolverRegistry
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.utils
import
AtomicCounter
from
vllm.utils
import
AtomicCounter
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -31,12 +28,6 @@ class BaseModelPath:
...
@@ -31,12 +28,6 @@ class BaseModelPath:
model_path
:
str
model_path
:
str
@
dataclass
class
PromptAdapterPath
:
name
:
str
local_path
:
str
@
dataclass
@
dataclass
class
LoRAModulePath
:
class
LoRAModulePath
:
name
:
str
name
:
str
...
@@ -60,7 +51,6 @@ class OpenAIServingModels:
...
@@ -60,7 +51,6 @@ class OpenAIServingModels:
base_model_paths
:
list
[
BaseModelPath
],
base_model_paths
:
list
[
BaseModelPath
],
*
,
*
,
lora_modules
:
Optional
[
list
[
LoRAModulePath
]]
=
None
,
lora_modules
:
Optional
[
list
[
LoRAModulePath
]]
=
None
,
prompt_adapters
:
Optional
[
list
[
PromptAdapterPath
]]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -81,20 +71,6 @@ class OpenAIServingModels:
...
@@ -81,20 +71,6 @@ class OpenAIServingModels:
LoRAResolverRegistry
.
get_resolver
(
lora_resolver_name
))
LoRAResolverRegistry
.
get_resolver
(
lora_resolver_name
))
self
.
lora_resolver_lock
:
dict
[
str
,
Lock
]
=
defaultdict
(
Lock
)
self
.
lora_resolver_lock
:
dict
[
str
,
Lock
]
=
defaultdict
(
Lock
)
self
.
prompt_adapter_requests
=
[]
if
prompt_adapters
is
not
None
:
for
i
,
prompt_adapter
in
enumerate
(
prompt_adapters
,
start
=
1
):
with
pathlib
.
Path
(
prompt_adapter
.
local_path
,
"adapter_config.json"
).
open
()
as
f
:
adapter_config
=
json
.
load
(
f
)
num_virtual_tokens
=
adapter_config
[
"num_virtual_tokens"
]
self
.
prompt_adapter_requests
.
append
(
PromptAdapterRequest
(
prompt_adapter_name
=
prompt_adapter
.
name
,
prompt_adapter_id
=
i
,
prompt_adapter_local_path
=
prompt_adapter
.
local_path
,
prompt_adapter_num_virtual_tokens
=
num_virtual_tokens
))
async
def
init_static_loras
(
self
):
async
def
init_static_loras
(
self
):
"""Loads all static LoRA modules.
"""Loads all static LoRA modules.
Raises if any fail to load"""
Raises if any fail to load"""
...
@@ -141,14 +117,7 @@ class OpenAIServingModels:
...
@@ -141,14 +117,7 @@ class OpenAIServingModels:
permission
=
[
ModelPermission
()])
permission
=
[
ModelPermission
()])
for
lora
in
self
.
lora_requests
.
values
()
for
lora
in
self
.
lora_requests
.
values
()
]
]
prompt_adapter_cards
=
[
ModelCard
(
id
=
prompt_adapter
.
prompt_adapter_name
,
root
=
self
.
base_model_paths
[
0
].
name
,
permission
=
[
ModelPermission
()])
for
prompt_adapter
in
self
.
prompt_adapter_requests
]
model_cards
.
extend
(
lora_cards
)
model_cards
.
extend
(
lora_cards
)
model_cards
.
extend
(
prompt_adapter_cards
)
return
ModelList
(
data
=
model_cards
)
return
ModelList
(
data
=
model_cards
)
async
def
load_lora_adapter
(
async
def
load_lora_adapter
(
...
...
vllm/entrypoints/openai/serving_pooling.py
View file @
82ec66f5
...
@@ -94,17 +94,10 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -94,17 +94,10 @@ class OpenAIServingPooling(OpenAIServing):
try
:
try
:
truncate_prompt_tokens
=
_validate_truncation_size
(
truncate_prompt_tokens
=
_validate_truncation_size
(
self
.
max_model_len
,
truncate_prompt_tokens
)
self
.
max_model_len
,
truncate_prompt_tokens
)
(
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
if
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for pooling models"
)
if
isinstance
(
request
,
PoolingChatRequest
):
if
isinstance
(
request
,
PoolingChatRequest
):
(
(
_
,
_
,
...
@@ -153,8 +146,7 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -153,8 +146,7 @@ class OpenAIServingPooling(OpenAIServing):
self
.
_log_inputs
(
request_id_item
,
self
.
_log_inputs
(
request_id_item
,
request_prompts
[
i
],
request_prompts
[
i
],
params
=
pooling_params
,
params
=
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
)
prompt_adapter_request
=
prompt_adapter_request
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
))
self
.
_get_trace_headers
(
raw_request
.
headers
))
...
...
vllm/entrypoints/openai/serving_responses.py
View file @
82ec66f5
...
@@ -133,10 +133,7 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -133,10 +133,7 @@ class OpenAIServingResponses(OpenAIServing):
messages
=
self
.
_construct_input_messages
(
request
,
prev_response
)
messages
=
self
.
_construct_input_messages
(
request
,
prev_response
)
try
:
try
:
(
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
model_name
=
self
.
_get_model_name
(
request
.
model
,
lora_request
)
model_name
=
self
.
_get_model_name
(
request
.
model
,
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
...
@@ -169,8 +166,7 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -169,8 +166,7 @@ class OpenAIServingResponses(OpenAIServing):
self
.
_log_inputs
(
request
.
request_id
,
self
.
_log_inputs
(
request
.
request_id
,
request_prompts
[
i
],
request_prompts
[
i
],
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
)
prompt_adapter_request
=
prompt_adapter_request
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
))
self
.
_get_trace_headers
(
raw_request
.
headers
))
...
@@ -181,7 +177,6 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -181,7 +177,6 @@ class OpenAIServingResponses(OpenAIServing):
request
.
request_id
,
request
.
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
)
)
generators
.
append
(
generator
)
generators
.
append
(
generator
)
...
...
vllm/entrypoints/openai/serving_score.py
View file @
82ec66f5
...
@@ -27,7 +27,6 @@ from vllm.inputs.data import TokensPrompt
...
@@ -27,7 +27,6 @@ from vllm.inputs.data import TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
make_async
,
merge_async_iterators
from
vllm.utils
import
make_async
,
merge_async_iterators
...
@@ -58,8 +57,6 @@ class ServingScores(OpenAIServing):
...
@@ -58,8 +57,6 @@ class ServingScores(OpenAIServing):
request_id
:
str
,
request_id
:
str
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
Union
[
LoRARequest
,
None
]]
=
None
,
lora_request
:
Optional
[
Union
[
LoRARequest
,
None
]]
=
None
,
prompt_adapter_request
:
Optional
[
Union
[
PromptAdapterRequest
,
None
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
input_texts
=
texts_1
+
texts_2
input_texts
=
texts_1
+
texts_2
...
@@ -100,8 +97,7 @@ class ServingScores(OpenAIServing):
...
@@ -100,8 +97,7 @@ class ServingScores(OpenAIServing):
self
.
_log_inputs
(
request_id_item
,
self
.
_log_inputs
(
request_id_item
,
input_texts
[
i
],
input_texts
[
i
],
params
=
pooling_params
,
params
=
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
)
prompt_adapter_request
=
prompt_adapter_request
)
generators
.
append
(
generators
.
append
(
self
.
engine_client
.
encode
(
self
.
engine_client
.
encode
(
...
@@ -176,8 +172,6 @@ class ServingScores(OpenAIServing):
...
@@ -176,8 +172,6 @@ class ServingScores(OpenAIServing):
request_id
:
str
,
request_id
:
str
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
Union
[
LoRARequest
,
None
]]
=
None
,
lora_request
:
Optional
[
Union
[
LoRARequest
,
None
]]
=
None
,
prompt_adapter_request
:
Optional
[
Union
[
PromptAdapterRequest
,
None
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
request_prompts
:
list
[
str
]
=
[]
request_prompts
:
list
[
str
]
=
[]
...
@@ -261,8 +255,7 @@ class ServingScores(OpenAIServing):
...
@@ -261,8 +255,7 @@ class ServingScores(OpenAIServing):
self
.
_log_inputs
(
request_id_item
,
self
.
_log_inputs
(
request_id_item
,
request_prompts
[
i
],
request_prompts
[
i
],
params
=
pooling_params
,
params
=
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
)
prompt_adapter_request
=
prompt_adapter_request
)
generator
=
self
.
engine_client
.
encode
(
generator
=
self
.
engine_client
.
encode
(
engine_prompt
,
engine_prompt
,
...
@@ -295,14 +288,7 @@ class ServingScores(OpenAIServing):
...
@@ -295,14 +288,7 @@ class ServingScores(OpenAIServing):
raw_request
:
Optional
[
Request
]
=
None
,
raw_request
:
Optional
[
Request
]
=
None
,
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
(
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
if
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for scoring models"
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
...
@@ -340,7 +326,6 @@ class ServingScores(OpenAIServing):
...
@@ -340,7 +326,6 @@ class ServingScores(OpenAIServing):
request_id
=
request_id
,
request_id
=
request_id
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
)
trace_headers
=
trace_headers
)
else
:
else
:
...
@@ -352,7 +337,6 @@ class ServingScores(OpenAIServing):
...
@@ -352,7 +337,6 @@ class ServingScores(OpenAIServing):
request_id
=
request_id
,
request_id
=
request_id
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
)
trace_headers
=
trace_headers
)
async
def
create_score
(
async
def
create_score
(
...
...
vllm/entrypoints/openai/serving_tokenization.py
View file @
82ec66f5
...
@@ -60,10 +60,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -60,10 +60,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_id
=
f
"tokn-
{
self
.
_base_request_id
(
raw_request
)
}
"
request_id
=
f
"tokn-
{
self
.
_base_request_id
(
raw_request
)
}
"
try
:
try
:
(
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
...
@@ -104,11 +101,8 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -104,11 +101,8 @@ class OpenAIServingTokenization(OpenAIServing):
self
.
_log_inputs
(
request_id
,
self
.
_log_inputs
(
request_id
,
request_prompts
[
i
],
request_prompts
[
i
],
params
=
None
,
params
=
None
,
lora_request
=
lora_request
,
lora_request
=
lora_request
)
prompt_adapter_request
=
prompt_adapter_request
)
# Silently ignore prompt adapter since it does not affect
# tokenization (Unlike in Embeddings API where an error is raised)
if
isinstance
(
engine_prompt
,
if
isinstance
(
engine_prompt
,
dict
)
and
"prompt_token_ids"
in
engine_prompt
:
dict
)
and
"prompt_token_ids"
in
engine_prompt
:
input_ids
.
extend
(
engine_prompt
[
"prompt_token_ids"
])
input_ids
.
extend
(
engine_prompt
[
"prompt_token_ids"
])
...
@@ -133,21 +127,14 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -133,21 +127,14 @@ class OpenAIServingTokenization(OpenAIServing):
request_id
=
f
"tokn-
{
self
.
_base_request_id
(
raw_request
)
}
"
request_id
=
f
"tokn-
{
self
.
_base_request_id
(
raw_request
)
}
"
(
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
self
.
_log_inputs
(
request_id
,
self
.
_log_inputs
(
request_id
,
request
.
tokens
,
request
.
tokens
,
params
=
None
,
params
=
None
,
lora_request
=
lora_request
,
lora_request
=
lora_request
)
prompt_adapter_request
=
prompt_adapter_request
)
# Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised)
prompt_input
=
await
self
.
_tokenize_prompt_input_async
(
prompt_input
=
await
self
.
_tokenize_prompt_input_async
(
request
,
request
,
...
...
vllm/entrypoints/openai/speech_to_text.py
View file @
82ec66f5
...
@@ -150,19 +150,12 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -150,19 +150,12 @@ class OpenAISpeechToText(OpenAIServing):
raw_request
.
state
.
request_metadata
=
request_metadata
raw_request
.
state
.
request_metadata
=
request_metadata
try
:
try
:
(
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
if
lora_request
:
if
lora_request
:
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"Currently do not support LoRA for "
"Currently do not support LoRA for "
f
"
{
self
.
task_type
.
title
()
}
."
)
f
"
{
self
.
task_type
.
title
()
}
."
)
if
prompt_adapter_request
:
return
self
.
create_error_response
(
f
"Currently do not support PromptAdapter for "
f
"
{
self
.
task_type
.
title
()
}
."
)
prompts
,
duration_s
=
await
self
.
_preprocess_speech_to_text
(
prompts
,
duration_s
=
await
self
.
_preprocess_speech_to_text
(
request
=
request
,
request
=
request
,
...
@@ -188,8 +181,7 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -188,8 +181,7 @@ class OpenAISpeechToText(OpenAIServing):
# It will not display special tokens like <|startoftranscript|>
# It will not display special tokens like <|startoftranscript|>
request
.
prompt
,
request
.
prompt
,
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
None
,
lora_request
=
None
)
prompt_adapter_request
=
None
)
list_result_generator
=
[
list_result_generator
=
[
self
.
engine_client
.
generate
(
self
.
engine_client
.
generate
(
...
...
vllm/executor/executor_base.py
View file @
82ec66f5
...
@@ -17,7 +17,6 @@ from vllm.logger import init_logger
...
@@ -17,7 +17,6 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.pooling_params
import
PoolingTask
from
vllm.pooling_params
import
PoolingTask
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.utils
import
make_async
from
vllm.utils
import
make_async
from
vllm.worker.worker_base
import
WorkerBase
from
vllm.worker.worker_base
import
WorkerBase
...
@@ -50,7 +49,6 @@ class ExecutorBase(ABC):
...
@@ -50,7 +49,6 @@ class ExecutorBase(ABC):
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device_config
=
vllm_config
.
device_config
self
.
device_config
=
vllm_config
.
device_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
_init_executor
()
self
.
_init_executor
()
self
.
is_sleeping
=
False
self
.
is_sleeping
=
False
...
@@ -171,35 +169,6 @@ class ExecutorBase(ABC):
...
@@ -171,35 +169,6 @@ class ExecutorBase(ABC):
assert
s
==
sets
[
0
],
"All workers should have the same LORAs."
assert
s
==
sets
[
0
],
"All workers should have the same LORAs."
return
sets
[
0
]
return
sets
[
0
]
def
add_prompt_adapter
(
self
,
prompt_adapter_request
:
PromptAdapterRequest
)
->
bool
:
assert
prompt_adapter_request
.
prompt_adapter_id
>
0
,
\
"prompt_adapter_id must be greater than 0."
return
all
(
self
.
collective_rpc
(
"add_prompt_adapter"
,
args
=
(
prompt_adapter_request
,
)))
def
remove_prompt_adapter
(
self
,
prompt_adapter_id
:
int
)
->
bool
:
assert
prompt_adapter_id
>
0
,
\
"prompt_adapter_id must be greater than 0."
return
all
(
self
.
collective_rpc
(
"remove_prompt_adapter"
,
args
=
(
prompt_adapter_id
,
)))
def
pin_prompt_adapter
(
self
,
prompt_adapter_id
:
int
)
->
bool
:
assert
prompt_adapter_id
>
0
,
\
"prompt_adapter_id must be greater than 0."
return
all
(
self
.
collective_rpc
(
"pin_prompt_adapter"
,
args
=
(
prompt_adapter_id
,
)))
def
list_prompt_adapters
(
self
)
->
Set
[
int
]:
sets
=
self
.
collective_rpc
(
"list_prompt_adapters"
)
for
s
in
sets
:
assert
(
s
==
sets
[
0
]
),
"All workers should have the same prompt adapters."
return
sets
[
0
]
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
self
.
collective_rpc
(
"start_profile"
)
self
.
collective_rpc
(
"start_profile"
)
...
...
vllm/inputs/preprocess.py
View file @
82ec66f5
...
@@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest
...
@@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
...
@@ -168,18 +167,6 @@ class InputPreprocessor:
...
@@ -168,18 +167,6 @@ class InputPreprocessor:
return
decoder_input_ids
return
decoder_input_ids
def
_apply_prompt_adapter
(
self
,
prompt_token_ids
:
list
[
int
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
list
[
int
]:
if
prompt_adapter_request
:
prompt_token_ids
=
(
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
+
prompt_token_ids
)
return
prompt_token_ids
def
_get_tokenization_kw
(
def
_get_tokenization_kw
(
self
,
self
,
overrides
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
overrides
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
...
@@ -786,15 +773,10 @@ class InputPreprocessor:
...
@@ -786,15 +773,10 @@ class InputPreprocessor:
def
_build_decoder_only_llm_inputs
(
def
_build_decoder_only_llm_inputs
(
self
,
self
,
prompt_inputs
:
DecoderOnlyInputs
,
prompt_inputs
:
DecoderOnlyInputs
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
DecoderOnlyInputs
:
)
->
DecoderOnlyInputs
:
if
"prompt_token_ids"
in
prompt_inputs
:
if
"prompt_token_ids"
in
prompt_inputs
:
prompt_inputs
=
cast
(
Union
[
TokenInputs
,
MultiModalInputs
],
prompt_inputs
=
cast
(
Union
[
TokenInputs
,
MultiModalInputs
],
prompt_inputs
)
# Needed for mypy
prompt_inputs
)
# Needed for mypy
prompt_inputs
[
"prompt_token_ids"
]
=
self
.
_apply_prompt_adapter
(
prompt_inputs
[
"prompt_token_ids"
],
prompt_adapter_request
=
prompt_adapter_request
,
)
return
prompt_inputs
return
prompt_inputs
...
@@ -803,7 +785,6 @@ class InputPreprocessor:
...
@@ -803,7 +785,6 @@ class InputPreprocessor:
prompt
:
SingletonPrompt
,
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
return_mm_hashes
:
bool
=
False
,
)
->
DecoderOnlyInputs
:
)
->
DecoderOnlyInputs
:
"""
"""
...
@@ -815,7 +796,6 @@ class InputPreprocessor:
...
@@ -815,7 +796,6 @@ class InputPreprocessor:
* prompt: input prompt
* prompt: input prompt
* lora_request
* lora_request
* prompt_adapter_request
* return_mm_hashes
* return_mm_hashes
Returns:
Returns:
...
@@ -830,17 +810,13 @@ class InputPreprocessor:
...
@@ -830,17 +810,13 @@ class InputPreprocessor:
return_mm_hashes
=
return_mm_hashes
,
return_mm_hashes
=
return_mm_hashes
,
)
)
return
self
.
_build_decoder_only_llm_inputs
(
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
)
prompt_comps
,
prompt_adapter_request
=
prompt_adapter_request
,
)
async
def
_process_decoder_only_prompt_async
(
async
def
_process_decoder_only_prompt_async
(
self
,
self
,
prompt
:
SingletonPrompt
,
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
return_mm_hashes
:
bool
=
False
,
)
->
DecoderOnlyInputs
:
)
->
DecoderOnlyInputs
:
"""
"""
...
@@ -854,17 +830,13 @@ class InputPreprocessor:
...
@@ -854,17 +830,13 @@ class InputPreprocessor:
return_mm_hashes
=
return_mm_hashes
,
return_mm_hashes
=
return_mm_hashes
,
)
)
return
self
.
_build_decoder_only_llm_inputs
(
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
)
prompt_comps
,
prompt_adapter_request
=
prompt_adapter_request
,
)
def
preprocess
(
def
preprocess
(
self
,
self
,
prompt
:
PromptType
,
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
return_mm_hashes
:
bool
=
False
,
)
->
ProcessorInputs
:
)
->
ProcessorInputs
:
"""Preprocess the input prompt."""
"""Preprocess the input prompt."""
...
@@ -886,7 +858,6 @@ class InputPreprocessor:
...
@@ -886,7 +858,6 @@ class InputPreprocessor:
prompt
,
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
return_mm_hashes
=
return_mm_hashes
,
return_mm_hashes
=
return_mm_hashes
,
)
)
...
@@ -895,7 +866,6 @@ class InputPreprocessor:
...
@@ -895,7 +866,6 @@ class InputPreprocessor:
prompt
:
PromptType
,
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
return_mm_hashes
:
bool
=
False
,
)
->
ProcessorInputs
:
)
->
ProcessorInputs
:
"""
"""
...
@@ -919,6 +889,5 @@ class InputPreprocessor:
...
@@ -919,6 +889,5 @@ class InputPreprocessor:
prompt
,
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
return_mm_hashes
=
return_mm_hashes
,
return_mm_hashes
=
return_mm_hashes
,
)
)
vllm/prompt_adapter/__init__.py
deleted
100644 → 0
View file @
78c13e30
vllm/prompt_adapter/layers.py
deleted
100644 → 0
View file @
78c13e30
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
from
torch
import
nn
from
vllm.adapter_commons.layers
import
AdapterMapping
from
vllm.config
import
PromptAdapterConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
@
dataclass
class
PromptAdapterMapping
(
AdapterMapping
):
pass
class
VocabParallelEmbeddingWithPromptAdapter
(
nn
.
Module
):
def
__init__
(
self
,
base_layer
:
VocabParallelEmbedding
)
->
None
:
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
emb_layer
=
self
.
base_layer
if
'LoRA'
in
base_layer
.
__class__
.
__name__
:
self
.
emb_layer
=
self
.
base_layer
.
base_layer
def
create_prompt_adapter_weights
(
self
,
prompt_adapter_config
:
PromptAdapterConfig
):
self
.
embeddings_tensors
=
torch
.
zeros
(
(
prompt_adapter_config
.
max_prompt_adapters
,
prompt_adapter_config
.
max_prompt_adapter_token
,
self
.
emb_layer
.
embedding_dim
,
),
dtype
=
self
.
emb_layer
.
weight
.
dtype
,
device
=
self
.
emb_layer
.
weight
.
device
,
)
self
.
adapter_lengths
=
torch
.
zeros
(
prompt_adapter_config
.
max_prompt_adapters
,
dtype
=
torch
.
long
,
device
=
self
.
emb_layer
.
weight
.
device
)
self
.
indices_gpu
:
torch
.
Tensor
self
.
embedding_indices_gpu
:
torch
.
Tensor
def
reset_prompt_adapter
(
self
,
index
:
int
):
self
.
embeddings_tensors
[
index
]
=
0
def
set_prompt_adapter
(
self
,
index
:
int
,
adapter_model
:
Optional
[
torch
.
Tensor
],
):
self
.
reset_prompt_adapter
(
index
)
if
adapter_model
is
not
None
:
length
=
adapter_model
.
shape
[
0
]
self
.
embeddings_tensors
[
index
,
:
length
]
=
adapter_model
self
.
adapter_lengths
[
index
]
=
length
def
set_mapping
(
self
,
prompt_indices
:
torch
.
Tensor
,
prompt_embedding_indices
:
torch
.
Tensor
,
):
self
.
indices_gpu
=
prompt_indices
.
to
(
device
=
self
.
emb_layer
.
weight
.
device
)
self
.
embedding_indices_gpu
=
prompt_embedding_indices
.
to
(
device
=
self
.
emb_layer
.
weight
.
device
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
base_layer
(
x
)
if
self
.
embedding_indices_gpu
.
ndim
>
1
:
valid_mask
=
self
.
indices_gpu
!=
-
1
gathered_embeddings
=
self
.
embeddings_tensors
[
self
.
embedding_indices_gpu
[:,
0
],
self
.
embedding_indices_gpu
[:,
1
]]
# Update hidden states
hidden_states
[
valid_mask
]
=
gathered_embeddings
return
hidden_states
\ No newline at end of file
vllm/prompt_adapter/models.py
deleted
100644 → 0
View file @
78c13e30
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
logging
import
math
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
import
torch
from
torch
import
nn
from
vllm.adapter_commons.models
import
(
AdapterLRUCache
,
AdapterModel
,
AdapterModelManager
)
from
vllm.adapter_commons.utils
import
(
add_adapter
,
deactivate_adapter
,
get_adapter
,
list_adapters
,
remove_adapter
,
set_adapter_mapping
)
from
vllm.config
import
PromptAdapterConfig
from
vllm.prompt_adapter.layers
import
(
VocabParallelEmbeddingWithPromptAdapter
)
# yapf: disable
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.utils
import
load_peft_weights
logger
=
logging
.
getLogger
(
__name__
)
_GLOBAL_PROMPT_ADAPTER_ID
=
0
def
get_prompt_adapter_id
():
global
_GLOBAL_PROMPT_ADAPTER_ID
_GLOBAL_PROMPT_ADAPTER_ID
+=
1
return
_GLOBAL_PROMPT_ADAPTER_ID
def
convert_to_embedding_indices
(
indices
):
embedding_indices
=
[]
count
=
0
for
value
in
indices
:
if
value
==
-
1
:
count
=
0
else
:
embedding_indices
.
append
([
value
,
count
])
count
+=
1
return
torch
.
tensor
(
embedding_indices
)
def
convert_mapping
(
mapping
:
PromptAdapterMapping
,
prompt_adapter_index_to_id
:
List
[
Optional
[
int
]],
)
->
torch
.
Tensor
:
"""Converts PromptAdapterMapping to index tensors.
Args:
mapping: PromptAdapterMapping mapping rows in a
batch to PromptAdapter ids.
prompt_adapter_index_to_id: List mapping PromptAdapter
ids to PromptAdapter indices.
Returns:
pa_indices: Tensor of shape [batch_size] mapping batch rows to
PromptAdapter indices.
"""
id_to_index
=
{
id_
:
idx
for
idx
,
id_
in
enumerate
(
prompt_adapter_index_to_id
)
if
id_
is
not
None
}
pa_indices
=
([
id_to_index
.
get
(
id_
,
-
1
)
if
id_
>
0
else
-
1
for
id_
in
mapping
.
index_mapping
])
pa_embedding_mapping
=
convert_to_embedding_indices
(
pa_indices
)
pa_indices
=
torch
.
tensor
(
pa_indices
)
return
pa_indices
,
pa_embedding_mapping
class
PromptAdapterModel
(
AdapterModel
):
def
__init__
(
self
,
prompt_adapter_id
=
None
,
num_virtual_tokens
=
None
,
prompt_embedding
=
None
)
->
None
:
self
.
id
=
prompt_adapter_id
self
.
prompt_embedding
=
prompt_embedding
self
.
num_virtual_tokens
=
num_virtual_tokens
@
classmethod
def
from_local_checkpoint
(
cls
,
adapter_model_path
:
str
,
prompt_adapter_id
:
int
,
num_virtual_tokens
:
int
,
config
:
PromptAdapterConfig
,
device
:
str
=
"cuda"
,
)
->
"PromptAdapterModel"
:
if
num_virtual_tokens
>
config
.
max_prompt_adapter_token
:
raise
ValueError
(
f
'num_virtual_tokens (
{
num_virtual_tokens
}
) should be <= '
f
'max_prompt_adapter_token(
{
config
.
max_prompt_adapter_token
}
)'
)
adapters_weights
=
load_peft_weights
(
adapter_model_path
,
device
)
prompt_embedding
=
adapters_weights
[
"prompt_embeddings"
].
to
(
config
.
prompt_adapter_dtype
)
return
cls
(
prompt_adapter_id
,
num_virtual_tokens
,
prompt_embedding
)
class
PromptAdapterModelManager
(
AdapterModelManager
):
"""A manager that manages multiple Prompt Adapter models."""
def
__init__
(
self
,
model
:
nn
.
Module
,
max_num_seqs
:
int
,
max_num_batched_tokens
:
int
,
prompt_adapter_config
:
PromptAdapterConfig
,
):
"""Create a PromptAdapterModel and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
prompt_adapter_config: the PromptAdapter config,
"""
self
.
model
:
nn
.
Module
=
model
# Dict instead of a Set for compatibility with LRUCache.
self
.
prompt_adapter_index_to_id
:
List
[
Optional
[
int
]]
=
[
None
]
*
self
.
prompt_adapter_slots
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_batched_tokens
=
math
.
ceil
(
max_num_batched_tokens
/
8
)
*
8
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
model
.
prompt_adapter_manager
=
self
self
.
adapter_type
=
'PromptAdapter'
self
.
base_indices
=
torch
.
tensor
([
-
1
])
self
.
base_embedding_indices
=
torch
.
tensor
([])
self
.
modules
:
Dict
[
str
,
nn
.
Module
]
=
{}
self
.
_create_prompt_adapter_modules
()
self
.
_last_mapping
:
Optional
[
PromptAdapterMapping
]
=
None
@
property
def
prompt_adapter_slots
(
self
)
->
int
:
return
self
.
prompt_adapter_config
.
max_prompt_adapters
@
property
def
adapter_slots
(
self
)
->
int
:
return
self
.
prompt_adapter_slots
@
property
def
capacity
(
self
)
->
int
:
return
self
.
prompt_adapter_config
.
max_cpu_prompt_adapters
def
activate_adapter
(
self
,
prompt_adapter_id
:
int
,
)
->
bool
:
"""Move PromptAdapter into a GPU buffer
to be used in the forward pass."""
if
prompt_adapter_id
in
self
.
_active_adapters
:
return
False
first_free_slot
=
next
(
((
i
,
prompt_adapter_id
)
for
i
,
prompt_adapter_id
in
enumerate
(
self
.
prompt_adapter_index_to_id
)
if
prompt_adapter_id
is
None
),
None
)
if
first_free_slot
is
None
:
raise
ValueError
(
"No free prompt_adapter slots"
)
index
,
_
=
first_free_slot
self
.
_active_adapters
[
prompt_adapter_id
]
=
None
prompt_adapter_model
=
(
self
.
_registered_adapters
[
prompt_adapter_id
])
logger
.
debug
(
"Activating prompt_adapter. int id: %d, slot index: %d"
,
prompt_adapter_model
.
id
,
index
)
self
.
prompt_adapter_index_to_id
[
index
]
=
prompt_adapter_model
.
id
for
_
,
v
in
self
.
modules
.
items
():
v
.
set_prompt_adapter
(
index
,
prompt_adapter_model
.
prompt_embedding
)
return
True
def
_deactivate_adapter
(
self
,
prompt_adapter_id
:
int
):
try
:
index
=
self
.
prompt_adapter_index_to_id
.
index
(
prompt_adapter_id
)
self
.
prompt_adapter_index_to_id
[
index
]
=
None
for
_
,
v
in
self
.
modules
.
items
():
v
.
reset_prompt_adapter
(
index
)
except
ValueError
:
pass
def
_add_adapter
(
self
,
prompt_adapter
:
PromptAdapterModel
):
self
.
_registered_adapters
[
prompt_adapter
.
id
]
=
prompt_adapter
def
_set_adapter_mapping
(
self
,
mapping
:
PromptAdapterMapping
)
->
None
:
base_indices
,
base_embedding_indices
=
convert_mapping
(
mapping
,
self
.
prompt_adapter_index_to_id
)
for
k
,
v
in
self
.
modules
.
items
():
v
.
set_mapping
(
base_indices
,
base_embedding_indices
)
def
_create_prompt_adapter_modules
(
self
):
for
module_name
,
module
in
self
.
model
.
named_modules
(
remove_duplicate
=
False
):
if
"VocabParallel"
in
module
.
__class__
.
__name__
:
new_module
=
VocabParallelEmbeddingWithPromptAdapter
(
module
)
new_module
.
create_prompt_adapter_weights
(
self
.
prompt_adapter_config
)
replaced_module
=
self
.
replace_submodule
(
self
.
model
,
module_name
,
new_module
)
self
.
register_module
(
module
.
__class__
.
__name__
,
replaced_module
)
replaced_module
.
set_mapping
(
self
.
base_indices
,
self
.
base_embedding_indices
)
break
def
replace_submodule
(
self
,
model
:
nn
.
Module
,
module_name
:
str
,
new_module
:
nn
.
Module
)
->
nn
.
Module
:
"""Replace a submodule in a model with a new module."""
parent
=
model
.
get_submodule
(
"."
.
join
(
module_name
.
split
(
"."
)[:
-
1
]))
target_name
=
module_name
.
split
(
"."
)[
-
1
]
setattr
(
parent
,
target_name
,
new_module
)
return
new_module
def
register_module
(
self
,
module_name
:
str
,
module
:
nn
.
Module
):
self
.
modules
[
module_name
]
=
module
def
pin_adapter
(
self
,
prompt_adapter_id
:
int
)
->
bool
:
"""Pin a PromptAdapterModel in the manager cache."""
raise
NotImplementedError
(
"Pinning is not supported in PromptAdapterModelManager. "
"Use LRUCachePromptAdapterModelManager for pinning"
)
# type: ignore
def
remove_all_adapters
(
self
):
"""Remove all PromptAdapterModel from the manager."""
self
.
_registered_adapters
.
clear
()
self
.
prompt_adapter_index_to_id
=
[
None
]
*
self
.
prompt_adapter_slots
self
.
_active_adapters
.
clear
()
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
return
deactivate_adapter
(
adapter_id
,
self
.
_active_adapters
,
self
.
_deactivate_adapter
)
def
add_adapter
(
self
,
adapter
:
PromptAdapterModel
)
->
bool
:
return
add_adapter
(
adapter
,
self
.
_registered_adapters
,
self
.
capacity
,
self
.
_add_adapter
)
def
set_adapter_mapping
(
self
,
mapping
:
PromptAdapterMapping
)
->
None
:
self
.
_last_mapping
=
set_adapter_mapping
(
mapping
,
self
.
_last_mapping
,
self
.
_set_adapter_mapping
)
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
return
remove_adapter
(
adapter_id
,
self
.
_registered_adapters
,
self
.
deactivate_adapter
)
def
list_adapters
(
self
)
->
Dict
[
int
,
Any
]:
return
list_adapters
(
self
.
_registered_adapters
)
def
get_adapter
(
self
,
adapter_id
:
int
)
->
Optional
[
Any
]:
return
get_adapter
(
adapter_id
,
self
.
_registered_adapters
)
class
PromptAdapterLRUCache
(
AdapterLRUCache
[
PromptAdapterModel
]):
def
__init__
(
self
,
capacity
:
int
,
deactivate_prompt_adapter_fn
:
Callable
[[
int
],
bool
]):
super
().
__init__
(
capacity
,
deactivate_prompt_adapter_fn
)
class
LRUCachePromptAdapterModelManager
(
PromptAdapterModelManager
):
"""A model manager that manages multiple prompt_adapters with LRU cache."""
def
__init__
(
self
,
model
:
nn
.
Module
,
max_num_seqs
:
int
,
max_num_batched_tokens
:
int
,
prompt_adapter_config
:
PromptAdapterConfig
,
):
self
.
prompt_adapter_config
=
prompt_adapter_config
super
().
__init__
(
model
,
max_num_seqs
,
max_num_batched_tokens
,
prompt_adapter_config
)
self
.
_registered_adapters
=
PromptAdapterLRUCache
(
self
.
capacity
,
self
.
deactivate_adapter
)
self
.
_active_adapters
=
PromptAdapterLRUCache
(
self
.
prompt_adapter_slots
,
self
.
_deactivate_adapter
)
def
list_adapters
(
self
)
->
Dict
[
int
,
PromptAdapterModel
]:
"""List all registered PromptAdapterModel."""
return
dict
(
self
.
_registered_adapters
.
cache
)
def
add_adapter
(
self
,
prompt_adapter
:
PromptAdapterModel
)
->
bool
:
"""Add a PromptAdapterModel to the manager."""
if
prompt_adapter
.
id
not
in
self
.
_registered_adapters
:
self
.
_add_adapter
(
prompt_adapter
)
was_added
=
True
else
:
# We always touch to update the LRU cache order
self
.
_registered_adapters
.
touch
(
prompt_adapter
.
id
)
was_added
=
False
return
was_added
def
activate_adapter
(
self
,
prompt_adapter_id
:
int
,
)
->
bool
:
if
prompt_adapter_id
not
in
self
.
_active_adapters
and
len
(
self
.
_active_adapters
)
>=
self
.
prompt_adapter_slots
:
self
.
_active_adapters
.
remove_oldest
()
result
=
super
().
activate_adapter
(
prompt_adapter_id
)
# We always touch to update the LRU cache order
self
.
_active_adapters
.
touch
(
prompt_adapter_id
)
return
result
def
remove_oldest_adapter
(
self
)
->
bool
:
if
len
(
self
.
_registered_adapters
)
>
0
:
self
.
_registered_adapters
.
remove_oldest
()
return
True
return
False
def
pin_adapter
(
self
,
prompt_adapter_id
:
int
)
->
bool
:
"""Pin a PromptAdapterModel in the manager cache."""
self
.
_pin_prompt_adapter_in_cpu_cache
(
prompt_adapter_id
)
self
.
_pin_prompt_adapter_in_gpu_cache
(
prompt_adapter_id
)
return
True
def
_pin_prompt_adapter_in_cpu_cache
(
self
,
prompt_adapter_id
:
int
):
try
:
self
.
_registered_adapters
.
pin
(
prompt_adapter_id
)
except
ValueError
as
err
:
raise
ValueError
(
"Pinning failed. "
f
"Prompt Adapter
{
prompt_adapter_id
}
is not registered."
)
from
err
def
_pin_prompt_adapter_in_gpu_cache
(
self
,
prompt_adapter_id
:
int
):
if
prompt_adapter_id
not
in
self
.
_active_adapters
:
# move adapter to gpu if not already active
self
.
activate_adapter
(
prompt_adapter_id
)
self
.
_active_adapters
.
pin
(
prompt_adapter_id
)
def
create_prompt_adapter_manager
(
model
:
nn
.
Module
,
max_num_seqs
:
int
,
max_num_batched_tokens
:
int
,
prompt_adapter_config
:
PromptAdapterConfig
,
prompt_adapter_manager_cls
:
Type
[
PromptAdapterModelManager
]
=
PromptAdapterModelManager
,
**
kwargs
)
->
PromptAdapterModelManager
:
"""Create a PromptAdapterModel for a given model."""
prompt_adapter_manager
=
prompt_adapter_manager_cls
(
model
=
model
,
max_num_seqs
=
max_num_seqs
,
max_num_batched_tokens
=
max_num_batched_tokens
,
prompt_adapter_config
=
prompt_adapter_config
,
**
kwargs
)
return
prompt_adapter_manager
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment