Commit bb1969ca authored by comfyanonymous's avatar comfyanonymous
Browse files

Initial support for the stable audio open model.

parent 1281f933
......@@ -135,3 +135,6 @@ class SD3(LatentFormat):
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
class StableAudio1(LatentFormat):
latent_channels = 64
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
import torch
from torch import nn
from typing import Literal, Dict, Any
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
def vae_sample(mean, scale):
stdev = nn.functional.softplus(scale) + 1e-4
var = stdev * stdev
logvar = torch.log(var)
latents = torch.randn_like(mean) * stdev + mean
kl = (mean * mean + var - logvar - 1).sum(1).mean()
return latents, kl
class VAEBottleneck(nn.Module):
def __init__(self):
super().__init__()
self.is_discrete = False
def encode(self, x, return_info=False, **kwargs):
info = {}
mean, scale = x.chunk(2, dim=1)
x, kl = vae_sample(mean, scale)
info["kl"] = kl
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def snake_beta(x, alpha, beta):
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
class SnakeBeta(nn.Module):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
# self.alpha.requires_grad = alpha_trainable
# self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return x
def WNConv1d(*args, **kwargs):
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":
act = torch.nn.ELU()
elif activation == "snake":
act = SnakeBeta(channels)
elif activation == "none":
act = torch.nn.Identity()
else:
raise ValueError(f"Unknown activation {activation}")
if antialias:
act = Activation1d(act)
return act
class ResidualUnit(nn.Module):
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
super().__init__()
self.dilation = dilation
padding = (dilation * (7-1)) // 2
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=7, dilation=dilation, padding=padding),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=out_channels, out_channels=out_channels,
kernel_size=1)
)
def forward(self, x):
res = x
#x = checkpoint(self.layers, x)
x = self.layers(x)
return x + res
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
super().__init__()
self.layers = nn.Sequential(
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=9, use_snake=use_snake),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
)
def forward(self, x):
return self.layers(x)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
super().__init__()
if use_nearest_upsample:
upsample_layer = nn.Sequential(
nn.Upsample(scale_factor=stride, mode="nearest"),
WNConv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride,
stride=1,
bias=False,
padding='same')
)
else:
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
upsample_layer,
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=9, use_snake=use_snake),
)
def forward(self, x):
return self.layers(x)
class OobleckEncoder(nn.Module):
def __init__(self,
in_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False
):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
]
for i in range(self.depth-1):
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class OobleckDecoder(nn.Module):
def __init__(self,
out_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=True):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
]
for i in range(self.depth-1, 0, -1):
layers += [DecoderBlock(
in_channels=c_mults[i]*channels,
out_channels=c_mults[i-1]*channels,
stride=strides[i-1],
use_snake=use_snake,
antialias_activation=antialias_activation,
use_nearest_upsample=use_nearest_upsample
)
]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
nn.Tanh() if final_tanh else nn.Identity()
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class AudioOobleckVAE(nn.Module):
def __init__(self,
in_channels=2,
channels=128,
latent_dim=64,
c_mults = [1, 2, 4, 8, 16],
strides = [2, 4, 4, 8, 8],
use_snake=True,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=False):
super().__init__()
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
self.bottleneck = VAEBottleneck()
def encode(self, x):
return self.bottleneck.encode(self.encoder(x))
def decode(self, x):
return self.decoder(self.bottleneck.decode(x))
This diff is collapsed.
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
import torch
import torch.nn as nn
from torch import Tensor, einsum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from einops import rearrange
import math
import comfy.ops
class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.empty(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
)
class NumberEmbedder(nn.Module):
def __init__(
self,
features: int,
dim: int = 256,
):
super().__init__()
self.features = features
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
if not torch.is_tensor(x):
device = next(self.embedding.parameters()).device
x = torch.tensor(x, device=device)
assert isinstance(x, Tensor)
shape = x.shape
x = rearrange(x, "... -> (...)")
embedding = self.embedding(x)
x = embedding.view(*shape, self.features)
return x # type: ignore
class Conditioner(nn.Module):
def __init__(
self,
dim: int,
output_dim: int,
project_out: bool = False
):
super().__init__()
self.dim = dim
self.output_dim = output_dim
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
def forward(self, x):
raise NotImplementedError()
class NumberConditioner(Conditioner):
'''
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
'''
def __init__(self,
output_dim: int,
min_val: float=0,
max_val: float=1
):
super().__init__(output_dim, output_dim)
self.min_val = min_val
self.max_val = max_val
self.embedder = NumberEmbedder(features=output_dim)
def forward(self, floats, device=None):
# Cast the inputs to floats
floats = [float(x) for x in floats]
if device is None:
device = next(self.embedder.parameters()).device
floats = torch.tensor(floats).to(device)
floats = floats.clamp(self.min_val, self.max_val)
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
# Cast floats to same type as embedder
embedder_dtype = next(self.embedder.parameters()).dtype
normalized_floats = normalized_floats.to(embedder_dtype)
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
......@@ -86,22 +86,32 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = q.shape
dim_head //= heads
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# force cast to fp32 to avoid overflowing
if attn_precision == torch.float32:
......@@ -138,17 +148,26 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = query.shape
dim_head //= heads
if skip_reshape:
b, _, _, dim_head = query.shape
else:
b, _, dim_head = query.shape
dim_head //= heads
scale = dim_head ** -0.5
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
else:
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
......@@ -200,22 +219,32 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None):
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = q.shape
dim_head //= heads
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
......@@ -311,9 +340,12 @@ try:
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
disabled_xformers = False
......@@ -328,10 +360,16 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
if mask is not None:
pad = 8 - q.shape[1] % 8
......@@ -341,18 +379,30 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = (
out.reshape(b, -1, heads * dim_head)
)
if skip_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
return out
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
......
......@@ -6,12 +6,15 @@ from comfy.ldm.cascade.stage_b import StageB
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.model_management
import comfy.conds
import comfy.ops
from enum import Enum
from . import utils
import comfy.latent_formats
import math
class ModelType(Enum):
EPS = 1
......@@ -20,9 +23,10 @@ class ModelType(Enum):
STABLE_CASCADE = 4
EDM = 5
FLOW = 6
V_PREDICTION_CONTINUOUS = 7
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
def model_sampling(model_config, model_type):
......@@ -44,6 +48,9 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.EDM:
c = EDM
s = ModelSamplingContinuousEDM
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
c = V_PREDICTION
s = ModelSamplingContinuousV
class ModelSampling(s, c):
pass
......@@ -236,11 +243,11 @@ class BaseModel(torch.nn.Module):
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
area = input_shape[0] * math.prod(input_shape[2:])
return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3]
area = input_shape[0] * math.prod(input_shape[2:])
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
......@@ -590,3 +597,33 @@ class SD3(BaseModel):
else:
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)
class StableAudio1(BaseModel):
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
def extra_conds(self, **kwargs):
out = {}
noise = kwargs.get("noise", None)
device = kwargs["device"]
seconds_start = kwargs.get("seconds_start", 0)
seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 21.53))
seconds_start_embed = self.seconds_start_embedder([seconds_start])[0].to(device)
seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device)
global_embed = torch.cat([seconds_start_embed, seconds_total_embed], dim=-1).reshape((1, -1))
out['global_embed'] = comfy.conds.CONDRegular(global_embed)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1)
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
......@@ -96,6 +96,11 @@ def detect_unet_config(state_dict, key_prefix):
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
unet_config = {}
unet_config["audio_model"] = "dit1.0"
return unet_config
unet_config = {
"use_checkpoint": False,
"image_size": 32,
......@@ -236,6 +241,13 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else:
return model_config
def unet_prefix_from_state_dict(state_dict):
if "model.model.postprocess_conv.weight" in state_dict: #audio models
unet_key_prefix = "model.model."
else:
unet_key_prefix = "model.diffusion_model."
return unet_key_prefix
def convert_config(unet_config):
new_config = unet_config.copy()
num_res_blocks = new_config.get("num_res_blocks", None)
......
......@@ -169,6 +169,14 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
class ModelSamplingContinuousV(ModelSamplingContinuousEDM):
def timestep(self, sigma):
return sigma.atan() / math.pi * 2
def sigma(self, timestep):
return (timestep * math.pi / 2).tan()
def time_snr_shift(alpha, t):
if alpha == 1.0:
return t
......
......@@ -51,6 +51,20 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
def reset_parameters(self):
return None
......@@ -133,6 +147,27 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input, output_size=None):
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
......@@ -147,6 +182,9 @@ class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear):
comfy_cast_weights = True
class Conv1d(disable_weight_init.Conv1d):
comfy_cast_weights = True
class Conv2d(disable_weight_init.Conv2d):
comfy_cast_weights = True
......@@ -161,3 +199,6 @@ class manual_cast(disable_weight_init):
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
comfy_cast_weights = True
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True
from comfy import sd1_clip
from transformers import T5TokenizerFast
import comfy.t5
import os
class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
class SAT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="t5base", clip_model=T5BaseModel, **kwargs)
......@@ -6,7 +6,7 @@ from comfy import model_management
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import yaml
import comfy.utils
......@@ -20,6 +20,7 @@ from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
from . import sd3_clip
from . import sa_t5
import comfy.model_patcher
import comfy.lora
......@@ -174,6 +175,7 @@ class VAE:
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
......@@ -232,6 +234,16 @@ class VAE:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.0.weight_v" in sd:
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2]) * model_management.dtype_size(dtype) #TODO: tweak for the audio VAE
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * 64) * model_management.dtype_size(dtype)
self.latent_channels = 64
self.output_channels = 2
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
......@@ -260,12 +272,12 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def vae_encode_crop_pixels(self, pixels):
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // self.downscale_ratio) * self.downscale_ratio
x_offset = (dims[d] % self.downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
......@@ -303,7 +315,7 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
......@@ -328,7 +340,7 @@ class VAE:
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
......@@ -371,6 +383,7 @@ class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
SD3 = 3
STABLE_AUDIO = 4
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = []
......@@ -404,6 +417,9 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
clip_target.clip = sa_t5.SAT5Model
clip_target.tokenizer = sa_t5.SAT5Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
......@@ -470,10 +486,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model_patcher = None
clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
......@@ -488,8 +505,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix)
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
......
......@@ -6,6 +6,7 @@ from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
from . import sd3_clip
from . import sa_t5
from . import supported_models_base
from . import latent_formats
......@@ -524,7 +525,35 @@ class SD3(supported_models_base.BASE):
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
class StableAudio(supported_models_base.BASE):
unet_config = {
"audio_model": "dit1.0",
}
sampling_settings = {"sigma_max": 500.0, "sigma_min": 0.03}
unet_extra_config = {}
latent_format = latent_formats.StableAudio1
text_encoder_key_prefix = ["text_encoders."]
vae_key_prefix = ["pretransform.model."]
def get_model(self, state_dict, prefix="", device=None):
seconds_start_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_start.": ""}, filter_keys=True)
seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True)
return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device)
def process_unet_state_dict(self, state_dict):
for k in list(state_dict.keys()):
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero
state_dict.pop(k)
return state_dict
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
models += [SVD_img2vid]
import torchaudio
import torch
import comfy.model_management
import folder_paths
import os
class EmptyLatentAudio:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": {}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "_for_testing/audio"
def generate(self):
batch_size = 1
latent = torch.zeros([batch_size, 64, 1024], device=self.device)
return ({"samples":latent, "type": "audio"}, )
class VAEEncodeAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "_for_testing/audio"
def encode(self, vae, audio):
t = vae.encode(audio["waveform"].movedim(1, -1))
return ({"samples":t}, )
class VAEDecodeAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "decode"
CATEGORY = "_for_testing/audio"
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1)
return ({"waveform": audio, "sample_rate": 44100}, )
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
self.compress_level = 4
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_audio"
OUTPUT_NODE = True
CATEGORY = "_for_testing/audio"
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
for (batch_number, waveform) in enumerate(audio["waveform"]):
#TODO: metadata
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.flac"
torchaudio.save(os.path.join(full_output_folder, file), waveform, audio["sample_rate"], format="FLAC")
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } }
class LoadAudio:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required": {"audio": [sorted(files), ]}, }
CATEGORY = "_for_testing/audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
multiplier = 1.0
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
@classmethod
def IS_CHANGED(s, audio):
image_path = folder_paths.get_annotated_filepath(audio)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, audio):
if not folder_paths.exists_annotated_filepath(audio):
return "Invalid audio file: {}".format(audio)
return True
NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio,
"VAEDecodeAudio": VAEDecodeAudio,
"SaveAudio": SaveAudio,
"LoadAudio": LoadAudio,
}
......@@ -196,6 +196,36 @@ class ModelSamplingContinuousEDM:
m.add_object_patch("latent_format", latent_format)
return (m, )
class ModelSamplingContinuousV:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction"],),
"sigma_max": ("FLOAT", {"default": 500.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
latent_format = None
sigma_data = 1.0
if sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousV, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class RescaleCFG:
@classmethod
def INPUT_TYPES(s):
......@@ -238,6 +268,7 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"ModelSamplingContinuousV": ModelSamplingContinuousV,
"ModelSamplingStableCascade": ModelSamplingStableCascade,
"ModelSamplingSD3": ModelSamplingSD3,
"RescaleCFG": RescaleCFG,
......
......@@ -818,7 +818,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
......@@ -826,11 +826,14 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name, type="stable_diffusion"):
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
elif type == "stable_audio":
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
......@@ -1973,6 +1976,7 @@ def init_custom_nodes():
"nodes_attention_multiply.py",
"nodes_advanced_samplers.py",
"nodes_webcam.py",
"nodes_audio.py",
"nodes_sd3.py",
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment