Commit 8c4a9bef authored by comfyanonymous's avatar comfyanonymous
Browse files

SD3 Support.

parent a5e6a632
...@@ -25,8 +25,9 @@ class SD15(LatentFormat): ...@@ -25,8 +25,9 @@ class SD15(LatentFormat):
self.taesd_decoder_name = "taesd_decoder" self.taesd_decoder_name = "taesd_decoder"
class SDXL(LatentFormat): class SDXL(LatentFormat):
scale_factor = 0.13025
def __init__(self): def __init__(self):
self.scale_factor = 0.13025
self.latent_rgb_factors = [ self.latent_rgb_factors = [
# R G B # R G B
[ 0.3920, 0.4054, 0.4549], [ 0.3920, 0.4054, 0.4549],
...@@ -104,3 +105,33 @@ class SC_B(LatentFormat): ...@@ -104,3 +105,33 @@ class SC_B(LatentFormat):
[-0.3087, -0.1535, 0.0366], [-0.3087, -0.1535, 0.0366],
[ 0.0290, -0.1574, -0.4078] [ 0.0290, -0.1574, -0.4078]
] ]
class SD3(LatentFormat):
latent_channels = 16
def __init__(self):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
self.latent_rgb_factors = [
[-0.0645, 0.0177, 0.1052],
[ 0.0028, 0.0312, 0.0650],
[ 0.1848, 0.0762, 0.0360],
[ 0.0944, 0.0360, 0.0889],
[ 0.0897, 0.0506, -0.0364],
[-0.0020, 0.1203, 0.0284],
[ 0.0855, 0.0118, 0.0283],
[-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700],
[-0.0412, 0.0281, -0.0039],
[ 0.1106, 0.1171, 0.1220],
[-0.0248, 0.0682, -0.0481],
[ 0.0815, 0.0846, 0.1207],
[-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456],
[-0.1418, -0.1457, -0.1259]
]
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
import logging
import math
from typing import Dict, Optional
import numpy as np
import torch
import torch.nn as nn
from .. import attention
from einops import rearrange, repeat
def default(x, y):
if x is not None:
return x
return y
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
drop_probs = drop
linear_layer = partial(operations.Conv2d, kernel_size=1) if use_conv else operations.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs)
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
self.drop2 = nn.Dropout(drop_probs)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
dynamic_img_pad: torch.jit.Final[bool]
def __init__(
self,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
flatten: bool = True,
bias: bool = True,
strict_img_size: bool = True,
dynamic_img_pad: bool = True,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.patch_size = (patch_size, patch_size)
if img_size is not None:
self.img_size = (img_size, img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
# if self.img_size is not None:
# if self.strict_img_size:
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
# elif not self.dynamic_img_pad:
# _assert(
# H % self.patch_size[0] == 0,
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
# )
# _assert(
# W % self.patch_size[1] == 0,
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
# )
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
x = self.norm(x)
return x
def modulate(x, shift, scale):
if shift is None:
shift = torch.zeros_like(scale)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
def get_2d_sincos_pos_embed(
embed_dim,
grid_size,
cls_token=False,
extra_tokens=0,
scaling_factor=None,
offset=None,
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
if scaling_factor is not None:
grid = grid / scaling_factor
if offset is not None:
grid = grid - offset
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32):
omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32):
small = min(h, w)
val_h = (h / small) * val_magnitude
val_w = (w / small) * val_magnitude
grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij')
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
if torch.is_floating_point(t):
embedding = embedding.to(dtype=t.dtype)
return embedding
def forward(self, t, dtype, **kwargs):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class VectorEmbedder(nn.Module):
"""
Embeds a flat vector of dimension input_dim
"""
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
emb = self.mlp(x)
return emb
#################################################################################
# Core DiT Model #
#################################################################################
def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2]
def optimized_attention(qkv, num_heads):
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
class SelfAttention(nn.Module):
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: Optional[float] = None,
proj_drop: float = 0.0,
attn_mode: str = "xformers",
pre_only: bool = False,
qk_norm: Optional[str] = None,
rmsnorm: bool = False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
if not pre_only:
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj_drop = nn.Dropout(proj_drop)
assert attn_mode in self.ATTENTION_MODES
self.attn_mode = attn_mode
self.pre_only = pre_only
if qk_norm == "rms":
self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
elif qk_norm == "ln":
self.ln_q = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
self.ln_k = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
elif qk_norm is None:
self.ln_q = nn.Identity()
self.ln_k = nn.Identity()
else:
raise ValueError(qk_norm)
def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
B, L, C = x.shape
qkv = self.qkv(x)
q, k, v = split_qkv(qkv, self.head_dim)
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
return (q, k, v)
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
x = self.proj(x)
x = self.proj_drop(x)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = self.pre_attention(x)
x = optimized_attention(
qkv, num_heads=self.num_heads
)
x = self.post_attention(x)
return x
class RMSNorm(torch.nn.Module):
def __init__(
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.learnable_scale = elementwise_affine
if self.learnable_scale:
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
else:
self.register_parameter("weight", None)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = self._norm(x)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
class SwiGLUFeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float] = None,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
class DismantledBlock(nn.Module):
"""
A DiT block with gated adaptive layer norm (adaLN) conditioning.
"""
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: str = "xformers",
qkv_bias: bool = False,
pre_only: bool = False,
rmsnorm: bool = False,
scale_mod_only: bool = False,
swiglu: bool = False,
qk_norm: Optional[str] = None,
dtype=None,
device=None,
operations=None,
**block_kwargs,
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if not rmsnorm:
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
else:
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=pre_only,
qk_norm=qk_norm,
rmsnorm=rmsnorm,
dtype=dtype,
device=device,
operations=operations
)
if not pre_only:
if not rmsnorm:
self.norm2 = operations.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
)
else:
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
if not pre_only:
if not swiglu:
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=lambda: nn.GELU(approximate="tanh"),
drop=0,
dtype=dtype,
device=device,
operations=operations
)
else:
self.mlp = SwiGLUFeedForward(
dim=hidden_size,
hidden_dim=mlp_hidden_dim,
multiple_of=256,
)
self.scale_mod_only = scale_mod_only
if not scale_mod_only:
n_mods = 6 if not pre_only else 2
else:
n_mods = 4 if not pre_only else 1
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device)
)
self.pre_only = pre_only
def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
if not self.pre_only:
if not self.scale_mod_only:
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(c).chunk(6, dim=1)
else:
shift_msa = None
shift_mlp = None
(
scale_msa,
gate_msa,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(
c
).chunk(4, dim=1)
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
return qkv, (
x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
)
else:
if not self.scale_mod_only:
(
shift_msa,
scale_msa,
) = self.adaLN_modulation(
c
).chunk(2, dim=1)
else:
shift_msa = None
scale_msa = self.adaLN_modulation(c)
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
return qkv, None
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
assert not self.pre_only
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention(
qkv,
num_heads=self.attn.num_heads,
)
return self.post_attention(attn, *intermediates)
def block_mixing(*args, use_checkpoint=True, **kwargs):
if use_checkpoint:
return torch.utils.checkpoint.checkpoint(
_block_mixing, *args, use_reentrant=False, **kwargs
)
else:
return _block_mixing(*args, **kwargs)
def _block_mixing(context, x, context_block, x_block, c):
context_qkv, context_intermediates = context_block.pre_attention(context, c)
x_qkv, x_intermediates = x_block.pre_attention(x, c)
o = []
for t in range(3):
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
qkv = tuple(o)
attn = optimized_attention(
qkv,
num_heads=x_block.attn.num_heads,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
attn[:, context_qkv[0].shape[1] :],
)
if not context_block.pre_only:
context = context_block.post_attention(context_attn, *context_intermediates)
else:
context = None
x = x_block.post_attention(x_attn, *x_intermediates)
return context, x
class JointBlock(nn.Module):
"""just a small wrapper to serve as a fsdp unit"""
def __init__(
self,
*args,
**kwargs,
):
super().__init__()
pre_only = kwargs.pop("pre_only")
qk_norm = kwargs.pop("qk_norm", None)
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
def forward(self, *args, **kwargs):
return block_mixing(
*args, context_block=self.context_block, x_block=self.x_block, **kwargs
)
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(
self,
hidden_size: int,
patch_size: int,
out_channels: int,
total_out_channels: Optional[int] = None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = (
operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
if (total_out_channels is None)
else operations.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class SelfAttentionContext(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None):
super().__init__()
dim_head = dim // heads
inner_dim = dim
self.heads = heads
self.dim_head = dim_head
self.qkv = operations.Linear(dim, dim * 3, bias=True, dtype=dtype, device=device)
self.proj = operations.Linear(inner_dim, dim, dtype=dtype, device=device)
def forward(self, x):
qkv = self.qkv(x)
q, k, v = split_qkv(qkv, self.dim_head)
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
return self.proj(x)
class ContextProcessorBlock(nn.Module):
def __init__(self, context_size, dtype=None, device=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = SelfAttentionContext(context_size, dtype=dtype, device=device, operations=operations)
self.norm2 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp = Mlp(in_features=context_size, hidden_features=(context_size * 4), act_layer=lambda: nn.GELU(approximate="tanh"), drop=0, dtype=dtype, device=device, operations=operations)
def forward(self, x):
x += self.attn(self.norm1(x))
x += self.mlp(self.norm2(x))
return x
class ContextProcessor(nn.Module):
def __init__(self, context_size, num_layers, dtype=None, device=None, operations=None):
super().__init__()
self.layers = torch.nn.ModuleList([ContextProcessorBlock(context_size, dtype=dtype, device=device, operations=operations) for i in range(num_layers)])
self.norm = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
def forward(self, x):
for i, l in enumerate(self.layers):
x = l(x)
return self.norm(x)
class MMDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size: int = 32,
patch_size: int = 2,
in_channels: int = 4,
depth: int = 28,
# hidden_size: Optional[int] = None,
# num_heads: Optional[int] = None,
mlp_ratio: float = 4.0,
learn_sigma: bool = False,
adm_in_channels: Optional[int] = None,
context_embedder_config: Optional[Dict] = None,
compile_core: bool = False,
use_checkpoint: bool = False,
register_length: int = 0,
attn_mode: str = "torch",
rmsnorm: bool = False,
scale_mod_only: bool = False,
swiglu: bool = False,
out_channels: Optional[int] = None,
pos_embed_scaling_factor: Optional[float] = None,
pos_embed_offset: Optional[float] = None,
pos_embed_max_size: Optional[int] = None,
num_patches = None,
qk_norm: Optional[str] = None,
qkv_bias: bool = True,
context_processor_layers = None,
context_size = 4096,
dtype = None, #TODO
device = None,
operations = None,
):
super().__init__()
self.dtype = dtype
self.learn_sigma = learn_sigma
self.in_channels = in_channels
default_out_channels = in_channels * 2 if learn_sigma else in_channels
self.out_channels = default(out_channels, default_out_channels)
self.patch_size = patch_size
self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size
# hidden_size = default(hidden_size, 64 * depth)
# num_heads = default(num_heads, hidden_size // 64)
# apply magic --> this defines a head_size of 64
self.hidden_size = 64 * depth
num_heads = depth
self.num_heads = num_heads
self.x_embedder = PatchEmbed(
input_size,
patch_size,
in_channels,
self.hidden_size,
bias=True,
strict_img_size=self.pos_embed_max_size is None,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
self.y_embedder = None
if adm_in_channels is not None:
assert isinstance(adm_in_channels, int)
self.y_embedder = VectorEmbedder(adm_in_channels, self.hidden_size, dtype=dtype, device=device, operations=operations)
if context_processor_layers is not None:
self.context_processor = ContextProcessor(context_size, context_processor_layers, dtype=dtype, device=device, operations=operations)
else:
self.context_processor = None
self.context_embedder = nn.Identity()
if context_embedder_config is not None:
if context_embedder_config["target"] == "torch.nn.Linear":
self.context_embedder = operations.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
self.register_length = register_length
if self.register_length > 0:
self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size, dtype=dtype, device=device))
# num_patches = self.x_embedder.num_patches
# Will use fixed sin-cos embedding:
# just use a buffer already
if num_patches is not None:
self.register_buffer(
"pos_embed",
torch.empty(1, num_patches, self.hidden_size, dtype=dtype, device=device),
)
else:
self.pos_embed = None
self.use_checkpoint = use_checkpoint
self.joint_blocks = nn.ModuleList(
[
JointBlock(
self.hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=i == depth - 1,
rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only,
swiglu=swiglu,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations
)
for i in range(depth)
]
)
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
# self.initialize_weights()
if compile_core:
assert False
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
def initialize_weights(self):
# TODO: Init context_embedder?
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding
if self.pos_embed is not None:
pos_embed_grid_size = (
int(self.x_embedder.num_patches**0.5)
if self.pos_embed_max_size is None
else self.pos_embed_max_size
)
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.x_embedder.num_patches**0.5),
pos_embed_grid_size,
scaling_factor=self.pos_embed_scaling_factor,
offset=self.pos_embed_offset,
)
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.pos_embed.shape[-2]**0.5),
scaling_factor=self.pos_embed_scaling_factor,
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
if hasattr(self, "y_embedder"):
nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.joint_blocks:
nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def cropped_pos_embed(self, hw, device=None):
p = self.x_embedder.patch_size[0]
h, w = hw
# patched size
h = (h + 1) // p
w = (w + 1) // p
if self.pos_embed is None:
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
assert self.pos_embed_max_size is not None
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
top = (self.pos_embed_max_size - h) // 2
left = (self.pos_embed_max_size - w) // 2
spatial_pos_embed = rearrange(
self.pos_embed,
"1 (h w) c -> 1 h w c",
h=self.pos_embed_max_size,
w=self.pos_embed_max_size,
)
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
# print(spatial_pos_embed, top, left, h, w)
# # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.875, 7.875, device=device) #matches exactly for 1024 res
# t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.5, 7.5, device=device) #scales better
# # print(t)
# return t
return spatial_pos_embed
def unpatchify(self, x, hw=None):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
if hw is None:
h = w = int(x.shape[1] ** 0.5)
else:
h, w = hw
h = (h + 1) // p
w = (w + 1) // p
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def forward_core_with_concat(
self,
x: torch.Tensor,
c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.register_length > 0:
context = torch.cat(
(
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
default(context, torch.Tensor([]).type_as(x)),
),
1,
)
# context is B, L', D
# x is B, L, D
for block in self.joint_blocks:
context, x = block(
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
return x
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
if self.context_processor is not None:
context = self.context_processor(context)
hw = x.shape[-2:]
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype)
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None and self.y_embedder is not None:
y = self.y_embedder(y) # (N, D)
c = c + y # (N, D)
if context is not None:
context = self.context_embedder(context)
x = self.forward_core_with_concat(x, c, context)
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x[:,:,:hw[-2],:hw[-1]]
class OpenAISignatureMMDITWrapper(MMDiT):
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y)
...@@ -5,11 +5,13 @@ from comfy.ldm.cascade.stage_c import StageC ...@@ -5,11 +5,13 @@ from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.cascade.stage_b import StageB
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.model_management import comfy.model_management
import comfy.conds import comfy.conds
import comfy.ops import comfy.ops
from enum import Enum from enum import Enum
from . import utils from . import utils
import comfy.latent_formats
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
...@@ -17,6 +19,7 @@ class ModelType(Enum): ...@@ -17,6 +19,7 @@ class ModelType(Enum):
V_PREDICTION_EDM = 3 V_PREDICTION_EDM = 3
STABLE_CASCADE = 4 STABLE_CASCADE = 4
EDM = 5 EDM = 5
FLOW = 6
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
...@@ -32,6 +35,9 @@ def model_sampling(model_config, model_type): ...@@ -32,6 +35,9 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_EDM: elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION c = V_PREDICTION
s = ModelSamplingContinuousEDM s = ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingDiscreteFlow
elif model_type == ModelType.STABLE_CASCADE: elif model_type == ModelType.STABLE_CASCADE:
c = EPS c = EPS
s = StableCascadeSampling s = StableCascadeSampling
...@@ -557,3 +563,23 @@ class StableCascade_B(BaseModel): ...@@ -557,3 +563,23 @@ class StableCascade_B(BaseModel):
out["effnet"] = comfy.conds.CONDRegular(prior) out["effnet"] = comfy.conds.CONDRegular(prior)
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
return out return out
class SD3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = {}
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
import comfy.supported_models import comfy.supported_models
import comfy.supported_models_base import comfy.supported_models_base
import math
import logging import logging
def count_blocks(state_dict_keys, prefix_string): def count_blocks(state_dict_keys, prefix_string):
...@@ -26,12 +27,47 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): ...@@ -26,12 +27,47 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
return None return None
def detect_unet_config(state_dict, key_prefix): def detect_unet_config(state_dict, key_prefix):
state_dict_keys = list(state_dict.keys()) state_dict_keys = list(state_dict.keys())
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
unet_config = {}
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
unet_config["patch_size"] = patch_size
unet_config["out_channels"] = state_dict['{}final_layer.linear.weight'.format(key_prefix)].shape[0] // (patch_size * patch_size)
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
unet_config["input_size"] = None
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
if y_key in state_dict_keys:
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
context_key = '{}context_embedder.weight'.format(key_prefix)
if context_key in state_dict_keys:
in_features = state_dict[context_key].shape[1]
out_features = state_dict[context_key].shape[0]
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
num_patches_key = '{}pos_embed'.format(key_prefix)
if num_patches_key in state_dict_keys:
num_patches = state_dict[num_patches_key].shape[1]
unet_config["num_patches"] = num_patches
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
rms_qk = '{}joint_blocks.0.context_block.attn.ln_q.weight'.format(key_prefix)
if rms_qk in state_dict_keys:
unet_config["qk_norm"] = "rms"
unet_config["pos_embed_scaling_factor"] = None #unused for inference
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
if context_processor in state_dict_keys:
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
return unet_config
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
unet_config = {} unet_config = {}
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix) text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
...@@ -58,7 +94,6 @@ def detect_unet_config(state_dict, key_prefix): ...@@ -58,7 +94,6 @@ def detect_unet_config(state_dict, key_prefix):
unet_config['nhead'] = [-1, 9, 18, 18] unet_config['nhead'] = [-1, 9, 18, 18]
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config return unet_config
unet_config = { unet_config = {
...@@ -93,6 +128,7 @@ def detect_unet_config(state_dict, key_prefix): ...@@ -93,6 +128,7 @@ def detect_unet_config(state_dict, key_prefix):
use_linear_in_transformer = False use_linear_in_transformer = False
video_model = False video_model = False
video_model_cross = False
current_res = 1 current_res = 1
count = 0 count = 0
...@@ -136,6 +172,7 @@ def detect_unet_config(state_dict, key_prefix): ...@@ -136,6 +172,7 @@ def detect_unet_config(state_dict, key_prefix):
context_dim = out[1] context_dim = out[1]
use_linear_in_transformer = out[2] use_linear_in_transformer = out[2]
video_model = out[3] video_model = out[3]
video_model_cross = out[4]
else: else:
transformer_depth.append(0) transformer_depth.append(0)
...@@ -176,6 +213,7 @@ def detect_unet_config(state_dict, key_prefix): ...@@ -176,6 +213,7 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["video_kernel_size"] = [3, 1, 1] unet_config["video_kernel_size"] = [3, 1, 1]
unet_config["use_temporal_resblock"] = True unet_config["use_temporal_resblock"] = True
unet_config["use_temporal_attention"] = True unet_config["use_temporal_attention"] = True
unet_config["disable_temporal_crossattention"] = not video_model_cross
else: else:
unet_config["use_temporal_resblock"] = False unet_config["use_temporal_resblock"] = False
unet_config["use_temporal_attention"] = False unet_config["use_temporal_attention"] = False
......
...@@ -33,6 +33,19 @@ class EDM(V_PREDICTION): ...@@ -33,6 +33,19 @@ class EDM(V_PREDICTION):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class CONST:
def calculate_input(self, sigma, noise):
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
return latent / (1.0 - sigma)
class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None): def __init__(self, model_config=None):
...@@ -104,6 +117,12 @@ class ModelSamplingDiscrete(torch.nn.Module): ...@@ -104,6 +117,12 @@ class ModelSamplingDiscrete(torch.nn.Module):
percent = 1.0 - percent percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item() return self.sigma(torch.tensor(percent * 999.0)).item()
class ModelSamplingDiscreteEDM(ModelSamplingDiscrete):
def timestep(self, sigma):
return 0.25 * sigma.log()
def sigma(self, timestep):
return (timestep / 0.25).exp()
class ModelSamplingContinuousEDM(torch.nn.Module): class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None): def __init__(self, model_config=None):
...@@ -149,6 +168,48 @@ class ModelSamplingContinuousEDM(torch.nn.Module): ...@@ -149,6 +168,48 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
log_sigma_min = math.log(self.sigma_min) log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
def time_snr_shift(alpha, t):
if alpha == 1.0:
return t
return alpha * t / (1 + (alpha - 1) * t)
class ModelSamplingDiscreteFlow(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
def set_parameters(self, shift=1.0, timesteps=1000):
self.shift = shift
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
self.register_buffer('sigmas', ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
def sigma(self, timestep):
return time_snr_shift(self.shift, timestep / 1000)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
class StableCascadeSampling(ModelSamplingDiscrete): class StableCascadeSampling(ModelSamplingDiscrete):
def __init__(self, model_config=None): def __init__(self, model_config=None):
super().__init__() super().__init__()
......
...@@ -19,6 +19,7 @@ from . import model_detection ...@@ -19,6 +19,7 @@ from . import model_detection
from . import sd1_clip from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
from . import sd3_clip
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
...@@ -395,9 +396,12 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI ...@@ -395,9 +396,12 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
else: else:
clip_target.clip = sd1_clip.SD1ClipModel clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer clip_target.tokenizer = sd1_clip.SD1Tokenizer
else: elif len(clip_data) == 2:
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3:
clip_target.clip = sd3_clip.SD3ClipModel
clip_target.tokenizer = sd3_clip.SD3Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory) clip = CLIP(clip_target, embedding_directory=embedding_directory)
for c in clip_data: for c in clip_data:
......
from comfy import sd1_clip
from comfy import sdxl_clip
from transformers import T5TokenizerFast
import comfy.t5
import torch
import os
import comfy.model_management
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
class SDT5XXLModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
class SD3Tokenizer:
def __init__(self, embedding_directory=None):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
class SD3ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.clip_g.set_clip_options(options)
self.t5xxl.set_clip_options(options)
def reset_clip_options(self):
self.clip_g.reset_clip_options()
self.clip_l.reset_clip_options()
self.t5xxl.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
lg_out = None
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
out = lg_out
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
else:
pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device())
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
out = t5_out
return out, pooled
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return self.clip_g.load_sd(sd)
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return self.t5xxl.load_sd(sd)
...@@ -5,6 +5,7 @@ from . import utils ...@@ -5,6 +5,7 @@ from . import utils
from . import sd1_clip from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
from . import sd3_clip
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
...@@ -488,6 +489,28 @@ class SDXL_instructpix2pix(SDXL): ...@@ -488,6 +489,28 @@ class SDXL_instructpix2pix(SDXL):
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device) return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] class SD3(supported_models_base.BASE):
unet_config = {
"in_channels": 16,
"pos_embed_scaling_factor": None,
}
sampling_settings = {
"shift": 3.0,
}
unet_extra_config = {}
latent_format = latent_formats.SD3
text_encoder_key_prefix = ["text_encoders."] #TODO?
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD3(self, device=device)
return out
def clip_target(self):
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO?
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
models += [SVD_img2vid] models += [SVD_img2vid]
import torch
import math
from comfy.ldm.modules.attention import optimized_attention_for_device
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
self.variance_epsilon = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(device=x.device, dtype=x.dtype) * x
class T5DenseActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device, operations):
super().__init__()
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x):
x = torch.nn.functional.relu(self.wi(x))
# x = self.dropout(x)
x = self.wo(x)
return x
class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device, operations):
super().__init__()
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x):
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear
# x = self.dropout(x)
x = self.wo(x)
return x
class T5LayerFF(torch.nn.Module):
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
super().__init__()
if ff_activation == "gelu_pytorch_tanh":
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device, operations)
elif ff_activation == "relu":
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, dtype, device, operations)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x):
forwarded_states = self.layer_norm(x)
forwarded_states = self.DenseReluDense(forwarded_states)
# x = x + self.dropout(forwarded_states)
x += forwarded_states
return x
class T5Attention(torch.nn.Module):
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations):
super().__init__()
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.k = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.v = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.o = operations.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
self.num_heads = num_heads
self.relative_attention_bias = None
if relative_attention_bias:
self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device):
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=True,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
q = self.q(x)
k = self.k(x)
v = self.v(x)
if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
if past_bias is not None:
if mask is not None:
mask = mask + past_bias
else:
mask = past_bias
out = optimized_attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
return self.o(out), past_bias
class T5LayerSelfAttention(torch.nn.Module):
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations):
super().__init__()
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
normed_hidden_states = self.layer_norm(x)
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
# x = x + self.dropout(attention_output)
x += output
return x, past_bias
class T5Block(torch.nn.Module):
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList()
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, dtype, device, operations))
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
x = self.layer[-1](x)
return x, past_bias
class T5Stack(torch.nn.Module):
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, num_heads, dtype, device, operations):
super().__init__()
self.block = torch.nn.ModuleList(
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
)
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
intermediate = None
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
past_bias = None
for i, l in enumerate(self.block):
x, past_bias = l(x, mask, past_bias, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.final_layer_norm(intermediate)
return x, intermediate
class T5(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.num_layers = config_dict["num_layers"]
model_dim = config_dict["d_model"]
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["num_heads"], dtype, device, operations)
self.dtype = dtype
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, embeddings):
self.shared = embeddings
def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids)
return self.encoder(x, *args, **kwargs)
{
"d_ff": 3072,
"d_kv": 64,
"d_model": 768,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "relu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 12,
"num_heads": 12,
"num_layers": 12,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 32128
}
{
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 32128
}
{
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"added_tokens_decoder": {
"0": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32000": {
"content": "<extra_id_99>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32001": {
"content": "<extra_id_98>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32002": {
"content": "<extra_id_97>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32003": {
"content": "<extra_id_96>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32004": {
"content": "<extra_id_95>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32005": {
"content": "<extra_id_94>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32006": {
"content": "<extra_id_93>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32007": {
"content": "<extra_id_92>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32008": {
"content": "<extra_id_91>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32009": {
"content": "<extra_id_90>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32010": {
"content": "<extra_id_89>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32011": {
"content": "<extra_id_88>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32012": {
"content": "<extra_id_87>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32013": {
"content": "<extra_id_86>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32014": {
"content": "<extra_id_85>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32015": {
"content": "<extra_id_84>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32016": {
"content": "<extra_id_83>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32017": {
"content": "<extra_id_82>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32018": {
"content": "<extra_id_81>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32019": {
"content": "<extra_id_80>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32020": {
"content": "<extra_id_79>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32021": {
"content": "<extra_id_78>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32022": {
"content": "<extra_id_77>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32023": {
"content": "<extra_id_76>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32024": {
"content": "<extra_id_75>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32025": {
"content": "<extra_id_74>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32026": {
"content": "<extra_id_73>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32027": {
"content": "<extra_id_72>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32028": {
"content": "<extra_id_71>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32029": {
"content": "<extra_id_70>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32030": {
"content": "<extra_id_69>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32031": {
"content": "<extra_id_68>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32032": {
"content": "<extra_id_67>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32033": {
"content": "<extra_id_66>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32034": {
"content": "<extra_id_65>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32035": {
"content": "<extra_id_64>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32036": {
"content": "<extra_id_63>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32037": {
"content": "<extra_id_62>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32038": {
"content": "<extra_id_61>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32039": {
"content": "<extra_id_60>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32040": {
"content": "<extra_id_59>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32041": {
"content": "<extra_id_58>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32042": {
"content": "<extra_id_57>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32043": {
"content": "<extra_id_56>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32044": {
"content": "<extra_id_55>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32045": {
"content": "<extra_id_54>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32046": {
"content": "<extra_id_53>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32047": {
"content": "<extra_id_52>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32048": {
"content": "<extra_id_51>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32049": {
"content": "<extra_id_50>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32050": {
"content": "<extra_id_49>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32051": {
"content": "<extra_id_48>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32052": {
"content": "<extra_id_47>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32053": {
"content": "<extra_id_46>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32054": {
"content": "<extra_id_45>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32055": {
"content": "<extra_id_44>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32056": {
"content": "<extra_id_43>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32057": {
"content": "<extra_id_42>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32058": {
"content": "<extra_id_41>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32059": {
"content": "<extra_id_40>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32060": {
"content": "<extra_id_39>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32061": {
"content": "<extra_id_38>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32062": {
"content": "<extra_id_37>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32063": {
"content": "<extra_id_36>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32064": {
"content": "<extra_id_35>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32065": {
"content": "<extra_id_34>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32066": {
"content": "<extra_id_33>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32067": {
"content": "<extra_id_32>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32068": {
"content": "<extra_id_31>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32069": {
"content": "<extra_id_30>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32070": {
"content": "<extra_id_29>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32071": {
"content": "<extra_id_28>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32072": {
"content": "<extra_id_27>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32073": {
"content": "<extra_id_26>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32074": {
"content": "<extra_id_25>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32075": {
"content": "<extra_id_24>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32076": {
"content": "<extra_id_23>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32077": {
"content": "<extra_id_22>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32078": {
"content": "<extra_id_21>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32079": {
"content": "<extra_id_20>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32080": {
"content": "<extra_id_19>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32081": {
"content": "<extra_id_18>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32082": {
"content": "<extra_id_17>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32083": {
"content": "<extra_id_16>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32084": {
"content": "<extra_id_15>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32085": {
"content": "<extra_id_14>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32086": {
"content": "<extra_id_13>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32087": {
"content": "<extra_id_12>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32088": {
"content": "<extra_id_11>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32089": {
"content": "<extra_id_10>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32090": {
"content": "<extra_id_9>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32091": {
"content": "<extra_id_8>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32092": {
"content": "<extra_id_7>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32093": {
"content": "<extra_id_6>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32094": {
"content": "<extra_id_5>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32095": {
"content": "<extra_id_4>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32096": {
"content": "<extra_id_3>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32097": {
"content": "<extra_id_2>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32098": {
"content": "<extra_id_1>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32099": {
"content": "<extra_id_0>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"clean_up_tokenization_spaces": true,
"eos_token": "</s>",
"extra_ids": 100,
"legacy": false,
"model_max_length": 512,
"pad_token": "<pad>",
"sp_model_kwargs": {},
"tokenizer_class": "T5Tokenizer",
"unk_token": "<unk>"
}
...@@ -132,6 +132,32 @@ class ModelSamplingStableCascade: ...@@ -132,6 +132,32 @@ class ModelSamplingStableCascade:
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return (m, )
class ModelSamplingSD3:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, shift):
m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
sampling_type = comfy.model_sampling.CONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingContinuousEDM: class ModelSamplingContinuousEDM:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
...@@ -213,5 +239,6 @@ NODE_CLASS_MAPPINGS = { ...@@ -213,5 +239,6 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"ModelSamplingStableCascade": ModelSamplingStableCascade, "ModelSamplingStableCascade": ModelSamplingStableCascade,
"ModelSamplingSD3": ModelSamplingSD3,
"RescaleCFG": RescaleCFG, "RescaleCFG": RescaleCFG,
} }
import folder_paths
import comfy.sd
import comfy.model_management
import nodes
import torch
class TripleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
class EmptySD3LatentImage:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/sd3"
def generate(self, width, height, batch_size=1):
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
return ({"samples":latent}, )
class CLIPTextEncodeSD3:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"empty_padding": (["none", "empty_prompt"], )
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
no_padding = empty_padding == "none"
tokens = clip.tokenize(clip_g)
if len(clip_g) == 0 and no_padding:
tokens["g"] = []
if len(clip_l) == 0 and no_padding:
tokens["l"] = []
else:
tokens["l"] = clip.tokenize(clip_l)["l"]
if len(t5xxl) == 0 and no_padding:
tokens["t5xxl"] = []
else:
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
if len(tokens["l"]) != len(tokens["g"]):
empty = clip.tokenize("")
while len(tokens["l"]) < len(tokens["g"]):
tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"]
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled}]], )
NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader,
"EmptySD3LatentImage": EmptySD3LatentImage,
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
}
...@@ -1964,6 +1964,7 @@ def init_custom_nodes(): ...@@ -1964,6 +1964,7 @@ def init_custom_nodes():
"nodes_attention_multiply.py", "nodes_attention_multiply.py",
"nodes_advanced_samplers.py", "nodes_advanced_samplers.py",
"nodes_webcam.py", "nodes_webcam.py",
"nodes_sd3.py",
] ]
import_failed = [] import_failed = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment