worker_manager.py 13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from contextlib import contextmanager
5
from typing import Any, Literal
6
7
8

import torch

9
from vllm.config import VllmConfig
10
from vllm.exceptions import LoRAAdapterNotFoundError
11
from vllm.logger import init_logger
12
13
from vllm.lora.lora_model import LoRAModel
from vllm.lora.model_manager import (
14
15
16
17
    LoRAModelManager,
    LRUCacheLoRAModelManager,
    create_lora_manager,
)
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.request import LoRARequest
20
21
22
23
24
from vllm.lora.utils import (
    get_adapter_absolute_path,
    is_in_target_modules,
    is_supported_lora_module,
)
25

26
logger = init_logger(__name__)
27
28


29
class WorkerLoRAManager:
30
31
32
33
34
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

35
    _manager_cls: type[LoRAModelManager] = LoRAModelManager
36
37
38

    def __init__(
        self,
39
        vllm_config: VllmConfig,
40
        device: torch.device,
41
42
        embedding_modules: dict[str, str],
        lora_model_cls: type[LoRAModel] = LoRAModel,
43
44
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
45
        self.embedding_modules = embedding_modules
46
        self._cached_dummy_lora: None | Literal[False] | LoRAModel = False
47
48
        self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        self.max_num_batched_tokens = (
49
50
            vllm_config.scheduler_config.max_num_batched_tokens
        )
51
52
53
54
55
56
        self.vocab_size = vllm_config.model_config.get_vocab_size()
        self.lora_config = vllm_config.lora_config

        # Use get_text_config() in case of multimodal models
        text_config = vllm_config.model_config.hf_config.get_text_config()

57
58
59
60
61
62
63
64
65
66
67
68
        # For encoder-decoder models (e.g., Whisper), use max_target_positions
        # instead of max_position_embeddings
        # TODO: Generalize max_position_embeddings handling for
        # out-of-tree (OOT) encoder-decoder models
        if vllm_config.model_config.is_encoder_decoder:
            self.max_position_embeddings = getattr(
                text_config, "max_target_positions", None
            )
        else:
            self.max_position_embeddings = getattr(
                text_config, "max_position_embeddings", None
            )
69
        self.device = device
70
        # Lazily initialized by create_lora_manager.
71
72
73
74
75
76
77
78
79
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
80
81
82
83
84
85
86
87

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
88
        vllm_config: VllmConfig | None = None,
89
90
91
92
93
94
95
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
96
            device=self.device,
97
            lora_manager_cls=self._manager_cls,
98
            vllm_config=vllm_config,
99
        )
100
        self._adapter_manager = lora_manager
101
102
        return lora_manager.model

103
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
104
        try:
105
106
            supported_lora_modules = self._adapter_manager.supported_lora_modules
            packed_modules_mapping = self._adapter_manager.packed_modules_mapping
107
            expected_lora_lst: list[str] = []
108
109
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
110
                    expected_lora_lst.extend(packed_modules_mapping[module])
111
                else:
112
                    expected_lora_lst.append(module)
113
                if module == "experts":
114
115
                    expected_lora_lst.append(module)
            expected_lora_modules = set(expected_lora_lst)
116
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
117

118
            peft_helper = PEFTHelper.from_local_dir(
119
120
121
122
                lora_path,
                self.max_position_embeddings,
                lora_request.tensorizer_config_dict,
            )
123
124
125
126
127

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

128
129
            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
130
            model = self._adapter_manager.model
131
            hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
132

133
134
135
            # Get model-defined prefixes to skip during LoRA loading.
            lora_skip_prefixes = getattr(model, "lora_skip_prefixes", None)

136
            lora = self._lora_model_cls.from_local_checkpoint(
137
                lora_path,
138
                expected_lora_modules,
139
                peft_helper=peft_helper,
140
141
142
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
143
                model_vocab_size=self.vocab_size,
144
                tensorizer_config_dict=lora_request.tensorizer_config_dict,
145
                weights_mapper=hf_to_vllm_mapper,
146
                skip_prefixes=lora_skip_prefixes,
147
            )
148

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            # Warn about adapter modules that will be ignored.
            target_modules = self.lora_config.target_modules
            for module_name in lora.loras:
                if not is_supported_lora_module(module_name, supported_lora_modules):
                    logger.warning_once(
                        "LoRA module '%s' in adapter '%s' is not in the "
                        "model's supported LoRA target modules [%s]. "
                        "These parameters will be ignored, which may "
                        "cause abnormal model behavior.",
                        module_name,
                        lora_request.lora_path,
                        ", ".join(sorted(supported_lora_modules)),
                    )
                elif not is_in_target_modules(module_name, target_modules):
                    logger.warning_once(
                        "LoRA module '%s' in adapter '%s' is not in the "
                        "deployment-time target_modules restriction [%s]."
                        " These parameters will be ignored.",
                        module_name,
                        lora_request.lora_path,
                        ", ".join(sorted(target_modules)),
                    )

172
173
174
175
176
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
177
            # For NotFoundError
178
179
            raise LoRAAdapterNotFoundError(
                lora_request.lora_name, lora_request.lora_path
180
            ) from e
181
        except Exception as e:
182
183
            raise e

184
185
186
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
187
        if lora_request.lora_int_id in self.list_adapters():
188
            return False
189
        if isinstance(self._cached_dummy_lora, LoRAModel):
190
            dummy_lora = self._cached_dummy_lora.clone(lora_request.lora_int_id)
191
        else:
192
            dummy_lora = self._adapter_manager.create_dummy_lora(
193
194
                lora_request.lora_int_id, rank, self.embedding_modules
            )
195
196
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
197
        return self._adapter_manager.add_adapter(dummy_lora)
198

199
200
201
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

202
    def set_active_adapters(self, requests: set[Any], mapping: Any | None) -> None:
203
204
205
        self._apply_adapters(requests)
        if mapping is not None:
            self._adapter_manager.set_adapter_mapping(mapping)
206

207
208
209
210
211
212
    def supports_tower_connector_lora(self) -> bool:
        return (
            self._adapter_manager.supports_mm
            and self._adapter_manager.supports_tower_connector_lora
        )

213
    def _apply_adapters(self, adapter_requests: set[Any]) -> None:
214
215
216
        existing_adapters = self.list_adapters()
        models_map = {
            adapter_request.adapter_id: adapter_request
217
218
            for adapter_request in adapter_requests
            if adapter_request
219
220
221
222
223
        }
        if len(models_map) > self._adapter_manager.adapter_slots:
            raise RuntimeError(
                f"Number of requested models ({len(models_map)}) is greater "
                "than the number of GPU model slots "
224
225
                f"({self._adapter_manager.adapter_slots})."
            )
226
227
228
229
230
        requested_ids = set(models_map)
        for adapter_id in existing_adapters - requested_ids:
            self.remove_adapter(adapter_id)
        for adapter_id in requested_ids - existing_adapters:
            self.add_adapter(models_map[adapter_id])
231

232
    def add_adapter(self, adapter_request: Any) -> bool:
233
234
235
236
237
238
        if adapter_request.adapter_id in self.list_adapters():
            return False
        loaded_adapter = self._load_adapter(adapter_request)
        loaded = self._adapter_manager.add_adapter(loaded_adapter)
        self._adapter_manager.activate_adapter(loaded_adapter.id)
        return loaded
239

240
241
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
242

243
244
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
245

246
    def list_adapters(self) -> set[int]:
247
        return set(self._adapter_manager.list_adapters())
248
249
250
251
252
253
254
255
256


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

257
    _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
258
259
260
261

    def create_lora_manager(
        self,
        model: torch.nn.Module,
262
        vllm_config: VllmConfig | None = None,
263
264
265
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
266
            lora_manager_cls=self._manager_cls,
267
268
269
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
270
            device=self.device,
271
            max_num_batched_tokens=self.max_num_batched_tokens,
272
            vllm_config=vllm_config,
273
        )
274
        self._adapter_manager = lora_manager
275
276
        return lora_manager.model

277
    def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
278
279
        loras_map = {
            lora_request.lora_int_id: lora_request
280
281
            for lora_request in lora_requests
            if lora_request
282
        }
283
        if len(loras_map) > self._adapter_manager.lora_slots:
284
285
286
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
287
288
                f"({self._adapter_manager.lora_slots})."
            )
289
        for lora in loras_map.values():
290
            self.add_adapter(lora)
291

292
    def add_adapter(self, lora_request: LoRARequest) -> bool:
293
294
295
296
297
        # Note that this method is not thread-safe. It may be invoked multiple
        # times for the same adapter when using multiple API servers.
        # This is ok because it's currently only called from
        # the single-threaded core engine loop.

298
299
300
301
        if (
            lora_request.lora_int_id not in self.list_adapters()
            or lora_request.load_inplace
        ):
302
303
304
305
306
307
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

308
309
310
311
            # Remove the existing adapter if it exists
            # Use case for LoRA inplace
            self._adapter_manager.remove_adapter(lora.id)

312
313
            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
314
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
315
                assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager)
316
                self._adapter_manager.remove_oldest_adapter()
317
            # Then add the new adapter to the cache
318
            loaded = self._adapter_manager.add_adapter(lora)
319
320
321
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
322
323
324
            loaded = (
                self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None
            )
325
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
326
        return loaded