module_mapping.py 1.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
# Adapted from
#  https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py

from dataclasses import dataclass, field
8
from typing import Union
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


@dataclass
class ModelKeys:
    model_type: str = None

    module_list: str = None

    embedding: str = None

    mlp: str = None

    down_proj: str = None

    attention: str = None

    o_proj: str = None

    q_proj: str = None

    k_proj: str = None

    v_proj: str = None

    qkv_proj: str = None

    qk_proj: str = None

    qa_proj: str = None

    qb_proj: str = None

    kva_proj: str = None

    kvb_proj: str = None

    output: str = None


@dataclass
class MultiModelKeys(ModelKeys):
50
51
    language_model: list[str] = field(default_factory=list)
    connector: list[str] = field(default_factory=list)
52
    # vision tower and audio tower
53
54
    tower_model: list[str] = field(default_factory=list)
    generator: list[str] = field(default_factory=list)
55
56

    @staticmethod
57
58
59
60
    def from_string_field(language_model: Union[str, list[str]] = None,
                          connector: Union[str, list[str]] = None,
                          tower_model: Union[str, list[str]] = None,
                          generator: Union[str, list[str]] = None,
61
62
63
64
65
66
67
68
69
70
71
72
                          **kwargs) -> 'MultiModelKeys':

        def to_list(value):
            if value is None:
                return []
            return [value] if isinstance(value, str) else list(value)

        return MultiModelKeys(language_model=to_list(language_model),
                              connector=to_list(connector),
                              tower_model=to_list(tower_model),
                              generator=to_list(generator),
                              **kwargs)