".buildkite/vscode:/vscode.git/clone" did not exist on "8e61425ee6d0bd03d3669c148eba8b263d101273"
lora.py 4.83 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import hashlib
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union

import torch
8
from pydantic import ConfigDict, Field, model_validator
9
from pydantic.dataclasses import dataclass
10
from typing_extensions import Self
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform

if TYPE_CHECKING:
    from vllm.config import ModelConfig
    from vllm.config.cache import CacheConfig
else:
    ModelConfig = Any
    CacheConfig = Any

logger = init_logger(__name__)

LoRADType = Literal["auto", "float16", "bfloat16"]
27
28
MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
LoRAExtraVocabSize = Literal[256, 512]
29
30
31
32
33
34
35


@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig:
    """Configuration for LoRA."""

36
    max_lora_rank: MaxLoRARanks = 16
37
    """Max LoRA rank."""
38
    max_loras: int = Field(default=1, ge=1)
39
40
41
42
43
44
45
46
47
48
49
    """Max number of LoRAs in a single batch."""
    fully_sharded_loras: bool = False
    """By default, only half of the LoRA computation is sharded with tensor
    parallelism. Enabling this will use the fully sharded layers. At high
    sequence length, max rank or tensor parallel size, this is likely faster.
    """
    max_cpu_loras: Optional[int] = None
    """Maximum number of LoRAs to store in CPU memory. Must be >= than
    `max_loras`."""
    lora_dtype: Union[torch.dtype, LoRADType] = "auto"
    """Data type for LoRA. If auto, will default to base model dtype."""
50
51
52
53
54
55
56
57
    lora_extra_vocab_size: LoRAExtraVocabSize = Field(
        default=256,
        deprecated=(
            "`lora_extra_vocab_size` is deprecated and will be removed "
            "in v0.12.0. Additional vocabulary support for "
            "LoRA adapters is being phased out."
        ),
    )
58
59
    """(Deprecated) Maximum size of extra vocabulary that can be present in a 
    LoRA adapter. Will be removed in v0.12.0."""
60
61
62
    lora_vocab_padding_size: ClassVar[int] = (
        current_platform.get_lora_vocab_padding_size()
    )
63
64
65
66
67
68
69
70
71
72
    default_mm_loras: Optional[dict[str, str]] = None
    """Dictionary mapping specific modalities to LoRA model paths; this field
    is only applicable to multimodal models and should be leveraged when a
    model always expects a LoRA to be active when a given modality is present.
    Note that currently, if a request provides multiple additional
    modalities, each of which have their own LoRA, we do NOT apply
    default_mm_loras because we currently only support one lora adapter
    per prompt. When run in offline mode, the lora IDs for n modalities
    will be automatically assigned to 1-n with the names of the modalities
    in alphabetic order."""
73
74
75
76
    bias_enabled: bool = Field(
        default=False,
        deprecated="`bias_enabled` is deprecated and will be removed in v0.12.0.",
    )
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    """[DEPRECATED] Enable bias for LoRA adapters. This option will be
    removed in v0.12.0."""

    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        factors.append(self.max_lora_rank)
        factors.append(self.max_loras)
        factors.append(self.fully_sharded_loras)
        factors.append(self.lora_dtype)
        factors.append(self.lora_extra_vocab_size)
        factors.append(self.lora_vocab_padding_size)
        factors.append(self.bias_enabled)
100
        hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
101
102
        return hash_str

103
104
    @model_validator(mode="after")
    def _validate_lora_config(self) -> Self:
105
106
107
108
109
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
110
111
                f"max_loras ({self.max_loras})"
            )
112

113
114
        return self

115
116
    def verify_with_cache_config(self, cache_config: CacheConfig):
        if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1:
117
            raise ValueError("V0 LoRA does not support CPU offload, please use V1.")
118
119
120
121
122
123

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)