modeling_dac.py 4.97 KB
Newer Older
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
1
import torch
eustlb's avatar
eustlb committed
2
from dac.model import DAC
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
3
from transformers import PreTrainedModel
eustlb's avatar
eustlb committed
4
from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
5

eustlb's avatar
eustlb committed
6
from .configuration_dac import DACConfig
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
7
8


Yoach Lacombe's avatar
Yoach Lacombe committed
9
# model doesn't support batching yet
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
10
11
12
13
14
15
16


class DACModel(PreTrainedModel):
    config_class = DACConfig

    def __init__(self, config):
        super().__init__(config)
Yoach Lacombe's avatar
Yoach Lacombe committed
17

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
18
        self.model = DAC(
Yoach Lacombe's avatar
Yoach Lacombe committed
19
20
21
            n_codebooks=config.num_codebooks,
            latent_dim=config.latent_dim,
            codebook_size=config.codebook_size,
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
22
        )
Yoach Lacombe's avatar
Yoach Lacombe committed
23
24
25
26

    def encode(
        self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
    ):
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        """
        Encodes the input audio waveform into discrete codes.

        Args:
            input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
                Float values of the input audio waveform.
            padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
                Padding mask used to pad the `input_values`.
            bandwidth (`float`, *optional*):
                Not used, kept to have the same inferface as HF encodec.
            n_quantizers (`int`, *optional*) :
                Number of quantizers to use, by default None
                If None, all quantizers are used.
            sample_rate (`int`, *optional*) :
                Signal sampling_rate

        Returns:
            A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
            factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
            `codebook` of shape `[batch_size, num_codebooks, frames]`.
            Scale is not used here.
Yoach Lacombe's avatar
Yoach Lacombe committed
48

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
49
50
51
52
53
54
55
        """
        _, channels, input_length = input_values.shape

        if channels < 1 or channels > 2:
            raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")

        audio_data = self.model.preprocess(input_values, sample_rate)
Yoach Lacombe's avatar
Yoach Lacombe committed
56

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
57
58
59
60
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # TODO: for now, no chunk length

Yoach Lacombe's avatar
Yoach Lacombe committed
61
        chunk_length = None  # self.config.chunk_length
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        if chunk_length is None:
            chunk_length = input_length
            stride = input_length
        else:
            stride = self.config.chunk_stride

        if padding_mask is None:
            padding_mask = torch.ones_like(input_values).bool()

        encoded_frames = []
        scales = []

        step = chunk_length - stride
        if (input_length % stride) - step != 0:
            raise ValueError(
                "The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly."
            )

        for offset in range(0, input_length - step, stride):
            mask = padding_mask[..., offset : offset + chunk_length].bool()
            frame = audio_data[:, :, offset : offset + chunk_length]
Yoach Lacombe's avatar
Yoach Lacombe committed
83

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
84
            scale = None
Yoach Lacombe's avatar
Yoach Lacombe committed
85

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
86
87
88
89
90
91
92
93
94
95
            _, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers)
            encoded_frames.append(encoded_frame)
            scales.append(scale)

        encoded_frames = torch.stack(encoded_frames)

        if not return_dict:
            return (encoded_frames, scales)

        return EncodecEncoderOutput(encoded_frames, scales)
Yoach Lacombe's avatar
Yoach Lacombe committed
96

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
97
    def decode(
Yoach Lacombe's avatar
Yoach Lacombe committed
98
99
100
101
102
103
        self,
        audio_codes,
        audio_scales,
        padding_mask=None,
        return_dict=None,
    ):
Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        """
        Decodes the given frames into an output audio waveform.

        Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
        trimmed.

        Args:
            audio_codes (`torch.FloatTensor`  of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
                Discret code embeddings computed using `model.encode`.
            audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
                Not used, kept to have the same inferface as HF encodec.
            padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
                Padding mask used to pad the `input_values`.
                Not used yet, kept to have the same inferface as HF encodec.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.

        """
        return_dict = return_dict or self.config.return_dict

        # TODO: for now, no chunk length

        if len(audio_codes) != 1:
            raise ValueError(f"Expected one frame, got {len(audio_codes)}")
Yoach Lacombe's avatar
Yoach Lacombe committed
128

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
129
130
131
132
133
        audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0]
        audio_values = self.model.decode(audio_values)
        if not return_dict:
            return (audio_values,)
        return EncodecDecoderOutput(audio_values)
Yoach Lacombe's avatar
Yoach Lacombe committed
134

Yoach Lacombe's avatar
add DAC  
Yoach Lacombe committed
135
    def forward(self, tensor):
eustlb's avatar
eustlb committed
136
        raise ValueError("`DACModel.forward` not implemented yet")