models.py 34.6 KB
Newer Older
1
2
3
4
5
import copy
import json
import math
import os
import re
6
7
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
8
9
10
11
12
13

import safetensors.torch
import torch
from torch import nn

from vllm.config import LoRAConfig
14
from vllm.logger import init_logger
15
16
17
from vllm.lora.layers import (BaseLayerWithLoRA,
                              LinearScalingRotaryEmbeddingWithLora,
                              LoRAMapping)
18
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
19
20
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
                             parse_fine_tuned_lora_name, replace_submodule)
21
from vllm.utils import LRUCache, is_pin_memory_available
22

23
logger = init_logger(__name__)
24
25
26
27

_GLOBAL_LORA_ID = 0


28
29
30
31
32
33
34
35
36
37
38
39
@dataclass
class LongContextLoRAContext:
    """Context for lora adapters that support long context."""
    # The scaling factors to support long context lora fine tuned models.
    scaling_factors: List[float]
    # dimension to apply rotary embedding.
    rot_dim: int
    # offsets to the sin_cos_cache for each lora_id loaded.
    # This value is dynamically modified.
    offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)


40
def convert_mapping(
41
42
43
44
45
46
47
48
    mapping: LoRAMapping,
    lora_index_to_id: List[Optional[int]],
    max_loras: int,
    vocab_size: int,
    extra_vocab_size: int,
    long_lora_context: Optional[LongContextLoRAContext] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
           Optional[torch.Tensor], List[int]]:
49
50
51
52
53
54
55
56
    """Converts LoRAMapping to index tensors.

    Args:
        mapping: LoRAMapping mapping rows in a batch to LoRA ids.
        lora_index_to_id: List mapping LoRA ids to LoRA indices.
        max_loras: Maximum number of LoRAs.
        vocab_size: Model vocab size.
        extra_vocab_size: Extra vocab size each LoRA can have.
57
        long_lora_context: Passed if there are long context lora in a batch.
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    Returns:
        A tuple of tensors:
            base_indices: Tensor of shape [batch_size] mapping batch rows to
                LoRA indices.
            sampler_indices: Tensor of shape [batch_size] mapping requests to
                LoRA indices for sampler. For generation, this will be the
                same as base_indicies. For prefill, this will map requests
                to LoRA indices.
            sampler_indices_padded: Tensor of shape [batch_size] mapping
                requests to LoRA indices for sampler with padding.
                Same as sampler_indicies, but -1 is replaced with
                max_loras.
            embeddings_indices: Tensor of shape [2, batch_size] mapping
                requests to embedding indices. First row is for embeddings
                added by the LoRAs, second row is for the LoRA.lora_a
                embeddings.
75
76
77
            long_lora_indices: Tensor of shape [batch_size] mapping
                requests to RoPE offsets and rot dims for long LoRAs.
                None if long context lora doesn't exist.
78
            indices_len: List of lengths of the above tensors.
79
80
81
82
                Used to index into each tensor. It contains length for
                (base_indices, sampler_indices, sampler_indices_padded,
                embeddings_indices, long_lora_indices). If long_lora doesn't
                exist, it only contains first 4 entries.
83
    """
84
85
86
    index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
    embedding_indices = index_mapping_indices.copy()
    lora_indices = index_mapping_indices.copy()
87
88
89
90
91
    long_lora_offsets: Optional[torch.Tensor] = None
    if long_lora_context:
        long_lora_offsets = torch.zeros(len(index_mapping_indices),
                                        device="cuda",
                                        dtype=torch.long)
92
    prompt_mapping: List[int] = [
93
94
95
96
        lora_index_to_id.index(x) if x > 0 else -1
        for x in mapping.prompt_mapping
    ]
    lora_idx = None
97
    for i in range(len(index_mapping_indices)):
98
        # TODO index can be slow. optimize
99
100
101
        lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
                    if index_mapping_indices[i] > 0 else -1)
        embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
102
        lora_indices[i] = lora_idx
103
104
105
106
107
108
109
110
111
112
113
114
115
        if long_lora_context:
            assert long_lora_offsets is not None
            lora_offset: int = long_lora_context.offsets_by_lora_id.get(
                index_mapping_indices[i], 0)
            long_lora_offsets[i] = lora_offset

    indices_list: List[Union[List[int], torch.Tensor]] = [
        index_mapping_indices, lora_indices, embedding_indices
    ]
    if long_lora_context:
        assert long_lora_offsets is not None
        indices_list.append(long_lora_offsets)
    indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
