"vscode:/vscode.git/clone" did not exist on "b6101d384db5709b4422ebd05fe84f0891ff63ce"
Unverified Commit 4d6ada94 authored by Swapnil Parekh's avatar Swapnil Parekh Committed by GitHub
Browse files

[CORE] Adding support for insertion of soft-tuned prompts (#4645)


Co-authored-by: default avatarSwapnil Parekh <swapnilp@ibm.com>
Co-authored-by: default avatarJoe G <joseph.granados@h2o.ai>
Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent a0550cbc
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
...@@ -48,6 +48,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -48,6 +48,7 @@ class TP1DraftModelRunner(ModelRunner):
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
multimodal_config: Optional[MultiModalConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
): ):
if return_hidden_states: if return_hidden_states:
...@@ -66,6 +67,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -66,6 +67,7 @@ class TP1DraftModelRunner(ModelRunner):
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states, return_hidden_states=return_hidden_states,
) )
...@@ -136,6 +138,13 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -136,6 +138,13 @@ class TP1DraftModelRunner(ModelRunner):
self.set_active_loras(model_input.lora_requests, self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping) model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
virtual_engine = model_input.virtual_engine virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = [] outputs: List[SamplerOutput] = []
for step in range(num_steps): for step in range(num_steps):
......
...@@ -8,7 +8,7 @@ from torch import nn ...@@ -8,7 +8,7 @@ from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -81,6 +81,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -81,6 +81,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
*args, *args,
**kwargs, **kwargs,
...@@ -94,6 +95,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -94,6 +95,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
......
...@@ -7,7 +7,7 @@ import torch.distributed ...@@ -7,7 +7,7 @@ import torch.distributed
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -133,6 +133,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -133,6 +133,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
...@@ -145,6 +146,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -145,6 +146,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
...@@ -167,6 +169,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -167,6 +169,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config, multimodal_config=self.multimodal_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -40,6 +40,7 @@ class EmbeddingModelRunner( ...@@ -40,6 +40,7 @@ class EmbeddingModelRunner(
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
): ):
super().__init__(model_config, super().__init__(model_config,
...@@ -51,6 +52,7 @@ class EmbeddingModelRunner( ...@@ -51,6 +52,7 @@ class EmbeddingModelRunner(
lora_config=lora_config, lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config) multimodal_config=multimodal_config)
@torch.inference_mode() @torch.inference_mode()
...@@ -71,6 +73,13 @@ class EmbeddingModelRunner( ...@@ -71,6 +73,13 @@ class EmbeddingModelRunner(
self.set_active_loras(model_input.lora_requests, self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping) model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
......
...@@ -25,7 +25,7 @@ except ImportError: ...@@ -25,7 +25,7 @@ except ImportError:
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
...@@ -40,6 +40,10 @@ from vllm.model_executor.models.interfaces import (supports_lora, ...@@ -40,6 +40,10 @@ from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision) supports_vision)
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs) MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
...@@ -85,6 +89,8 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -85,6 +89,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_mapping: Optional["LoRAMapping"] = None lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
...@@ -97,6 +103,8 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -97,6 +103,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"prompt_adapter_mapping": self.prompt_adapter_mapping,
"prompt_adapter_requests": self.prompt_adapter_requests,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids, "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids, "finished_requests_ids": self.finished_requests_ids,
...@@ -133,6 +141,8 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): ...@@ -133,6 +141,8 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"prompt_adapter_mapping": self.prompt_adapter_mapping,
"prompt_adapter_requests": self.prompt_adapter_requests,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids, "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids, "finished_requests_ids": self.finished_requests_ids,
...@@ -172,6 +182,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -172,6 +182,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
): ):
...@@ -183,6 +194,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -183,6 +194,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
...@@ -232,6 +244,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -232,6 +244,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.model: nn.Module # Set after load_model self.model: nn.Module # Set after load_model
# Set after load_model. # Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
self.flashinfer_decode_workspace_buffer = None self.flashinfer_decode_workspace_buffer = None
self.flashinfer_decode_wrapper = None self.flashinfer_decode_wrapper = None
...@@ -240,16 +253,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -240,16 +253,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def load_model(self) -> None: def load_model(self) -> None:
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
self.model = get_model( self.model = get_model(model_config=self.model_config,
model_config=self.model_config,
device_config=self.device_config, device_config=self.device_config,
load_config=self.load_config, load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config, multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
cache_config=self.cache_config, cache_config=self.cache_config)
)
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",
...@@ -274,6 +285,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -274,6 +285,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
) )
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
if self.prompt_adapter_config:
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.device,
self.prompt_adapter_config)
self.model = (
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))
if self.kv_cache_dtype == "fp8" and is_hip(): if self.kv_cache_dtype == "fp8" and is_hip():
# Currently only ROCm accepts kv-cache scaling factors # Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated # via quantization_param_path and this will be deprecated
...@@ -354,6 +374,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -354,6 +374,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_index_mapping: List[int] = [] lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = [] lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set() lora_requests: Set[LoRARequest] = set()
prompt_adapter_index_mapping: List[int] = []
prompt_adapter_prompt_mapping: List[int] = []
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
seq_lens: List[int] = [] seq_lens: List[int] = []
prefill_seq_lens: List[int] = [] prefill_seq_lens: List[int] = []
...@@ -504,6 +527,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -504,6 +527,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
input_tokens.extend(tokens) input_tokens.extend(tokens)
input_positions.extend(list(range(context_len, seq_len))) input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id lora_id = seq_group_metadata.lora_int_id
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if is_prompt: if is_prompt:
assert len(seq_ids) == 1 assert len(seq_ids) == 1
...@@ -534,6 +558,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -534,6 +558,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs) multi_modal_inputs_list.append(mm_kwargs)
if prompt_adapter_id > 0 and is_prompt:
prompt_adapter_requests.add(
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.\
prompt_adapter_num_virtual_tokens
pm = [prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
prompt_adapter_index_mapping += pm
prompt_adapter_prompt_mapping.extend(
[prompt_adapter_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
is_profile_run = _is_block_tables_empty( is_profile_run = _is_block_tables_empty(
seq_group_metadata.block_tables) seq_group_metadata.block_tables)
if is_profile_run: if is_profile_run:
...@@ -618,12 +657,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -618,12 +657,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq_lens.append(1) seq_lens.append(1)
block_tables.append([]) block_tables.append([])
lora_index_mapping.append(0) lora_index_mapping.append(0)
prompt_adapter_index_mapping.append(0)
if self.attn_backend.get_name() == "flashinfer": if self.attn_backend.get_name() == "flashinfer":
last_paged_kv_indptr = paged_kv_indptr[-1] last_paged_kv_indptr = paged_kv_indptr[-1]
paged_kv_indptr.append(last_paged_kv_indptr) paged_kv_indptr.append(last_paged_kv_indptr)
paged_kv_last_page_len.append(0) paged_kv_last_page_len.append(0)
batch_size = graph_batch_size batch_size = graph_batch_size
num_decode_tokens = batch_size num_decode_tokens = batch_size
...@@ -759,6 +797,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -759,6 +797,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
else: else:
lora_mapping = None lora_mapping = None
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
else:
prompt_adapter_mapping = None
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device) device=self.device)
request_ids_to_seq_ids = { request_ids_to_seq_ids = {
...@@ -776,7 +822,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -776,7 +822,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_requests=lora_requests, lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs, multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids, request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=finished_requests_ids) finished_requests_ids=finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests,
)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
...@@ -878,33 +927,67 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -878,33 +927,67 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def remove_all_loras(self): def remove_all_loras(self):
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
self.lora_manager.remove_all_loras() self.lora_manager.remove_all_adapters()
def set_active_loras(self, lora_requests: Set[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None: lora_mapping: LoRAMapping) -> None:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_loras(lora_requests, lora_mapping) self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_lora(lora_request) return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id) return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_lora(lora_id) return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_loras() return self.lora_manager.list_adapters()
def remove_all_prompt_adapters(self):
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
self.prompt_adapter_manager.remove_all_adapters()
def set_active_prompt_adapters(
self, prompt_adapter_requests: Set[PromptAdapterRequest],
prompt_adapter_mapping: PromptAdapterMapping) -> None:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
self.prompt_adapter_manager.set_active_adapters(
prompt_adapter_requests, prompt_adapter_mapping)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.list_adapters()
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
...@@ -1063,6 +1146,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1063,6 +1146,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
) )
self.set_active_loras(set(), lora_mapping) self.set_active_loras(set(), lora_mapping)
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
[-1] * batch_size,
[-1] * batch_size,
)
self.set_active_prompt_adapters(
set(), prompt_adapter_mapping)
graph_runner = CUDAGraphRunner( graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name()) self.model, self.attn_backend.get_name())
...@@ -1189,6 +1280,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1189,6 +1280,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.set_active_loras(model_input.lora_requests, self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping) model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
if self.attn_backend.get_name() == "flashinfer": if self.attn_backend.get_name() == "flashinfer":
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
......
...@@ -8,7 +8,8 @@ import torch.distributed ...@@ -8,7 +8,8 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig) PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
...@@ -16,6 +17,7 @@ from vllm.lora.request import LoRARequest ...@@ -16,6 +17,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.embedding_model_runner import EmbeddingModelRunner
...@@ -45,6 +47,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -45,6 +47,7 @@ class Worker(LocalOrDistributedWorkerBase):
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None, speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None: ) -> None:
...@@ -59,6 +62,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -59,6 +62,7 @@ class Worker(LocalOrDistributedWorkerBase):
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if parallel_config and is_driver_worker: if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \ assert rank % parallel_config.tensor_parallel_size == 0, \
...@@ -92,6 +96,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -92,6 +96,7 @@ class Worker(LocalOrDistributedWorkerBase):
lora_config=self.lora_config, lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
**speculative_args, **speculative_args,
) )
...@@ -296,6 +301,19 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -296,6 +301,19 @@ class Worker(LocalOrDistributedWorkerBase):
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self.model_runner.list_loras() return self.model_runner.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_runner.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_runner.remove_lora(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.model_runner.list_prompt_adapters()
@property @property
def max_model_len(self) -> int: def max_model_len(self) -> int:
return self.model_config.max_model_len return self.model_config.max_model_len
......
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -88,6 +88,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -88,6 +88,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
*args, *args,
**kwargs, **kwargs,
...@@ -98,6 +99,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -98,6 +99,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config self.load_config = load_config
self.cache_config = cache_config self.cache_config = cache_config
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
......
...@@ -10,7 +10,8 @@ import torch.distributed ...@@ -10,7 +10,8 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig) PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -47,6 +48,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): ...@@ -47,6 +48,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None, speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
assert device_config.device_type == "xpu" assert device_config.device_type == "xpu"
...@@ -63,6 +65,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): ...@@ -63,6 +65,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment