Commit 0db67215 authored by mrfakename's avatar mrfakename
Browse files

pip package

parent f7cfcd21
~__pycache__/
__pycache__/
*.py[cod]
*$py.class
......@@ -168,3 +168,8 @@ cython_debug/
# PyPI configuration file
.pypirc
.DS_Store
*.wav
This diff is collapsed.
<h1 align="center">
<br/>
TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization
<br/>
✨✨✨
</h1>
# TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization
<div align="center">
<img src="assets/tf_teaser.png" alt="TangoFlux" width="1000" />
<br/>
[![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO) [![Replicate](https://replicate.com/chenxwh/tangoflux/badge)](https://replicate.com/chenxwh/tangoflux)
[![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Hugging_Face_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO) [![Replicate](https://replicate.com/chenxwh/tangoflux/badge)](https://replicate.com/chenxwh/tangoflux)
</div>
## Quickstart on Google Colab
## Demos
[![Hugging Face Space](https://img.shields.io/badge/Hugging_Face_Space-TangoFlux-blue?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux)
| Colab |
| --- |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1j__4fl_BlaVS_225M34d-EKxsVDJPRiR?usp=sharing)
## Overall Pipeline
TangoFlux consists of FluxTransformer blocks, which are Diffusion Transformers (DiT) and Multimodal Diffusion Transformers (MMDiT) conditioned on a textual prompt and a duration embedding to generate a 44.1kHz audio up to 30 seconds long. TangoFlux learns a rectified flow trajectory to an audio latent representation encoded by a variational autoencoder (VAE). TangoFlux training pipeline consists of three stages: pre-training, fine-tuning, and preference optimization with CRPO. CRPO, particularly, iteratively generates new synthetic data and constructs preference pairs for preference optimization using DPO loss for flow matching.
![cover-photo](assets/tangoflux.png)
🚀 **TangoFlux can generate 44.1kHz stereo audio up to 30 seconds in ~3 seconds on a single A40 GPU.**
🚀 **TangoFlux can generate up to 30 seconds long 44.1kHz stereo audios in about 3 seconds.**
## Installation
## Training TangoFlux
We use the accelerate package from HuggingFace for multi-gpu training. Run accelerate config from terminal and set up your run configuration by the answering the questions asked. We have placed the default accelerator config in the `configs` folder. Please specify the path to your training files in the configs/tangoflux_config.yaml. A sample of train.json and val.json has been provided. Replace them with your own audio.
```bash
pip install tangoflux
```
## Inference
TangoFlux can generate audio up to 30 seconds long. You must pass a duration to the `model.generate` function when using the Python API. Please note that duration should be between 1 and 30.
### Web Interface
Run the following command to start the web interface:
`tangoflux_config.yaml` defines the training file paths and model hyperparameters:
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
tangoflux-demo
```
To perform DPO training, modify the training files such that each data point contains a "chosen","reject","caption" and "duration". Please specify the path to your training files in the configs/tangoflux_config.yaml. An example has been provided in train_dpo.json. Replace them with your own audio.
### CLI
Use the CLI to generate audio from text.
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train_dpo.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
tangoflux "Hammer slowly hitting the wooden table" output.wav --duration 10 --steps 50
```
## Inference with TangoFlux
Download the TangoFlux model and generate audio from a text prompt.
TangoFlux can generate audios up to 30 second long through passing in a duration variable in the `model.generate` function. Please note that duration should be strictly greather than 1 and lesser than 30.
### Python API
```python
import torchaudio
from tangoflux import TangoFluxInference
from IPython.display import Audio
model = TangoFluxInference(name='declare-lab/TangoFlux')
audio = model.generate('Hammer slowly hitting the wooden table', steps=50, duration=10)
Audio(data=audio, rate=44100)
torchaudio.save('output.wav', audio, 44100)
```
Our evaluation shows that inference with 50 steps yields the best results. A CFG scale of 3.5, 4, and 4.5 yield similar quality output. Inference with 25 steps yields similar audio quality at a faster speed.
## Training
We use the `accelerate` package from Hugging Face for multi-GPU training. Run `accelerate config` to setup your run configuration. The default accelerate config is in the `configs` folder. Please specify the path to your training files in the `configs/tangoflux_config.yaml`. Samples of `train.json` and `val.json` have been provided. Replace them with your own audio.
`tangoflux_config.yaml` defines the training file paths and model hyperparameters:
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
```
To perform DPO training, modify the training files such that each data point contains "chosen", "reject", "caption" and "duration" fields. Please specify the path to your training files in `configs/tangoflux_config.yaml`. An example has been provided in `train_dpo.json`. Replace it with your own audio.
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train_dpo.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
```
Our evaluation shows that inferring with 50 steps yield the best results. A CFG scale of 3.5, 4, and 4.5 yield simliar quality output.
For faster inference, consider setting steps to 25 that yield similar audio quality.
## Evaluation Scripts
......@@ -77,15 +93,13 @@ This key comparison metrics include:
All the inference times are observed on the same A40 GPU. The counts of trainable parameters are reported in the **\#Params** column.
| Model | \#Params | Duration | Steps | FD<sub>openl3</sub> ↓ | KL<sub>passt</sub> ↓ | CLAP<sub>score</sub> ↑ | IS ↑ | Inference Time (s) |
|---------------------------------|-----------|----------|-------|-----------------------|----------------------|------------------------|------|--------------------|
| **AudioLDM 2-large** | 712M | 10 sec | 200 | 108.3 | 1.81 | 0.419 | 7.9 | 24.8 |
| Model | Params | Duration | Steps | FD<sub>openl3</sub> ↓ | KL<sub>passt</sub> ↓ | CLAP<sub>score</sub> ↑ | IS ↑ | Inference Time (s) |
|---|---|---|---|---|---|---|---|---|
| **AudioLDM 2 (Large)** | 712M | 10 sec | 200 | 108.3 | 1.81 | 0.419 | 7.9 | 24.8 |
| **Stable Audio Open** | 1056M | 47 sec | 100 | 89.2 | 2.58 | 0.291 | 9.9 | 8.6 |
| **Tango 2** | 866M | 10 sec | 200 | 108.4 | **1.11** | 0.447 | 9.0 | 22.8 |
| **TangoFlux-base** | **515M** | 30 sec | 50 | 80.2 | 1.22 | 0.431 | 11.7 | **3.7** |
| **TangoFlux** | **515M** | 30 sec | 50 | **75.1** | 1.15 | **0.480** | **12.2** | **3.7** |
| **Tango 2** | 866M | 10 sec | 200 | 108.4 | 1.11 | 0.447 | 9.0 | 22.8 |
| **TangoFlux (Base)** | 515M | 30 sec | 50 | 80.2 | 1.22 | 0.431 | 11.7 | 3.7 |
| **TangoFlux** | 515M | 30 sec | 50 | 75.1 | 1.15 | 0.480 | 12.2 | 3.7 |
## Citation
......@@ -100,3 +114,7 @@ All the inference times are observed on the same A40 GPU. The counts of trainabl
url={https://arxiv.org/abs/2412.21037},
}
```
## License
TangoFlux is licensed under the MIT License. See the `LICENSE` file for more details.
\ No newline at end of file
*.wav
\ No newline at end of file
import torchaudio
from tangoflux import TangoFluxInference
model = TangoFluxInference(name="declare-lab/TangoFlux")
audio = model.generate("Hammer slowly hitting the wooden table", steps=50, duration=10)
torchaudio.save("output.wav", audio, sample_rate=44100)
......@@ -10,11 +10,13 @@ from diffusers import AutoencoderOobleck
import soundfile as sf
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
from src.model import TangoFlux
from tangoflux.model import TangoFlux
from tangoflux import TangoFluxInference
MODEL_CACHE = "model_cache"
MODEL_URL = "https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
MODEL_URL = (
"https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
)
class CachedTangoFluxInference(TangoFluxInference):
......
from setuptools import setup
setup(
name="tangoflux",
description="TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching",
version="0.1.0",
packages=["tangoflux"],
install_requires=[
"torch==2.4.0",
"torchaudio==2.4.0",
"torchlibrosa==0.1.0",
"torchvision==0.19.0",
"transformers==4.44.0",
"diffusers==0.30.0",
"accelerate==0.34.2",
"datasets==2.21.0",
"librosa",
"tqdm",
"wandb",
"click",
"gradio",
"torchaudio",
],
entry_points={
"console_scripts": [
"tangoflux=tangoflux.cli:main",
"tangoflux-demo=tangoflux.demo:main",
],
},
)
from diffusers import AutoencoderOobleck
import torch
from transformers import T5EncoderModel,T5TokenizerFast
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import FluxTransformer2DModel
from torch import nn
from typing import List
......@@ -9,10 +9,10 @@ from diffusers.training_utils import compute_density_for_timestep_sampling
import copy
import torch.nn.functional as F
import numpy as np
from src.model import TangoFlux
from tangoflux.model import TangoFlux
from huggingface_hub import snapshot_download
from tqdm import tqdm
from typing import Optional,Union,List
from typing import Optional, Union, List
from datasets import load_dataset, Audio
from math import pi
import json
......@@ -23,8 +23,11 @@ from safetensors.torch import load_file
class TangoFluxInference:
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
def __init__(
self,
name="declare-lab/TangoFlux",
device="cuda" if torch.cuda.is_available() else "cpu",
):
self.vae = AutoencoderOobleck()
......@@ -33,29 +36,25 @@ class TangoFluxInference:
self.vae.load_state_dict(vae_weights)
weights = load_file("{}/tangoflux.safetensors".format(paths))
with open('{}/config.json'.format(paths),'r') as f:
with open("{}/config.json".format(paths), "r") as f:
config = json.load(f)
self.model = TangoFlux(config)
self.model.load_state_dict(weights,strict=False)
self.model.load_state_dict(weights, strict=False)
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
self.vae.to(device)
self.model.to(device)
def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
def generate(self, prompt, steps=25, duration=10, guidance_scale=4.5):
with torch.no_grad():
latents = self.model.inference_flow(prompt,
latents = self.model.inference_flow(
prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale)
guidance_scale=guidance_scale,
)
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
wave = self.vae.decode(latents.transpose(2, 1)).sample.cpu()[0]
waveform_end = int(duration * self.vae.config.sampling_rate)
wave = wave[:, :waveform_end]
return wave
import click
import torchaudio
from tangoflux import TangoFluxInference
@click.command()
@click.argument('prompt')
@click.argument('output_file')
@click.option('--duration', default=10, type=int, help='Duration in seconds (1-30)')
@click.option('--steps', default=50, type=int, help='Number of inference steps (10-100)')
def main(prompt: str, output_file: str, duration: int, steps: int):
"""Generate audio from text using TangoFlux.
Args:
prompt: Text description of the audio to generate
output_file: Path to save the generated audio file
duration: Duration of generated audio in seconds (default: 10)
steps: Number of inference steps (default: 50)
"""
if not 1 <= duration <= 30:
raise click.BadParameter('Duration must be between 1 and 30 seconds')
if not 10 <= steps <= 100:
raise click.BadParameter('Steps must be between 10 and 100')
model = TangoFluxInference(name="declare-lab/TangoFlux")
audio = model.generate(prompt, steps=steps, duration=duration)
torchaudio.save(output_file, audio, sample_rate=44100)
if __name__ == '__main__':
main()
import gradio as gr
import torchaudio
import click
import tempfile
from tangoflux import TangoFluxInference
model = TangoFluxInference(name="declare-lab/TangoFlux")
def generate_audio(prompt, duration, steps):
audio = model.generate(prompt, steps=steps, duration=duration)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
torchaudio.save(f.name, audio, sample_rate=44100)
return f.name
examples = [
["Hammer slowly hitting the wooden table", 10, 50],
["Gentle rain falling on a tin roof", 15, 50],
["Wind chimes tinkling in a light breeze", 10, 50],
["Rhythmic wooden table tapping overlaid with steady water pouring sound", 10, 50],
]
with gr.Blocks(title="TangoFlux Text-to-Audio Generation") as demo:
gr.Markdown("# TangoFlux Text-to-Audio Generation")
gr.Markdown("Generate audio from text descriptions using TangoFlux")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Text Prompt", placeholder="Enter your audio description..."
)
duration = gr.Slider(
minimum=1, maximum=30, value=10, step=1, label="Duration (seconds)"
)
steps = gr.Slider(
minimum=10, maximum=100, value=50, step=10, label="Number of Steps"
)
generate_btn = gr.Button("Generate Audio")
with gr.Column():
audio_output = gr.Audio(label="Generated Audio")
generate_btn.click(
fn=generate_audio, inputs=[prompt, duration, steps], outputs=audio_output
)
gr.Examples(
examples=examples,
inputs=[prompt, duration, steps],
outputs=audio_output,
fn=generate_audio,
)
@click.command()
@click.option('--host', default='127.0.0.1', help='Host to bind to')
@click.option('--port', default=None, help='Port to bind to')
@click.option('--share', is_flag=True, help='Enable sharing via Gradio')
def main(host, port, share):
demo.queue().launch(server_name=host, server_port=port, share=share)
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -31,8 +31,6 @@ def pad_wav(waveform, segment_length):
return waveform
def read_wav_file(filename, duration_sec):
info = torchaudio.info(filename)
sample_rate = info.sample_rate
......@@ -40,29 +38,38 @@ def read_wav_file(filename, duration_sec):
# Calculate the number of frames corresponding to the desired duration
num_frames = int(sample_rate * duration_sec)
waveform, sr = torchaudio.load(filename,num_frames=num_frames) # Faster!!!
waveform, sr = torchaudio.load(filename, num_frames=num_frames) # Faster!!!
if waveform.shape[0] == 2 : ## Stereo audio
if waveform.shape[0] == 2: ## Stereo audio
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=44100)
resampled_waveform = resampler(waveform)
#print(resampled_waveform.shape)
padded_left = pad_wav(resampled_waveform[0], int(44100*duration_sec)) ## We pad left and right seperately
padded_right = pad_wav(resampled_waveform[1], int(44100*duration_sec))
# print(resampled_waveform.shape)
padded_left = pad_wav(
resampled_waveform[0], int(44100 * duration_sec)
) ## We pad left and right seperately
padded_right = pad_wav(resampled_waveform[1], int(44100 * duration_sec))
return torch.stack([padded_left,padded_right])
return torch.stack([padded_left, padded_right])
else:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=44100)[0]
waveform = pad_wav(waveform, int(44100*duration_sec)).unsqueeze(0)
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=44100
)[0]
waveform = pad_wav(waveform, int(44100 * duration_sec)).unsqueeze(0)
return waveform
class DPOText2AudioDataset(Dataset):
def __init__(self, dataset, prefix, text_column, audio_w_column, audio_l_column, duration, num_examples=-1):
def __init__(
self,
dataset,
prefix,
text_column,
audio_w_column,
audio_l_column,
duration,
num_examples=-1,
):
inputs = list(dataset[text_column])
self.inputs = [prefix + inp for inp in inputs]
......@@ -72,11 +79,18 @@ class DPOText2AudioDataset(Dataset):
self.indices = list(range(len(self.inputs)))
self.mapper = {}
for index, audio_w, audio_l, duration, text in zip(self.indices, self.audios_w,self.audios_l,self.durations,inputs):
for index, audio_w, audio_l, duration, text in zip(
self.indices, self.audios_w, self.audios_l, self.durations, inputs
):
self.mapper[index] = [audio_w, audio_l, duration, text]
if num_examples != -1:
self.inputs, self.audios_w, self.audios_l, self.durations = self.inputs[:num_examples], self.audios_w[:num_examples], self.audios_l[:num_examples], self.durations[:num_examples]
self.inputs, self.audios_w, self.audios_l, self.durations = (
self.inputs[:num_examples],
self.audios_w[:num_examples],
self.audios_l[:num_examples],
self.durations[:num_examples],
)
self.indices = self.indices[:num_examples]
def __len__(self):
......@@ -86,15 +100,24 @@ class DPOText2AudioDataset(Dataset):
return len(self.inputs)
def __getitem__(self, index):
s1, s2, s3, s4, s5 = self.inputs[index], self.audios_w[index], self.audios_l[index], self.durations[index], self.indices[index]
s1, s2, s3, s4, s5 = (
self.inputs[index],
self.audios_w[index],
self.audios_l[index],
self.durations[index],
self.indices[index],
)
return s1, s2, s3, s4, s5
def collate_fn(self, data):
dat = pd.DataFrame(data)
return [dat[i].tolist() for i in dat]
class Text2AudioDataset(Dataset):
def __init__(self, dataset, prefix, text_column, audio_column, duration, num_examples=-1):
def __init__(
self, dataset, prefix, text_column, audio_column, duration, num_examples=-1
):
inputs = list(dataset[text_column])
self.inputs = [prefix + inp for inp in inputs]
......@@ -103,11 +126,17 @@ class Text2AudioDataset(Dataset):
self.indices = list(range(len(self.inputs)))
self.mapper = {}
for index, audio, duration,text in zip(self.indices, self.audios, self.durations,inputs):
self.mapper[index] = [audio, text,duration]
for index, audio, duration, text in zip(
self.indices, self.audios, self.durations, inputs
):
self.mapper[index] = [audio, text, duration]
if num_examples != -1:
self.inputs, self.audios, self.durations = self.inputs[:num_examples], self.audios[:num_examples], self.durations[:num_examples]
self.inputs, self.audios, self.durations = (
self.inputs[:num_examples],
self.audios[:num_examples],
self.durations[:num_examples],
)
self.indices = self.indices[:num_examples]
def __len__(self):
......@@ -117,7 +146,12 @@ class Text2AudioDataset(Dataset):
return len(self.inputs)
def __getitem__(self, index):
s1, s2, s3, s4 = self.inputs[index], self.audios[index], self.durations[index], self.indices[index]
s1, s2, s3, s4 = (
self.inputs[index],
self.audios[index],
self.durations[index],
self.indices[index],
)
return s1, s2, s3, s4
def collate_fn(self, data):
......
CUDA_VISISBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
\ No newline at end of file
CUDA_VISISBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' tangoflux/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment