Commit 55ba477f authored by LucipherDev's avatar LucipherDev
Browse files

Add ComfyUI custom node

parent f7adb62a
try:
from .comfyui import *
except:
pass
\ No newline at end of file
# ComfyUI-TangoFlux
ComfyUI Custom Nodes for ["TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching"](https://arxiv.org/abs/2412.21037). These nodes, adapted from [the official implementations](https://github.com/declare-lab/TangoFlux/), generates high-quality 44.1kHz audio up to 30 seconds using just a text promptproduction.
## Installation
1. Navigate to your ComfyUI's custom_nodes directory:
```bash
cd ComfyUI/custom_nodes
```
2. Clone this repository:
```bash
git clone https://github.com/declare-lab/TangoFlux ComfyUI-TangoFlux
```
3. Install requirements:
```bash
cd ComfyUI-TangoFlux/comfyui
python install.py
```
### Or Install via ComfyUI Manager
#### Check out some demos from [the official demo page](https://tangoflux.github.io/)
## Example Workflow
![example_workflow](https://github.com/user-attachments/assets/afbf7b53-d712-4c9c-a538-53f0dc001f45)
## Usage
**All the necessary models should be automatically downloaded when the TangoFluxLoader node is used for the first time.**
**Models can also be downloaded using the `install.py` script**
![models_folder_structure](https://github.com/user-attachments/assets/94d8a54a-10d6-4f90-bb4d-3ee181dee3a2)
**Manual Download:**
- Download TangoFlux from [here](https://huggingface.co/declare-lab/TangoFlux/tree/main) into `models/tangoflux`
- Download text encoders from [here](https://huggingface.co/google/flan-t5-large/tree/main) into `models/text_encoders/google-flan-t5-large`
*(Include Everything as shown in the screenshot above. Do Not Rename Anything)*
The nodes can be found in "TangoFlux" category as `TangoFluxLoader`, `TangoFluxSampler`, `TangoFluxVAEDecodeAndPlay`.
![teacache_options](https://github.com/user-attachments/assets/29e676d9-902b-4ea2-9f72-18d3607996e8)
> [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup TangoFlux 2x without much audio quality degradation, in a training-free manner.
>
>
> ## 📈 Inference Latency Comparisons on a Single A800
>
>
> | TangoFlux | TeaCache (0.25) | TeaCache (0.4) |
> |:-------------------:|:----------------------------:|:--------------------:|
> | ~4.08 s | ~2.42 s | ~1.95 s |
## Citation
```bibtex
@misc{hung2024tangofluxsuperfastfaithful,
title={TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization},
author={Chia-Yu Hung and Navonil Majumder and Zhifeng Kong and Ambuj Mehrish and Rafael Valle and Bryan Catanzaro and Soujanya Poria},
year={2024},
eprint={2412.21037},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2412.21037},
}
```
```
@article{liu2024timestep,
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
journal={arXiv preprint arXiv:2411.19108},
year={2024}
}
```
from .nodes import NODE_CLASS_MAPPINGS
from .server import *
WEB_DIRECTORY = "./comfyui/web"
__all__ = ["NODE_CLASS_MAPPINGS", "WEB_DIRECTORY"]
{
"last_node_id": 13,
"last_link_id": 15,
"nodes": [
{
"id": 10,
"type": "TangoFluxLoader",
"pos": [
380,
320
],
"size": [
210,
102
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "model",
"type": "TANGOFLUX_MODEL",
"links": [
11
],
"slot_index": 0
},
{
"name": "vae",
"type": "TANGOFLUX_VAE",
"links": [
15
],
"slot_index": 1
}
],
"properties": {
"Node name for S&R": "TangoFluxLoader"
},
"widgets_values": [
false,
0.25
]
},
{
"id": 13,
"type": "TangoFluxVAEDecodeAndPlay",
"pos": [
1060,
320
],
"size": [
315,
126
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"name": "vae",
"type": "TANGOFLUX_VAE",
"link": 15
},
{
"name": "latents",
"type": "TANGOFLUX_LATENTS",
"link": 14
}
],
"outputs": [],
"properties": {
"Node name for S&R": "TangoFluxVAEDecodeAndPlay"
},
"widgets_values": [
"TangoFlux",
"wav",
true
]
},
{
"id": 11,
"type": "TangoFluxSampler",
"pos": [
620,
320
],
"size": [
400,
220
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "TANGOFLUX_MODEL",
"link": 11
}
],
"outputs": [
{
"name": "latents",
"type": "TANGOFLUX_LATENTS",
"links": [
14
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "TangoFluxSampler"
},
"widgets_values": [
"A dog barking near the ocean, ocean waves crashing.",
50,
3,
10,
106139285587780,
"randomize",
1
]
}
],
"links": [
[
11,
10,
0,
11,
0,
"TANGOFLUX_MODEL"
],
[
14,
11,
0,
13,
1,
"TANGOFLUX_LATENTS"
],
[
15,
10,
1,
13,
0,
"TANGOFLUX_VAE"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.9480295566502464,
"offset": [
-200.83333333333337,
-102.2460379319304
]
},
"node_versions": {
"comfyui-tangoflux": "1.0.4"
}
},
"version": 0.4
}
\ No newline at end of file
import sys
import os
import logging
import subprocess
import traceback
import json
import re
log = logging.getLogger("TangoFlux")
download_models = True
EXT_PATH = os.path.dirname(os.path.abspath(__file__))
try:
folder_paths_path = os.path.abspath(os.path.join(EXT_PATH, "..", "..", "..", "folder_paths.py"))
sys.path.append(os.path.dirname(folder_paths_path))
import folder_paths
TANGOFLUX_DIR = os.path.join(folder_paths.models_dir, "tangoflux")
TEXT_ENCODER_DIR = os.path.join(folder_paths.models_dir, "text_encoders")
except:
download_models = False
try:
log.info("Installing requirements")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", f"{EXT_PATH}/requirements.txt", "--no-warn-script-location"])
if download_models:
from huggingface_hub import snapshot_download
log.info("Downloading Necessary models")
try:
log.info(f"Downloading TangoFlux models to: {TANGOFLUX_DIR}")
snapshot_download(
repo_id="declare-lab/TangoFlux",
allow_patterns=["*.json", "*.safetensors"],
local_dir=TANGOFLUX_DIR,
local_dir_use_symlinks=False,
)
except Exception:
traceback.print_exc()
log.error("Failed to download TangoFlux models")
log.info("Loading config")
with open(os.path.join(TANGOFLUX_DIR, "config.json"), "r") as f:
config = json.load(f)
try:
text_encoder = re.sub(r'[<>:"/\\|?*]', '-', config.get("text_encoder_name", "google/flan-t5-large"))
text_encoder_path = os.path.join(TEXT_ENCODER_DIR, text_encoder)
log.info(f"Downloading text encoders to: {text_encoder_path}")
snapshot_download(
repo_id=config.get("text_encoder_name", "google/flan-t5-large"),
allow_patterns=["*.json", "*.safetensors", "*.model"],
local_dir=text_encoder_path,
local_dir_use_symlinks=False,
)
except Exception:
traceback.print_exc()
log.error("Failed to download text encoders")
try:
log.info("Installing TangoFlux module")
subprocess.check_call([sys.executable, "-m", "pip", "install", os.path.join(EXT_PATH, "..")])
except Exception:
traceback.print_exc()
log.error("Failed to install TangoFlux module")
log.info("TangoFlux Installation completed")
except Exception:
traceback.print_exc()
log.error("TangoFlux Installation failed")
\ No newline at end of file
import os
import logging
import json
import random
import torch
import torchaudio
import re
from diffusers import AutoencoderOobleck, FluxTransformer2DModel
from huggingface_hub import snapshot_download
from comfy.utils import load_torch_file, ProgressBar
import folder_paths
from tangoflux.model import TangoFlux
from .teacache import teacache_forward
log = logging.getLogger("TangoFlux")
TANGOFLUX_DIR = os.path.join(folder_paths.models_dir, "tangoflux")
if "tangoflux" not in folder_paths.folder_names_and_paths:
current_paths = [TANGOFLUX_DIR]
else:
current_paths, _ = folder_paths.folder_names_and_paths["tangoflux"]
folder_paths.folder_names_and_paths["tangoflux"] = (
current_paths,
folder_paths.supported_pt_extensions,
)
TEXT_ENCODER_DIR = os.path.join(folder_paths.models_dir, "text_encoders")
class TangoFluxLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"enable_teacache": ("BOOLEAN", {"default": False}),
"rel_l1_thresh": (
"FLOAT",
{"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.01},
),
},
}
RETURN_TYPES = ("TANGOFLUX_MODEL", "TANGOFLUX_VAE")
RETURN_NAMES = ("model", "vae")
OUTPUT_TOOLTIPS = ("TangoFlux Model", "TangoFlux Vae")
CATEGORY = "TangoFlux"
FUNCTION = "load_tangoflux"
DESCRIPTION = "Load TangoFlux model"
def __init__(self):
self.model = None
self.vae = None
self.enable_teacache = False
self.rel_l1_thresh = 0.25
self.original_forward = FluxTransformer2DModel.forward
def load_tangoflux(
self,
enable_teacache=False,
rel_l1_thresh=0.25,
tangoflux_path=TANGOFLUX_DIR,
text_encoder_path=TEXT_ENCODER_DIR,
device="cuda",
):
if self.model is None or self.enable_teacache != enable_teacache:
pbar = ProgressBar(6)
snapshot_download(
repo_id="declare-lab/TangoFlux",
allow_patterns=["*.json", "*.safetensors"],
local_dir=tangoflux_path,
local_dir_use_symlinks=False,
)
pbar.update(1)
log.info("Loading config")
with open(os.path.join(tangoflux_path, "config.json"), "r") as f:
config = json.load(f)
pbar.update(1)
text_encoder = re.sub(
r'[<>:"/\\|?*]',
"-",
config.get("text_encoder_name", "google/flan-t5-large"),
)
text_encoder_path = os.path.join(text_encoder_path, text_encoder)
snapshot_download(
repo_id=config.get("text_encoder_name", "google/flan-t5-large"),
allow_patterns=["*.json", "*.safetensors", "*.model"],
local_dir=text_encoder_path,
local_dir_use_symlinks=False,
)
pbar.update(1)
log.info("Loading TangoFlux models")
del self.model
self.model = None
model_weights = load_torch_file(
os.path.join(tangoflux_path, "tangoflux.safetensors"),
device=torch.device(device),
)
pbar.update(1)
if enable_teacache:
log.info("Enabling TeaCache")
FluxTransformer2DModel.forward = teacache_forward
else:
log.info("Disabling TeaCache")
FluxTransformer2DModel.forward = self.original_forward
model = TangoFlux(config=config, text_encoder_dir=text_encoder_path)
model.load_state_dict(model_weights, strict=False)
model.to(device)
if enable_teacache:
model.transformer.__class__.enable_teacache = True
model.transformer.__class__.cnt = 0
model.transformer.__class__.rel_l1_thresh = rel_l1_thresh
model.transformer.__class__.accumulated_rel_l1_distance = 0
model.transformer.__class__.previous_modulated_input = None
model.transformer.__class__.previous_residual = None
pbar.update(1)
self.model = model
del model
self.enable_teacache = enable_teacache
self.rel_l1_thresh = rel_l1_thresh
if self.vae is None:
log.info("Loading TangoFlux VAE")
vae_weights = load_torch_file(
os.path.join(tangoflux_path, "vae.safetensors")
)
self.vae = AutoencoderOobleck()
self.vae.load_state_dict(vae_weights)
self.vae.to(device)
pbar.update(1)
if self.enable_teacache == True and self.rel_l1_thresh != rel_l1_thresh:
self.model.transformer.__class__.rel_l1_thresh = rel_l1_thresh
self.rel_l1_thresh = rel_l1_thresh
return (self.model, self.vae)
class TangoFluxSampler:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("TANGOFLUX_MODEL",),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"steps": ("INT", {"default": 50, "min": 1, "max": 10000, "step": 1}),
"guidance_scale": (
"FLOAT",
{"default": 3, "min": 1, "max": 100, "step": 1},
),
"duration": ("INT", {"default": 10, "min": 1, "max": 30, "step": 1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
}
RETURN_TYPES = ("TANGOFLUX_LATENTS",)
RETURN_NAMES = ("latents",)
OUTPUT_TOOLTIPS = "TangoFlux Sample"
CATEGORY = "TangoFlux"
FUNCTION = "sample"
DESCRIPTION = "Sampler for TangoFlux"
def sample(
self,
model,
prompt,
steps=50,
guidance_scale=3,
duration=10,
seed=0,
batch_size=1,
device="cuda",
):
pbar = ProgressBar(steps)
with torch.no_grad():
model.to(device)
try:
if model.transformer.__class__.enable_teacache:
model.transformer.__class__.num_steps = steps
except:
pass
log.info("Generating latents with TangoFlux")
latents = model.inference_flow(
prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale,
seed=seed,
num_samples_per_prompt=batch_size,
callback_on_step_end=lambda: pbar.update(1),
)
return ({"latents": latents, "duration": duration},)
class TangoFluxVAEDecodeAndPlay:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"vae": ("TANGOFLUX_VAE",),
"latents": ("TANGOFLUX_LATENTS",),
"filename_prefix": ("STRING", {"default": "TangoFlux"}),
"format": (
["wav", "mp3", "flac", "aac", "wma"],
{"default": "wav"},
),
"save_output": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ()
OUTPUT_NODE = True
CATEGORY = "TangoFlux"
FUNCTION = "play"
DESCRIPTION = "Decoder and Player for TangoFlux"
def decode(self, vae, latents):
results = []
for latent in latents:
decoded = vae.decode(latent.unsqueeze(0).transpose(2, 1)).sample.cpu()
results.append(decoded)
results = torch.cat(results, dim=0)
return results
def play(
self,
vae,
latents,
filename_prefix="TangoFlux",
format="wav",
save_output=True,
device="cuda",
):
audios = []
pbar = ProgressBar(len(latents) + 2)
if save_output:
output_dir = folder_paths.get_output_directory()
prefix_append = ""
type = "output"
else:
output_dir = folder_paths.get_temp_directory()
prefix_append = "_temp_" + "".join(
random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)
)
type = "temp"
filename_prefix += prefix_append
full_output_folder, filename, counter, subfolder, _ = (
folder_paths.get_save_image_path(filename_prefix, output_dir)
)
os.makedirs(full_output_folder, exist_ok=True)
pbar.update(1)
duration = latents["duration"]
latents = latents["latents"]
vae.to(device)
log.info("Decoding Tangoflux latents")
waves = self.decode(vae, latents)
pbar.update(1)
for wave in waves:
waveform_end = int(duration * vae.config.sampling_rate)
wave = wave[:, :waveform_end]
file = f"{filename}_{counter:05}_.{format}"
torchaudio.save(
os.path.join(full_output_folder, file), wave, sample_rate=44100
)
counter += 1
audios.append({"filename": file, "subfolder": subfolder, "type": type})
pbar.update(1)
return {
"ui": {"audios": audios},
}
NODE_CLASS_MAPPINGS = {
"TangoFluxLoader": TangoFluxLoader,
"TangoFluxSampler": TangoFluxSampler,
"TangoFluxVAEDecodeAndPlay": TangoFluxVAEDecodeAndPlay,
}
torchaudio
torchlibrosa
torchvision
diffusers
accelerate
datasets
librosa
wandb
tqdm
\ No newline at end of file
import os
import server
import folder_paths
web = server.web
@server.PromptServer.instance.routes.get("/tangoflux/playaudio")
async def play_audio(request):
query = request.rel_url.query
filename = query.get("filename", None)
if filename is None:
return web.Response(status=404)
if filename[0] == "/" or ".." in filename:
return web.Response(status=403)
filename, output_dir = folder_paths.annotated_filepath(filename)
if not output_dir:
file_type = query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(file_type)
if output_dir is None:
return web.Response(status=400)
subfolder = query.get("subfolder", None)
if subfolder:
full_output_dir = os.path.join(output_dir, subfolder)
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
return web.Response(status=403)
output_dir = full_output_dir
filename = os.path.basename(filename)
file_path = os.path.join(output_dir, filename)
if not os.path.isfile(file_path):
return web.Response(status=404)
_, ext = os.path.splitext(filename)
ext = ext.lower()
content_types = {
".wav": "audio/wav",
".mp3": "audio/mpeg",
".flac": "audio/flac",
".aac": "audio/aac",
".wma": "audio/x-ms-wma",
}
content_type = content_types.get(ext, None)
if content_type is None:
return web.Response(status=400)
try:
with open(file_path, "rb") as file:
data = file.read()
except:
return web.Response(status=500)
return web.Response(body=data, content_type=content_type)
# Code from https://github.com/ali-vilab/TeaCache/blob/main/TeaCache4TangoFlux/teacache_tango_flux.py
from typing import Any, Dict, Optional, Union
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_version,
logging,
scale_lora_layers,
unscale_lora_layers,
)
import torch
import numpy as np
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if (
joint_attention_kwargs is not None
and joint_attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)
if self.enable_teacache:
inp = hidden_states.clone()
temb_ = temb.clone()
modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.transformer_blocks[0].norm1(inp, emb=temb_)
)
if self.cnt == 0 or self.cnt == self.num_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [
4.98651651e02,
-2.83781631e02,
5.58554382e01,
-3.82021401e00,
2.64230861e-01,
]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(
(
(modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
)
.cpu()
.item()
)
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False}
if is_torch_version(">=", "1.11.0")
else {}
)
encoder_hidden_states, hidden_states = (
torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False}
if is_torch_version(">=", "1.11.0")
else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
self.previous_residual = hidden_states - ori_hidden_states
else:
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
encoder_hidden_states, hidden_states = (
torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
import { app } from "../../../scripts/app.js";
import { api } from "../../../scripts/api.js";
app.registerExtension({
name: "TangoFlux.playAudio",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "TangoFluxVAEDecodeAndPlay") {
const originalNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = async function () {
originalNodeCreated?.apply(this, arguments);
this.widgets_count = this.widgets?.length || 0;
this.addAudioWidgets = (audios) => {
if (this.widgets) {
for (let i = 0; i < this.widgets.length; i++) {
if (this.widgets[i].name.startsWith("_playaudio")) {
this.widgets[i].onRemove?.();
}
}
this.widgets.length = this.widgets_count;
}
let index = 0
for (const params of audios) {
const audioElement = document.createElement("audio");
audioElement.controls = true;
this.addDOMWidget("_playaudio" + index, "playaudio", audioElement, {
serialize: false,
hideOnZoom: false,
});
audioElement.src = api.apiURL(
`/tangoflux/playaudio?${new URLSearchParams(params)}`
);
index++
}
requestAnimationFrame(() => {
const newSize = this.computeSize();
newSize[0] = Math.max(newSize[0], this.size[0]);
newSize[1] = Math.max(newSize[1], this.size[1]);
this.onResize?.(newSize);
app.graph.setDirtyCanvas(true, false);
});
};
};
const originalNodeExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = async function (message) {
originalNodeExecuted?.apply(this, arguments);
if (message?.audios) {
this.addAudioWidgets(message.audios);
}
};
}
},
});
...@@ -138,7 +138,7 @@ def retrieve_timesteps( ...@@ -138,7 +138,7 @@ def retrieve_timesteps(
class TangoFlux(nn.Module): class TangoFlux(nn.Module):
def __init__(self, config, initialize_reference_model=False): def __init__(self, config, text_encoder_dir=None, initialize_reference_model=False,):
super().__init__() super().__init__()
...@@ -156,8 +156,12 @@ class TangoFlux(nn.Module): ...@@ -156,8 +156,12 @@ class TangoFlux(nn.Module):
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler) self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
self.max_text_seq_len = 64 self.max_text_seq_len = 64
self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name) self.text_encoder = T5EncoderModel.from_pretrained(
self.tokenizer = T5TokenizerFast.from_pretrained(self.text_encoder_name) text_encoder_dir if text_encoder_dir is not None else self.text_encoder_name
)
self.tokenizer = T5TokenizerFast.from_pretrained(
text_encoder_dir if text_encoder_dir is not None else self.text_encoder_name
)
self.text_embedding_dim = self.text_encoder.config.d_model self.text_embedding_dim = self.text_encoder.config.d_model
self.fc = nn.Sequential( self.fc = nn.Sequential(
...@@ -282,10 +286,18 @@ class TangoFlux(nn.Module): ...@@ -282,10 +286,18 @@ class TangoFlux(nn.Module):
timesteps=None, timesteps=None,
guidance_scale=3, guidance_scale=3,
duration=10, duration=10,
seed=0,
disable_progress=False, disable_progress=False,
num_samples_per_prompt=1, num_samples_per_prompt=1,
callback_on_step_end=None,
): ):
"""Only tested for single inference. Haven't test for batch inference""" """Only tested for single inference. Haven't test for batch inference"""
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
bsz = num_samples_per_prompt bsz = num_samples_per_prompt
device = self.transformer.device device = self.transformer.device
...@@ -376,6 +388,11 @@ class TangoFlux(nn.Module): ...@@ -376,6 +388,11 @@ class TangoFlux(nn.Module):
latents = scheduler.step(noise_pred, t, latents).prev_sample latents = scheduler.step(noise_pred, t, latents).prev_sample
progress_bar.update(1)
if callback_on_step_end is not None:
callback_on_step_end()
return latents return latents
def forward(self, latents, prompt, duration=torch.tensor([10]), sft=True): def forward(self, latents, prompt, duration=torch.tensor([10]), sft=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment