worker_manager.py 13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from contextlib import contextmanager
5
from typing import Any, Literal
6
7
8

import torch

9
from vllm.config import VllmConfig
10
from vllm.exceptions import LoRAAdapterNotFoundError
11
from vllm.logger import init_logger
12
13
from vllm.lora.lora_model import LoRAModel
from vllm.lora.model_manager import (
14
15
16
17
    LoRAModelManager,
    LRUCacheLoRAModelManager,
    create_lora_manager,
)
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.request import LoRARequest
20
21
22
23
24
from vllm.lora.utils import (
    get_adapter_absolute_path,
    is_in_target_modules,
    is_supported_lora_module,
)
25

26
logger = init_logger(__name__)
27
28


29
class WorkerLoRAManager:
30
31
32
33
34
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

35
    _manager_cls: type[LoRAModelManager] = LoRAModelManager
36
37
38

    def __init__(
        self,
39
        vllm_config: VllmConfig,
40
        device: torch.device,
41
42
        embedding_modules: dict[str, str],
        lora_model_cls: type[LoRAModel] = LoRAModel,
43
44
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
45
        self.embedding_modules = embedding_modules
46
        self._cached_dummy_lora: None | Literal[False] | LoRAModel = False
47
48
        self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        self.max_num_batched_tokens = (
49
50
            vllm_config.scheduler_config.max_num_batched_tokens
        )
51
52
53
54
55
56
        self.vocab_size = vllm_config.model_config.get_vocab_size()
        self.lora_config = vllm_config.lora_config

        # Use get_text_config() in case of multimodal models
        text_config = vllm_config.model_config.hf_config.get_text_config()

57
58
59
60
61
62
63
64
65
66
67
68
        # For encoder-decoder models (e.g., Whisper), use max_target_positions
        # instead of max_position_embeddings
        # TODO: Generalize max_position_embeddings handling for
        # out-of-tree (OOT) encoder-decoder models
        if vllm_config.model_config.is_encoder_decoder:
            self.max_position_embeddings = getattr(
                text_config, "max_target_positions", None
            )
        else:
            self.max_position_embeddings = getattr(
                text_config, "max_position_embeddings", None
            )
69
        self.device = device
70
        # Lazily initialized by create_lora_manager.
71
72
73
74
75
76
77
78
79
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
80
81
82
83
84
85
86
87

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
88
        vllm_config: VllmConfig | None = None,
89
90
91
92
93
94
95
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
96
            device=self.device,
97
            lora_manager_cls=self._manager_cls,
98
            vllm_config=vllm_config,
99
        )
100
        self._adapter_manager = lora_manager
101
102
        return lora_manager.model

103
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
104
        try:
105
106
            supported_lora_modules = self._adapter_manager.supported_lora_modules
            packed_modules_mapping = self._adapter_manager.packed_modules_mapping
107
            expected_lora_lst: list[str] = []
108
109
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
110
                    expected_lora_lst.extend(packed_modules_mapping[module])
111
                else:
112
                    expected_lora_lst.append(module)
113
                if module == "experts":
114
115
                    expected_lora_lst.append(module)
            expected_lora_modules = set(expected_lora_lst)
116
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
117

118
            peft_helper = PEFTHelper.from_local_dir(
119
120
121
122
                lora_path,
                self.max_position_embeddings,
                lora_request.tensorizer_config_dict,
            )
123
124
125
126
127

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

128
129
            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
130
            model = self._adapter_manager.model
131
            hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
132

133
134
135
            # Get model-defined prefixes to skip during LoRA loading.
            lora_skip_prefixes = getattr(model, "lora_skip_prefixes", None)

136
            lora = self._lora_model_cls.from_local_checkpoint(
137
                lora_path,
138
                expected_lora_modules,
139
                peft_helper=peft_helper,
140
141
142
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
143
                model_vocab_size=self.vocab_size,
144
                tensorizer_config_dict=lora_request.tensorizer_config_dict,
145
                weights_mapper=hf_to_vllm_mapper,
146
                skip_prefixes=lora_skip_prefixes,
147
            )
148

149
150
            # Warn about adapter modules that will be ignored.
            target_modules = self.lora_config.target_modules
151
            expected_lora_modules_lst = list(expected_lora_modules)
152
            for module_name in lora.loras:
153
                if not is_supported_lora_module(module_name, expected_lora_modules_lst):
154
155
156
157
158
159
160
                    logger.warning_once(
                        "LoRA module '%s' in adapter '%s' is not in the "
                        "model's supported LoRA target modules [%s]. "
                        "These parameters will be ignored, which may "
                        "cause abnormal model behavior.",
                        module_name,
                        lora_request.lora_path,
161
                        ", ".join(sorted(expected_lora_modules_lst)),
162
163
164
165
166
167
168
169
170
171
172
                    )
                elif not is_in_target_modules(module_name, target_modules):
                    logger.warning_once(
                        "LoRA module '%s' in adapter '%s' is not in the "
                        "deployment-time target_modules restriction [%s]."
                        " These parameters will be ignored.",
                        module_name,
                        lora_request.lora_path,
                        ", ".join(sorted(target_modules)),
                    )

