Unverified Commit 1453841b authored by Chia-Yu Hung's avatar Chia-Yu Hung Committed by GitHub
Browse files

Merge pull request #6 from fakerybakery/main

pip package
parents 7090a624 8fd3cc82
~__pycache__/
__pycache__/
*.py[cod]
*$py.class
......@@ -168,3 +168,8 @@ cython_debug/
# PyPI configuration file
.pypirc
.DS_Store
*.wav
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install git+https://github.com/declare-lab/TangoFlux.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import IPython\n",
"import torchaudio\n",
"from tangoflux import TangoFluxInference\n",
"from IPython.display import Audio\n",
"\n",
"model = TangoFluxInference(name='declare-lab/TangoFlux')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# @title Generate Audio\n",
"\n",
"prompt = 'Hammer slowly hitting the wooden table' # @param {type:\"string\"}\n",
"duration = 10 # @param {type:\"number\"}\n",
"steps = 50 # @param {type:\"number\"}\n",
"\n",
"audio = model.generate(prompt, steps=steps, duration=duration)\n",
"\n",
"Audio(data=audio, rate=44100)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
This source diff could not be displayed because it is too large. You can view the blob instead.
<h1 align="center">
<br/>
TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization
<br/>
✨✨✨
</h1>
# TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization
<div align="center">
<img src="assets/tf_teaser.png" alt="TangoFlux" width="1000" />
<br/>
[![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Hugging_Face-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Hugging_Face_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Hugging_Face-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO) [![Replicate](https://replicate.com/chenxwh/tangoflux/badge)](https://replicate.com/chenxwh/tangoflux)
<br/>
</div>
[![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO) [![Replicate](https://replicate.com/chenxwh/tangoflux/badge)](https://replicate.com/chenxwh/tangoflux)
## Demos
[![Hugging Face Space](https://img.shields.io/badge/Hugging_Face_Space-TangoFlux-blue?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/declare-lab/TangoFlux/blob/main/Demo.ipynb)
## Overall Pipeline
TangoFlux consists of FluxTransformer blocks, which are Diffusion Transformers (DiT) and Multimodal Diffusion Transformers (MMDiT) conditioned on a textual prompt and a duration embedding to generate a 44.1kHz audio up to 30 seconds long. TangoFlux learns a rectified flow trajectory to an audio latent representation encoded by a variational autoencoder (VAE). TangoFlux training pipeline consists of three stages: pre-training, fine-tuning, and preference optimization with CRPO. CRPO, particularly, iteratively generates new synthetic data and constructs preference pairs for preference optimization using DPO loss for flow matching.
</div>
![cover-photo](assets/tangoflux.png)
## Quickstart on Google Colab
🚀 **TangoFlux can generate 44.1kHz stereo audio up to 30 seconds in ~3 seconds on a single A40 GPU.**
| Colab |
| --- |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1j__4fl_BlaVS_225M34d-EKxsVDJPRiR?usp=sharing)
## Installation
## Overall Pipeline
TangoFlux consists of FluxTransformer blocks, which are Diffusion Transformers (DiT) and Multimodal Diffusion Transformers (MMDiT) conditioned on a textual prompt and a duration embedding to generate a 44.1kHz audio up to 30 seconds long. TangoFlux learns a rectified flow trajectory to an audio latent representation encoded by a variational autoencoder (VAE). TangoFlux training pipeline consists of three stages: pre-training, fine-tuning, and preference optimization with CRPO. CRPO, particularly, iteratively generates new synthetic data and constructs preference pairs for preference optimization using DPO loss for flow matching.
```bash
pip install git+https://github.com/declare-lab/TangoFlux
```
![cover-photo](assets/tangoflux.png)
## Inference
TangoFlux can generate audio up to 30 seconds long. You must pass a duration to the `model.generate` function when using the Python API. Please note that duration should be between 1 and 30.
🚀 **TangoFlux can generate up to 30 seconds long 44.1kHz stereo audios in about 3 seconds on an A40 GPU.**
### Web Interface
## Training TangoFlux
We use the accelerate package from HuggingFace for multi-gpu training. Run accelerate config from terminal and set up your run configuration by the answering the questions asked. We have placed the default accelerator config in the `configs` folder. Please specify the path to your training files in the configs/tangoflux_config.yaml. A sample of train.json and val.json has been provided. Replace them with your own audio.
Run the following command to start the web interface:
`tangoflux_config.yaml` defines the training file paths and model hyperparameters:
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
tangoflux-demo
```
To perform DPO training, modify the training files such that each data point contains a "chosen","reject","caption" and "duration". Please specify the path to your training files in the configs/tangoflux_config.yaml. An example has been provided in train_dpo.json. Replace them with your own audio.
### CLI
Use the CLI to generate audio from text.
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train_dpo.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
tangoflux "Hammer slowly hitting the wooden table" output.wav --duration 10 --steps 50
```
## Inference with TangoFlux
Download the TangoFlux model and generate audio from a text prompt.
TangoFlux can generate audios up to 30 second long through passing in a duration variable in the `model.generate` function. Please note that duration should be strictly greather than 1 and lesser than 30.
### Python API
```python
import torchaudio
from tangoflux import TangoFluxInference
from IPython.display import Audio
model = TangoFluxInference(name='declare-lab/TangoFlux')
audio = model.generate('Hammer slowly hitting the wooden table', steps=50, duration=10)
Audio(data=audio, rate=44100)
torchaudio.save('output.wav', audio, 44100)
```
Our evaluation shows that inference with 50 steps yields the best results. A CFG scale of 3.5, 4, and 4.5 yield similar quality output. Inference with 25 steps yields similar audio quality at a faster speed.
## Training
We use the `accelerate` package from Hugging Face for multi-GPU training. Run `accelerate config` to setup your run configuration. The default accelerate config is in the `configs` folder. Please specify the path to your training files in the `configs/tangoflux_config.yaml`. Samples of `train.json` and `val.json` have been provided. Replace them with your own audio.
`tangoflux_config.yaml` defines the training file paths and model hyperparameters:
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
```
To perform DPO training, modify the training files such that each data point contains "chosen", "reject", "caption" and "duration" fields. Please specify the path to your training files in `configs/tangoflux_config.yaml`. An example has been provided in `train_dpo.json`. Replace it with your own audio.
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train_dpo.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
```
Our evaluation shows that inferring with 50 steps yield the best results. A CFG scale of 3.5, 4, and 4.5 yield simliar quality output.
For faster inference, consider setting steps to 25 that yield similar audio quality.
## Evaluation Scripts
......@@ -77,15 +92,13 @@ This key comparison metrics include:
All the inference times are observed on the same A40 GPU. The counts of trainable parameters are reported in the **\#Params** column.
| Model | \#Params | Duration | Steps | FD<sub>openl3</sub> ↓ | KL<sub>passt</sub> ↓ | CLAP<sub>score</sub> ↑ | IS ↑ | Inference Time (s) |
|---------------------------------|-----------|----------|-------|-----------------------|----------------------|------------------------|------|--------------------|
| **AudioLDM 2-large** | 712M | 10 sec | 200 | 108.3 | 1.81 | 0.419 | 7.9 | 24.8 |
| **Stable Audio Open** | 1056M | 47 sec | 100 | 89.2 | 2.58 | 0.291 | 9.9 | 8.6 |
| **Tango 2** | 866M | 10 sec | 200 | 108.4 | **1.11** | 0.447 | 9.0 | 22.8 |
| **TangoFlux-base** | **515M** | 30 sec | 50 | 80.2 | 1.22 | 0.431 | 11.7 | **3.7** |
| **TangoFlux** | **515M** | 30 sec | 50 | **75.1** | 1.15 | **0.480** | **12.2** | **3.7** |
| Model | Params | Duration | Steps | FD<sub>openl3</sub> ↓ | KL<sub>passt</sub> ↓ | CLAP<sub>score</sub> ↑ | IS ↑ | Inference Time (s) |
|---|---|---|---|---|---|---|---|---|
| **AudioLDM 2 (Large)** | 712M | 10 sec | 200 | 108.3 | 1.81 | 0.419 | 7.9 | 24.8 |
| **Stable Audio Open** | 1056M | 47 sec | 100 | 89.2 | 2.58 | 0.291 | 9.9 | 8.6 |
| **Tango 2** | 866M | 10 sec | 200 | 108.4 | 1.11 | 0.447 | 9.0 | 22.8 |
| **TangoFlux (Base)** | 515M | 30 sec | 50 | 80.2 | 1.22 | 0.431 | 11.7 | 3.7 |
| **TangoFlux** | 515M | 30 sec | 50 | 75.1 | 1.15 | 0.480 | 12.2 | 3.7 |
## Citation
......@@ -100,3 +113,7 @@ All the inference times are observed on the same A40 GPU. The counts of trainabl
url={https://arxiv.org/abs/2412.21037},
}
```
## License
TangoFlux is licensed under the MIT License. See the `LICENSE` file for more details.
*.wav
\ No newline at end of file
import torchaudio
from tangoflux import TangoFluxInference
model = TangoFluxInference(name="declare-lab/TangoFlux")
audio = model.generate("Hammer slowly hitting the wooden table", steps=50, duration=10)
torchaudio.save("output.wav", audio, sample_rate=44100)
......@@ -10,11 +10,13 @@ from diffusers import AutoencoderOobleck
import soundfile as sf
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
from src.model import TangoFlux
from tangoflux.model import TangoFlux
from tangoflux import TangoFluxInference
MODEL_CACHE = "model_cache"
MODEL_URL = "https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
MODEL_URL = (
"https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
)
class CachedTangoFluxInference(TangoFluxInference):
......
from setuptools import setup
setup(
name="tangoflux",
description="TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching",
version="0.1.0",
packages=["tangoflux"],
install_requires=[
"torch==2.4.0",
"torchaudio==2.4.0",
"torchlibrosa==0.1.0",
"torchvision==0.19.0",
"transformers==4.44.0",
"diffusers==0.30.0",
"accelerate==0.34.2",
"datasets==2.21.0",
"librosa",
"tqdm",
"wandb",
"click",
"gradio",
"torchaudio",
],
entry_points={
"console_scripts": [
"tangoflux=tangoflux.cli:main",
"tangoflux-demo=tangoflux.demo:main",
],
},
)
from diffusers import AutoencoderOobleck
import torch
from transformers import T5EncoderModel,T5TokenizerFast
from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import FluxTransformer2DModel
from torch import nn
from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler
......@@ -9,10 +9,10 @@ from diffusers.training_utils import compute_density_for_timestep_sampling
import copy
import torch.nn.functional as F
import numpy as np
from src.model import TangoFlux
from tangoflux.model import TangoFlux
from huggingface_hub import snapshot_download
from tqdm import tqdm
from typing import Optional,Union,List
from typing import Optional, Union, List
from datasets import load_dataset, Audio
from math import pi
import json
......@@ -23,39 +23,38 @@ from safetensors.torch import load_file
class TangoFluxInference:
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
def __init__(
self,
name="declare-lab/TangoFlux",
device="cuda" if torch.cuda.is_available() else "cpu",
):
self.vae = AutoencoderOobleck()
paths = snapshot_download(repo_id=name)
paths = snapshot_download(repo_id=name)
vae_weights = load_file("{}/vae.safetensors".format(paths))
self.vae.load_state_dict(vae_weights)
weights = load_file("{}/tangoflux.safetensors".format(paths))
with open('{}/config.json'.format(paths),'r') as f:
with open("{}/config.json".format(paths), "r") as f:
config = json.load(f)
self.model = TangoFlux(config)
self.model.load_state_dict(weights,strict=False)
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
self.model.load_state_dict(weights, strict=False)
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
self.vae.to(device)
self.model.to(device)
def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
with torch.no_grad():
latents = self.model.inference_flow(prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale)
def generate(self, prompt, steps=25, duration=10, guidance_scale=4.5):
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
with torch.no_grad():
latents = self.model.inference_flow(
prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale,
)
wave = self.vae.decode(latents.transpose(2, 1)).sample.cpu()[0]
waveform_end = int(duration * self.vae.config.sampling_rate)
wave = wave[:, :waveform_end]
wave = wave[:, :waveform_end]
return wave
import click
import torchaudio
from tangoflux import TangoFluxInference
@click.command()
@click.argument('prompt')
@click.argument('output_file')
@click.option('--duration', default=10, type=int, help='Duration in seconds (1-30)')
@click.option('--steps', default=50, type=int, help='Number of inference steps (10-100)')
def main(prompt: str, output_file: str, duration: int, steps: int):
"""Generate audio from text using TangoFlux.
Args:
prompt: Text description of the audio to generate
output_file: Path to save the generated audio file
duration: Duration of generated audio in seconds (default: 10)
steps: Number of inference steps (default: 50)
"""
if not 1 <= duration <= 30:
raise click.BadParameter('Duration must be between 1 and 30 seconds')
if not 10 <= steps <= 100:
raise click.BadParameter('Steps must be between 10 and 100')
model = TangoFluxInference(name="declare-lab/TangoFlux")
audio = model.generate(prompt, steps=steps, duration=duration)
torchaudio.save(output_file, audio, sample_rate=44100)
if __name__ == '__main__':
main()
import gradio as gr
import torchaudio
import click
import tempfile
from tangoflux import TangoFluxInference
model = TangoFluxInference(name="declare-lab/TangoFlux")
def generate_audio(prompt, duration, steps):
audio = model.generate(prompt, steps=steps, duration=duration)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
torchaudio.save(f.name, audio, sample_rate=44100)
return f.name
examples = [
["Hammer slowly hitting the wooden table", 10, 50],
["Gentle rain falling on a tin roof", 15, 50],
["Wind chimes tinkling in a light breeze", 10, 50],
["Rhythmic wooden table tapping overlaid with steady water pouring sound", 10, 50],
]
with gr.Blocks(title="TangoFlux Text-to-Audio Generation") as demo:
gr.Markdown("# TangoFlux Text-to-Audio Generation")
gr.Markdown("Generate audio from text descriptions using TangoFlux")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Text Prompt", placeholder="Enter your audio description..."
)
duration = gr.Slider(
minimum=1, maximum=30, value=10, step=1, label="Duration (seconds)"
)
steps = gr.Slider(
minimum=10, maximum=100, value=50, step=10, label="Number of Steps"
)
generate_btn = gr.Button("Generate Audio")
with gr.Column():
audio_output = gr.Audio(label="Generated Audio")
generate_btn.click(
fn=generate_audio, inputs=[prompt, duration, steps], outputs=audio_output
)
gr.Examples(
examples=examples,
inputs=[prompt, duration, steps],
outputs=audio_output,
fn=generate_audio,
)
@click.command()
@click.option('--host', default='127.0.0.1', help='Host to bind to')
@click.option('--port', default=None, help='Port to bind to')
@click.option('--share', is_flag=True, help='Enable sharing via Gradio')
def main(host, port, share):
demo.queue().launch(server_name=host, server_port=port, share=share)
if __name__ == "__main__":
main()
from transformers import T5EncoderModel,T5TokenizerFast
from transformers import T5EncoderModel, T5TokenizerFast
import torch
from diffusers import FluxTransformer2DModel
from diffusers import FluxTransformer2DModel
from torch import nn
import random
from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_density_for_timestep_sampling
......@@ -11,19 +11,16 @@ import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from typing import Optional,Union,List
from typing import Optional, Union, List
from datasets import load_dataset, Audio
from math import pi
import inspect
import yaml
class StableAudioPositionalEmbedding(nn.Module):
"""Used for continuous time
Adapted from stable audio open.
Adapted from Stable Audio Open.
"""
def __init__(self, dim: int):
......@@ -38,7 +35,8 @@ class StableAudioPositionalEmbedding(nn.Module):
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((times, fouriered), dim=-1)
return fouriered
class DurationEmbedder(nn.Module):
"""
A simple linear projection model to map numbers to a latent space.
......@@ -73,7 +71,7 @@ class DurationEmbedder(nn.Module):
self.number_embedding_dim = number_embedding_dim
self.min_value = min_value
self.max_value = max_value
self.dtype = torch.float32
self.dtype = torch.float32
def forward(
self,
......@@ -81,7 +79,9 @@ class DurationEmbedder(nn.Module):
):
floats = floats.clamp(self.min_value, self.max_value)
normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value)
normalized_floats = (floats - self.min_value) / (
self.max_value - self.min_value
)
# Cast floats to same type as embedder
embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
......@@ -103,9 +103,13 @@ def retrieve_timesteps(
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
accepts_timesteps = "timesteps" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
......@@ -115,7 +119,9 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
accept_sigmas = "sigmas" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
......@@ -128,113 +134,115 @@ def retrieve_timesteps(
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class TangoFlux(nn.Module):
def __init__(self,config,initialize_reference_model=False):
def __init__(self, config, initialize_reference_model=False):
super().__init__()
self.num_layers = config.get('num_layers', 6)
self.num_single_layers = config.get('num_single_layers', 18)
self.in_channels = config.get('in_channels', 64)
self.attention_head_dim = config.get('attention_head_dim', 128)
self.joint_attention_dim = config.get('joint_attention_dim', 1024)
self.num_attention_heads = config.get('num_attention_heads', 8)
self.audio_seq_len = config.get('audio_seq_len', 645)
self.max_duration = config.get('max_duration', 30)
self.uncondition = config.get('uncondition', False)
self.text_encoder_name = config.get('text_encoder_name', "google/flan-t5-large")
self.num_layers = config.get("num_layers", 6)
self.num_single_layers = config.get("num_single_layers", 18)
self.in_channels = config.get("in_channels", 64)
self.attention_head_dim = config.get("attention_head_dim", 128)
self.joint_attention_dim = config.get("joint_attention_dim", 1024)
self.num_attention_heads = config.get("num_attention_heads", 8)
self.audio_seq_len = config.get("audio_seq_len", 645)
self.max_duration = config.get("max_duration", 30)
self.uncondition = config.get("uncondition", False)
self.text_encoder_name = config.get("text_encoder_name", "google/flan-t5-large")
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
self.max_text_seq_len = 64
self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
self.tokenizer = T5TokenizerFast.from_pretrained(self.text_encoder_name)
self.text_embedding_dim = self.text_encoder.config.d_model
self.fc = nn.Sequential(nn.Linear(self.text_embedding_dim,self.joint_attention_dim),nn.ReLU())
self.duration_emebdder = DurationEmbedder(self.text_embedding_dim,min_value=0,max_value=self.max_duration)
self.fc = nn.Sequential(
nn.Linear(self.text_embedding_dim, self.joint_attention_dim), nn.ReLU()
)
self.duration_emebdder = DurationEmbedder(
self.text_embedding_dim, min_value=0, max_value=self.max_duration
)
self.transformer = FluxTransformer2DModel(
in_channels=self.in_channels,
num_layers=self.num_layers,
num_single_layers=self.num_single_layers,
attention_head_dim=self.attention_head_dim,
num_attention_heads=self.num_attention_heads,
joint_attention_dim=self.joint_attention_dim,
pooled_projection_dim=self.text_embedding_dim,
guidance_embeds=False)
self.beta_dpo = 2000 ## this is used for dpo training
def get_sigmas(self,timesteps, n_dim=3, dtype=torch.float32):
in_channels=self.in_channels,
num_layers=self.num_layers,
num_single_layers=self.num_single_layers,
attention_head_dim=self.attention_head_dim,
num_attention_heads=self.num_attention_heads,
joint_attention_dim=self.joint_attention_dim,
pooled_projection_dim=self.text_embedding_dim,
guidance_embeds=False,
)
self.beta_dpo = 2000 ## this is used for dpo training
def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32):
device = self.text_encoder.device
sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1):
device = self.text_encoder.device
batch = self.tokenizer(
prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
prompt,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
)
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
device
)
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
with torch.no_grad():
prompt_embeds = self.text_encoder(
input_ids=input_ids, attention_mask=attention_mask
)[0]
prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
# get unconditional embeddings for classifier free guidance
uncond_tokens = [""]
uncond_tokens = [""]
max_length = prompt_embeds.shape[1]
uncond_batch = self.tokenizer(
uncond_tokens, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt",
uncond_tokens,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_batch.input_ids.to(device)
uncond_attention_mask = uncond_batch.attention_mask.to(device)
with torch.no_grad():
negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
)[0]
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(
num_samples_per_prompt, 0
)
uncond_attention_mask = uncond_attention_mask.repeat_interleave(
num_samples_per_prompt, 0
)
# For classifier free guidance, we need to do two forward passes.
# We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
boolean_prompt_mask = (prompt_mask == 1).to(device)
......@@ -245,266 +253,287 @@ class TangoFlux(nn.Module):
def encode_text(self, prompt):
device = self.text_encoder.device
batch = self.tokenizer(
prompt, max_length=self.max_text_seq_len, padding=True, truncation=True, return_tensors="pt")
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
prompt,
max_length=self.max_text_seq_len,
padding=True,
truncation=True,
return_tensors="pt",
)
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
device
)
encoder_hidden_states = self.text_encoder(
input_ids=input_ids, attention_mask=attention_mask)[0]
input_ids=input_ids, attention_mask=attention_mask
)[0]
boolean_encoder_mask = (attention_mask == 1).to(device)
return encoder_hidden_states, boolean_encoder_mask
def encode_duration(self,duration):
return self.duration_emebdder(duration)
def encode_duration(self, duration):
return self.duration_emebdder(duration)
@torch.no_grad()
def inference_flow(self, prompt,
num_inference_steps=50,
timesteps=None,
guidance_scale=3,
duration=10,
disable_progress=False,
num_samples_per_prompt=1):
'''Only tested for single inference. Haven't test for batch inference'''
def inference_flow(
self,
prompt,
num_inference_steps=50,
timesteps=None,
guidance_scale=3,
duration=10,
disable_progress=False,
num_samples_per_prompt=1,
):
"""Only tested for single inference. Haven't test for batch inference"""
bsz = num_samples_per_prompt
device = self.transformer.device
scheduler = self.noise_scheduler
if not isinstance(prompt,list):
if not isinstance(prompt, list):
prompt = [prompt]
if not isinstance(duration,torch.Tensor):
duration = torch.tensor([duration],device=device)
if not isinstance(duration, torch.Tensor):
duration = torch.tensor([duration], device=device)
classifier_free_guidance = guidance_scale > 1.0
duration_hidden_states = self.encode_duration(duration)
if classifier_free_guidance:
bsz = 2 * num_samples_per_prompt
encoder_hidden_states, boolean_encoder_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt=num_samples_per_prompt)
duration_hidden_states = duration_hidden_states.repeat(bsz,1,1)
encoder_hidden_states, boolean_encoder_mask = (
self.encode_text_classifier_free(
prompt, num_samples_per_prompt=num_samples_per_prompt
)
)
duration_hidden_states = duration_hidden_states.repeat(bsz, 1, 1)
else:
encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt,num_samples_per_prompt=num_samples_per_prompt)
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states)
masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
encoder_hidden_states, boolean_encoder_mask = self.encode_text(
prompt, num_samples_per_prompt=num_samples_per_prompt
)
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(
encoder_hidden_states
)
masked_data = torch.where(
mask_expanded, encoder_hidden_states, torch.tensor(float("nan"))
)
pooled = torch.nanmean(masked_data, dim=1)
pooled_projection = self.fc(pooled)
encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim)
encoder_hidden_states = torch.cat(
[encoder_hidden_states, duration_hidden_states], dim=1
) ## (bs,seq_len,dim)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
device,
timesteps,
sigmas
scheduler, num_inference_steps, device, timesteps, sigmas
)
latents = torch.randn(num_samples_per_prompt,self.audio_seq_len,64)
latents = torch.randn(num_samples_per_prompt, self.audio_seq_len, 64)
weight_dtype = latents.dtype
progress_bar = tqdm(range(num_inference_steps), disable=disable_progress)
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
audio_ids = torch.arange(self.audio_seq_len).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
txt_ids = torch.zeros(bsz, encoder_hidden_states.shape[1], 3).to(device)
audio_ids = (
torch.arange(self.audio_seq_len)
.unsqueeze(0)
.unsqueeze(-1)
.repeat(bsz, 1, 3)
.to(device)
)
timesteps = timesteps.to(device)
latents = latents.to(device)
encoder_hidden_states = encoder_hidden_states.to(device)
for i, t in enumerate(timesteps):
latents_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
latents_input = (
torch.cat([latents] * 2) if classifier_free_guidance else latents
)
noise_pred = self.transformer(
hidden_states=latents_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=torch.tensor([t/1000],device=device),
guidance = None,
pooled_projections=pooled_projection,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=audio_ids,
return_dict=False,
)[0]
hidden_states=latents_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=torch.tensor([t / 1000], device=device),
guidance=None,
pooled_projections=pooled_projection,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=audio_ids,
return_dict=False,
)[0]
if classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents).prev_sample
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
latents = scheduler.step(noise_pred, t, latents).prev_sample
return latents
def forward(self,
latents,
prompt,
duration=torch.tensor([10]),
sft=True
):
def forward(self, latents, prompt, duration=torch.tensor([10]), sft=True):
device = latents.device
audio_seq_length = self.audio_seq_len
bsz = latents.shape[0]
encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
duration_hidden_states = self.encode_duration(duration)
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states)
masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(
encoder_hidden_states
)
masked_data = torch.where(
mask_expanded, encoder_hidden_states, torch.tensor(float("nan"))
)
pooled = torch.nanmean(masked_data, dim=1)
pooled_projection = self.fc(pooled)
## Add duration hidden states to encoder hidden states
encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim)
encoder_hidden_states = torch.cat(
[encoder_hidden_states, duration_hidden_states], dim=1
) ## (bs,seq_len,dim)
txt_ids = torch.zeros(bsz, encoder_hidden_states.shape[1], 3).to(device)
audio_ids = (
torch.arange(audio_seq_length)
.unsqueeze(0)
.unsqueeze(-1)
.repeat(bsz, 1, 3)
.to(device)
)
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
audio_ids = torch.arange(audio_seq_length).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
if sft:
if self.uncondition:
mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
if len(mask_indices) > 0:
encoder_hidden_states[mask_indices] = 0
noise = torch.randn_like(latents)
u = compute_density_for_timestep_sampling(
weighting_scheme='logit_normal',
batch_size=bsz,
logit_mean=0,
logit_std=1,
mode_scale=None,
)
weighting_scheme="logit_normal",
batch_size=bsz,
logit_mean=0,
logit_std=1,
mode_scale=None,
)
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
timesteps = self.noise_scheduler_copy.timesteps[indices].to(
device=latents.device
)
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
model_pred = self.transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
model_pred = self.transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps/1000,
return_dict=False)[0]
timestep=timesteps / 1000,
return_dict=False,
)[0]
target = noise - latents
loss = torch.mean(
( (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
((model_pred.float() - target.float()) ** 2).reshape(
target.shape[0], -1
),
1,
)
loss = loss.mean()
raw_model_loss, raw_ref_loss,implicit_acc = 0,0,0 ## default this to 0 if doing sft
raw_model_loss, raw_ref_loss, implicit_acc = (
0,
0,
0,
) ## default this to 0 if doing sft
else:
encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
pooled_projection = pooled_projection.repeat(2,1)
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1) ## Have to sample same noise for preferred and rejected
pooled_projection = pooled_projection.repeat(2, 1)
noise = (
torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1)
) ## Have to sample same noise for preferred and rejected
u = compute_density_for_timestep_sampling(
weighting_scheme='logit_normal',
batch_size=bsz//2,
logit_mean=0,
logit_std=1,
mode_scale=None,
)
weighting_scheme="logit_normal",
batch_size=bsz // 2,
logit_mean=0,
logit_std=1,
mode_scale=None,
)
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
timesteps = self.noise_scheduler_copy.timesteps[indices].to(
device=latents.device
)
timesteps = timesteps.repeat(2)
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
model_pred = self.transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
model_pred = self.transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps/1000,
return_dict=False)[0]
timestep=timesteps / 1000,
return_dict=False,
)[0]
target = noise - latents
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
model_losses = F.mse_loss(
model_pred.float(), target.float(), reduction="none"
)
model_losses = model_losses.mean(
dim=list(range(1, len(model_losses.shape)))
)
model_losses_w, model_losses_l = model_losses.chunk(2)
model_diff = model_losses_w - model_losses_l
model_diff = model_losses_w - model_losses_l
raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
with torch.no_grad():
ref_preds = self.ref_transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
timestep=timesteps/1000,
return_dict=False)[0]
ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
timestep=timesteps / 1000,
return_dict=False,
)[0]
ref_loss = F.mse_loss(
ref_preds.float(), target.float(), reduction="none"
)
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean()
scale_term = -0.5 * self.beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
implicit_acc = (scale_term * (model_diff - ref_diff) > 0).sum().float() / inside_term.size(0)
loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
## raw_model_loss, raw_ref_loss, implicit_acc is used to help to analyze dpo behaviour.
return loss, raw_model_loss, raw_ref_loss, implicit_acc
inside_term = scale_term * (model_diff - ref_diff)
implicit_acc = (
scale_term * (model_diff - ref_diff) > 0
).sum().float() / inside_term.size(0)
loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
\ No newline at end of file
## raw_model_loss, raw_ref_loss, implicit_acc is used to help to analyze dpo behaviour.
return loss, raw_model_loss, raw_ref_loss, implicit_acc
......@@ -22,136 +22,170 @@ from tqdm.auto import tqdm
from transformers import SchedulerType, get_scheduler
from model import TangoFlux
from datasets import load_dataset, Audio
from utils import Text2AudioDataset,read_wav_file,pad_wav
from utils import Text2AudioDataset, read_wav_file, pad_wav
from diffusers import AutoencoderOobleck
import torchaudio
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Rectified flow for text to audio generation task.")
parser = argparse.ArgumentParser(
description="Rectified flow for text to audio generation task."
)
parser.add_argument(
"--num_examples", type=int, default=-1,
"--num_examples",
type=int,
default=-1,
help="How many examples to use for training and validation.",
)
parser.add_argument(
"--text_column", type=str, default="captions",
"--text_column",
type=str,
default="captions",
help="The name of the column in the datasets containing the input texts.",
)
parser.add_argument(
"--audio_column", type=str, default="location",
"--audio_column",
type=str,
default="location",
help="The name of the column in the datasets containing the audio paths.",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9,
help="The beta1 parameter for the Adam optimizer."
"--adam_beta1",
type=float,
default=0.9,
help="The beta1 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_beta2", type=float, default=0.95,
help="The beta2 parameter for the Adam optimizer."
"--adam_beta2",
type=float,
default=0.95,
help="The beta2 parameter for the Adam optimizer.",
)
parser.add_argument(
"--config", type=str, default='tangoflux_config.yaml',
"--config",
type=str,
default="tangoflux_config.yaml",
help="Config file defining the model size as well as other hyper parameter.",
)
parser.add_argument(
"--prefix", type=str, default='',
"--prefix",
type=str,
default="",
help="Add prefix in text prompts.",
)
parser.add_argument(
"--learning_rate", type=float, default=3e-5,
"--learning_rate",
type=float,
default=3e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--weight_decay", type=float, default=1e-8,
help="Weight decay to use."
"--weight_decay", type=float, default=1e-8, help="Weight decay to use."
)
parser.add_argument(
"--max_train_steps", type=int, default=None,
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--lr_scheduler_type", type=SchedulerType, default="linear",
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
choices=[
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
],
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler."
"--num_warmup_steps",
type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument(
"--adam_epsilon", type=float, default=1e-08,
help="Epsilon value for the Adam optimizer"
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument(
"--adam_weight_decay", type=float, default=1e-2,
help="Epsilon value for the Adam optimizer"
"--adam_weight_decay",
type=float,
default=1e-2,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument(
"--seed", type=int, default=None,
help="A seed for reproducible training."
"--seed", type=int, default=None, help="A seed for reproducible training."
)
parser.add_argument(
"--checkpointing_steps", type=str, default="best",
"--checkpointing_steps",
type=str,
default="best",
help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
)
parser.add_argument(
"--save_every", type=int, default=5,
help="Save model after every how many epochs when checkpointing_steps is set to best."
"--save_every",
type=int,
default=5,
help="Save model after every how many epochs when checkpointing_steps is set to best.",
)
parser.add_argument(
"--resume_from_checkpoint", type=str, default=None,
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a local checkpoint folder.",
)
parser.add_argument(
"--load_from_checkpoint", type=str, default=None,
"--load_from_checkpoint",
type=str,
default=None,
help="Whether to continue training from a model weight",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
accelerator_log_kwargs = {}
def load_config(config_path):
with open(config_path, 'r') as file:
with open(config_path, "r") as file:
return yaml.safe_load(file)
config = load_config(args.config)
learning_rate = float(config['training']['learning_rate'])
num_train_epochs = int(config['training']['num_train_epochs'])
num_warmup_steps = int(config['training']['num_warmup_steps'])
per_device_batch_size = int(config['training']['per_device_batch_size'])
gradient_accumulation_steps = int(config['training']['gradient_accumulation_steps'])
learning_rate = float(config["training"]["learning_rate"])
num_train_epochs = int(config["training"]["num_train_epochs"])
num_warmup_steps = int(config["training"]["num_warmup_steps"])
per_device_batch_size = int(config["training"]["per_device_batch_size"])
gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
output_dir = config['paths']['output_dir']
output_dir = config["paths"]["output_dir"]
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
**accelerator_log_kwargs,
)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, **accelerator_log_kwargs)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......@@ -172,12 +206,12 @@ def main():
if accelerator.is_main_process:
if output_dir is None or output_dir == "":
output_dir = "saved/" + str(int(time.time()))
if not os.path.exists("saved"):
os.makedirs("saved")
os.makedirs(output_dir, exist_ok=True)
elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
......@@ -187,74 +221,120 @@ def main():
accelerator.project_configuration.automatic_checkpoint_naming = False
wandb.init(project="Text to Audio Flow matching",settings=wandb.Settings(_disable_stats=True))
wandb.init(
project="Text to Audio Flow matching",
settings=wandb.Settings(_disable_stats=True),
)
accelerator.wait_for_everyone()
# Get the datasets
data_files = {}
#if args.train_file is not None:
if config['paths']['train_file'] != '':
data_files["train"] = config['paths']['train_file']
# if args.validation_file is not None:
if config['paths']['val_file'] != '':
data_files["validation"] = config['paths']['val_file']
if config['paths']['test_file'] != '':
data_files["test"] = config['paths']['test_file']
else:
data_files["test"] = config['paths']['val_file']
extension = 'json'
# if args.train_file is not None:
if config["paths"]["train_file"] != "":
data_files["train"] = config["paths"]["train_file"]
# if args.validation_file is not None:
if config["paths"]["val_file"] != "":
data_files["validation"] = config["paths"]["val_file"]
if config["paths"]["test_file"] != "":
data_files["test"] = config["paths"]["test_file"]
else:
data_files["test"] = config["paths"]["val_file"]
extension = "json"
raw_datasets = load_dataset(extension, data_files=data_files)
text_column, audio_column = args.text_column, args.audio_column
model = TangoFlux(config=config['model'])
vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae')
model = TangoFlux(config=config["model"])
vae = AutoencoderOobleck.from_pretrained(
"stabilityai/stable-audio-open-1.0", subfolder="vae"
)
## Freeze vae
for param in vae.parameters():
vae.requires_grad = False
vae.eval()
## Freeze text encoder param
for param in model.text_encoder.parameters():
param.requires_grad = False
model.text_encoder.eval()
prefix = args.prefix
prefix = args.prefix
with accelerator.main_process_first():
train_dataset = Text2AudioDataset(raw_datasets["train"], prefix, text_column, audio_column,'duration', args.num_examples)
eval_dataset = Text2AudioDataset(raw_datasets["validation"], prefix, text_column, audio_column,'duration', args.num_examples)
test_dataset = Text2AudioDataset(raw_datasets["test"], prefix, text_column, audio_column,'duration', args.num_examples)
accelerator.print("Num instances in train: {}, validation: {}, test: {}".format(train_dataset.get_num_instances(), eval_dataset.get_num_instances(), test_dataset.get_num_instances()))
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['training']['per_device_batch_size'], collate_fn=train_dataset.collate_fn)
eval_dataloader = DataLoader(eval_dataset, shuffle=True, batch_size=config['training']['per_device_batch_size'], collate_fn=eval_dataset.collate_fn)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=config['training']['per_device_batch_size'], collate_fn=test_dataset.collate_fn)
train_dataset = Text2AudioDataset(
raw_datasets["train"],
prefix,
text_column,
audio_column,
"duration",
args.num_examples,
)
eval_dataset = Text2AudioDataset(
raw_datasets["validation"],
prefix,
text_column,
audio_column,
"duration",
args.num_examples,
)
test_dataset = Text2AudioDataset(
raw_datasets["test"],
prefix,
text_column,
audio_column,
"duration",
args.num_examples,
)
accelerator.print(
"Num instances in train: {}, validation: {}, test: {}".format(
train_dataset.get_num_instances(),
eval_dataset.get_num_instances(),
test_dataset.get_num_instances(),
)
)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=train_dataset.collate_fn,
)
eval_dataloader = DataLoader(
eval_dataset,
shuffle=True,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=eval_dataset.collate_fn,
)
test_dataloader = DataLoader(
test_dataset,
shuffle=False,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=test_dataset.collate_fn,
)
# Optimizer
optimizer_parameters = list(model.transformer.parameters())+list(model.fc.parameters())
num_trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
optimizer_parameters = list(model.transformer.parameters()) + list(
model.fc.parameters()
)
num_trainable_parameters = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
if args.load_from_checkpoint:
from safetensors.torch import load_file
w1 = load_file(args.load_from_checkpoint)
model.load_state_dict(w1,strict=False)
model.load_state_dict(w1, strict=False)
logger.info("Weights loaded from{}".format(args.load_from_checkpoint))
optimizer = torch.optim.AdamW(
optimizer_parameters, lr=learning_rate,
optimizer_parameters,
lr=learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
......@@ -262,31 +342,35 @@ def main():
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
if args.max_train_steps is None:
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps * gradient_accumulation_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * gradient_accumulation_steps
num_warmup_steps=num_warmup_steps
* gradient_accumulation_steps
* accelerator.num_processes,
num_training_steps=args.max_train_steps * gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
vae, model, optimizer, lr_scheduler = accelerator.prepare(
vae, model, optimizer, lr_scheduler
vae, model, optimizer, lr_scheduler
)
train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
train_dataloader, eval_dataloader, test_dataloader
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
if overrode_max_train_steps:
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
......@@ -299,42 +383,44 @@ def main():
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
# Train!
total_batch_size = per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
total_batch_size = (
per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar = tqdm(
range(args.max_train_steps), disable=not accelerator.is_local_main_process
)
completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save
resume_from_checkpoint = config['paths']['resume_from_checkpoint']
if resume_from_checkpoint!= '':
resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
if resume_from_checkpoint != "":
accelerator.load_state(resume_from_checkpoint)
accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")
# Duration of the audio clips in seconds
best_loss = np.inf
length = config['training']['max_audio_duration']
best_loss = np.inf
length = config["training"]["max_audio_duration"]
for epoch in range(starting_epoch, num_train_epochs):
model.train()
total_loss, total_val_loss = 0, 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
device = model.device
......@@ -342,43 +428,44 @@ def main():
with torch.no_grad():
audio_list = []
for audio_path in audios:
wav = read_wav_file(audio_path,length) ## Only read the first 30 seconds of audio
if wav.shape[0] == 1 : ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2,1)
audio_list.append(wav)
wav = read_wav_file(
audio_path, length
) ## Only read the first 30 seconds of audio
if (
wav.shape[0] == 1
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2, 1)
audio_list.append(wav)
audio_input = torch.stack(audio_list,dim=0)
audio_input = torch.stack(audio_list, dim=0)
audio_input = audio_input.to(device)
unwrapped_vae = accelerator.unwrap_model(vae)
duration = torch.tensor(duration,device=device)
duration = torch.clamp(duration, max=length) ## clamp duration to max audio length
audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
audio_latent = audio_latent.transpose(1,2) ## Tranpose to (bsz, seq_len, channel)
loss, _, _,_ = model(audio_latent, text ,duration=duration)
duration = torch.tensor(duration, device=device)
duration = torch.clamp(
duration, max=length
) ## clamp duration to max audio length
audio_latent = unwrapped_vae.encode(
audio_input
).latent_dist.sample()
audio_latent = audio_latent.transpose(
1, 2
) ## Tranpose to (bsz, seq_len, channel)
loss, _, _, _ = model(audio_latent, text, duration=duration)
total_loss += loss.detach().float()
accelerator.backward(loss)
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
optimizer.step()
lr_scheduler.step()
if completed_steps % 10 == 0 and accelerator.is_main_process:
......@@ -388,20 +475,21 @@ def main():
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
logger.info(f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}")
total_norm = total_norm**0.5
logger.info(
f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
)
lr = lr_scheduler.get_last_lr()[0]
result = {
"train_loss": loss.item(),
"grad_norm": total_norm,
"learning_rate": lr
"learning_rate": lr,
}
# result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
wandb.log(result, step=completed_steps)
# Checks if the accelerator has performed an optimization step behind the scenes
if isinstance(checkpointing_steps, int):
......@@ -415,55 +503,59 @@ def main():
break
model.eval()
eval_progress_bar = tqdm(range(len(eval_dataloader)), disable=not accelerator.is_local_main_process)
eval_progress_bar = tqdm(
range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
)
for step, batch in enumerate(eval_dataloader):
with accelerator.accumulate(model) and torch.no_grad():
device = model.device
text, audios, duration, _ = batch
audio_list = []
for audio_path in audios:
wav = read_wav_file(audio_path,length) ## make sure none of audio exceed 30 sec
if wav.shape[0] == 1 : ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2,1)
audio_list.append(wav)
wav = read_wav_file(
audio_path, length
) ## make sure none of audio exceed 30 sec
if (
wav.shape[0] == 1
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2, 1)
audio_list.append(wav)
audio_input = torch.stack(audio_list,dim=0)
audio_input = torch.stack(audio_list, dim=0)
audio_input = audio_input.to(device)
duration = torch.tensor(duration,device=device)
duration = torch.tensor(duration, device=device)
unwrapped_vae = accelerator.unwrap_model(vae)
audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
audio_latent = audio_latent.transpose(1,2) ## Tranpose to (bsz, seq_len, channel)
val_loss,_, _,_ = model(audio_latent, text , duration=duration)
audio_latent = audio_latent.transpose(
1, 2
) ## Tranpose to (bsz, seq_len, channel)
val_loss, _, _, _ = model(audio_latent, text, duration=duration)
total_val_loss += val_loss.detach().float()
eval_progress_bar.update(1)
if accelerator.is_main_process:
result = {}
result["epoch"] = float(epoch+1)
result["epoch/train_loss"] = round(total_loss.item()/len(train_dataloader), 4)
result["epoch/val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
result["epoch"] = float(epoch + 1)
result["epoch/train_loss"] = round(
total_loss.item() / len(train_dataloader), 4
)
result["epoch/val_loss"] = round(
total_val_loss.item() / len(eval_dataloader), 4
)
wandb.log(result, step=completed_steps)
wandb.log(result,step=completed_steps)
result_string = "Epoch: {}, Loss Train: {}, Val: {}\n".format(
epoch, result["epoch/train_loss"], result["epoch/val_loss"]
)
result_string = "Epoch: {}, Loss Train: {}, Val: {}\n".format(epoch, result["epoch/train_loss"],result["epoch/val_loss"])
accelerator.print(result_string)
with open("{}/summary.jsonl".format(output_dir), "a") as f:
f.write(json.dumps(result) + "\n\n")
......@@ -480,13 +572,17 @@ def main():
if accelerator.is_main_process and args.checkpointing_steps == "best":
if save_checkpoint:
accelerator.save_state("{}/{}".format(output_dir, "best"))
if (epoch + 1) % args.save_every == 0:
accelerator.save_state("{}/{}".format(output_dir, "epoch_" + str(epoch+1)))
accelerator.save_state(
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
)
if accelerator.is_main_process and args.checkpointing_steps == "epoch":
accelerator.save_state("{}/{}".format(output_dir, "epoch_" + str(epoch+1)))
accelerator.save_state(
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
)
if __name__ == "__main__":
main()
......@@ -5,6 +5,7 @@ import logging
import math
import os
import yaml
# from tqdm import tqdm
import copy
from pathlib import Path
......@@ -22,9 +23,9 @@ from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import SchedulerType, get_scheduler
from src.model import TangoFlux
from tangoflux.model import TangoFlux
from datasets import load_dataset, Audio
from src.utils import Text2AudioDataset,read_wav_file,DPOText2AudioDataset
from tangoflux.utils import Text2AudioDataset, read_wav_file, DPOText2AudioDataset
from diffusers import AutoencoderOobleck
import torchaudio
......@@ -32,84 +33,119 @@ import torchaudio
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Rectified flow for text to audio generation task.")
parser = argparse.ArgumentParser(
description="Rectified flow for text to audio generation task."
)
parser.add_argument(
"--num_examples", type=int, default=-1,
"--num_examples",
type=int,
default=-1,
help="How many examples to use for training and validation.",
)
parser.add_argument(
"--text_column", type=str, default="captions",
"--text_column",
type=str,
default="captions",
help="The name of the column in the datasets containing the input texts.",
)
parser.add_argument(
"--audio_column", type=str, default="location",
"--audio_column",
type=str,
default="location",
help="The name of the column in the datasets containing the audio paths.",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9,
help="The beta1 parameter for the Adam optimizer."
"--adam_beta1",
type=float,
default=0.9,
help="The beta1 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_beta2", type=float, default=0.95,
help="The beta2 parameter for the Adam optimizer."
"--adam_beta2",
type=float,
default=0.95,
help="The beta2 parameter for the Adam optimizer.",
)
parser.add_argument(
"--config", type=str, default='tangoflux_config.yaml',
"--config",
type=str,
default="tangoflux_config.yaml",
help="Config file defining the model size.",
)
parser.add_argument(
"--weight_decay", type=float, default=1e-8,
help="Weight decay to use."
"--weight_decay", type=float, default=1e-8, help="Weight decay to use."
)
parser.add_argument(
"--max_train_steps", type=int, default=None,
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--lr_scheduler_type", type=SchedulerType, default="linear",
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
choices=[
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
],
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler."
"--num_warmup_steps",
type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument(
"--adam_epsilon", type=float, default=1e-08,
help="Epsilon value for the Adam optimizer"
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument(
"--adam_weight_decay", type=float, default=1e-2,
help="Epsilon value for the Adam optimizer"
"--adam_weight_decay",
type=float,
default=1e-2,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument(
"--seed", type=int, default=None,
help="A seed for reproducible training."
"--seed", type=int, default=None, help="A seed for reproducible training."
)
parser.add_argument(
"--checkpointing_steps", type=str, default="best",
"--checkpointing_steps",
type=str,
default="best",
help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
)
parser.add_argument(
"--save_every", type=int, default=5,
help="Save model after every how many epochs when checkpointing_steps is set to best."
"--save_every",
type=int,
default=5,
help="Save model after every how many epochs when checkpointing_steps is set to best.",
)
parser.add_argument(
"--resume_from_checkpoint", type=str, default=None,
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a local checkpoint folder.",
)
parser.add_argument(
"--report_to", type=str, default="all",
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
......@@ -117,60 +153,58 @@ def parse_args():
),
)
parser.add_argument(
"--load_from_checkpoint", type=str, default=None,
"--load_from_checkpoint",
type=str,
default=None,
help="Whether to continue training from a model weight",
)
parser.add_argument(
"--audio_length", type=float, default=30,
"--audio_length",
type=float,
default=30,
help="Audio duration",
)
args = parser.parse_args()
# Sanity checks
#if args.train_file is None and args.validation_file is None:
# raise ValueError("Need a training/validation file.")
#else:
# if args.train_file is not None:
# extension = args.train_file.split(".")[-1]
# assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
#if args.validation_file is not None:
# extension = args.validation_file.split(".")[-1]
# assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
# if args.train_file is None and args.validation_file is None:
# raise ValueError("Need a training/validation file.")
# else:
# if args.train_file is not None:
# extension = args.train_file.split(".")[-1]
# assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
# if args.validation_file is not None:
# extension = args.validation_file.split(".")[-1]
# assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
return args
def main():
args = parse_args()
accelerator_log_kwargs = {}
def load_config(config_path):
with open(config_path, 'r') as file:
with open(config_path, "r") as file:
return yaml.safe_load(file)
config = load_config(args.config)
learning_rate = float(config['training']['learning_rate'])
num_train_epochs = int(config['training']['num_train_epochs'])
num_warmup_steps = int(config['training']['num_warmup_steps'])
per_device_batch_size = int(config['training']['per_device_batch_size'])
gradient_accumulation_steps = int(config['training']['gradient_accumulation_steps'])
learning_rate = float(config["training"]["learning_rate"])
num_train_epochs = int(config["training"]["num_train_epochs"])
num_warmup_steps = int(config["training"]["num_warmup_steps"])
per_device_batch_size = int(config["training"]["per_device_batch_size"])
gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
output_dir = config['paths']['output_dir']
output_dir = config["paths"]["output_dir"]
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
**accelerator_log_kwargs,
)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, **accelerator_log_kwargs)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......@@ -191,12 +225,12 @@ def main():
if accelerator.is_main_process:
if output_dir is None or output_dir == "":
output_dir = "saved/" + str(int(time.time()))
if not os.path.exists("saved"):
os.makedirs("saved")
os.makedirs(output_dir, exist_ok=True)
elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
......@@ -206,73 +240,122 @@ def main():
accelerator.project_configuration.automatic_checkpoint_naming = False
wandb.init(project="Text to Audio Flow matching DPO",settings=wandb.Settings(_disable_stats=True))
wandb.init(
project="Text to Audio Flow matching DPO",
settings=wandb.Settings(_disable_stats=True),
)
accelerator.wait_for_everyone()
# Get the datasets
data_files = {}
#if args.train_file is not None:
if config['paths']['train_file'] != '':
data_files["train"] = config['paths']['train_file']
# if args.validation_file is not None:
if config['paths']['val_file'] != '':
data_files["validation"] = config['paths']['val_file']
if config['paths']['test_file'] != '':
data_files["test"] = config['paths']['test_file']
else:
data_files["test"] = config['paths']['val_file']
extension = 'json'
train_dataset = load_dataset(extension,data_files=data_files['train'])
data_files.pop('train')
# if args.train_file is not None:
if config["paths"]["train_file"] != "":
data_files["train"] = config["paths"]["train_file"]
# if args.validation_file is not None:
if config["paths"]["val_file"] != "":
data_files["validation"] = config["paths"]["val_file"]
if config["paths"]["test_file"] != "":
data_files["test"] = config["paths"]["test_file"]
else:
data_files["test"] = config["paths"]["val_file"]
extension = "json"
train_dataset = load_dataset(extension, data_files=data_files["train"])
data_files.pop("train")
raw_datasets = load_dataset(extension, data_files=data_files)
text_column, audio_column = args.text_column, args.audio_column
model = TangoFlux(config=config['model'],initialize_reference_model=True)
vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae')
model = TangoFlux(config=config["model"], initialize_reference_model=True)
vae = AutoencoderOobleck.from_pretrained(
"stabilityai/stable-audio-open-1.0", subfolder="vae"
)
## Freeze vae
for param in vae.parameters():
vae.requires_grad = False
vae.eval()
## Freeze text encoder param
for param in model.text_encoder.parameters():
param.requires_grad = False
model.text_encoder.eval()
prefix = ""
prefix = ""
with accelerator.main_process_first():
train_dataset = DPOText2AudioDataset(train_dataset["train"], prefix, text_column, 'chosen','reject','duration', args.num_examples)
eval_dataset = Text2AudioDataset(raw_datasets["validation"], prefix, text_column, audio_column,'duration', args.num_examples)
test_dataset = Text2AudioDataset(raw_datasets["test"], prefix, text_column, audio_column,'duration', args.num_examples)
accelerator.print("Num instances in train: {}, validation: {}, test: {}".format(train_dataset.get_num_instances(), eval_dataset.get_num_instances(), test_dataset.get_num_instances()))
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['training']['per_device_batch_size'], collate_fn=train_dataset.collate_fn)
eval_dataloader = DataLoader(eval_dataset, shuffle=True, batch_size=config['training']['per_device_batch_size'], collate_fn=eval_dataset.collate_fn)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=config['training']['per_device_batch_size'], collate_fn=test_dataset.collate_fn)
train_dataset = DPOText2AudioDataset(
train_dataset["train"],
prefix,
text_column,
"chosen",
"reject",
"duration",
args.num_examples,
)
eval_dataset = Text2AudioDataset(
raw_datasets["validation"],
prefix,
text_column,
audio_column,
"duration",
args.num_examples,
)
test_dataset = Text2AudioDataset(
raw_datasets["test"],
prefix,
text_column,
audio_column,
"duration",
args.num_examples,
)
accelerator.print(
"Num instances in train: {}, validation: {}, test: {}".format(
train_dataset.get_num_instances(),
eval_dataset.get_num_instances(),
test_dataset.get_num_instances(),
)
)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=train_dataset.collate_fn,
)
eval_dataloader = DataLoader(
eval_dataset,
shuffle=True,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=eval_dataset.collate_fn,
)
test_dataloader = DataLoader(
test_dataset,
shuffle=False,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=test_dataset.collate_fn,
)
# Optimizer
optimizer_parameters = list(model.transformer.parameters())+list(model.fc.parameters())
num_trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
optimizer_parameters = list(model.transformer.parameters()) + list(
model.fc.parameters()
)
num_trainable_parameters = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
if args.load_from_checkpoint:
from safetensors.torch import load_file
w1 = load_file(args.load_from_checkpoint)
model.load_state_dict(w1,strict=False)
model.load_state_dict(w1, strict=False)
logger.info("Weights loaded from{}".format(args.load_from_checkpoint))
import copy
model.ref_transformer = copy.deepcopy(model.transformer)
model.ref_transformer.requires_grad_ = False
model.ref_transformer.eval()
......@@ -280,48 +363,49 @@ def main():
param.requires_grad = False
@torch.no_grad()
def initialize_or_update_ref_transformer(model, accelerator: Accelerator,alpha=0.5):
def initialize_or_update_ref_transformer(
model, accelerator: Accelerator, alpha=0.5
):
"""
Initializes or updates ref_transformer as alpha * ref + 1-alpha * transformer.
Args:
model (torch.nn.Module): The main model containing the 'transformer' attribute.
accelerator (Accelerator): The Accelerator instance used to unwrap the model.
initial_ref_model (torch.nn.Module, optional): An optional initial reference model.
initial_ref_model (torch.nn.Module, optional): An optional initial reference model.
If not provided, ref_transformer is initialized as a copy of transformer.
Returns:
torch.nn.Module: The model with the updated ref_transformer.
"""
# Unwrap the model to access the original underlying model
unwrapped_model = accelerator.unwrap_model(model)
with torch.no_grad():
for ref_param, model_param in zip(unwrapped_model.ref_transformer.parameters(),
unwrapped_model.transformer.parameters()):
average_param = alpha * ref_param.data + (1-alpha) * model_param.data
for ref_param, model_param in zip(
unwrapped_model.ref_transformer.parameters(),
unwrapped_model.transformer.parameters(),
):
average_param = alpha * ref_param.data + (1 - alpha) * model_param.data
ref_param.data.copy_(average_param)
unwrapped_model.ref_transformer.eval()
unwrapped_model.ref_transformer.requires_grad_= False
unwrapped_model.ref_transformer.requires_grad_ = False
for param in unwrapped_model.ref_transformer.parameters():
param.requires_grad = False
return model
model.ref_transformer = copy.deepcopy(model.transformer)
model.ref_transformer.requires_grad_ = False
model.ref_transformer.eval()
for param in model.ref_transformer.parameters():
param.requires_grad = False
optimizer = torch.optim.AdamW(
optimizer_parameters, lr=learning_rate,
optimizer_parameters,
lr=learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
......@@ -329,31 +413,35 @@ def main():
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
if args.max_train_steps is None:
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps * gradient_accumulation_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * gradient_accumulation_steps
num_warmup_steps=num_warmup_steps
* gradient_accumulation_steps
* accelerator.num_processes,
num_training_steps=args.max_train_steps * gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
vae, model, optimizer, lr_scheduler = accelerator.prepare(
vae, model, optimizer, lr_scheduler
vae, model, optimizer, lr_scheduler
)
train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
train_dataloader, eval_dataloader, test_dataloader
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
if overrode_max_train_steps:
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
......@@ -366,96 +454,108 @@ def main():
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
# Train!
total_batch_size = per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
total_batch_size = (
per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar = tqdm(
range(args.max_train_steps), disable=not accelerator.is_local_main_process
)
completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save
resume_from_checkpoint = config['paths']['resume_from_checkpoint']
if resume_from_checkpoint!= '':
resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
if resume_from_checkpoint != "":
accelerator.load_state(resume_from_checkpoint)
accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")
# Duration of the audio clips in seconds
best_loss = np.inf
length = config['training']['max_audio_duration']
best_loss = np.inf
length = config["training"]["max_audio_duration"]
for epoch in range(starting_epoch, num_train_epochs):
model.train()
total_loss, total_val_loss = 0, 0
for step, batch in enumerate(train_dataloader):
optimizer.zero_grad()
with accelerator.accumulate(model):
optimizer.zero_grad()
device = accelerator.device
text, audio_w,audio_l, duration, _ = batch
text, audio_w, audio_l, duration, _ = batch
with torch.no_grad():
audio_list_w = []
audio_list_l = []
for audio_path in audio_w:
wav = read_wav_file(audio_path,length) ## Only read the first 30 seconds of audio
if wav.shape[0] == 1 : ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2,1)
wav = read_wav_file(
audio_path, length
) ## Only read the first 30 seconds of audio
if (
wav.shape[0] == 1
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2, 1)
audio_list_w.append(wav)
for audio_path in audio_l:
wav = read_wav_file(audio_path,length) ## Only read the first 30 seconds of audio
if wav.shape[0] == 1 : ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2,1)
wav = read_wav_file(
audio_path, length
) ## Only read the first 30 seconds of audio
if (
wav.shape[0] == 1
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2, 1)
audio_list_l.append(wav)
audio_input_w = torch.stack(audio_list_w,dim=0).to(device)
audio_input_l = torch.stack(audio_list_l,dim=0).to(device)
#audio_input_ = audio_input.to(device)
audio_input_w = torch.stack(audio_list_w, dim=0).to(device)
audio_input_l = torch.stack(audio_list_l, dim=0).to(device)
# audio_input_ = audio_input.to(device)
unwrapped_vae = accelerator.unwrap_model(vae)
duration = torch.tensor(duration,device=device)
duration = torch.clamp(duration, max=length) ## max duration is 30 sec
audio_latent_w = unwrapped_vae.encode(audio_input_w).latent_dist.sample()
audio_latent_l = unwrapped_vae.encode(audio_input_l).latent_dist.sample()
audio_latent = torch.cat((audio_latent_w,audio_latent_l),dim=0)
audio_latent = audio_latent.transpose(1,2) ## Tranpose to (bsz, seq_len, channel)
loss, raw_model_loss, raw_ref_loss,implicit_acc = model(audio_latent, text ,duration=duration,sft=False)
duration = torch.tensor(duration, device=device)
duration = torch.clamp(
duration, max=length
) ## max duration is 30 sec
audio_latent_w = unwrapped_vae.encode(
audio_input_w
).latent_dist.sample()
audio_latent_l = unwrapped_vae.encode(
audio_input_l
).latent_dist.sample()
audio_latent = torch.cat((audio_latent_w, audio_latent_l), dim=0)
audio_latent = audio_latent.transpose(
1, 2
) ## Tranpose to (bsz, seq_len, channel)
loss, raw_model_loss, raw_ref_loss, implicit_acc = model(
audio_latent, text, duration=duration, sft=False
)
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
#if accelerator.sync_gradients:
# if accelerator.sync_gradients:
if accelerator.sync_gradients:
#accelerator.clip_grad_value_(model.parameters(),1.0)
# accelerator.clip_grad_value_(model.parameters(),1.0)
progress_bar.update(1)
completed_steps += 1
if completed_steps % 10 == 0 and accelerator.is_main_process:
......@@ -465,26 +565,25 @@ def main():
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
logger.info(f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}")
total_norm = total_norm**0.5
logger.info(
f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
)
lr = lr_scheduler.get_last_lr()[0]
result = {
"train_loss": loss.item(),
"grad_norm": total_norm,
"learning_rate": lr,
'raw_model_loss':raw_model_loss,
'raw_ref_loss': raw_ref_loss,
'implicit_acc':implicit_acc
"raw_model_loss": raw_model_loss,
"raw_ref_loss": raw_ref_loss,
"implicit_acc": implicit_acc,
}
# result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
wandb.log(result, step=completed_steps)
# Checks if the accelerator has performed an optimization step behind the scenes
if isinstance(checkpointing_steps, int):
......@@ -497,72 +596,78 @@ def main():
if completed_steps >= args.max_train_steps:
break
model.eval()
eval_progress_bar = tqdm(range(len(eval_dataloader)), disable=not accelerator.is_local_main_process)
eval_progress_bar = tqdm(
range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
)
for step, batch in enumerate(eval_dataloader):
with accelerator.accumulate(model) and torch.no_grad():
device = model.device
text, audios, duration, _ = batch
audio_list = []
for audio_path in audios:
wav = read_wav_file(audio_path,length) ## Only read the first 30 seconds of audio
if wav.shape[0] == 1 : ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2,1)
wav = read_wav_file(
audio_path, length
) ## Only read the first 30 seconds of audio
if (
wav.shape[0] == 1
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2, 1)
audio_list.append(wav)
audio_input = torch.stack(audio_list,dim=0)
audio_input = torch.stack(audio_list, dim=0)
audio_input = audio_input.to(device)
duration = torch.tensor(duration,device=device)
duration = torch.tensor(duration, device=device)
unwrapped_vae = accelerator.unwrap_model(vae)
audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
audio_latent = audio_latent.transpose(1,2) ## Tranpose to (bsz, seq_len, channel)
val_loss, _, _, _ = model(audio_latent, text , duration=duration,sft=True)
audio_latent = audio_latent.transpose(
1, 2
) ## Tranpose to (bsz, seq_len, channel)
val_loss, _, _, _ = model(
audio_latent, text, duration=duration, sft=True
)
total_val_loss += val_loss.detach().float()
eval_progress_bar.update(1)
if accelerator.is_main_process:
result = {}
result["epoch"] = float(epoch+1)
result["epoch/train_loss"] = round(total_loss.item()/len(train_dataloader), 4)
result["epoch/val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
result = {}
result["epoch"] = float(epoch + 1)
wandb.log(result,step=completed_steps)
result["epoch/train_loss"] = round(
total_loss.item() / len(train_dataloader), 4
)
result["epoch/val_loss"] = round(
total_val_loss.item() / len(eval_dataloader), 4
)
wandb.log(result, step=completed_steps)
with open("{}/summary.jsonl".format(output_dir), "a") as f:
f.write(json.dumps(result) + "\n\n")
logger.info(result)
save_checkpoint= True
save_checkpoint = True
accelerator.wait_for_everyone()
if accelerator.is_main_process and args.checkpointing_steps == "best":
if save_checkpoint:
accelerator.save_state("{}/{}".format(output_dir, "best"))
if (epoch + 1) % args.save_every == 0:
accelerator.save_state("{}/{}".format(output_dir, "epoch_" + str(epoch+1)))
accelerator.save_state(
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
)
if accelerator.is_main_process and args.checkpointing_steps == "epoch":
accelerator.save_state("{}/{}".format(output_dir, "epoch_" + str(epoch+1)))
accelerator.save_state(
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
)
if __name__ == "__main__":
main()
......@@ -11,7 +11,7 @@ import numpy as np
import numpy as np
def normalize_wav(waveform):
waveform = waveform - torch.mean(waveform)
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
......@@ -20,7 +20,7 @@ def normalize_wav(waveform):
def pad_wav(waveform, segment_length):
waveform_length = len(waveform)
if segment_length is None or waveform_length == segment_length:
return waveform
elif waveform_length > segment_length:
......@@ -29,40 +29,47 @@ def pad_wav(waveform, segment_length):
padded_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
waveform = torch.cat([waveform, padded_wav])
return waveform
def read_wav_file(filename, duration_sec):
info = torchaudio.info(filename)
sample_rate = info.sample_rate
# Calculate the number of frames corresponding to the desired duration
num_frames = int(sample_rate * duration_sec)
waveform, sr = torchaudio.load(filename,num_frames=num_frames) # Faster!!!
waveform, sr = torchaudio.load(filename, num_frames=num_frames) # Faster!!!
if waveform.shape[0] == 2 : ## Stereo audio
if waveform.shape[0] == 2: ## Stereo audio
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=44100)
resampled_waveform = resampler(waveform)
#print(resampled_waveform.shape)
padded_left = pad_wav(resampled_waveform[0], int(44100*duration_sec)) ## We pad left and right seperately
padded_right = pad_wav(resampled_waveform[1], int(44100*duration_sec))
# print(resampled_waveform.shape)
padded_left = pad_wav(
resampled_waveform[0], int(44100 * duration_sec)
) ## We pad left and right seperately
padded_right = pad_wav(resampled_waveform[1], int(44100 * duration_sec))
return torch.stack([padded_left,padded_right])
return torch.stack([padded_left, padded_right])
else:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=44100)[0]
waveform = pad_wav(waveform, int(44100*duration_sec)).unsqueeze(0)
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=44100
)[0]
waveform = pad_wav(waveform, int(44100 * duration_sec)).unsqueeze(0)
return waveform
class DPOText2AudioDataset(Dataset):
def __init__(self, dataset, prefix, text_column, audio_w_column, audio_l_column, duration, num_examples=-1):
def __init__(
self,
dataset,
prefix,
text_column,
audio_w_column,
audio_l_column,
duration,
num_examples=-1,
):
inputs = list(dataset[text_column])
self.inputs = [prefix + inp for inp in inputs]
......@@ -72,11 +79,18 @@ class DPOText2AudioDataset(Dataset):
self.indices = list(range(len(self.inputs)))
self.mapper = {}
for index, audio_w, audio_l, duration, text in zip(self.indices, self.audios_w,self.audios_l,self.durations,inputs):
for index, audio_w, audio_l, duration, text in zip(
self.indices, self.audios_w, self.audios_l, self.durations, inputs
):
self.mapper[index] = [audio_w, audio_l, duration, text]
if num_examples != -1:
self.inputs, self.audios_w, self.audios_l, self.durations = self.inputs[:num_examples], self.audios_w[:num_examples], self.audios_l[:num_examples], self.durations[:num_examples]
self.inputs, self.audios_w, self.audios_l, self.durations = (
self.inputs[:num_examples],
self.audios_w[:num_examples],
self.audios_l[:num_examples],
self.durations[:num_examples],
)
self.indices = self.indices[:num_examples]
def __len__(self):
......@@ -86,15 +100,24 @@ class DPOText2AudioDataset(Dataset):
return len(self.inputs)
def __getitem__(self, index):
s1, s2, s3, s4, s5 = self.inputs[index], self.audios_w[index], self.audios_l[index], self.durations[index], self.indices[index]
s1, s2, s3, s4, s5 = (
self.inputs[index],
self.audios_w[index],
self.audios_l[index],
self.durations[index],
self.indices[index],
)
return s1, s2, s3, s4, s5
def collate_fn(self, data):
dat = pd.DataFrame(data)
return [dat[i].tolist() for i in dat]
class Text2AudioDataset(Dataset):
def __init__(self, dataset, prefix, text_column, audio_column, duration, num_examples=-1):
def __init__(
self, dataset, prefix, text_column, audio_column, duration, num_examples=-1
):
inputs = list(dataset[text_column])
self.inputs = [prefix + inp for inp in inputs]
......@@ -103,11 +126,17 @@ class Text2AudioDataset(Dataset):
self.indices = list(range(len(self.inputs)))
self.mapper = {}
for index, audio, duration,text in zip(self.indices, self.audios, self.durations,inputs):
self.mapper[index] = [audio, text,duration]
for index, audio, duration, text in zip(
self.indices, self.audios, self.durations, inputs
):
self.mapper[index] = [audio, text, duration]
if num_examples != -1:
self.inputs, self.audios, self.durations = self.inputs[:num_examples], self.audios[:num_examples], self.durations[:num_examples]
self.inputs, self.audios, self.durations = (
self.inputs[:num_examples],
self.audios[:num_examples],
self.durations[:num_examples],
)
self.indices = self.indices[:num_examples]
def __len__(self):
......@@ -117,7 +146,12 @@ class Text2AudioDataset(Dataset):
return len(self.inputs)
def __getitem__(self, index):
s1, s2, s3, s4 = self.inputs[index], self.audios[index], self.durations[index], self.indices[index]
s1, s2, s3, s4 = (
self.inputs[index],
self.audios[index],
self.durations[index],
self.indices[index],
)
return s1, s2, s3, s4
def collate_fn(self, data):
......
CUDA_VISISBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
\ No newline at end of file
CUDA_VISISBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' tangoflux/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment