Unverified Commit 4d6ada94 authored by Swapnil Parekh's avatar Swapnil Parekh Committed by GitHub
Browse files

[CORE] Adding support for insertion of soft-tuned prompts (#4645)


Co-authored-by: default avatarSwapnil Parekh <swapnilp@ibm.com>
Co-authored-by: default avatarJoe G <joseph.granados@h2o.ai>
Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent a0550cbc
...@@ -116,7 +116,7 @@ async def detokenize(request: DetokenizeRequest): ...@@ -116,7 +116,7 @@ async def detokenize(request: DetokenizeRequest):
@app.get("/v1/models") @app.get("/v1/models")
async def show_available_models(): async def show_available_models():
models = await openai_serving_chat.show_available_models() models = await openai_serving_completion.show_available_models()
return JSONResponse(content=models.model_dump()) return JSONResponse(content=models.model_dump())
...@@ -236,7 +236,8 @@ if __name__ == "__main__": ...@@ -236,7 +236,8 @@ if __name__ == "__main__":
args.lora_modules, args.lora_modules,
args.chat_template) args.chat_template)
openai_serving_completion = OpenAIServingCompletion( openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules) engine, model_config, served_model_names, args.lora_modules,
args.prompt_adapters)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names) served_model_names)
app.root_path = args.root_path app.root_path = args.root_path
......
...@@ -9,7 +9,8 @@ import json ...@@ -9,7 +9,8 @@ import json
import ssl import ssl
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -23,6 +24,16 @@ class LoRAParserAction(argparse.Action): ...@@ -23,6 +24,16 @@ class LoRAParserAction(argparse.Action):
setattr(namespace, self.dest, lora_list) setattr(namespace, self.dest, lora_list)
class PromptAdapterParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
adapter_list = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
setattr(namespace, self.dest, adapter_list)
def make_arg_parser(): def make_arg_parser():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
...@@ -65,6 +76,14 @@ def make_arg_parser(): ...@@ -65,6 +76,14 @@ def make_arg_parser():
action=LoRAParserAction, action=LoRAParserAction,
help="LoRA module configurations in the format name=path. " help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.") "Multiple modules can be specified.")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
parser.add_argument("--chat-template", parser.add_argument("--chat-template",
type=nullable_str, type=nullable_str,
default=None, default=None,
......
...@@ -258,7 +258,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -258,7 +258,7 @@ class OpenAIServingChat(OpenAIServing):
prompt=prompt, prompt=prompt,
add_special_tokens=request.add_special_tokens) add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) _, lora_request = self._maybe_get_adapter(request)
decoding_config = await self.engine.get_decoding_config() decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \ guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend or decoding_config.guided_decoding_backend
......
...@@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, ...@@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
TokenizeResponse, UsageInfo) TokenizeResponse, UsageInfo)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing) OpenAIServing,
PromptAdapterPath)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
...@@ -67,11 +68,13 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -67,11 +68,13 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]): lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]]):
super().__init__(engine=engine, super().__init__(engine=engine,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules,
prompt_adapters=prompt_adapters)
async def create_completion(self, request: CompletionRequest, async def create_completion(self, request: CompletionRequest,
raw_request: Request): raw_request: Request):
...@@ -101,7 +104,12 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -101,7 +104,12 @@ class OpenAIServingCompletion(OpenAIServing):
generators: List[AsyncIterator[RequestOutput]] = [] generators: List[AsyncIterator[RequestOutput]] = []
try: try:
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) adapter_type, adapter_request = self._maybe_get_adapter(request)
lora_request, prompt_adapter_request = None, None
if adapter_type == 'LoRA':
lora_request, prompt_adapter_request = adapter_request, None
elif adapter_type == 'PromptAdapter':
lora_request, prompt_adapter_request = None, adapter_request
decoding_config = await self.engine.get_decoding_config() decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \ guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend or decoding_config.guided_decoding_backend
...@@ -147,6 +155,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -147,6 +155,7 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params, sampling_params,
f"{request_id}-{i}", f"{request_id}-{i}",
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
) )
......
...@@ -16,12 +16,19 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -16,12 +16,19 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ModelPermission, TokenizeRequest) ModelPermission, TokenizeRequest)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass @dataclass
class LoRAModulePath: class LoRAModulePath:
name: str name: str
...@@ -30,9 +37,14 @@ class LoRAModulePath: ...@@ -30,9 +37,14 @@ class LoRAModulePath:
class OpenAIServing: class OpenAIServing:
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, def __init__(
served_model_names: List[str], self,
lora_modules: Optional[List[LoRAModulePath]]): engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
):
super().__init__() super().__init__()
self.engine = engine self.engine = engine
...@@ -49,9 +61,8 @@ class OpenAIServing: ...@@ -49,9 +61,8 @@ class OpenAIServing:
self.served_model_names = served_model_names self.served_model_names = served_model_names
if lora_modules is None: self.lora_requests = []
self.lora_requests = [] if lora_modules is not None:
else:
self.lora_requests = [ self.lora_requests = [
LoRARequest( LoRARequest(
lora_name=lora.name, lora_name=lora.name,
...@@ -60,6 +71,20 @@ class OpenAIServing: ...@@ -60,6 +71,20 @@ class OpenAIServing:
) for i, lora in enumerate(lora_modules, start=1) ) for i, lora in enumerate(lora_modules, start=1)
] ]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with open(f"./{prompt_adapter.local_path}"
f"/adapter_config.json") as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ model_cards = [
...@@ -75,7 +100,14 @@ class OpenAIServing: ...@@ -75,7 +100,14 @@ class OpenAIServing:
permission=[ModelPermission()]) permission=[ModelPermission()])
for lora in self.lora_requests for lora in self.lora_requests
] ]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.served_model_names[0],
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards) model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards) return ModelList(data=model_cards)
def create_error_response( def create_error_response(
...@@ -109,20 +141,29 @@ class OpenAIServing: ...@@ -109,20 +141,29 @@ class OpenAIServing:
return None return None
if request.model in [lora.lora_name for lora in self.lora_requests]: if request.model in [lora.lora_name for lora in self.lora_requests]:
return None return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
]:
return None
return self.create_error_response( return self.create_error_response(
message=f"The model `{request.model}` does not exist.", message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora( def _maybe_get_adapter(
self, request: Union[CompletionRequest, ChatCompletionRequest, self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest] EmbeddingRequest]
) -> Optional[LoRARequest]: ) -> Tuple[Optional[str], Optional[Union[LoRARequest,
PromptAdapterRequest]]]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None, None
for lora in self.lora_requests: for lora in self.lora_requests:
if request.model == lora.lora_name: if request.model == lora.lora_name:
return lora return 'LoRA', lora
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return 'PromptAdapter', prompt_adapter
# if _check_model has been called earlier, this will be unreachable # if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.") raise ValueError(f"The model `{request.model}` does not exist.")
......
...@@ -7,6 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig ...@@ -7,6 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
...@@ -48,6 +49,7 @@ class CPUExecutor(ExecutorBase): ...@@ -48,6 +49,7 @@ class CPUExecutor(ExecutorBase):
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config, multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=True, is_driver_worker=True,
) )
self.driver_worker.init_device() self.driver_worker.init_device()
...@@ -90,6 +92,19 @@ class CPUExecutor(ExecutorBase): ...@@ -90,6 +92,19 @@ class CPUExecutor(ExecutorBase):
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.driver_worker.list_prompt_adapters()
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
def check_health(self) -> None: def check_health(self) -> None:
# CPUExecutor will always be healthy as long as # CPUExecutor will always be healthy as long as
# it's running. # it's running.
......
...@@ -4,8 +4,10 @@ from typing import List, Optional, Set, Tuple ...@@ -4,8 +4,10 @@ from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig) PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
...@@ -28,6 +30,7 @@ class ExecutorBase(ABC): ...@@ -28,6 +30,7 @@ class ExecutorBase(ABC):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
...@@ -38,6 +41,7 @@ class ExecutorBase(ABC): ...@@ -38,6 +41,7 @@ class ExecutorBase(ABC):
self.device_config = device_config self.device_config = device_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.speculative_config = speculative_config self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self._init_executor() self._init_executor()
...@@ -95,6 +99,23 @@ class ExecutorBase(ABC): ...@@ -95,6 +99,23 @@ class ExecutorBase(ABC):
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
raise NotImplementedError
@abstractmethod
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError # type: ignore
@abstractmethod
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError
@abstractmethod @abstractmethod
def check_health(self) -> None: def check_health(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an """Checks if the executor is healthy. If not, it should raise an
...@@ -122,12 +143,14 @@ class ExecutorAsyncBase(ExecutorBase): ...@@ -122,12 +143,14 @@ class ExecutorAsyncBase(ExecutorBase):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
) -> None: ) -> None:
self.pp_locks: Optional[List[asyncio.Lock]] = None self.pp_locks: Optional[List[asyncio.Lock]] = None
super().__init__(model_config, cache_config, parallel_config, super().__init__(model_config, cache_config, parallel_config,
scheduler_config, device_config, load_config, scheduler_config, device_config, load_config,
lora_config, multimodal_config, speculative_config) lora_config, multimodal_config, speculative_config,
prompt_adapter_config)
@abstractmethod @abstractmethod
async def execute_model_async( async def execute_model_async(
......
...@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union ...@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
...@@ -45,6 +46,7 @@ class GPUExecutor(ExecutorBase): ...@@ -45,6 +46,7 @@ class GPUExecutor(ExecutorBase):
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config, multimodal_config=self.multimodal_config,
speculative_config=self.speculative_config, speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=(not self.parallel_config) is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0), or (rank % self.parallel_config.tensor_parallel_size == 0),
) )
...@@ -107,6 +109,25 @@ class GPUExecutor(ExecutorBase): ...@@ -107,6 +109,25 @@ class GPUExecutor(ExecutorBase):
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.driver_worker.list_prompt_adapters()
def check_health(self) -> None: def check_health(self) -> None:
# GPUExecutor will always be healthy as long as # GPUExecutor will always be healthy as long as
# it's running. # it's running.
......
...@@ -8,7 +8,8 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, ...@@ -8,7 +8,8 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig) PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.distributed_gpu_executor import ( # yapf: disable from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync) DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
...@@ -44,6 +45,7 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -44,6 +45,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
assert device_config.device_type == "xpu" assert device_config.device_type == "xpu"
...@@ -58,6 +60,7 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -58,6 +60,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group
......
...@@ -4,7 +4,8 @@ import torch ...@@ -4,7 +4,8 @@ import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig) PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -27,6 +28,7 @@ class XPUExecutor(GPUExecutor): ...@@ -27,6 +28,7 @@ class XPUExecutor(GPUExecutor):
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
assert device_config.device_type == "xpu" assert device_config.device_type == "xpu"
...@@ -43,6 +45,7 @@ class XPUExecutor(GPUExecutor): ...@@ -43,6 +45,7 @@ class XPUExecutor(GPUExecutor):
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None self.speculative_config = None
# Instantiate the worker and load the model to GPU. # Instantiate the worker and load the model to GPU.
......
...@@ -8,6 +8,7 @@ import torch.nn as nn ...@@ -8,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.adapter_commons.layers import AdapterMapping
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -134,15 +135,8 @@ def _apply_lora_packed_nslice( ...@@ -134,15 +135,8 @@ def _apply_lora_packed_nslice(
@dataclass @dataclass
class LoRAMapping: class LoRAMapping(AdapterMapping):
# Per every token in input_ids: pass
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
......
...@@ -4,12 +4,17 @@ import math ...@@ -4,12 +4,17 @@ import math
import os import os
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import safetensors.torch import safetensors.torch
import torch import torch
from torch import nn from torch import nn
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA, from vllm.lora.layers import (BaseLayerWithLoRA,
...@@ -19,7 +24,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights ...@@ -19,7 +24,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.utils import LRUCache, is_pin_memory_available from vllm.utils import is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -153,7 +158,7 @@ def get_lora_id(): ...@@ -153,7 +158,7 @@ def get_lora_id():
return _GLOBAL_LORA_ID return _GLOBAL_LORA_ID
class LoRAModel: class LoRAModel(AdapterModel):
"""A LoRA fine-tuned model.""" """A LoRA fine-tuned model."""
def __init__( def __init__(
...@@ -388,7 +393,7 @@ class LoRAModel: ...@@ -388,7 +393,7 @@ class LoRAModel:
) )
class LoRAModelManager: class LoRAModelManager(AdapterModelManager):
"""A manager that manages multiple LoRA-fine-tuned models.""" """A manager that manages multiple LoRA-fine-tuned models."""
def __init__( def __init__(
...@@ -440,8 +445,7 @@ class LoRAModelManager: ...@@ -440,8 +445,7 @@ class LoRAModelManager:
# base_indices, sampler_indices, sampler_indices_padded, # base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices # embeddings_indices
self.indices_len: List[Optional[int]] = [None] * 4 self.indices_len: List[Optional[int]] = [None] * 4
super().__init__(model)
self.model = model
if hasattr(self.model, "supported_lora_modules"): if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy( self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules) self.model.supported_lora_modules)
...@@ -453,11 +457,11 @@ class LoRAModelManager: ...@@ -453,11 +457,11 @@ class LoRAModelManager:
self.model.packed_modules_mapping) self.model.packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {} self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache. # Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping: Optional[LoRAMapping] = None self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self
self.adapter_type = 'LoRa'
@property @property
def capacity(self) -> int: def capacity(self) -> int:
...@@ -467,15 +471,16 @@ class LoRAModelManager: ...@@ -467,15 +471,16 @@ class LoRAModelManager:
def lora_slots(self) -> int: def lora_slots(self) -> int:
return self.lora_config.max_loras return self.lora_config.max_loras
def __len__(self) -> int: @property
return len(self._registered_loras) def adapter_slots(self) -> int:
return self.lora_slots
def activate_lora( def activate_adapter(
self, self,
lora_id: int, lora_id: int,
) -> bool: ) -> bool:
"""Move LoRA into a GPU buffer to be used in the forward pass.""" """Move LoRA into a GPU buffer to be used in the forward pass."""
if lora_id in self._active_loras: if lora_id in self._active_adapters:
return False return False
first_free_slot = next( first_free_slot = next(
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
...@@ -483,8 +488,8 @@ class LoRAModelManager: ...@@ -483,8 +488,8 @@ class LoRAModelManager:
if first_free_slot is None: if first_free_slot is None:
raise ValueError("No free lora slots") raise ValueError("No free lora slots")
index, _ = first_free_slot index, _ = first_free_slot
self._active_loras[lora_id] = None self._active_adapters[lora_id] = None
lora_model = self._registered_loras[lora_id] lora_model = self._registered_adapters[lora_id]
logger.debug("Activating LoRA. int id: %d, slot index: %d", logger.debug("Activating LoRA. int id: %d, slot index: %d",
lora_model.id, index) lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id self.lora_index_to_id[index] = lora_model.id
...@@ -498,21 +503,13 @@ class LoRAModelManager: ...@@ -498,21 +503,13 @@ class LoRAModelManager:
module.reset_lora(index) module.reset_lora(index)
return True return True
def _deactivate_lora(self, lora_id: int): def _deactivate_adapter(self, lora_id: int):
try: try:
index = self.lora_index_to_id.index(lora_id) index = self.lora_index_to_id.index(lora_id)
self.lora_index_to_id[index] = None self.lora_index_to_id[index] = None
except ValueError: except ValueError:
pass pass
def deactivate_lora(self, lora_id: int) -> bool:
"""Remove a LoRA from a GPU buffer."""
if lora_id in self._active_loras:
self._deactivate_lora(lora_id)
self._active_loras.pop(lora_id)
return True
return False
def _set_long_lora_context(self, lora: LoRAModel): def _set_long_lora_context(self, lora: LoRAModel):
if self.long_lora_context is None: if self.long_lora_context is None:
return return
...@@ -528,40 +525,19 @@ class LoRAModelManager: ...@@ -528,40 +525,19 @@ class LoRAModelManager:
if offsets: if offsets:
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
def _add_lora(self, lora: LoRAModel): def _add_adapter(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora) self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora self._registered_adapters[lora.id] = lora
self._set_long_lora_context(lora) self._set_long_lora_context(lora)
def add_lora(self, lora: LoRAModel) -> bool: def pin_adapter(self, lora_id: int) -> bool:
"""Add a LoRAModel to the manager CPU cache."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras:
if len(self._registered_loras) >= self.capacity:
raise RuntimeError("No free LoRA slots.")
self._add_lora(lora)
return True
return False
def remove_lora(self, lora_id: int) -> bool:
"""Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora?
self.deactivate_lora(lora_id)
if self.long_lora_context:
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
return bool(self._registered_loras.pop(lora_id, None))
def pin_lora(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache.""" """Pin a LoRAModel in the manager cache."""
raise NotImplementedError( raise NotImplementedError(
"Pinning is not supported in LoRAModelManager." "Pinning is not supported in LoRAModelManager."
"Use LRUCacheLoRAModelManager for pinning") # type: ignore "Use LRUCacheLoRAModelManager for pinning") # type: ignore
# TODO see if this can be vectorized # TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None: def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded, (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_offsets_tensor, embeddings_indices, long_lora_offsets_tensor,
indices_len) = convert_mapping(mapping, self.lora_index_to_id, indices_len) = convert_mapping(mapping, self.lora_index_to_id,
...@@ -583,23 +559,11 @@ class LoRAModelManager: ...@@ -583,23 +559,11 @@ class LoRAModelManager:
# Maintain the reference # Maintain the reference
self.indices_len[:] = indices_len self.indices_len[:] = indices_len
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: def remove_all_adapters(self):
if self._last_mapping != lora_mapping:
self._set_lora_mapping(lora_mapping)
self._last_mapping = lora_mapping
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras)
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)
def remove_all_loras(self):
"""Remove all LoRAModels from the manager.""" """Remove all LoRAModels from the manager."""
self._registered_loras.clear() self._registered_adapters.clear()
self.lora_index_to_id = [None] * self.lora_slots self.lora_index_to_id = [None] * self.lora_slots
self._active_loras.clear() self._active_adapters.clear()
def _create_lora_modules(self): def _create_lora_modules(self):
for module_name, module in self.model.named_modules( for module_name, module in self.model.named_modules(
...@@ -743,18 +707,39 @@ class LoRAModelManager: ...@@ -743,18 +707,39 @@ class LoRAModelManager:
lora_model.loras[module_name] = PackedLoRALayerWeights.pack( lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras) replacement_loras)
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)
def add_adapter(self, adapter: LoRAModel) -> bool:
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", adapter.id, adapter.id,
adapter.scaling_factor)
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
class LoRALRUCache(LRUCache[LoRAModel]): def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
self._set_adapter_mapping)
def remove_adapter(self, adapter_id: int) -> bool:
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
def list_adapters(self) -> Dict[int, Any]:
return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
return get_adapter(adapter_id, self._registered_adapters)
class LoRALRUCache(AdapterLRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
bool]): bool]):
super().__init__(capacity) super().__init__(capacity, deactivate_lora_fn)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: int, value: LoRAModel):
logger.debug("Removing LoRA. int id: %d", key)
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
class LRUCacheLoRAModelManager(LoRAModelManager): class LRUCacheLoRAModelManager(LoRAModelManager):
...@@ -770,49 +755,49 @@ class LRUCacheLoRAModelManager(LoRAModelManager): ...@@ -770,49 +755,49 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
): ):
super().__init__(model, max_num_seqs, max_num_batched_tokens, super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config) vocab_size, lora_config)
self._registered_loras: LoRALRUCache = LoRALRUCache( self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_lora) self.capacity, self.deactivate_adapter)
self._active_loras: LoRALRUCache = LoRALRUCache( self._active_adapters: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_lora) self.lora_slots, self._deactivate_adapter)
def list_loras(self) -> Dict[int, LoRAModel]: def list_adapters(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels.""" """List all registered LoRAModels."""
return dict(self._registered_loras.cache) return dict(self._registered_adapters.cache)
def add_lora(self, lora: LoRAModel) -> bool: def add_adapter(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager.""" """Add a LoRAModel to the manager."""
logger.debug( logger.debug(
"Adding lora. Model id: %d, " "Adding lora. Model id: %d, "
"int id: %d, " "int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor) "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras: if lora.id not in self._registered_adapters:
self._add_lora(lora) self._add_adapter(lora)
was_added = True was_added = True
else: else:
# We always touch to update the LRU cache order # We always touch to update the LRU cache order
self._registered_loras.touch(lora.id) self._registered_adapters.touch(lora.id)
was_added = False was_added = False
return was_added return was_added
def activate_lora( def activate_adapter(
self, self,
lora_id: int, lora_id: int,
) -> bool: ) -> bool:
if lora_id not in self._active_loras and len( if lora_id not in self._active_adapters and len(
self._active_loras) >= self.lora_slots: self._active_adapters) >= self.lora_slots:
self._active_loras.remove_oldest() self._active_adapters.remove_oldest()
result = super().activate_lora(lora_id) result = super().activate_adapter(lora_id)
# We always touch to update the LRU cache order # We always touch to update the LRU cache order
self._active_loras.touch(lora_id) self._active_adapters.touch(lora_id)
return result return result
def remove_oldest_lora(self) -> bool: def remove_oldest_adapter(self) -> bool:
if len(self._registered_loras) > 0: if len(self._registered_adapters) > 0:
self._registered_loras.remove_oldest() self._registered_adapters.remove_oldest()
return True return True
return False return False
def pin_lora(self, lora_id: int) -> bool: def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache.""" """Pin a LoRAModel in the manager cache."""
self._pin_lora_in_cpu_cache(lora_id) self._pin_lora_in_cpu_cache(lora_id)
self._pin_lora_in_gpu_cache(lora_id) self._pin_lora_in_gpu_cache(lora_id)
...@@ -820,17 +805,17 @@ class LRUCacheLoRAModelManager(LoRAModelManager): ...@@ -820,17 +805,17 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def _pin_lora_in_cpu_cache(self, lora_id: int): def _pin_lora_in_cpu_cache(self, lora_id: int):
try: try:
self._registered_loras.pin(lora_id) self._registered_adapters.pin(lora_id)
except ValueError as err: except ValueError as err:
raise ValueError("Pinning failed. " raise ValueError("Pinning failed. "
f"LoRA {lora_id} is not registered.") from err f"LoRA {lora_id} is not registered.") from err
def _pin_lora_in_gpu_cache(self, lora_id: int): def _pin_lora_in_gpu_cache(self, lora_id: int):
if lora_id not in self._active_loras: if lora_id not in self._active_adapters:
# move lora to gpu if not already active # move lora to gpu if not already active
self.activate_lora(lora_id) self.activate_adapter(lora_id)
self._active_loras.pin(lora_id) self._active_adapters.pin(lora_id)
def create_lora_manager( def create_lora_manager(
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from vllm.adapter_commons.request import AdapterRequest
@dataclass @dataclass
class LoRARequest: class LoRARequest(AdapterRequest):
""" """
Request for a LoRA adapter. Request for a LoRA adapter.
Note that this class should be be used internally. For online Note that this class should be used internally. For online
serving, it is recommended to not allow users to use this class but serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters. accessing unauthorized LoRA adapters.
...@@ -20,15 +22,16 @@ class LoRARequest: ...@@ -20,15 +22,16 @@ class LoRARequest:
lora_int_id: int lora_int_id: int
lora_local_path: str lora_local_path: str
long_lora_max_len: Optional[int] = None long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__
def __post_init__(self): @property
if self.lora_int_id < 1: def adapter_id(self):
raise ValueError( return self.lora_int_id
f"lora_int_id must be > 0, got {self.lora_int_id}")
def __eq__(self, value: object) -> bool: @property
return isinstance( def name(self):
value, LoRARequest) and self.lora_int_id == value.lora_int_id return self.lora_name
def __hash__(self) -> int: @property
return self.lora_int_id def local_path(self):
return self.lora_local_path
from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
import torch import torch
from vllm.adapter_commons.utils import (add_adapter_worker,
apply_adapters_worker,
list_adapters_worker,
set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import (LoRAModel, LoRAModelManager, from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager) LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -14,79 +17,13 @@ from vllm.lora.request import LoRARequest ...@@ -14,79 +17,13 @@ from vllm.lora.request import LoRARequest
logger = init_logger(__name__) logger = init_logger(__name__)
class AbstractWorkerLoRAManager(ABC): class WorkerLoRAManager(AbstractWorkerManager):
"""Abstract class for managing LoRA models on the worker side."""
def __init__(self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
max_position_embeddings: Optional[int] = None):
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.device = device
self.lora_config = lora_config
# If False, do not cache. If None, cache is empty.
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@property
@abstractmethod
def is_enabled(self) -> bool:
...
@abstractmethod
def create_lora_manager(
self,
model: torch.nn.Module,
) -> Any:
...
@abstractmethod
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
...
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
...
@abstractmethod
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
...
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
...
@abstractmethod
def remove_all_loras(self):
...
@abstractmethod
def list_loras(self) -> Set[int]:
...
class WorkerLoRAManager(AbstractWorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side. """WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded.""" loaded), and every other LoRA will be unloaded."""
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager _manager_cls: Type[LoRAModelManager] = LoRAModelManager
def __init__( def __init__(
self, self,
...@@ -103,16 +40,23 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -103,16 +40,23 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
self._lora_model_cls = lora_model_cls self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules self.embedding_padding_modules = embedding_padding_modules
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.lora_config = lora_config
self.max_position_embeddings = max_position_embeddings
super().__init__(device)
# Lazily initialized by create_lora_manager. # Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager self._adapter_manager: LoRAModelManager
super().__init__(
max_num_seqs, @contextmanager
max_num_batched_tokens, def dummy_lora_cache(self):
vocab_size, """Use this context manager to reuse the dummy lora model
lora_config, to avoid creating it repeatedly."""
device, self._cached_dummy_lora = None
max_position_embeddings=max_position_embeddings, yield
) self._cached_dummy_lora = False
@property @property
def is_enabled(self) -> bool: def is_enabled(self) -> bool:
...@@ -128,41 +72,14 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -128,41 +72,14 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
lora_config=self.lora_config, lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls, lora_manager_cls=self._manager_cls,
) )
self._lora_manager = lora_manager self._adapter_manager = lora_manager
return lora_manager.model return lora_manager.model
def set_active_loras(self, lora_requests: Set[LoRARequest], def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping)
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_that_exist = self.list_loras()
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
new_loras = set(loras_map)
loras_to_add = new_loras - loras_that_exist
loras_to_remove = loras_that_exist - new_loras
for lora_id in loras_to_remove:
self.remove_lora(lora_id)
for lora_id in loras_to_add:
self.add_lora(loras_map[lora_id])
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
try: try:
model = self._lora_manager.model model = self._adapter_manager.model
supported_lora_modules = model.supported_lora_modules supported_lora_modules = model.supported_lora_modules
packed_modules_mapping = model.packed_modules_mapping packed_modules_mapping = model.packed_modules_mapping
expected_lora_modules: List[str] = [] expected_lora_modules: List[str] = []
...@@ -198,37 +115,45 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -198,37 +115,45 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
return lora return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras(): if lora_request.lora_int_id in self.list_adapters():
return False return False
if isinstance(self._cached_dummy_lora, LoRAModel): if isinstance(self._cached_dummy_lora, LoRAModel):
dummy_lora = self._cached_dummy_lora.clone( dummy_lora = self._cached_dummy_lora.clone(
lora_request.lora_int_id) lora_request.lora_int_id)
else: else:
dummy_lora = self._lora_manager.create_dummy_lora( dummy_lora = self._adapter_manager.create_dummy_lora(
lora_request.lora_int_id, rank, 1, self.embedding_modules) lora_request.lora_int_id, rank, 1, self.embedding_modules)
if self._cached_dummy_lora is None: if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora) return self._adapter_manager.add_adapter(dummy_lora)
def add_lora(self, lora_request: LoRARequest) -> bool: def pin_adapter(self, adapter_id: int) -> bool:
if lora_request.lora_int_id in self.list_loras(): return self._adapter_manager.pin_adapter(adapter_id)
return False
lora = self._load_lora(lora_request) def set_active_adapters(self, requests: Set[Any],
loaded = self._lora_manager.add_lora(lora) mapping: Optional[Any]) -> None:
self._lora_manager.activate_lora(lora.id) set_active_adapters_worker(requests, mapping, self._apply_adapters,
return loaded self._adapter_manager.set_adapter_mapping)
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter)
def remove_lora(self, lora_id: int) -> bool: def add_adapter(self, adapter_request: Any) -> bool:
return self._lora_manager.remove_lora(lora_id) return add_adapter_worker(adapter_request, self.list_adapters,
self._load_adapter,
self._adapter_manager.add_adapter,
self._adapter_manager.activate_adapter)
def pin_lora(self, lora_id: int) -> bool: def remove_adapter(self, adapter_id: int) -> bool:
return self._lora_manager.pin_lora(lora_id) return self._adapter_manager.remove_adapter(adapter_id)
def remove_all_loras(self): def remove_all_adapters(self):
self._lora_manager.remove_all_loras() self._adapter_manager.remove_all_adapters()
def list_loras(self) -> Set[int]: def list_adapters(self) -> Set[int]:
return set(self._lora_manager.list_loras()) return list_adapters_worker(self._adapter_manager.list_adapters)
class LRUCacheWorkerLoRAManager(WorkerLoRAManager): class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
...@@ -238,8 +163,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -238,8 +163,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
(unless they are already loaded) and least recently used LoRAs will (unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity.""" be unloaded if the cache is above capacity."""
_lora_manager_cls: Type[ _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager( def create_lora_manager(
self, self,
...@@ -247,40 +171,41 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -247,40 +171,41 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
) -> Any: ) -> Any:
lora_manager = create_lora_manager( lora_manager = create_lora_manager(
model, model,
lora_manager_cls=self._lora_manager_cls, lora_manager_cls=self._manager_cls,
max_num_seqs=self.max_num_seqs, max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
lora_config=self.lora_config, lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
) )
self._lora_manager = lora_manager self._adapter_manager = lora_manager
return lora_manager.model return lora_manager.model
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
loras_map = { loras_map = {
lora_request.lora_int_id: lora_request lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request for lora_request in lora_requests if lora_request
} }
if len(loras_map) > self._lora_manager.lora_slots: if len(loras_map) > self._adapter_manager.lora_slots:
raise RuntimeError( raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater " f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots " "than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).") f"({self._adapter_manager.lora_slots}).")
for lora in loras_map.values(): for lora in loras_map.values():
self.add_lora(lora) self.add_adapter(lora)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_adapter(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id not in self.list_loras(): if lora_request.lora_int_id not in self.list_adapters():
# Remove before we load the new lora to save memory # Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity: if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager) assert isinstance(self._adapter_manager,
self._lora_manager.remove_oldest_lora() LRUCacheLoRAModelManager)
lora = self._load_lora(lora_request) self._adapter_manager.remove_oldest_adapter()
loaded = self._lora_manager.add_lora(lora) lora = self._load_adapter(lora_request)
loaded = self._adapter_manager.add_adapter(lora)
else: else:
# If the lora is already loaded, just touch it to # If the lora is already loaded, just touch it to
# update its position in the caches # update its position in the caches
loaded = self._lora_manager.get_lora( loaded = self._adapter_manager.get_adapter(
lora_request.lora_int_id) is not None lora_request.lora_int_id) is not None
self._lora_manager.activate_lora(lora_request.lora_int_id) self._adapter_manager.activate_adapter(lora_request.lora_int_id)
return loaded return loaded
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.adapter_commons.layers import AdapterMapping
from vllm.config import PromptAdapterConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@dataclass
class PromptAdapterMapping(AdapterMapping):
pass
class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.emb_layer = self.base_layer
if 'LoRA' in base_layer.__class__.__name__:
self.emb_layer = self.base_layer.base_layer
def create_prompt_adapter_weights(
self, prompt_adapter_config: PromptAdapterConfig):
self.embeddings_tensors = torch.zeros(
(
prompt_adapter_config.max_prompt_adapters,
prompt_adapter_config.max_prompt_adapter_token,
self.emb_layer.embedding_dim,
),
dtype=self.emb_layer.weight.dtype,
device=self.emb_layer.weight.device,
)
self.adapter_lengths = torch.zeros(
prompt_adapter_config.max_prompt_adapters,
dtype=torch.long,
device=self.emb_layer.weight.device)
self.indices_gpu: torch.Tensor
self.embedding_indices_gpu: torch.Tensor
def reset_prompt_adapter(self, index: int):
self.embeddings_tensors[index] = 0
def set_prompt_adapter(
self,
index: int,
adapter_model: Optional[torch.Tensor],
):
self.reset_prompt_adapter(index)
if adapter_model is not None:
length = adapter_model.shape[0]
self.embeddings_tensors[index, :length] = adapter_model
self.adapter_lengths[index] = length
def set_mapping(
self,
prompt_indices: torch.Tensor,
prompt_embedding_indices: torch.Tensor,
):
self.indices_gpu = prompt_indices.to(
device=self.emb_layer.weight.device)
self.embedding_indices_gpu = prompt_embedding_indices.to(
device=self.emb_layer.weight.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = self.base_layer(x)
if self.embedding_indices_gpu.ndim > 1:
valid_mask = self.indices_gpu != -1
gathered_embeddings = self.embeddings_tensors[
self.embedding_indices_gpu[:, 0],
self.embedding_indices_gpu[:, 1]]
# Update hidden states
hidden_states[valid_mask] = gathered_embeddings
return hidden_states
\ No newline at end of file
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Type
import torch
from torch import nn
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config import PromptAdapterConfig
from vllm.prompt_adapter.layers import (
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
from vllm.prompt_adapter.layers import PromptAdapterMapping
logger = logging.getLogger(__name__)
_GLOBAL_PROMPT_ADAPTER_ID = 0
def get_prompt_adapter_id():
global _GLOBAL_PROMPT_ADAPTER_ID
_GLOBAL_PROMPT_ADAPTER_ID += 1
return _GLOBAL_PROMPT_ADAPTER_ID
def convert_to_embedding_indices(indices):
embedding_indices = []
count = 0
for value in indices:
if value == -1:
count = 0
else:
embedding_indices.append([value, count])
count += 1
return torch.tensor(embedding_indices)
def convert_mapping(
mapping: PromptAdapterMapping,
prompt_adapter_index_to_id: List[Optional[int]],
) -> torch.Tensor:
"""Converts PromptAdapterMapping to index tensors.
Args:
mapping: PromptAdapterMapping mapping rows in a
batch to PromptAdapter ids.
prompt_adapter_index_to_id: List mapping PromptAdapter
ids to PromptAdapter indices.
Returns:
pa_indices: Tensor of shape [batch_size] mapping batch rows to
PromptAdapter indices.
"""
id_to_index = {
id_: idx
for idx, id_ in enumerate(prompt_adapter_index_to_id)
if id_ is not None
}
pa_indices = ([
id_to_index.get(id_, -1) if id_ > 0 else -1
for id_ in mapping.index_mapping
])
pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
pa_indices = torch.tensor(pa_indices)
return pa_indices, pa_embedding_mapping
class PromptAdapterModel(AdapterModel):
def __init__(self,
prompt_adapter_id=None,
num_virtual_tokens=None,
prompt_embedding=None) -> None:
self.id = prompt_adapter_id
self.prompt_embedding = prompt_embedding
self.num_virtual_tokens = num_virtual_tokens
@classmethod
def from_local_checkpoint(
cls,
adapter_model_path: str,
prompt_adapter_id: int,
num_virtual_tokens: int,
config: PromptAdapterConfig,
device: str = "cuda",
) -> "PromptAdapterModel":
from peft.utils import load_peft_weights
if num_virtual_tokens > config.max_prompt_adapter_token:
raise ValueError(
f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
adapters_weights = load_peft_weights(adapter_model_path, device)
prompt_embedding = adapters_weights["prompt_embeddings"].to(
config.prompt_adapter_dtype)
return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
class PromptAdapterModelManager(AdapterModelManager):
"""A manager that manages multiple Prompt Adapter models."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
):
"""Create a PromptAdapterModel and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
prompt_adapter_config: the PromptAdapter config,
"""
self.model: nn.Module = model
# Dict instead of a Set for compatibility with LRUCache.
self.prompt_adapter_index_to_id: List[
Optional[int]] = [None] * self.prompt_adapter_slots
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.prompt_adapter_config = prompt_adapter_config
self.model.prompt_adapter_manager = self
self.adapter_type = 'PromptAdapter'
self.base_indices = torch.tensor([-1])
self.base_embedding_indices = torch.tensor([])
self.modules: Dict[str, nn.Module] = {}
self._create_prompt_adapter_modules()
self._last_mapping: Optional[PromptAdapterMapping] = None
@property
def prompt_adapter_slots(self) -> int:
return self.prompt_adapter_config.max_prompt_adapters
@property
def adapter_slots(self) -> int:
return self.prompt_adapter_slots
@property
def capacity(self) -> int:
return self.prompt_adapter_config.max_cpu_prompt_adapters
def activate_adapter(
self,
prompt_adapter_id: int,
) -> bool:
"""Move PromptAdapter into a GPU buffer
to be used in the forward pass."""
if prompt_adapter_id in self._active_adapters:
return False
first_free_slot = next(
((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
None)
if first_free_slot is None:
raise ValueError("No free prompt_adapter slots")
index, _ = first_free_slot
self._active_adapters[prompt_adapter_id] = None
prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
logger.debug("Activating prompt_adapter. int id: %d, slot index: %d",
prompt_adapter_model.id, index)
self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
for _, v in self.modules.items():
v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
return True
def _deactivate_adapter(self, prompt_adapter_id: int):
try:
index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
self.prompt_adapter_index_to_id[index] = None
for _, v in self.modules.items():
v.reset_prompt_adapter(index)
except ValueError:
pass
def _add_adapter(self, prompt_adapter: PromptAdapterModel):
self._registered_adapters[prompt_adapter.id] = prompt_adapter
def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
base_indices, base_embedding_indices = convert_mapping(
mapping, self.prompt_adapter_index_to_id)
for k, v in self.modules.items():
v.set_mapping(base_indices, base_embedding_indices)
def _create_prompt_adapter_modules(self):
for module_name, module in self.model.named_modules(
remove_duplicate=False):
if "VocabParallel" in module.__class__.__name__:
new_module = VocabParallelEmbeddingWithPromptAdapter(module)
new_module.create_prompt_adapter_weights(
self.prompt_adapter_config)
replaced_module = self.replace_submodule(
self.model, module_name, new_module)
self.register_module(module.__class__.__name__,
replaced_module)
replaced_module.set_mapping(self.base_indices,
self.base_embedding_indices)
break
def replace_submodule(self, model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def register_module(self, module_name: str, module: nn.Module):
self.modules[module_name] = module
def pin_adapter(self, prompt_adapter_id: int) -> bool:
"""Pin a PromptAdapterModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in PromptAdapterModelManager."
"Use LRUCachePromptAdapterModelManager for pinning"
) # type: ignore
def remove_all_adapters(self):
"""Remove all PromptAdapterModel from the manager."""
self._registered_adapters.clear()
self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
self._active_adapters.clear()
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)
def add_adapter(self, adapter: PromptAdapterModel) -> bool:
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
self._set_adapter_mapping)
def remove_adapter(self, adapter_id: int) -> bool:
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
def list_adapters(self) -> Dict[int, Any]:
return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
return get_adapter(adapter_id, self._registered_adapters)
class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
def __init__(self, capacity: int,
deactivate_prompt_adapter_fn: Callable[[int], bool]):
super().__init__(capacity, deactivate_prompt_adapter_fn)
class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
"""A model manager that manages multiple prompt_adapters with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
):
self.prompt_adapter_config = prompt_adapter_config
super().__init__(model, max_num_seqs, max_num_batched_tokens,
prompt_adapter_config)
self._registered_adapters = PromptAdapterLRUCache(
self.capacity, self.deactivate_adapter)
self._active_adapters = PromptAdapterLRUCache(
self.prompt_adapter_slots, self._deactivate_adapter)
def list_adapters(self) -> Dict[int, PromptAdapterModel]:
"""List all registered PromptAdapterModel."""
return dict(self._registered_adapters.cache)
def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
"""Add a PromptAdapterModel to the manager."""
if prompt_adapter.id not in self._registered_adapters:
self._add_adapter(prompt_adapter)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_adapters.touch(prompt_adapter.id)
was_added = False
return was_added
def activate_adapter(
self,
prompt_adapter_id: int,
) -> bool:
if prompt_adapter_id not in self._active_adapters and len(
self._active_adapters) >= self.prompt_adapter_slots:
self._active_adapters.remove_oldest()
result = super().activate_adapter(prompt_adapter_id)
# We always touch to update the LRU cache order
self._active_adapters.touch(prompt_adapter_id)
return result
def remove_oldest_adapter(self) -> bool:
if len(self._registered_adapters) > 0:
self._registered_adapters.remove_oldest()
return True
return False
def pin_adapter(self, prompt_adapter_id: int) -> bool:
"""Pin a PromptAdapterModel in the manager cache."""
self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
return True
def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
try:
self._registered_adapters.pin(prompt_adapter_id)
except ValueError as err:
raise ValueError(
"Pinning failed. "
f"Prompt Adapter {prompt_adapter_id} is not registered."
) from err
def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
if prompt_adapter_id not in self._active_adapters:
# move adapter to gpu if not already active
self.activate_adapter(prompt_adapter_id)
self._active_adapters.pin(prompt_adapter_id)
def create_prompt_adapter_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
prompt_adapter_manager_cls: Type[
PromptAdapterModelManager] = PromptAdapterModelManager,
**kwargs) -> PromptAdapterModelManager:
"""Create a PromptAdapterModel for a given model."""
prompt_adapter_manager = prompt_adapter_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
prompt_adapter_config=prompt_adapter_config,
**kwargs)
return prompt_adapter_manager
from dataclasses import dataclass
from vllm.adapter_commons.request import AdapterRequest
@dataclass
class PromptAdapterRequest(AdapterRequest):
"""
Request for a Prompt adapter.
"""
prompt_adapter_name: str
prompt_adapter_id: int
prompt_adapter_local_path: str
prompt_adapter_num_virtual_tokens: int
def __hash__(self):
return super().__hash__()
@property
def adapter_id(self):
return self.prompt_adapter_id
@property
def name(self):
return self.prompt_adapter_name
@property
def local_path(self):
return self.prompt_adapter_local_path
import logging
from typing import Any, Optional, Set, Type
import torch
from vllm.adapter_commons.utils import (add_adapter_worker,
apply_adapters_worker,
list_adapters_worker,
set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
from vllm.config import PromptAdapterConfig
from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager,
PromptAdapterModel,
PromptAdapterModelManager,
create_prompt_adapter_manager)
from vllm.prompt_adapter.request import PromptAdapterRequest
logger = logging.getLogger(__name__)
class WorkerPromptAdapterManager(AbstractWorkerManager):
"""WorkerPromptAdapterManager that manages
prompt_adapter models on the worker side.
Every request, the requested prompt_adapters will be
loaded (unless they are already loaded),
and every other prompt_adapter will be unloaded."""
_manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager
def __init__(
self,
max_num_seqs: int,
max_num_batched_tokens: int,
device: torch.device,
prompt_adapter_config: PromptAdapterConfig,
prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel
):
self._adapter_manager: PromptAdapterModelManager
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self._prompt_adapter_model_cls = prompt_adapter_model_cls
self.prompt_adapter_config = prompt_adapter_config
super().__init__(device)
@property
def is_enabled(self) -> bool:
return True
def create_prompt_adapter_manager(
self,
model: torch.nn.Module,
) -> Any:
prompt_adapter_manager = create_prompt_adapter_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
prompt_adapter_config=self.prompt_adapter_config,
prompt_adapter_manager_cls=self._manager_cls,
)
self._adapter_manager = prompt_adapter_manager
return prompt_adapter_manager.model
def _load_adapter(
self, prompt_adapter_request: PromptAdapterRequest
) -> PromptAdapterModel:
try:
prompt_adapter = (
self._prompt_adapter_model_cls.from_local_checkpoint(
prompt_adapter_request.prompt_adapter_local_path,
prompt_adapter_id=prompt_adapter_request.prompt_adapter_id,
num_virtual_tokens=prompt_adapter_request.
prompt_adapter_num_virtual_tokens,
config=self.prompt_adapter_config,
device=str(self.device),
))
except Exception as e:
raise RuntimeError(
f"Loading prompt_adapter "
f"{prompt_adapter_request.prompt_adapter_local_path}"
f" failed") from e
return prompt_adapter
def add_dummy_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return True
def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id)
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
set_active_adapters_worker(requests, mapping, self._apply_adapters,
self._adapter_manager.set_adapter_mapping)
def add_adapter(self, adapter_request: Any) -> bool:
return add_adapter_worker(adapter_request, self.list_adapters,
self._load_adapter,
self._adapter_manager.add_adapter,
self._adapter_manager.activate_adapter)
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter)
def remove_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.remove_adapter(adapter_id)
def remove_all_adapters(self):
self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> Set[int]:
return list_adapters_worker(self._adapter_manager.list_adapters)
class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager):
"""WorkerPromptAdapterManager that manages
prompt_adapter models on the worker side.
Uses an LRU Cache. Every request, the requested
prompt_adapters will be loaded (unless they are already loaded)
and least recently used prompt_adapters will
be unloaded if the cache is above capacity."""
_prompt_adapter_manager_cls: Type[
LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager
def create_prompt_adapter_manager(
self,
model: torch.nn.Module,
) -> Any:
prompt_adapter_manager = create_prompt_adapter_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
prompt_adapter_config=self.prompt_adapter_config,
prompt_adapter_manager_cls=self._prompt_adapter_manager_cls)
self._adapter_manager: LRUCachePromptAdapterModelManager = (
prompt_adapter_manager)
return prompt_adapter_manager.model
def _apply_adapters(
self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None:
prompt_adapters_map = {
prompt_adapter_request.prompt_adapter_id: prompt_adapter_request
for prompt_adapter_request in prompt_adapter_requests
if prompt_adapter_request
}
if len(prompt_adapters_map
) > self._adapter_manager.prompt_adapter_slots:
raise RuntimeError(
f"Number of requested prompt_adapters "
f"({len(prompt_adapters_map)}) is greater "
"than the number of GPU prompt_adapter slots "
f"({self._adapter_manager.prompt_adapter_slots}).")
for prompt_adapter in prompt_adapters_map.values():
self.add_adapter(prompt_adapter)
def add_adapter(self,
prompt_adapter_request: PromptAdapterRequest) -> bool:
if prompt_adapter_request.prompt_adapter_id not in self.list_adapters(
):
# Remove before we load the new prompt_adapter to save memory
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
self._adapter_manager.remove_oldest_adapter()
prompt_adapter = self._load_adapter(prompt_adapter_request)
loaded = self._adapter_manager.add_adapter(prompt_adapter)
else:
# If the prompt_adapter is already loaded, just touch it to
# update its position in the caches
loaded = self._adapter_manager.get_adapter(
prompt_adapter_request.prompt_adapter_id) is not None
self._adapter_manager.activate_adapter(
prompt_adapter_request.prompt_adapter_id)
return loaded
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -238,21 +239,25 @@ class Sequence: ...@@ -238,21 +239,25 @@ class Sequence:
block_size: The block size of the sequence. Should be the same as the block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine. block size used by the block manager and cache engine.
lora_request: LoRA request. lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request.
""" """
def __init__( def __init__(
self, self,
seq_id: int, seq_id: int,
inputs: "LLMInputs", inputs: "LLMInputs",
block_size: int, block_size: int,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.inputs = inputs self.inputs = inputs
self.block_size = block_size self.block_size = block_size
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.data = SequenceData(self.prompt_token_ids) self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
...@@ -287,6 +292,11 @@ class Sequence: ...@@ -287,6 +292,11 @@ class Sequence:
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
def get_output_text_to_return(self, buffer_length: int): def get_output_text_to_return(self, buffer_length: int):
# We return the full output text if the sequence is finished. # We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished() truncate = buffer_length and not self.is_finished()
...@@ -414,6 +424,7 @@ class SequenceGroup: ...@@ -414,6 +424,7 @@ class SequenceGroup:
encoder_seq: Optional, the single encoder sequence. Should be None encoder_seq: Optional, the single encoder sequence. Should be None
unless you are working with an encoder/decoder model. unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
""" """
def __init__( def __init__(
...@@ -427,6 +438,7 @@ class SequenceGroup: ...@@ -427,6 +438,7 @@ class SequenceGroup:
pooling_params: Optional[PoolingParams] = None, pooling_params: Optional[PoolingParams] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
...@@ -441,6 +453,7 @@ class SequenceGroup: ...@@ -441,6 +453,7 @@ class SequenceGroup:
self.state = SequenceGroupState() self.state = SequenceGroupState()
self.embeddings = embeddings self.embeddings = embeddings
self.pooling_params = pooling_params self.pooling_params = pooling_params
self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq self.encoder_seq = encoder_seq
self.trace_headers = trace_headers self.trace_headers = trace_headers
...@@ -466,6 +479,16 @@ class SequenceGroup: ...@@ -466,6 +479,16 @@ class SequenceGroup:
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
@property
def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
if self.prompt_adapter_request else 0
def get_last_latency(self, now: float) -> Optional[float]: def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings.""" """Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error. # If still in prefill phase, raise Error.
...@@ -624,6 +647,7 @@ class SequenceGroupMetadata: ...@@ -624,6 +647,7 @@ class SequenceGroupMetadata:
(SequenceGroup.encoder_seq). Should be None (SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder unless you are working with an encoder/decoder
model. model.
prompt_adapter_request: Prompt Adapter request.
""" """
def __init__( def __init__(
...@@ -642,6 +666,7 @@ class SequenceGroupMetadata: ...@@ -642,6 +666,7 @@ class SequenceGroupMetadata:
multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None, encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None, cross_block_table: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
...@@ -650,6 +675,7 @@ class SequenceGroupMetadata: ...@@ -650,6 +675,7 @@ class SequenceGroupMetadata:
self.block_tables = block_tables self.block_tables = block_tables
self.pooling_params = pooling_params self.pooling_params = pooling_params
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state self.state = SequenceGroupState() if state is None else state
...@@ -674,6 +700,16 @@ class SequenceGroupMetadata: ...@@ -674,6 +700,16 @@ class SequenceGroupMetadata:
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
@property
def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0
@property @property
def token_chunk_size(self) -> int: def token_chunk_size(self) -> int:
"""Return the number of tokens to be processed (chunk size).""" """Return the number of tokens to be processed (chunk size)."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment