configuration_parler_tts.py 13.7 KB
Newer Older
sanchit-gandhi's avatar
sanchit-gandhi committed
1
# coding=utf-8
2
# Copyright 2024 and The HuggingFace Inc. team. All rights reserved.
sanchit-gandhi's avatar
sanchit-gandhi committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yoach Lacombe's avatar
Yoach Lacombe committed
15
""" Parler-TTS model configuration"""
sanchit-gandhi's avatar
sanchit-gandhi committed
16
17

from transformers import AutoConfig, logging
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
18
from transformers.configuration_utils import PretrainedConfig
sanchit-gandhi's avatar
sanchit-gandhi committed
19
20
21
22
23


logger = logging.get_logger(__name__)

MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
Yoach Lacombe's avatar
Yoach Lacombe committed
24
25
    "facebook/parler_tts-small": "https://huggingface.co/facebook/parler_tts-small/resolve/main/config.json",
    # See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
sanchit-gandhi's avatar
sanchit-gandhi committed
26
27
28
}


Yoach Lacombe's avatar
Yoach Lacombe committed
29
class ParlerTTSDecoderConfig(PretrainedConfig):
sanchit-gandhi's avatar
sanchit-gandhi committed
30
    r"""
Yoach Lacombe's avatar
Yoach Lacombe committed
31
32
33
34
    This is the configuration class to store the configuration of an [`ParlerTTSDecoder`]. It is used to instantiate a
    Parler-TTS decoder according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the Parler-TTS
    [facebook/parler_tts-small](https://huggingface.co/facebook/parler_tts-small) architecture.
sanchit-gandhi's avatar
sanchit-gandhi committed
35
36
37
38
39
40

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:
41
        vocab_size (`int`, *optional*, defaults to 2049):
Yoach Lacombe's avatar
Yoach Lacombe committed
42
            Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
Yoach Lacombe's avatar
Yoach Lacombe committed
43
            represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
sanchit-gandhi's avatar
sanchit-gandhi committed
44
45
46
47
48
49
        hidden_size (`int`, *optional*, defaults to 1024):
            Dimensionality of the layers and the pooler layer.
        num_hidden_layers (`int`, *optional*, defaults to 24):
            Number of decoder layers.
        num_attention_heads (`int`, *optional*, defaults to 16):
            Number of attention heads for each attention layer in the Transformer block.
50
51
52
53
54
55
56
57
58
59
60
        num_key_value_heads (`int`, *optional*):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details checkout [this
            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
            `num_attention_heads`.
        num_cross_attention_key_value_heads (`int`, *optional*):
            This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers.
            If it is not specified, will default to `num_key_value_heads`.
sanchit-gandhi's avatar
sanchit-gandhi committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        ffn_dim (`int`, *optional*, defaults to 4096):
            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
        activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
            `"relu"`, `"silu"` and `"gelu_new"` are supported.
        dropout (`float`, *optional*, defaults to 0.1):
            The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        activation_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for activations inside the fully connected layer.
        max_position_embeddings (`int`, *optional*, defaults to 2048):
            The maximum sequence length that this model might ever be used with. Typically, set this to something large
            just in case (e.g., 512 or 1024 or 2048).
        initializer_factor (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        layerdrop (`float`, *optional*, defaults to 0.0):
            The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
            for more details.
        scale_embedding (`bool`, *optional*, defaults to `False`):
            Scale embeddings by diving by sqrt(hidden_size).
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether the model should return the last key/values attentions (not used by all models)
        num_codebooks (`int`, *optional*, defaults to 4):
            The number of parallel codebooks forwarded to the model.
        tie_word_embeddings(`bool`, *optional*, defaults to `False`):
            Whether input and output word embeddings should be tied.
88
89
90
91
92
93
        rope_embeddings (`bool`, *optional*, defaults to `False`):
            Whether to use ROPE or absolute positional embeddings.
        rope_theta (`float`, *optional*, defaults to 100000.0):
            The base period of the RoPE embeddings.
        cross_attention_implementation_strategy (`str`, *optional*):
            If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
sanchit-gandhi's avatar
sanchit-gandhi committed
94
95
    """

Yoach Lacombe's avatar
Yoach Lacombe committed
96
    model_type = "parler_tts_decoder"
sanchit-gandhi's avatar
sanchit-gandhi committed
97
98
99
100
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
        self,
Yoach Lacombe's avatar
Yoach Lacombe committed
101
        vocab_size=2049,  # vocab size = 2048 (encodec vocab size) + 1 (eos)
sanchit-gandhi's avatar
sanchit-gandhi committed
102
103
104
105
        max_position_embeddings=2048,
        num_hidden_layers=24,
        ffn_dim=4096,
        num_attention_heads=16,
106
107
        num_key_value_heads=None,
        num_cross_attention_key_value_heads=None,
sanchit-gandhi's avatar
sanchit-gandhi committed
108
109
110
111
112
113
114
115
116
117
        layerdrop=0.0,
        use_cache=True,
        activation_function="gelu",
        hidden_size=1024,
        dropout=0.1,
        attention_dropout=0.0,
        activation_dropout=0.0,
        initializer_factor=0.02,
        scale_embedding=False,
        num_codebooks=4,
118
119
120
        pad_token_id=2048,
        bos_token_id=2049,
        eos_token_id=2048,
sanchit-gandhi's avatar
sanchit-gandhi committed
121
        tie_word_embeddings=False,
122
123
124
        rope_embeddings=False,
        rope_theta=10_000.0,
        cross_attention_implementation_strategy=None,
sanchit-gandhi's avatar
sanchit-gandhi committed
125
126
127
128
129
130
131
132
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.ffn_dim = ffn_dim
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
133
134
135
136
137
138
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        if num_cross_attention_key_value_heads is None:
            num_cross_attention_key_value_heads = num_key_value_heads
        self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads
sanchit-gandhi's avatar
sanchit-gandhi committed
139
140
141
142
143
144
145
146
147
        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.activation_dropout = activation_dropout
        self.activation_function = activation_function
        self.initializer_factor = initializer_factor
        self.layerdrop = layerdrop
        self.use_cache = use_cache
        self.scale_embedding = scale_embedding  # scale factor will be sqrt(d_model) if True
        self.num_codebooks = num_codebooks
148
149
150
        self.rope_embeddings = rope_embeddings
        self.rope_theta = rope_theta
        self.cross_attention_implementation_strategy = cross_attention_implementation_strategy
sanchit-gandhi's avatar
sanchit-gandhi committed
151
152
153
154
155
156
157
158
159
160

        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )


Yoach Lacombe's avatar
Yoach Lacombe committed
161
class ParlerTTSConfig(PretrainedConfig):
sanchit-gandhi's avatar
sanchit-gandhi committed
162
    r"""
Yoach Lacombe's avatar
Yoach Lacombe committed
163
164
    This is the configuration class to store the configuration of a [`ParlerTTSModel`]. It is used to instantiate a
    Parler-TTS model according to the specified arguments, defining the text encoder, audio encoder and Parler-TTS decoder
sanchit-gandhi's avatar
sanchit-gandhi committed
165
166
167
168
169
170
    configs.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
Yoach Lacombe's avatar
Yoach Lacombe committed
171
        vocab_size (`int`, *optional*, defaults to 1024):
Yoach Lacombe's avatar
Yoach Lacombe committed
172
173
            Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
            represented by the `prompt_inputs_ids`.
174
175
        prompt_cross_attention (`bool`, *optional*, defaults to `False`):
            Whether to use cross-attention conditioning for the prompt (as well as the description).
sanchit-gandhi's avatar
sanchit-gandhi committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        kwargs (*optional*):
            Dictionary of keyword arguments. Notably:

                - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
                  defines the text encoder config.
                - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
                  defines the audio encoder config.
                - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
                  the decoder config.

    Example:

    ```python
    >>> from transformers import (
Yoach Lacombe's avatar
Yoach Lacombe committed
190
191
    ...     ParlerTTSConfig,
    ...     ParlerTTSDecoderConfig,
sanchit-gandhi's avatar
sanchit-gandhi committed
192
193
    ...     T5Config,
    ...     EncodecConfig,
Yoach Lacombe's avatar
Yoach Lacombe committed
194
    ...     ParlerTTSForConditionalGeneration,
sanchit-gandhi's avatar
sanchit-gandhi committed
195
196
197
198
199
    ... )

    >>> # Initializing text encoder, audio encoder, and decoder model configurations
    >>> text_encoder_config = T5Config()
    >>> audio_encoder_config = EncodecConfig()
Yoach Lacombe's avatar
Yoach Lacombe committed
200
    >>> decoder_config = ParlerTTSDecoderConfig()
sanchit-gandhi's avatar
sanchit-gandhi committed
201

Yoach Lacombe's avatar
Yoach Lacombe committed
202
    >>> configuration = ParlerTTSConfig.from_sub_models_config(
sanchit-gandhi's avatar
sanchit-gandhi committed
203
204
205
    ...     text_encoder_config, audio_encoder_config, decoder_config
    ... )

Yoach Lacombe's avatar
Yoach Lacombe committed
206
207
    >>> # Initializing a ParlerTTSForConditionalGeneration (with random weights) from the facebook/parler_tts-small style configuration
    >>> model = ParlerTTSForConditionalGeneration(configuration)
sanchit-gandhi's avatar
sanchit-gandhi committed
208
209
210
211
212
213
214
215

    >>> # Accessing the model configuration
    >>> configuration = model.config
    >>> config_text_encoder = model.config.text_encoder
    >>> config_audio_encoder = model.config.audio_encoder
    >>> config_decoder = model.config.decoder

    >>> # Saving the model, including its configuration
Yoach Lacombe's avatar
Yoach Lacombe committed
216
    >>> model.save_pretrained("parler_tts-model")
sanchit-gandhi's avatar
sanchit-gandhi committed
217
218

    >>> # loading model and config from pretrained folder
Yoach Lacombe's avatar
Yoach Lacombe committed
219
220
    >>> parler_tts_config = ParlerTTSConfig.from_pretrained("parler_tts-model")
    >>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler_tts-model", config=parler_tts_config)
sanchit-gandhi's avatar
sanchit-gandhi committed
221
222
    ```"""

