predict.py 2.88 KB
Newer Older
chenxwh's avatar
chenxwh committed
1
2
3
4
5
6
7
8
9
10
11
12
# Prediction interface for Cog ⚙️
# https://cog.run/python

import os
import subprocess
import time
import json
from cog import BasePredictor, Input, Path
from diffusers import AutoencoderOobleck
import soundfile as sf
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
mrfakename's avatar
mrfakename committed
13
from tangoflux.model import TangoFlux
chenxwh's avatar
chenxwh committed
14
15
16
from tangoflux import TangoFluxInference

MODEL_CACHE = "model_cache"
mrfakename's avatar
mrfakename committed
17
18
19
MODEL_URL = (
    "https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
)
chenxwh's avatar
chenxwh committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92


class CachedTangoFluxInference(TangoFluxInference):
    ## load the weights from replicate.delivery for faster booting
    def __init__(self, name="declare-lab/TangoFlux", device="cuda", cached_paths=None):
        if cached_paths:
            paths = cached_paths
        else:
            paths = snapshot_download(repo_id=name)

        self.vae = AutoencoderOobleck()
        vae_weights = load_file(f"{paths}/vae.safetensors")
        self.vae.load_state_dict(vae_weights)
        weights = load_file(f"{paths}/tangoflux.safetensors")

        with open(f"{paths}/config.json", "r") as f:
            config = json.load(f)
        self.model = TangoFlux(config)
        self.model.load_state_dict(weights, strict=False)
        self.vae.to(device)
        self.model.to(device)


def download_weights(url, dest):
    start = time.time()
    print("downloading url: ", url)
    print("downloading to: ", dest)
    subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
    print("downloading took: ", time.time() - start)


class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""

        if not os.path.exists(MODEL_CACHE):
            print("downloading")
            download_weights(MODEL_URL, MODEL_CACHE)

        self.model = CachedTangoFluxInference(
            cached_paths=f"{MODEL_CACHE}/declare-lab/TangoFlux"
        )

    def predict(
        self,
        prompt: str = Input(
            description="Input prompt", default="Hammer slowly hitting the wooden table"
        ),
        duration: int = Input(
            description="Duration of the output audio in seconds", default=10
        ),
        steps: int = Input(
            description="Number of inference steps", ge=1, le=200, default=25
        ),
        guidance_scale: float = Input(
            description="Scale for classifier-free guidance", ge=1, le=20, default=4.5
        ),
    ) -> Path:
        """Run a single prediction on the model"""

        audio = self.model.generate(
            prompt,
            steps=steps,
            guidance_scale=guidance_scale,
            duration=duration,
        )
        audio_numpy = audio.numpy()
        out_path = "/tmp/out.wav"

        sf.write(
            out_path, audio_numpy.T, samplerate=self.model.vae.config.sampling_rate
        )
        return Path(out_path)