Unverified Commit 82ec66f5 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[V0 Deprecation] Remove Prompt Adapters (#20588)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 78c13e30
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -30,7 +29,6 @@ class RequestLogger: ...@@ -30,7 +29,6 @@ class RequestLogger:
params: Optional[Union[SamplingParams, PoolingParams, params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]], BeamSearchParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: ) -> None:
max_log_len = self.max_log_len max_log_len = self.max_log_len
if max_log_len is not None: if max_log_len is not None:
...@@ -44,7 +42,6 @@ class RequestLogger: ...@@ -44,7 +42,6 @@ class RequestLogger:
"Received request %s: prompt: %r, " "Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, " "params: %s, prompt_token_ids: %s, "
"prompt_embeds shape: %s, " "prompt_embeds shape: %s, "
"lora_request: %s, prompt_adapter_request: %s.", request_id, "lora_request: %s.", request_id, prompt, params, prompt_token_ids,
prompt, params, prompt_token_ids,
prompt_embeds.shape if prompt_embeds is not None else None, prompt_embeds.shape if prompt_embeds is not None else None,
lora_request, prompt_adapter_request) lora_request)
...@@ -1620,7 +1620,6 @@ async def init_app_state( ...@@ -1620,7 +1620,6 @@ async def init_app_state(
model_config=model_config, model_config=model_config,
base_model_paths=base_model_paths, base_model_paths=base_model_paths,
lora_modules=lora_modules, lora_modules=lora_modules,
prompt_adapters=args.prompt_adapters,
) )
await state.openai_serving_models.init_static_loras() await state.openai_serving_models.init_static_loras()
state.openai_serving_responses = OpenAIServingResponses( state.openai_serving_responses = OpenAIServingResponses(
......
...@@ -20,8 +20,7 @@ from vllm.config import config ...@@ -20,8 +20,7 @@ from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template) validate_chat_template)
from vllm.entrypoints.openai.serving_models import (LoRAModulePath, from vllm.entrypoints.openai.serving_models import LoRAModulePath
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -65,27 +64,6 @@ class LoRAParserAction(argparse.Action): ...@@ -65,27 +64,6 @@ class LoRAParserAction(argparse.Action):
setattr(namespace, self.dest, lora_list) setattr(namespace, self.dest, lora_list)
class PromptAdapterParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
adapter_list: list[PromptAdapterPath] = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
setattr(namespace, self.dest, adapter_list)
@config @config
@dataclass @dataclass
class FrontendArgs: class FrontendArgs:
...@@ -115,9 +93,6 @@ class FrontendArgs: ...@@ -115,9 +93,6 @@ class FrontendArgs:
or JSON list format. Example (old format): `'name=path'` Example (new or JSON list format. Example (old format): `'name=path'` Example (new
format): `{\"name\": \"name\", \"path\": \"lora_path\", format): `{\"name\": \"name\", \"path\": \"lora_path\",
\"base_model_name\": \"id\"}`""" \"base_model_name\": \"id\"}`"""
prompt_adapters: Optional[list[PromptAdapterPath]] = None
"""Prompt adapter configurations in the format name=path. Multiple adapters
can be specified."""
chat_template: Optional[str] = None chat_template: Optional[str] = None
"""The file path to the chat template, or the template in single-line form """The file path to the chat template, or the template in single-line form
for the specified model.""" for the specified model."""
...@@ -207,12 +182,6 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" ...@@ -207,12 +182,6 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
frontend_kwargs["lora_modules"]["type"] = optional_type(str) frontend_kwargs["lora_modules"]["type"] = optional_type(str)
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
# Special case: Prompt adapters need custom parser action and
# optional_type(str)
frontend_kwargs["prompt_adapters"]["type"] = optional_type(str)
frontend_kwargs["prompt_adapters"][
"action"] = PromptAdapterParserAction
# Special case: Middleware needs append action # Special case: Middleware needs append action
frontend_kwargs["middleware"]["action"] = "append" frontend_kwargs["middleware"]["action"] = "append"
frontend_kwargs["middleware"]["type"] = str frontend_kwargs["middleware"]["type"] = str
...@@ -288,9 +257,6 @@ def validate_parsed_serve_args(args: argparse.Namespace): ...@@ -288,9 +257,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
if args.enable_auto_tool_choice and not args.tool_call_parser: if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires " raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser") "--tool-call-parser")
if args.enable_prompt_embeds and args.enable_prompt_adapter:
raise ValueError(
"Cannot use prompt embeds and prompt adapter at the same time.")
def log_non_default_args(args: argparse.Namespace): def log_non_default_args(args: argparse.Namespace):
......
...@@ -337,7 +337,6 @@ async def main(args): ...@@ -337,7 +337,6 @@ async def main(args):
model_config=model_config, model_config=model_config,
base_model_paths=base_model_paths, base_model_paths=base_model_paths,
lora_modules=None, lora_modules=None,
prompt_adapters=None,
) )
openai_serving_chat = OpenAIServingChat( openai_serving_chat = OpenAIServingChat(
engine, engine,
......
...@@ -147,11 +147,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -147,11 +147,8 @@ class OpenAIServingChat(OpenAIServing):
raise self.engine_client.dead_error raise self.engine_client.dead_error
try: try:
( lora_request = self._maybe_get_adapters(
lora_request, request, supports_default_mm_loras=True)
prompt_adapter_request,
) = self._maybe_get_adapters(request,
supports_default_mm_loras=True)
model_name = self._get_model_name(request.model, lora_request) model_name = self._get_model_name(request.model, lora_request)
...@@ -239,8 +236,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -239,8 +236,7 @@ class OpenAIServingChat(OpenAIServing):
self._log_inputs(request_id, self._log_inputs(request_id,
request_prompts[i], request_prompts[i],
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers)) self._get_trace_headers(raw_request.headers))
...@@ -259,7 +255,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -259,7 +255,6 @@ class OpenAIServingChat(OpenAIServing):
request_id, request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority, priority=request.priority,
) )
......
...@@ -49,19 +49,11 @@ class ClassificationMixin(OpenAIServing): ...@@ -49,19 +49,11 @@ class ClassificationMixin(OpenAIServing):
return None return None
try: try:
( ctx.lora_request = self._maybe_get_adapters(ctx.request)
ctx.lora_request,
ctx.prompt_adapter_request,
) = self._maybe_get_adapters(ctx.request)
ctx.tokenizer = await self.engine_client.get_tokenizer( ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request) ctx.lora_request)
if ctx.prompt_adapter_request is not None:
raise NotImplementedError(
"Prompt adapter is not supported for classification models"
)
( (
ctx.request_prompts, ctx.request_prompts,
ctx.engine_prompts, ctx.engine_prompts,
......
...@@ -121,10 +121,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -121,10 +121,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
try: try:
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
...@@ -197,7 +194,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -197,7 +194,6 @@ class OpenAIServingCompletion(OpenAIServing):
request_prompts[i], request_prompts[i],
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
) )
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
...@@ -221,7 +217,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -221,7 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params, sampling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
) )
......
...@@ -53,18 +53,11 @@ class EmbeddingMixin(OpenAIServing): ...@@ -53,18 +53,11 @@ class EmbeddingMixin(OpenAIServing):
) -> Optional[ErrorResponse]: ) -> Optional[ErrorResponse]:
ctx = cast(EmbeddingServeContext, ctx) ctx = cast(EmbeddingServeContext, ctx)
try: try:
( ctx.lora_request = self._maybe_get_adapters(ctx.request)
ctx.lora_request,
ctx.prompt_adapter_request,
) = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
) )
if ctx.prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
( (
_, _,
......
...@@ -68,7 +68,6 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error ...@@ -68,7 +68,6 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
MultiModalDataDict) MultiModalDataDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob, PromptLogprobs from vllm.sequence import Logprob, PromptLogprobs
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
...@@ -161,7 +160,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, ...@@ -161,7 +160,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
request_id: str request_id: str
created_time: int = Field(default_factory=lambda: int(time.time())) created_time: int = Field(default_factory=lambda: int(time.time()))
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
# Shared across most requests # Shared across most requests
tokenizer: Optional[AnyTokenizer] = None tokenizer: Optional[AnyTokenizer] = None
...@@ -343,12 +341,10 @@ class OpenAIServing: ...@@ -343,12 +341,10 @@ class OpenAIServing:
return self.create_error_response( return self.create_error_response(
"Request prompts not available") "Request prompts not available")
self._log_inputs( self._log_inputs(request_id_item,
request_id_item, ctx.request_prompts[i],
ctx.request_prompts[i], params=pooling_params,
params=pooling_params, lora_request=ctx.lora_request)
lora_request=ctx.lora_request,
prompt_adapter_request=ctx.prompt_adapter_request)
# Mypy has an existing bug related to inferring the variance of # Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`: # TypedDicts with `builtins.enumerate`:
...@@ -450,11 +446,6 @@ class OpenAIServing: ...@@ -450,11 +446,6 @@ class OpenAIServing:
if isinstance(load_result, ErrorResponse) and \ if isinstance(load_result, ErrorResponse) and \
load_result.code == HTTPStatus.BAD_REQUEST.value: load_result.code == HTTPStatus.BAD_REQUEST.value:
error_response = load_result error_response = load_result
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.models.prompt_adapter_requests
]:
return None
return error_response or self.create_error_response( return error_response or self.create_error_response(
message=f"The model `{request.model}` does not exist.", message=f"The model `{request.model}` does not exist.",
...@@ -489,25 +480,21 @@ class OpenAIServing: ...@@ -489,25 +480,21 @@ class OpenAIServing:
self, self,
request: AnyRequest, request: AnyRequest,
supports_default_mm_loras: bool = False, supports_default_mm_loras: bool = False,
) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[ ) -> Optional[LoRARequest]:
None, PromptAdapterRequest]]:
if request.model in self.models.lora_requests: if request.model in self.models.lora_requests:
return self.models.lora_requests[request.model], None return self.models.lora_requests[request.model]
# Currently only support default modality specific loras # Currently only support default modality specific loras
# if we have exactly one lora matched on the request. # if we have exactly one lora matched on the request.
if supports_default_mm_loras: if supports_default_mm_loras:
default_mm_lora = self._get_active_default_mm_loras(request) default_mm_lora = self._get_active_default_mm_loras(request)
if default_mm_lora is not None: if default_mm_lora is not None:
return default_mm_lora, None return default_mm_lora
if self._is_model_supported(request.model): if self._is_model_supported(request.model):
return None, None return None
for prompt_adapter in self.models.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable # if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.") raise ValueError(f"The model `{request.model}` does not exist.")
...@@ -987,7 +974,6 @@ class OpenAIServing: ...@@ -987,7 +974,6 @@ class OpenAIServing:
params: Optional[Union[SamplingParams, PoolingParams, params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]], BeamSearchParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: ) -> None:
if self.request_logger is None: if self.request_logger is None:
return return
...@@ -1009,7 +995,6 @@ class OpenAIServing: ...@@ -1009,7 +995,6 @@ class OpenAIServing:
prompt_embeds, prompt_embeds,
params=params, params=params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
) )
async def _get_trace_headers( async def _get_trace_headers(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pathlib
from asyncio import Lock from asyncio import Lock
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
...@@ -19,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ...@@ -19,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter from vllm.utils import AtomicCounter
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,12 +28,6 @@ class BaseModelPath: ...@@ -31,12 +28,6 @@ class BaseModelPath:
model_path: str model_path: str
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass @dataclass
class LoRAModulePath: class LoRAModulePath:
name: str name: str
...@@ -60,7 +51,6 @@ class OpenAIServingModels: ...@@ -60,7 +51,6 @@ class OpenAIServingModels:
base_model_paths: list[BaseModelPath], base_model_paths: list[BaseModelPath],
*, *,
lora_modules: Optional[list[LoRAModulePath]] = None, lora_modules: Optional[list[LoRAModulePath]] = None,
prompt_adapters: Optional[list[PromptAdapterPath]] = None,
): ):
super().__init__() super().__init__()
...@@ -81,20 +71,6 @@ class OpenAIServingModels: ...@@ -81,20 +71,6 @@ class OpenAIServingModels:
LoRAResolverRegistry.get_resolver(lora_resolver_name)) LoRAResolverRegistry.get_resolver(lora_resolver_name))
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
async def init_static_loras(self): async def init_static_loras(self):
"""Loads all static LoRA modules. """Loads all static LoRA modules.
Raises if any fail to load""" Raises if any fail to load"""
...@@ -141,14 +117,7 @@ class OpenAIServingModels: ...@@ -141,14 +117,7 @@ class OpenAIServingModels:
permission=[ModelPermission()]) permission=[ModelPermission()])
for lora in self.lora_requests.values() for lora in self.lora_requests.values()
] ]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards) model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards) return ModelList(data=model_cards)
async def load_lora_adapter( async def load_lora_adapter(
......
...@@ -94,17 +94,10 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -94,17 +94,10 @@ class OpenAIServingPooling(OpenAIServing):
try: try:
truncate_prompt_tokens = _validate_truncation_size( truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens) self.max_model_len, truncate_prompt_tokens)
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for pooling models")
if isinstance(request, PoolingChatRequest): if isinstance(request, PoolingChatRequest):
( (
_, _,
...@@ -153,8 +146,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -153,8 +146,7 @@ class OpenAIServingPooling(OpenAIServing):
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
request_prompts[i], request_prompts[i],
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers)) self._get_trace_headers(raw_request.headers))
......
...@@ -133,10 +133,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -133,10 +133,7 @@ class OpenAIServingResponses(OpenAIServing):
messages = self._construct_input_messages(request, prev_response) messages = self._construct_input_messages(request, prev_response)
try: try:
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_name = self._get_model_name(request.model, lora_request) model_name = self._get_model_name(request.model, lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
...@@ -169,8 +166,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -169,8 +166,7 @@ class OpenAIServingResponses(OpenAIServing):
self._log_inputs(request.request_id, self._log_inputs(request.request_id,
request_prompts[i], request_prompts[i],
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers)) self._get_trace_headers(raw_request.headers))
...@@ -181,7 +177,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -181,7 +177,6 @@ class OpenAIServingResponses(OpenAIServing):
request.request_id, request.request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority, priority=request.priority,
) )
generators.append(generator) generators.append(generator)
......
...@@ -27,7 +27,6 @@ from vllm.inputs.data import TokensPrompt ...@@ -27,7 +27,6 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import make_async, merge_async_iterators from vllm.utils import make_async, merge_async_iterators
...@@ -58,8 +57,6 @@ class ServingScores(OpenAIServing): ...@@ -58,8 +57,6 @@ class ServingScores(OpenAIServing):
request_id: str, request_id: str,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None, lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]: ) -> Union[list[PoolingRequestOutput], ErrorResponse]:
input_texts = texts_1 + texts_2 input_texts = texts_1 + texts_2
...@@ -100,8 +97,7 @@ class ServingScores(OpenAIServing): ...@@ -100,8 +97,7 @@ class ServingScores(OpenAIServing):
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
input_texts[i], input_texts[i],
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
generators.append( generators.append(
self.engine_client.encode( self.engine_client.encode(
...@@ -176,8 +172,6 @@ class ServingScores(OpenAIServing): ...@@ -176,8 +172,6 @@ class ServingScores(OpenAIServing):
request_id: str, request_id: str,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None, lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]: ) -> Union[list[PoolingRequestOutput], ErrorResponse]:
request_prompts: list[str] = [] request_prompts: list[str] = []
...@@ -261,8 +255,7 @@ class ServingScores(OpenAIServing): ...@@ -261,8 +255,7 @@ class ServingScores(OpenAIServing):
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
request_prompts[i], request_prompts[i],
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_prompt,
...@@ -295,14 +288,7 @@ class ServingScores(OpenAIServing): ...@@ -295,14 +288,7 @@ class ServingScores(OpenAIServing):
raw_request: Optional[Request] = None, raw_request: Optional[Request] = None,
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]: ) -> Union[list[PoolingRequestOutput], ErrorResponse]:
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
...@@ -340,7 +326,6 @@ class ServingScores(OpenAIServing): ...@@ -340,7 +326,6 @@ class ServingScores(OpenAIServing):
request_id=request_id, request_id=request_id,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers) trace_headers=trace_headers)
else: else:
...@@ -352,7 +337,6 @@ class ServingScores(OpenAIServing): ...@@ -352,7 +337,6 @@ class ServingScores(OpenAIServing):
request_id=request_id, request_id=request_id,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers) trace_headers=trace_headers)
async def create_score( async def create_score(
......
...@@ -60,10 +60,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -60,10 +60,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokn-{self._base_request_id(raw_request)}" request_id = f"tokn-{self._base_request_id(raw_request)}"
try: try:
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
...@@ -104,11 +101,8 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -104,11 +101,8 @@ class OpenAIServingTokenization(OpenAIServing):
self._log_inputs(request_id, self._log_inputs(request_id,
request_prompts[i], request_prompts[i],
params=None, params=None,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect
# tokenization (Unlike in Embeddings API where an error is raised)
if isinstance(engine_prompt, if isinstance(engine_prompt,
dict) and "prompt_token_ids" in engine_prompt: dict) and "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"]) input_ids.extend(engine_prompt["prompt_token_ids"])
...@@ -133,21 +127,14 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -133,21 +127,14 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokn-{self._base_request_id(raw_request)}" request_id = f"tokn-{self._base_request_id(raw_request)}"
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
self._log_inputs(request_id, self._log_inputs(request_id,
request.tokens, request.tokens,
params=None, params=None,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised)
prompt_input = await self._tokenize_prompt_input_async( prompt_input = await self._tokenize_prompt_input_async(
request, request,
......
...@@ -150,19 +150,12 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -150,19 +150,12 @@ class OpenAISpeechToText(OpenAIServing):
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
try: try:
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
if lora_request: if lora_request:
return self.create_error_response( return self.create_error_response(
"Currently do not support LoRA for " "Currently do not support LoRA for "
f"{self.task_type.title()}.") f"{self.task_type.title()}.")
if prompt_adapter_request:
return self.create_error_response(
f"Currently do not support PromptAdapter for "
f"{self.task_type.title()}.")
prompts, duration_s = await self._preprocess_speech_to_text( prompts, duration_s = await self._preprocess_speech_to_text(
request=request, request=request,
...@@ -188,8 +181,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -188,8 +181,7 @@ class OpenAISpeechToText(OpenAIServing):
# It will not display special tokens like <|startoftranscript|> # It will not display special tokens like <|startoftranscript|>
request.prompt, request.prompt,
params=sampling_params, params=sampling_params,
lora_request=None, lora_request=None)
prompt_adapter_request=None)
list_result_generator = [ list_result_generator = [
self.engine_client.generate( self.engine_client.generate(
......
...@@ -17,7 +17,6 @@ from vllm.logger import init_logger ...@@ -17,7 +17,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.pooling_params import PoolingTask from vllm.pooling_params import PoolingTask
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
...@@ -50,7 +49,6 @@ class ExecutorBase(ABC): ...@@ -50,7 +49,6 @@ class ExecutorBase(ABC):
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self._init_executor() self._init_executor()
self.is_sleeping = False self.is_sleeping = False
...@@ -171,35 +169,6 @@ class ExecutorBase(ABC): ...@@ -171,35 +169,6 @@ class ExecutorBase(ABC):
assert s == sets[0], "All workers should have the same LORAs." assert s == sets[0], "All workers should have the same LORAs."
return sets[0] return sets[0]
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("add_prompt_adapter",
args=(prompt_adapter_request, )))
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("remove_prompt_adapter",
args=(prompt_adapter_id, )))
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("pin_prompt_adapter",
args=(prompt_adapter_id, )))
def list_prompt_adapters(self) -> Set[int]:
sets = self.collective_rpc("list_prompt_adapters")
for s in sets:
assert (s == sets[0]
), "All workers should have the same prompt adapters."
return sets[0]
def start_profile(self) -> None: def start_profile(self) -> None:
self.collective_rpc("start_profile") self.collective_rpc("start_profile")
......
...@@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest ...@@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs) MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
...@@ -168,18 +167,6 @@ class InputPreprocessor: ...@@ -168,18 +167,6 @@ class InputPreprocessor:
return decoder_input_ids return decoder_input_ids
def _apply_prompt_adapter(
self,
prompt_token_ids: list[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> list[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return prompt_token_ids
def _get_tokenization_kw( def _get_tokenization_kw(
self, self,
overrides: Optional[dict[str, Any]] = None, overrides: Optional[dict[str, Any]] = None,
...@@ -786,15 +773,10 @@ class InputPreprocessor: ...@@ -786,15 +773,10 @@ class InputPreprocessor:
def _build_decoder_only_llm_inputs( def _build_decoder_only_llm_inputs(
self, self,
prompt_inputs: DecoderOnlyInputs, prompt_inputs: DecoderOnlyInputs,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
if "prompt_token_ids" in prompt_inputs: if "prompt_token_ids" in prompt_inputs:
prompt_inputs = cast(Union[TokenInputs, MultiModalInputs], prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
prompt_inputs) # Needed for mypy prompt_inputs) # Needed for mypy
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request,
)
return prompt_inputs return prompt_inputs
...@@ -803,7 +785,6 @@ class InputPreprocessor: ...@@ -803,7 +785,6 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
...@@ -815,7 +796,6 @@ class InputPreprocessor: ...@@ -815,7 +796,6 @@ class InputPreprocessor:
* prompt: input prompt * prompt: input prompt
* lora_request * lora_request
* prompt_adapter_request
* return_mm_hashes * return_mm_hashes
Returns: Returns:
...@@ -830,17 +810,13 @@ class InputPreprocessor: ...@@ -830,17 +810,13 @@ class InputPreprocessor:
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
return self._build_decoder_only_llm_inputs( return self._build_decoder_only_llm_inputs(prompt_comps)
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
async def _process_decoder_only_prompt_async( async def _process_decoder_only_prompt_async(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
...@@ -854,17 +830,13 @@ class InputPreprocessor: ...@@ -854,17 +830,13 @@ class InputPreprocessor:
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
return self._build_decoder_only_llm_inputs( return self._build_decoder_only_llm_inputs(prompt_comps)
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
def preprocess( def preprocess(
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
...@@ -886,7 +858,6 @@ class InputPreprocessor: ...@@ -886,7 +858,6 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
...@@ -895,7 +866,6 @@ class InputPreprocessor: ...@@ -895,7 +866,6 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
...@@ -919,6 +889,5 @@ class InputPreprocessor: ...@@ -919,6 +889,5 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.adapter_commons.layers import AdapterMapping
from vllm.config import PromptAdapterConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@dataclass
class PromptAdapterMapping(AdapterMapping):
pass
class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.emb_layer = self.base_layer
if 'LoRA' in base_layer.__class__.__name__:
self.emb_layer = self.base_layer.base_layer
def create_prompt_adapter_weights(
self, prompt_adapter_config: PromptAdapterConfig):
self.embeddings_tensors = torch.zeros(
(
prompt_adapter_config.max_prompt_adapters,
prompt_adapter_config.max_prompt_adapter_token,
self.emb_layer.embedding_dim,
),
dtype=self.emb_layer.weight.dtype,
device=self.emb_layer.weight.device,
)
self.adapter_lengths = torch.zeros(
prompt_adapter_config.max_prompt_adapters,
dtype=torch.long,
device=self.emb_layer.weight.device)
self.indices_gpu: torch.Tensor
self.embedding_indices_gpu: torch.Tensor
def reset_prompt_adapter(self, index: int):
self.embeddings_tensors[index] = 0
def set_prompt_adapter(
self,
index: int,
adapter_model: Optional[torch.Tensor],
):
self.reset_prompt_adapter(index)
if adapter_model is not None:
length = adapter_model.shape[0]
self.embeddings_tensors[index, :length] = adapter_model
self.adapter_lengths[index] = length
def set_mapping(
self,
prompt_indices: torch.Tensor,
prompt_embedding_indices: torch.Tensor,
):
self.indices_gpu = prompt_indices.to(
device=self.emb_layer.weight.device)
self.embedding_indices_gpu = prompt_embedding_indices.to(
device=self.emb_layer.weight.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = self.base_layer(x)
if self.embedding_indices_gpu.ndim > 1:
valid_mask = self.indices_gpu != -1
gathered_embeddings = self.embeddings_tensors[
self.embedding_indices_gpu[:, 0],
self.embedding_indices_gpu[:, 1]]
# Update hidden states
hidden_states[valid_mask] = gathered_embeddings
return hidden_states
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Type
import torch
from torch import nn
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config import PromptAdapterConfig
from vllm.prompt_adapter.layers import (
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.utils import load_peft_weights
logger = logging.getLogger(__name__)
_GLOBAL_PROMPT_ADAPTER_ID = 0
def get_prompt_adapter_id():
global _GLOBAL_PROMPT_ADAPTER_ID
_GLOBAL_PROMPT_ADAPTER_ID += 1
return _GLOBAL_PROMPT_ADAPTER_ID
def convert_to_embedding_indices(indices):
embedding_indices = []
count = 0
for value in indices:
if value == -1:
count = 0
else:
embedding_indices.append([value, count])
count += 1
return torch.tensor(embedding_indices)
def convert_mapping(
mapping: PromptAdapterMapping,
prompt_adapter_index_to_id: List[Optional[int]],
) -> torch.Tensor:
"""Converts PromptAdapterMapping to index tensors.
Args:
mapping: PromptAdapterMapping mapping rows in a
batch to PromptAdapter ids.
prompt_adapter_index_to_id: List mapping PromptAdapter
ids to PromptAdapter indices.
Returns:
pa_indices: Tensor of shape [batch_size] mapping batch rows to
PromptAdapter indices.
"""
id_to_index = {
id_: idx
for idx, id_ in enumerate(prompt_adapter_index_to_id)
if id_ is not None
}
pa_indices = ([
id_to_index.get(id_, -1) if id_ > 0 else -1
for id_ in mapping.index_mapping
])
pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
pa_indices = torch.tensor(pa_indices)
return pa_indices, pa_embedding_mapping
class PromptAdapterModel(AdapterModel):
def __init__(self,
prompt_adapter_id=None,
num_virtual_tokens=None,
prompt_embedding=None) -> None:
self.id = prompt_adapter_id
self.prompt_embedding = prompt_embedding
self.num_virtual_tokens = num_virtual_tokens
@classmethod
def from_local_checkpoint(
cls,
adapter_model_path: str,
prompt_adapter_id: int,
num_virtual_tokens: int,
config: PromptAdapterConfig,
device: str = "cuda",
) -> "PromptAdapterModel":
if num_virtual_tokens > config.max_prompt_adapter_token:
raise ValueError(
f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
adapters_weights = load_peft_weights(adapter_model_path, device)
prompt_embedding = adapters_weights["prompt_embeddings"].to(
config.prompt_adapter_dtype)
return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
class PromptAdapterModelManager(AdapterModelManager):
"""A manager that manages multiple Prompt Adapter models."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
):
"""Create a PromptAdapterModel and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
prompt_adapter_config: the PromptAdapter config,
"""
self.model: nn.Module = model
# Dict instead of a Set for compatibility with LRUCache.
self.prompt_adapter_index_to_id: List[
Optional[int]] = [None] * self.prompt_adapter_slots
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.prompt_adapter_config = prompt_adapter_config
self.model.prompt_adapter_manager = self
self.adapter_type = 'PromptAdapter'
self.base_indices = torch.tensor([-1])
self.base_embedding_indices = torch.tensor([])
self.modules: Dict[str, nn.Module] = {}
self._create_prompt_adapter_modules()
self._last_mapping: Optional[PromptAdapterMapping] = None
@property
def prompt_adapter_slots(self) -> int:
return self.prompt_adapter_config.max_prompt_adapters
@property
def adapter_slots(self) -> int:
return self.prompt_adapter_slots
@property
def capacity(self) -> int:
return self.prompt_adapter_config.max_cpu_prompt_adapters
def activate_adapter(
self,
prompt_adapter_id: int,
) -> bool:
"""Move PromptAdapter into a GPU buffer
to be used in the forward pass."""
if prompt_adapter_id in self._active_adapters:
return False
first_free_slot = next(
((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
None)
if first_free_slot is None:
raise ValueError("No free prompt_adapter slots")
index, _ = first_free_slot
self._active_adapters[prompt_adapter_id] = None
prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
logger.debug("Activating prompt_adapter. int id: %d, slot index: %d",
prompt_adapter_model.id, index)
self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
for _, v in self.modules.items():
v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
return True
def _deactivate_adapter(self, prompt_adapter_id: int):
try:
index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
self.prompt_adapter_index_to_id[index] = None
for _, v in self.modules.items():
v.reset_prompt_adapter(index)
except ValueError:
pass
def _add_adapter(self, prompt_adapter: PromptAdapterModel):
self._registered_adapters[prompt_adapter.id] = prompt_adapter
def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
base_indices, base_embedding_indices = convert_mapping(
mapping, self.prompt_adapter_index_to_id)
for k, v in self.modules.items():
v.set_mapping(base_indices, base_embedding_indices)
def _create_prompt_adapter_modules(self):
for module_name, module in self.model.named_modules(
remove_duplicate=False):
if "VocabParallel" in module.__class__.__name__:
new_module = VocabParallelEmbeddingWithPromptAdapter(module)
new_module.create_prompt_adapter_weights(
self.prompt_adapter_config)
replaced_module = self.replace_submodule(
self.model, module_name, new_module)
self.register_module(module.__class__.__name__,
replaced_module)
replaced_module.set_mapping(self.base_indices,
self.base_embedding_indices)
break
def replace_submodule(self, model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def register_module(self, module_name: str, module: nn.Module):
self.modules[module_name] = module
def pin_adapter(self, prompt_adapter_id: int) -> bool:
"""Pin a PromptAdapterModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in PromptAdapterModelManager. "
"Use LRUCachePromptAdapterModelManager for pinning"
) # type: ignore
def remove_all_adapters(self):
"""Remove all PromptAdapterModel from the manager."""
self._registered_adapters.clear()
self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
self._active_adapters.clear()
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)
def add_adapter(self, adapter: PromptAdapterModel) -> bool:
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
self._set_adapter_mapping)
def remove_adapter(self, adapter_id: int) -> bool:
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
def list_adapters(self) -> Dict[int, Any]:
return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
return get_adapter(adapter_id, self._registered_adapters)
class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
def __init__(self, capacity: int,
deactivate_prompt_adapter_fn: Callable[[int], bool]):
super().__init__(capacity, deactivate_prompt_adapter_fn)
class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
"""A model manager that manages multiple prompt_adapters with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
):
self.prompt_adapter_config = prompt_adapter_config
super().__init__(model, max_num_seqs, max_num_batched_tokens,
prompt_adapter_config)
self._registered_adapters = PromptAdapterLRUCache(
self.capacity, self.deactivate_adapter)
self._active_adapters = PromptAdapterLRUCache(
self.prompt_adapter_slots, self._deactivate_adapter)
def list_adapters(self) -> Dict[int, PromptAdapterModel]:
"""List all registered PromptAdapterModel."""
return dict(self._registered_adapters.cache)
def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
"""Add a PromptAdapterModel to the manager."""
if prompt_adapter.id not in self._registered_adapters:
self._add_adapter(prompt_adapter)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_adapters.touch(prompt_adapter.id)
was_added = False
return was_added
def activate_adapter(
self,
prompt_adapter_id: int,
) -> bool:
if prompt_adapter_id not in self._active_adapters and len(
self._active_adapters) >= self.prompt_adapter_slots:
self._active_adapters.remove_oldest()
result = super().activate_adapter(prompt_adapter_id)
# We always touch to update the LRU cache order
self._active_adapters.touch(prompt_adapter_id)
return result
def remove_oldest_adapter(self) -> bool:
if len(self._registered_adapters) > 0:
self._registered_adapters.remove_oldest()
return True
return False
def pin_adapter(self, prompt_adapter_id: int) -> bool:
"""Pin a PromptAdapterModel in the manager cache."""
self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
return True
def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
try:
self._registered_adapters.pin(prompt_adapter_id)
except ValueError as err:
raise ValueError(
"Pinning failed. "
f"Prompt Adapter {prompt_adapter_id} is not registered."
) from err
def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
if prompt_adapter_id not in self._active_adapters:
# move adapter to gpu if not already active
self.activate_adapter(prompt_adapter_id)
self._active_adapters.pin(prompt_adapter_id)
def create_prompt_adapter_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
prompt_adapter_manager_cls: Type[
PromptAdapterModelManager] = PromptAdapterModelManager,
**kwargs) -> PromptAdapterModelManager:
"""Create a PromptAdapterModel for a given model."""
prompt_adapter_manager = prompt_adapter_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
prompt_adapter_config=prompt_adapter_config,
**kwargs)
return prompt_adapter_manager
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment