worker_manager.py 10.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from contextlib import contextmanager
4
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
5
6
7

import torch

8
9
10
11
12
from vllm.adapter_commons.utils import (add_adapter_worker,
                                        apply_adapters_worker,
                                        list_adapters_worker,
                                        set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
13
from vllm.config import LoRAConfig
14
from vllm.logger import init_logger
Terry's avatar
Terry committed
15
from vllm.lora.models import (LoRAModel, LoRAModelManager,
16
                              LRUCacheLoRAModelManager, create_lora_manager)
17
from vllm.lora.peft_helper import PEFTHelper
18
from vllm.lora.request import LoRARequest
19
from vllm.lora.utils import get_adapter_absolute_path
20

21
logger = init_logger(__name__)
22
23


24
class WorkerLoRAManager(AbstractWorkerManager):
25
26
27
28
29
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

30
    _manager_cls: Type[LoRAModelManager] = LoRAModelManager
31
32
33
34
35
36
37
38

    def __init__(
        self,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
Terry's avatar
Terry committed
39
40
        embedding_modules: Dict[str, str],
        embedding_padding_modules: List[str],
41
        lora_model_cls: Type[LoRAModel] = LoRAModel,
42
        max_position_embeddings: Optional[int] = None,
43
44
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
45
46
        self.embedding_modules = embedding_modules
        self.embedding_padding_modules = embedding_padding_modules
47
48
49
50
51
52
53
        self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
        self.max_num_seqs = max_num_seqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.vocab_size = vocab_size
        self.lora_config = lora_config
        self.max_position_embeddings = max_position_embeddings
        super().__init__(device)
54
        # Lazily initialized by create_lora_manager.
55
56
57
58
59
60
61
62
63
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
79
            device=self.device,
80
            lora_manager_cls=self._manager_cls,
81
        )
82
        self._adapter_manager = lora_manager
83
84
        return lora_manager.model

85
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
86
        try:
87
            model = self._adapter_manager.model
88
89
            supported_lora_modules = model.supported_lora_modules
            packed_modules_mapping = model.packed_modules_mapping
90
            expected_lora_modules: List[str] = []
91
92
93
94
95
96
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
                    expected_lora_modules.extend(
                        packed_modules_mapping[module])
                else:
                    expected_lora_modules.append(module)
97
98

            expected_lora_modules = list(set(expected_lora_modules))
99
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
100

101
102
103
104
105
106
107
            peft_helper = PEFTHelper.from_local_dir(
                lora_path, self.max_position_embeddings)

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

108
109
110
111
112
113
114
            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
            hf_to_vllm_mapper = None
            if (hasattr(model, "hf_to_vllm_mapper")
                    and model.hf_to_vllm_mapper is not None):
                hf_to_vllm_mapper = model.hf_to_vllm_mapper

115
            lora = self._lora_model_cls.from_local_checkpoint(
116
                lora_path,
117
                expected_lora_modules,
118
                peft_helper=peft_helper,
119
120
121
122
123
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
                target_embedding_padding=self.vocab_size +
                self.lora_config.lora_extra_vocab_size,
Terry's avatar
Terry committed
124
125
                embedding_modules=self.embedding_modules,
                embedding_padding_modules=self.embedding_padding_modules,
126
127
                weights_mapper=hf_to_vllm_mapper)

128
129
130
131
132
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
133
            # For NotFoundError
134
135
            raise ValueError(
                f"Loading lora {lora_request.lora_name} failed: No adapter "
136
                f"found for {lora_request.lora_path}") from e
137
        except Exception as e:
138
139
140
            # For BadRequestError
            raise e

141
        if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
142
143
144
            raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
                             f"is greater than lora_extra_vocab_size "
                             f"{self.lora_config.lora_extra_vocab_size}.")
145
146
147
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
148
        if lora_request.lora_int_id in self.list_adapters():
149
            return False
150
151
152
153
        if isinstance(self._cached_dummy_lora, LoRAModel):
            dummy_lora = self._cached_dummy_lora.clone(
                lora_request.lora_int_id)
        else:
154
            dummy_lora = self._adapter_manager.create_dummy_lora(
155
                lora_request.lora_int_id, rank, 1, self.embedding_modules)
156
157
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
158
        return self._adapter_manager.add_adapter(dummy_lora)
159

160
161
162
163
164
165
166
167
168
169
170
171
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

    def set_active_adapters(self, requests: Set[Any],
                            mapping: Optional[Any]) -> None:
        set_active_adapters_worker(requests, mapping, self._apply_adapters,
                                   self._adapter_manager.set_adapter_mapping)

    def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
        apply_adapters_worker(adapter_requests, self.list_adapters,
                              self._adapter_manager.adapter_slots,
                              self.remove_adapter, self.add_adapter)
172

173
174
175
176
177
    def add_adapter(self, adapter_request: Any) -> bool:
        return add_adapter_worker(adapter_request, self.list_adapters,
                                  self._load_adapter,
                                  self._adapter_manager.add_adapter,
                                  self._adapter_manager.activate_adapter)
178

179
180
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
181

182
183
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
184

185
186
    def list_adapters(self) -> Set[int]:
        return list_adapters_worker(self._adapter_manager.list_adapters)
187
188
189
190
191
192
193
194
195


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

196
    _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
197
198
199
200
201
202
203

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
204
            lora_manager_cls=self._manager_cls,
205
206
207
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
208
            device=self.device,
209
210
            max_num_batched_tokens=self.max_num_batched_tokens,
        )
211
        self._adapter_manager = lora_manager
212
213
        return lora_manager.model

214
    def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
215
216
217
218
        loras_map = {
            lora_request.lora_int_id: lora_request
            for lora_request in lora_requests if lora_request
        }
219
        if len(loras_map) > self._adapter_manager.lora_slots:
220
221
222
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
223
                f"({self._adapter_manager.lora_slots}).")
224
        for lora in loras_map.values():
225
            self.add_adapter(lora)
226

227
228
    def add_adapter(self, lora_request: LoRARequest) -> bool:
        if lora_request.lora_int_id not in self.list_adapters():
229
230
231
232
233
234
235
236
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
237
238
239
240
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
                assert isinstance(self._adapter_manager,
                                  LRUCacheLoRAModelManager)
                self._adapter_manager.remove_oldest_adapter()
241
            # Then add the new adapter to the cache
242
            loaded = self._adapter_manager.add_adapter(lora)
243
244
245
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
246
            loaded = self._adapter_manager.get_adapter(
247
                lora_request.lora_int_id) is not None
248
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
249
        return loaded