Unverified Commit ad932a22 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Core] Faster startup for LoRA enabled models (#4634)

parent 5510cf0e
......@@ -119,6 +119,16 @@ class LoRAModel:
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)
@property
def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size
......
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Set, Type
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Set, Type, Union
import torch
......@@ -25,6 +26,17 @@ class AbstractWorkerLoRAManager(ABC):
self.device = device
self.lora_config = lora_config
# If False, do not cache. If None, cache is empty.
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@abstractproperty
def is_enabled(self) -> bool:
...
......@@ -174,9 +186,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
return self._lora_manager.add_lora(
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
rank, self.embedding_modules))
if isinstance(self._cached_dummy_lora, LoRAModel):
dummy_lora = self._cached_dummy_lora.clone(
lora_request.lora_int_id)
else:
dummy_lora = self._lora_manager.create_dummy_lora(
lora_request.lora_int_id, rank, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora)
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras():
......
......@@ -835,6 +835,7 @@ class ModelRunner:
dummy_lora_requests = []
dummy_lora_requests_per_seq = []
if self.lora_config:
with self.lora_manager.dummy_lora_cache():
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment