Unverified Commit 4d6ada94 authored by Swapnil Parekh's avatar Swapnil Parekh Committed by GitHub
Browse files

[CORE] Adding support for insertion of soft-tuned prompts (#4645)


Co-authored-by: default avatarSwapnil Parekh <swapnilp@ibm.com>
Co-authored-by: default avatarJoe G <joseph.granados@h2o.ai>
Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent a0550cbc
...@@ -111,6 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml ...@@ -111,6 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml mypy vllm/logging --config-file pyproject.toml
mypy vllm/prompt_adapter --config-file pyproject.toml
mypy tests --config-file pyproject.toml mypy tests --config-file pyproject.toml
......
...@@ -92,11 +92,10 @@ def batched_generate( ...@@ -92,11 +92,10 @@ def batched_generate(
for input in inputs: for input in inputs:
prompt, sampling_param, lora_req = input prompt, sampling_param, lora_req = input
# Add requests to the engine and run the engine # Add requests to the engine and run the engine
llm._validate_and_add_requests( llm._validate_and_add_requests(prompt,
prompt, sampling_param,
sampling_param, lora_request=lora_req,
lora_request=lora_req, prompt_adapter_request=None)
)
outputs = llm._run_engine(use_tqdm=True) outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
......
This diff is collapsed.
import pytest
import vllm
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
def do_sample(llm, pa_name: str, pa_id: int):
prompts = [
"Tweet text : @nationalgridus I have no water and the bill is \
current and paid. Can you do something about this? Label : ",
"Tweet text : @nationalgridus Looks good thanks! Label : "
]
sampling_params = vllm.SamplingParams(temperature=0.0,
max_tokens=3,
stop_token_ids=[3])
outputs = llm.generate(prompts,
sampling_params,
prompt_adapter_request=PromptAdapterRequest(
pa_name, pa_id, PA_PATH, 8) if pa_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_twitter_prompt_adapter(enforce_eager: bool):
llm = vllm.LLM(MODEL_PATH,
enforce_eager=enforce_eager,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
expected_output = ['complaint', 'no complaint']
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
pa_path2 = 'swapnilbp/angry_tweet_ptune'
def do_sample(engine):
prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def do_sample(engine):
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_lora_prompt_adapter():
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)
expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output
...@@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest ...@@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -92,6 +93,7 @@ class AsyncLLM: ...@@ -92,6 +93,7 @@ class AsyncLLM:
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalDataDict] = None, multi_modal_data: Optional[MultiModalDataDict] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> List[RequestOutput]: ) -> List[RequestOutput]:
if prompts is None: if prompts is None:
......
...@@ -23,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: ...@@ -23,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
cache_config=engine_config.cache_config, cache_config=engine_config.cache_config,
load_config=engine_config.load_config, load_config=engine_config.load_config,
lora_config=engine_config.lora_config, lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
is_driver_worker=True, is_driver_worker=True,
) )
return model_runner return model_runner
......
from dataclasses import dataclass
from typing import Tuple
@dataclass
class AdapterMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
\ No newline at end of file
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
from torch import nn
from vllm.logger import init_logger
from vllm.utils import LRUCache
logger = init_logger(__name__)
class AdapterModel(ABC):
def __init__(self, model_id=None):
self.id = model_id
@abstractmethod
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
# Common initialization code
# Load weights or embeddings from local checkpoint
raise NotImplementedError("Subclasses must implement this method.")
T = TypeVar('T')
class AdapterLRUCache(LRUCache[T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: Hashable, value: T):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
class AdapterModelManager(ABC):
def __init__(
self,
model: nn.Module,
):
"""Create a AdapterModelManager and adapter for a given model.
Args:
model: the model to be adapted.
"""
self.model: nn.Module = model
self._registered_adapters: Dict[int, Any] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: Dict[int, None] = {}
self.adapter_type = 'Adapter'
self._last_mapping = None
def __len__(self) -> int:
return len(self._registered_adapters)
@property
@abstractmethod
def adapter_slots(self):
...
@property
@abstractmethod
def capacity(self):
...
@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
...
@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
...
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def remove_all_adapters(self):
...
@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
...
@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
...
@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
...
from abc import abstractmethod
from dataclasses import dataclass
@dataclass
class AdapterRequest:
"""
Base class for adapter requests.
"""
@property
@abstractmethod
def adapter_id(self):
...
def __post_init__(self):
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, self.__class__) and self.adapter_id == value.adapter_id
def __hash__(self) -> int:
return hash(self.adapter_id)
from typing import Any, Callable, Dict, Optional, Set
## model functions
def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
deactivate_func: Callable) -> bool:
if adapter_id in active_adapters:
deactivate_func(adapter_id)
active_adapters.pop(adapter_id)
return True
return False
def add_adapter(adapter: Any, registered_adapters: Dict[int, Any],
capacity: int, add_func: Callable) -> bool:
if adapter.id not in registered_adapters:
if len(registered_adapters) >= capacity:
raise RuntimeError('No free adapter slots.')
add_func(adapter)
registered_adapters[adapter.id] = adapter
return True
return False
def set_adapter_mapping(mapping: Any, last_mapping: Any,
set_mapping_func: Callable) -> Any:
if last_mapping != mapping:
set_mapping_func(mapping)
return mapping
return last_mapping
def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any],
deactivate_func: Callable) -> bool:
deactivate_func(adapter_id)
return bool(registered_adapters.pop(adapter_id, None))
def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
return dict(registered_adapters)
def get_adapter(adapter_id: int,
registered_adapters: Dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id, None)
## worker functions
def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any],
apply_adapters_func,
set_adapter_mapping_func) -> None:
apply_adapters_func(requests)
set_adapter_mapping_func(mapping)
def add_adapter_worker(adapter_request: Any, list_adapters_func,
load_adapter_func, add_adapter_func,
activate_adapter_func) -> bool:
if adapter_request.adapter_id in list_adapters_func():
return False
loaded_adapter = load_adapter_func(adapter_request)
loaded = add_adapter_func(loaded_adapter)
activate_adapter_func(loaded_adapter.id)
return loaded
def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
adapter_slots: int, remove_adapter_func,
add_adapter_func) -> None:
models_that_exist = list_adapters_func()
models_map = {
adapter_request.adapter_id: adapter_request
for adapter_request in adapter_requests if adapter_request
}
if len(models_map) > adapter_slots:
raise RuntimeError(
f"Number of requested models ({len(models_map)}) is greater "
f"than the number of GPU model slots "
f"({adapter_slots}).")
new_models = set(models_map)
models_to_add = new_models - models_that_exist
models_to_remove = models_that_exist - new_models
for adapter_id in models_to_remove:
remove_adapter_func(adapter_id)
for adapter_id in models_to_add:
add_adapter_func(models_map[adapter_id])
def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]:
return set(adapter_manager_list_adapters_func())
from abc import ABC, abstractmethod
from typing import Any, Optional, Set
import torch
class AbstractWorkerManager(ABC):
def __init__(self, device: torch.device):
self.device = device
@property
@abstractmethod
def is_enabled(self) -> bool:
...
@abstractmethod
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
...
@abstractmethod
def add_adapter(self, adapter_request: Any) -> bool:
...
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def remove_all_adapters(self):
...
@abstractmethod
def list_adapters(self) -> Set[int]:
...
...@@ -1285,6 +1285,39 @@ class LoRAConfig: ...@@ -1285,6 +1285,39 @@ class LoRAConfig:
raise ValueError("LoRA is not supported with chunked prefill yet.") raise ValueError("LoRA is not supported with chunked prefill yet.")
@dataclass
class PromptAdapterConfig:
max_prompt_adapters: int
max_prompt_adapter_token: int
max_cpu_prompt_adapters: Optional[int] = None
prompt_adapter_dtype: Optional[torch.dtype] = None
def __post_init__(self):
library_name = 'peft'
try:
__import__(library_name)
except ImportError as e:
raise ImportError(
f"'{library_name}' is not installed for prompt adapter support."
f"Please install it using 'pip install {library_name}'."
) from e
if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
f"({self.max_prompt_adapters}) must be >= 1.")
if self.max_prompt_adapter_token == 0:
raise ValueError("max_prompt_adapter_token must be set.")
if self.max_cpu_prompt_adapters is None:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype in (None, "auto"):
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
self.prompt_adapter_dtype)
@dataclass @dataclass
class MultiModalConfig: class MultiModalConfig:
"""Configs the input data format and how models should run for """Configs the input data format and how models should run for
...@@ -1518,6 +1551,7 @@ class EngineConfig: ...@@ -1518,6 +1551,7 @@ class EngineConfig:
speculative_config: Optional[SpeculativeConfig] speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig] decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig] observability_config: Optional[ObservabilityConfig]
prompt_adapter_config: Optional[PromptAdapterConfig]
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
...@@ -1529,6 +1563,9 @@ class EngineConfig: ...@@ -1529,6 +1563,9 @@ class EngineConfig:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
self.scheduler_config) self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def to_dict(self): def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs. """Return the configs as a dictionary, for use in **kwargs.
......
...@@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager ...@@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus) SequenceGroupMetadata, SequenceStatus)
...@@ -139,6 +140,8 @@ class SchedulerOutputs: ...@@ -139,6 +140,8 @@ class SchedulerOutputs:
if self.num_loras > 0: if self.num_loras > 0:
self._sort_by_lora_ids() self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool: def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups. # NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
...@@ -157,6 +160,14 @@ class SchedulerOutputs: ...@@ -157,6 +160,14 @@ class SchedulerOutputs:
if g.seq_group.lora_request is not None if g.seq_group.lora_request is not None
} }
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass @dataclass
class SchedulerRunningOutputs: class SchedulerRunningOutputs:
...@@ -1024,6 +1035,7 @@ class Scheduler: ...@@ -1024,6 +1035,7 @@ class Scheduler:
# `multi_modal_data` will be None. # `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups > 0 else None,
prompt_adapter_request=seq_group.prompt_adapter_request,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
......
...@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple, Union ...@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig, PromptAdapterConfig, SchedulerConfig,
TokenizerPoolConfig) SpeculativeConfig, TokenizerPoolConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -66,6 +66,9 @@ class EngineArgs: ...@@ -66,6 +66,9 @@ class EngineArgs:
enable_lora: bool = False enable_lora: bool = False
max_loras: int = 1 max_loras: int = 1
max_lora_rank: int = 16 max_lora_rank: int = 16
enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256 lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None long_lora_scaling_factors: Optional[Tuple[float]] = None
...@@ -449,6 +452,17 @@ class EngineArgs: ...@@ -449,6 +452,17 @@ class EngineArgs:
'Enabling this will use the fully sharded layers. ' 'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or ' 'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.')) 'tensor parallel size, this is likely faster.'))
parser.add_argument('--enable-prompt-adapter',
action='store_true',
help='If True, enable handling of PromptAdapters.')
parser.add_argument('--max-prompt-adapters',
type=int,
default=EngineArgs.max_prompt_adapters,
help='Max number of PromptAdapters in a batch.')
parser.add_argument('--max-prompt-adapter-token',
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
parser.add_argument("--device", parser.add_argument("--device",
type=str, type=str,
default=EngineArgs.device, default=EngineArgs.device,
...@@ -726,6 +740,11 @@ class EngineArgs: ...@@ -726,6 +740,11 @@ class EngineArgs:
model_loader_extra_config=self.model_loader_extra_config, model_loader_extra_config=self.model_loader_extra_config,
) )
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig( decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend) guided_decoding_backend=self.guided_decoding_backend)
...@@ -751,6 +770,7 @@ class EngineArgs: ...@@ -751,6 +770,7 @@ class EngineArgs:
load_config=load_config, load_config=load_config,
decoding_config=decoding_config, decoding_config=decoding_config,
observability_config=observability_config, observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
) )
......
...@@ -18,6 +18,7 @@ from vllm.logger import init_logger ...@@ -18,6 +18,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -264,6 +265,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -264,6 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
request_id: str, request_id: str,
inputs: PromptInputs, inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> LLMInputs:
if isinstance(inputs, str): if isinstance(inputs, str):
inputs = {"prompt": inputs} inputs = {"prompt": inputs}
...@@ -279,6 +281,12 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -279,6 +281,12 @@ class _AsyncLLMEngine(LLMEngine):
else: else:
prompt_token_ids = inputs["prompt_token_ids"] prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = [
0
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
prompt_token_ids
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"), prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data")) multi_modal_data=inputs.get("multi_modal_data"))
...@@ -286,13 +294,14 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -286,13 +294,14 @@ class _AsyncLLMEngine(LLMEngine):
return self.input_processor(llm_inputs) return self.input_processor(llm_inputs)
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None: ) -> None:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
...@@ -301,7 +310,10 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -301,7 +310,10 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time() arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async( processed_inputs = await self.process_model_inputs_async(
request_id=request_id, inputs=inputs, lora_request=lora_request) request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
...@@ -309,6 +321,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -309,6 +321,7 @@ class _AsyncLLMEngine(LLMEngine):
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
) )
...@@ -627,6 +640,7 @@ class AsyncLLMEngine: ...@@ -627,6 +640,7 @@ class AsyncLLMEngine:
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests: if self.log_requests:
if isinstance(inputs, str): if isinstance(inputs, str):
...@@ -669,7 +683,7 @@ class AsyncLLMEngine: ...@@ -669,7 +683,7 @@ class AsyncLLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
) prompt_adapter_request=prompt_adapter_request)
return stream return stream
...@@ -680,6 +694,7 @@ class AsyncLLMEngine: ...@@ -680,6 +694,7 @@ class AsyncLLMEngine:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]: ) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request. """Generate outputs for a request.
...@@ -695,6 +710,8 @@ class AsyncLLMEngine: ...@@ -695,6 +710,8 @@ class AsyncLLMEngine:
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine The output `RequestOutput` objects from the LLMEngine
...@@ -749,6 +766,7 @@ class AsyncLLMEngine: ...@@ -749,6 +766,7 @@ class AsyncLLMEngine:
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
): ):
yield LLMEngine.validate_output(output, RequestOutput) yield LLMEngine.validate_output(output, RequestOutput)
...@@ -837,6 +855,7 @@ class AsyncLLMEngine: ...@@ -837,6 +855,7 @@ class AsyncLLMEngine:
*, *,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or """Common logic to process requests with SamplingParams or
PoolingParams.""" PoolingParams."""
...@@ -849,6 +868,7 @@ class AsyncLLMEngine: ...@@ -849,6 +868,7 @@ class AsyncLLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
) )
try: try:
......
...@@ -8,7 +8,8 @@ from transformers import PreTrainedTokenizer ...@@ -8,7 +8,8 @@ from transformers import PreTrainedTokenizer
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, MultiModalConfig, LoRAConfig, ModelConfig, MultiModalConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs) SchedulerOutputs)
...@@ -27,6 +28,7 @@ from vllm.lora.request import LoRARequest ...@@ -27,6 +28,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence, PoolerOutput, SamplerOutput, Sequence,
...@@ -93,6 +95,8 @@ class LLMEngine: ...@@ -93,6 +95,8 @@ class LLMEngine:
decoding. decoding.
executor_class: The model executor class for managing distributed executor_class: The model executor class for managing distributed
execution. execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection. usage_context: Specified entry point, used for usage info collection.
""" """
...@@ -161,6 +165,7 @@ class LLMEngine: ...@@ -161,6 +165,7 @@ class LLMEngine:
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig], decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig], observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[ExecutorBase], executor_class: Type[ExecutorBase],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
...@@ -222,6 +227,7 @@ class LLMEngine: ...@@ -222,6 +227,7 @@ class LLMEngine:
self.speculative_config = speculative_config self.speculative_config = speculative_config
self.load_config = load_config self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig() self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig( self.observability_config = observability_config or ObservabilityConfig(
) )
self.log_stats = log_stats self.log_stats = log_stats
...@@ -250,6 +256,7 @@ class LLMEngine: ...@@ -250,6 +256,7 @@ class LLMEngine:
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
speculative_config=speculative_config, speculative_config=speculative_config,
load_config=load_config, load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
) )
if not self.model_config.embedding_mode: if not self.model_config.embedding_mode:
...@@ -282,6 +289,8 @@ class LLMEngine: ...@@ -282,6 +289,8 @@ class LLMEngine:
# Feature flags # Feature flags
"enable_lora": "enable_lora":
bool(lora_config), bool(lora_config),
"enable_prompt_adapter":
bool(prompt_adapter_config),
"enable_prefix_caching": "enable_prefix_caching":
cache_config.enable_prefix_caching, cache_config.enable_prefix_caching,
"enforce_eager": "enforce_eager":
...@@ -376,7 +385,6 @@ class LLMEngine: ...@@ -376,7 +385,6 @@ class LLMEngine:
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
distributed_executor_backend = ( distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend) engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class. # Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor from vllm.executor.neuron_executor import NeuronExecutor
...@@ -409,7 +417,6 @@ class LLMEngine: ...@@ -409,7 +417,6 @@ class LLMEngine:
else: else:
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor executor_class = GPUExecutor
# Create the LLM engine. # Create the LLM engine.
engine = cls( engine = cls(
**engine_config.to_dict(), **engine_config.to_dict(),
...@@ -470,6 +477,9 @@ class LLMEngine: ...@@ -470,6 +477,9 @@ class LLMEngine:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
self.scheduler_config) self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _get_eos_token_id( def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]: self, lora_request: Optional[LoRARequest]) -> Optional[int]:
...@@ -487,6 +497,7 @@ class LLMEngine: ...@@ -487,6 +497,7 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Dict[str, str]] = None,
) -> None: ) -> None:
# Create the sequences. # Create the sequences.
...@@ -495,7 +506,7 @@ class LLMEngine: ...@@ -495,7 +506,7 @@ class LLMEngine:
eos_token_id = self._get_eos_token_id(lora_request) eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request) lora_request, prompt_adapter_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams # Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
...@@ -506,7 +517,7 @@ class LLMEngine: ...@@ -506,7 +517,7 @@ class LLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
) prompt_adapter_request=prompt_adapter_request)
elif isinstance(params, PoolingParams): elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling( seq_group = self._create_sequence_group_with_pooling(
request_id, request_id,
...@@ -514,7 +525,7 @@ class LLMEngine: ...@@ -514,7 +525,7 @@ class LLMEngine:
params, params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
) prompt_adapter_request=prompt_adapter_request)
else: else:
raise ValueError( raise ValueError(
"Either SamplingParams or PoolingParams must be provided.") "Either SamplingParams or PoolingParams must be provided.")
...@@ -535,6 +546,7 @@ class LLMEngine: ...@@ -535,6 +546,7 @@ class LLMEngine:
request_id: str, request_id: str,
inputs: PromptInputs, inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> LLMInputs:
if isinstance(inputs, str): if isinstance(inputs, str):
inputs = {"prompt": inputs} inputs = {"prompt": inputs}
...@@ -549,6 +561,11 @@ class LLMEngine: ...@@ -549,6 +561,11 @@ class LLMEngine:
else: else:
prompt_token_ids = inputs["prompt_token_ids"] prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = \
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
+ prompt_token_ids
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"), prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data")) multi_modal_data=inputs.get("multi_modal_data"))
...@@ -563,6 +580,7 @@ class LLMEngine: ...@@ -563,6 +580,7 @@ class LLMEngine:
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
...@@ -612,9 +630,11 @@ class LLMEngine: ...@@ -612,9 +630,11 @@ class LLMEngine:
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
processed_inputs = self.process_model_inputs(request_id=request_id, processed_inputs = self.process_model_inputs(
inputs=inputs, request_id=request_id,
lora_request=lora_request) inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
...@@ -622,6 +642,7 @@ class LLMEngine: ...@@ -622,6 +642,7 @@ class LLMEngine:
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
) )
...@@ -633,6 +654,7 @@ class LLMEngine: ...@@ -633,6 +654,7 @@ class LLMEngine:
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs max_logprobs = self.get_model_config().max_logprobs
...@@ -658,7 +680,7 @@ class LLMEngine: ...@@ -658,7 +680,7 @@ class LLMEngine:
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
) prompt_adapter_request=prompt_adapter_request)
return seq_group return seq_group
...@@ -669,16 +691,19 @@ class LLMEngine: ...@@ -669,16 +691,19 @@ class LLMEngine:
pooling_params: PoolingParams, pooling_params: PoolingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams.""" """Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler # Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone() pooling_params = pooling_params.clone()
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id=request_id, seq_group = SequenceGroup(
seqs=[seq], request_id=request_id,
arrival_time=arrival_time, seqs=[seq],
lora_request=lora_request, arrival_time=arrival_time,
pooling_params=pooling_params) lora_request=lora_request,
pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request)
return seq_group return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
...@@ -1082,6 +1107,16 @@ class LLMEngine: ...@@ -1082,6 +1107,16 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id) return self.model_executor.pin_lora(lora_id)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def check_health(self) -> None: def check_health(self) -> None:
if self.tokenizer: if self.tokenizer:
self.tokenizer.check_health() self.tokenizer.check_health()
......
...@@ -13,6 +13,7 @@ from vllm.logger import init_logger ...@@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -255,6 +256,7 @@ class LLM: ...@@ -255,6 +256,7 @@ class LLM:
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -271,6 +273,8 @@ class LLM: ...@@ -271,6 +273,8 @@ class LLM:
prompts and it is paired one by one with the prompt. prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns: Returns:
A list of `RequestOutput` objects containing the A list of `RequestOutput` objects containing the
...@@ -304,7 +308,7 @@ class LLM: ...@@ -304,7 +308,7 @@ class LLM:
inputs=inputs, inputs=inputs,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) prompt_adapter_request=prompt_adapter_request)
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput) return LLMEngine.validate_outputs(outputs, RequestOutput)
...@@ -397,6 +401,7 @@ class LLM: ...@@ -397,6 +401,7 @@ class LLM:
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -412,6 +417,8 @@ class LLM: ...@@ -412,6 +417,8 @@ class LLM:
use the default pooling parameters. use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns: Returns:
A list of `EmbeddingRequestOutput` objects containing the A list of `EmbeddingRequestOutput` objects containing the
...@@ -445,6 +452,7 @@ class LLM: ...@@ -445,6 +452,7 @@ class LLM:
inputs=inputs, inputs=inputs,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
) )
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
...@@ -504,6 +512,7 @@ class LLM: ...@@ -504,6 +512,7 @@ class LLM:
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]], Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: ) -> None:
if isinstance(inputs, (str, dict)): if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
...@@ -526,19 +535,23 @@ class LLM: ...@@ -526,19 +535,23 @@ class LLM:
params[i] if isinstance(params, Sequence) else params, params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
) prompt_adapter_request=prompt_adapter_request)
def _add_request( def _add_request(
self, self,
inputs: PromptInputs, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest],
LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, self.llm_engine.add_request(
inputs, request_id,
params, inputs,
lora_request=lora_request) params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
def _run_engine( def _run_engine(
self, *, use_tqdm: bool self, *, use_tqdm: bool
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment