utils.py 8.26 KB
Newer Older
1
import copy
2
import os
3
4
import re
from typing import List, Optional, Set, Tuple, Type, Union
5

6
7
8
import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
                                   HFValidationError, RepositoryNotFoundError)
9
from torch import nn
10
from transformers import PretrainedConfig
11

12
from vllm.config import LoRAConfig
13
from vllm.logger import init_logger
14
15
16
from vllm.lora.fully_sharded_layers import (
    ColumnParallelLinearWithShardedLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
17
18
    MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
    RowParallelLinearWithShardedLoRA)
19
20
21
22
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
23
                              LinearScalingRotaryEmbeddingWithLora,
24
25
26
27
                              LogitsProcessorWithLoRA,
                              MergedColumnParallelLinearWithLoRA,
                              MergedQKVParallelLinearWithLora,
                              QKVParallelLinearWithLora,
28
                              ReplicatedLinearWithLoRA,
29
30
31
32
33
                              RowParallelLinearWithLoRA,
                              VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
34
35
from vllm.model_executor.models.utils import WeightsMapper
from vllm.utils import print_warning_once
36
37

logger = init_logger(__name__)
38

39
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
40
41
42
43
44
45
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
    QKVParallelLinearWithLora,
    MergedQKVParallelLinearWithLora,
    RowParallelLinearWithLoRA,
46
    ReplicatedLinearWithLoRA,
47
48
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
49
    QKVParallelLinearWithShardedLora,
50
    MergedColumnParallelLinearWithShardedLoRA,
51
52
53
    MergedQKVParallelLinearWithShardedLora,
    RowParallelLinearWithShardedLoRA,
    LinearScalingRotaryEmbeddingWithLora,
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
}


def from_layer(layer: nn.Module,
               max_loras: int,
               lora_config: LoRAConfig,
               packed_modules_list: List,
               model_config: Optional[PretrainedConfig] = None) -> nn.Module:
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
        if lora_cls.can_replace_layer(source_layer=layer,
                                      lora_config=lora_config,
                                      packed_modules_list=packed_modules_list,
                                      model_config=model_config):
            ret = lora_cls(layer)
            ret.create_lora_weights(max_loras, lora_config, model_config)
            return ret
    return layer


def from_layer_logits_processor(
    layer: LogitsProcessor,
    lm_head: ParallelLMHead,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
    ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
82
83
                                  lm_head.weight.dtype, lm_head.weight.device,
                                  lm_head.get_sharded_to_full_mapping())
84
85
86
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

87
88
89
90
91
92
93
94
95
96

def replace_submodule(model: nn.Module, module_name: str,
                      new_module: nn.Module) -> nn.Module:
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


97
98
99
100
def parse_fine_tuned_lora_name(
        name: str,
        weights_mapper: Optional[WeightsMapper] = None
) -> Tuple[str, bool, bool]:
101
102
103
104
105
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
106
107
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
108
109
110
111
    return:
        Tuple(module_name, is_lora_a):
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
112
            is_bias whether the tensor is lora bias.
113
    """
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

    w_mapper = None
    if weights_mapper:
        w_mapper = copy.deepcopy(weights_mapper)
        # TODO: Currently only supports mapping for prefix, mapping for
        # substr and subfix will be supported in the future.
        for attr, mapping in [
            ("orig_to_new_substr", w_mapper.orig_to_new_substr),
            ("orig_to_new_suffix", w_mapper.orig_to_new_suffix),
        ]:
            if mapping:
                print_warning_once(
                    f"vLLM currently does not support mapping of LoRA weights "
                    f"for {mapping}.")
                setattr(w_mapper, attr, {})

    mapper = (lambda name: w_mapper._map_name(name)
              if w_mapper is not None else name)
132
    parts = name.split(".")
133
134
    if parts[-1] == "weight" and (parts[-2] == "lora_A"
                                  or parts[-2] == "lora_B"):
135
136
        new_name = ".".join(parts[2:-2])
        return mapper(new_name), parts[-2] == "lora_A", False
137

138
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
139
140
        new_name = ".".join(parts[2:-1])
        return mapper(new_name), parts[-1] == "lora_embedding_A", False
141
142

    if parts[-1] == "bias":
143
144
        new_name = ".".join(parts[2:-2])
        return mapper(new_name), False, True
145

146
    raise ValueError(f"{name} is unsupported LoRA weight")
147
148


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def is_regex_target_modules(load_modules: Union[str, List[str]],
                            expected_lora_modules: List[str]) -> bool:
    """
    PEFT supports passing `target_modules` in the form of regular expressions, 
    such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to 
    determine whether the suffix in the regular expression is present in the 
    `expected_lora_modules`.
    """

    def is_valid_regex(pattern):
        try:
            re.compile(pattern)
            return True
        except re.error:
            return False

    def is_subset(sub_list, full_list):
        return set(sub_list).issubset(set(full_list))

    # Similar to PEFT's processing logic, regex-related operations are only
    #  executed when the load_modules is a `str`.
    if not isinstance(load_modules, str):
        return False

    if is_valid_regex(load_modules):
        match = re.search(r"\((.*?)\)\$?$", load_modules)
        if match:
            suffix = match.group(1).split("|")
            return is_subset(suffix, expected_lora_modules)
    return False


181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
    if lora_path.startswith('~'):
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

    # If the path does not exist locally, assume it's a Hugging Face repo.
    try:
        local_snapshot_path = huggingface_hub.snapshot_download(
            repo_id=lora_path)
    except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
            HFValidationError):
        # Handle errors that may occur during the download
        # Return original path instead instead of throwing error here
        logger.exception("Error downloading the HuggingFace model")
        return lora_path

    return local_snapshot_path