Unverified Commit 9b945daa authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Experimental] Add multi-LoRA support (#1804)


Co-authored-by: default avatarChen Shen <scv119@gmail.com>
Co-authored-by: default avatarShreyas Krishnaswamy <shrekris@anyscale.com>
Co-authored-by: default avatarAvnish Narayan <avnish@anyscale.com>
parent 9c1352eb
import os
import random
import tempfile
from unittest.mock import patch
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig
from vllm.worker.worker import Worker
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
worker = Worker(
model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf",
"meta-llama/Llama-2-7b-hf",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
),
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32, 256),
local_rank=0,
rank=0,
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
max_loras=32),
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_model()
worker.load_model()
worker.model_runner.set_active_loras([], LoRAMapping([], []))
assert worker.list_loras() == set()
n_loras = 32
lora_requests = [
LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
]
worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
assert worker.list_loras() == {
lora_request.lora_int_id
for lora_request in lora_requests
}
for i in range(32):
random.seed(i)
iter_lora_requests = random.choices(lora_requests,
k=random.randint(1, n_loras))
random.shuffle(iter_lora_requests)
iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
worker.model_runner.set_active_loras(iter_lora_requests,
LoRAMapping([], []))
assert worker.list_loras().issuperset(
{lora_request.lora_int_id
for lora_request in iter_lora_requests})
from typing import List, Optional
import torch
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
class DummyLoRAManager:
def __init__(self):
super().__init__()
self._loras = {}
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
self._loras[module_name] = lora
def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
return self._loras.get(module_name, None)
def init_random_lora(self,
module_name: str,
weight: torch.Tensor,
rank: int = 8,
generate_embeddings_tensor: int = 0):
lora = LoRALayerWeights(
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank],
dtype=weight.dtype,
device="cuda"),
lora_b=torch.rand([rank, weight.shape[0]],
dtype=weight.dtype,
device="cuda"),
)
if generate_embeddings_tensor:
lora.embeddings_tensor = torch.rand(5,
generate_embeddings_tensor,
dtype=weight.dtype,
device="cuda")
self.set_module_lora(module_name, lora)
return lora
def init_lora(self,
module_name: str,
input_dim: int,
output_dim: int,
rank=8,
noop=False,
embeddings_tensor=None):
lora = LoRALayerWeights(
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([input_dim, rank], device="cuda"),
lora_b=torch.rand([rank, output_dim], device="cuda"),
embeddings_tensor=embeddings_tensor,
)
self.set_module_lora(module_name, lora)
return lora
def reset_lora(self):
self._loras = {}
def init_packed_lora(
self,
module_name: str,
input_dim: int,
output_dims: List[int],
noop_lora_index: List[int] = None,
rank=8,
):
base_loras = []
noop_lora_index = set(noop_lora_index or [])
for i, out_dim in enumerate(output_dims):
base_lora = self.init_lora(
module_name + "_000_" + str(i),
input_dim,
out_dim,
rank=rank,
noop=i in noop_lora_index,
)
base_loras.append(base_lora)
packed_lora = PackedLoRALayerWeights.pack(base_loras)
self.set_module_lora(module_name, packed_lora)
return packed_lora
......@@ -19,10 +19,11 @@ class MockLogitsSampler(Sampler):
self.fake_logits = fake_logits
def forward(self, *args, **kwargs):
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x), patch(
"vllm.model_executor.layers.sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
with patch(
"vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x), patch(
"vllm.model_executor.layers.sampler.Sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
......@@ -38,7 +39,7 @@ def _prepare_test(
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
model_runner = ModelRunner(None, None, None)
model_runner = ModelRunner(None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner
......@@ -266,7 +267,7 @@ def test_sampler_top_k_top_p(seed: int):
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
model_runner = ModelRunner(None, None, None)
model_runner = ModelRunner(None, None, None, None)
generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k,
......
......@@ -83,8 +83,8 @@ def create_worker(cls: type,
enforce_eager=enforce_eager,
)
(model_config, cache_config, parallel_config,
scheduler_config) = engine_args.create_engine_configs()
(model_config, cache_config, parallel_config, scheduler_config,
_) = engine_args.create_engine_configs()
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
......
......@@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner
def test_prepare_prompt():
model_runner = ModelRunner(None, None, None)
model_runner = ModelRunner(None, None, None, None)
model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
......@@ -33,7 +33,7 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, _, return_prompt_lens, _ = (
input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = (
model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
......
from typing import Optional, Union
from typing import Optional, Union, ClassVar
from dataclasses import dataclass
import os
import torch
......@@ -397,6 +398,54 @@ class SchedulerConfig:
f"({self.max_num_seqs}).")
@dataclass
class LoRAConfig:
max_lora_rank: int
max_loras: int
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None
lora_extra_vocab_size: int = 256
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
def __post_init__(self):
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
possible_max_ranks = (8, 16, 32, 64)
possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks:
raise ValueError(
f"max_lora_rank ({self.max_lora_rank}) must be one of "
f"{possible_max_ranks}.")
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
raise ValueError(
f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
f"must be one of {possible_lora_extra_vocab_size}.")
if self.max_loras < 1:
raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
if self.max_cpu_loras is None:
self.max_cpu_loras = self.max_loras
elif self.max_cpu_loras < self.max_loras:
raise ValueError(
f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
f"max_num_seqs ({self.max_loras})")
def verify_with_model_config(self, model_config: ModelConfig):
if self.lora_dtype in (None, "auto"):
self.lora_dtype = model_config.dtype
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
if model_config.quantization is not None:
raise ValueError(
"LoRA is not supported with quantized models yet.")
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
raise ValueError(
"Due to limitations of the custom LoRA CUDA kernel, "
"max_num_batched_tokens must be <= 65528 when "
"LoRA is enabled.")
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
......
from collections import deque
import enum
import time
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set
from vllm.config import CacheConfig, SchedulerConfig
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
from vllm.core.policy import PolicyFactory
from vllm.lora.request import LoRARequest
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
......@@ -49,11 +50,25 @@ class SchedulerOutputs:
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
self.num_loras = len(self.lora_requests)
if self.num_loras > 0:
self._sort_by_lora_ids()
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self) -> bool:
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.lora_request.lora_int_id
if g.lora_request else 0, g.request_id))
@property
def lora_requests(self) -> Set[LoRARequest]:
return {g.lora_request for g in self.scheduled_seq_groups}
class Scheduler:
......@@ -61,9 +76,14 @@ class Scheduler:
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
# simple and NOT fair. It can lead to starvation of some
# LoRAs. This should be improved in the future.
self.lora_config = lora_config
self.prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
......@@ -87,6 +107,10 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
self.swapped: Deque[SequenceGroup] = deque()
@property
def lora_enabled(self) -> bool:
return bool(self.lora_config)
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)
......@@ -150,14 +174,17 @@ class Scheduler:
# requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
seq_lens: List[int] = []
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
leftover_waiting_sequences = deque()
while self.waiting:
seq_group = self.waiting[0]
waiting_seqs = seq_group.get_seqs(
status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, (
......@@ -188,6 +215,17 @@ class Scheduler:
self.waiting.popleft()
continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
curr_loras) >= self.lora_config.max_loras:
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_waiting_sequences.appendleft(seq_group)
self.waiting.popleft()
continue
# If the number of batched tokens exceeds the limit, stop.
new_seq_lens = seq_lens + [num_prompt_tokens]
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
......@@ -207,12 +245,16 @@ class Scheduler:
break
seq_lens = new_seq_lens
seq_group = self.waiting.popleft()
if lora_int_id > 0:
curr_loras.add(lora_int_id)
self.waiting.popleft()
self._allocate(seq_group)
self.running.append(seq_group)
num_curr_seqs += num_new_seqs
scheduled.append(seq_group)
self.waiting.extendleft(leftover_waiting_sequences)
if scheduled or ignored_seq_groups:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
......@@ -260,9 +302,25 @@ class Scheduler:
if not preempted:
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
leftover_swapped = deque()
while self.swapped:
seq_group = self.swapped[0]
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
curr_loras) >= self.lora_config.max_loras:
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped.appendleft(seq_group)
self.swapped.popleft()
continue
# If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group):
break
......@@ -274,12 +332,16 @@ class Scheduler:
self.scheduler_config.max_num_seqs):
break
seq_group = self.swapped.popleft()
if lora_int_id > 0:
curr_loras.add(lora_int_id)
self.swapped.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs
self.running.append(seq_group)
self.swapped.extendleft(leftover_swapped)
# Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state.
......@@ -320,6 +382,7 @@ class Scheduler:
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
lora_request=seq_group.lora_request,
prefix=seq_group.prefix,
)
seq_group_metadata_list.append(seq_group_metadata)
......
......@@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
SchedulerConfig, LoRAConfig)
@dataclass
......@@ -35,6 +35,12 @@ class EngineArgs:
quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
def __post_init__(self):
if self.tokenizer is None:
......@@ -202,6 +208,39 @@ class EngineArgs:
help='maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
help='Max number of LoRAs in a single batch.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
'--lora-dtype',
type=str,
default=EngineArgs.lora_dtype,
choices=['auto', 'float16', 'bfloat16', 'float32'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
default=EngineArgs.max_cpu_loras,
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
return parser
@classmethod
......@@ -214,7 +253,8 @@ class EngineArgs:
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
Optional[LoRAConfig]]:
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
......@@ -234,7 +274,14 @@ class EngineArgs:
self.max_num_seqs,
model_config.max_model_len,
self.max_paddings)
return model_config, cache_config, parallel_config, scheduler_config
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
return model_config, cache_config, parallel_config, scheduler_config, lora_config
@dataclass
......
......@@ -4,6 +4,7 @@ from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator)
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
......@@ -203,6 +204,52 @@ class _AsyncLLMEngine(LLMEngine):
return self._process_model_outputs(output, scheduler_outputs)
async def encode_request_async(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = await self.tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
async def add_request_async(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(
request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async def _run_workers_async(
self,
method: str,
......@@ -332,7 +379,7 @@ class AsyncLLMEngine:
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
else:
self.engine.add_request(**new_request)
await self.engine.add_request_async(**new_request)
if finished_requests:
await self._engine_abort(finished_requests)
......@@ -371,6 +418,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncStream:
if self.log_requests:
......@@ -386,7 +434,8 @@ class AsyncLLMEngine:
f"prompt: {shortened_prompt!r}, "
f"prefix_pos: {prefix_pos},"
f"sampling params: {sampling_params}, "
f"prompt token ids: {shortened_token_ids}.")
f"prompt token ids: {shortened_token_ids}, "
f"lora_request: {lora_request}.")
if not self.is_running:
if self.start_engine_loop:
......@@ -398,12 +447,21 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos)
return stream
......@@ -414,6 +472,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
......@@ -429,6 +488,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
......@@ -487,12 +547,15 @@ class AsyncLLMEngine:
arrival_time = time.monotonic()
try:
stream = await self.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
prefix_pos=prefix_pos)
stream = await self.add_request(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async for request_output in stream:
yield request_output
......
......@@ -5,8 +5,9 @@ import time
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
SchedulerConfig, LoRAConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import record_metrics
......@@ -17,7 +18,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
TokenizerGroup)
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
if ray:
......@@ -64,6 +65,7 @@ class LLMEngine:
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
placement_group: Optional["PlacementGroup"],
log_stats: bool,
) -> None:
......@@ -87,17 +89,13 @@ class LLMEngine:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
tokenizer_revision=model_config.tokenizer_revision,
revision=model_config.revision)
self._init_tokenizer()
self.seq_counter = Counter()
# Create the parallel GPU workers.
......@@ -114,7 +112,7 @@ class LLMEngine:
self._init_cache()
# Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config)
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Logging.
self.last_logging_time = 0.0
......@@ -123,6 +121,9 @@ class LLMEngine:
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
......@@ -141,11 +142,24 @@ class LLMEngine:
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
is_driver_worker=True,
)
self._run_workers("init_model")
self._run_workers("load_model")
def _init_tokenizer(self, **tokenizer_init_kwargs):
init_kwargs = dict(
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: TokenizerGroup = TokenizerGroup(
self.model_config.tokenizer, **init_kwargs)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
......@@ -233,6 +247,7 @@ class LLMEngine:
local_rank,
rank,
distributed_init_method,
lora_config=self.lora_config,
))
driver_rank = 0
......@@ -244,6 +259,7 @@ class LLMEngine:
driver_local_rank,
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
is_driver_worker=True,
)
......@@ -257,6 +273,10 @@ class LLMEngine:
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.
......@@ -332,6 +352,20 @@ class LLMEngine:
log_stats=not engine_args.disable_log_stats)
return engine
def encode_request(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
def add_request(
self,
request_id: str,
......@@ -339,6 +373,7 @@ class LLMEngine:
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
"""Add a request to the engine's request pool.
......@@ -386,24 +421,31 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.monotonic()
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(prompt)
prompt_token_ids = self.encode_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
lora_request)
# Check whether the input specifies prefix
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, prefix)
arrival_time, lora_request, prefix)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
......@@ -453,11 +495,13 @@ class LLMEngine:
current_worst_score = (current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
eos_token_id=self.get_tokenizer_for_seq(
current_worst_seq).eos_token_id))
if early_stopping is False:
highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id))
else:
assert early_stopping == "never"
if length_penalty > 0.0:
......@@ -471,7 +515,8 @@ class LLMEngine:
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
......@@ -480,7 +525,8 @@ class LLMEngine:
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id))
return current_worst_score >= highest_attainable_score
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
......@@ -571,7 +617,7 @@ class LLMEngine:
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
......@@ -599,7 +645,7 @@ class LLMEngine:
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True)
# Check if we can stop the beam search.
......@@ -837,7 +883,7 @@ class LLMEngine:
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.tokenizer,
self.get_tokenizer_for_seq(seq),
all_input_ids=seq.get_token_ids(),
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
......@@ -879,11 +925,28 @@ class LLMEngine:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
== self.get_tokenizer_for_seq(seq).eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> List[int]:
return self._run_workers("list_loras")
def _run_workers(
self,
method: str,
......
......@@ -3,6 +3,7 @@ from typing import List, Optional, Union
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.lora.request import LoRARequest
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.outputs import RequestOutput
......@@ -122,6 +123,7 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
prefix_pos: Optional[Union[int, List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
......@@ -141,6 +143,7 @@ class LLM:
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
Returns:
A list of `RequestOutput` objects containing the generated
......@@ -168,7 +171,11 @@ class LLM:
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
self._add_request(prompt, sampling_params, token_ids, prefix_pos_i)
self._add_request(prompt,
sampling_params,
token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos_i)
return self._run_engine(use_tqdm)
def _add_request(
......@@ -176,6 +183,7 @@ class LLM:
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
request_id = str(next(self.request_counter))
......@@ -183,6 +191,7 @@ class LLM:
prompt,
sampling_params,
prompt_token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
......
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather,
)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear,
QKVParallelLinear,
MergedColumnParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim
if TYPE_CHECKING:
pass
def _apply_lora(
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
indices: torch.Tensor,
output: torch.Tensor,
):
"""Applies lora to each input.
This method applies all loras to each input. It uses the
indices vector to determine which lora yields the
correct output. An index of -1 means no lora should be
applied. This method adds the final lora results to the
output.
Input shapes:
x: (batch_size, hidden_dim)
lora_a_stacked: (num_loras, lora_rank, hidden_dim)
lora_b_stacked: (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
return output.view_as(org_output)
def _apply_lora_packed_nslice(
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
):
"""Applies lora to each input.
This method applies all loras to each input. It uses the
indices vector to determine which lora yields the
correct output. An index of -1 means no lora should be
applied. This method adds the final lora results to the
output.
This method is used for layers that are composed of multiple sublayers
(slices) packed together.
Input shapes:
x: (batch_size, hidden_dim)
lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...), where n is number of slices
"""
org_output = output
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx in range(len(output_slices)):
add_lora_slice(output, x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
output_slices[slice_idx])
offset_left += output_slices[slice_idx]
return output.view_as(org_output)
@dataclass
class LoRAMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
class BaseLayerWithLoRA(nn.Module):
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
model_config: PretrainedConfig) -> None:
"""Initializes lora matrices."""
...
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
"""Overwrites lora tensors at index."""
...
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
"""Sets the mapping indices."""
...
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
lora_vocab_start_idx = self.base_layer.org_vocab_size
weights_idx = None
if self.base_layer.vocab_end_index > lora_vocab_start_idx:
# We can start adding lora weights
weights_idx = max(
lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
self.embeddings_slice = (self.base_layer.vocab_start_index -
self.base_layer.org_vocab_size +
weights_idx,
self.base_layer.vocab_end_index -
self.base_layer.org_vocab_size)
self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
self.embeddings_weights.fill_(0)
else:
self.embeddings_slice = None
self.embeddings_weights = None
self.embeddings_tensors = torch.zeros(
(
max_loras,
lora_config.lora_extra_vocab_size,
self.base_layer.embedding_dim,
),
dtype=self.base_layer.weight.dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.org_vocab_size +
lora_config.lora_extra_vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.embedding_dim,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked_2d = self.lora_a_stacked.view(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.embeddings_indices = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1]].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part
embeddings = self.embeddings_tensors.view(
self.embeddings_tensors.shape[0] *
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2]
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.embeddings_indices = embeddings_indices
self.indices_len = indices_len
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
full_lora_a_embeddings = F.embedding(
x + indices,
self.lora_a_stacked_2d,
)
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
full_output_org = full_output
if full_output.ndim == 3:
full_output = full_output.view(
full_output.shape[0] * full_output.shape[1], -1)
if full_lora_a_embeddings.ndim == 3:
full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] *
full_lora_a_embeddings.shape[1], -1)
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
return full_output.view_as(full_output_org)
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_a_stacked = torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
max_loras,
1,
self.base_layer.weight.shape[0],
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[1]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
)
return output
def forward(self, input_):
"""Forward of ColumnParallelLinear
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = (self.base_layer.bias
if not self.base_layer.skip_bias_add else None)
# Matrix multiply.
output_parallel = self.apply_weights(input_, bias)
if self.base_layer.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = (self.base_layer.bias
if self.base_layer.skip_bias_add else None)
return output, output_bias
@property
def linear_weights(self):
return self.base_layer.linear_weights
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
packed together (eg. gate_proj + up_proj -> gate_up_proj).
This means we have 2 LoRAs, each applied to one half of the layer.
Both slices must have the same size.
"""
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
super().__init__(base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
n_slices = 2
if not (len(self.base_layer.output_sizes) == n_slices
and self.base_layer.output_sizes[0]
== self.base_layer.output_sizes[1]):
raise ValueError(
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size.")
self.tp_size = get_tensor_model_parallel_world_size()
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
) for _ in range(n_slices))
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
self.base_layer.weight.shape[0] // 2,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
) for _ in range(n_slices))
self.indices: Optional[torch.Tensor] = None
self.output_dim = self.lora_b_stacked[0].shape[2]
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
self.lora_a_stacked[1][index] = 0
self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[0][:,
start_idx:end_idx], lora_b[1][:,
start_idx:end_idx]
if lora_a[0] is not None:
self.lora_a_stacked[0][
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
lora_a[0].T, non_blocking=True)
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if lora_a[1] is not None:
self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
lora_a[1].T, non_blocking=True)
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
(self.output_dim, self.output_dim),
)
return output
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
This means we have 3 LoRAs, each applied to one slice of the layer.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
# q, k, v
self.lora_a_stacked = (
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
)
self.lora_b_stacked = (
torch.zeros(
max_loras,
1,
self.q_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
)
self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
self.kv_proj_shard_size)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
self.lora_b_stacked[0][index] = 0
self.lora_a_stacked[1][index] = 0
self.lora_b_stacked[1][index] = 0
self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
self.lora_b_stacked[0][
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
lora_b_q.T, non_blocking=True)
if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
self.lora_b_stacked[1][
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
lora_b_k.T, non_blocking=True)
if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
self.lora_b_stacked[2][
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
lora_b_v.T, non_blocking=True)
else:
if lora_b[0] is not None:
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if lora_b[1] is not None:
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
if lora_b[2] is not None:
self.lora_b_stacked[2][
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
lora_b[2].T, non_blocking=True)
if lora_a[0] is not None:
self.lora_a_stacked[0][
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
lora_a[0].T, non_blocking=True)
if lora_a[1] is not None:
self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
lora_a[1].T, non_blocking=True)
if lora_a[2] is not None:
self.lora_a_stacked[2][
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
self.output_slices,
)
return output
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_a_stacked = torch.zeros(
(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.weight.shape[0],
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.base_layer.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.base_layer.weight.shape[1]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x)
_apply_lora(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
)
return output
def forward(self, input_):
"""Forward of RowParallelLinear
Args:
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
- bias
"""
# Set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (output_ + self.base_layer.bias
if self.base_layer.bias is not None else output_)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
@property
def weight(self):
return self.base_layer.weight
class SamplerWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: Sampler,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
super().__init__()
self.base_layer = base_layer
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
@property
def vocab_size(self):
return self.base_layer.vocab_size
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
@property
def include_gpu_probs_tensor(self):
return self.base_layer.include_gpu_probs_tensor
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024:
raise ValueError(
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
1,
lora_config.max_lora_rank,
self.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
# Pad for kernel compatibility
math.ceil(self.base_layer.vocab_size /
lora_config.lora_vocab_padding_size) *
lora_config.lora_vocab_padding_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.embeddings_tensors = torch.full(
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
fill_value=float("-inf"),
dtype=self.dtype,
device=self.device,
)
self.indices = None
self.indices_padded = None
self.indices_len = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = float("-inf")
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1], ] = embeddings_tensor
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = sampler_indices
self.indices_padded = sampler_indices_padded
self.indices_len = indices_len
def _get_logits(
self,
hidden_states: torch.Tensor,
embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
if logits is None:
return None
lora_logits = torch.empty(
self.embeddings_tensors.shape[0] + 1,
self.embeddings_tensors.shape[1],
hidden_states.shape[0],
dtype=self.embeddings_tensors.dtype,
device=self.embeddings_tensors.device,
)
torch.matmul(self.embeddings_tensors,
hidden_states.T,
out=lora_logits[:-1])
lora_logits[-1] = float("-inf")
lora_logits = lora_logits.mT
lora_logits = (lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2],
).index_select(0,
self.indices_padded[:self.indices_len[2]]).nan_to_num_(
nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
_apply_lora(
hidden_states,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[1]],
logits,
)
# Remove paddings in vocab (if any).
logits = logits[:, :self.base_layer.vocab_size]
return logits
def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs)
def from_layer(
layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
supported_layer_types = {
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLora,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
def from_layer_sampler(
layer: Sampler,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> SamplerWithLoRA:
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
from typing import List, Optional
import torch
from vllm.utils import in_wsl
class LoRALayerWeights:
"""LoRA weights for a layer composed of two low rank matrixes."""
def __init__(
self,
module_name: str,
rank: int,
lora_alpha: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
) -> None:
self.module_name = module_name
self.rank = rank
self.lora_alpha = lora_alpha
self.lora_a = lora_a
self.lora_b = lora_b
self.embeddings_tensor = embeddings_tensor
if scaling is None:
self.scaling = self.lora_alpha / self.rank
else:
self.scaling = scaling
def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1:
return
self.lora_b *= self.scaling
self.scaling = 1
return self
@property
def input_dim(self) -> int:
return self.lora_a.shape[0]
@property
def output_dim(self) -> int:
return self.lora_b.shape[1]
@property
def is_packed(self) -> bool:
return False
@property
def extra_vocab_size(self) -> int:
return self.embeddings_tensor.shape[
0] if self.embeddings_tensor is not None else 0
@classmethod
def create_dummy_lora_weights(
cls,
module_name: str,
input_dim: int,
output_dim: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and not in_wsl()
lora_a = torch.zeros([input_dim, rank],
dtype=dtype,
device=device,
pin_memory=pin_memory)
lora_b = torch.zeros([rank, output_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
embeddings_tensor = torch.rand(
10,
embeddings_tensor_dim,
dtype=dtype,
device=device,
pin_memory=pin_memory) if embeddings_tensor_dim else None
return cls(
module_name,
rank=rank,
lora_alpha=1,
lora_a=lora_a,
lora_b=lora_b,
embeddings_tensor=embeddings_tensor,
)
class PackedLoRALayerWeights(LoRALayerWeights):
"""LoRA used for packed layers (eg. qkv_proj)."""
def __init__(
self,
module_name: str,
rank: int,
lora_alphas: List[int],
lora_a: List[torch.Tensor],
lora_b: List[torch.Tensor],
scaling: Optional[List[float]] = None,
) -> None:
super().__init__(
module_name=module_name,
rank=rank,
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling,
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [
lora_alpha / self.rank for lora_alpha in self.lora_alphas
]
@classmethod
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
for lora in loras:
if lora is None:
continue
lora.optimize()
rank = first_lora.rank
module_name = first_lora.module_name
obj = cls(
module_name,
rank,
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
scaling=[1 if lora is not None else None for lora in loras])
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None:
continue
self.lora_b[i] *= self.scaling[i]
self.scaling[i] = 1
return self
@property
def input_dim(self) -> int:
raise NotImplementedError()
@property
def output_dim(self) -> int:
raise NotImplementedError()
@property
def is_packed(self) -> bool:
return True
import copy
import json
import logging
import math
import os
import re
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type,
Union)
import safetensors.torch
import torch
from torch import nn
from vllm.config import LoRAConfig
from vllm.utils import LRUCache, in_wsl
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
logger = logging.getLogger(__name__)
# TODO: The mappings below should be moved to individual model classes.
PACKED_MODULES_CFG = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
TARGET_MODULES_QKV = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
EMBEDDING_PADDING_MODULES = ["lm_head"]
_GLOBAL_LORA_ID = 0
def convert_mapping(
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
max_loras: int, vocab_size: int, extra_vocab_size: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
indices_len: List of lengths of the above tensors.
"""
indices = list(mapping.index_mapping).copy()
embedding_indices = indices.copy()
lora_indices = indices.copy()
prompt_mapping = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(indices[i])
if indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
indices[i] = i
lora_indices[i] = lora_idx
indices = torch.tensor([indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1])
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
return _GLOBAL_LORA_ID
class LoRAModel:
"""A LoRA fine-tuned model."""
def __init__(
self,
lora_model_id: int,
rank: int,
loras: Dict[str, LoRALayerWeights],
) -> None:
self.id = lora_model_id
assert (lora_model_id >
0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras
@property
def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size
for lora in self.loras.values()) if self.loras else 0
def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
# (yard1): TODO see if we can derive target_embedding_padding automatically
@classmethod
def from_lora_tensors(
cls,
lora_model_id: int,
rank: int,
lora_alpha: int,
tensors: Dict[str, torch.Tensor],
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and not in_wsl()
loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name),
None)
if embeddings_module:
lora_embeddings_tensor = embeddings[
EMBEDDING_MODULES[embeddings_module]].to(
device=device, dtype=dtype)
if pin_memory:
lora_embeddings_tensor = (
lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights(module_name, rank,
lora_alpha, None, None,
lora_embeddings_tensor)
if is_lora_a:
loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t()
if pin_memory:
loras[module_name].lora_a = loras[
module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
if any(name in module_name
for name in EMBEDDING_PADDING_MODULES
) and target_embedding_padding is not None:
lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[1]
addition = target_embedding_padding - lora_b.shape[1]
loras[module_name].lora_b = torch.nn.functional.pad(
lora_b, (0, addition))
if pin_memory:
loras[module_name].lora_b = loras[
module_name].lora_b.pin_memory()
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, rank, loras)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint."""
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin")
if os.path.isfile(lora_tensor_path):
tensors = safetensors.torch.load_file(lora_tensor_path)
elif os.path.isfile(lora_bin_file_path):
tensors = torch.load(lora_bin_file_path)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
embeddings = None
if os.path.isfile(new_embeddings_tensor_path):
embeddings = safetensors.torch.load_file(
new_embeddings_tensor_path)
elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(new_embeddings_bin_file_path)
with open(lora_config_path) as f:
config = json.load(f)
rank = config["r"]
lora_alpha = config["lora_alpha"]
return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
rank=rank,
lora_alpha=lora_alpha,
tensors=tensors,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
)
class LoRAModelManager:
"""A manager that manages multiple LoRA-fine-tuned models."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
):
"""Create a LoRAModelManager and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
vocab_size: the vocab size of the model.
lora_config: the LoRA configuration.
lora_target_modules: the target modules patterns to be adapted.
Support both single module name and a list of module names.
packed_modules_mapping: the mapping for packed modules. vLLM
packs some modules into one module, e.g., qkv_proj
is packed of q_proj, k_proj, and v_proj. These modules
have a single layer in the original model, but they are split
into multiple layers in the adapted model.
"""
self.lora_config = lora_config
self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.embeddings_indices = torch.empty(2,
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.offsets = []
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len = [None] * 4
self.model: nn.Module = model
self.lora_target_modules: List[str] = ([
lora_target_modules
] if isinstance(lora_target_modules, str) else lora_target_modules)
self.lora_target_modules = copy.deepcopy(lora_target_modules)
self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping = None
self._create_lora_modules()
self.model.lora_manager = self
@property
def capacity(self) -> int:
return self.lora_config.max_cpu_loras
@property
def lora_slots(self) -> int:
return self.lora_config.max_loras
def __len__(self) -> int:
return len(self._registered_loras)
def activate_lora(
self,
lora_id: int,
) -> bool:
"""Move LoRA into a GPU buffer to be used in the forward pass."""
if lora_id in self._active_loras:
return False
first_free_slot = next(
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
if lora_id is None), None)
if first_free_slot is None:
raise ValueError("No free lora slots")
index, _ = first_free_slot
self._active_loras[lora_id] = None
lora_model = self._registered_loras[lora_id]
logger.debug(
f"Activating LoRA. int id: {lora_model.id}, slot index: {index}")
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name)
if module_lora:
module_lora.optimize()
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
module_lora.embeddings_tensor)
else:
module.reset_lora(index)
return True
def _deactivate_lora(self, lora_id: int):
try:
index = self.lora_index_to_id.index(lora_id)
self.lora_index_to_id[index] = None
except ValueError:
pass
def deactivate_lora(self, lora_id: int) -> bool:
"""Remove a LoRA from a GPU buffer."""
if lora_id in self._active_loras:
self._deactivate_lora(lora_id)
self._active_loras.pop(lora_id)
return True
return False
def _add_lora(self, lora: LoRAModel) -> bool:
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager CPU cache."""
if lora.id not in self._registered_loras:
if len(self._registered_loras) >= self.capacity:
raise RuntimeError("No free LoRA slots.")
self._add_lora(lora)
return True
return False
def remove_lora(self, lora_id: int) -> bool:
"""Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora?
self.deactivate_lora(lora_id)
return bool(self._registered_loras.pop(lora_id, None))
# TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices,
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
self.lora_slots + 1, self.vocab_size,
self.lora_config.lora_extra_vocab_size)
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self.embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
# Maintain the reference
self.indices_len[:] = indices_len
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
if self._last_mapping != lora_mapping:
self._set_lora_mapping(lora_mapping)
self._last_mapping = lora_mapping
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras)
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)
def remove_all_loras(self) -> bool:
"""Remove all LoRAModels from the manager."""
self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots
self._active_loras.clear()
def _create_lora_modules(self):
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name):
continue
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
self.model.config))
# (yard1): TODO make this more robust
if "lm_head" in module_name:
sampler_module = self.model.get_submodule("sampler")
new_module = replace_submodule(
self.model, "sampler",
from_layer_sampler(sampler_module, module, self.lora_slots,
self.lora_config, self.model.config))
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices,
self.sampler_indices_padded,
self.embeddings_indices, self.indices_len)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
self.modules[module_name] = module
def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
if parts[-1] in EMBEDDING_MODULES:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
hasattr(module.base_layer, "org_vocab_size")
else module.base_layer.weight.shape[1])
output_dim = module.base_layer.embedding_dim if hasattr(
module.base_layer,
"embedding_dim") else module.base_layer.weight.shape[0]
embeddings_tensor_dim = (module.base_layer.embedding_dim if
hasattr(module.base_layer,
"embedding_dim") else
module.base_layer.weight.shape[1])
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
input_dim,
output_dim,
rank,
module.lora_a_stacked.dtype,
"cpu",
embeddings_tensor_dim=embeddings_tensor_dim)
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.lora_a_stacked.shape[-1],
module.lora_b_stacked.shape[-2],
rank,
module.lora_a_stacked.dtype,
"cpu",
)
lora.optimize()
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
module.lora_a_stacked[i].shape[-1],
module.lora_b_stacked[i].shape[-2],
rank,
module.lora_a_stacked[i].dtype,
"cpu",
)
lora.optimize()
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):
return any(
re.match(
r".*\.{target_module}$".format(target_module=target_module),
module_name) or target_module == module_name
for target_module in self.lora_target_modules)
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
replacements = self.packed_modules_mapping.get(module_name)
if not replacements:
return
prefix = ".".join(parts[:-1])
self.packed_modules[module_full_name] = [
prefix + "." + r if prefix else r for r in replacements
]
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras = []
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
replacement_loras.append(lora)
if lora:
has_replacement = True
if not has_replacement:
continue
for i in range(len(replacement_loras)):
if replacement_loras[i]:
continue
replacement_loras[i] = None
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras)
class LoRALRUCache(LRUCache):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: Any):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
):
super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config, lora_target_modules,
packed_modules_mapping)
self._registered_loras: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_lora)
self._active_loras: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_lora)
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras.cache)
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
if lora.id not in self._registered_loras:
self._add_lora(lora)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_loras.touch(lora.id)
was_added = False
return was_added
def activate_lora(
self,
lora_id: int,
) -> bool:
if lora_id not in self._active_loras and len(
self._active_loras) >= self.lora_slots:
self._active_loras.remove_oldest()
result = super().activate_lora(lora_id)
# We always touch to update the LRU cache order
self._active_loras.touch(lora_id)
return result
def remove_oldest_lora(self) -> bool:
if len(self._registered_loras) > 0:
self._registered_loras.remove_oldest()
return True
return False
def create_lora_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not getattr(model, "supports_lora", False):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
lora_target_modules=target_modules,
**kwargs)
return lora_manager
# Based on code from https://github.com/punica-ai/punica
from typing import Optional
import torch
import_exc = None
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
import_exc = e
if import_exc is None:
def bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
matrices.
indicies: Shape: `[B]`. Indices of the weight matrices.
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
"""
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def add_lora(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
*,
buffer: Optional[torch.Tensor] = None):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
"""
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical innacuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
1.0)
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
scale)
def add_lora_slice(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
y_offset: int,
y_slice_size: int,
*,
buffer: Optional[torch.Tensor] = None):
"""
Same as `add_lora` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv_low_level(
buffer,
x,
wa_t_all,
indicies,
layer_idx,
1.0,
x.size(1),
buffer.size(1),
0,
)
punica_kernels.dispatch_bgmv_low_level(
y,
buffer,
wb_t_all,
indicies,
layer_idx,
scale,
buffer.size(1),
y_slice_size,
y_offset,
)
else:
def _raise_exc(
*args, # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
if torch.cuda.get_device_capability() < (8, 0):
raise ImportError(
"LoRA kernels require compute capability>=8.0") from import_exc
else:
raise import_exc
bgmv = _raise_exc
add_lora = _raise_exc
add_lora_slice = _raise_exc
__all__ = [
"bgmv",
"add_lora",
"add_lora_slice",
]
from dataclasses import dataclass
@dataclass
class LoRARequest:
"""
Request for a LoRA adapter.
Note that this class should be be used internally. For online
serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters.
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
lora_name: str
lora_int_id: int
lora_local_path: str
def __post_init__(self):
if self.lora_int_id < 1:
raise ValueError(
f"lora_int_id must be > 0, got {self.lora_int_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, LoRARequest) and self.lora_int_id == value.lora_int_id
def __hash__(self) -> int:
return self.lora_int_id
import logging
from typing import Tuple
from torch import nn
logger = logging.getLogger(__name__)
def replace_submodule(model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
"""Parse the name of lora weights.
args:
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
return:
Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
"""
parts = name.split(".")
assert parts[0] == "base_model"
assert parts[1] == "model"
if parts[-1] == "weight":
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported format")
import logging
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, List, Optional, Set, Type, Union
import torch
from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest
from vllm.lora.layers import LoRAMapping
from vllm.config import LoRAConfig
logger = logging.getLogger(__name__)
class WorkerLoRAManager(ABC):
"""Abstract class for managing LoRA models on the worker side."""
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
vocab_size: int, lora_config: LoRAConfig,
device: torch.device):
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.device = device
self.lora_config = lora_config
@abstractproperty
def is_enabled(self) -> bool:
...
@abstractmethod
def create_lora_manager(
self,
model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any:
...
@abstractmethod
def set_active_loras(self, lora_requests: List[LoRARequest],
lora_mapping: LoRAMapping) -> None:
...
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
...
@abstractmethod
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
...
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
...
@abstractmethod
def remove_all_loras(self) -> bool:
...
@abstractmethod
def list_loras(self) -> Set[int]:
...
class WorkerLoRAManager(WorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded."""
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
def __init__(
self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
lora_model_cls: Type[LoRAModel] = LoRAModel,
):
self._lora_manager: Optional[LoRAModelManager] = None
self._lora_model_cls = lora_model_cls
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device)
@property
def is_enabled(self) -> bool:
return True
def create_lora_manager(
self,
model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any:
lora_manager = create_lora_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
target_modules=target_modules,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls,
)
self._lora_manager: LoRAModelManager = lora_manager
return lora_manager.model
def set_active_loras(self, lora_requests: List[LoRARequest],
lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping)
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
loras_that_exist = self.list_loras()
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
new_loras = set(loras_map)
loras_to_add = new_loras - loras_that_exist
loras_to_remove = loras_that_exist - new_loras
for lora_id in loras_to_remove:
self.remove_lora(lora_id)
for lora_id in loras_to_add:
self.add_lora(loras_map[lora_id])
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
try:
lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size +
self.lora_config.lora_extra_vocab_size,
)
except Exception as e:
raise RuntimeError(
f"Loading lora {lora_request.lora_local_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
raise ValueError(
f"LoRA rank {lora.rank} is greater than max_lora_rank "
f"{self.lora_config.max_lora_rank}.")
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(
f"LoRA added vocab size {lora.extra_vocab_size} is greater than "
f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}."
)
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
return self._lora_manager.add_lora(
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
rank))
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
self._lora_manager.activate_lora(lora.id)
return loaded
def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)
def remove_all_loras(self) -> bool:
self._lora_manager.remove_all_loras()
def list_loras(self) -> Set[int]:
return set(self._lora_manager.list_loras())
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
(unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity."""
_lora_manager_cls: Type[
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager(
self,
model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any:
lora_manager = create_lora_manager(
model,
target_modules=target_modules,
lora_manager_cls=self._lora_manager_cls,
max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens,
)
self._lora_manager: LRUCacheLoRAModelManager = lora_manager
return lora_manager.model
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
for lora in loras_map.values():
self.add_lora(lora)
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id not in self.list_loras():
# Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
self._lora_manager.remove_oldest_lora()
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
self._lora_manager.activate_lora(lora_request.lora_int_id)
return loaded
......@@ -27,9 +27,25 @@ class Sampler(nn.Module):
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
"""
def __init__(self, vocab_size: int) -> None:
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None) -> None:
super().__init__()
self.vocab_size = vocab_size
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def forward(
self,
......@@ -42,8 +58,7 @@ class Sampler(nn.Module):
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
# Get the logits for the next tokens.
logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size)
logits = self._get_logits(hidden_states, embedding, embedding_bias)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
......@@ -98,20 +113,6 @@ class Sampler(nn.Module):
prompt_logprobs, sample_logprobs)
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor],
vocab_size: int) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :vocab_size]
return logits
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment