Unverified Commit d32c883c authored by Soujanya Poria's avatar Soujanya Poria Committed by GitHub
Browse files

Merge pull request #4 from chenxwh/main

Add Replicate demo and API
parents f86e0b9b 555a6c1d
......@@ -12,7 +12,8 @@ TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching a
<br/>
[![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO)
[![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Huggingface_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO) [![Replicate](https://replicate.com/chenxwh/tangoflux/badge)](https://replicate.com/chenxwh/tangoflux)
......
# Configuration for Cog ⚙️
# Reference: https://cog.run/yaml
build:
# set to true if your model requires a GPU
gpu: true
# a list of ubuntu apt packages to install
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
# python version in the form '3.11' or '3.11.4'
python_version: "3.11"
# a list of packages in the format <package-name>==<version>
python_packages:
- torch==2.4.0
- torchaudio==2.4.0
- torchlibrosa==0.1.0
- torchvision==0.19.0
- transformers==4.44.0
- diffusers==0.30.0
- accelerate==0.34.2
- datasets==2.21.0
- librosa
- ipython
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
predict: "predict.py:Predictor"
# Prediction interface for Cog ⚙️
# https://cog.run/python
import os
import subprocess
import time
import json
from cog import BasePredictor, Input, Path
from diffusers import AutoencoderOobleck
import soundfile as sf
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
from src.model import TangoFlux
from tangoflux import TangoFluxInference
MODEL_CACHE = "model_cache"
MODEL_URL = "https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
class CachedTangoFluxInference(TangoFluxInference):
## load the weights from replicate.delivery for faster booting
def __init__(self, name="declare-lab/TangoFlux", device="cuda", cached_paths=None):
if cached_paths:
paths = cached_paths
else:
paths = snapshot_download(repo_id=name)
self.vae = AutoencoderOobleck()
vae_weights = load_file(f"{paths}/vae.safetensors")
self.vae.load_state_dict(vae_weights)
weights = load_file(f"{paths}/tangoflux.safetensors")
with open(f"{paths}/config.json", "r") as f:
config = json.load(f)
self.model = TangoFlux(config)
self.model.load_state_dict(weights, strict=False)
self.vae.to(device)
self.model.to(device)
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
if not os.path.exists(MODEL_CACHE):
print("downloading")
download_weights(MODEL_URL, MODEL_CACHE)
self.model = CachedTangoFluxInference(
cached_paths=f"{MODEL_CACHE}/declare-lab/TangoFlux"
)
def predict(
self,
prompt: str = Input(
description="Input prompt", default="Hammer slowly hitting the wooden table"
),
duration: int = Input(
description="Duration of the output audio in seconds", default=10
),
steps: int = Input(
description="Number of inference steps", ge=1, le=200, default=25
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=4.5
),
) -> Path:
"""Run a single prediction on the model"""
audio = self.model.generate(
prompt,
steps=steps,
guidance_scale=guidance_scale,
duration=duration,
)
audio_numpy = audio.numpy()
out_path = "/tmp/out.wav"
sf.write(
out_path, audio_numpy.T, samplerate=self.model.vae.config.sampling_rate
)
return Path(out_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment