worker_manager.py 10.3 KB
Newer Older
1
from contextlib import contextmanager
2
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
3
4
5

import torch

6
7
8
9
10
from vllm.adapter_commons.utils import (add_adapter_worker,
                                        apply_adapters_worker,
                                        list_adapters_worker,
                                        set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
11
from vllm.config import LoRAConfig
12
from vllm.logger import init_logger
Terry's avatar
Terry committed
13
from vllm.lora.models import (LoRAModel, LoRAModelManager,
14
                              LRUCacheLoRAModelManager, create_lora_manager)
15
from vllm.lora.peft_helper import PEFTHelper
16
from vllm.lora.request import LoRARequest
17
from vllm.lora.utils import get_adapter_absolute_path
18

19
logger = init_logger(__name__)
20
21


22
class WorkerLoRAManager(AbstractWorkerManager):
23
24
25
26
27
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

28
    _manager_cls: Type[LoRAModelManager] = LoRAModelManager
29
30
31
32
33
34
35
36

    def __init__(
        self,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
Terry's avatar
Terry committed
37
38
        embedding_modules: Dict[str, str],
        embedding_padding_modules: List[str],
39
        lora_model_cls: Type[LoRAModel] = LoRAModel,
40
        max_position_embeddings: Optional[int] = None,
41
42
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
43
44
        self.embedding_modules = embedding_modules
        self.embedding_padding_modules = embedding_padding_modules
45
46
47
48
49
50
51
        self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
        self.max_num_seqs = max_num_seqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.vocab_size = vocab_size
        self.lora_config = lora_config
        self.max_position_embeddings = max_position_embeddings
        super().__init__(device)
52
        # Lazily initialized by create_lora_manager.
53
54
55
56
57
58
59
60
61
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
77
            device=self.device,
78
            lora_manager_cls=self._manager_cls,
79
        )
80
        self._adapter_manager = lora_manager
81
82
        return lora_manager.model

83
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
84
        try:
85
            model = self._adapter_manager.model
86
87
            supported_lora_modules = model.supported_lora_modules
            packed_modules_mapping = model.packed_modules_mapping
88
            expected_lora_modules: List[str] = []
89
90
91
92
93
94
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
                    expected_lora_modules.extend(
                        packed_modules_mapping[module])
                else:
                    expected_lora_modules.append(module)
95
96

            expected_lora_modules = list(set(expected_lora_modules))
97
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
98

99
100
101
102
103
104
105
            peft_helper = PEFTHelper.from_local_dir(
                lora_path, self.max_position_embeddings)

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

106
107
108
109
110
111
112
            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
            hf_to_vllm_mapper = None
            if (hasattr(model, "hf_to_vllm_mapper")
                    and model.hf_to_vllm_mapper is not None):
                hf_to_vllm_mapper = model.hf_to_vllm_mapper

113
            lora = self._lora_model_cls.from_local_checkpoint(
114
                lora_path,
115
                expected_lora_modules,
116
                peft_helper=peft_helper,
117
118
119
120
121
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
                target_embedding_padding=self.vocab_size +
                self.lora_config.lora_extra_vocab_size,
Terry's avatar
Terry committed
122
123
                embedding_modules=self.embedding_modules,
                embedding_padding_modules=self.embedding_padding_modules,
124
125
                weights_mapper=hf_to_vllm_mapper)

126
127
128
129
130
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
131
            # For NotFoundError
132
133
134
            raise ValueError(
                f"Loading lora {lora_request.lora_name} failed: No adapter "
                f"found for {lora_path}") from e
135
        except Exception as e:
136
137
138
            # For BadRequestError
            raise e

139
        if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
140
141
142
            raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
                             f"is greater than lora_extra_vocab_size "
                             f"{self.lora_config.lora_extra_vocab_size}.")
143
144
145
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
146
        if lora_request.lora_int_id in self.list_adapters():
147
            return False
148
149
150
151
        if isinstance(self._cached_dummy_lora, LoRAModel):
            dummy_lora = self._cached_dummy_lora.clone(
                lora_request.lora_int_id)
        else:
152
            dummy_lora = self._adapter_manager.create_dummy_lora(
153
                lora_request.lora_int_id, rank, 1, self.embedding_modules)
154
155
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
156
        return self._adapter_manager.add_adapter(dummy_lora)
157

158
159
160
161
162
163
164
165
166
167
168
169
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

    def set_active_adapters(self, requests: Set[Any],
                            mapping: Optional[Any]) -> None:
        set_active_adapters_worker(requests, mapping, self._apply_adapters,
                                   self._adapter_manager.set_adapter_mapping)

    def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
        apply_adapters_worker(adapter_requests, self.list_adapters,
                              self._adapter_manager.adapter_slots,
                              self.remove_adapter, self.add_adapter)
170

171
172
173
174
175
    def add_adapter(self, adapter_request: Any) -> bool:
        return add_adapter_worker(adapter_request, self.list_adapters,
                                  self._load_adapter,
                                  self._adapter_manager.add_adapter,
                                  self._adapter_manager.activate_adapter)
176

177
178
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
179

180
181
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
182

183
184
    def list_adapters(self) -> Set[int]:
        return list_adapters_worker(self._adapter_manager.list_adapters)
185
186
187
188
189
190
191
192
193


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

194
    _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
195
196
197
198
199
200
201

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
202
            lora_manager_cls=self._manager_cls,
203
204
205
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
206
            device=self.device,
207
208
            max_num_batched_tokens=self.max_num_batched_tokens,
        )
209
        self._adapter_manager = lora_manager
210
211
        return lora_manager.model

212
    def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
213
214
215
216
        loras_map = {
            lora_request.lora_int_id: lora_request
            for lora_request in lora_requests if lora_request
        }
217
        if len(loras_map) > self._adapter_manager.lora_slots:
218
219
220
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
221
                f"({self._adapter_manager.lora_slots}).")
222
        for lora in loras_map.values():
223
            self.add_adapter(lora)
224

225
226
    def add_adapter(self, lora_request: LoRARequest) -> bool:
        if lora_request.lora_int_id not in self.list_adapters():
227
228
229
230
231
232
233
234
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
235
236
237
238
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
                assert isinstance(self._adapter_manager,
                                  LRUCacheLoRAModelManager)
                self._adapter_manager.remove_oldest_adapter()
239
            # Then add the new adapter to the cache
240
            loaded = self._adapter_manager.add_adapter(lora)
241
242
243
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
244
            loaded = self._adapter_manager.get_adapter(
245
                lora_request.lora_int_id) is not None
246
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
247
        return loaded