Unverified Commit 1453841b authored by Chia-Yu Hung's avatar Chia-Yu Hung Committed by GitHub
Browse files

Merge pull request #6 from fakerybakery/main

pip package
parents 7090a624 8fd3cc82
~__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
...@@ -168,3 +168,8 @@ cython_debug/ ...@@ -168,3 +168,8 @@ cython_debug/
# PyPI configuration file # PyPI configuration file
.pypirc .pypirc
.DS_Store
*.wav
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install git+https://github.com/declare-lab/TangoFlux.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import IPython\n",
"import torchaudio\n",
"from tangoflux import TangoFluxInference\n",
"from IPython.display import Audio\n",
"\n",
"model = TangoFluxInference(name='declare-lab/TangoFlux')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# @title Generate Audio\n",
"\n",
"prompt = 'Hammer slowly hitting the wooden table' # @param {type:\"string\"}\n",
"duration = 10 # @param {type:\"number\"}\n",
"steps = 50 # @param {type:\"number\"}\n",
"\n",
"audio = model.generate(prompt, steps=steps, duration=duration)\n",
"\n",
"Audio(data=audio, rate=44100)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
This source diff could not be displayed because it is too large. You can view the blob instead.
<h1 align="center"> # TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization
<br/>
TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization
<br/>
✨✨✨
</h1>
<div align="center"> <div align="center">
<img src="assets/tf_teaser.png" alt="TangoFlux" width="1000" /> <img src="assets/tf_teaser.png" alt="TangoFlux" width="1000" />
<br/>
[![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Hugging_Face-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Hugging_Face_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Hugging_Face-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO) [![Replicate](https://replicate.com/chenxwh/tangoflux/badge)](https://replicate.com/chenxwh/tangoflux)
<br/> </div>
[![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO) [![Replicate](https://replicate.com/chenxwh/tangoflux/badge)](https://replicate.com/chenxwh/tangoflux) ## Demos
[![Hugging Face Space](https://img.shields.io/badge/Hugging_Face_Space-TangoFlux-blue?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/declare-lab/TangoFlux/blob/main/Demo.ipynb)
## Overall Pipeline
TangoFlux consists of FluxTransformer blocks, which are Diffusion Transformers (DiT) and Multimodal Diffusion Transformers (MMDiT) conditioned on a textual prompt and a duration embedding to generate a 44.1kHz audio up to 30 seconds long. TangoFlux learns a rectified flow trajectory to an audio latent representation encoded by a variational autoencoder (VAE). TangoFlux training pipeline consists of three stages: pre-training, fine-tuning, and preference optimization with CRPO. CRPO, particularly, iteratively generates new synthetic data and constructs preference pairs for preference optimization using DPO loss for flow matching.
</div> ![cover-photo](assets/tangoflux.png)
## Quickstart on Google Colab 🚀 **TangoFlux can generate 44.1kHz stereo audio up to 30 seconds in ~3 seconds on a single A40 GPU.**
| Colab | ## Installation
| --- |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1j__4fl_BlaVS_225M34d-EKxsVDJPRiR?usp=sharing)
## Overall Pipeline ```bash
TangoFlux consists of FluxTransformer blocks, which are Diffusion Transformers (DiT) and Multimodal Diffusion Transformers (MMDiT) conditioned on a textual prompt and a duration embedding to generate a 44.1kHz audio up to 30 seconds long. TangoFlux learns a rectified flow trajectory to an audio latent representation encoded by a variational autoencoder (VAE). TangoFlux training pipeline consists of three stages: pre-training, fine-tuning, and preference optimization with CRPO. CRPO, particularly, iteratively generates new synthetic data and constructs preference pairs for preference optimization using DPO loss for flow matching. pip install git+https://github.com/declare-lab/TangoFlux
```
![cover-photo](assets/tangoflux.png) ## Inference
TangoFlux can generate audio up to 30 seconds long. You must pass a duration to the `model.generate` function when using the Python API. Please note that duration should be between 1 and 30.
🚀 **TangoFlux can generate up to 30 seconds long 44.1kHz stereo audios in about 3 seconds on an A40 GPU.** ### Web Interface
## Training TangoFlux Run the following command to start the web interface:
We use the accelerate package from HuggingFace for multi-gpu training. Run accelerate config from terminal and set up your run configuration by the answering the questions asked. We have placed the default accelerator config in the `configs` folder. Please specify the path to your training files in the configs/tangoflux_config.yaml. A sample of train.json and val.json has been provided. Replace them with your own audio.
`tangoflux_config.yaml` defines the training file paths and model hyperparameters:
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml' tangoflux-demo
``` ```
To perform DPO training, modify the training files such that each data point contains a "chosen","reject","caption" and "duration". Please specify the path to your training files in the configs/tangoflux_config.yaml. An example has been provided in train_dpo.json. Replace them with your own audio. ### CLI
Use the CLI to generate audio from text.
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train_dpo.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml' tangoflux "Hammer slowly hitting the wooden table" output.wav --duration 10 --steps 50
``` ```
## Inference with TangoFlux
Download the TangoFlux model and generate audio from a text prompt. ### Python API
TangoFlux can generate audios up to 30 second long through passing in a duration variable in the `model.generate` function. Please note that duration should be strictly greather than 1 and lesser than 30.
```python ```python
import torchaudio import torchaudio
from tangoflux import TangoFluxInference from tangoflux import TangoFluxInference
from IPython.display import Audio
model = TangoFluxInference(name='declare-lab/TangoFlux') model = TangoFluxInference(name='declare-lab/TangoFlux')
audio = model.generate('Hammer slowly hitting the wooden table', steps=50, duration=10) audio = model.generate('Hammer slowly hitting the wooden table', steps=50, duration=10)
Audio(data=audio, rate=44100) torchaudio.save('output.wav', audio, 44100)
```
Our evaluation shows that inference with 50 steps yields the best results. A CFG scale of 3.5, 4, and 4.5 yield similar quality output. Inference with 25 steps yields similar audio quality at a faster speed.
## Training
We use the `accelerate` package from Hugging Face for multi-GPU training. Run `accelerate config` to setup your run configuration. The default accelerate config is in the `configs` folder. Please specify the path to your training files in the `configs/tangoflux_config.yaml`. Samples of `train.json` and `val.json` have been provided. Replace them with your own audio.
`tangoflux_config.yaml` defines the training file paths and model hyperparameters:
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
```
To perform DPO training, modify the training files such that each data point contains "chosen", "reject", "caption" and "duration" fields. Please specify the path to your training files in `configs/tangoflux_config.yaml`. An example has been provided in `train_dpo.json`. Replace it with your own audio.
```bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train_dpo.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
``` ```
Our evaluation shows that inferring with 50 steps yield the best results. A CFG scale of 3.5, 4, and 4.5 yield simliar quality output.
For faster inference, consider setting steps to 25 that yield similar audio quality.
## Evaluation Scripts ## Evaluation Scripts
...@@ -77,15 +92,13 @@ This key comparison metrics include: ...@@ -77,15 +92,13 @@ This key comparison metrics include:
All the inference times are observed on the same A40 GPU. The counts of trainable parameters are reported in the **\#Params** column. All the inference times are observed on the same A40 GPU. The counts of trainable parameters are reported in the **\#Params** column.
| Model | \#Params | Duration | Steps | FD<sub>openl3</sub> ↓ | KL<sub>passt</sub> ↓ | CLAP<sub>score</sub> ↑ | IS ↑ | Inference Time (s) | | Model | Params | Duration | Steps | FD<sub>openl3</sub> ↓ | KL<sub>passt</sub> ↓ | CLAP<sub>score</sub> ↑ | IS ↑ | Inference Time (s) |
|---------------------------------|-----------|----------|-------|-----------------------|----------------------|------------------------|------|--------------------| |---|---|---|---|---|---|---|---|---|
| **AudioLDM 2-large** | 712M | 10 sec | 200 | 108.3 | 1.81 | 0.419 | 7.9 | 24.8 | | **AudioLDM 2 (Large)** | 712M | 10 sec | 200 | 108.3 | 1.81 | 0.419 | 7.9 | 24.8 |
| **Stable Audio Open** | 1056M | 47 sec | 100 | 89.2 | 2.58 | 0.291 | 9.9 | 8.6 | | **Stable Audio Open** | 1056M | 47 sec | 100 | 89.2 | 2.58 | 0.291 | 9.9 | 8.6 |
| **Tango 2** | 866M | 10 sec | 200 | 108.4 | **1.11** | 0.447 | 9.0 | 22.8 | | **Tango 2** | 866M | 10 sec | 200 | 108.4 | 1.11 | 0.447 | 9.0 | 22.8 |
| **TangoFlux-base** | **515M** | 30 sec | 50 | 80.2 | 1.22 | 0.431 | 11.7 | **3.7** | | **TangoFlux (Base)** | 515M | 30 sec | 50 | 80.2 | 1.22 | 0.431 | 11.7 | 3.7 |
| **TangoFlux** | **515M** | 30 sec | 50 | **75.1** | 1.15 | **0.480** | **12.2** | **3.7** | | **TangoFlux** | 515M | 30 sec | 50 | 75.1 | 1.15 | 0.480 | 12.2 | 3.7 |
## Citation ## Citation
...@@ -100,3 +113,7 @@ All the inference times are observed on the same A40 GPU. The counts of trainabl ...@@ -100,3 +113,7 @@ All the inference times are observed on the same A40 GPU. The counts of trainabl
url={https://arxiv.org/abs/2412.21037}, url={https://arxiv.org/abs/2412.21037},
} }
``` ```
## License
TangoFlux is licensed under the MIT License. See the `LICENSE` file for more details.
*.wav
\ No newline at end of file
import torchaudio
from tangoflux import TangoFluxInference
model = TangoFluxInference(name="declare-lab/TangoFlux")
audio = model.generate("Hammer slowly hitting the wooden table", steps=50, duration=10)
torchaudio.save("output.wav", audio, sample_rate=44100)
...@@ -10,11 +10,13 @@ from diffusers import AutoencoderOobleck ...@@ -10,11 +10,13 @@ from diffusers import AutoencoderOobleck
import soundfile as sf import soundfile as sf
from safetensors.torch import load_file from safetensors.torch import load_file
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from src.model import TangoFlux from tangoflux.model import TangoFlux
from tangoflux import TangoFluxInference from tangoflux import TangoFluxInference
MODEL_CACHE = "model_cache" MODEL_CACHE = "model_cache"
MODEL_URL = "https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar" MODEL_URL = (
"https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
)
class CachedTangoFluxInference(TangoFluxInference): class CachedTangoFluxInference(TangoFluxInference):
......
from setuptools import setup
setup(
name="tangoflux",
description="TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching",
version="0.1.0",
packages=["tangoflux"],
install_requires=[
"torch==2.4.0",
"torchaudio==2.4.0",
"torchlibrosa==0.1.0",
"torchvision==0.19.0",
"transformers==4.44.0",
"diffusers==0.30.0",
"accelerate==0.34.2",
"datasets==2.21.0",
"librosa",
"tqdm",
"wandb",
"click",
"gradio",
"torchaudio",
],
entry_points={
"console_scripts": [
"tangoflux=tangoflux.cli:main",
"tangoflux-demo=tangoflux.demo:main",
],
},
)
from diffusers import AutoencoderOobleck from diffusers import AutoencoderOobleck
import torch import torch
from transformers import T5EncoderModel,T5TokenizerFast from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from torch import nn from torch import nn
from typing import List from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler from diffusers import FlowMatchEulerDiscreteScheduler
...@@ -9,10 +9,10 @@ from diffusers.training_utils import compute_density_for_timestep_sampling ...@@ -9,10 +9,10 @@ from diffusers.training_utils import compute_density_for_timestep_sampling
import copy import copy
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from src.model import TangoFlux from tangoflux.model import TangoFlux
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from tqdm import tqdm from tqdm import tqdm
from typing import Optional,Union,List from typing import Optional, Union, List
from datasets import load_dataset, Audio from datasets import load_dataset, Audio
from math import pi from math import pi
import json import json
...@@ -23,39 +23,38 @@ from safetensors.torch import load_file ...@@ -23,39 +23,38 @@ from safetensors.torch import load_file
class TangoFluxInference: class TangoFluxInference:
def __init__(self,name='declare-lab/TangoFlux',device="cuda"): def __init__(
self,
name="declare-lab/TangoFlux",
device="cuda" if torch.cuda.is_available() else "cpu",
):
self.vae = AutoencoderOobleck() self.vae = AutoencoderOobleck()
paths = snapshot_download(repo_id=name) paths = snapshot_download(repo_id=name)
vae_weights = load_file("{}/vae.safetensors".format(paths)) vae_weights = load_file("{}/vae.safetensors".format(paths))
self.vae.load_state_dict(vae_weights) self.vae.load_state_dict(vae_weights)
weights = load_file("{}/tangoflux.safetensors".format(paths)) weights = load_file("{}/tangoflux.safetensors".format(paths))
with open('{}/config.json'.format(paths),'r') as f: with open("{}/config.json".format(paths), "r") as f:
config = json.load(f) config = json.load(f)
self.model = TangoFlux(config) self.model = TangoFlux(config)
self.model.load_state_dict(weights,strict=False) self.model.load_state_dict(weights, strict=False)
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected # _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
self.vae.to(device) self.vae.to(device)
self.model.to(device) self.model.to(device)
def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
with torch.no_grad():
latents = self.model.inference_flow(prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale)
def generate(self, prompt, steps=25, duration=10, guidance_scale=4.5):
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0] with torch.no_grad():
latents = self.model.inference_flow(
prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale,
)
wave = self.vae.decode(latents.transpose(2, 1)).sample.cpu()[0]
waveform_end = int(duration * self.vae.config.sampling_rate) waveform_end = int(duration * self.vae.config.sampling_rate)
wave = wave[:, :waveform_end] wave = wave[:, :waveform_end]
return wave return wave
import click
import torchaudio
from tangoflux import TangoFluxInference
@click.command()
@click.argument('prompt')
@click.argument('output_file')
@click.option('--duration', default=10, type=int, help='Duration in seconds (1-30)')
@click.option('--steps', default=50, type=int, help='Number of inference steps (10-100)')
def main(prompt: str, output_file: str, duration: int, steps: int):
"""Generate audio from text using TangoFlux.
Args:
prompt: Text description of the audio to generate
output_file: Path to save the generated audio file
duration: Duration of generated audio in seconds (default: 10)
steps: Number of inference steps (default: 50)
"""
if not 1 <= duration <= 30:
raise click.BadParameter('Duration must be between 1 and 30 seconds')
if not 10 <= steps <= 100:
raise click.BadParameter('Steps must be between 10 and 100')
model = TangoFluxInference(name="declare-lab/TangoFlux")
audio = model.generate(prompt, steps=steps, duration=duration)
torchaudio.save(output_file, audio, sample_rate=44100)
if __name__ == '__main__':
main()
import gradio as gr
import torchaudio
import click
import tempfile
from tangoflux import TangoFluxInference
model = TangoFluxInference(name="declare-lab/TangoFlux")
def generate_audio(prompt, duration, steps):
audio = model.generate(prompt, steps=steps, duration=duration)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
torchaudio.save(f.name, audio, sample_rate=44100)
return f.name
examples = [
["Hammer slowly hitting the wooden table", 10, 50],
["Gentle rain falling on a tin roof", 15, 50],
["Wind chimes tinkling in a light breeze", 10, 50],
["Rhythmic wooden table tapping overlaid with steady water pouring sound", 10, 50],
]
with gr.Blocks(title="TangoFlux Text-to-Audio Generation") as demo:
gr.Markdown("# TangoFlux Text-to-Audio Generation")
gr.Markdown("Generate audio from text descriptions using TangoFlux")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Text Prompt", placeholder="Enter your audio description..."
)
duration = gr.Slider(
minimum=1, maximum=30, value=10, step=1, label="Duration (seconds)"
)
steps = gr.Slider(
minimum=10, maximum=100, value=50, step=10, label="Number of Steps"
)
generate_btn = gr.Button("Generate Audio")
with gr.Column():
audio_output = gr.Audio(label="Generated Audio")
generate_btn.click(
fn=generate_audio, inputs=[prompt, duration, steps], outputs=audio_output
)
gr.Examples(
examples=examples,
inputs=[prompt, duration, steps],
outputs=audio_output,
fn=generate_audio,
)
@click.command()
@click.option('--host', default='127.0.0.1', help='Host to bind to')
@click.option('--port', default=None, help='Port to bind to')
@click.option('--share', is_flag=True, help='Enable sharing via Gradio')
def main(host, port, share):
demo.queue().launch(server_name=host, server_port=port, share=share)
if __name__ == "__main__":
main()
from transformers import T5EncoderModel,T5TokenizerFast from transformers import T5EncoderModel, T5TokenizerFast
import torch import torch
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from torch import nn from torch import nn
import random
from typing import List from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_density_for_timestep_sampling from diffusers.training_utils import compute_density_for_timestep_sampling
...@@ -11,19 +11,16 @@ import torch.nn.functional as F ...@@ -11,19 +11,16 @@ import torch.nn.functional as F
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from typing import Optional,Union,List from typing import Optional, Union, List
from datasets import load_dataset, Audio from datasets import load_dataset, Audio
from math import pi from math import pi
import inspect import inspect
import yaml import yaml
class StableAudioPositionalEmbedding(nn.Module): class StableAudioPositionalEmbedding(nn.Module):
"""Used for continuous time """Used for continuous time
Adapted from Stable Audio Open.
Adapted from stable audio open.
""" """
def __init__(self, dim: int): def __init__(self, dim: int):
...@@ -38,7 +35,8 @@ class StableAudioPositionalEmbedding(nn.Module): ...@@ -38,7 +35,8 @@ class StableAudioPositionalEmbedding(nn.Module):
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((times, fouriered), dim=-1) fouriered = torch.cat((times, fouriered), dim=-1)
return fouriered return fouriered
class DurationEmbedder(nn.Module): class DurationEmbedder(nn.Module):
""" """
A simple linear projection model to map numbers to a latent space. A simple linear projection model to map numbers to a latent space.
...@@ -73,7 +71,7 @@ class DurationEmbedder(nn.Module): ...@@ -73,7 +71,7 @@ class DurationEmbedder(nn.Module):
self.number_embedding_dim = number_embedding_dim self.number_embedding_dim = number_embedding_dim
self.min_value = min_value self.min_value = min_value
self.max_value = max_value self.max_value = max_value
self.dtype = torch.float32 self.dtype = torch.float32
def forward( def forward(
self, self,
...@@ -81,7 +79,9 @@ class DurationEmbedder(nn.Module): ...@@ -81,7 +79,9 @@ class DurationEmbedder(nn.Module):
): ):
floats = floats.clamp(self.min_value, self.max_value) floats = floats.clamp(self.min_value, self.max_value)
normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) normalized_floats = (floats - self.min_value) / (
self.max_value - self.min_value
)
# Cast floats to same type as embedder # Cast floats to same type as embedder
embedder_dtype = next(self.time_positional_embedding.parameters()).dtype embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
...@@ -103,9 +103,13 @@ def retrieve_timesteps( ...@@ -103,9 +103,13 @@ def retrieve_timesteps(
): ):
if timesteps is not None and sigmas is not None: if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accepts_timesteps: if not accepts_timesteps:
raise ValueError( raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
...@@ -115,7 +119,9 @@ def retrieve_timesteps( ...@@ -115,7 +119,9 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None: elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accept_sigmas = "sigmas" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accept_sigmas: if not accept_sigmas:
raise ValueError( raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
...@@ -128,113 +134,115 @@ def retrieve_timesteps( ...@@ -128,113 +134,115 @@ def retrieve_timesteps(
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
return timesteps, num_inference_steps return timesteps, num_inference_steps
class TangoFlux(nn.Module): class TangoFlux(nn.Module):
def __init__(self, config, initialize_reference_model=False):
def __init__(self,config,initialize_reference_model=False):
super().__init__() super().__init__()
self.num_layers = config.get("num_layers", 6)
self.num_single_layers = config.get("num_single_layers", 18)
self.num_layers = config.get('num_layers', 6) self.in_channels = config.get("in_channels", 64)
self.num_single_layers = config.get('num_single_layers', 18) self.attention_head_dim = config.get("attention_head_dim", 128)
self.in_channels = config.get('in_channels', 64) self.joint_attention_dim = config.get("joint_attention_dim", 1024)
self.attention_head_dim = config.get('attention_head_dim', 128) self.num_attention_heads = config.get("num_attention_heads", 8)
self.joint_attention_dim = config.get('joint_attention_dim', 1024) self.audio_seq_len = config.get("audio_seq_len", 645)
self.num_attention_heads = config.get('num_attention_heads', 8) self.max_duration = config.get("max_duration", 30)
self.audio_seq_len = config.get('audio_seq_len', 645) self.uncondition = config.get("uncondition", False)
self.max_duration = config.get('max_duration', 30) self.text_encoder_name = config.get("text_encoder_name", "google/flan-t5-large")
self.uncondition = config.get('uncondition', False)
self.text_encoder_name = config.get('text_encoder_name', "google/flan-t5-large")
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler) self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
self.max_text_seq_len = 64 self.max_text_seq_len = 64
self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name) self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
self.tokenizer = T5TokenizerFast.from_pretrained(self.text_encoder_name) self.tokenizer = T5TokenizerFast.from_pretrained(self.text_encoder_name)
self.text_embedding_dim = self.text_encoder.config.d_model self.text_embedding_dim = self.text_encoder.config.d_model
self.fc = nn.Sequential(
self.fc = nn.Sequential(nn.Linear(self.text_embedding_dim,self.joint_attention_dim),nn.ReLU()) nn.Linear(self.text_embedding_dim, self.joint_attention_dim), nn.ReLU()
self.duration_emebdder = DurationEmbedder(self.text_embedding_dim,min_value=0,max_value=self.max_duration) )
self.duration_emebdder = DurationEmbedder(
self.text_embedding_dim, min_value=0, max_value=self.max_duration
)
self.transformer = FluxTransformer2DModel( self.transformer = FluxTransformer2DModel(
in_channels=self.in_channels, in_channels=self.in_channels,
num_layers=self.num_layers, num_layers=self.num_layers,
num_single_layers=self.num_single_layers, num_single_layers=self.num_single_layers,
attention_head_dim=self.attention_head_dim, attention_head_dim=self.attention_head_dim,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
joint_attention_dim=self.joint_attention_dim, joint_attention_dim=self.joint_attention_dim,
pooled_projection_dim=self.text_embedding_dim, pooled_projection_dim=self.text_embedding_dim,
guidance_embeds=False) guidance_embeds=False,
)
self.beta_dpo = 2000 ## this is used for dpo training
self.beta_dpo = 2000 ## this is used for dpo training
def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32):
def get_sigmas(self,timesteps, n_dim=3, dtype=torch.float32):
device = self.text_encoder.device device = self.text_encoder.device
sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype) sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device) schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device)
timesteps = timesteps.to(device) timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim: while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1) sigma = sigma.unsqueeze(-1)
return sigma return sigma
def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1): def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1):
device = self.text_encoder.device device = self.text_encoder.device
batch = self.tokenizer( batch = self.tokenizer(
prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" prompt,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
)
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
device
) )
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
with torch.no_grad(): with torch.no_grad():
prompt_embeds = self.text_encoder( prompt_embeds = self.text_encoder(
input_ids=input_ids, attention_mask=attention_mask input_ids=input_ids, attention_mask=attention_mask
)[0] )[0]
prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0) attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
uncond_tokens = [""] uncond_tokens = [""]
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_batch = self.tokenizer( uncond_batch = self.tokenizer(
uncond_tokens, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt", uncond_tokens,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
) )
uncond_input_ids = uncond_batch.input_ids.to(device) uncond_input_ids = uncond_batch.input_ids.to(device)
uncond_attention_mask = uncond_batch.attention_mask.to(device) uncond_attention_mask = uncond_batch.attention_mask.to(device)
with torch.no_grad(): with torch.no_grad():
negative_prompt_embeds = self.text_encoder( negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input_ids, attention_mask=uncond_attention_mask input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
)[0] )[0]
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(
uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0) num_samples_per_prompt, 0
)
uncond_attention_mask = uncond_attention_mask.repeat_interleave(
num_samples_per_prompt, 0
)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_mask = torch.cat([uncond_attention_mask, attention_mask]) prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
boolean_prompt_mask = (prompt_mask == 1).to(device) boolean_prompt_mask = (prompt_mask == 1).to(device)
...@@ -245,266 +253,287 @@ class TangoFlux(nn.Module): ...@@ -245,266 +253,287 @@ class TangoFlux(nn.Module):
def encode_text(self, prompt): def encode_text(self, prompt):
device = self.text_encoder.device device = self.text_encoder.device
batch = self.tokenizer( batch = self.tokenizer(
prompt, max_length=self.max_text_seq_len, padding=True, truncation=True, return_tensors="pt") prompt,
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) max_length=self.max_text_seq_len,
padding=True,
truncation=True,
return_tensors="pt",
)
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
device
)
encoder_hidden_states = self.text_encoder( encoder_hidden_states = self.text_encoder(
input_ids=input_ids, attention_mask=attention_mask)[0] input_ids=input_ids, attention_mask=attention_mask
)[0]
boolean_encoder_mask = (attention_mask == 1).to(device) boolean_encoder_mask = (attention_mask == 1).to(device)
return encoder_hidden_states, boolean_encoder_mask return encoder_hidden_states, boolean_encoder_mask
def encode_duration(self,duration):
return self.duration_emebdder(duration)
def encode_duration(self, duration):
return self.duration_emebdder(duration)
@torch.no_grad() @torch.no_grad()
def inference_flow(self, prompt, def inference_flow(
num_inference_steps=50, self,
timesteps=None, prompt,
guidance_scale=3, num_inference_steps=50,
duration=10, timesteps=None,
disable_progress=False, guidance_scale=3,
num_samples_per_prompt=1): duration=10,
disable_progress=False,
'''Only tested for single inference. Haven't test for batch inference''' num_samples_per_prompt=1,
):
"""Only tested for single inference. Haven't test for batch inference"""
bsz = num_samples_per_prompt bsz = num_samples_per_prompt
device = self.transformer.device device = self.transformer.device
scheduler = self.noise_scheduler scheduler = self.noise_scheduler
if not isinstance(prompt,list): if not isinstance(prompt, list):
prompt = [prompt] prompt = [prompt]
if not isinstance(duration,torch.Tensor): if not isinstance(duration, torch.Tensor):
duration = torch.tensor([duration],device=device) duration = torch.tensor([duration], device=device)
classifier_free_guidance = guidance_scale > 1.0 classifier_free_guidance = guidance_scale > 1.0
duration_hidden_states = self.encode_duration(duration) duration_hidden_states = self.encode_duration(duration)
if classifier_free_guidance: if classifier_free_guidance:
bsz = 2 * num_samples_per_prompt bsz = 2 * num_samples_per_prompt
encoder_hidden_states, boolean_encoder_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt=num_samples_per_prompt) encoder_hidden_states, boolean_encoder_mask = (
duration_hidden_states = duration_hidden_states.repeat(bsz,1,1) self.encode_text_classifier_free(
prompt, num_samples_per_prompt=num_samples_per_prompt
)
)
duration_hidden_states = duration_hidden_states.repeat(bsz, 1, 1)
else: else:
encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt,num_samples_per_prompt=num_samples_per_prompt) encoder_hidden_states, boolean_encoder_mask = self.encode_text(
prompt, num_samples_per_prompt=num_samples_per_prompt
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states) )
masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(
encoder_hidden_states
)
masked_data = torch.where(
mask_expanded, encoder_hidden_states, torch.tensor(float("nan"))
)
pooled = torch.nanmean(masked_data, dim=1) pooled = torch.nanmean(masked_data, dim=1)
pooled_projection = self.fc(pooled) pooled_projection = self.fc(pooled)
encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim) encoder_hidden_states = torch.cat(
[encoder_hidden_states, duration_hidden_states], dim=1
) ## (bs,seq_len,dim)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
scheduler, scheduler, num_inference_steps, device, timesteps, sigmas
num_inference_steps,
device,
timesteps,
sigmas
) )
latents = torch.randn(num_samples_per_prompt,self.audio_seq_len,64) latents = torch.randn(num_samples_per_prompt, self.audio_seq_len, 64)
weight_dtype = latents.dtype weight_dtype = latents.dtype
progress_bar = tqdm(range(num_inference_steps), disable=disable_progress) progress_bar = tqdm(range(num_inference_steps), disable=disable_progress)
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device) txt_ids = torch.zeros(bsz, encoder_hidden_states.shape[1], 3).to(device)
audio_ids = torch.arange(self.audio_seq_len).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device) audio_ids = (
torch.arange(self.audio_seq_len)
.unsqueeze(0)
.unsqueeze(-1)
.repeat(bsz, 1, 3)
.to(device)
)
timesteps = timesteps.to(device) timesteps = timesteps.to(device)
latents = latents.to(device) latents = latents.to(device)
encoder_hidden_states = encoder_hidden_states.to(device) encoder_hidden_states = encoder_hidden_states.to(device)
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
latents_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
latents_input = (
torch.cat([latents] * 2) if classifier_free_guidance else latents
)
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latents_input, hidden_states=latents_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=torch.tensor([t/1000],device=device), timestep=torch.tensor([t / 1000], device=device),
guidance = None, guidance=None,
pooled_projections=pooled_projection, pooled_projections=pooled_projection,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids, txt_ids=txt_ids,
img_ids=audio_ids, img_ids=audio_ids,
return_dict=False, return_dict=False,
)[0] )[0]
if classifier_free_guidance: if classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
latents = scheduler.step(noise_pred, t, latents).prev_sample
latents = scheduler.step(noise_pred, t, latents).prev_sample
return latents return latents
def forward(self, def forward(self, latents, prompt, duration=torch.tensor([10]), sft=True):
latents,
prompt,
duration=torch.tensor([10]),
sft=True
):
device = latents.device device = latents.device
audio_seq_length = self.audio_seq_len audio_seq_length = self.audio_seq_len
bsz = latents.shape[0] bsz = latents.shape[0]
encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt) encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
duration_hidden_states = self.encode_duration(duration) duration_hidden_states = self.encode_duration(duration)
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states) encoder_hidden_states
masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan'))) )
masked_data = torch.where(
mask_expanded, encoder_hidden_states, torch.tensor(float("nan"))
)
pooled = torch.nanmean(masked_data, dim=1) pooled = torch.nanmean(masked_data, dim=1)
pooled_projection = self.fc(pooled) pooled_projection = self.fc(pooled)
## Add duration hidden states to encoder hidden states ## Add duration hidden states to encoder hidden states
encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim) encoder_hidden_states = torch.cat(
[encoder_hidden_states, duration_hidden_states], dim=1
) ## (bs,seq_len,dim)
txt_ids = torch.zeros(bsz, encoder_hidden_states.shape[1], 3).to(device)
audio_ids = (
torch.arange(audio_seq_length)
.unsqueeze(0)
.unsqueeze(-1)
.repeat(bsz, 1, 3)
.to(device)
)
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
audio_ids = torch.arange(audio_seq_length).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
if sft: if sft:
if self.uncondition: if self.uncondition:
mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1] mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
if len(mask_indices) > 0: if len(mask_indices) > 0:
encoder_hidden_states[mask_indices] = 0 encoder_hidden_states[mask_indices] = 0
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
u = compute_density_for_timestep_sampling( u = compute_density_for_timestep_sampling(
weighting_scheme='logit_normal', weighting_scheme="logit_normal",
batch_size=bsz, batch_size=bsz,
logit_mean=0, logit_mean=0,
logit_std=1, logit_std=1,
mode_scale=None, mode_scale=None,
) )
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device) timesteps = self.noise_scheduler_copy.timesteps[indices].to(
device=latents.device
)
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
model_pred = self.transformer(
hidden_states=noisy_model_input,
model_pred = self.transformer( encoder_hidden_states=encoder_hidden_states,
hidden_states=noisy_model_input, pooled_projections=pooled_projection,
encoder_hidden_states=encoder_hidden_states, img_ids=audio_ids,
pooled_projections=pooled_projection, txt_ids=txt_ids,
img_ids=audio_ids, guidance=None,
txt_ids=txt_ids,
guidance=None,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps/1000, timestep=timesteps / 1000,
return_dict=False)[0] return_dict=False,
)[0]
target = noise - latents target = noise - latents
loss = torch.mean( loss = torch.mean(
( (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), ((model_pred.float() - target.float()) ** 2).reshape(
1, target.shape[0], -1
) ),
1,
)
loss = loss.mean() loss = loss.mean()
raw_model_loss, raw_ref_loss,implicit_acc = 0,0,0 ## default this to 0 if doing sft raw_model_loss, raw_ref_loss, implicit_acc = (
0,
0,
0,
) ## default this to 0 if doing sft
else: else:
encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1) encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
pooled_projection = pooled_projection.repeat(2,1) pooled_projection = pooled_projection.repeat(2, 1)
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1) ## Have to sample same noise for preferred and rejected noise = (
torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1)
) ## Have to sample same noise for preferred and rejected
u = compute_density_for_timestep_sampling( u = compute_density_for_timestep_sampling(
weighting_scheme='logit_normal', weighting_scheme="logit_normal",
batch_size=bsz//2, batch_size=bsz // 2,
logit_mean=0, logit_mean=0,
logit_std=1, logit_std=1,
mode_scale=None, mode_scale=None,
) )
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device) timesteps = self.noise_scheduler_copy.timesteps[indices].to(
device=latents.device
)
timesteps = timesteps.repeat(2) timesteps = timesteps.repeat(2)
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
model_pred = self.transformer( model_pred = self.transformer(
hidden_states=noisy_model_input, hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection, pooled_projections=pooled_projection,
img_ids=audio_ids, img_ids=audio_ids,
txt_ids=txt_ids, txt_ids=txt_ids,
guidance=None, guidance=None,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps/1000, timestep=timesteps / 1000,
return_dict=False)[0] return_dict=False,
)[0]
target = noise - latents target = noise - latents
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none") model_losses = F.mse_loss(
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape)))) model_pred.float(), target.float(), reduction="none"
)
model_losses = model_losses.mean(
dim=list(range(1, len(model_losses.shape)))
)
model_losses_w, model_losses_l = model_losses.chunk(2) model_losses_w, model_losses_l = model_losses.chunk(2)
model_diff = model_losses_w - model_losses_l model_diff = model_losses_w - model_losses_l
raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean()) raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
with torch.no_grad(): with torch.no_grad():
ref_preds = self.ref_transformer( ref_preds = self.ref_transformer(
hidden_states=noisy_model_input, hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection, pooled_projections=pooled_projection,
img_ids=audio_ids, img_ids=audio_ids,
txt_ids=txt_ids, txt_ids=txt_ids,
guidance=None, guidance=None,
timestep=timesteps/1000, timestep=timesteps / 1000,
return_dict=False)[0] return_dict=False,
)[0]
ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none") ref_loss = F.mse_loss(
ref_preds.float(), target.float(), reduction="none"
)
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape)))) ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
ref_losses_w, ref_losses_l = ref_loss.chunk(2) ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean() raw_ref_loss = ref_loss.mean()
scale_term = -0.5 * self.beta_dpo scale_term = -0.5 * self.beta_dpo
inside_term = scale_term * (model_diff - ref_diff) inside_term = scale_term * (model_diff - ref_diff)
implicit_acc = (scale_term * (model_diff - ref_diff) > 0).sum().float() / inside_term.size(0) implicit_acc = (
loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean() scale_term * (model_diff - ref_diff) > 0
).sum().float() / inside_term.size(0)
## raw_model_loss, raw_ref_loss, implicit_acc is used to help to analyze dpo behaviour. loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
return loss, raw_model_loss, raw_ref_loss, implicit_acc
## raw_model_loss, raw_ref_loss, implicit_acc is used to help to analyze dpo behaviour.
\ No newline at end of file return loss, raw_model_loss, raw_ref_loss, implicit_acc
...@@ -22,136 +22,170 @@ from tqdm.auto import tqdm ...@@ -22,136 +22,170 @@ from tqdm.auto import tqdm
from transformers import SchedulerType, get_scheduler from transformers import SchedulerType, get_scheduler
from model import TangoFlux from model import TangoFlux
from datasets import load_dataset, Audio from datasets import load_dataset, Audio
from utils import Text2AudioDataset,read_wav_file,pad_wav from utils import Text2AudioDataset, read_wav_file, pad_wav
from diffusers import AutoencoderOobleck from diffusers import AutoencoderOobleck
import torchaudio import torchaudio
logger = get_logger(__name__) logger = get_logger(__name__)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Rectified flow for text to audio generation task.") parser = argparse.ArgumentParser(
description="Rectified flow for text to audio generation task."
)
parser.add_argument( parser.add_argument(
"--num_examples", type=int, default=-1, "--num_examples",
type=int,
default=-1,
help="How many examples to use for training and validation.", help="How many examples to use for training and validation.",
) )
parser.add_argument( parser.add_argument(
"--text_column", type=str, default="captions", "--text_column",
type=str,
default="captions",
help="The name of the column in the datasets containing the input texts.", help="The name of the column in the datasets containing the input texts.",
) )
parser.add_argument( parser.add_argument(
"--audio_column", type=str, default="location", "--audio_column",
type=str,
default="location",
help="The name of the column in the datasets containing the audio paths.", help="The name of the column in the datasets containing the audio paths.",
) )
parser.add_argument( parser.add_argument(
"--adam_beta1", type=float, default=0.9, "--adam_beta1",
help="The beta1 parameter for the Adam optimizer." type=float,
default=0.9,
help="The beta1 parameter for the Adam optimizer.",
) )
parser.add_argument( parser.add_argument(
"--adam_beta2", type=float, default=0.95, "--adam_beta2",
help="The beta2 parameter for the Adam optimizer." type=float,
default=0.95,
help="The beta2 parameter for the Adam optimizer.",
) )
parser.add_argument( parser.add_argument(
"--config", type=str, default='tangoflux_config.yaml', "--config",
type=str,
default="tangoflux_config.yaml",
help="Config file defining the model size as well as other hyper parameter.", help="Config file defining the model size as well as other hyper parameter.",
) )
parser.add_argument( parser.add_argument(
"--prefix", type=str, default='', "--prefix",
type=str,
default="",
help="Add prefix in text prompts.", help="Add prefix in text prompts.",
) )
parser.add_argument( parser.add_argument(
"--learning_rate", type=float, default=3e-5, "--learning_rate",
type=float,
default=3e-5,
help="Initial learning rate (after the potential warmup period) to use.", help="Initial learning rate (after the potential warmup period) to use.",
) )
parser.add_argument( parser.add_argument(
"--weight_decay", type=float, default=1e-8, "--weight_decay", type=float, default=1e-8, help="Weight decay to use."
help="Weight decay to use."
) )
parser.add_argument( parser.add_argument(
"--max_train_steps", type=int, default=None, "--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
) )
parser.add_argument( parser.add_argument(
"--lr_scheduler_type", type=SchedulerType, default="linear", "--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.", help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], choices=[
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
],
) )
parser.add_argument( parser.add_argument(
"--num_warmup_steps", type=int, default=0, "--num_warmup_steps",
help="Number of steps for the warmup in the lr scheduler." type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler.",
) )
parser.add_argument( parser.add_argument(
"--adam_epsilon", type=float, default=1e-08, "--adam_epsilon",
help="Epsilon value for the Adam optimizer" type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer",
) )
parser.add_argument( parser.add_argument(
"--adam_weight_decay", type=float, default=1e-2, "--adam_weight_decay",
help="Epsilon value for the Adam optimizer" type=float,
default=1e-2,
help="Epsilon value for the Adam optimizer",
) )
parser.add_argument( parser.add_argument(
"--seed", type=int, default=None, "--seed", type=int, default=None, help="A seed for reproducible training."
help="A seed for reproducible training."
) )
parser.add_argument( parser.add_argument(
"--checkpointing_steps", type=str, default="best", "--checkpointing_steps",
type=str,
default="best",
help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.", help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
) )
parser.add_argument( parser.add_argument(
"--save_every", type=int, default=5, "--save_every",
help="Save model after every how many epochs when checkpointing_steps is set to best." type=int,
default=5,
help="Save model after every how many epochs when checkpointing_steps is set to best.",
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", type=str, default=None, "--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a local checkpoint folder.", help="If the training should continue from a local checkpoint folder.",
) )
parser.add_argument( parser.add_argument(
"--load_from_checkpoint", type=str, default=None, "--load_from_checkpoint",
type=str,
default=None,
help="Whether to continue training from a model weight", help="Whether to continue training from a model weight",
) )
args = parser.parse_args() args = parser.parse_args()
return args return args
def main(): def main():
args = parse_args() args = parse_args()
accelerator_log_kwargs = {} accelerator_log_kwargs = {}
def load_config(config_path): def load_config(config_path):
with open(config_path, 'r') as file: with open(config_path, "r") as file:
return yaml.safe_load(file) return yaml.safe_load(file)
config = load_config(args.config) config = load_config(args.config)
learning_rate = float(config["training"]["learning_rate"])
num_train_epochs = int(config["training"]["num_train_epochs"])
learning_rate = float(config['training']['learning_rate']) num_warmup_steps = int(config["training"]["num_warmup_steps"])
num_train_epochs = int(config['training']['num_train_epochs']) per_device_batch_size = int(config["training"]["per_device_batch_size"])
num_warmup_steps = int(config['training']['num_warmup_steps']) gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
per_device_batch_size = int(config['training']['per_device_batch_size'])
gradient_accumulation_steps = int(config['training']['gradient_accumulation_steps'])
output_dir = config['paths']['output_dir'] output_dir = config["paths"]["output_dir"]
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
**accelerator_log_kwargs,
)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, **accelerator_log_kwargs)
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -172,12 +206,12 @@ def main(): ...@@ -172,12 +206,12 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
if output_dir is None or output_dir == "": if output_dir is None or output_dir == "":
output_dir = "saved/" + str(int(time.time())) output_dir = "saved/" + str(int(time.time()))
if not os.path.exists("saved"): if not os.path.exists("saved"):
os.makedirs("saved") os.makedirs("saved")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
elif output_dir is not None: elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
...@@ -187,74 +221,120 @@ def main(): ...@@ -187,74 +221,120 @@ def main():
accelerator.project_configuration.automatic_checkpoint_naming = False accelerator.project_configuration.automatic_checkpoint_naming = False
wandb.init(project="Text to Audio Flow matching",settings=wandb.Settings(_disable_stats=True)) wandb.init(
project="Text to Audio Flow matching",
settings=wandb.Settings(_disable_stats=True),
)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# Get the datasets # Get the datasets
data_files = {} data_files = {}
#if args.train_file is not None: # if args.train_file is not None:
if config['paths']['train_file'] != '': if config["paths"]["train_file"] != "":
data_files["train"] = config['paths']['train_file'] data_files["train"] = config["paths"]["train_file"]
# if args.validation_file is not None: # if args.validation_file is not None:
if config['paths']['val_file'] != '': if config["paths"]["val_file"] != "":
data_files["validation"] = config['paths']['val_file'] data_files["validation"] = config["paths"]["val_file"]
if config['paths']['test_file'] != '': if config["paths"]["test_file"] != "":
data_files["test"] = config['paths']['test_file'] data_files["test"] = config["paths"]["test_file"]
else: else:
data_files["test"] = config['paths']['val_file'] data_files["test"] = config["paths"]["val_file"]
extension = 'json' extension = "json"
raw_datasets = load_dataset(extension, data_files=data_files) raw_datasets = load_dataset(extension, data_files=data_files)
text_column, audio_column = args.text_column, args.audio_column text_column, audio_column = args.text_column, args.audio_column
model = TangoFlux(config=config['model']) model = TangoFlux(config=config["model"])
vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae') vae = AutoencoderOobleck.from_pretrained(
"stabilityai/stable-audio-open-1.0", subfolder="vae"
)
## Freeze vae ## Freeze vae
for param in vae.parameters(): for param in vae.parameters():
vae.requires_grad = False vae.requires_grad = False
vae.eval() vae.eval()
## Freeze text encoder param ## Freeze text encoder param
for param in model.text_encoder.parameters(): for param in model.text_encoder.parameters():
param.requires_grad = False param.requires_grad = False
model.text_encoder.eval() model.text_encoder.eval()
prefix = args.prefix
prefix = args.prefix
with accelerator.main_process_first(): with accelerator.main_process_first():
train_dataset = Text2AudioDataset(raw_datasets["train"], prefix, text_column, audio_column,'duration', args.num_examples) train_dataset = Text2AudioDataset(
eval_dataset = Text2AudioDataset(raw_datasets["validation"], prefix, text_column, audio_column,'duration', args.num_examples) raw_datasets["train"],
test_dataset = Text2AudioDataset(raw_datasets["test"], prefix, text_column, audio_column,'duration', args.num_examples) prefix,
text_column,
accelerator.print("Num instances in train: {}, validation: {}, test: {}".format(train_dataset.get_num_instances(), eval_dataset.get_num_instances(), test_dataset.get_num_instances())) audio_column,
"duration",
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['training']['per_device_batch_size'], collate_fn=train_dataset.collate_fn) args.num_examples,
eval_dataloader = DataLoader(eval_dataset, shuffle=True, batch_size=config['training']['per_device_batch_size'], collate_fn=eval_dataset.collate_fn) )
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=config['training']['per_device_batch_size'], collate_fn=test_dataset.collate_fn) eval_dataset = Text2AudioDataset(
raw_datasets["validation"],
prefix,
text_column,
audio_column,
"duration",
args.num_examples,
)
test_dataset = Text2AudioDataset(
raw_datasets["test"],
prefix,
text_column,
audio_column,
"duration",
args.num_examples,
)
accelerator.print(
"Num instances in train: {}, validation: {}, test: {}".format(
train_dataset.get_num_instances(),
eval_dataset.get_num_instances(),
test_dataset.get_num_instances(),
)
)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=train_dataset.collate_fn,
)
eval_dataloader = DataLoader(
eval_dataset,
shuffle=True,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=eval_dataset.collate_fn,
)
test_dataloader = DataLoader(
test_dataset,
shuffle=False,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=test_dataset.collate_fn,
)
# Optimizer # Optimizer
optimizer_parameters = list(model.transformer.parameters())+list(model.fc.parameters())
num_trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
optimizer_parameters = list(model.transformer.parameters()) + list(
model.fc.parameters()
)
num_trainable_parameters = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
if args.load_from_checkpoint: if args.load_from_checkpoint:
from safetensors.torch import load_file from safetensors.torch import load_file
w1 = load_file(args.load_from_checkpoint) w1 = load_file(args.load_from_checkpoint)
model.load_state_dict(w1,strict=False) model.load_state_dict(w1, strict=False)
logger.info("Weights loaded from{}".format(args.load_from_checkpoint)) logger.info("Weights loaded from{}".format(args.load_from_checkpoint))
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
optimizer_parameters, lr=learning_rate, optimizer_parameters,
lr=learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon, eps=args.adam_epsilon,
...@@ -262,31 +342,35 @@ def main(): ...@@ -262,31 +342,35 @@ def main():
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
if args.max_train_steps is None: if args.max_train_steps is None:
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True overrode_max_train_steps = True
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=num_warmup_steps * gradient_accumulation_steps * accelerator.num_processes, num_warmup_steps=num_warmup_steps
num_training_steps=args.max_train_steps * gradient_accumulation_steps * gradient_accumulation_steps
* accelerator.num_processes,
num_training_steps=args.max_train_steps * gradient_accumulation_steps,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
vae, model, optimizer, lr_scheduler = accelerator.prepare( vae, model, optimizer, lr_scheduler = accelerator.prepare(
vae, model, optimizer, lr_scheduler vae, model, optimizer, lr_scheduler
) )
train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare( train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
train_dataloader, eval_dataloader, test_dataloader train_dataloader, eval_dataloader, test_dataloader
) )
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
if overrode_max_train_steps: if overrode_max_train_steps:
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
...@@ -299,42 +383,44 @@ def main(): ...@@ -299,42 +383,44 @@ def main():
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
# Train! # Train!
total_batch_size = per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps total_batch_size = (
per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {per_device_batch_size}") logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(
range(args.max_train_steps), disable=not accelerator.is_local_main_process
)
completed_steps = 0 completed_steps = 0
starting_epoch = 0 starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
resume_from_checkpoint = config['paths']['resume_from_checkpoint'] resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
if resume_from_checkpoint!= '': if resume_from_checkpoint != "":
accelerator.load_state(resume_from_checkpoint) accelerator.load_state(resume_from_checkpoint)
accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}") accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")
# Duration of the audio clips in seconds # Duration of the audio clips in seconds
best_loss = np.inf best_loss = np.inf
length = config['training']['max_audio_duration'] length = config["training"]["max_audio_duration"]
for epoch in range(starting_epoch, num_train_epochs): for epoch in range(starting_epoch, num_train_epochs):
model.train() model.train()
total_loss, total_val_loss = 0, 0 total_loss, total_val_loss = 0, 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(model): with accelerator.accumulate(model):
optimizer.zero_grad() optimizer.zero_grad()
device = model.device device = model.device
...@@ -342,43 +428,44 @@ def main(): ...@@ -342,43 +428,44 @@ def main():
with torch.no_grad(): with torch.no_grad():
audio_list = [] audio_list = []
for audio_path in audios: for audio_path in audios:
wav = read_wav_file(audio_path,length) ## Only read the first 30 seconds of audio
if wav.shape[0] == 1 : ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2,1)
audio_list.append(wav)
wav = read_wav_file(
audio_path, length
) ## Only read the first 30 seconds of audio
if (
wav.shape[0] == 1
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2, 1)
audio_list.append(wav)
audio_input = torch.stack(audio_list, dim=0)
audio_input = torch.stack(audio_list,dim=0)
audio_input = audio_input.to(device) audio_input = audio_input.to(device)
unwrapped_vae = accelerator.unwrap_model(vae) unwrapped_vae = accelerator.unwrap_model(vae)
duration = torch.tensor(duration,device=device) duration = torch.tensor(duration, device=device)
duration = torch.clamp(duration, max=length) ## clamp duration to max audio length duration = torch.clamp(
duration, max=length
audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample() ) ## clamp duration to max audio length
audio_latent = audio_latent.transpose(1,2) ## Tranpose to (bsz, seq_len, channel)
audio_latent = unwrapped_vae.encode(
loss, _, _,_ = model(audio_latent, text ,duration=duration) audio_input
).latent_dist.sample()
audio_latent = audio_latent.transpose(
1, 2
) ## Tranpose to (bsz, seq_len, channel)
loss, _, _, _ = model(audio_latent, text, duration=duration)
total_loss += loss.detach().float() total_loss += loss.detach().float()
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
progress_bar.update(1) progress_bar.update(1)
completed_steps += 1 completed_steps += 1
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
if completed_steps % 10 == 0 and accelerator.is_main_process: if completed_steps % 10 == 0 and accelerator.is_main_process:
...@@ -388,20 +475,21 @@ def main(): ...@@ -388,20 +475,21 @@ def main():
param_norm = p.grad.data.norm(2) param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2 total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5 total_norm = total_norm**0.5
logger.info(f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}") logger.info(
f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
)
lr = lr_scheduler.get_last_lr()[0] lr = lr_scheduler.get_last_lr()[0]
result = { result = {
"train_loss": loss.item(), "train_loss": loss.item(),
"grad_norm": total_norm, "grad_norm": total_norm,
"learning_rate": lr "learning_rate": lr,
} }
# result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4) # result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
wandb.log(result, step=completed_steps) wandb.log(result, step=completed_steps)
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
...@@ -415,55 +503,59 @@ def main(): ...@@ -415,55 +503,59 @@ def main():
break break
model.eval() model.eval()
eval_progress_bar = tqdm(range(len(eval_dataloader)), disable=not accelerator.is_local_main_process) eval_progress_bar = tqdm(
range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
)
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with accelerator.accumulate(model) and torch.no_grad(): with accelerator.accumulate(model) and torch.no_grad():
device = model.device device = model.device
text, audios, duration, _ = batch text, audios, duration, _ = batch
audio_list = [] audio_list = []
for audio_path in audios: for audio_path in audios:
wav = read_wav_file(audio_path,length) ## make sure none of audio exceed 30 sec
if wav.shape[0] == 1 : ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2,1)
audio_list.append(wav)
wav = read_wav_file(
audio_path, length
) ## make sure none of audio exceed 30 sec
if (
wav.shape[0] == 1
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2, 1)
audio_list.append(wav)
audio_input = torch.stack(audio_list,dim=0) audio_input = torch.stack(audio_list, dim=0)
audio_input = audio_input.to(device) audio_input = audio_input.to(device)
duration = torch.tensor(duration,device=device) duration = torch.tensor(duration, device=device)
unwrapped_vae = accelerator.unwrap_model(vae) unwrapped_vae = accelerator.unwrap_model(vae)
audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample() audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
audio_latent = audio_latent.transpose(1,2) ## Tranpose to (bsz, seq_len, channel) audio_latent = audio_latent.transpose(
1, 2
) ## Tranpose to (bsz, seq_len, channel)
val_loss,_, _,_ = model(audio_latent, text , duration=duration)
val_loss, _, _, _ = model(audio_latent, text, duration=duration)
total_val_loss += val_loss.detach().float() total_val_loss += val_loss.detach().float()
eval_progress_bar.update(1) eval_progress_bar.update(1)
if accelerator.is_main_process: if accelerator.is_main_process:
result = {} result = {}
result["epoch"] = float(epoch+1) result["epoch"] = float(epoch + 1)
result["epoch/train_loss"] = round(total_loss.item()/len(train_dataloader), 4) result["epoch/train_loss"] = round(
result["epoch/val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4) total_loss.item() / len(train_dataloader), 4
)
result["epoch/val_loss"] = round(
total_val_loss.item() / len(eval_dataloader), 4
)
wandb.log(result, step=completed_steps)
wandb.log(result,step=completed_steps) result_string = "Epoch: {}, Loss Train: {}, Val: {}\n".format(
epoch, result["epoch/train_loss"], result["epoch/val_loss"]
)
result_string = "Epoch: {}, Loss Train: {}, Val: {}\n".format(epoch, result["epoch/train_loss"],result["epoch/val_loss"])
accelerator.print(result_string) accelerator.print(result_string)
with open("{}/summary.jsonl".format(output_dir), "a") as f: with open("{}/summary.jsonl".format(output_dir), "a") as f:
f.write(json.dumps(result) + "\n\n") f.write(json.dumps(result) + "\n\n")
...@@ -480,13 +572,17 @@ def main(): ...@@ -480,13 +572,17 @@ def main():
if accelerator.is_main_process and args.checkpointing_steps == "best": if accelerator.is_main_process and args.checkpointing_steps == "best":
if save_checkpoint: if save_checkpoint:
accelerator.save_state("{}/{}".format(output_dir, "best")) accelerator.save_state("{}/{}".format(output_dir, "best"))
if (epoch + 1) % args.save_every == 0: if (epoch + 1) % args.save_every == 0:
accelerator.save_state("{}/{}".format(output_dir, "epoch_" + str(epoch+1))) accelerator.save_state(
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
)
if accelerator.is_main_process and args.checkpointing_steps == "epoch": if accelerator.is_main_process and args.checkpointing_steps == "epoch":
accelerator.save_state("{}/{}".format(output_dir, "epoch_" + str(epoch+1))) accelerator.save_state(
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -5,6 +5,7 @@ import logging ...@@ -5,6 +5,7 @@ import logging
import math import math
import os import os
import yaml import yaml
# from tqdm import tqdm # from tqdm import tqdm
import copy import copy
from pathlib import Path from pathlib import Path
...@@ -22,9 +23,9 @@ from datasets import load_dataset ...@@ -22,9 +23,9 @@ from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import SchedulerType, get_scheduler from transformers import SchedulerType, get_scheduler
from src.model import TangoFlux from tangoflux.model import TangoFlux
from datasets import load_dataset, Audio from datasets import load_dataset, Audio
from src.utils import Text2AudioDataset,read_wav_file,DPOText2AudioDataset from tangoflux.utils import Text2AudioDataset, read_wav_file, DPOText2AudioDataset
from diffusers import AutoencoderOobleck from diffusers import AutoencoderOobleck
import torchaudio import torchaudio
...@@ -32,84 +33,119 @@ import torchaudio ...@@ -32,84 +33,119 @@ import torchaudio
logger = get_logger(__name__) logger = get_logger(__name__)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Rectified flow for text to audio generation task.") parser = argparse.ArgumentParser(
description="Rectified flow for text to audio generation task."
)
parser.add_argument( parser.add_argument(
"--num_examples", type=int, default=-1, "--num_examples",
type=int,
default=-1,
help="How many examples to use for training and validation.", help="How many examples to use for training and validation.",
) )
parser.add_argument( parser.add_argument(
"--text_column", type=str, default="captions", "--text_column",
type=str,
default="captions",
help="The name of the column in the datasets containing the input texts.", help="The name of the column in the datasets containing the input texts.",
) )
parser.add_argument( parser.add_argument(
"--audio_column", type=str, default="location", "--audio_column",
type=str,
default="location",
help="The name of the column in the datasets containing the audio paths.", help="The name of the column in the datasets containing the audio paths.",
) )
parser.add_argument( parser.add_argument(
"--adam_beta1", type=float, default=0.9, "--adam_beta1",
help="The beta1 parameter for the Adam optimizer." type=float,
default=0.9,
help="The beta1 parameter for the Adam optimizer.",
) )
parser.add_argument( parser.add_argument(
"--adam_beta2", type=float, default=0.95, "--adam_beta2",
help="The beta2 parameter for the Adam optimizer." type=float,
default=0.95,
help="The beta2 parameter for the Adam optimizer.",
) )
parser.add_argument( parser.add_argument(
"--config", type=str, default='tangoflux_config.yaml', "--config",
type=str,
default="tangoflux_config.yaml",
help="Config file defining the model size.", help="Config file defining the model size.",
) )
parser.add_argument( parser.add_argument(
"--weight_decay", type=float, default=1e-8, "--weight_decay", type=float, default=1e-8, help="Weight decay to use."
help="Weight decay to use."
) )
parser.add_argument( parser.add_argument(
"--max_train_steps", type=int, default=None, "--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
) )
parser.add_argument( parser.add_argument(
"--lr_scheduler_type", type=SchedulerType, default="linear", "--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.", help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], choices=[
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
],
) )
parser.add_argument( parser.add_argument(
"--num_warmup_steps", type=int, default=0, "--num_warmup_steps",
help="Number of steps for the warmup in the lr scheduler." type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler.",
) )
parser.add_argument( parser.add_argument(
"--adam_epsilon", type=float, default=1e-08, "--adam_epsilon",
help="Epsilon value for the Adam optimizer" type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer",
) )
parser.add_argument( parser.add_argument(
"--adam_weight_decay", type=float, default=1e-2, "--adam_weight_decay",
help="Epsilon value for the Adam optimizer" type=float,
default=1e-2,
help="Epsilon value for the Adam optimizer",
) )
parser.add_argument( parser.add_argument(
"--seed", type=int, default=None, "--seed", type=int, default=None, help="A seed for reproducible training."
help="A seed for reproducible training."
) )
parser.add_argument( parser.add_argument(
"--checkpointing_steps", type=str, default="best", "--checkpointing_steps",
type=str,
default="best",
help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.", help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
) )
parser.add_argument( parser.add_argument(
"--save_every", type=int, default=5, "--save_every",
help="Save model after every how many epochs when checkpointing_steps is set to best." type=int,
default=5,
help="Save model after every how many epochs when checkpointing_steps is set to best.",
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", type=str, default=None, "--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a local checkpoint folder.", help="If the training should continue from a local checkpoint folder.",
) )
parser.add_argument( parser.add_argument(
"--report_to", type=str, default="all", "--report_to",
type=str,
default="all",
help=( help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
...@@ -117,60 +153,58 @@ def parse_args(): ...@@ -117,60 +153,58 @@ def parse_args():
), ),
) )
parser.add_argument( parser.add_argument(
"--load_from_checkpoint", type=str, default=None, "--load_from_checkpoint",
type=str,
default=None,
help="Whether to continue training from a model weight", help="Whether to continue training from a model weight",
) )
parser.add_argument( parser.add_argument(
"--audio_length", type=float, default=30, "--audio_length",
type=float,
default=30,
help="Audio duration", help="Audio duration",
) )
args = parser.parse_args() args = parser.parse_args()
# Sanity checks # Sanity checks
#if args.train_file is None and args.validation_file is None: # if args.train_file is None and args.validation_file is None:
# raise ValueError("Need a training/validation file.") # raise ValueError("Need a training/validation file.")
#else: # else:
# if args.train_file is not None: # if args.train_file is not None:
# extension = args.train_file.split(".")[-1] # extension = args.train_file.split(".")[-1]
# assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." # assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
#if args.validation_file is not None: # if args.validation_file is not None:
# extension = args.validation_file.split(".")[-1] # extension = args.validation_file.split(".")[-1]
# assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." # assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
return args return args
def main(): def main():
args = parse_args() args = parse_args()
accelerator_log_kwargs = {} accelerator_log_kwargs = {}
def load_config(config_path): def load_config(config_path):
with open(config_path, 'r') as file: with open(config_path, "r") as file:
return yaml.safe_load(file) return yaml.safe_load(file)
config = load_config(args.config) config = load_config(args.config)
learning_rate = float(config["training"]["learning_rate"])
num_train_epochs = int(config["training"]["num_train_epochs"])
learning_rate = float(config['training']['learning_rate']) num_warmup_steps = int(config["training"]["num_warmup_steps"])
num_train_epochs = int(config['training']['num_train_epochs']) per_device_batch_size = int(config["training"]["per_device_batch_size"])
num_warmup_steps = int(config['training']['num_warmup_steps']) gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
per_device_batch_size = int(config['training']['per_device_batch_size'])
gradient_accumulation_steps = int(config['training']['gradient_accumulation_steps'])
output_dir = config['paths']['output_dir'] output_dir = config["paths"]["output_dir"]
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
**accelerator_log_kwargs,
)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, **accelerator_log_kwargs)
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -191,12 +225,12 @@ def main(): ...@@ -191,12 +225,12 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
if output_dir is None or output_dir == "": if output_dir is None or output_dir == "":
output_dir = "saved/" + str(int(time.time())) output_dir = "saved/" + str(int(time.time()))
if not os.path.exists("saved"): if not os.path.exists("saved"):
os.makedirs("saved") os.makedirs("saved")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
elif output_dir is not None: elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
...@@ -206,73 +240,122 @@ def main(): ...@@ -206,73 +240,122 @@ def main():
accelerator.project_configuration.automatic_checkpoint_naming = False accelerator.project_configuration.automatic_checkpoint_naming = False
wandb.init(project="Text to Audio Flow matching DPO",settings=wandb.Settings(_disable_stats=True)) wandb.init(
project="Text to Audio Flow matching DPO",
settings=wandb.Settings(_disable_stats=True),
)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# Get the datasets # Get the datasets
data_files = {} data_files = {}
#if args.train_file is not None: # if args.train_file is not None:
if config['paths']['train_file'] != '': if config["paths"]["train_file"] != "":
data_files["train"] = config['paths']['train_file'] data_files["train"] = config["paths"]["train_file"]
# if args.validation_file is not None: # if args.validation_file is not None:
if config['paths']['val_file'] != '': if config["paths"]["val_file"] != "":
data_files["validation"] = config['paths']['val_file'] data_files["validation"] = config["paths"]["val_file"]
if config['paths']['test_file'] != '': if config["paths"]["test_file"] != "":
data_files["test"] = config['paths']['test_file'] data_files["test"] = config["paths"]["test_file"]
else: else:
data_files["test"] = config['paths']['val_file'] data_files["test"] = config["paths"]["val_file"]
extension = 'json' extension = "json"
train_dataset = load_dataset(extension,data_files=data_files['train']) train_dataset = load_dataset(extension, data_files=data_files["train"])
data_files.pop('train') data_files.pop("train")
raw_datasets = load_dataset(extension, data_files=data_files) raw_datasets = load_dataset(extension, data_files=data_files)
text_column, audio_column = args.text_column, args.audio_column text_column, audio_column = args.text_column, args.audio_column
model = TangoFlux(config=config['model'],initialize_reference_model=True) model = TangoFlux(config=config["model"], initialize_reference_model=True)
vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae') vae = AutoencoderOobleck.from_pretrained(
"stabilityai/stable-audio-open-1.0", subfolder="vae"
)
## Freeze vae ## Freeze vae
for param in vae.parameters(): for param in vae.parameters():
vae.requires_grad = False vae.requires_grad = False
vae.eval() vae.eval()
## Freeze text encoder param ## Freeze text encoder param
for param in model.text_encoder.parameters(): for param in model.text_encoder.parameters():
param.requires_grad = False param.requires_grad = False
model.text_encoder.eval() model.text_encoder.eval()
prefix = ""
prefix = ""
with accelerator.main_process_first(): with accelerator.main_process_first():
train_dataset = DPOText2AudioDataset(train_dataset["train"], prefix, text_column, 'chosen','reject','duration', args.num_examples) train_dataset = DPOText2AudioDataset(
eval_dataset = Text2AudioDataset(raw_datasets["validation"], prefix, text_column, audio_column,'duration', args.num_examples) train_dataset["train"],
test_dataset = Text2AudioDataset(raw_datasets["test"], prefix, text_column, audio_column,'duration', args.num_examples) prefix,
text_column,
accelerator.print("Num instances in train: {}, validation: {}, test: {}".format(train_dataset.get_num_instances(), eval_dataset.get_num_instances(), test_dataset.get_num_instances())) "chosen",
"reject",
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['training']['per_device_batch_size'], collate_fn=train_dataset.collate_fn) "duration",
eval_dataloader = DataLoader(eval_dataset, shuffle=True, batch_size=config['training']['per_device_batch_size'], collate_fn=eval_dataset.collate_fn) args.num_examples,
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=config['training']['per_device_batch_size'], collate_fn=test_dataset.collate_fn) )
eval_dataset = Text2AudioDataset(
raw_datasets["validation"],
prefix,
text_column,
audio_column,
"duration",
args.num_examples,
)
test_dataset = Text2AudioDataset(
raw_datasets["test"],
prefix,
text_column,
audio_column,
"duration",
args.num_examples,
)
accelerator.print(
"Num instances in train: {}, validation: {}, test: {}".format(
train_dataset.get_num_instances(),
eval_dataset.get_num_instances(),
test_dataset.get_num_instances(),
)
)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=train_dataset.collate_fn,
)
eval_dataloader = DataLoader(
eval_dataset,
shuffle=True,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=eval_dataset.collate_fn,
)
test_dataloader = DataLoader(
test_dataset,
shuffle=False,
batch_size=config["training"]["per_device_batch_size"],
collate_fn=test_dataset.collate_fn,
)
# Optimizer # Optimizer
optimizer_parameters = list(model.transformer.parameters())+list(model.fc.parameters())
num_trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
optimizer_parameters = list(model.transformer.parameters()) + list(
model.fc.parameters()
)
num_trainable_parameters = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
if args.load_from_checkpoint: if args.load_from_checkpoint:
from safetensors.torch import load_file from safetensors.torch import load_file
w1 = load_file(args.load_from_checkpoint) w1 = load_file(args.load_from_checkpoint)
model.load_state_dict(w1,strict=False) model.load_state_dict(w1, strict=False)
logger.info("Weights loaded from{}".format(args.load_from_checkpoint)) logger.info("Weights loaded from{}".format(args.load_from_checkpoint))
import copy import copy
model.ref_transformer = copy.deepcopy(model.transformer) model.ref_transformer = copy.deepcopy(model.transformer)
model.ref_transformer.requires_grad_ = False model.ref_transformer.requires_grad_ = False
model.ref_transformer.eval() model.ref_transformer.eval()
...@@ -280,48 +363,49 @@ def main(): ...@@ -280,48 +363,49 @@ def main():
param.requires_grad = False param.requires_grad = False
@torch.no_grad() @torch.no_grad()
def initialize_or_update_ref_transformer(model, accelerator: Accelerator,alpha=0.5): def initialize_or_update_ref_transformer(
model, accelerator: Accelerator, alpha=0.5
):
""" """
Initializes or updates ref_transformer as alpha * ref + 1-alpha * transformer. Initializes or updates ref_transformer as alpha * ref + 1-alpha * transformer.
Args: Args:
model (torch.nn.Module): The main model containing the 'transformer' attribute. model (torch.nn.Module): The main model containing the 'transformer' attribute.
accelerator (Accelerator): The Accelerator instance used to unwrap the model. accelerator (Accelerator): The Accelerator instance used to unwrap the model.
initial_ref_model (torch.nn.Module, optional): An optional initial reference model. initial_ref_model (torch.nn.Module, optional): An optional initial reference model.
If not provided, ref_transformer is initialized as a copy of transformer. If not provided, ref_transformer is initialized as a copy of transformer.
Returns: Returns:
torch.nn.Module: The model with the updated ref_transformer. torch.nn.Module: The model with the updated ref_transformer.
""" """
# Unwrap the model to access the original underlying model # Unwrap the model to access the original underlying model
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
with torch.no_grad(): with torch.no_grad():
for ref_param, model_param in zip(unwrapped_model.ref_transformer.parameters(), for ref_param, model_param in zip(
unwrapped_model.transformer.parameters()): unwrapped_model.ref_transformer.parameters(),
average_param = alpha * ref_param.data + (1-alpha) * model_param.data unwrapped_model.transformer.parameters(),
):
average_param = alpha * ref_param.data + (1 - alpha) * model_param.data
ref_param.data.copy_(average_param) ref_param.data.copy_(average_param)
unwrapped_model.ref_transformer.eval() unwrapped_model.ref_transformer.eval()
unwrapped_model.ref_transformer.requires_grad_= False unwrapped_model.ref_transformer.requires_grad_ = False
for param in unwrapped_model.ref_transformer.parameters(): for param in unwrapped_model.ref_transformer.parameters():
param.requires_grad = False param.requires_grad = False
return model return model
model.ref_transformer = copy.deepcopy(model.transformer) model.ref_transformer = copy.deepcopy(model.transformer)
model.ref_transformer.requires_grad_ = False model.ref_transformer.requires_grad_ = False
model.ref_transformer.eval() model.ref_transformer.eval()
for param in model.ref_transformer.parameters(): for param in model.ref_transformer.parameters():
param.requires_grad = False param.requires_grad = False
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
optimizer_parameters, lr=learning_rate, optimizer_parameters,
lr=learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon, eps=args.adam_epsilon,
...@@ -329,31 +413,35 @@ def main(): ...@@ -329,31 +413,35 @@ def main():
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
if args.max_train_steps is None: if args.max_train_steps is None:
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True overrode_max_train_steps = True
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=num_warmup_steps * gradient_accumulation_steps * accelerator.num_processes, num_warmup_steps=num_warmup_steps
num_training_steps=args.max_train_steps * gradient_accumulation_steps * gradient_accumulation_steps
* accelerator.num_processes,
num_training_steps=args.max_train_steps * gradient_accumulation_steps,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
vae, model, optimizer, lr_scheduler = accelerator.prepare( vae, model, optimizer, lr_scheduler = accelerator.prepare(
vae, model, optimizer, lr_scheduler vae, model, optimizer, lr_scheduler
) )
train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare( train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
train_dataloader, eval_dataloader, test_dataloader train_dataloader, eval_dataloader, test_dataloader
) )
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
if overrode_max_train_steps: if overrode_max_train_steps:
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
...@@ -366,96 +454,108 @@ def main(): ...@@ -366,96 +454,108 @@ def main():
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
# Train! # Train!
total_batch_size = per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps total_batch_size = (
per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {per_device_batch_size}") logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(
range(args.max_train_steps), disable=not accelerator.is_local_main_process
)
completed_steps = 0 completed_steps = 0
starting_epoch = 0 starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
resume_from_checkpoint = config['paths']['resume_from_checkpoint'] resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
if resume_from_checkpoint!= '': if resume_from_checkpoint != "":
accelerator.load_state(resume_from_checkpoint) accelerator.load_state(resume_from_checkpoint)
accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}") accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")
# Duration of the audio clips in seconds # Duration of the audio clips in seconds
best_loss = np.inf best_loss = np.inf
length = config['training']['max_audio_duration'] length = config["training"]["max_audio_duration"]
for epoch in range(starting_epoch, num_train_epochs): for epoch in range(starting_epoch, num_train_epochs):
model.train() model.train()
total_loss, total_val_loss = 0, 0 total_loss, total_val_loss = 0, 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
optimizer.zero_grad() optimizer.zero_grad()
with accelerator.accumulate(model): with accelerator.accumulate(model):
optimizer.zero_grad() optimizer.zero_grad()
device = accelerator.device device = accelerator.device
text, audio_w,audio_l, duration, _ = batch text, audio_w, audio_l, duration, _ = batch
with torch.no_grad(): with torch.no_grad():
audio_list_w = [] audio_list_w = []
audio_list_l = [] audio_list_l = []
for audio_path in audio_w: for audio_path in audio_w:
wav = read_wav_file(
wav = read_wav_file(audio_path,length) ## Only read the first 30 seconds of audio audio_path, length
if wav.shape[0] == 1 : ## If this audio is mono, we repeat the channel so it become "fake stereo" ) ## Only read the first 30 seconds of audio
wav = wav.repeat(2,1) if (
wav.shape[0] == 1
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2, 1)
audio_list_w.append(wav) audio_list_w.append(wav)
for audio_path in audio_l: for audio_path in audio_l:
wav = read_wav_file(audio_path,length) ## Only read the first 30 seconds of audio wav = read_wav_file(
if wav.shape[0] == 1 : ## If this audio is mono, we repeat the channel so it become "fake stereo" audio_path, length
wav = wav.repeat(2,1) ) ## Only read the first 30 seconds of audio
if (
wav.shape[0] == 1
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2, 1)
audio_list_l.append(wav) audio_list_l.append(wav)
audio_input_w = torch.stack(audio_list_w, dim=0).to(device)
audio_input_l = torch.stack(audio_list_l, dim=0).to(device)
# audio_input_ = audio_input.to(device)
audio_input_w = torch.stack(audio_list_w,dim=0).to(device)
audio_input_l = torch.stack(audio_list_l,dim=0).to(device)
#audio_input_ = audio_input.to(device)
unwrapped_vae = accelerator.unwrap_model(vae) unwrapped_vae = accelerator.unwrap_model(vae)
duration = torch.tensor(duration,device=device) duration = torch.tensor(duration, device=device)
duration = torch.clamp(duration, max=length) ## max duration is 30 sec duration = torch.clamp(
duration, max=length
audio_latent_w = unwrapped_vae.encode(audio_input_w).latent_dist.sample() ) ## max duration is 30 sec
audio_latent_l = unwrapped_vae.encode(audio_input_l).latent_dist.sample()
audio_latent = torch.cat((audio_latent_w,audio_latent_l),dim=0) audio_latent_w = unwrapped_vae.encode(
audio_latent = audio_latent.transpose(1,2) ## Tranpose to (bsz, seq_len, channel) audio_input_w
).latent_dist.sample()
loss, raw_model_loss, raw_ref_loss,implicit_acc = model(audio_latent, text ,duration=duration,sft=False) audio_latent_l = unwrapped_vae.encode(
audio_input_l
).latent_dist.sample()
audio_latent = torch.cat((audio_latent_w, audio_latent_l), dim=0)
audio_latent = audio_latent.transpose(
1, 2
) ## Tranpose to (bsz, seq_len, channel)
loss, raw_model_loss, raw_ref_loss, implicit_acc = model(
audio_latent, text, duration=duration, sft=False
)
total_loss += loss.detach().float() total_loss += loss.detach().float()
accelerator.backward(loss) accelerator.backward(loss)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
#if accelerator.sync_gradients: # if accelerator.sync_gradients:
if accelerator.sync_gradients: if accelerator.sync_gradients:
#accelerator.clip_grad_value_(model.parameters(),1.0) # accelerator.clip_grad_value_(model.parameters(),1.0)
progress_bar.update(1) progress_bar.update(1)
completed_steps += 1 completed_steps += 1
if completed_steps % 10 == 0 and accelerator.is_main_process: if completed_steps % 10 == 0 and accelerator.is_main_process:
...@@ -465,26 +565,25 @@ def main(): ...@@ -465,26 +565,25 @@ def main():
param_norm = p.grad.data.norm(2) param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2 total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5 total_norm = total_norm**0.5
logger.info(f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}") logger.info(
f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
)
lr = lr_scheduler.get_last_lr()[0] lr = lr_scheduler.get_last_lr()[0]
result = { result = {
"train_loss": loss.item(), "train_loss": loss.item(),
"grad_norm": total_norm, "grad_norm": total_norm,
"learning_rate": lr, "learning_rate": lr,
'raw_model_loss':raw_model_loss, "raw_model_loss": raw_model_loss,
'raw_ref_loss': raw_ref_loss, "raw_ref_loss": raw_ref_loss,
'implicit_acc':implicit_acc "implicit_acc": implicit_acc,
} }
# result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4) # result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
wandb.log(result, step=completed_steps) wandb.log(result, step=completed_steps)
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
...@@ -497,72 +596,78 @@ def main(): ...@@ -497,72 +596,78 @@ def main():
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
model.eval() model.eval()
eval_progress_bar = tqdm(range(len(eval_dataloader)), disable=not accelerator.is_local_main_process) eval_progress_bar = tqdm(
range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
)
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with accelerator.accumulate(model) and torch.no_grad(): with accelerator.accumulate(model) and torch.no_grad():
device = model.device device = model.device
text, audios, duration, _ = batch text, audios, duration, _ = batch
audio_list = [] audio_list = []
for audio_path in audios: for audio_path in audios:
wav = read_wav_file(audio_path,length) ## Only read the first 30 seconds of audio wav = read_wav_file(
if wav.shape[0] == 1 : ## If this audio is mono, we repeat the channel so it become "fake stereo" audio_path, length
wav = wav.repeat(2,1) ) ## Only read the first 30 seconds of audio
if (
wav.shape[0] == 1
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
wav = wav.repeat(2, 1)
audio_list.append(wav) audio_list.append(wav)
audio_input = torch.stack(audio_list,dim=0) audio_input = torch.stack(audio_list, dim=0)
audio_input = audio_input.to(device) audio_input = audio_input.to(device)
duration = torch.tensor(duration,device=device) duration = torch.tensor(duration, device=device)
unwrapped_vae = accelerator.unwrap_model(vae) unwrapped_vae = accelerator.unwrap_model(vae)
audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample() audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
audio_latent = audio_latent.transpose(1,2) ## Tranpose to (bsz, seq_len, channel) audio_latent = audio_latent.transpose(
1, 2
) ## Tranpose to (bsz, seq_len, channel)
val_loss, _, _, _ = model(audio_latent, text , duration=duration,sft=True)
val_loss, _, _, _ = model(
audio_latent, text, duration=duration, sft=True
)
total_val_loss += val_loss.detach().float() total_val_loss += val_loss.detach().float()
eval_progress_bar.update(1) eval_progress_bar.update(1)
if accelerator.is_main_process: if accelerator.is_main_process:
result = {}
result["epoch"] = float(epoch+1)
result["epoch/train_loss"] = round(total_loss.item()/len(train_dataloader), 4) result = {}
result["epoch/val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4) result["epoch"] = float(epoch + 1)
wandb.log(result,step=completed_steps) result["epoch/train_loss"] = round(
total_loss.item() / len(train_dataloader), 4
)
result["epoch/val_loss"] = round(
total_val_loss.item() / len(eval_dataloader), 4
)
wandb.log(result, step=completed_steps)
with open("{}/summary.jsonl".format(output_dir), "a") as f: with open("{}/summary.jsonl".format(output_dir), "a") as f:
f.write(json.dumps(result) + "\n\n") f.write(json.dumps(result) + "\n\n")
logger.info(result) logger.info(result)
save_checkpoint = True
save_checkpoint= True
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process and args.checkpointing_steps == "best": if accelerator.is_main_process and args.checkpointing_steps == "best":
if save_checkpoint: if save_checkpoint:
accelerator.save_state("{}/{}".format(output_dir, "best")) accelerator.save_state("{}/{}".format(output_dir, "best"))
if (epoch + 1) % args.save_every == 0: if (epoch + 1) % args.save_every == 0:
accelerator.save_state("{}/{}".format(output_dir, "epoch_" + str(epoch+1))) accelerator.save_state(
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
)
if accelerator.is_main_process and args.checkpointing_steps == "epoch": if accelerator.is_main_process and args.checkpointing_steps == "epoch":
accelerator.save_state("{}/{}".format(output_dir, "epoch_" + str(epoch+1))) accelerator.save_state(
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
import numpy as np import numpy as np
def normalize_wav(waveform): def normalize_wav(waveform):
waveform = waveform - torch.mean(waveform) waveform = waveform - torch.mean(waveform)
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
...@@ -20,7 +20,7 @@ def normalize_wav(waveform): ...@@ -20,7 +20,7 @@ def normalize_wav(waveform):
def pad_wav(waveform, segment_length): def pad_wav(waveform, segment_length):
waveform_length = len(waveform) waveform_length = len(waveform)
if segment_length is None or waveform_length == segment_length: if segment_length is None or waveform_length == segment_length:
return waveform return waveform
elif waveform_length > segment_length: elif waveform_length > segment_length:
...@@ -29,40 +29,47 @@ def pad_wav(waveform, segment_length): ...@@ -29,40 +29,47 @@ def pad_wav(waveform, segment_length):
padded_wav = torch.zeros(segment_length - waveform_length).to(waveform.device) padded_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
waveform = torch.cat([waveform, padded_wav]) waveform = torch.cat([waveform, padded_wav])
return waveform return waveform
def read_wav_file(filename, duration_sec): def read_wav_file(filename, duration_sec):
info = torchaudio.info(filename) info = torchaudio.info(filename)
sample_rate = info.sample_rate sample_rate = info.sample_rate
# Calculate the number of frames corresponding to the desired duration # Calculate the number of frames corresponding to the desired duration
num_frames = int(sample_rate * duration_sec) num_frames = int(sample_rate * duration_sec)
waveform, sr = torchaudio.load(filename,num_frames=num_frames) # Faster!!! waveform, sr = torchaudio.load(filename, num_frames=num_frames) # Faster!!!
if waveform.shape[0] == 2 : ## Stereo audio if waveform.shape[0] == 2: ## Stereo audio
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=44100) resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=44100)
resampled_waveform = resampler(waveform) resampled_waveform = resampler(waveform)
#print(resampled_waveform.shape) # print(resampled_waveform.shape)
padded_left = pad_wav(resampled_waveform[0], int(44100*duration_sec)) ## We pad left and right seperately padded_left = pad_wav(
padded_right = pad_wav(resampled_waveform[1], int(44100*duration_sec)) resampled_waveform[0], int(44100 * duration_sec)
) ## We pad left and right seperately
padded_right = pad_wav(resampled_waveform[1], int(44100 * duration_sec))
return torch.stack([padded_left,padded_right]) return torch.stack([padded_left, padded_right])
else: else:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=44100)[0] waveform = torchaudio.functional.resample(
waveform = pad_wav(waveform, int(44100*duration_sec)).unsqueeze(0) waveform, orig_freq=sr, new_freq=44100
)[0]
waveform = pad_wav(waveform, int(44100 * duration_sec)).unsqueeze(0)
return waveform return waveform
class DPOText2AudioDataset(Dataset): class DPOText2AudioDataset(Dataset):
def __init__(self, dataset, prefix, text_column, audio_w_column, audio_l_column, duration, num_examples=-1): def __init__(
self,
dataset,
prefix,
text_column,
audio_w_column,
audio_l_column,
duration,
num_examples=-1,
):
inputs = list(dataset[text_column]) inputs = list(dataset[text_column])
self.inputs = [prefix + inp for inp in inputs] self.inputs = [prefix + inp for inp in inputs]
...@@ -72,11 +79,18 @@ class DPOText2AudioDataset(Dataset): ...@@ -72,11 +79,18 @@ class DPOText2AudioDataset(Dataset):
self.indices = list(range(len(self.inputs))) self.indices = list(range(len(self.inputs)))
self.mapper = {} self.mapper = {}
for index, audio_w, audio_l, duration, text in zip(self.indices, self.audios_w,self.audios_l,self.durations,inputs): for index, audio_w, audio_l, duration, text in zip(
self.indices, self.audios_w, self.audios_l, self.durations, inputs
):
self.mapper[index] = [audio_w, audio_l, duration, text] self.mapper[index] = [audio_w, audio_l, duration, text]
if num_examples != -1: if num_examples != -1:
self.inputs, self.audios_w, self.audios_l, self.durations = self.inputs[:num_examples], self.audios_w[:num_examples], self.audios_l[:num_examples], self.durations[:num_examples] self.inputs, self.audios_w, self.audios_l, self.durations = (
self.inputs[:num_examples],
self.audios_w[:num_examples],
self.audios_l[:num_examples],
self.durations[:num_examples],
)
self.indices = self.indices[:num_examples] self.indices = self.indices[:num_examples]
def __len__(self): def __len__(self):
...@@ -86,15 +100,24 @@ class DPOText2AudioDataset(Dataset): ...@@ -86,15 +100,24 @@ class DPOText2AudioDataset(Dataset):
return len(self.inputs) return len(self.inputs)
def __getitem__(self, index): def __getitem__(self, index):
s1, s2, s3, s4, s5 = self.inputs[index], self.audios_w[index], self.audios_l[index], self.durations[index], self.indices[index] s1, s2, s3, s4, s5 = (
self.inputs[index],
self.audios_w[index],
self.audios_l[index],
self.durations[index],
self.indices[index],
)
return s1, s2, s3, s4, s5 return s1, s2, s3, s4, s5
def collate_fn(self, data): def collate_fn(self, data):
dat = pd.DataFrame(data) dat = pd.DataFrame(data)
return [dat[i].tolist() for i in dat] return [dat[i].tolist() for i in dat]
class Text2AudioDataset(Dataset): class Text2AudioDataset(Dataset):
def __init__(self, dataset, prefix, text_column, audio_column, duration, num_examples=-1): def __init__(
self, dataset, prefix, text_column, audio_column, duration, num_examples=-1
):
inputs = list(dataset[text_column]) inputs = list(dataset[text_column])
self.inputs = [prefix + inp for inp in inputs] self.inputs = [prefix + inp for inp in inputs]
...@@ -103,11 +126,17 @@ class Text2AudioDataset(Dataset): ...@@ -103,11 +126,17 @@ class Text2AudioDataset(Dataset):
self.indices = list(range(len(self.inputs))) self.indices = list(range(len(self.inputs)))
self.mapper = {} self.mapper = {}
for index, audio, duration,text in zip(self.indices, self.audios, self.durations,inputs): for index, audio, duration, text in zip(
self.mapper[index] = [audio, text,duration] self.indices, self.audios, self.durations, inputs
):
self.mapper[index] = [audio, text, duration]
if num_examples != -1: if num_examples != -1:
self.inputs, self.audios, self.durations = self.inputs[:num_examples], self.audios[:num_examples], self.durations[:num_examples] self.inputs, self.audios, self.durations = (
self.inputs[:num_examples],
self.audios[:num_examples],
self.durations[:num_examples],
)
self.indices = self.indices[:num_examples] self.indices = self.indices[:num_examples]
def __len__(self): def __len__(self):
...@@ -117,7 +146,12 @@ class Text2AudioDataset(Dataset): ...@@ -117,7 +146,12 @@ class Text2AudioDataset(Dataset):
return len(self.inputs) return len(self.inputs)
def __getitem__(self, index): def __getitem__(self, index):
s1, s2, s3, s4 = self.inputs[index], self.audios[index], self.durations[index], self.indices[index] s1, s2, s3, s4 = (
self.inputs[index],
self.audios[index],
self.durations[index],
self.indices[index],
)
return s1, s2, s3, s4 return s1, s2, s3, s4
def collate_fn(self, data): def collate_fn(self, data):
......
CUDA_VISISBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' src/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml' CUDA_VISISBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' tangoflux/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment