worker_manager.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from contextlib import contextmanager
5
from typing import Any, Literal, Optional, Union
6
7
8

import torch

9
from vllm.config import VllmConfig
10
from vllm.logger import init_logger
Terry's avatar
Terry committed
11
from vllm.lora.models import (LoRAModel, LoRAModelManager,
12
                              LRUCacheLoRAModelManager, create_lora_manager)
13
from vllm.lora.peft_helper import PEFTHelper
14
from vllm.lora.request import LoRARequest
15
from vllm.lora.utils import get_adapter_absolute_path
16

17
logger = init_logger(__name__)
18
19


20
class WorkerLoRAManager:
21
22
23
24
25
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

26
    _manager_cls: type[LoRAModelManager] = LoRAModelManager
27
28
29

    def __init__(
        self,
30
        vllm_config: VllmConfig,
31
        device: torch.device,
32
33
34
        embedding_modules: dict[str, str],
        embedding_padding_modules: list[str],
        lora_model_cls: type[LoRAModel] = LoRAModel,
35
36
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
37
38
        self.embedding_modules = embedding_modules
        self.embedding_padding_modules = embedding_padding_modules
39
        self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
40
41
42
43
44
45
46
47
48
49
        self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        self.max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)
        self.vocab_size = vllm_config.model_config.get_vocab_size()
        self.lora_config = vllm_config.lora_config

        # Use get_text_config() in case of multimodal models
        text_config = vllm_config.model_config.hf_config.get_text_config()

        self.max_position_embeddings = text_config.max_position_embeddings
50
        self.device = device
51
        # Lazily initialized by create_lora_manager.
52
53
54
55
56
57
58
59
60
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
76
            device=self.device,
77
            lora_manager_cls=self._manager_cls,
78
        )
79
        self._adapter_manager = lora_manager
80
81
        return lora_manager.model

82
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
83
        try:
84
85
86
87
            supported_lora_modules = (
                self._adapter_manager.supported_lora_modules)
            packed_modules_mapping = (
                self._adapter_manager.packed_modules_mapping)
88
            expected_lora_modules: list[str] = []
89
90
91
92
93
94
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
                    expected_lora_modules.extend(
                        packed_modules_mapping[module])
                else:
                    expected_lora_modules.append(module)
95
96

            expected_lora_modules = list(set(expected_lora_modules))
97
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
98

99
            peft_helper = PEFTHelper.from_local_dir(
100
101
                lora_path, self.max_position_embeddings,
                lora_request.tensorizer_config_dict)
102
103
104
105
106

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

107
108
            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
109
            model = self._adapter_manager.model
110
            hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
111

112
            lora = self._lora_model_cls.from_local_checkpoint(
113
                lora_path,
114
                expected_lora_modules,
115
                peft_helper=peft_helper,
116
117
118
119
120
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
                target_embedding_padding=self.vocab_size +
                self.lora_config.lora_extra_vocab_size,
Terry's avatar
Terry committed
121
122
                embedding_modules=self.embedding_modules,
                embedding_padding_modules=self.embedding_padding_modules,
123
                tensorizer_config_dict=lora_request.tensorizer_config_dict,
124
125
                weights_mapper=hf_to_vllm_mapper)

126
127
128
129
130
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
131
            # For NotFoundError
132
133
            raise ValueError(
                f"Loading lora {lora_request.lora_name} failed: No adapter "
134
                f"found for {lora_request.lora_path}") from e
135
        except Exception as e:
136
137
138
            # For BadRequestError
            raise e

139
        if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
140
141
142
            raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
                             f"is greater than lora_extra_vocab_size "
                             f"{self.lora_config.lora_extra_vocab_size}.")
143
144
145
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
146
        if lora_request.lora_int_id in self.list_adapters():
147
            return False
148
149
150
151
        if isinstance(self._cached_dummy_lora, LoRAModel):
            dummy_lora = self._cached_dummy_lora.clone(
                lora_request.lora_int_id)
        else:
152
            dummy_lora = self._adapter_manager.create_dummy_lora(
153
                lora_request.lora_int_id, rank, self.embedding_modules)
154
155
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
156
        return self._adapter_manager.add_adapter(dummy_lora)
157

158
159
160
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

161
    def set_active_adapters(self, requests: set[Any],
162
                            mapping: Optional[Any]) -> None:
163
164
165
        self._apply_adapters(requests)
        if mapping is not None:
            self._adapter_manager.set_adapter_mapping(mapping)
166

167
    def _apply_adapters(self, adapter_requests: set[Any]) -> None:
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        existing_adapters = self.list_adapters()
        models_map = {
            adapter_request.adapter_id: adapter_request
            for adapter_request in adapter_requests if adapter_request
        }
        if len(models_map) > self._adapter_manager.adapter_slots:
            raise RuntimeError(
                f"Number of requested models ({len(models_map)}) is greater "
                "than the number of GPU model slots "
                f"({self._adapter_manager.adapter_slots}).")
        requested_ids = set(models_map)
        for adapter_id in existing_adapters - requested_ids:
            self.remove_adapter(adapter_id)
        for adapter_id in requested_ids - existing_adapters:
            self.add_adapter(models_map[adapter_id])
183

184
    def add_adapter(self, adapter_request: Any) -> bool:
185
186
187
188
189
190
        if adapter_request.adapter_id in self.list_adapters():
            return False
        loaded_adapter = self._load_adapter(adapter_request)
        loaded = self._adapter_manager.add_adapter(loaded_adapter)
        self._adapter_manager.activate_adapter(loaded_adapter.id)
        return loaded
191

192
193
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
194

195
196
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
197

198
    def list_adapters(self) -> set[int]:
199
        return set(self._adapter_manager.list_adapters())
200
201
202
203
204
205
206
207
208


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

209
    _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
210
211
212
213
214
215
216

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
217
            lora_manager_cls=self._manager_cls,
218
219
220
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
221
            device=self.device,
222
223
            max_num_batched_tokens=self.max_num_batched_tokens,
        )
224
        self._adapter_manager = lora_manager
225
226
        return lora_manager.model

227
    def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
228
229
230
231
        loras_map = {
            lora_request.lora_int_id: lora_request
            for lora_request in lora_requests if lora_request
        }
232
        if len(loras_map) > self._adapter_manager.lora_slots:
233
234
235
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
236
                f"({self._adapter_manager.lora_slots}).")
237
        for lora in loras_map.values():
238
            self.add_adapter(lora)
239

240
    def add_adapter(self, lora_request: LoRARequest) -> bool:
241
242
243
244
245
        # Note that this method is not thread-safe. It may be invoked multiple
        # times for the same adapter when using multiple API servers.
        # This is ok because it's currently only called from
        # the single-threaded core engine loop.

246
        if lora_request.lora_int_id not in self.list_adapters():
247
248
249
250
251
252
253
254
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
255
256
257
258
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
                assert isinstance(self._adapter_manager,
                                  LRUCacheLoRAModelManager)
                self._adapter_manager.remove_oldest_adapter()
259
            # Then add the new adapter to the cache
260
            loaded = self._adapter_manager.add_adapter(lora)
261
262
263
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
264
            loaded = self._adapter_manager.get_adapter(
265
                lora_request.lora_int_id) is not None
266
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
267
        return loaded