worker_manager.py 11.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from contextlib import contextmanager
5
from typing import Any, Literal
6
7
8

import torch

9
from vllm.config import VllmConfig
10
from vllm.logger import init_logger
11
12
from vllm.lora.lora_model import LoRAModel
from vllm.lora.model_manager import (
13
14
15
16
    LoRAModelManager,
    LRUCacheLoRAModelManager,
    create_lora_manager,
)
17
from vllm.lora.peft_helper import PEFTHelper
18
from vllm.lora.request import LoRARequest
19
from vllm.lora.utils import get_adapter_absolute_path
20

21
logger = init_logger(__name__)
22
23


24
class WorkerLoRAManager:
25
26
27
28
29
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

30
    _manager_cls: type[LoRAModelManager] = LoRAModelManager
31
32
33

    def __init__(
        self,
34
        vllm_config: VllmConfig,
35
        device: torch.device,
36
37
        embedding_modules: dict[str, str],
        lora_model_cls: type[LoRAModel] = LoRAModel,
38
39
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
40
        self.embedding_modules = embedding_modules
41
        self._cached_dummy_lora: None | Literal[False] | LoRAModel = False
42
43
        self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        self.max_num_batched_tokens = (
44
45
            vllm_config.scheduler_config.max_num_batched_tokens
        )
46
47
48
49
50
51
        self.vocab_size = vllm_config.model_config.get_vocab_size()
        self.lora_config = vllm_config.lora_config

        # Use get_text_config() in case of multimodal models
        text_config = vllm_config.model_config.hf_config.get_text_config()

52
53
54
55
56
57
58
59
60
61
62
63
        # For encoder-decoder models (e.g., Whisper), use max_target_positions
        # instead of max_position_embeddings
        # TODO: Generalize max_position_embeddings handling for
        # out-of-tree (OOT) encoder-decoder models
        if vllm_config.model_config.is_encoder_decoder:
            self.max_position_embeddings = getattr(
                text_config, "max_target_positions", None
            )
        else:
            self.max_position_embeddings = getattr(
                text_config, "max_position_embeddings", None
            )
64
        self.device = device
65
        # Lazily initialized by create_lora_manager.
66
67
68
69
70
71
72
73
74
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
75
76
77
78
79
80
81
82

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
83
        vllm_config: VllmConfig | None = None,
84
85
86
87
88
89
90
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
91
            device=self.device,
92
            lora_manager_cls=self._manager_cls,
93
            vllm_config=vllm_config,
94
        )
95
        self._adapter_manager = lora_manager
96
97
        return lora_manager.model

98
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
99
        try:
100
101
            supported_lora_modules = self._adapter_manager.supported_lora_modules
            packed_modules_mapping = self._adapter_manager.packed_modules_mapping
102
            expected_lora_lst: list[str] = []
103
104
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
105
                    expected_lora_lst.extend(packed_modules_mapping[module])
106
                else:
107
                    expected_lora_lst.append(module)
108
                if module == "experts":
109
110
                    expected_lora_lst.append(module)
            expected_lora_modules = set(expected_lora_lst)
111
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
112

113
            peft_helper = PEFTHelper.from_local_dir(
114
115
116
117
                lora_path,
                self.max_position_embeddings,
                lora_request.tensorizer_config_dict,
            )
118
119
120
121
122

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

123
124
            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
125
            model = self._adapter_manager.model
126
            hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
127

128
129
130
            # Get model-defined prefixes to skip during LoRA loading.
            lora_skip_prefixes = getattr(model, "lora_skip_prefixes", None)

131
            lora = self._lora_model_cls.from_local_checkpoint(
132
                lora_path,
133
                expected_lora_modules,
134
                peft_helper=peft_helper,
135
136
137
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
138
                model_vocab_size=self.vocab_size,
139
                tensorizer_config_dict=lora_request.tensorizer_config_dict,
140
                weights_mapper=hf_to_vllm_mapper,
141
                skip_prefixes=lora_skip_prefixes,
142
            )
143

144
145
146
147
148
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
149
            # For NotFoundError
150
151
            raise ValueError(
                f"Loading lora {lora_request.lora_name} failed: No adapter "
152
153
                f"found for {lora_request.lora_path}"
            ) from e
154
        except Exception as e:
