worker_manager.py 13.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from contextlib import contextmanager
5
from typing import Any, Literal
6
7
8

import torch

9
from vllm.config import VllmConfig
10
from vllm.exceptions import LoRAAdapterNotFoundError
11
from vllm.logger import init_logger
12
13
from vllm.lora.lora_model import LoRAModel
from vllm.lora.model_manager import (
14
15
16
17
    LoRAModelManager,
    LRUCacheLoRAModelManager,
    create_lora_manager,
)
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.request import LoRARequest
20
21
22
23
24
from vllm.lora.utils import (
    get_adapter_absolute_path,
    is_in_target_modules,
    is_supported_lora_module,
)
25

26
logger = init_logger(__name__)
27
28


29
class WorkerLoRAManager:
30
31
32
33
34
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

35
    _manager_cls: type[LoRAModelManager] = LoRAModelManager
36
37
38

    def __init__(
        self,
39
        vllm_config: VllmConfig,
40
        device: torch.device,
41
42
        embedding_modules: dict[str, str],
        lora_model_cls: type[LoRAModel] = LoRAModel,
43
44
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
45
        self.embedding_modules = embedding_modules
46
        self._cached_dummy_lora: None | Literal[False] | LoRAModel = False
47
48
        self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        self.max_num_batched_tokens = (
49
50
            vllm_config.scheduler_config.max_num_batched_tokens
        )
51
52
53
54
55
56
        self.vocab_size = vllm_config.model_config.get_vocab_size()
        self.lora_config = vllm_config.lora_config

        # Use get_text_config() in case of multimodal models
        text_config = vllm_config.model_config.hf_config.get_text_config()

57
58
59
60
61
62
63
64
65
66
67
68
        # For encoder-decoder models (e.g., Whisper), use max_target_positions
        # instead of max_position_embeddings
        # TODO: Generalize max_position_embeddings handling for
        # out-of-tree (OOT) encoder-decoder models
        if vllm_config.model_config.is_encoder_decoder:
            self.max_position_embeddings = getattr(
                text_config, "max_target_positions", None
            )
        else:
            self.max_position_embeddings = getattr(
                text_config, "max_position_embeddings", None
            )
69
        self.device = device
70
        # Lazily initialized by create_lora_manager.
71
72
73
74
75
76
77
78
79
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
80
81
82
83
84
85
86
87

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
88
        vllm_config: VllmConfig | None = None,
89
90
91
92
93
94
95
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
96
            device=self.device,
97
            lora_manager_cls=self._manager_cls,
98
            vllm_config=vllm_config,
99
        )
100
        self._adapter_manager = lora_manager
101
102
        return lora_manager.model

103
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
104
        try:
105
106
            supported_lora_modules = self._adapter_manager.supported_lora_modules
            packed_modules_mapping = self._adapter_manager.packed_modules_mapping
107
            expected_lora_lst: list[str] = []
108
109
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
110
                    expected_lora_lst.extend(packed_modules_mapping[module])
111
                else:
112
                    expected_lora_lst.append(module)
113
                if module == "experts":
114
115
                    expected_lora_lst.append(module)
            expected_lora_modules = set(expected_lora_lst)
116
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
117

118
            peft_helper = PEFTHelper.from_local_dir(
119
120
121
122
                lora_path,
                self.max_position_embeddings,
                lora_request.tensorizer_config_dict,
            )
123
124
125
126
127

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

128
129
            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
130
            model = self._adapter_manager.model
131
            hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
132

133
134
135
            # Get model-defined prefixes to skip during LoRA loading.
            lora_skip_prefixes = getattr(model, "lora_skip_prefixes", None)

136
            lora = self._lora_model_cls.from_local_checkpoint(
137
                lora_path,
138
                expected_lora_modules,
139
                peft_helper=peft_helper,
140
141
142
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
143
                model_vocab_size=self.vocab_size,
144
                tensorizer_config_dict=lora_request.tensorizer_config_dict,
145
                weights_mapper=hf_to_vllm_mapper,
146
                skip_prefixes=lora_skip_prefixes,
147
            )
148

149
150
            # Warn about adapter modules that will be ignored.
            target_modules = self.lora_config.target_modules
151
            expected_lora_modules_lst = list(expected_lora_modules)
152
            for module_name in lora.loras:
153
                if not is_supported_lora_module(module_name, expected_lora_modules_lst):
154
155
156
157
158
159
160
                    logger.warning_once(
                        "LoRA module '%s' in adapter '%s' is not in the "
                        "model's supported LoRA target modules [%s]. "
                        "These parameters will be ignored, which may "
                        "cause abnormal model behavior.",
                        module_name,
                        lora_request.lora_path,
161
                        ", ".join(sorted(expected_lora_modules_lst)),
162
                    )
163
164
165
166
167
                elif not is_in_target_modules(
                    module_name,
                    target_modules,
                    packed_modules_mapping,
                ):
