Unverified Commit 1453841b authored by Chia-Yu Hung's avatar Chia-Yu Hung Committed by GitHub
Browse files

Merge pull request #6 from fakerybakery/main

pip package
parents 7090a624 8fd3cc82
~__pycache__/
__pycache__/
*.py[cod]
*$py.class
......@@ -168,3 +168,8 @@ cython_debug/
# PyPI configuration file
.pypirc
.DS_Store
*.wav
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install git+https://github.com/declare-lab/TangoFlux.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import IPython\n",
"import torchaudio\n",
"from tangoflux import TangoFluxInference\n",
"from IPython.display import Audio\n",
"\n",
"model = TangoFluxInference(name='declare-lab/TangoFlux')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# @title Generate Audio\n",
"\n",
"prompt = 'Hammer slowly hitting the wooden table' # @param {type:\"string\"}\n",
"duration = 10 # @param {type:\"number\"}\n",
"steps = 50 # @param {type:\"number\"}\n",
"\n",
"audio = model.generate(prompt, steps=steps, duration=duration)\n",
"\n",
"Audio(data=audio, rate=44100)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
This diff is collapsed.
<h1 align="center">
<br/>
TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization
<br/>
✨✨✨
</h1>
# TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization
<div align="center">
<img src="assets/tf_teaser.png" alt="TangoFlux" width="1000" />
<br/>
[![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Hugging_Face-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Hugging_Face_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Hugging_Face-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO) [![Replicate](https://replicate.com/chenxwh/tangoflux/badge)](https://replicate.com/chenxwh/tangoflux)
<br/>
</div>
[![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO) [![Replicate](https://replicate.com/chenxwh/tangoflux/badge)](https://replicate.com/chenxwh/tangoflux)
## Demos
[![Hugging Face Space](https://img.shields.io/badge/Hugging_Face_Space-TangoFlux-blue?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/declare-lab/TangoFlux/blob/main/Demo.ipynb)
## Overall Pipeline
TangoFlux consists of FluxTransformer blocks, which are Diffusion Transformers (DiT) and Multimodal Diffusion Transformers (MMDiT) conditioned on a textual prompt and a duration embedding to generate a 44.1kHz audio up to 30 seconds long. TangoFlux learns a rectified flow trajectory to an audio latent representation encoded by a variational autoencoder (VAE). TangoFlux training pipeline consists of three stages: pre-training, fine-tuning, and preference optimization with CRPO. CRPO, particularly, iteratively generates new synthetic data and constructs preference pairs for preference optimization using DPO loss for flow matching.
</div>
![cover-photo](assets/tangoflux.png)
## Quickstart on Google Colab
🚀 **TangoFlux can generate 44.1kHz stereo audio up to 30 seconds in ~3 seconds on a single A40 GPU.**
| Colab |
| --- |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1j__4fl_BlaVS_225M34d-EKxsVDJPRiR?usp=sharing)
## Installation
## Overall Pipeline
TangoFlux consists of FluxTransformer blocks, which are Diffusion Transformers (DiT) and Multimodal Diffusion Transformers (MMDiT) conditioned on a textual prompt and a duration embedding to generate a 44.1kHz audio up to 30 seconds long. TangoFlux learns a rectified flow trajectory to an audio latent representation encoded by a variational autoencoder (VAE). TangoFlux training pipeline consists of three stages: pre-training, fine-tuning, and preference optimization with CRPO. CRPO, particularly, iteratively generates new synthetic data and constructs preference pairs for preference optimization using DPO loss for flow matching.
```bash
pip install git+https://github.com/declare-lab/TangoFlux
```
![cover-photo](assets/tangoflux.png)
## Inference
TangoFlux can generate audio up to 30 seconds long. You must pass a duration to the `model.generate` function when using the Python API. Please note that duration should be between 1 and 30.
🚀 **TangoFlux can generate up to 30 seconds long 44.1kHz stereo audios in about 3 seconds on an A40 GPU.**
### Web Interface
## Training TangoFlux
We use the accelerate package from HuggingFace for multi-gpu training. Run accelerate config from terminal and set up your run configuration by the answering the questions asked. We have placed the default accelerator config in the `configs` folder. Please specify the path to your training files in the configs/tangoflux_config.yaml. A sample of train.json and val.json has been provided. Replace them with your own audio.
Run the following command to start the web interface:
`tangoflux_config.yaml` defines the training file paths and model hyperparameters:
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
tangoflux-demo
```
To perform DPO training, modify the training files such that each data point contains a "chosen","reject","caption" and "duration". Please specify the path to your training files in the configs/tangoflux_config.yaml. An example has been provided in train_dpo.json. Replace them with your own audio.
### CLI
Use the CLI to generate audio from text.
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train_dpo.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
tangoflux "Hammer slowly hitting the wooden table" output.wav --duration 10 --steps 50
```
## Inference with TangoFlux
Download the TangoFlux model and generate audio from a text prompt.
TangoFlux can generate audios up to 30 second long through passing in a duration variable in the `model.generate` function. Please note that duration should be strictly greather than 1 and lesser than 30.
### Python API
```python
import torchaudio
from tangoflux import TangoFluxInference
from IPython.display import Audio
model = TangoFluxInference(name='declare-lab/TangoFlux')
audio = model.generate('Hammer slowly hitting the wooden table', steps=50, duration=10)
Audio(data=audio, rate=44100)
torchaudio.save('output.wav', audio, 44100)
```
Our evaluation shows that inference with 50 steps yields the best results. A CFG scale of 3.5, 4, and 4.5 yield similar quality output. Inference with 25 steps yields similar audio quality at a faster speed.
## Training
We use the `accelerate` package from Hugging Face for multi-GPU training. Run `accelerate config` to setup your run configuration. The default accelerate config is in the `configs` folder. Please specify the path to your training files in the `configs/tangoflux_config.yaml`. Samples of `train.json` and `val.json` have been provided. Replace them with your own audio.
`tangoflux_config.yaml` defines the training file paths and model hyperparameters:
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
```
To perform DPO training, modify the training files such that each data point contains "chosen", "reject", "caption" and "duration" fields. Please specify the path to your training files in `configs/tangoflux_config.yaml`. An example has been provided in `train_dpo.json`. Replace it with your own audio.
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train_dpo.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
```
Our evaluation shows that inferring with 50 steps yield the best results. A CFG scale of 3.5, 4, and 4.5 yield simliar quality output.
For faster inference, consider setting steps to 25 that yield similar audio quality.
## Evaluation Scripts
......@@ -77,15 +92,13 @@ This key comparison metrics include:
All the inference times are observed on the same A40 GPU. The counts of trainable parameters are reported in the **\#Params** column.
| Model | \#Params | Duration | Steps | FD<sub>openl3</sub> ↓ | KL<sub>passt</sub> ↓ | CLAP<sub>score</sub> ↑ | IS ↑ | Inference Time (s) |
|---------------------------------|-----------|----------|-------|-----------------------|----------------------|------------------------|------|--------------------|
| **AudioLDM 2-large** | 712M | 10 sec | 200 | 108.3 | 1.81 | 0.419 | 7.9 | 24.8 |
| **Stable Audio Open** | 1056M | 47 sec | 100 | 89.2 | 2.58 | 0.291 | 9.9 | 8.6 |
| **Tango 2** | 866M | 10 sec | 200 | 108.4 | **1.11** | 0.447 | 9.0 | 22.8 |
| **TangoFlux-base** | **515M** | 30 sec | 50 | 80.2 | 1.22 | 0.431 | 11.7 | **3.7** |
| **TangoFlux** | **515M** | 30 sec | 50 | **75.1** | 1.15 | **0.480** | **12.2** | **3.7** |
| Model | Params | Duration | Steps | FD<sub>openl3</sub> ↓ | KL<sub>passt</sub> ↓ | CLAP<sub>score</sub> ↑ | IS ↑ | Inference Time (s) |
|---|---|---|---|---|---|---|---|---|
| **AudioLDM 2 (Large)** | 712M | 10 sec | 200 | 108.3 | 1.81 | 0.419 | 7.9 | 24.8 |
| **Stable Audio Open** | 1056M | 47 sec | 100 | 89.2 | 2.58 | 0.291 | 9.9 | 8.6 |
| **Tango 2** | 866M | 10 sec | 200 | 108.4 | 1.11 | 0.447 | 9.0 | 22.8 |
| **TangoFlux (Base)** | 515M | 30 sec | 50 | 80.2 | 1.22 | 0.431 | 11.7 | 3.7 |
| **TangoFlux** | 515M | 30 sec | 50 | 75.1 | 1.15 | 0.480 | 12.2 | 3.7 |
## Citation
......@@ -100,3 +113,7 @@ All the inference times are observed on the same A40 GPU. The counts of trainabl
url={https://arxiv.org/abs/2412.21037},
}
```
## License
TangoFlux is licensed under the MIT License. See the `LICENSE` file for more details.
*.wav
\ No newline at end of file
import torchaudio
from tangoflux import TangoFluxInference
model = TangoFluxInference(name="declare-lab/TangoFlux")
audio = model.generate("Hammer slowly hitting the wooden table", steps=50, duration=10)
torchaudio.save("output.wav", audio, sample_rate=44100)
......@@ -10,11 +10,13 @@ from diffusers import AutoencoderOobleck
import soundfile as sf
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
from src.model import TangoFlux
from tangoflux.model import TangoFlux
from tangoflux import TangoFluxInference
MODEL_CACHE = "model_cache"
MODEL_URL = "https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
MODEL_URL = (
"https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
)
class CachedTangoFluxInference(TangoFluxInference):
......
from setuptools import setup
setup(
name="tangoflux",
description="TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching",
version="0.1.0",
packages=["tangoflux"],
install_requires=[
"torch==2.4.0",
"torchaudio==2.4.0",
"torchlibrosa==0.1.0",
"torchvision==0.19.0",
"transformers==4.44.0",
"diffusers==0.30.0",
"accelerate==0.34.2",
"datasets==2.21.0",
"librosa",
"tqdm",
"wandb",
"click",
"gradio",
"torchaudio",
],
entry_points={
"console_scripts": [
"tangoflux=tangoflux.cli:main",
"tangoflux-demo=tangoflux.demo:main",
],
},
)
from diffusers import AutoencoderOobleck
import torch
from transformers import T5EncoderModel,T5TokenizerFast
from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import FluxTransformer2DModel
from torch import nn
from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler
......@@ -9,10 +9,10 @@ from diffusers.training_utils import compute_density_for_timestep_sampling
import copy
import torch.nn.functional as F
import numpy as np
from src.model import TangoFlux
from tangoflux.model import TangoFlux
from huggingface_hub import snapshot_download
from tqdm import tqdm
from typing import Optional,Union,List
from typing import Optional, Union, List
from datasets import load_dataset, Audio
from math import pi
import json
......@@ -23,39 +23,38 @@ from safetensors.torch import load_file
class TangoFluxInference:
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
def __init__(
self,
name="declare-lab/TangoFlux",
device="cuda" if torch.cuda.is_available() else "cpu",
):
self.vae = AutoencoderOobleck()
paths = snapshot_download(repo_id=name)
paths = snapshot_download(repo_id=name)
vae_weights = load_file("{}/vae.safetensors".format(paths))
self.vae.load_state_dict(vae_weights)
weights = load_file("{}/tangoflux.safetensors".format(paths))
with open('{}/config.json'.format(paths),'r') as f:
with open("{}/config.json".format(paths), "r") as f:
config = json.load(f)
self.model = TangoFlux(config)
self.model.load_state_dict(weights,strict=False)
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
self.model.load_state_dict(weights, strict=False)
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
self.vae.to(device)
self.model.to(device)
def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
with torch.no_grad():
latents = self.model.inference_flow(prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale)
def generate(self, prompt, steps=25, duration=10, guidance_scale=4.5):
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
with torch.no_grad():
latents = self.model.inference_flow(
prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale,
)
wave = self.vae.decode(latents.transpose(2, 1)).sample.cpu()[0]
waveform_end = int(duration * self.vae.config.sampling_rate)
wave = wave[:, :waveform_end]
wave = wave[:, :waveform_end]
return wave
import click
import torchaudio
from tangoflux import TangoFluxInference
@click.command()
@click.argument('prompt')
@click.argument('output_file')
@click.option('--duration', default=10, type=int, help='Duration in seconds (1-30)')
@click.option('--steps', default=50, type=int, help='Number of inference steps (10-100)')
def main(prompt: str, output_file: str, duration: int, steps: int):
"""Generate audio from text using TangoFlux.
Args:
prompt: Text description of the audio to generate
output_file: Path to save the generated audio file
duration: Duration of generated audio in seconds (default: 10)
steps: Number of inference steps (default: 50)
"""
if not 1 <= duration <= 30:
raise click.BadParameter('Duration must be between 1 and 30 seconds')
if not 10 <= steps <= 100:
raise click.BadParameter('Steps must be between 10 and 100')
model = TangoFluxInference(name="declare-lab/TangoFlux")
audio = model.generate(prompt, steps=steps, duration=duration)
torchaudio.save(output_file, audio, sample_rate=44100)
if __name__ == '__main__':
main()
import gradio as gr
import torchaudio
import click
import tempfile
from tangoflux import TangoFluxInference
model = TangoFluxInference(name="declare-lab/TangoFlux")
def generate_audio(prompt, duration, steps):
audio = model.generate(prompt, steps=steps, duration=duration)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
torchaudio.save(f.name, audio, sample_rate=44100)
return f.name
examples = [
["Hammer slowly hitting the wooden table", 10, 50],
["Gentle rain falling on a tin roof", 15, 50],
["Wind chimes tinkling in a light breeze", 10, 50],
["Rhythmic wooden table tapping overlaid with steady water pouring sound", 10, 50],
]
with gr.Blocks(title="TangoFlux Text-to-Audio Generation") as demo:
gr.Markdown("# TangoFlux Text-to-Audio Generation")
gr.Markdown("Generate audio from text descriptions using TangoFlux")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Text Prompt", placeholder="Enter your audio description..."
)
duration = gr.Slider(
minimum=1, maximum=30, value=10, step=1, label="Duration (seconds)"
)
steps = gr.Slider(
minimum=10, maximum=100, value=50, step=10, label="Number of Steps"
)
generate_btn = gr.Button("Generate Audio")
with gr.Column():
audio_output = gr.Audio(label="Generated Audio")
generate_btn.click(
fn=generate_audio, inputs=[prompt, duration, steps], outputs=audio_output
)
gr.Examples(
examples=examples,
inputs=[prompt, duration, steps],
outputs=audio_output,
fn=generate_audio,
)
@click.command()
@click.option('--host', default='127.0.0.1', help='Host to bind to')
@click.option('--port', default=None, help='Port to bind to')
@click.option('--share', is_flag=True, help='Enable sharing via Gradio')
def main(host, port, share):
demo.queue().launch(server_name=host, server_port=port, share=share)
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -11,7 +11,7 @@ import numpy as np
import numpy as np
def normalize_wav(waveform):
waveform = waveform - torch.mean(waveform)
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
......@@ -20,7 +20,7 @@ def normalize_wav(waveform):
def pad_wav(waveform, segment_length):
waveform_length = len(waveform)
if segment_length is None or waveform_length == segment_length:
return waveform
elif waveform_length > segment_length:
......@@ -29,40 +29,47 @@ def pad_wav(waveform, segment_length):
padded_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
waveform = torch.cat([waveform, padded_wav])
return waveform
def read_wav_file(filename, duration_sec):
info = torchaudio.info(filename)
sample_rate = info.sample_rate
# Calculate the number of frames corresponding to the desired duration
num_frames = int(sample_rate * duration_sec)
waveform, sr = torchaudio.load(filename,num_frames=num_frames) # Faster!!!
waveform, sr = torchaudio.load(filename, num_frames=num_frames) # Faster!!!
if waveform.shape[0] == 2 : ## Stereo audio
if waveform.shape[0] == 2: ## Stereo audio
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=44100)
resampled_waveform = resampler(waveform)
#print(resampled_waveform.shape)
padded_left = pad_wav(resampled_waveform[0], int(44100*duration_sec)) ## We pad left and right seperately
padded_right = pad_wav(resampled_waveform[1], int(44100*duration_sec))
# print(resampled_waveform.shape)
padded_left = pad_wav(
resampled_waveform[0], int(44100 * duration_sec)
) ## We pad left and right seperately
padded_right = pad_wav(resampled_waveform[1], int(44100 * duration_sec))
return torch.stack([padded_left,padded_right])
return torch.stack([padded_left, padded_right])
else:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=44100)[0]
waveform = pad_wav(waveform, int(44100*duration_sec)).unsqueeze(0)
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=44100
)[0]
waveform = pad_wav(waveform, int(44100 * duration_sec)).unsqueeze(0)
return waveform
class DPOText2AudioDataset(Dataset):
def __init__(self, dataset, prefix, text_column, audio_w_column, audio_l_column, duration, num_examples=-1):
def __init__(
self,
dataset,
prefix,
text_column,
audio_w_column,
audio_l_column,
duration,
num_examples=-1,
):
inputs = list(dataset[text_column])
self.inputs = [prefix + inp for inp in inputs]
......@@ -72,11 +79,18 @@ class DPOText2AudioDataset(Dataset):
self.indices = list(range(len(self.inputs)))
self.mapper = {}
for index, audio_w, audio_l, duration, text in zip(self.indices, self.audios_w,self.audios_l,self.durations,inputs):
for index, audio_w, audio_l, duration, text in zip(
self.indices, self.audios_w, self.audios_l, self.durations, inputs
):
self.mapper[index] = [audio_w, audio_l, duration, text]
if num_examples != -1:
self.inputs, self.audios_w, self.audios_l, self.durations = self.inputs[:num_examples], self.audios_w[:num_examples], self.audios_l[:num_examples], self.durations[:num_examples]
self.inputs, self.audios_w, self.audios_l, self.durations = (
self.inputs[:num_examples],
self.audios_w[:num_examples],
self.audios_l[:num_examples],
self.durations[:num_examples],
)
self.indices = self.indices[:num_examples]
def __len__(self):
......@@ -86,15 +100,24 @@ class DPOText2AudioDataset(Dataset):
return len(self.inputs)
def __getitem__(self, index):
s1, s2, s3, s4, s5 = self.inputs[index], self.audios_w[index], self.audios_l[index], self.durations[index], self.indices[index]
s1, s2, s3, s4, s5 = (
self.inputs[index],
self.audios_w[index],
self.audios_l[index],
self.durations[index],
self.indices[index],
)
return s1, s2, s3, s4, s5
def collate_fn(self, data):
dat = pd.DataFrame(data)
return [dat[i].tolist() for i in dat]
class Text2AudioDataset(Dataset):
def __init__(self, dataset, prefix, text_column, audio_column, duration, num_examples=-1):
def __init__(
self, dataset, prefix, text_column, audio_column, duration, num_examples=-1
):
inputs = list(dataset[text_column])
self.inputs = [prefix + inp for inp in inputs]
......@@ -103,11 +126,17 @@ class Text2AudioDataset(Dataset):
self.indices = list(range(len(self.inputs)))
self.mapper = {}
for index, audio, duration,text in zip(self.indices, self.audios, self.durations,inputs):
self.mapper[index] = [audio, text,duration]
for index, audio, duration, text in zip(
self.indices, self.audios, self.durations, inputs
):
self.mapper[index] = [audio, text, duration]
if num_examples != -1:
self.inputs, self.audios, self.durations = self.inputs[:num_examples], self.audios[:num_examples], self.durations[:num_examples]
self.inputs, self.audios, self.durations = (
self.inputs[:num_examples],
self.audios[:num_examples],
self.durations[:num_examples],
)
self.indices = self.indices[:num_examples]
def __len__(self):
......@@ -117,7 +146,12 @@ class Text2AudioDataset(Dataset):
return len(self.inputs)
def __getitem__(self, index):
s1, s2, s3, s4 = self.inputs[index], self.audios[index], self.durations[index], self.indices[index]
s1, s2, s3, s4 = (
self.inputs[index],
self.audios[index],
self.durations[index],
self.indices[index],
)
return s1, s2, s3, s4
def collate_fn(self, data):
......
CUDA_VISISBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
\ No newline at end of file
CUDA_VISISBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' tangoflux/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment