worker_manager.py 10.2 KB
Newer Older
1
from contextlib import contextmanager
2
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
3
4
5

import torch

6
7
8
9
10
from vllm.adapter_commons.utils import (add_adapter_worker,
                                        apply_adapters_worker,
                                        list_adapters_worker,
                                        set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
11
from vllm.config import LoRAConfig
12
from vllm.logger import init_logger
Terry's avatar
Terry committed
13
from vllm.lora.models import (LoRAModel, LoRAModelManager,
14
15
                              LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest
16
from vllm.lora.utils import get_adapter_absolute_path
17

18
logger = init_logger(__name__)
19
20


21
class WorkerLoRAManager(AbstractWorkerManager):
22
23
24
25
26
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

27
    _manager_cls: Type[LoRAModelManager] = LoRAModelManager
28
29
30
31
32
33
34
35

    def __init__(
        self,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
Terry's avatar
Terry committed
36
37
        embedding_modules: Dict[str, str],
        embedding_padding_modules: List[str],
38
        lora_model_cls: Type[LoRAModel] = LoRAModel,
39
        max_position_embeddings: Optional[int] = None,
40
41
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
42
43
        self.embedding_modules = embedding_modules
        self.embedding_padding_modules = embedding_padding_modules
44
45
46
47
48
49
50
        self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
        self.max_num_seqs = max_num_seqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.vocab_size = vocab_size
        self.lora_config = lora_config
        self.max_position_embeddings = max_position_embeddings
        super().__init__(device)
51
        # Lazily initialized by create_lora_manager.
52
53
54
55
56
57
58
59
60
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
76
            device=self.device,
77
            lora_manager_cls=self._manager_cls,
78
        )
79
        self._adapter_manager = lora_manager
80
81
        return lora_manager.model

82
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
83
        try:
84
            model = self._adapter_manager.model
85
86
            supported_lora_modules = model.supported_lora_modules
            packed_modules_mapping = model.packed_modules_mapping
87
            expected_lora_modules: List[str] = []
88
89
90
91
92
93
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
                    expected_lora_modules.extend(
                        packed_modules_mapping[module])
                else:
                    expected_lora_modules.append(module)
94
95

            expected_lora_modules = list(set(expected_lora_modules))
96
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
97
98
99
100
101
102
103
104

            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
            hf_to_vllm_mapper = None
            if (hasattr(model, "hf_to_vllm_mapper")
                    and model.hf_to_vllm_mapper is not None):
                hf_to_vllm_mapper = model.hf_to_vllm_mapper

105
            lora = self._lora_model_cls.from_local_checkpoint(
106
                lora_path,
107
                expected_lora_modules,
108
                max_position_embeddings=self.max_position_embeddings,
109
110
111
112
113
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
                target_embedding_padding=self.vocab_size +
                self.lora_config.lora_extra_vocab_size,
Terry's avatar
Terry committed
114
115
                embedding_modules=self.embedding_modules,
                embedding_padding_modules=self.embedding_padding_modules,
116
117
                weights_mapper=hf_to_vllm_mapper)

118
119
120
121
122
123
124
125
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
            raise ValueError(
                f"Loading lora {lora_request.lora_name} failed: No adapter "
                f"found for {lora_path}") from e
126
        except Exception as e:
127
            raise RuntimeError(f"Loading lora {lora_path} failed") from e
128
129
130
131
132
        if lora.rank > self.lora_config.max_lora_rank:
            raise ValueError(
                f"LoRA rank {lora.rank} is greater than max_lora_rank "
                f"{self.lora_config.max_lora_rank}.")
        if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
133
134
135
            raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
                             f"is greater than lora_extra_vocab_size "
                             f"{self.lora_config.lora_extra_vocab_size}.")
136
137
138
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
139
        if lora_request.lora_int_id in self.list_adapters():
140
            return False
141
142
143
144
        if isinstance(self._cached_dummy_lora, LoRAModel):
            dummy_lora = self._cached_dummy_lora.clone(
                lora_request.lora_int_id)
        else:
145
            dummy_lora = self._adapter_manager.create_dummy_lora(
146
                lora_request.lora_int_id, rank, 1, self.embedding_modules)
147
148
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
149
        return self._adapter_manager.add_adapter(dummy_lora)
150

151
152
153
154
155
156
157
158
159
160
161
162
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

    def set_active_adapters(self, requests: Set[Any],
                            mapping: Optional[Any]) -> None:
        set_active_adapters_worker(requests, mapping, self._apply_adapters,
                                   self._adapter_manager.set_adapter_mapping)

    def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
        apply_adapters_worker(adapter_requests, self.list_adapters,
                              self._adapter_manager.adapter_slots,
                              self.remove_adapter, self.add_adapter)
163

164
165
166
167
168
    def add_adapter(self, adapter_request: Any) -> bool:
        return add_adapter_worker(adapter_request, self.list_adapters,
                                  self._load_adapter,
                                  self._adapter_manager.add_adapter,
                                  self._adapter_manager.activate_adapter)
169

170
171
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
172

173
174
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
175

176
177
    def list_adapters(self) -> Set[int]:
        return list_adapters_worker(self._adapter_manager.list_adapters)
178
179
180
181
182
183
184
185
186


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

187
    _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
188
189
190
191
192
193
194

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
195
            lora_manager_cls=self._manager_cls,
196
197
198
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
199
            device=self.device,
200
201
            max_num_batched_tokens=self.max_num_batched_tokens,
        )
202
        self._adapter_manager = lora_manager
203
204
        return lora_manager.model

205
    def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
206
207
208
209
        loras_map = {
            lora_request.lora_int_id: lora_request
            for lora_request in lora_requests if lora_request
        }
210
        if len(loras_map) > self._adapter_manager.lora_slots:
211
212
213
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
214
                f"({self._adapter_manager.lora_slots}).")
215
        for lora in loras_map.values():
216
            self.add_adapter(lora)
217

218
219
    def add_adapter(self, lora_request: LoRARequest) -> bool:
        if lora_request.lora_int_id not in self.list_adapters():
220
221
222
223
224
225
226
227
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
228
229
230
231
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
                assert isinstance(self._adapter_manager,
                                  LRUCacheLoRAModelManager)
                self._adapter_manager.remove_oldest_adapter()
232
            # Then add the new adapter to the cache
233
            loaded = self._adapter_manager.add_adapter(lora)
234
235
236
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
237
            loaded = self._adapter_manager.get_adapter(
238
                lora_request.lora_int_id) is not None
239
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
240
        return loaded