Commit 9bde9933 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

add DAC

parent 71d87fd3
......@@ -90,4 +90,5 @@ Convert sequence of discrete labels to text description (using an LLM).
## Step 5: Train the Model
Train MusicGen-style model on the TTS task.
Needs DAC.
import dac
# Download a model
model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path)
from stable_speech import DACConfig, DACModel
hf_dac = DACModel(DACConfig())
hf_dac.model.load_state_dict(model.state_dict())
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
hf_dac.push_to_hub("ylacombe/dac_44khZ_8kbps")
# DACConfig.register_for_auto_class()
# DACModel.register_for_auto_class("AutoModel")
from transformers import EncodecFeatureExtractor
EncodecFeatureExtractor(sampling_rate=44100).push_to_hub("ylacombe/dac_44khZ_8kbps")
\ No newline at end of file
......@@ -60,6 +60,12 @@ from transformers.optimization import get_scheduler
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available
from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
from accelerate import Accelerator
from accelerate.utils import set_seed
......
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig
from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask
\ No newline at end of file
from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask
from .dac_wrapper import DACConfig, DACModel
\ No newline at end of file
from .configuration_dac import DACConfig
from .modeling_dac import DACModel
\ No newline at end of file
from transformers import PretrainedConfig
from typing import List
class DACConfig(PretrainedConfig):
model_type = "dac"
def __init__(
self,
num_codebooks: int = 9,
model_bitrate: int = 8, # kbps
codebook_size: int = 1024,
latent_dim: int = 1024,
frame_rate: int = 86,
**kwargs,
):
self.codebook_size = codebook_size
self.model_bitrate = model_bitrate
self.latent_dim = latent_dim
self.num_codebooks = num_codebooks
self.frame_rate = frame_rate
super().__init__(**kwargs)
\ No newline at end of file
import torch
from transformers import PreTrainedModel
from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput
from .configuration_dac import DACConfig
from dac.model import DAC
# model doesn't support batching yet
class DACModel(PreTrainedModel):
config_class = DACConfig
def __init__(self, config):
super().__init__(config)
self.model = DAC(
n_codebooks = config.num_codebooks,
latent_dim = config.latent_dim,
codebook_size = config.codebook_size,
)
def encode(self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None):
"""
Encodes the input audio waveform into discrete codes.
Args:
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Float values of the input audio waveform.
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Padding mask used to pad the `input_values`.
bandwidth (`float`, *optional*):
Not used, kept to have the same inferface as HF encodec.
n_quantizers (`int`, *optional*) :
Number of quantizers to use, by default None
If None, all quantizers are used.
sample_rate (`int`, *optional*) :
Signal sampling_rate
Returns:
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
`codebook` of shape `[batch_size, num_codebooks, frames]`.
Scale is not used here.
"""
_, channels, input_length = input_values.shape
if channels < 1 or channels > 2:
raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
audio_data = self.model.preprocess(input_values, sample_rate)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# TODO: for now, no chunk length
chunk_length = None # self.config.chunk_length
if chunk_length is None:
chunk_length = input_length
stride = input_length
else:
stride = self.config.chunk_stride
if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool()
encoded_frames = []
scales = []
step = chunk_length - stride
if (input_length % stride) - step != 0:
raise ValueError(
"The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly."
)
for offset in range(0, input_length - step, stride):
mask = padding_mask[..., offset : offset + chunk_length].bool()
frame = audio_data[:, :, offset : offset + chunk_length]
scale = None
_, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers)
encoded_frames.append(encoded_frame)
scales.append(scale)
encoded_frames = torch.stack(encoded_frames)
if not return_dict:
return (encoded_frames, scales)
return EncodecEncoderOutput(encoded_frames, scales)
def decode(
self,
audio_codes,
audio_scales,
padding_mask = None,
return_dict = None,
):
"""
Decodes the given frames into an output audio waveform.
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
trimmed.
Args:
audio_codes (`torch.FloatTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
Discret code embeddings computed using `model.encode`.
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
Not used, kept to have the same inferface as HF encodec.
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Padding mask used to pad the `input_values`.
Not used yet, kept to have the same inferface as HF encodec.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
return_dict = return_dict or self.config.return_dict
# TODO: for now, no chunk length
if len(audio_codes) != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0]
audio_values = self.model.decode(audio_values)
if not return_dict:
return (audio_values,)
return EncodecDecoderOutput(audio_values)
def forward(self, tensor):
raise ValueError(f"`DACModel.forward` not implemented yet")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment