worker_manager.py 10.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from contextlib import contextmanager
5
from typing import Any, Literal, Optional, Union
6
7
8

import torch

9
10
11
12
13
from vllm.adapter_commons.utils import (add_adapter_worker,
                                        apply_adapters_worker,
                                        list_adapters_worker,
                                        set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
Terry's avatar
Terry committed
16
from vllm.lora.models import (LoRAModel, LoRAModelManager,
17
                              LRUCacheLoRAModelManager, create_lora_manager)
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.request import LoRARequest
20
from vllm.lora.utils import get_adapter_absolute_path
21

22
logger = init_logger(__name__)
23
24


25
class WorkerLoRAManager(AbstractWorkerManager):
26
27
28
29
30
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

31
    _manager_cls: type[LoRAModelManager] = LoRAModelManager
32
33
34
35
36
37
38
39

    def __init__(
        self,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
40
41
42
        embedding_modules: dict[str, str],
        embedding_padding_modules: list[str],
        lora_model_cls: type[LoRAModel] = LoRAModel,
43
        max_position_embeddings: Optional[int] = None,
44
45
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
46
47
        self.embedding_modules = embedding_modules
        self.embedding_padding_modules = embedding_padding_modules
48
49
50
51
52
53
54
        self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
        self.max_num_seqs = max_num_seqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.vocab_size = vocab_size
        self.lora_config = lora_config
        self.max_position_embeddings = max_position_embeddings
        super().__init__(device)
55
        # Lazily initialized by create_lora_manager.
56
57
58
59
60
61
62
63
64
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
80
            device=self.device,
81
            lora_manager_cls=self._manager_cls,
82
        )
83
        self._adapter_manager = lora_manager
84
85
        return lora_manager.model

86
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
87
        try:
88
89
90
91
            supported_lora_modules = (
                self._adapter_manager.supported_lora_modules)
            packed_modules_mapping = (
                self._adapter_manager.packed_modules_mapping)
92
            expected_lora_modules: list[str] = []
93
94
95
96
97
98
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
                    expected_lora_modules.extend(
                        packed_modules_mapping[module])
                else:
                    expected_lora_modules.append(module)
99
100

            expected_lora_modules = list(set(expected_lora_modules))
101
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
102

103
            peft_helper = PEFTHelper.from_local_dir(
104
105
                lora_path, self.max_position_embeddings,
                lora_request.tensorizer_config_dict)
106
107
108
109
110

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

111
112
            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
113
            model = self._adapter_manager.model
114
            hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
115

116
            lora = self._lora_model_cls.from_local_checkpoint(
117
                lora_path,
118
                expected_lora_modules,
119
                peft_helper=peft_helper,
120
121
122
123
124
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
                target_embedding_padding=self.vocab_size +
                self.lora_config.lora_extra_vocab_size,
Terry's avatar
Terry committed
125
126
                embedding_modules=self.embedding_modules,
                embedding_padding_modules=self.embedding_padding_modules,
127
                tensorizer_config_dict=lora_request.tensorizer_config_dict,
128
129
                weights_mapper=hf_to_vllm_mapper)

130
131
132
133
134
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
135
            # For NotFoundError
136
137
            raise ValueError(
                f"Loading lora {lora_request.lora_name} failed: No adapter "
138
                f"found for {lora_request.lora_path}") from e
139
        except Exception as e:
140
141
142
            # For BadRequestError
            raise e

143
        if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
144
145
146
            raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
                             f"is greater than lora_extra_vocab_size "
                             f"{self.lora_config.lora_extra_vocab_size}.")
147
148
149
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
150
        if lora_request.lora_int_id in self.list_adapters():
151
            return False
152
153
154
155
        if isinstance(self._cached_dummy_lora, LoRAModel):
            dummy_lora = self._cached_dummy_lora.clone(
                lora_request.lora_int_id)
        else:
156
            dummy_lora = self._adapter_manager.create_dummy_lora(
157
                lora_request.lora_int_id, rank, self.embedding_modules)
158
159
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
160
        return self._adapter_manager.add_adapter(dummy_lora)
161

162
163
164
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

165
    def set_active_adapters(self, requests: set[Any],
166
167
168
169
                            mapping: Optional[Any]) -> None:
        set_active_adapters_worker(requests, mapping, self._apply_adapters,
                                   self._adapter_manager.set_adapter_mapping)

170
    def _apply_adapters(self, adapter_requests: set[Any]) -> None:
171
172
173
        apply_adapters_worker(adapter_requests, self.list_adapters,
                              self._adapter_manager.adapter_slots,
                              self.remove_adapter, self.add_adapter)
174

175
176
177
178
179
    def add_adapter(self, adapter_request: Any) -> bool:
        return add_adapter_worker(adapter_request, self.list_adapters,
                                  self._load_adapter,
                                  self._adapter_manager.add_adapter,
                                  self._adapter_manager.activate_adapter)
180

181
182
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
183

184
185
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
186

187
    def list_adapters(self) -> set[int]:
188
        return list_adapters_worker(self._adapter_manager.list_adapters)
189
190
191
192
193
194
195
196
197


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

198
    _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
199
200
201
202
203
204
205

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
206
            lora_manager_cls=self._manager_cls,
207
208
209
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
210
            device=self.device,
211
212
            max_num_batched_tokens=self.max_num_batched_tokens,
        )
213
        self._adapter_manager = lora_manager
214
215
        return lora_manager.model

216
    def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
217
218
219
220
        loras_map = {
            lora_request.lora_int_id: lora_request
            for lora_request in lora_requests if lora_request
        }
221
        if len(loras_map) > self._adapter_manager.lora_slots:
222
223
224
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
225
                f"({self._adapter_manager.lora_slots}).")
226
        for lora in loras_map.values():
227
            self.add_adapter(lora)
228

229
    def add_adapter(self, lora_request: LoRARequest) -> bool:
230
231
232
233
234
        # Note that this method is not thread-safe. It may be invoked multiple
        # times for the same adapter when using multiple API servers.
        # This is ok because it's currently only called from
        # the single-threaded core engine loop.

235
        if lora_request.lora_int_id not in self.list_adapters():
236
237
238
239
240
241
242
243
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
244
245
246
247
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
                assert isinstance(self._adapter_manager,
                                  LRUCacheLoRAModelManager)
                self._adapter_manager.remove_oldest_adapter()
248
            # Then add the new adapter to the cache
249
            loaded = self._adapter_manager.add_adapter(lora)
250
251
252
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
253
            loaded = self._adapter_manager.get_adapter(
254
                lora_request.lora_int_id) is not None
255
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
256
        return loaded