Unverified Commit b5b4a398 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Mypy] Typing lora folder (#4337)

parent f4bc4de1
...@@ -33,8 +33,6 @@ jobs: ...@@ -33,8 +33,6 @@ jobs:
- name: Mypy - name: Mypy
run: | run: |
mypy vllm/attention --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml
...@@ -44,8 +42,9 @@ jobs: ...@@ -44,8 +42,9 @@ jobs:
mypy vllm/engine --config-file pyproject.toml mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
# TODO(sang): Fix nested dir # TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Fix nested dir mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --config-file pyproject.toml
...@@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml ...@@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml mypy vllm/model_executor/*.py --config-file pyproject.toml
# mypy vllm/lora/*.py --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml
CODESPELL_EXCLUDES=( CODESPELL_EXCLUDES=(
......
...@@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None: def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]
def create_lora_weights( def create_lora_weights(
self, self,
...@@ -233,9 +235,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -233,9 +235,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2], self.lora_a_stacked.shape[2],
) )
self.indices: Optional[torch.Tensor] = None # Lazily initialized.
self.indices_len: Optional[List[int]] = None self.indices: torch.Tensor
self.embeddings_indices = None self.indices_len: List[int]
self.embeddings_indices: torch.Tensor
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
...@@ -267,6 +270,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -267,6 +270,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.embeddings_tensors.shape[1], self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2] self.embeddings_tensors.shape[2]
)[self.embeddings_slice[0]:self.embeddings_slice[1]] )[self.embeddings_slice[0]:self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def set_mapping( def set_mapping(
...@@ -343,11 +347,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -343,11 +347,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[2] self.output_dim = self.lora_b_stacked.shape[2]
# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0 self.lora_b_stacked[index] = 0
...@@ -475,8 +480,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -475,8 +480,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
device=self.device, device=self.device,
) for _ in range(n_slices)) ) for _ in range(n_slices))
self.indices: Optional[torch.Tensor] = None
self.output_dim = self.lora_b_stacked[0].shape[2] self.output_dim = self.lora_b_stacked[0].shape[2]
# Lazily initialized.
self.indices: torch.Tensor
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0 self.lora_a_stacked[0][index] = 0
...@@ -690,7 +696,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -690,7 +696,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.kv_proj_shard_size) self.kv_proj_shard_size)
self.packed_indices: Optional[torch.Tensor] = None self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None # lazily initialized.
self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0 self.lora_a_stacked[0][index] = 0
...@@ -814,8 +821,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -814,8 +821,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
self.indices: Optional[torch.Tensor] = None # Lazily initialized
self.indices_len: Optional[List[int]] = None self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
...@@ -991,9 +999,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -991,9 +999,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
) )
self.indices = None # Lazily initialized.
self.indices_padded = None self.indices: torch.Tensor
self.indices_len = None self.indices_len: List[int]
self.indices_padded: torch.Tensor
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
......
...@@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights):
self, self,
module_name: str, module_name: str,
rank: int, rank: int,
lora_alphas: List[int], lora_alphas: List[Optional[int]],
lora_a: List[torch.Tensor], lora_a: List[Optional[torch.Tensor]],
lora_b: List[torch.Tensor], lora_b: List[Optional[torch.Tensor]],
scaling: Optional[List[float]] = None, scaling: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha=0, lora_alpha=0,
lora_a=lora_a, lora_a=lora_a,
lora_b=lora_b, lora_b=lora_b,
scaling=scaling, scaling=scaling, # type: ignore
embeddings_tensor=None, embeddings_tensor=None,
) )
self.lora_alphas = lora_alphas self.lora_alphas = lora_alphas
if scaling is None: if scaling is None:
self.scaling = [ self.scaling = [ # type: ignore
lora_alpha / self.rank for lora_alpha in self.lora_alphas lora_alpha / self.rank # type: ignore # noqa
for lora_alpha in self.lora_alphas
] ]
@classmethod @classmethod
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": def pack(
cls, loras: List[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA. """Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA. If LoRA is None, it signifies that the submodule does not have a LoRA.
...@@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras],
scaling=[1 if lora is not None else None for lora in loras]) scaling=[
1 if lora is not None else None # type: ignore
for lora in loras
])
return obj return obj
def optimize(self) -> "PackedLoRALayerWeights": def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b.""" """Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)): for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None: if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue continue
self.lora_b[i] *= self.scaling[i] self.lora_b[i] *= self.scaling[i] # type: ignore
self.scaling[i] = 1 self.scaling[i] = 1 # type: ignore
return self return self
@property @property
......
...@@ -3,7 +3,7 @@ import json ...@@ -3,7 +3,7 @@ import json
import math import math
import os import os
import re import re
from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type from typing import Callable, Dict, List, Optional, Tuple, Type
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -53,26 +53,27 @@ def convert_mapping( ...@@ -53,26 +53,27 @@ def convert_mapping(
embeddings. embeddings.
indices_len: List of lengths of the above tensors. indices_len: List of lengths of the above tensors.
""" """
indices = list(mapping.index_mapping).copy() index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = indices.copy() embedding_indices = index_mapping_indices.copy()
lora_indices = indices.copy() lora_indices = index_mapping_indices.copy()
prompt_mapping = [ prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1 lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping for x in mapping.prompt_mapping
] ]
lora_idx = None lora_idx = None
for i in range(len(indices)): for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize # TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(indices[i]) lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if indices[i] > 0 else -1) if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if indices[i] > 0 else 0 embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
indices[i] = i index_mapping_indices[i] = i
lora_indices[i] = lora_idx lora_indices[i] = lora_idx
indices = torch.tensor([indices, lora_indices, embedding_indices], indices = torch.tensor(
[index_mapping_indices, lora_indices, embedding_indices],
dtype=torch.long, dtype=torch.long,
device="cuda") device="cuda")
prompt_mapping = torch.tensor(prompt_mapping, prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda", device="cuda",
dtype=torch.long) dtype=torch.long)
embeddings_indices = torch.stack([ embeddings_indices = torch.stack([
...@@ -81,16 +82,17 @@ def convert_mapping( ...@@ -81,16 +82,17 @@ def convert_mapping(
]) ])
embeddings_indices[embeddings_indices == -1] = max_loras - 1 embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1] base_indices = indices[1]
sampler_indices = prompt_mapping sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone() sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = ( sampler_indices_padded = (
torch.arange( torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded))) (sampler_indices_padded * len(sampler_indices_padded)))
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], indices_len = [
sampler_indices_padded.shape[-1], base_indices.shape[-1], sampler_indices.shape[-1],
embeddings_indices.shape[-1]) sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
return (base_indices, sampler_indices, sampler_indices_padded, return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len) embeddings_indices, indices_len)
...@@ -149,6 +151,7 @@ class LoRAModel: ...@@ -149,6 +151,7 @@ class LoRAModel:
if module_name not in loras: if module_name not in loras:
lora_embeddings_tensor = None lora_embeddings_tensor = None
if embeddings: if embeddings:
assert embedding_modules is not None
embeddings_module = next( embeddings_module = next(
(k for k in embedding_modules if k in module_name), (k for k in embedding_modules if k in module_name),
None) None)
...@@ -171,6 +174,7 @@ class LoRAModel: ...@@ -171,6 +174,7 @@ class LoRAModel:
else: else:
loras[module_name].lora_b = tensor.to(device=device, loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t() dtype=dtype).t()
assert embedding_padding_modules is not None
if any(name in module_name if any(name in module_name
for name in embedding_padding_modules for name in embedding_padding_modules
) and target_embedding_padding is not None: ) and target_embedding_padding is not None:
...@@ -295,11 +299,10 @@ class LoRAModelManager: ...@@ -295,11 +299,10 @@ class LoRAModelManager:
self.max_num_batched_tokens, self.max_num_batched_tokens,
dtype=torch.long, dtype=torch.long,
device="cuda") device="cuda")
self.offsets = []
# 4 is the number of indicies tensors defined above # 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded, # base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices # embeddings_indices
self.indices_len = [None] * 4 self.indices_len: List[Optional[int]] = [None] * 4
self.model: nn.Module = model self.model: nn.Module = model
if hasattr(self.model, "supported_lora_modules"): if hasattr(self.model, "supported_lora_modules"):
...@@ -312,7 +315,7 @@ class LoRAModelManager: ...@@ -312,7 +315,7 @@ class LoRAModelManager:
self._registered_loras: Dict[int, LoRAModel] = {} self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache. # Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {} self._active_loras: Dict[int, None] = {}
self._last_mapping = None self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self self.model.lora_manager = self
...@@ -370,7 +373,7 @@ class LoRAModelManager: ...@@ -370,7 +373,7 @@ class LoRAModelManager:
return True return True
return False return False
def _add_lora(self, lora: LoRAModel) -> bool: def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora) self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora self._registered_loras[lora.id] = lora
...@@ -418,7 +421,7 @@ class LoRAModelManager: ...@@ -418,7 +421,7 @@ class LoRAModelManager:
def get_lora(self, lora_id: int) -> Optional[LoRAModel]: def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None) return self._registered_loras.get(lora_id, None)
def remove_all_loras(self) -> bool: def remove_all_loras(self):
"""Remove all LoRAModels from the manager.""" """Remove all LoRAModels from the manager."""
self._registered_loras.clear() self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots self.lora_index_to_id = [None] * self.lora_slots
...@@ -467,6 +470,7 @@ class LoRAModelManager: ...@@ -467,6 +470,7 @@ class LoRAModelManager:
continue continue
parts = module_name.split(".") parts = module_name.split(".")
if module_name not in self.packed_modules: if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules: if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size + input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if self.lora_config.lora_extra_vocab_size if
...@@ -500,7 +504,7 @@ class LoRAModelManager: ...@@ -500,7 +504,7 @@ class LoRAModelManager:
else: else:
parts = module_name.split(".") parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]] replacements = self.packed_modules_mapping[parts[-1]]
subloras = [] subloras: List[Optional["LoRALayerWeights"]] = []
for i, r in enumerate(replacements): for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights( lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r, module_name + "." + r,
...@@ -538,7 +542,7 @@ class LoRAModelManager: ...@@ -538,7 +542,7 @@ class LoRAModelManager:
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items(): for module_name, new_module_names in self.packed_modules.items():
replacement_loras = [] replacement_loras: List[Optional[LoRALayerWeights]] = []
has_replacement = False has_replacement = False
for r in new_module_names: for r in new_module_names:
lora = lora_model.get_lora(r) lora = lora_model.get_lora(r)
...@@ -557,12 +561,12 @@ class LoRAModelManager: ...@@ -557,12 +561,12 @@ class LoRAModelManager:
class LoRALRUCache(LRUCache[LoRAModel]): class LoRALRUCache(LRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
None]): bool]):
super().__init__(capacity) super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: LoRAModel): def _on_remove(self, key: int, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}") logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key) self.deactivate_lora_fn(key)
return super()._on_remove(key, value) return super()._on_remove(key, value)
......
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Optional, Set, Type from typing import Any, Dict, List, Set, Type
import torch import torch
...@@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC): ...@@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC):
... ...
@abstractmethod @abstractmethod
def set_active_loras(self, lora_requests: List[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None: lora_mapping: LoRAMapping) -> None:
... ...
...@@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC): ...@@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC):
... ...
@abstractmethod @abstractmethod
def remove_all_loras(self) -> bool: def remove_all_loras(self):
... ...
@abstractmethod @abstractmethod
...@@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
embedding_padding_modules: List[str], embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel, lora_model_cls: Type[LoRAModel] = LoRAModel,
): ):
self._lora_manager: Optional[LoRAModelManager] = None
self._lora_model_cls = lora_model_cls self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules self.embedding_padding_modules = embedding_padding_modules
# Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device) lora_config, device)
...@@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora_config=self.lora_config, lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls, lora_manager_cls=self._lora_manager_cls,
) )
self._lora_manager: LoRAModelManager = lora_manager self._lora_manager = lora_manager
return lora_manager.model return lora_manager.model
def set_active_loras(self, lora_requests: Set[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
...@@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id) return self._lora_manager.remove_lora(lora_id)
def remove_all_loras(self) -> bool: def remove_all_loras(self):
self._lora_manager.remove_all_loras() self._lora_manager.remove_all_loras()
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
...@@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
lora_config=self.lora_config, lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
) )
self._lora_manager: LRUCacheLoRAModelManager = lora_manager self._lora_manager = lora_manager
return lora_manager.model return lora_manager.model
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_map = { loras_map = {
lora_request.lora_int_id: lora_request lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request for lora_request in lora_requests if lora_request
...@@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
if lora_request.lora_int_id not in self.list_loras(): if lora_request.lora_int_id not in self.list_loras():
# Remove before we load the new lora to save memory # Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity: if len(self._lora_manager) + 1 > self._lora_manager.capacity:
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
self._lora_manager.remove_oldest_lora() self._lora_manager.remove_oldest_lora()
lora = self._load_lora(lora_request) lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora) loaded = self._lora_manager.add_lora(lora)
else: else:
# If the lora is already loaded, just touch it to # If the lora is already loaded, just touch it to
# update its position in the caches # update its position in the caches
loaded = self._lora_manager.get_lora(lora_request.lora_int_id) loaded = self._lora_manager.get_lora(
lora_request.lora_int_id) is not None
self._lora_manager.activate_lora(lora_request.lora_int_id) self._lora_manager.activate_lora(lora_request.lora_int_id)
return loaded return loaded
...@@ -928,10 +928,10 @@ class ModelRunner: ...@@ -928,10 +928,10 @@ class ModelRunner:
torch.cuda.synchronize() torch.cuda.synchronize()
return return
def remove_all_loras(self) -> bool: def remove_all_loras(self):
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras() self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: Set[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None: lora_mapping: LoRAMapping) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment