Commit 5e2c95b7 authored by wuxk1's avatar wuxk1
Browse files

init commit for comui

parents
Pipeline #3174 failed with stages
in 0 seconds
import torch
from typing import Optional
import comfy.ldm.modules.diffusionmodules.mmdit
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
def __init__(
self,
num_blocks = None,
control_latent_channels = None,
dtype = None,
device = None,
operations = None,
**kwargs,
):
super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
# controlnet_blocks
self.controlnet_blocks = torch.nn.ModuleList([])
for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
if control_latent_channels is None:
control_latent_channels = self.in_channels
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
None,
self.patch_size,
control_latent_channels,
self.hidden_size,
bias=True,
strict_img_size=False,
dtype=dtype,
device=device,
operations=operations
)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> torch.Tensor:
#weird sd3 controlnet specific stuff
y = torch.zeros_like(y)
if self.context_processor is not None:
context = self.context_processor(context)
hw = x.shape[-2:]
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
x += self.pos_embed_input(hint)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
y = self.y_embedder(y)
c = c + y
if context is not None:
context = self.context_embedder(context)
output = []
blocks = len(self.joint_blocks)
for i in range(blocks):
context, x = self.joint_blocks[i](
context,
x,
c=c,
use_checkpoint=self.use_checkpoint,
)
out = self.controlnet_blocks[i](x)
count = self.depth // blocks
if i == blocks - 1:
count -= 1
for j in range(count):
output.append(out)
return {"output": output}
import argparse
import enum
import os
import comfy.options
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
Auto = "auto"
Latent2RGB = "latent2rgb"
TAESD = "taesd"
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
upcast = parser.add_mutually_exclusive_group()
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
parser.add_argument(
"--front-end-version",
type=str,
default=DEFAULT_VERSION_STRING,
help="""
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
download available frontend implementations from GitHub releases.
The version string should be in the format of:
[repoOwner]/[repoName]@[version]
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
""",
)
def is_valid_directory(path: str) -> str:
"""Validate if the given path is a directory, and check permissions."""
if not os.path.exists(path):
raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
if not os.access(path, os.R_OK):
raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
return path
parser.add_argument(
"--front-end-root",
type=is_valid_directory,
default=None,
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
parser.add_argument(
"--comfy-api-base",
type=str,
default="https://api.comfy.org",
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
)
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:
args = parser.parse_args([])
if args.windows_standalone_build:
args.auto_launch = True
if args.disable_auto_launch:
args.auto_launch = False
if args.force_fp16:
args.fp16_unet = True
# '--fast' is not provided, use an empty set
if args.fast is None:
args.fast = set()
# '--fast' is provided with an empty list, enable all optimizations
elif args.fast == []:
args.fast = set(PerformanceFeature)
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)
{
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"hidden_act": "gelu",
"hidden_size": 1280,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 20,
"num_hidden_layers": 32,
"pad_token_id": 1,
"projection_dim": 1280,
"torch_dtype": "float32",
"vocab_size": 49408
}
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
self.heads = heads
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None, optimized_attention=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
}
class CLIPMLP(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
super().__init__()
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
self.activation = ACTIVATIONS[activation]
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
def forward(self, x, mask=None, optimized_attention=None):
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
super().__init__()
self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, dtype=torch.float32):
return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
class CLIPTextModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
num_positions = config_dict["max_position_embeddings"]
self.eos_token_id = config_dict["eos_token_id"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:
x = self.embeddings(input_tokens, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
if num_tokens is not None:
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
else:
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.dtype = dtype
def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
super().__init__()
num_patches = (image_size // patch_size) ** 2
if model_type == "siglip_vision_model":
self.class_embedding = None
patch_bias = True
else:
num_patches = num_patches + 1
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
patch_bias = False
self.patch_embedding = operations.Conv2d(
in_channels=num_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=patch_bias,
dtype=dtype,
device=device
)
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
if self.class_embedding is not None:
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
class CLIPVision(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.output_layernorm = False
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x = self.pre_layrnorm(x)
#TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
if self.output_layernorm:
x = self.post_layernorm(x)
pooled_output = x
else:
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class LlavaProjector(torch.nn.Module):
def __init__(self, in_dim, out_dim, dtype, device, operations):
super().__init__()
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
if "projection_dim" in config_dict:
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
else:
self.visual_projection = lambda a: a
if "llava3" == config_dict.get("projector_type", None):
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
else:
self.multi_modal_projector = None
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
projected = None
if self.multi_modal_projector is not None:
projected = self.multi_modal_projector(x[1])
return (x[0], x[1], out, projected)
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import os
import torch
import json
import logging
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
class Output:
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["mm_projected"] = out[3]
return outputs
def convert_to_transformers(sd, prefix):
sd_k = sd.keys()
if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
keys_to_replace = {
"{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
"{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
"{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
"{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
"{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
"{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
"{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
}
for x in keys_to_replace:
if x in sd_k:
sd[keys_to_replace[x]] = sd.pop(x)
if "{}proj".format(prefix) in sd_k:
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
sd = transformers_convert(sd, prefix, "vision_model.", 48)
else:
replace_prefix = {prefix: ""}
sd = state_dict_prefix_replace(sd, replace_prefix)
return sd
def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if convert_keys:
sd = convert_to_transformers(sd, prefix)
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
elif embed_shape == 577:
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
else:
return None
clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
if len(m) > 0:
logging.warning("missing clip vision: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
sd.pop(k)
return clip
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
else:
return load_clipvision_from_sd(sd)
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "gelu",
"hidden_size": 1664,
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 8192,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 48,
"patch_size": 14,
"projection_dim": 1280,
"torch_dtype": "float32"
}
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "gelu",
"hidden_size": 1280,
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 32,
"patch_size": 14,
"projection_dim": 1024,
"torch_dtype": "float32"
}
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"torch_dtype": "float32"
}
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-5,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"torch_dtype": "float32"
}
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-5,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"projector_type": "llava3",
"torch_dtype": "float32"
}
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 512,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 16,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}
# Comfy Typing
## Type hinting for ComfyUI Node development
This module provides type hinting and concrete convenience types for node developers.
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
```python
from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
class ExampleNode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {"required": {}}
```
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
# Types
A few primary types are documented below. More complete information is available via the docstrings on each type.
## `IO`
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
- `ANY`: `"*"`
- `NUMBER`: `"FLOAT,INT"`
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
## `ComfyNodeABC`
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
### Type hinting for `INPUT_TYPES`
![INPUT_TYPES auto-completion in Visual Studio Code](examples/input_types.png)
### `INPUT_TYPES` return dict
![INPUT_TYPES return value type hinting in Visual Studio Code](examples/required_hint.png)
### Options for individual inputs
![INPUT_TYPES return value option auto-completion in Visual Studio Code](examples/input_options.png)
import torch
from typing import Callable, Protocol, TypedDict, Optional, List
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
class UnetApplyFunction(Protocol):
"""Function signature protocol on comfy.model_base.BaseModel.apply_model"""
def __call__(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
pass
class UnetApplyConds(TypedDict):
"""Optional conditions for unet apply function."""
c_concat: Optional[torch.Tensor]
c_crossattn: Optional[torch.Tensor]
control: Optional[torch.Tensor]
transformer_options: Optional[dict]
class UnetParams(TypedDict):
# Tensor of shape [B, C, H, W]
input: torch.Tensor
# Tensor of shape [B]
timestep: torch.Tensor
c: UnetApplyConds
# List of [0, 1], [0], [1], ...
# 0 means conditional, 1 means conditional unconditional
cond_or_uncond: List[int]
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
__all__ = [
"UnetWrapperFunction",
UnetApplyConds.__name__,
UnetParams.__name__,
UnetApplyFunction.__name__,
IO.__name__,
InputTypeDict.__name__,
ComfyNodeABC.__name__,
CheckLazyMixin.__name__,
FileLocator.__name__,
]
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
from inspect import cleandoc
class ExampleNode(ComfyNodeABC):
"""An example node that just adds 1 to an input integer.
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
* This node is intended as an example for developers only.
"""
DESCRIPTION = cleandoc(__doc__)
CATEGORY = "examples"
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"input_int": (IO.INT, {"defaultInput": True}),
}
}
RETURN_TYPES = (IO.INT,)
RETURN_NAMES = ("input_plus_one",)
FUNCTION = "execute"
def execute(self, input_int: int):
return (input_int + 1,)
"""Comfy-specific type hinting"""
from __future__ import annotations
from typing import Literal, TypedDict, Optional
from typing_extensions import NotRequired
from abc import ABC, abstractmethod
from enum import Enum
class StrEnum(str, Enum):
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
def __str__(self) -> str:
return self.value
class IO(StrEnum):
"""Node input/output data types.
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
"""
STRING = "STRING"
IMAGE = "IMAGE"
MASK = "MASK"
LATENT = "LATENT"
BOOLEAN = "BOOLEAN"
INT = "INT"
FLOAT = "FLOAT"
COMBO = "COMBO"
CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS"
GUIDER = "GUIDER"
NOISE = "NOISE"
CLIP = "CLIP"
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
LORA_MODEL = "LORA_MODEL"
LOSS_MAP = "LOSS_MAP"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"
GLIGEN = "GLIGEN"
UPSCALE_MODEL = "UPSCALE_MODEL"
AUDIO = "AUDIO"
WEBCAM = "WEBCAM"
POINT = "POINT"
FACE_ANALYSIS = "FACE_ANALYSIS"
BBOX = "BBOX"
SEGS = "SEGS"
VIDEO = "VIDEO"
ANY = "*"
"""Always matches any type, but at a price.
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
"""
NUMBER = "FLOAT,INT"
"""A float or an int - could be either"""
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
"""Could be any of: string, float, int, or bool"""
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class RemoteInputOptions(TypedDict):
route: str
"""The route to the remote source."""
refresh_button: bool
"""Specifies whether to show a refresh button in the UI below the widget."""
control_after_refresh: Literal["first", "last"]
"""Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on."""
timeout: int
"""The maximum amount of time to wait for a response from the remote source in milliseconds."""
max_retries: int
"""The maximum number of retries before aborting the request."""
refresh: int
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
class MultiSelectOptions(TypedDict):
placeholder: NotRequired[str]
"""The placeholder text to display in the multi-select widget when no items are selected."""
chip: NotRequired[bool]
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function.
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
"""
default: NotRequired[bool | str | float | int | list | tuple]
"""The default value of the widget"""
defaultInput: NotRequired[bool]
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
- defaultInput on required inputs should be dropped.
- defaultInput on optional inputs should be replaced with forceInput.
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
"""
forceInput: NotRequired[bool]
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
lazy: NotRequired[bool]
"""Declares that this input uses lazy evaluation"""
rawLink: NotRequired[bool]
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
tooltip: NotRequired[str]
"""Tooltip for the input (or widget), shown on pointer hover"""
socketless: NotRequired[bool]
"""All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
Available from frontend v1.17.5
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
"""
widgetType: NotRequired[str]
"""Specifies a type to be used for widget initialization if different from the input type.
Available from frontend v1.18.0
https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550"""
# class InputTypeNumber(InputTypeOptions):
# default: float | int
min: NotRequired[float]
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
max: NotRequired[float]
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
step: NotRequired[float]
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
round: NotRequired[float]
"""Floats are rounded by this value (``FLOAT``)"""
# class InputTypeBoolean(InputTypeOptions):
# default: bool
label_on: NotRequired[str]
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_off: NotRequired[str]
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
multiline: NotRequired[bool]
"""Use a multiline text box (``STRING``)"""
placeholder: NotRequired[str]
"""Placeholder text to display in the UI when empty (``STRING``)"""
# Deprecated:
# defaultVal: str
dynamicPrompts: NotRequired[bool]
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
# class InputTypeCombo(InputTypeOptions):
image_upload: NotRequired[bool]
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
image_folder: NotRequired[Literal["input", "output", "temp"]]
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
"""
remote: NotRequired[RemoteInputOptions]
"""Specifies the configuration for a remote input.
Available after ComfyUI frontend v1.9.7
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
control_after_generate: NotRequired[bool]
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
options: NotRequired[list[str | int | float]]
"""COMBO type only. Specifies the selectable options for the combo widget.
Prefer:
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
Over:
[["Option 1", "Option 2", "Option 3"]]
"""
multi_select: NotRequired[MultiSelectOptions]
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
class HiddenInputTypeDict(TypedDict):
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
node_id: NotRequired[Literal["UNIQUE_ID"]]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
unique_id: NotRequired[Literal["UNIQUE_ID"]]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt: NotRequired[Literal["PROMPT"]]
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]]
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt: NotRequired[Literal["DYNPROMPT"]]
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
class InputTypeDict(TypedDict):
"""Provides type hinting for node INPUT_TYPES.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
"""
required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
"""Describes all inputs that must be connected for the node to execute."""
optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
"""Describes inputs which do not need to be connected."""
hidden: NotRequired[HiddenInputTypeDict]
"""Offers advanced functionality and server-client communication.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
"""
class ComfyNodeABC(ABC):
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview
"""
DESCRIPTION: str
"""Node description, shown as a tooltip when hovering over the node.
Usage::
# Explicitly define the description
DESCRIPTION = "Example description here."
# Use the docstring of the node class.
DESCRIPTION = cleandoc(__doc__)
"""
CATEGORY: str
"""The category of the node, as per the "Add Node" menu.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#category
"""
EXPERIMENTAL: bool
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
API_NODE: Optional[bool]
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
@classmethod
@abstractmethod
def INPUT_TYPES(s) -> InputTypeDict:
"""Defines node inputs.
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
* The ``optional`` key can be added to describe inputs which do not need to be connected.
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#input-types
"""
return {"required": {}}
OUTPUT_NODE: bool
"""Flags this node as an output node, causing any inputs it requires to be executed.
If a node is not connected to any output nodes, that node will not be executed. Usage::
OUTPUT_NODE = True
From the docs:
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
"""
INPUT_IS_LIST: bool
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
From the docs:
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
OUTPUT_IS_LIST: tuple[bool, ...]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
From the docs:
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
specifying which outputs which should be so treated.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
RETURN_TYPES: tuple[IO, ...]
"""A tuple representing the outputs of this node.
Usage::
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
"""
RETURN_NAMES: tuple[str, ...]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
"""
OUTPUT_TOOLTIPS: tuple[str, ...]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#function
"""
class CheckLazyMixin:
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
def check_lazy_status(self, **kwargs) -> list[str]:
"""Returns a list of input names that should be evaluated.
This basic mixin impl. requires all inputs.
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]
return need
class FileLocator(TypedDict):
"""Provides type hinting for the file location"""
filename: str
"""The filename of the file."""
subfolder: str
"""The subfolder of the file."""
type: Literal["input", "output", "temp"]
"""The root folder of the file."""
import torch
import math
import comfy.utils
import logging
class CONDRegular:
def __init__(self, cond):
self.cond = cond
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device, skipping concat.")
return False
return True
def concat(self, others):
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)
def size(self):
return list(self.cond.size())
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, area, **kwargs):
data = self.cond
if area is not None:
dims = len(area) // 2
for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i])
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
class CONDCrossAttn(CONDRegular):
def can_concat(self, other):
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = math.lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device: skipping concat.")
return False
return True
def concat(self, others):
conds = [self.cond]
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
if self.cond != other.cond:
return False
return True
def concat(self, others):
return self.cond
def size(self):
return [1]
class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
return self._copy_with(out)
def can_concat(self, other):
if len(self.cond) != len(other.cond):
return False
for i in range(len(self.cond)):
if self.cond[i].shape != other.cond[i].shape:
return False
return True
def concat(self, others):
out = []
for i in range(len(self.cond)):
o = [self.cond[i]]
for x in others:
o.append(x.cond[i])
out.append(torch.cat(o))
return out
def size(self): # hackish implementation to make the mem estimation work
o = 0
c = 1
for c in self.cond:
size = c.size()
o += math.prod(size)
if len(size) > 1:
c = size[1]
return [1, c, o // c]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment