modeling_dac.py 4.97 KB
Newer Older
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
1
2
3
4
5
6
7
8
9
import torch

from transformers import PreTrainedModel
from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput
from .configuration_dac import DACConfig

from dac.model import DAC


Yoach Lacombe's avatar
Yoach Lacombe committed
10
# model doesn't support batching yet
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
11
12
13
14
15
16
17


class DACModel(PreTrainedModel):
    config_class = DACConfig

    def __init__(self, config):
        super().__init__(config)
Yoach Lacombe's avatar
Yoach Lacombe committed
18

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
19
        self.model = DAC(
Yoach Lacombe's avatar
Yoach Lacombe committed
20
21
22
            n_codebooks=config.num_codebooks,
            latent_dim=config.latent_dim,
            codebook_size=config.codebook_size,
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
23
        )
Yoach Lacombe's avatar
Yoach Lacombe committed
24
25
26
27

    def encode(
        self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
    ):
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        """
        Encodes the input audio waveform into discrete codes.

        Args:
            input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
                Float values of the input audio waveform.
            padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
                Padding mask used to pad the `input_values`.
            bandwidth (`float`, *optional*):
                Not used, kept to have the same inferface as HF encodec.
            n_quantizers (`int`, *optional*) :
                Number of quantizers to use, by default None
                If None, all quantizers are used.
            sample_rate (`int`, *optional*) :
                Signal sampling_rate

        Returns:
            A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
            factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
            `codebook` of shape `[batch_size, num_codebooks, frames]`.
            Scale is not used here.
Yoach Lacombe's avatar
Yoach Lacombe committed
49

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
50
51
52
53
54
55
56
        """
        _, channels, input_length = input_values.shape

        if channels < 1 or channels > 2:
            raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")

        audio_data = self.model.preprocess(input_values, sample_rate)
Yoach Lacombe's avatar
Yoach Lacombe committed
57

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
58
59
60
61
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # TODO: for now, no chunk length

Yoach Lacombe's avatar
Yoach Lacombe committed
62
        chunk_length = None  # self.config.chunk_length
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        if chunk_length is None:
            chunk_length = input_length
            stride = input_length
        else:
            stride = self.config.chunk_stride

        if padding_mask is None:
            padding_mask = torch.ones_like(input_values).bool()

        encoded_frames = []
        scales = []

        step = chunk_length - stride
        if (input_length % stride) - step != 0:
            raise ValueError(
                "The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly."
            )

        for offset in range(0, input_length - step, stride):
            mask = padding_mask[..., offset : offset + chunk_length].bool()
            frame = audio_data[:, :, offset : offset + chunk_length]
Yoach Lacombe's avatar
Yoach Lacombe committed
84

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
85
            scale = None
Yoach Lacombe's avatar
Yoach Lacombe committed
86

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
87
88
89
90
91
92
93
94
95
96
            _, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers)
            encoded_frames.append(encoded_frame)
            scales.append(scale)

        encoded_frames = torch.stack(encoded_frames)

        if not return_dict:
            return (encoded_frames, scales)

        return EncodecEncoderOutput(encoded_frames, scales)
Yoach Lacombe's avatar
Yoach Lacombe committed
97

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
98
    def decode(
Yoach Lacombe's avatar
Yoach Lacombe committed
99
100
101
102
103
104
        self,
        audio_codes,
        audio_scales,
        padding_mask=None,
        return_dict=None,
    ):
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        """
        Decodes the given frames into an output audio waveform.

        Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
        trimmed.

        Args:
            audio_codes (`torch.FloatTensor`  of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
                Discret code embeddings computed using `model.encode`.
            audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
                Not used, kept to have the same inferface as HF encodec.
            padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
                Padding mask used to pad the `input_values`.
                Not used yet, kept to have the same inferface as HF encodec.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.

        """
        return_dict = return_dict or self.config.return_dict

        # TODO: for now, no chunk length

        if len(audio_codes) != 1:
            raise ValueError(f"Expected one frame, got {len(audio_codes)}")
Yoach Lacombe's avatar
Yoach Lacombe committed
129

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
130
131
132
133
134
        audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0]
        audio_values = self.model.decode(audio_values)
        if not return_dict:
            return (audio_values,)
        return EncodecDecoderOutput(audio_values)
Yoach Lacombe's avatar
Yoach Lacombe committed
135

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
136
    def forward(self, tensor):
Yoach Lacombe's avatar
Yoach Lacombe committed
137
        raise ValueError(f"`DACModel.forward` not implemented yet")