"vllm/model_executor/models/nemotron.py" did not exist on "63e7176f265be43dcc425f5ab4ab45c90234f5c3"
worker_manager.py 11.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from contextlib import contextmanager
5
from typing import Any, Literal
6
7
8

import torch

9
from vllm.config import VllmConfig
10
from vllm.exceptions import LoRAAdapterNotFoundError
11
from vllm.logger import init_logger
12
13
from vllm.lora.lora_model import LoRAModel
from vllm.lora.model_manager import (
14
15
16
17
    LoRAModelManager,
    LRUCacheLoRAModelManager,
    create_lora_manager,
)
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.request import LoRARequest
20
from vllm.lora.utils import get_adapter_absolute_path
21

22
logger = init_logger(__name__)
23
24


25
class WorkerLoRAManager:
26
27
28
29
30
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Every request, the requested LoRAs will be loaded (unless they are already
    loaded), and every other LoRA will be unloaded."""

31
    _manager_cls: type[LoRAModelManager] = LoRAModelManager
32
33
34

    def __init__(
        self,
35
        vllm_config: VllmConfig,
36
        device: torch.device,
37
38
        embedding_modules: dict[str, str],
        lora_model_cls: type[LoRAModel] = LoRAModel,
39
40
    ):
        self._lora_model_cls = lora_model_cls
Terry's avatar
Terry committed
41
        self.embedding_modules = embedding_modules
42
        self._cached_dummy_lora: None | Literal[False] | LoRAModel = False
43
44
        self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        self.max_num_batched_tokens = (
45
46
            vllm_config.scheduler_config.max_num_batched_tokens
        )
47
48
49
50
51
52
        self.vocab_size = vllm_config.model_config.get_vocab_size()
        self.lora_config = vllm_config.lora_config

        # Use get_text_config() in case of multimodal models
        text_config = vllm_config.model_config.hf_config.get_text_config()

53
54
55
56
57
58
59
60
61
62
63
64
        # For encoder-decoder models (e.g., Whisper), use max_target_positions
        # instead of max_position_embeddings
        # TODO: Generalize max_position_embeddings handling for
        # out-of-tree (OOT) encoder-decoder models
        if vllm_config.model_config.is_encoder_decoder:
            self.max_position_embeddings = getattr(
                text_config, "max_target_positions", None
            )
        else:
            self.max_position_embeddings = getattr(
                text_config, "max_position_embeddings", None
            )
65
        self.device = device
66
        # Lazily initialized by create_lora_manager.
67
68
69
70
71
72
73
74
75
        self._adapter_manager: LoRAModelManager

    @contextmanager
    def dummy_lora_cache(self):
        """Use this context manager to reuse the dummy lora model
        to avoid creating it repeatedly."""
        self._cached_dummy_lora = None
        yield
        self._cached_dummy_lora = False
76
77
78
79
80
81
82
83

    @property
    def is_enabled(self) -> bool:
        return True

    def create_lora_manager(
        self,
        model: torch.nn.Module,
84
        vllm_config: VllmConfig | None = None,
85
86
87
88
89
90
91
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
            max_num_seqs=self.max_num_seqs,
            max_num_batched_tokens=self.max_num_batched_tokens,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
92
            device=self.device,
93
            lora_manager_cls=self._manager_cls,
94
            vllm_config=vllm_config,
95
        )
96
        self._adapter_manager = lora_manager
97
98
        return lora_manager.model

99
    def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
100
        try:
101
102
            supported_lora_modules = self._adapter_manager.supported_lora_modules
            packed_modules_mapping = self._adapter_manager.packed_modules_mapping
103
            expected_lora_lst: list[str] = []
104
105
            for module in supported_lora_modules:
                if module in packed_modules_mapping:
106
                    expected_lora_lst.extend(packed_modules_mapping[module])
107
                else:
108
                    expected_lora_lst.append(module)
109
                if module == "experts":
110
111
                    expected_lora_lst.append(module)
            expected_lora_modules = set(expected_lora_lst)
112
            lora_path = get_adapter_absolute_path(lora_request.lora_path)
113

114
            peft_helper = PEFTHelper.from_local_dir(
115
116
117
118
                lora_path,
                self.max_position_embeddings,
                lora_request.tensorizer_config_dict,
            )
119
120
121
122
123

            # Validates the LoRA configuration against requirements before
            # loading weights, throwing an exception if validation fails.
            peft_helper.validate_legal(self.lora_config)

124
125
            # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
            # to ensure correct loading of lora weights.
126
            model = self._adapter_manager.model
127
            hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
128

129
130
131
            # Get model-defined prefixes to skip during LoRA loading.
            lora_skip_prefixes = getattr(model, "lora_skip_prefixes", None)

132
            lora = self._lora_model_cls.from_local_checkpoint(
133
                lora_path,
134
                expected_lora_modules,
135
                peft_helper=peft_helper,
136
137
138
                lora_model_id=lora_request.lora_int_id,
                device="cpu",
                dtype=self.lora_config.lora_dtype,
139
                model_vocab_size=self.vocab_size,
140
                tensorizer_config_dict=lora_request.tensorizer_config_dict,
141
                weights_mapper=hf_to_vllm_mapper,
142
                skip_prefixes=lora_skip_prefixes,
143
            )
144

145
146
147
148
149
        except FileNotFoundError as e:
            # FileNotFoundError should be raised if both
            # - No adapter found to download from huggingface (or in
            #       offline mode)
            # - No local adapter files found at `lora_request.lora_path`
150
            # For NotFoundError
151
152
            raise LoRAAdapterNotFoundError(
                lora_request.lora_name, lora_request.lora_path
153
            ) from e
154
        except Exception as e:
155
156
            raise e

157
158
159
        return lora

    def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
160
        if lora_request.lora_int_id in self.list_adapters():
161
            return False
162
        if isinstance(self._cached_dummy_lora, LoRAModel):
163
            dummy_lora = self._cached_dummy_lora.clone(lora_request.lora_int_id)
164
        else:
165
            dummy_lora = self._adapter_manager.create_dummy_lora(
166
167
                lora_request.lora_int_id, rank, self.embedding_modules
            )
168
169
            if self._cached_dummy_lora is None:
                self._cached_dummy_lora = dummy_lora
170
        return self._adapter_manager.add_adapter(dummy_lora)
171

172
173
174
    def pin_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.pin_adapter(adapter_id)

175
    def set_active_adapters(self, requests: set[Any], mapping: Any | None) -> None:
176
177
178
        self._apply_adapters(requests)
        if mapping is not None:
            self._adapter_manager.set_adapter_mapping(mapping)
179

180
181
182
183
184
185
    def supports_tower_connector_lora(self) -> bool:
        return (
            self._adapter_manager.supports_mm
            and self._adapter_manager.supports_tower_connector_lora
        )

186
    def _apply_adapters(self, adapter_requests: set[Any]) -> None:
187
188
189
        existing_adapters = self.list_adapters()
        models_map = {
            adapter_request.adapter_id: adapter_request
190
191
            for adapter_request in adapter_requests
            if adapter_request
192
193
194
195
196
        }
        if len(models_map) > self._adapter_manager.adapter_slots:
            raise RuntimeError(
                f"Number of requested models ({len(models_map)}) is greater "
                "than the number of GPU model slots "
197
198
                f"({self._adapter_manager.adapter_slots})."
            )
199
200
201
202
203
        requested_ids = set(models_map)
        for adapter_id in existing_adapters - requested_ids:
            self.remove_adapter(adapter_id)
        for adapter_id in requested_ids - existing_adapters:
            self.add_adapter(models_map[adapter_id])
204

205
    def add_adapter(self, adapter_request: Any) -> bool:
206
207
208
209
210
211
        if adapter_request.adapter_id in self.list_adapters():
            return False
        loaded_adapter = self._load_adapter(adapter_request)
        loaded = self._adapter_manager.add_adapter(loaded_adapter)
        self._adapter_manager.activate_adapter(loaded_adapter.id)
        return loaded
212

213
214
    def remove_adapter(self, adapter_id: int) -> bool:
        return self._adapter_manager.remove_adapter(adapter_id)
215

216
217
    def remove_all_adapters(self):
        self._adapter_manager.remove_all_adapters()
218

219
    def list_adapters(self) -> set[int]:
220
        return set(self._adapter_manager.list_adapters())
221
222
223
224
225
226
227
228
229


class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
    """WorkerLoRAManager that manages LoRA models on the worker side.

    Uses an LRU Cache. Every request, the requested LoRAs will be loaded
    (unless they are already loaded) and least recently used LoRAs will
    be unloaded if the cache is above capacity."""

230
    _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
231
232
233
234

    def create_lora_manager(
        self,
        model: torch.nn.Module,
235
        vllm_config: VllmConfig | None = None,
236
237
238
    ) -> Any:
        lora_manager = create_lora_manager(
            model,
239
            lora_manager_cls=self._manager_cls,
240
241
242
            max_num_seqs=self.max_num_seqs,
            vocab_size=self.vocab_size,
            lora_config=self.lora_config,
243
            device=self.device,
244
            max_num_batched_tokens=self.max_num_batched_tokens,
245
            vllm_config=vllm_config,
246
        )
247
        self._adapter_manager = lora_manager
248
249
        return lora_manager.model

250
    def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
251
252
        loras_map = {
            lora_request.lora_int_id: lora_request
253
254
            for lora_request in lora_requests
            if lora_request
255
        }
256
        if len(loras_map) > self._adapter_manager.lora_slots:
257
258
259
            raise RuntimeError(
                f"Number of requested LoRAs ({len(loras_map)}) is greater "
                "than the number of GPU LoRA slots "
260
261
                f"({self._adapter_manager.lora_slots})."
            )
262
        for lora in loras_map.values():
263
            self.add_adapter(lora)
264

265
    def add_adapter(self, lora_request: LoRARequest) -> bool:
266
267
268
269
270
        # Note that this method is not thread-safe. It may be invoked multiple
        # times for the same adapter when using multiple API servers.
        # This is ok because it's currently only called from
        # the single-threaded core engine loop.

271
272
273
274
        if (
            lora_request.lora_int_id not in self.list_adapters()
            or lora_request.load_inplace
        ):
275
276
277
278
279
280
            # Load the new adapter first to ensure it is actually valid, before
            # evicting any existing adapters.
            # This may cause the # of loaded lora adapters to very temporarily
            # exceed `--max-cpu-loras`.
            lora = self._load_adapter(lora_request)

281
282
283
284
            # Remove the existing adapter if it exists
            # Use case for LoRA inplace
            self._adapter_manager.remove_adapter(lora.id)

285
286
            # Loading succeeded, now check if we will exceed cache capacity and
            # evict if the oldest adapter if so
287
            if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
288
                assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager)
289
                self._adapter_manager.remove_oldest_adapter()
290
            # Then add the new adapter to the cache
291
            loaded = self._adapter_manager.add_adapter(lora)
292
293
294
        else:
            # If the lora is already loaded, just touch it to
            # update its position in the caches
295
296
297
            loaded = (
                self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None
            )
298
        self._adapter_manager.activate_adapter(lora_request.lora_int_id)
299
        return loaded