155
156
157
            # For BadRequestError
            raise e

158
159
160
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
161
        if lora_request.lora_int_id in self.list_adapters():
162
            return False
163
        if isinstance(self._cached_dummy_lora, LoRAModel):
164
            dummy_lora = self._cached_dummy_lora.clone(lora_request.lora_int_id)
165
        else:
166
            dummy_lora = self._adapter_manager.create_dummy_lora(
167
168
                lora_request.lora_int_id, rank, self.embedding_modules
            )
169
170
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
171
        return self._adapter_manager.add_adapter(dummy_lora)
172

173
174
175
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

176
    def set_active_adapters(self, requests: set[Any], mapping: Any | None) -> None:
177
178
179
        self._apply_adapters(requests)
        if mapping is not None:
            self._adapter_manager.set_adapter_mapping(mapping)
180

181
182
183
184
185
186
    def supports_tower_connector_lora(self) -> bool:
        return (
            self._adapter_manager.supports_mm
            and self._adapter_manager.supports_tower_connector_lora
        )

187
    def _apply_adapters(self, adapter_requests: set[Any]) -> None:
188
189
190
        existing_adapters = self.list_adapters()
        models_map = {
            adapter_request.adapter_id: adapter_request
191
192
            for adapter_request in adapter_requests
            if adapter_request
193
194
195
196
197
        }
        if len(models_map) > self._adapter_manager.adapter_slots:
            raise RuntimeError(
                f"Number of requested models ({len(models_map)}) is greater "
                "than the number of GPU model slots "
198
199
                f"({self._adapter_manager.adapter_slots})."
            )
200
201
202
203
204
        requested_ids = set(models_map)
        for adapter_id in existing_adapters - requested_ids:
            self.remove_adapter(adapter_id)
        for adapter_id in requested_ids - existing_adapters:
            self.add_adapter(models_map[adapter_id])
205

206
    def add_adapter(self, adapter_request: Any) -> bool:
207
208
209
210
211
212
        if adapter_request.adapter_id in self.list_adapters():
            return False
        loaded_adapter = self._load_adapter(adapter_request)
        loaded = self._adapter_manager.add_adapter(loaded_adapter)
        self._adapter_manager.activate_adapter(loaded_adapter.id)
        return loaded
213

214
215
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
216

217
218
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
219

220
    def list_adapters(self) -> set[int]:
221
        return set(self._adapter_manager.list_adapters())
222
223
224
225
226
227
228
229
230


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

231
    _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
232
233
234
235

    def create_lora_manager(
        self,
        model: torch.nn.Module,
236
        vllm_config: VllmConfig | None = None,
237
238
239
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
240
            lora_manager_cls=self._manager_cls,
241
242
243
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
244
            device=self.device,
245
            max_num_batched_tokens=self.max_num_batched_tokens,
246
            vllm_config=vllm_config,
247
        )
248
        self._adapter_manager = lora_manager
249
250
        return lora_manager.model

251
    def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
252
253
        loras_map = {
            lora_request.lora_int_id: lora_request
254
255
            for lora_request in lora_requests
            if lora_request
256
        }
257
        if len(loras_map) > self._adapter_manager.lora_slots:
258
259
260
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
261
262
                f"({self._adapter_manager.lora_slots})."
            )
263
        for lora in loras_map.values():
264
            self.add_adapter(lora)
265

266
    def add_adapter(self, lora_request: LoRARequest) -> bool:
267
268
269
270
271
        # Note that this method is not thread-safe. It may be invoked multiple
        # times for the same adapter when using multiple API servers.
        # This is ok because it's currently only called from
        # the single-threaded core engine loop.

272
273
274
275
        if (
            lora_request.lora_int_id not in self.list_adapters()
            or lora_request.load_inplace
        ):
276
277
278
279
280
281
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

282
283
284
285
            # Remove the existing adapter if it exists
            # Use case for LoRA inplace
            self._adapter_manager.remove_adapter(lora.id)

286
287
            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
288
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
289
                assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager)
290
                self._adapter_manager.remove_oldest_adapter()
291
            # Then add the new adapter to the cache
292
            loaded = self._adapter_manager.add_adapter(lora)
293
294
295
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
296
297
298
            loaded = (
                self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None
            )
299
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
300
        return loaded