worker_manager.py 11.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from contextlib import contextmanager
5
from typing import Any, Literal
6
7
8

import torch

9
from vllm.config import VllmConfig
10
from vllm.logger import init_logger
11
12
13
14
15
16
from vllm.lora.models import (
    LoRAModel,
    LoRAModelManager,
    LRUCacheLoRAModelManager,
    create_lora_manager,
)
17
from vllm.lora.peft_helper import PEFTHelper
18
from vllm.lora.request import LoRARequest
19
from vllm.lora.utils import get_adapter_absolute_path
20

21
logger = init_logger(__name__)
22
23


24
class WorkerLoRAManager:
25
26
27
28
29
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

30
    _manager_cls: type[LoRAModelManager] = LoRAModelManager
31
32
33

    def __init__(
        self,
34
        vllm_config: VllmConfig,
35
        device: torch.device,
36
37
38
        embedding_modules: dict[str, str],
        embedding_padding_modules: list[str],
        lora_model_cls: type[LoRAModel] = LoRAModel,
39
40
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
41
42
        self.embedding_modules = embedding_modules
        self.embedding_padding_modules = embedding_padding_modules
43
        self._cached_dummy_lora: None | Literal[False] | LoRAModel = False
44
45
        self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        self.max_num_batched_tokens = (
46
47
            vllm_config.scheduler_config.max_num_batched_tokens
        )
48
49
50
51
52
53
54
        self.vocab_size = vllm_config.model_config.get_vocab_size()
        self.lora_config = vllm_config.lora_config

        # Use get_text_config() in case of multimodal models
        text_config = vllm_config.model_config.hf_config.get_text_config()

        self.max_position_embeddings = text_config.max_position_embeddings
55
        self.device = device
56
        # Lazily initialized by create_lora_manager.
57
58
59
60
61
62
63
64
65
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
81
            device=self.device,
82
            lora_manager_cls=self._manager_cls,
83
        )
84
        self._adapter_manager = lora_manager
85
86
        return lora_manager.model

87
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
88
        try:
89
90
            supported_lora_modules = self._adapter_manager.supported_lora_modules
            packed_modules_mapping = self._adapter_manager.packed_modules_mapping
91
            expected_lora_modules: list[str] = []
92
93
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
94
                    expected_lora_modules.extend(packed_modules_mapping[module])
95
96
                else:
                    expected_lora_modules.append(module)
97
98
                if module == "experts":
                    expected_lora_modules.append(module)
99
            expected_lora_modules = list(set(expected_lora_modules))
100
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
101

102
            peft_helper = PEFTHelper.from_local_dir(
103
104
105
106
                lora_path,
                self.max_position_embeddings,
                lora_request.tensorizer_config_dict,
            )
107
108
109
110
111

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

112
113
            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
114
            model = self._adapter_manager.model
115
            hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
116

117
            lora = self._lora_model_cls.from_local_checkpoint(
118
                lora_path,
119
                expected_lora_modules,
120
                peft_helper=peft_helper,
121
122
123
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
124
125
                target_embedding_padding=self.vocab_size
                + self.lora_config.lora_extra_vocab_size,
Terry's avatar
Terry committed
126
127
                embedding_modules=self.embedding_modules,
                embedding_padding_modules=self.embedding_padding_modules,
128
                tensorizer_config_dict=lora_request.tensorizer_config_dict,
129
130
                weights_mapper=hf_to_vllm_mapper,
            )
131

132
133
134
135
136
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
137
            # For NotFoundError
138
139
            raise ValueError(
                f"Loading lora {lora_request.lora_name} failed: No adapter "
140
141
                f"found for {lora_request.lora_path}"
            ) from e
142
        except Exception as e:
143
144
145
            # For BadRequestError
            raise e

146
        if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
147
148
149
150
151
            raise ValueError(
                f"LoRA added vocab size {lora.extra_vocab_size} "
                f"is greater than lora_extra_vocab_size "
                f"{self.lora_config.lora_extra_vocab_size}."
            )
152
153
154
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
155
        if lora_request.lora_int_id in self.list_adapters():
156
            return False
157
        if isinstance(self._cached_dummy_lora, LoRAModel):
158
            dummy_lora = self._cached_dummy_lora.clone(lora_request.lora_int_id)
159
        else:
160
            dummy_lora = self._adapter_manager.create_dummy_lora(
161
162
                lora_request.lora_int_id, rank, self.embedding_modules
            )
163
164
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
165
        return self._adapter_manager.add_adapter(dummy_lora)
166

167
168
169
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

170
    def set_active_adapters(self, requests: set[Any], mapping: Any | None) -> None:
171
172
173
        self._apply_adapters(requests)
        if mapping is not None:
            self._adapter_manager.set_adapter_mapping(mapping)
174

175
    def _apply_adapters(self, adapter_requests: set[Any]) -> None:
176
177
178
        existing_adapters = self.list_adapters()
        models_map = {
            adapter_request.adapter_id: adapter_request
179
180
            for adapter_request in adapter_requests
            if adapter_request
181
182
183
184
185
        }
        if len(models_map) > self._adapter_manager.adapter_slots:
            raise RuntimeError(
                f"Number of requested models ({len(models_map)}) is greater "
                "than the number of GPU model slots "
186
187
                f"({self._adapter_manager.adapter_slots})."
            )
188
189
190
191
192
        requested_ids = set(models_map)
        for adapter_id in existing_adapters - requested_ids:
            self.remove_adapter(adapter_id)
        for adapter_id in requested_ids - existing_adapters:
            self.add_adapter(models_map[adapter_id])
193

194
    def add_adapter(self, adapter_request: Any) -> bool:
195
196
197
198
199
200
        if adapter_request.adapter_id in self.list_adapters():
            return False
        loaded_adapter = self._load_adapter(adapter_request)
        loaded = self._adapter_manager.add_adapter(loaded_adapter)
        self._adapter_manager.activate_adapter(loaded_adapter.id)
        return loaded
201

202
203
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
204

205
206
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
207

208
    def list_adapters(self) -> set[int]:
209
        return set(self._adapter_manager.list_adapters())
210
211
212
213
214
215
216
217
218


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

219
    _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
220
221
222
223
224
225
226

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
227
            lora_manager_cls=self._manager_cls,
228
229
230
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
231
            device=self.device,
232
233
            max_num_batched_tokens=self.max_num_batched_tokens,
        )
234
        self._adapter_manager = lora_manager
235
236
        return lora_manager.model

237
    def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
238
239
        loras_map = {
            lora_request.lora_int_id: lora_request
240
241
            for lora_request in lora_requests
            if lora_request
242
        }
243
        if len(loras_map) > self._adapter_manager.lora_slots:
244
245
246
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
247
248
                f"({self._adapter_manager.lora_slots})."
            )
249
        for lora in loras_map.values():
250
            self.add_adapter(lora)
251

252
    def add_adapter(self, lora_request: LoRARequest) -> bool:
253
254
255
256
257
        # Note that this method is not thread-safe. It may be invoked multiple
        # times for the same adapter when using multiple API servers.
        # This is ok because it's currently only called from
        # the single-threaded core engine loop.

258
        if lora_request.lora_int_id not in self.list_adapters():
259
260
261
262
263
264
265
266
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
267
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
268
                assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager)
269
                self._adapter_manager.remove_oldest_adapter()
270
            # Then add the new adapter to the cache
271
            loaded = self._adapter_manager.add_adapter(lora)
272
273
274
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
275
276
277
            loaded = (
                self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None
            )
278
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
279
        return loaded