116
117
118
    prompt_mapping_tensor = torch.tensor(prompt_mapping,
                                         device="cuda",
                                         dtype=torch.long)
119
120
121
122
123
124
    embeddings_indices = torch.stack([
        indices[2] * extra_vocab_size,
        indices[2] * (vocab_size + extra_vocab_size)
    ])
    embeddings_indices[embeddings_indices == -1] = max_loras - 1
    base_indices = indices[1]
125
    sampler_indices = prompt_mapping_tensor
126
127
128
129
130
131
    sampler_indices_padded = sampler_indices.clone()
    sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
    sampler_indices_padded = (
        torch.arange(
            0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
        (sampler_indices_padded * len(sampler_indices_padded)))
132
133
134
135
136
137
    long_lora_indices = None
    long_lora_indices_len: Optional[int] = None
    if long_lora_context:
        long_lora_indices = indices[3]
        long_lora_indices_len = long_lora_indices.shape[-1]
    # Contain length of indices tensors. Used to index into each tensor.
138
139
140
141
    indices_len = [
        base_indices.shape[-1], sampler_indices.shape[-1],
        sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
    ]
142
143
    if long_lora_indices_len is not None:
        indices_len.append(long_lora_indices_len)
144
145

    return (base_indices, sampler_indices, sampler_indices_padded,
146
            embeddings_indices, long_lora_indices, indices_len)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


class LoRAModel:
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
        loras: Dict[str, LoRALayerWeights],
163
        scaling_factor: Optional[float] = None,
164
    ) -> None:
165
166
167
168
169
170
171
172
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
            scaling_factor: Scaling factor to support long context lora model.
                None if the lora is not tuned for long context support.
        """
173
        self.id = lora_model_id
174
175
176
        # Scaling factor for long context lora model. None if it is not
        # fine tuned for the long context.
        self.scaling_factor = scaling_factor
177
178
179
180
181
        assert (lora_model_id >
                0), f"a valid lora id should be greater than 0, got {self.id}"
        self.rank = rank
        self.loras: Dict[str, LoRALayerWeights] = loras

182
183
184
185
186
187
188
189
190
191
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
        rank: int,
        lora_alpha: int,
        tensors: Dict[str, torch.Tensor],
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        embeddings: Optional[Dict[str, torch.Tensor]] = None,
        target_embedding_padding: Optional[int] = None,
213
        scaling_factor: Optional[float] = None,
Terry's avatar
Terry committed
214
215
        embedding_modules: Optional[Dict[str, str]] = None,
        embedding_padding_modules: Optional[List[str]] = None,
216
217
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
218
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
219
220
221
222
223
224
        loras: Dict[str, LoRALayerWeights] = {}
        for tensor_name, tensor in tensors.items():
            module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
225
                    assert embedding_modules is not None
226
                    embeddings_module = next(
Terry's avatar
Terry committed
227
                        (k for k in embedding_modules if k in module_name),
228
229
230
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
231
                            embedding_modules[embeddings_module]].to(
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
                loras[module_name] = LoRALayerWeights(module_name, rank,
                                                      lora_alpha, None, None,
                                                      lora_embeddings_tensor)
            if is_lora_a:
                loras[module_name].lora_a = tensor.to(device=device,
                                                      dtype=dtype).t()
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
                                                      dtype=dtype).t()
248
                assert embedding_padding_modules is not None
249
                if any(name in module_name
Terry's avatar
Terry committed
250
                       for name in embedding_padding_modules
251
252
253
254
255
256
257
258
259
260
261
262
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
                    assert target_embedding_padding >= lora_b.shape[1]
                    addition = target_embedding_padding - lora_b.shape[1]
                    loras[module_name].lora_b = torch.nn.functional.pad(
                        lora_b, (0, addition))
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
263
        return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
264
265
266

    @classmethod
    def from_local_checkpoint(
Terry's avatar
Terry committed
267
268
        cls,
        lora_dir: str,
269
        expected_lora_modules: List[str],
270
271
        *,
        max_position_embeddings: Optional[int] = None,
Terry's avatar
Terry committed
272
273
274
275
276
277
278
        lora_model_id: Optional[int] = None,
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        target_embedding_padding: Optional[int] = None,
        embedding_modules: Optional[Dict[str, str]] = None,
        embedding_padding_modules: Optional[List[str]] = None,
    ) -> "LoRAModel":
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        """Create a LoRAModel from a local checkpoint.
        
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
            max_position_embeddings: Max position embedding length. Used to
                scaling the largest context length. If None, the lora model's
                context length is not scaled.
            lora_model_id: Lora model id. If not given, automatically set by
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
296
297
298
299
300
301
302
        lora_config_path = os.path.join(lora_dir, "adapter_config.json")
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
303
304
305
306
307
        with open(lora_config_path) as f:
            config = json.load(f)
        target_modules = config["target_modules"]
        unexpected_modules = []
        for module in target_modules:
308
309
310
            # Compatible with more modules, such as:layers.11.self_attn.k_proj
            part_name = module.split(".")[-1]
            if part_name not in expected_lora_modules:
311
312
                unexpected_modules.append(module)
        # loaded lora's target modules must be a subset of expected_lora_modules
313

314
        if unexpected_modules:
315
            print(unexpected_modules, "modules")
316
317
318
319
320
            raise ValueError(
                f"While loading {lora_dir}, expected"
                f" target modules in {expected_lora_modules}"
                f" but received {unexpected_modules}."
                f" Please verify that the loaded LoRA module is correct")
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        if os.path.isfile(lora_tensor_path):
            tensors = safetensors.torch.load_file(lora_tensor_path)
        elif os.path.isfile(lora_bin_file_path):
            tensors = torch.load(lora_bin_file_path)
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
            embeddings = torch.load(new_embeddings_bin_file_path)

        rank = config["r"]
        lora_alpha = config["lora_alpha"]
337
338
339
340
341
342
343
344
        context_length = config.get("context_length", None)
        scaling_factor = None
        if context_length:
            if max_position_embeddings is None:
                max_position_embeddings = context_length
            scaling_factor = float(
                math.ceil(context_length / max_position_embeddings))

345
346
347
348
349
350
351
352
353
354
        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            rank=rank,
            lora_alpha=lora_alpha,
            tensors=tensors,
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
355
            scaling_factor=scaling_factor,
Terry's avatar
Terry committed
356
357
            embedding_modules=embedding_modules,
            embedding_padding_modules=embedding_padding_modules,
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        )


class LoRAModelManager:
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
        self.lora_config = lora_config
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
        self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
        self.vocab_size = vocab_size
389
        self.long_lora_context: Optional[LongContextLoRAContext] = None
390
391
392
393
394
395
396
397
398
399
400
401
402
        self.base_indices = torch.empty(self.max_num_batched_tokens,
                                        dtype=torch.long,
                                        device="cuda")
        self.sampler_indices = torch.empty(self.max_num_batched_tokens,
                                           dtype=torch.long,
                                           device="cuda")
        self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
                                                  dtype=torch.long,
                                                  device="cuda")
        self.embeddings_indices = torch.empty(2,
                                              self.max_num_batched_tokens,
                                              dtype=torch.long,
                                              device="cuda")
403
404
405
406
407
408
        self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
                                             dtype=torch.long,
                                             device="cuda")
        # Scaling factor -> offset to the sin_cos_cache to it.
        # Used for long context lora.
        self.scaling_factor_to_offset: Dict[float, int] = {}
409
410
411
        # 4 is the number of indicies tensors defined above
        # base_indices, sampler_indices, sampler_indices_padded,
        # embeddings_indices
412
        self.indices_len: List[Optional[int]] = [None] * 4
413
414

        self.model: nn.Module = model
Terry's avatar
Terry committed
415
416
417
        if hasattr(self.model, "supported_lora_modules"):
            self.supported_lora_modules = copy.deepcopy(
                self.model.supported_lora_modules)
418
419
420
421
            if lora_config.long_lora_scaling_factors:
                # We need to replace rotary emb layer to do batch computation
                # for long lora.
                self.supported_lora_modules.append("rotary_emb")
Terry's avatar
Terry committed
422
423
            self.packed_modules_mapping = copy.deepcopy(
                self.model.packed_modules_mapping)
424
425
426
427
428
        self.packed_modules: Dict[str, List[str]] = {}
        self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
        self._registered_loras: Dict[int, LoRAModel] = {}
        # Dict instead of a Set for compatibility with LRUCache.
        self._active_loras: Dict[int, None] = {}
