module_mapping.py 1.19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
# Adapted from
#  https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py

from dataclasses import dataclass, field


@dataclass
11
class MultiModelKeys:
12
13
    language_model: list[str] = field(default_factory=list)
    connector: list[str] = field(default_factory=list)
14
    # vision tower and audio tower
15
16
    tower_model: list[str] = field(default_factory=list)
    generator: list[str] = field(default_factory=list)
17
18

    @staticmethod
19
    def from_string_field(
20
21
22
23
        language_model: str | list[str] = None,
        connector: str | list[str] = None,
        tower_model: str | list[str] = None,
        generator: str | list[str] = None,
24
25
        **kwargs,
    ) -> "MultiModelKeys":
26
27
28
29
30
        def to_list(value):
            if value is None:
                return []
            return [value] if isinstance(value, str) else list(value)

31
32
33
34
35
36
37
        return MultiModelKeys(
            language_model=to_list(language_model),
            connector=to_list(connector),
            tower_model=to_list(tower_model),
            generator=to_list(generator),
            **kwargs,
        )