Yoach Lacombe's avatar
Yoach Lacombe committed
223
    model_type = "parler_tts"
sanchit-gandhi's avatar
sanchit-gandhi committed
224
225
    is_composition = True

226
    def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
sanchit-gandhi's avatar
sanchit-gandhi committed
227
228
229
230
231
232
233
234
235
236
237
238
        super().__init__(**kwargs)
        if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
            raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")

        text_encoder_config = kwargs.pop("text_encoder")
        text_encoder_model_type = text_encoder_config.pop("model_type")

        audio_encoder_config = kwargs.pop("audio_encoder")
        audio_encoder_model_type = audio_encoder_config.pop("model_type")

        decoder_config = kwargs.pop("decoder")

Yoach Lacombe's avatar
Yoach Lacombe committed
239
        self.vocab_size = vocab_size
240
        self.prompt_cross_attention = prompt_cross_attention
sanchit-gandhi's avatar
sanchit-gandhi committed
241
242
        self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
        self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
Yoach Lacombe's avatar
Yoach Lacombe committed
243
        self.decoder = ParlerTTSDecoderConfig(**decoder_config)
sanchit-gandhi's avatar
sanchit-gandhi committed
244
245
246
247
248
249
250
        self.is_encoder_decoder = True

    @classmethod
    def from_sub_models_config(
        cls,
        text_encoder_config: PretrainedConfig,
        audio_encoder_config: PretrainedConfig,
Yoach Lacombe's avatar
Yoach Lacombe committed
251
        decoder_config: ParlerTTSDecoderConfig,
sanchit-gandhi's avatar
sanchit-gandhi committed
252
253
254
        **kwargs,
    ):
        r"""
Yoach Lacombe's avatar
Yoach Lacombe committed
255
        Instantiate a [`ParlerTTSConfig`] (or a derived class) from text encoder, audio encoder and decoder
sanchit-gandhi's avatar
sanchit-gandhi committed
256
257
258
        configurations.

        Returns:
Yoach Lacombe's avatar
Yoach Lacombe committed
259
            [`ParlerTTSConfig`]: An instance of a configuration object
sanchit-gandhi's avatar
sanchit-gandhi committed
260
261
262
263
264
265
266
267
268
269
270
271
272
        """

        return cls(
            text_encoder=text_encoder_config.to_dict(),
            audio_encoder=audio_encoder_config.to_dict(),
            decoder=decoder_config.to_dict(),
            **kwargs,
        )

    @property
    # This is a property because you might want to change the codec model on the fly
    def sampling_rate(self):
        return self.audio_encoder.sampling_rate
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

    # Copy from musicgen
    @property
    def _attn_implementation(self):
        # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
        if hasattr(self, "_attn_implementation_internal"):
            if self._attn_implementation_internal is None:
                # `config.attn_implementation` should never be None, for backward compatibility.
                return "eager"
            else:
                return self._attn_implementation_internal
        else:
            return "eager"

    @_attn_implementation.setter
    def _attn_implementation(self, value):
        self._attn_implementation_internal = value
        self.decoder._attn_implementation = value