Commit bb1969ca authored by comfyanonymous's avatar comfyanonymous
Browse files

Initial support for the stable audio open model.

parent 1281f933
...@@ -135,3 +135,6 @@ class SD3(LatentFormat): ...@@ -135,3 +135,6 @@ class SD3(LatentFormat):
def process_out(self, latent): def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor return (latent / self.scale_factor) + self.shift_factor
class StableAudio1(LatentFormat):
latent_channels = 64
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
import torch
from torch import nn
from typing import Literal, Dict, Any
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
def vae_sample(mean, scale):
stdev = nn.functional.softplus(scale) + 1e-4
var = stdev * stdev
logvar = torch.log(var)
latents = torch.randn_like(mean) * stdev + mean
kl = (mean * mean + var - logvar - 1).sum(1).mean()
return latents, kl
class VAEBottleneck(nn.Module):
def __init__(self):
super().__init__()
self.is_discrete = False
def encode(self, x, return_info=False, **kwargs):
info = {}
mean, scale = x.chunk(2, dim=1)
x, kl = vae_sample(mean, scale)
info["kl"] = kl
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def snake_beta(x, alpha, beta):
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
class SnakeBeta(nn.Module):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
# self.alpha.requires_grad = alpha_trainable
# self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return x
def WNConv1d(*args, **kwargs):
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":
act = torch.nn.ELU()
elif activation == "snake":
act = SnakeBeta(channels)
elif activation == "none":
act = torch.nn.Identity()
else:
raise ValueError(f"Unknown activation {activation}")
if antialias:
act = Activation1d(act)
return act
class ResidualUnit(nn.Module):
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
super().__init__()
self.dilation = dilation
padding = (dilation * (7-1)) // 2
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=7, dilation=dilation, padding=padding),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=out_channels, out_channels=out_channels,
kernel_size=1)
)
def forward(self, x):
res = x
#x = checkpoint(self.layers, x)
x = self.layers(x)
return x + res
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
super().__init__()
self.layers = nn.Sequential(
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=9, use_snake=use_snake),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
)
def forward(self, x):
return self.layers(x)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
super().__init__()
if use_nearest_upsample:
upsample_layer = nn.Sequential(
nn.Upsample(scale_factor=stride, mode="nearest"),
WNConv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride,
stride=1,
bias=False,
padding='same')
)
else:
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
upsample_layer,
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=9, use_snake=use_snake),
)
def forward(self, x):
return self.layers(x)
class OobleckEncoder(nn.Module):
def __init__(self,
in_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False
):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
]
for i in range(self.depth-1):
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class OobleckDecoder(nn.Module):
def __init__(self,
out_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=True):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
]
for i in range(self.depth-1, 0, -1):
layers += [DecoderBlock(
in_channels=c_mults[i]*channels,
out_channels=c_mults[i-1]*channels,
stride=strides[i-1],
use_snake=use_snake,
antialias_activation=antialias_activation,
use_nearest_upsample=use_nearest_upsample
)
]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
nn.Tanh() if final_tanh else nn.Identity()
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class AudioOobleckVAE(nn.Module):
def __init__(self,
in_channels=2,
channels=128,
latent_dim=64,
c_mults = [1, 2, 4, 8, 16],
strides = [2, 4, 4, 8, 8],
use_snake=True,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=False):
super().__init__()
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
self.bottleneck = VAEBottleneck()
def encode(self, x):
return self.bottleneck.encode(self.encoder(x))
def decode(self, x):
return self.decoder(self.bottleneck.decode(x))
This diff is collapsed.
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
import torch
import torch.nn as nn
from torch import Tensor, einsum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from einops import rearrange
import math
import comfy.ops
class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.empty(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
)
class NumberEmbedder(nn.Module):
def __init__(
self,
features: int,
dim: int = 256,
):
super().__init__()
self.features = features
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
if not torch.is_tensor(x):
device = next(self.embedding.parameters()).device
x = torch.tensor(x, device=device)
assert isinstance(x, Tensor)
shape = x.shape
x = rearrange(x, "... -> (...)")
embedding = self.embedding(x)
x = embedding.view(*shape, self.features)
return x # type: ignore
class Conditioner(nn.Module):
def __init__(
self,
dim: int,
output_dim: int,
project_out: bool = False
):
super().__init__()
self.dim = dim
self.output_dim = output_dim
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
def forward(self, x):
raise NotImplementedError()
class NumberConditioner(Conditioner):
'''
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
'''
def __init__(self,
output_dim: int,
min_val: float=0,
max_val: float=1
):
super().__init__(output_dim, output_dim)
self.min_val = min_val
self.max_val = max_val
self.embedder = NumberEmbedder(features=output_dim)
def forward(self, floats, device=None):
# Cast the inputs to floats
floats = [float(x) for x in floats]
if device is None:
device = next(self.embedder.parameters()).device
floats = torch.tensor(floats).to(device)
floats = floats.clamp(self.min_val, self.max_val)
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
# Cast floats to same type as embedder
embedder_dtype = next(self.embedder.parameters()).dtype
normalized_floats = normalized_floats.to(embedder_dtype)
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
...@@ -86,22 +86,32 @@ class FeedForward(nn.Module): ...@@ -86,22 +86,32 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None): def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = q.shape if skip_reshape:
dim_head //= heads b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5 scale = dim_head ** -0.5
h = heads h = heads
q, k, v = map( if skip_reshape:
lambda t: t.unsqueeze(3) q, k, v = map(
.reshape(b, -1, heads, dim_head) lambda t: t.reshape(b * heads, -1, dim_head),
.permute(0, 2, 1, 3) (q, k, v),
.reshape(b * heads, -1, dim_head) )
.contiguous(), else:
(q, k, v), q, k, v = map(
) lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# force cast to fp32 to avoid overflowing # force cast to fp32 to avoid overflowing
if attn_precision == torch.float32: if attn_precision == torch.float32:
...@@ -138,17 +148,26 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None): ...@@ -138,17 +148,26 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
return out return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = query.shape if skip_reshape:
dim_head //= heads b, _, _, dim_head = query.shape
else:
b, _, dim_head = query.shape
dim_head //= heads
scale = dim_head ** -0.5 scale = dim_head ** -0.5
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
else:
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype dtype = query.dtype
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32 upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
...@@ -200,22 +219,32 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None) ...@@ -200,22 +219,32 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None): def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = q.shape if skip_reshape:
dim_head //= heads b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5 scale = dim_head ** -0.5
h = heads h = heads
q, k, v = map( if skip_reshape:
lambda t: t.unsqueeze(3) q, k, v = map(
.reshape(b, -1, heads, dim_head) lambda t: t.reshape(b * heads, -1, dim_head),
.permute(0, 2, 1, 3) (q, k, v),
.reshape(b * heads, -1, dim_head) )
.contiguous(), else:
(q, k, v), q, k, v = map(
) lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
...@@ -311,9 +340,12 @@ try: ...@@ -311,9 +340,12 @@ try:
except: except:
pass pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
b, _, dim_head = q.shape if skip_reshape:
dim_head //= heads b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
disabled_xformers = False disabled_xformers = False
...@@ -328,10 +360,16 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): ...@@ -328,10 +360,16 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
if disabled_xformers: if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask) return attention_pytorch(q, k, v, heads, mask)
q, k, v = map( if skip_reshape:
lambda t: t.reshape(b, -1, heads, dim_head), q, k, v = map(
(q, k, v), lambda t: t.reshape(b * heads, -1, dim_head),
) (q, k, v),
)
else:
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
if mask is not None: if mask is not None:
pad = 8 - q.shape[1] % 8 pad = 8 - q.shape[1] % 8
...@@ -341,18 +379,30 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): ...@@ -341,18 +379,30 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = ( if skip_reshape:
out.reshape(b, -1, heads * dim_head) out = (
) out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
return out return out
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
b, _, dim_head = q.shape if skip_reshape:
dim_head //= heads b, _, _, dim_head = q.shape
q, k, v = map( else:
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), b, _, dim_head = q.shape
(q, k, v), dim_head //= heads
) q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = ( out = (
......
...@@ -6,12 +6,15 @@ from comfy.ldm.cascade.stage_b import StageB ...@@ -6,12 +6,15 @@ from comfy.ldm.cascade.stage_b import StageB
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.model_management import comfy.model_management
import comfy.conds import comfy.conds
import comfy.ops import comfy.ops
from enum import Enum from enum import Enum
from . import utils from . import utils
import comfy.latent_formats import comfy.latent_formats
import math
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
...@@ -20,9 +23,10 @@ class ModelType(Enum): ...@@ -20,9 +23,10 @@ class ModelType(Enum):
STABLE_CASCADE = 4 STABLE_CASCADE = 4
EDM = 5 EDM = 5
FLOW = 6 FLOW = 6
V_PREDICTION_CONTINUOUS = 7
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
...@@ -44,6 +48,9 @@ def model_sampling(model_config, model_type): ...@@ -44,6 +48,9 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.EDM: elif model_type == ModelType.EDM:
c = EDM c = EDM
s = ModelSamplingContinuousEDM s = ModelSamplingContinuousEDM
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
c = V_PREDICTION
s = ModelSamplingContinuousV
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
...@@ -236,11 +243,11 @@ class BaseModel(torch.nn.Module): ...@@ -236,11 +243,11 @@ class BaseModel(torch.nn.Module):
if self.manual_cast_dtype is not None: if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked #TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3] area = input_shape[0] * math.prod(input_shape[2:])
return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024) return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else: else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3] area = input_shape[0] * math.prod(input_shape[2:])
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
...@@ -590,3 +597,33 @@ class SD3(BaseModel): ...@@ -590,3 +597,33 @@ class SD3(BaseModel):
else: else:
area = input_shape[0] * input_shape[2] * input_shape[3] area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024) return (area * 0.3) * (1024 * 1024)
class StableAudio1(BaseModel):
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
def extra_conds(self, **kwargs):
out = {}
noise = kwargs.get("noise", None)
device = kwargs["device"]
seconds_start = kwargs.get("seconds_start", 0)
seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 21.53))
seconds_start_embed = self.seconds_start_embedder([seconds_start])[0].to(device)
seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device)
global_embed = torch.cat([seconds_start_embed, seconds_total_embed], dim=-1).reshape((1, -1))
out['global_embed'] = comfy.conds.CONDRegular(global_embed)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1)
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
...@@ -96,6 +96,11 @@ def detect_unet_config(state_dict, key_prefix): ...@@ -96,6 +96,11 @@ def detect_unet_config(state_dict, key_prefix):
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config return unet_config
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
unet_config = {}
unet_config["audio_model"] = "dit1.0"
return unet_config
unet_config = { unet_config = {
"use_checkpoint": False, "use_checkpoint": False,
"image_size": 32, "image_size": 32,
...@@ -236,6 +241,13 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal ...@@ -236,6 +241,13 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else: else:
return model_config return model_config
def unet_prefix_from_state_dict(state_dict):
if "model.model.postprocess_conv.weight" in state_dict: #audio models
unet_key_prefix = "model.model."
else:
unet_key_prefix = "model.diffusion_model."
return unet_key_prefix
def convert_config(unet_config): def convert_config(unet_config):
new_config = unet_config.copy() new_config = unet_config.copy()
num_res_blocks = new_config.get("num_res_blocks", None) num_res_blocks = new_config.get("num_res_blocks", None)
......
...@@ -169,6 +169,14 @@ class ModelSamplingContinuousEDM(torch.nn.Module): ...@@ -169,6 +169,14 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
class ModelSamplingContinuousV(ModelSamplingContinuousEDM):
def timestep(self, sigma):
return sigma.atan() / math.pi * 2
def sigma(self, timestep):
return (timestep * math.pi / 2).tan()
def time_snr_shift(alpha, t): def time_snr_shift(alpha, t):
if alpha == 1.0: if alpha == 1.0:
return t return t
......
...@@ -51,6 +51,20 @@ class disable_weight_init: ...@@ -51,6 +51,20 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
def reset_parameters(self): def reset_parameters(self):
return None return None
...@@ -133,6 +147,27 @@ class disable_weight_init: ...@@ -133,6 +147,27 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input, output_size=None):
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod @classmethod
def conv_nd(s, dims, *args, **kwargs): def conv_nd(s, dims, *args, **kwargs):
if dims == 2: if dims == 2:
...@@ -147,6 +182,9 @@ class manual_cast(disable_weight_init): ...@@ -147,6 +182,9 @@ class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear): class Linear(disable_weight_init.Linear):
comfy_cast_weights = True comfy_cast_weights = True
class Conv1d(disable_weight_init.Conv1d):
comfy_cast_weights = True
class Conv2d(disable_weight_init.Conv2d): class Conv2d(disable_weight_init.Conv2d):
comfy_cast_weights = True comfy_cast_weights = True
...@@ -161,3 +199,6 @@ class manual_cast(disable_weight_init): ...@@ -161,3 +199,6 @@ class manual_cast(disable_weight_init):
class ConvTranspose2d(disable_weight_init.ConvTranspose2d): class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
comfy_cast_weights = True comfy_cast_weights = True
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True
from comfy import sd1_clip
from transformers import T5TokenizerFast
import comfy.t5
import os
class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
class SAT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="t5base", clip_model=T5BaseModel, **kwargs)
...@@ -6,7 +6,7 @@ from comfy import model_management ...@@ -6,7 +6,7 @@ from comfy import model_management
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import yaml import yaml
import comfy.utils import comfy.utils
...@@ -20,6 +20,7 @@ from . import sd1_clip ...@@ -20,6 +20,7 @@ from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
from . import sd3_clip from . import sd3_clip
from . import sa_t5
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
...@@ -174,6 +175,7 @@ class VAE: ...@@ -174,6 +175,7 @@ class VAE:
self.downscale_ratio = 8 self.downscale_ratio = 8
self.upscale_ratio = 8 self.upscale_ratio = 8
self.latent_channels = 4 self.latent_channels = 4
self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
...@@ -232,6 +234,16 @@ class VAE: ...@@ -232,6 +234,16 @@ class VAE:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.0.weight_v" in sd:
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2]) * model_management.dtype_size(dtype) #TODO: tweak for the audio VAE
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * 64) * model_management.dtype_size(dtype)
self.latent_channels = 64
self.output_channels = 2
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None
...@@ -260,12 +272,12 @@ class VAE: ...@@ -260,12 +272,12 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio dims = pixels.shape[1:-1]
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio for d in range(len(dims)):
if pixels.shape[1] != x or pixels.shape[2] != y: x = (dims[d] // self.downscale_ratio) * self.downscale_ratio
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2 x_offset = (dims[d] % self.downscale_ratio) // 2
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2 if x != dims[d]:
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] pixels = pixels.narrow(d + 1, x_offset, x)
return pixels return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
...@@ -303,7 +315,7 @@ class VAE: ...@@ -303,7 +315,7 @@ class VAE:
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device) pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
...@@ -328,7 +340,7 @@ class VAE: ...@@ -328,7 +340,7 @@ class VAE:
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
...@@ -371,6 +383,7 @@ class CLIPType(Enum): ...@@ -371,6 +383,7 @@ class CLIPType(Enum):
STABLE_DIFFUSION = 1 STABLE_DIFFUSION = 1
STABLE_CASCADE = 2 STABLE_CASCADE = 2
SD3 = 3 SD3 = 3
STABLE_AUDIO = 4
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = [] clip_data = []
...@@ -404,6 +417,9 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI ...@@ -404,6 +417,9 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer clip_target.tokenizer = sd3_clip.SD3Tokenizer
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
clip_target.clip = sa_t5.SAT5Model
clip_target.tokenizer = sa_t5.SAT5Tokenizer
else: else:
clip_target.clip = sd1_clip.SD1ClipModel clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer clip_target.tokenizer = sd1_clip.SD1Tokenizer
...@@ -470,10 +486,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o ...@@ -470,10 +486,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model_patcher = None model_patcher = None
clip_target = None clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.") model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
...@@ -488,8 +505,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o ...@@ -488,8 +505,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_model: if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, diffusion_model_prefix)
if output_vae: if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
......
...@@ -6,6 +6,7 @@ from . import sd1_clip ...@@ -6,6 +6,7 @@ from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
from . import sd3_clip from . import sd3_clip
from . import sa_t5
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
...@@ -524,7 +525,35 @@ class SD3(supported_models_base.BASE): ...@@ -524,7 +525,35 @@ class SD3(supported_models_base.BASE):
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5)) return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
class StableAudio(supported_models_base.BASE):
unet_config = {
"audio_model": "dit1.0",
}
sampling_settings = {"sigma_max": 500.0, "sigma_min": 0.03}
unet_extra_config = {}
latent_format = latent_formats.StableAudio1
text_encoder_key_prefix = ["text_encoders."]
vae_key_prefix = ["pretransform.model."]
def get_model(self, state_dict, prefix="", device=None):
seconds_start_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_start.": ""}, filter_keys=True)
seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True)
return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device)
def process_unet_state_dict(self, state_dict):
for k in list(state_dict.keys()):
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero
state_dict.pop(k)
return state_dict
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3] models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
models += [SVD_img2vid] models += [SVD_img2vid]
import torchaudio
import torch
import comfy.model_management
import folder_paths
import os
class EmptyLatentAudio:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": {}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "_for_testing/audio"
def generate(self):
batch_size = 1
latent = torch.zeros([batch_size, 64, 1024], device=self.device)
return ({"samples":latent, "type": "audio"}, )
class VAEEncodeAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "_for_testing/audio"
def encode(self, vae, audio):
t = vae.encode(audio["waveform"].movedim(1, -1))
return ({"samples":t}, )
class VAEDecodeAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "decode"
CATEGORY = "_for_testing/audio"
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1)
return ({"waveform": audio, "sample_rate": 44100}, )
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
self.compress_level = 4
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_audio"
OUTPUT_NODE = True
CATEGORY = "_for_testing/audio"
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
for (batch_number, waveform) in enumerate(audio["waveform"]):
#TODO: metadata
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.flac"
torchaudio.save(os.path.join(full_output_folder, file), waveform, audio["sample_rate"], format="FLAC")
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } }
class LoadAudio:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required": {"audio": [sorted(files), ]}, }
CATEGORY = "_for_testing/audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
multiplier = 1.0
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
@classmethod
def IS_CHANGED(s, audio):
image_path = folder_paths.get_annotated_filepath(audio)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, audio):
if not folder_paths.exists_annotated_filepath(audio):
return "Invalid audio file: {}".format(audio)
return True
NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio,
"VAEDecodeAudio": VAEDecodeAudio,
"SaveAudio": SaveAudio,
"LoadAudio": LoadAudio,
}
...@@ -196,6 +196,36 @@ class ModelSamplingContinuousEDM: ...@@ -196,6 +196,36 @@ class ModelSamplingContinuousEDM:
m.add_object_patch("latent_format", latent_format) m.add_object_patch("latent_format", latent_format)
return (m, ) return (m, )
class ModelSamplingContinuousV:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction"],),
"sigma_max": ("FLOAT", {"default": 500.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
latent_format = None
sigma_data = 1.0
if sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousV, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class RescaleCFG: class RescaleCFG:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
...@@ -238,6 +268,7 @@ class RescaleCFG: ...@@ -238,6 +268,7 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"ModelSamplingContinuousV": ModelSamplingContinuousV,
"ModelSamplingStableCascade": ModelSamplingStableCascade, "ModelSamplingStableCascade": ModelSamplingStableCascade,
"ModelSamplingSD3": ModelSamplingSD3, "ModelSamplingSD3": ModelSamplingSD3,
"RescaleCFG": RescaleCFG, "RescaleCFG": RescaleCFG,
......
...@@ -818,7 +818,7 @@ class CLIPLoader: ...@@ -818,7 +818,7 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3"], ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
...@@ -826,11 +826,14 @@ class CLIPLoader: ...@@ -826,11 +826,14 @@ class CLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name, type="stable_diffusion"): def load_clip(self, clip_name, type="stable_diffusion"):
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
if type == "stable_cascade": if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == "sd3": elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3 clip_type = comfy.sd.CLIPType.SD3
elif type == "stable_audio":
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_path = folder_paths.get_full_path("clip", clip_name) clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
...@@ -1973,6 +1976,7 @@ def init_custom_nodes(): ...@@ -1973,6 +1976,7 @@ def init_custom_nodes():
"nodes_attention_multiply.py", "nodes_attention_multiply.py",
"nodes_advanced_samplers.py", "nodes_advanced_samplers.py",
"nodes_webcam.py", "nodes_webcam.py",
"nodes_audio.py",
"nodes_sd3.py", "nodes_sd3.py",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment