worker_manager.py 9.15 KB
Newer Older
1
from abc import ABC, abstractmethod, abstractproperty
2
from typing import Any, Dict, List, Set, Type
3
4
5

import torch

6
from vllm.config import LoRAConfig
7
from vllm.logger import init_logger
8
from vllm.lora.layers import LoRAMapping
Terry's avatar
Terry committed
9
from vllm.lora.models import (LoRAModel, LoRAModelManager,
10
11
12
                              LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest

13
logger = init_logger(__name__)
14
15


Terry's avatar
Terry committed
16
class AbstractWorkerLoRAManager(ABC):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    """Abstract class for managing LoRA models on the worker side."""

    def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
                 vocab_size: int, lora_config: LoRAConfig,
                 device: torch.device):
        self.max_num_seqs = max_num_seqs
        self.max_num_batched_tokens = max_num_batched_tokens
        self.vocab_size = vocab_size
        self.device = device
        self.lora_config = lora_config

    @abstractproperty
    def is_enabled(self) -> bool:
        ...

    @abstractmethod
    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        ...

    @abstractmethod
40
    def set_active_loras(self, lora_requests: Set[LoRARequest],
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
                         lora_mapping: LoRAMapping) -> None:
        ...

    @abstractmethod
    def add_lora(self, lora_request: LoRARequest) -> bool:
        ...

    @abstractmethod
    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
        ...

    @abstractmethod
    def remove_lora(self, lora_id: int) -> bool:
        ...

    @abstractmethod
57
    def remove_all_loras(self):
58
59
60
61
62
63
64
        ...

    @abstractmethod
    def list_loras(self) -> Set[int]:
        ...


Terry's avatar
Terry committed
65
class WorkerLoRAManager(AbstractWorkerLoRAManager):
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

    _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager

    def __init__(
        self,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
Terry's avatar
Terry committed
80
81
        embedding_modules: Dict[str, str],
        embedding_padding_modules: List[str],
82
83
84
        lora_model_cls: Type[LoRAModel] = LoRAModel,
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
85
86
        self.embedding_modules = embedding_modules
        self.embedding_padding_modules = embedding_padding_modules
87
88
        # Lazily initialized by create_lora_manager.
        self._lora_manager: LoRAModelManager
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
                         lora_config, device)

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
            lora_manager_cls=self._lora_manager_cls,
        )
108
        self._lora_manager = lora_manager
109
110
        return lora_manager.model

111
    def set_active_loras(self, lora_requests: Set[LoRARequest],
112
113
114
115
                         lora_mapping: LoRAMapping) -> None:
        self._apply_loras(lora_requests)
        self._lora_manager.set_lora_mapping(lora_mapping)

116
    def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        loras_that_exist = self.list_loras()
        loras_map = {
            lora_request.lora_int_id: lora_request
            for lora_request in lora_requests if lora_request
        }
        if len(loras_map) > self._lora_manager.lora_slots:
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
                f"({self._lora_manager.lora_slots}).")

        new_loras = set(loras_map)
        loras_to_add = new_loras - loras_that_exist
        loras_to_remove = loras_that_exist - new_loras

        for lora_id in loras_to_remove:
            self.remove_lora(lora_id)

        for lora_id in loras_to_add:
            self.add_lora(loras_map[lora_id])

    def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
        try:
140
141
142
143
144
145
146
147
148
149
            model = self._lora_manager.model
            supported_lora_modules = model.supported_lora_modules
            packed_modules_mapping = model.packed_modules_mapping
            expected_lora_modules = []
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
                    expected_lora_modules.extend(
                        packed_modules_mapping[module])
                else:
                    expected_lora_modules.append(module)
150
151
            lora = self._lora_model_cls.from_local_checkpoint(
                lora_request.lora_local_path,
152
                expected_lora_modules,
153
154
155
156
157
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
                target_embedding_padding=self.vocab_size +
                self.lora_config.lora_extra_vocab_size,
Terry's avatar
Terry committed
158
159
                embedding_modules=self.embedding_modules,
                embedding_padding_modules=self.embedding_padding_modules,
160
161
162
163
164
165
166
167
168
            )
        except Exception as e:
            raise RuntimeError(
                f"Loading lora {lora_request.lora_local_path} failed") from e
        if lora.rank > self.lora_config.max_lora_rank:
            raise ValueError(
                f"LoRA rank {lora.rank} is greater than max_lora_rank "
                f"{self.lora_config.max_lora_rank}.")
        if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
169
170
171
            raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
                             f"is greater than lora_extra_vocab_size "
                             f"{self.lora_config.lora_extra_vocab_size}.")
172
173
174
175
176
177
178
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
        if lora_request.lora_int_id in self.list_loras():
            return False
        return self._lora_manager.add_lora(
            self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
Terry's avatar
Terry committed
179
                                                 rank, self.embedding_modules))
180
181
182
183
184
185
186
187
188
189
190
191

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if lora_request.lora_int_id in self.list_loras():
            return False
        lora = self._load_lora(lora_request)
        loaded = self._lora_manager.add_lora(lora)
        self._lora_manager.activate_lora(lora.id)
        return loaded

    def remove_lora(self, lora_id: int) -> bool:
        return self._lora_manager.remove_lora(lora_id)

192
    def remove_all_loras(self):
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        self._lora_manager.remove_all_loras()

    def list_loras(self) -> Set[int]:
        return set(self._lora_manager.list_loras())


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

    _lora_manager_cls: Type[
        LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager

    def create_lora_manager(
        self,
        model: torch.nn.Module,
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            lora_manager_cls=self._lora_manager_cls,
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
            max_num_batched_tokens=self.max_num_batched_tokens,
        )
221
        self._lora_manager = lora_manager
222
223
        return lora_manager.model

224
    def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        loras_map = {
            lora_request.lora_int_id: lora_request
            for lora_request in lora_requests if lora_request
        }
        if len(loras_map) > self._lora_manager.lora_slots:
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
                f"({self._lora_manager.lora_slots}).")
        for lora in loras_map.values():
            self.add_lora(lora)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if lora_request.lora_int_id not in self.list_loras():
            # Remove before we load the new lora to save memory
            if len(self._lora_manager) + 1 > self._lora_manager.capacity:
241
                assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
242
243
244
245
246
247
                self._lora_manager.remove_oldest_lora()
            lora = self._load_lora(lora_request)
            loaded = self._lora_manager.add_lora(lora)
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
248
249
            loaded = self._lora_manager.get_lora(
                lora_request.lora_int_id) is not None
250
251
        self._lora_manager.activate_lora(lora_request.lora_int_id)
        return loaded