168
169
170
171
172
173
174
175
176
                    logger.warning_once(
                        "LoRA module '%s' in adapter '%s' is not in the "
                        "deployment-time target_modules restriction [%s]."
                        " These parameters will be ignored.",
                        module_name,
                        lora_request.lora_path,
                        ", ".join(sorted(target_modules)),
                    )

177
178
179
180
181
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
182
            # For NotFoundError
183
184
            raise LoRAAdapterNotFoundError(
                lora_request.lora_name, lora_request.lora_path
185
            ) from e
186
        except Exception as e:
187
188
            raise e

189
190
191
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
192
        if lora_request.lora_int_id in self.list_adapters():
193
            return False
194
        if isinstance(self._cached_dummy_lora, LoRAModel):
195
            dummy_lora = self._cached_dummy_lora.clone(lora_request.lora_int_id)
196
        else:
197
            dummy_lora = self._adapter_manager.create_dummy_lora(
198
199
                lora_request.lora_int_id, rank, self.embedding_modules
            )
200
201
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
202
        return self._adapter_manager.add_adapter(dummy_lora)
203

204
205
206
    def get_dummy_lora_warmup_rank(self, default_rank: int) -> int:
        return self._adapter_manager.get_dummy_lora_warmup_rank(default_rank)

207
208
209
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

210
    def set_active_adapters(self, requests: set[Any], mapping: Any | None) -> None:
211
212
213
        self._apply_adapters(requests)
        if mapping is not None:
            self._adapter_manager.set_adapter_mapping(mapping)
214

215
216
217
218
219
220
    def supports_tower_connector_lora(self) -> bool:
        return (
            self._adapter_manager.supports_mm
            and self._adapter_manager.supports_tower_connector_lora
        )

221
    def _apply_adapters(self, adapter_requests: set[Any]) -> None:
222
223
224
        existing_adapters = self.list_adapters()
        models_map = {
            adapter_request.adapter_id: adapter_request
225
226
            for adapter_request in adapter_requests
            if adapter_request
227
228
229
230
231
        }
        if len(models_map) > self._adapter_manager.adapter_slots:
            raise RuntimeError(
                f"Number of requested models ({len(models_map)}) is greater "
                "than the number of GPU model slots "
232
233
                f"({self._adapter_manager.adapter_slots})."
            )
234
235
236
237
238
        requested_ids = set(models_map)
        for adapter_id in existing_adapters - requested_ids:
            self.remove_adapter(adapter_id)
        for adapter_id in requested_ids - existing_adapters:
            self.add_adapter(models_map[adapter_id])
239

240
    def add_adapter(self, adapter_request: Any) -> bool:
241
242
243
244
245
246
        if adapter_request.adapter_id in self.list_adapters():
            return False
        loaded_adapter = self._load_adapter(adapter_request)
        loaded = self._adapter_manager.add_adapter(loaded_adapter)
        self._adapter_manager.activate_adapter(loaded_adapter.id)
        return loaded
247

248
249
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
250

251
252
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
253

254
    def list_adapters(self) -> set[int]:
255
        return set(self._adapter_manager.list_adapters())
256
257
258
259
260
261
262
263
264


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

265
    _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
266
267
268
269

    def create_lora_manager(
        self,
        model: torch.nn.Module,
270
        vllm_config: VllmConfig | None = None,
271
272
273
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
274
            lora_manager_cls=self._manager_cls,
275
276
277
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
278
            device=self.device,
279
            max_num_batched_tokens=self.max_num_batched_tokens,
280
            vllm_config=vllm_config,
281
        )
282
        self._adapter_manager = lora_manager
283
284
        return lora_manager.model

285
    def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
286
287
        loras_map = {
            lora_request.lora_int_id: lora_request
288
289
            for lora_request in lora_requests
            if lora_request
290
        }
291
        if len(loras_map) > self._adapter_manager.lora_slots:
292
293
294
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
295
296
                f"({self._adapter_manager.lora_slots})."
            )
297
        for lora in loras_map.values():
298
            self.add_adapter(lora)
299

300
    def add_adapter(self, lora_request: LoRARequest) -> bool:
301
302
303
304
305
        # Note that this method is not thread-safe. It may be invoked multiple
        # times for the same adapter when using multiple API servers.
        # This is ok because it's currently only called from
        # the single-threaded core engine loop.

306
307
308
309
        if (
            lora_request.lora_int_id not in self.list_adapters()
            or lora_request.load_inplace
        ):
310
311
312
313
314
315
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

316
317
318
319
            # Remove the existing adapter if it exists
            # Use case for LoRA inplace
            self._adapter_manager.remove_adapter(lora.id)

320
321
            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
322
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
323
                assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager)
324
                self._adapter_manager.remove_oldest_adapter()
325
            # Then add the new adapter to the cache
326
            loaded = self._adapter_manager.add_adapter(lora)
327
328
329
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
330
331
332
            loaded = (
                self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None
            )
333
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
334
        return loaded