ultravox.py 4.83 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py
5
from typing import Any
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

import transformers


class UltravoxConfig(transformers.PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`UltravoxForConditionalGeneration`]. It is used to instantiate an
    Ultravox model according to the specified arguments, defining the model
    architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to
    control the model outputs. Read the documentation from [`PretrainedConfig`]
    for more information.

    Args:
        audio_config (`Union[AutoConfig, dict]`,  *optional*):
23
            Custom audio config or dict.
24
        text_config (`Union[AutoConfig, dict]`, *optional*):
25
26
27
28
29
            The config object of the text backbone.
        audio_model_id (`str`, *optional*):
            The model ID of the audio backbone.
        text_model_id (`str`, *optional*):
            The model ID of the text backbone.
30
31
32
33
34
35
36
37
38
39
        ignore_index (`int`, *optional*, defaults to -100):
            The ignore index for the loss function.
        audio_token_index (`int`, *optional*, defaults to 32000):
            The audio token index to encode the audio prompt.
        stack_factor (`int`, *optional*, defaults to 8):
            Audio downsampling factor for the multimodal projector.
        norm_init (`float`, *optional*, defaults to 0.4):
            The initialization value for the layer normalization.
        projector_act (`str`, *optional*, defaults to `"swiglu"`):
            The activation function used by the multimodal projector.
40
41
42
43
        projector_ln_mid (`bool`, *optional*, defaults to `False`):
            Whether to apply layer normalization at the middle of the
            projector or at the end. Versions v0.4.1 and below
            use `False`, but v0.5 and above use `True`.
44
    """
45

46
    wrapped_model_config: transformers.PretrainedConfig
47
    model_type = "ultravox"
48
    audio_token = "<|audio|>"
49
50
51
52
    is_composition = False

    def __init__(
        self,
53
54
55
56
        audio_config: dict[str, Any] | None = None,
        text_config: dict[str, Any] | None = None,
        audio_model_id: str | None = None,
        text_model_id: str | None = None,
57
58
59
60
61
62
        ignore_index: int = -100,
        audio_token_index: int = 32000,
        hidden_size: int = 4096,
        stack_factor: int = 8,
        norm_init: float = 0.4,
        projector_act: str = "swiglu",
63
        projector_ln_mid: bool = False,
64
        num_projector_layers: int = 0,
65
66
67
68
69
70
71
72
73
        **kwargs,
    ):
        self.ignore_index = ignore_index
        self.audio_token_index = audio_token_index

        self.hidden_size = hidden_size
        self.stack_factor = stack_factor
        self.norm_init = norm_init
        self.projector_act = projector_act
74
        self.projector_ln_mid = projector_ln_mid
75
        self.num_projector_layers = num_projector_layers
76

77
78
79
        # N.B. May set the wrapped_model_config below.
        self.text_model_id = text_model_id
        if text_model_id is None:
80
            text_config = text_config or {}
81
            self.wrapped_model_config = transformers.CONFIG_MAPPING[
82
83
                text_config.get("model_type", "llama")
            ](**text_config)
84

85
86
87
88
89
        # N.B. May set the audio_config below.
        self.audio_model_id = audio_model_id
        if audio_model_id is None:
            self.audio_model_id = None
            audio_config = audio_config or {}
90
91
92
            self.audio_config = transformers.CONFIG_MAPPING[
                audio_config.get("model_type", "whisper")
            ](**audio_config)
93
94

        super().__init__(**kwargs)
95

96
97
98
99
100
101
102
103
104
    def __setattr__(self, key, value):
        # Since --hf-overrides are applied _after_ the UltravoxConfig is
        # instantiated, load the configs implicitly when assigning text_model_id
        # or audio_model_id. This allows:
        #
        #   --hf-overrides.text_model_id=<quantized variant>
        #
        # to behave as intended.
        if key == "text_model_id" and value is not None:
105
106
            from vllm.transformers_utils.config import get_config

107
            self.wrapped_model_config = get_config(value, trust_remote_code=False)
108
109
        elif key == "audio_model_id" and value is not None:
            from vllm.transformers_utils.config import get_config
110

111
            self.audio_config = get_config(value, trust_remote_code=False)
112

113
        return super().__setattr__(key, value)
114

115
    @property
116
    def text_config(self) -> transformers.PretrainedConfig:
117
118
119
        # When Ultravox wraps a multi-modal model (e.g. Gemma), we instantiate
        # the full model, but the text config is the text config of the inner
        # model.
120
        return self.wrapped_model_config.get_text_config()