173
174
175
176
177
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
178
            # For NotFoundError
179
180
            raise LoRAAdapterNotFoundError(
                lora_request.lora_name, lora_request.lora_path
181
            ) from e
182
        except Exception as e:
183
184
            raise e

185
186
187
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
188
        if lora_request.lora_int_id in self.list_adapters():
189
            return False
190
        if isinstance(self._cached_dummy_lora, LoRAModel):
191
            dummy_lora = self._cached_dummy_lora.clone(lora_request.lora_int_id)
192
        else:
193
            dummy_lora = self._adapter_manager.create_dummy_lora(
194
195
                lora_request.lora_int_id, rank, self.embedding_modules
            )
196
197
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
198
        return self._adapter_manager.add_adapter(dummy_lora)
199

200
201
202
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

203
    def set_active_adapters(self, requests: set[Any], mapping: Any | None) -> None:
204
205
206
        self._apply_adapters(requests)
        if mapping is not None:
            self._adapter_manager.set_adapter_mapping(mapping)
207

208
209
210
211
212
213
    def supports_tower_connector_lora(self) -> bool:
        return (
            self._adapter_manager.supports_mm
            and self._adapter_manager.supports_tower_connector_lora
        )

214
    def _apply_adapters(self, adapter_requests: set[Any]) -> None:
215
216
217
        existing_adapters = self.list_adapters()
        models_map = {
            adapter_request.adapter_id: adapter_request
218
219
            for adapter_request in adapter_requests
            if adapter_request
220
221
222
223
224
        }
        if len(models_map) > self._adapter_manager.adapter_slots:
            raise RuntimeError(
                f"Number of requested models ({len(models_map)}) is greater "
                "than the number of GPU model slots "
225
226
                f"({self._adapter_manager.adapter_slots})."
            )
227
228
229
230
231
        requested_ids = set(models_map)
        for adapter_id in existing_adapters - requested_ids:
            self.remove_adapter(adapter_id)
        for adapter_id in requested_ids - existing_adapters:
            self.add_adapter(models_map[adapter_id])
232

233
    def add_adapter(self, adapter_request: Any) -> bool:
234
235
236
237
238
239
        if adapter_request.adapter_id in self.list_adapters():
            return False
        loaded_adapter = self._load_adapter(adapter_request)
        loaded = self._adapter_manager.add_adapter(loaded_adapter)
        self._adapter_manager.activate_adapter(loaded_adapter.id)
        return loaded
240

241
242
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
243

244
245
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
246

247
    def list_adapters(self) -> set[int]:
248
        return set(self._adapter_manager.list_adapters())
249
250
251
252
253
254
255
256
257


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

258
    _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
259
260
261
262

    def create_lora_manager(
        self,
        model: torch.nn.Module,
263
        vllm_config: VllmConfig | None = None,
264
265
266
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
267
            lora_manager_cls=self._manager_cls,
268
269
270
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
271
            device=self.device,
272
            max_num_batched_tokens=self.max_num_batched_tokens,
273
            vllm_config=vllm_config,
274
        )
275
        self._adapter_manager = lora_manager
276
277
        return lora_manager.model

278
    def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
279
280
        loras_map = {
            lora_request.lora_int_id: lora_request
281
282
            for lora_request in lora_requests
            if lora_request
283
        }
284
        if len(loras_map) > self._adapter_manager.lora_slots:
285
286
287
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
288
289
                f"({self._adapter_manager.lora_slots})."
            )
290
        for lora in loras_map.values():
291
            self.add_adapter(lora)
292

293
    def add_adapter(self, lora_request: LoRARequest) -> bool:
294
295
296
297
298
        # Note that this method is not thread-safe. It may be invoked multiple
        # times for the same adapter when using multiple API servers.
        # This is ok because it's currently only called from
        # the single-threaded core engine loop.

299
300
301
302
        if (
            lora_request.lora_int_id not in self.list_adapters()
            or lora_request.load_inplace
        ):
303
304
305
306
307
308
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

309
310
311
312
            # Remove the existing adapter if it exists
            # Use case for LoRA inplace
            self._adapter_manager.remove_adapter(lora.id)

313
314
            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
315
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
316
                assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager)
317
                self._adapter_manager.remove_oldest_adapter()
318
            # Then add the new adapter to the cache
319
            loaded = self._adapter_manager.add_adapter(lora)
320
321
322
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
323
324
325
            loaded = (
                self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None
            )
326
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
327
        return loaded