worker_manager.py 9.33 KB
Newer Older
1
from contextlib import contextmanager
2
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
3
4
5

import torch

6
7
8
9
10
from vllm.adapter_commons.utils import (add_adapter_worker,
                                        apply_adapters_worker,
                                        list_adapters_worker,
                                        set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
11
from vllm.config import LoRAConfig
12
from vllm.logger import init_logger
Terry's avatar
Terry committed
13
from vllm.lora.models import (LoRAModel, LoRAModelManager,
14
15
                              LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest
16
from vllm.lora.utils import get_adapter_absolute_path
17

18
logger = init_logger(__name__)
19
20


21
class WorkerLoRAManager(AbstractWorkerManager):
22
23
24
25
26
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

27
    _manager_cls: Type[LoRAModelManager] = LoRAModelManager
28
29
30
31
32
33
34
35

    def __init__(
        self,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
Terry's avatar
Terry committed
36
37
        embedding_modules: Dict[str, str],
        embedding_padding_modules: List[str],
38
        lora_model_cls: Type[LoRAModel] = LoRAModel,
39
        max_position_embeddings: Optional[int] = None,
40
41
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
42
43
        self.embedding_modules = embedding_modules
        self.embedding_padding_modules = embedding_padding_modules
44
45
46
47
48
49
50
        self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
        self.max_num_seqs = max_num_seqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.vocab_size = vocab_size
        self.lora_config = lora_config
        self.max_position_embeddings = max_position_embeddings
        super().__init__(device)
51
        # Lazily initialized by create_lora_manager.
52
53
54
55
56
57
58
59
60
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
76
            device=self.device,
77
            lora_manager_cls=self._manager_cls,
78
        )
79
        self._adapter_manager = lora_manager
80
81
        return lora_manager.model

82
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
83
        try:
84
            model = self._adapter_manager.model
85
86
            supported_lora_modules = model.supported_lora_modules
            packed_modules_mapping = model.packed_modules_mapping
87
            expected_lora_modules: List[str] = []
88
89
90
91
92
93
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
                    expected_lora_modules.extend(
                        packed_modules_mapping[module])
                else:
                    expected_lora_modules.append(module)
94
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
95
96
97
98
99
100
101
102

            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
            hf_to_vllm_mapper = None
            if (hasattr(model, "hf_to_vllm_mapper")
                    and model.hf_to_vllm_mapper is not None):
                hf_to_vllm_mapper = model.hf_to_vllm_mapper

103
            lora = self._lora_model_cls.from_local_checkpoint(
104
                lora_path,
105
                expected_lora_modules,
106
                max_position_embeddings=self.max_position_embeddings,
107
108
109
110
111
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
                target_embedding_padding=self.vocab_size +
                self.lora_config.lora_extra_vocab_size,
Terry's avatar
Terry committed
112
113
                embedding_modules=self.embedding_modules,
                embedding_padding_modules=self.embedding_padding_modules,
114
115
                weights_mapper=hf_to_vllm_mapper)

116
        except Exception as e:
117
            raise RuntimeError(f"Loading lora {lora_path} failed") from e
118
119
120
121
122
        if lora.rank > self.lora_config.max_lora_rank:
            raise ValueError(
                f"LoRA rank {lora.rank} is greater than max_lora_rank "
                f"{self.lora_config.max_lora_rank}.")
        if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
123
124
125
            raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
                             f"is greater than lora_extra_vocab_size "
                             f"{self.lora_config.lora_extra_vocab_size}.")
126
127
128
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
129
        if lora_request.lora_int_id in self.list_adapters():
130
            return False
131
132
133
134
        if isinstance(self._cached_dummy_lora, LoRAModel):
            dummy_lora = self._cached_dummy_lora.clone(
                lora_request.lora_int_id)
        else:
135
            dummy_lora = self._adapter_manager.create_dummy_lora(
136
                lora_request.lora_int_id, rank, 1, self.embedding_modules)
137
138
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
139
        return self._adapter_manager.add_adapter(dummy_lora)
140

141
142
143
144
145
146
147
148
149
150
151
152
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

    def set_active_adapters(self, requests: Set[Any],
                            mapping: Optional[Any]) -> None:
        set_active_adapters_worker(requests, mapping, self._apply_adapters,
                                   self._adapter_manager.set_adapter_mapping)

    def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
        apply_adapters_worker(adapter_requests, self.list_adapters,
                              self._adapter_manager.adapter_slots,
                              self.remove_adapter, self.add_adapter)
153

154
155
156
157
158
    def add_adapter(self, adapter_request: Any) -> bool:
        return add_adapter_worker(adapter_request, self.list_adapters,
                                  self._load_adapter,
                                  self._adapter_manager.add_adapter,
                                  self._adapter_manager.activate_adapter)
159

160
161
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
162

163
164
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
165

166
167
    def list_adapters(self) -> Set[int]:
        return list_adapters_worker(self._adapter_manager.list_adapters)
168
169
170
171
172
173
174
175
176


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

177
    _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
178
179
180
181
182
183
184

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
185
            lora_manager_cls=self._manager_cls,
186
187
188
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
189
            device=self.device,
190
191
            max_num_batched_tokens=self.max_num_batched_tokens,
        )
192
        self._adapter_manager = lora_manager
193
194
        return lora_manager.model

195
    def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
196
197
198
199
        loras_map = {
            lora_request.lora_int_id: lora_request
            for lora_request in lora_requests if lora_request
        }
200
        if len(loras_map) > self._adapter_manager.lora_slots:
201
202
203
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
204
                f"({self._adapter_manager.lora_slots}).")
205
        for lora in loras_map.values():
206
            self.add_adapter(lora)
207

208
209
    def add_adapter(self, lora_request: LoRARequest) -> bool:
        if lora_request.lora_int_id not in self.list_adapters():
210
            # Remove before we load the new lora to save memory
211
212
213
214
215
216
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
                assert isinstance(self._adapter_manager,
                                  LRUCacheLoRAModelManager)
                self._adapter_manager.remove_oldest_adapter()
            lora = self._load_adapter(lora_request)
            loaded = self._adapter_manager.add_adapter(lora)
217
218
219
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
220
            loaded = self._adapter_manager.get_adapter(
221
                lora_request.lora_int_id) is not None
222
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
223
        return loaded