429
        self._last_mapping: Optional[LoRAMapping] = None
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        self._create_lora_modules()
        self.model.lora_manager = self

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

    def __len__(self) -> int:
        return len(self._registered_loras)

    def activate_lora(
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
        if lora_id in self._active_loras:
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
        self._active_loras[lora_id] = None
        lora_model = self._registered_loras[lora_id]
459
460
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
            module_lora = lora_model.get_lora(module_name)
            if module_lora:
                module_lora.optimize()
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
                                module_lora.embeddings_tensor)
            else:
                module.reset_lora(index)
        return True

    def _deactivate_lora(self, lora_id: int):
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

    def deactivate_lora(self, lora_id: int) -> bool:
        """Remove a LoRA from a GPU buffer."""
        if lora_id in self._active_loras:
            self._deactivate_lora(lora_id)
            self._active_loras.pop(lora_id)
            return True
        return False

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    def _set_long_lora_context(self, lora: LoRAModel):
        if self.long_lora_context is None:
            return

        if lora.scaling_factor is None:
            return

        if (lora.scaling_factor not in self.scaling_factor_to_offset):
            raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
                             " has not been initialized.")

        offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
        if offsets:
            self.long_lora_context.offsets_by_lora_id[lora.id] = offsets

502
    def _add_lora(self, lora: LoRAModel):
503
504
        self._create_merged_loras_inplace(lora)
        self._registered_loras[lora.id] = lora
505
        self._set_long_lora_context(lora)
506
507
508

    def add_lora(self, lora: LoRAModel) -> bool:
        """Add a LoRAModel to the manager CPU cache."""
509
510
511
512
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
513
514
515
516
517
518
519
520
521
522
523
        if lora.id not in self._registered_loras:
            if len(self._registered_loras) >= self.capacity:
                raise RuntimeError("No free LoRA slots.")
            self._add_lora(lora)
            return True
        return False

    def remove_lora(self, lora_id: int) -> bool:
        """Remove a LoRAModel from the manager CPU cache."""
        # TODO: should we check active lora?
        self.deactivate_lora(lora_id)
524
525
        if self.long_lora_context:
            self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
526
527
        return bool(self._registered_loras.pop(lora_id, None))

528
529
530
531
532
533
    def pin_lora(self, lora_id: int) -> bool:
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
            "Pinning is not supported in LoRAModelManager."
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

534
535
536
    # TODO see if this can be vectorized
    def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
        (base_indices, sampler_indices, sampler_indices_padded,
537
         embeddings_indices, long_lora_offsets_tensor,
538
539
         indices_len) = convert_mapping(mapping, self.lora_index_to_id,
                                        self.lora_slots + 1, self.vocab_size,
540
541
                                        self.lora_config.lora_extra_vocab_size,
                                        self.long_lora_context)
542
543
544
545
546
547
548
        self.base_indices[:base_indices.shape[0]].copy_(base_indices)
        self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
        self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
            sampler_indices_padded)
        self.embeddings_indices[:embeddings_indices.
                                shape[0], :embeddings_indices.shape[1]].copy_(
                                    embeddings_indices)
549
550
551
552
553
        if long_lora_offsets_tensor is not None:
            self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
                long_lora_offsets_tensor)
        else:
            self.long_lora_indices.zero_()
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        # Maintain the reference
        self.indices_len[:] = indices_len

    def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
        if self._last_mapping != lora_mapping:
            self._set_lora_mapping(lora_mapping)
        self._last_mapping = lora_mapping

    def list_loras(self) -> Dict[int, LoRAModel]:
        """List all registered LoRAModels."""
        return dict(self._registered_loras)

    def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
        return self._registered_loras.get(lora_id, None)

569
    def remove_all_loras(self):
570
571
572
573
574
575
        """Remove all LoRAModels from the manager."""
        self._registered_loras.clear()
        self.lora_index_to_id = [None] * self.lora_slots
        self._active_loras.clear()

    def _create_lora_modules(self):
576
577
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
578
579
            if not self._match_target_modules(module_name):
                continue
580
581
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
582
583
584
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
585
                           packed_moduled_lst, self.model.config))
