Unverified Commit ad932a22 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Core] Faster startup for LoRA enabled models (#4634)

parent 5510cf0e
...@@ -119,6 +119,16 @@ class LoRAModel: ...@@ -119,6 +119,16 @@ class LoRAModel:
self.rank = rank self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras self.loras: Dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)
@property @property
def extra_vocab_size(self) -> int: def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size return max(lora.extra_vocab_size
......
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Set, Type from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Set, Type, Union
import torch import torch
...@@ -25,6 +26,17 @@ class AbstractWorkerLoRAManager(ABC): ...@@ -25,6 +26,17 @@ class AbstractWorkerLoRAManager(ABC):
self.device = device self.device = device
self.lora_config = lora_config self.lora_config = lora_config
# If False, do not cache. If None, cache is empty.
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@abstractproperty @abstractproperty
def is_enabled(self) -> bool: def is_enabled(self) -> bool:
... ...
...@@ -174,9 +186,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -174,9 +186,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras(): if lora_request.lora_int_id in self.list_loras():
return False return False
return self._lora_manager.add_lora( if isinstance(self._cached_dummy_lora, LoRAModel):
self._lora_manager.create_dummy_lora(lora_request.lora_int_id, dummy_lora = self._cached_dummy_lora.clone(
rank, self.embedding_modules)) lora_request.lora_int_id)
else:
dummy_lora = self._lora_manager.create_dummy_lora(
lora_request.lora_int_id, rank, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras(): if lora_request.lora_int_id in self.list_loras():
......
...@@ -835,20 +835,21 @@ class ModelRunner: ...@@ -835,20 +835,21 @@ class ModelRunner:
dummy_lora_requests = [] dummy_lora_requests = []
dummy_lora_requests_per_seq = [] dummy_lora_requests_per_seq = []
if self.lora_config: if self.lora_config:
for idx in range(self.lora_config.max_loras): with self.lora_manager.dummy_lora_cache():
lora_id = idx + 1 for idx in range(self.lora_config.max_loras):
dummy_lora_request = LoRARequest( lora_id = idx + 1
lora_name=f"warmup_{lora_id}", dummy_lora_request = LoRARequest(
lora_int_id=lora_id, lora_name=f"warmup_{lora_id}",
lora_local_path="/not/a/real/path", lora_int_id=lora_id,
) lora_local_path="/not/a/real/path",
self.lora_manager.add_dummy_lora(dummy_lora_request, )
rank=LORA_WARMUP_RANK) self.lora_manager.add_dummy_lora(dummy_lora_request,
dummy_lora_requests.append(dummy_lora_request) rank=LORA_WARMUP_RANK)
dummy_lora_requests_per_seq = [ dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests[idx % len(dummy_lora_requests)] dummy_lora_requests_per_seq = [
for idx in range(max_num_seqs) dummy_lora_requests[idx % len(dummy_lora_requests)]
] for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total # Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment