__init__.py 1.98 KB
Newer Older
hungchiayu1's avatar
hungchiayu1 committed
1
2
from diffusers import AutoencoderOobleck
import torch
mrfakename's avatar
mrfakename committed
3
4
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import FluxTransformer2DModel
hungchiayu1's avatar
hungchiayu1 committed
5
6
7
8
9
10
11
from torch import nn
from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_density_for_timestep_sampling
import copy
import torch.nn.functional as F
import numpy as np
mrfakename's avatar
mrfakename committed
12
from tangoflux.model import TangoFlux
hungchiayu1's avatar
hungchiayu1 committed
13
14
from huggingface_hub import snapshot_download
from tqdm import tqdm
mrfakename's avatar
mrfakename committed
15
from typing import Optional, Union, List
hungchiayu1's avatar
hungchiayu1 committed
16
17
18
19
20
21
22
23
24
25
from datasets import load_dataset, Audio
from math import pi
import json
import inspect
import yaml
from safetensors.torch import load_file


class TangoFluxInference:

mrfakename's avatar
mrfakename committed
26
27
28
29
30
    def __init__(
        self,
        name="declare-lab/TangoFlux",
        device="cuda" if torch.cuda.is_available() else "cpu",
    ):
hungchiayu1's avatar
hungchiayu1 committed
31

Chia-Yu Hung's avatar
Chia-Yu Hung committed
32
        self.vae = AutoencoderOobleck()
hungchiayu1's avatar
hungchiayu1 committed
33

mrfakename's avatar
mrfakename committed
34
        paths = snapshot_download(repo_id=name)
Chia-Yu Hung's avatar
Chia-Yu Hung committed
35
36
        vae_weights = load_file("{}/vae.safetensors".format(paths))
        self.vae.load_state_dict(vae_weights)
hungchiayu1's avatar
hungchiayu1 committed
37
38
        weights = load_file("{}/tangoflux.safetensors".format(paths))

mrfakename's avatar
mrfakename committed
39
        with open("{}/config.json".format(paths), "r") as f:
hungchiayu1's avatar
hungchiayu1 committed
40
41
            config = json.load(f)
        self.model = TangoFlux(config)
mrfakename's avatar
mrfakename committed
42
43
        self.model.load_state_dict(weights, strict=False)
        # _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
hungchiayu1's avatar
hungchiayu1 committed
44
45
46
        self.vae.to(device)
        self.model.to(device)

mrfakename's avatar
mrfakename committed
47
    def generate(self, prompt, steps=25, duration=10, guidance_scale=4.5):
hungchiayu1's avatar
hungchiayu1 committed
48

mrfakename's avatar
mrfakename committed
49
50
51
52
53
54
55
56
57
        with torch.no_grad():
            latents = self.model.inference_flow(
                prompt,
                duration=duration,
                num_inference_steps=steps,
                guidance_scale=guidance_scale,
            )

            wave = self.vae.decode(latents.transpose(2, 1)).sample.cpu()[0]
hungchiayu1's avatar
updates  
hungchiayu1 committed
58
        waveform_end = int(duration * self.vae.config.sampling_rate)
mrfakename's avatar
mrfakename committed
59
        wave = wave[:, :waveform_end]
hungchiayu1's avatar
hungchiayu1 committed
60
        return wave