586
587
588
589
590
591
592
            # LinearScalingRotaryEmbeddingWithLora is used to handle
            # long context lora. Register relevant metadata.
            if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
                self.long_lora_context = LongContextLoRAContext(
                    new_module.scaling_factors, new_module.rotary_dim)
                self.scaling_factor_to_offset = \
                    new_module.scaling_factor_to_offset
593
594
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
595
596
                logits_processor_module = self.model.get_submodule(
                    "logits_processor")
597
                new_module = replace_submodule(
598
599
600
601
602
                    self.model, "logits_processor",
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
603
604
605
606
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
            new_module.set_mapping(self.base_indices, self.sampler_indices,
                                   self.sampler_indices_padded,
607
608
                                   self.embeddings_indices,
                                   self.long_lora_indices, self.indices_len)
609
610
611
612
613

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
614
615
616
617
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
618
            scaling_factor: Optional[float],
Terry's avatar
Terry committed
619
            embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
620
        """Create zero-initialized LoRAModel for warmup."""
621
        model = LoRAModel(lora_id, rank, {}, scaling_factor)
622
623
        for module_name, module in self.model.named_modules():
            if not self._match_target_modules(module_name) or not isinstance(
624
625
                    module, BaseLayerWithLoRA) or isinstance(
                        module, LinearScalingRotaryEmbeddingWithLora):
626
627
628
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
629
                assert embedding_modules is not None
Terry's avatar
Terry committed
630
                if parts[-1] in embedding_modules:
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
                        module.lora_a_stacked.dtype,
                        "cpu",
                        embeddings_tensor_dim=embeddings_tensor_dim)
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.lora_a_stacked.shape[-1],
                        module.lora_b_stacked.shape[-2],
                        rank,
                        module.lora_a_stacked.dtype,
                        "cpu",
                    )
                lora.optimize()
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
663
                subloras: List[Optional["LoRALayerWeights"]] = []
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    lora.optimize()
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
684
            for target_module in self.supported_lora_modules)
685
686
687
688

    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
689
690
691
692
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
693
694
695
696
697
698
699
700
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
701
            replacement_loras: List[Optional[LoRALayerWeights]] = []
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
            has_replacement = False
            for r in new_module_names:
                lora = lora_model.get_lora(r)
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)


718
class LoRALRUCache(LRUCache[LoRAModel]):
719

720
721
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
722
723
724
        super().__init__(capacity)
        self.deactivate_lora_fn = deactivate_lora_fn

725
    def _on_remove(self, key: int, value: LoRAModel):
726
        logger.debug("Removing LoRA. int id: %d", key)
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
        self.deactivate_lora_fn(key)
        return super()._on_remove(key, value)


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
    ):
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
Terry's avatar
Terry committed
743
                         vocab_size, lora_config)
744
745
746
747
748
749
750
751
752
753
754
        self._registered_loras: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_lora)
        self._active_loras: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_lora)

    def list_loras(self) -> Dict[int, LoRAModel]:
        """List all registered LoRAModels."""
        return dict(self._registered_loras.cache)

    def add_lora(self, lora: LoRAModel) -> bool:
        """Add a LoRAModel to the manager."""
755
756
757
758
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
        if lora.id not in self._registered_loras:
            self._add_lora(lora)
            was_added = True
        else:
            # We always touch to update the LRU cache order
            self._registered_loras.touch(lora.id)
            was_added = False
        return was_added

    def activate_lora(
        self,
        lora_id: int,
    ) -> bool:
        if lora_id not in self._active_loras and len(
                self._active_loras) >= self.lora_slots:
            self._active_loras.remove_oldest()
        result = super().activate_lora(lora_id)
        # We always touch to update the LRU cache order
        self._active_loras.touch(lora_id)
        return result

    def remove_oldest_lora(self) -> bool:
        if len(self._registered_loras) > 0:
            self._registered_loras.remove_oldest()
            return True
        return False

786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
    def pin_lora(self, lora_id: int) -> bool:
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
            self._registered_loras.pin(lora_id)
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
        if lora_id not in self._active_loras:
            # move lora to gpu if not already active
            self.activate_lora(lora_id)

        self._active_loras.pin(lora_id)

806
807
808
809
810
811
812
813
814
815

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
Terry's avatar
Terry committed
816
    if not hasattr(model, "supported_lora_modules"):
817
818
819
820
821
822
823
824
825
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
        **kwargs)
    return lora_manager